Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup losses tutorial #251

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 53 additions & 21 deletions tutorials/classy_loss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a Classy Loss"
"# Creating a custom loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loss functions are crucial because they define the objective to optimize for during training. Classy Vision can work directly with loss functions defined in [PyTorch](https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html) without the need for any wrapper classes, but during research it's common to create custom losses with hyperparameters. Using `ClassyLoss` you can expose these hyperparameters via a configuration file.\n",
"\n",
"This tutorial will demonstrate: (1) how to create a custom loss within Classy Vision; (2) how to integrate your loss with Classy Vision's configuration system; (3) how to use a ClassyLoss independently, without other classy vision abstractions.\n",
"\n",
"## 1. Defining a loss\n",
"\n",
"Creating a new loss in Classy Vision is as simple as adding a new loss within PyTorch. The loss has to derive from `ClassyLoss` (which inherits from [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#module)), and implement a `forward` method.\n",
"\n",
"**Note**: The forward method should take the right arguments depending on the task the loss will be used for. For instance, a `ClassificationTask` passes the `output` and `target` to `forward`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from classy_vision.losses import ClassyLoss\n",
Expand All @@ -39,15 +43,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can start using this loss for training."
"Now we can start using this loss for training. Take a look at our [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial for more details on how to train a model from a Jupyter notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from classy_vision.tasks import ClassificationTask\n",
Expand All @@ -60,14 +62,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Integrate it with the configuration system\n",
"\n",
"To be able to use the registration mechanism to be able to pick up the loss from a configuration, we need to do two additional things -\n",
"- Implement a `from_config` method\n",
"- Add the `register_loss` decorator to `MyLoss`"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -86,7 +90,7 @@
" return cls(alpha=config[\"alpha\"])\n",
" \n",
" def forward(self, output, target):\n",
" return (output - target).pow(2) * self.alpha"
" return (output - target).pow(2).sum() * self.alpha"
]
},
{
Expand All @@ -98,20 +102,48 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(3.5601)\n"
]
}
],
"source": [
"from classy_vision.losses import build_loss\n",
"import torch\n",
"\n",
"loss_config = {\n",
" \"name\": \"my_loss\",\n",
" \"alpha\": 5\n",
"}\n",
"my_loss = build_loss(loss_config)\n",
"assert isinstance(my_loss, MyLoss)"
"assert isinstance(my_loss, MyLoss)\n",
"\n",
"# ClassyLoss inherits from torch.nn.Module, so it works as expected\n",
"with torch.no_grad():\n",
" y_hat, target = torch.rand((1, 10)), torch.rand((1, 10))\n",
" print(my_loss(y_hat, target))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that your loss is integrated with the configuration system, you can train it using `classy_train.py` as described in the [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial, no further changes are needed! Just make sure the code defining your model is in the `losses` folder of your classy project."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Conclusion\n",
"\n",
"In this tutorial, we learned how to make your loss compatible with Classy Vision and how to integrate it with the configuration system. Refer to our documentation to learn more about [ClassyLoss](https://classyvision.ai/api/losses.html)."
]
}
],
Expand All @@ -125,9 +157,9 @@
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"display_name": "Classy Vision",
"display_name": "Python 3",
"language": "python",
"name": "bento_kernel_classy_vision"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -139,7 +171,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down