Skip to content

Commit

Permalink
Cleanup losses tutorial (#251)
Browse files Browse the repository at this point in the history
Summary:
Change the losses tutorial to be more structured -- following Jessica's
feedback. I think this looks much cleaner.
Pull Request resolved: #251

Differential Revision: D18697317

Pulled By: vreis

fbshipit-source-id: 84df08436ce9328f5db4a389d399f4f023ebc0b6
  • Loading branch information
vreis authored and facebook-github-bot committed Nov 26, 2019
1 parent f826711 commit 3c25363
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions tutorials/classy_loss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a Classy Loss"
"# Creating a custom loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loss functions are crucial because they define the objective to optimize for during training. Classy Vision can work directly with loss functions defined in [PyTorch](https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html) without the need for any wrapper classes, but during research it's common to create custom losses with hyperparameters. Using `ClassyLoss` you can expose these hyperparameters via a configuration file.\n",
"\n",
"This tutorial will demonstrate: (1) how to create a custom loss within Classy Vision; (2) how to integrate your loss with Classy Vision's configuration system; (3) how to use a ClassyLoss independently, without other classy vision abstractions.\n",
"\n",
"## 1. Defining a loss\n",
"\n",
"Creating a new loss in Classy Vision is as simple as adding a new loss within PyTorch. The loss has to derive from `ClassyLoss` (which inherits from [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#module)), and implement a `forward` method.\n",
"\n",
"**Note**: The forward method should take the right arguments depending on the task the loss will be used for. For instance, a `ClassificationTask` passes the `output` and `target` to `forward`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from classy_vision.losses import ClassyLoss\n",
Expand All @@ -39,15 +43,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can start using this loss for training."
"Now we can start using this loss for training. Take a look at our [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial for more details on how to train a model from a Jupyter notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from classy_vision.tasks import ClassificationTask\n",
Expand All @@ -60,14 +62,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Integrate it with the configuration system\n",
"\n",
"To be able to use the registration mechanism to be able to pick up the loss from a configuration, we need to do two additional things -\n",
"- Implement a `from_config` method\n",
"- Add the `register_loss` decorator to `MyLoss`"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -86,7 +90,7 @@
" return cls(alpha=config[\"alpha\"])\n",
" \n",
" def forward(self, output, target):\n",
" return (output - target).pow(2) * self.alpha"
" return (output - target).pow(2).sum() * self.alpha"
]
},
{
Expand All @@ -98,20 +102,48 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(3.5601)\n"
]
}
],
"source": [
"from classy_vision.losses import build_loss\n",
"import torch\n",
"\n",
"loss_config = {\n",
" \"name\": \"my_loss\",\n",
" \"alpha\": 5\n",
"}\n",
"my_loss = build_loss(loss_config)\n",
"assert isinstance(my_loss, MyLoss)"
"assert isinstance(my_loss, MyLoss)\n",
"\n",
"# ClassyLoss inherits from torch.nn.Module, so it works as expected\n",
"with torch.no_grad():\n",
" y_hat, target = torch.rand((1, 10)), torch.rand((1, 10))\n",
" print(my_loss(y_hat, target))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that your loss is integrated with the configuration system, you can train it using `classy_train.py` as described in the [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial, no further changes are needed! Just make sure the code defining your model is in the `losses` folder of your classy project."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Conclusion\n",
"\n",
"In this tutorial, we learned how to make your loss compatible with Classy Vision and how to integrate it with the configuration system. Refer to our documentation to learn more about [ClassyLoss](https://classyvision.ai/api/losses.html)."
]
}
],
Expand All @@ -125,9 +157,9 @@
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"display_name": "Classy Vision",
"display_name": "Python 3",
"language": "python",
"name": "bento_kernel_classy_vision"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -139,7 +171,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 3c25363

Please sign in to comment.