Skip to content

Commit

Permalink
Classy model tutorial (#215) (#215)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #215

Added a tutorial for creating a model. This is a very basic tutorial which skips the following concepts-
- Attachable blocks
- Heads
- Optimizer params

The primary reason is that the API for these isn't the best and we will probably change it in the next iteration
Pull Request resolved: #248

Test Plan: {F222475868}

Differential Revision: D18511058

Pulled By: vreis

fbshipit-source-id: a3dfe79d0a2b6d0fa19b9a3ff78a232fc41dae74
  • Loading branch information
vreis authored and facebook-github-bot committed Nov 23, 2019
1 parent 1f99989 commit 888ce39
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 3 deletions.
203 changes: 203 additions & 0 deletions tutorials/classy_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating a custom model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial will demonstrate: (1) how to create a custom model within Classy Vision; (2) how to integrate your model with Classy Vision's configuration system; (3) how to use the model for training and inference;\n",
"\n",
"## 1. Defining a model\n",
"\n",
"Creating a new model in Classy Vision is the simple as creating one within PyTorch. The model needs to derive from `ClassyModel` and implement a `forward` method to perform inference. `ClassyModel` inherits from [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#module), so it works exactly as you would expect."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"from classy_vision.models import ClassyModel\n",
"\n",
"\n",
"class MyModel(ClassyModel):\n",
" def __init__(self, num_classes):\n",
" super().__init__()\n",
" \n",
" # Average all the pixels, generate one output per class\n",
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
" num_channels = 3\n",
" self.fc = nn.Linear(num_channels, num_classes)\n",
" \n",
" def forward(self, x):\n",
" # perform average pooling\n",
" out = self.avgpool(x)\n",
"\n",
" # reshape the output and apply the fc layer\n",
" out = out.reshape(out.size(0), -1)\n",
" out = self.fc(out)\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can start using this model for training. Take a look at our [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial for more details on how to train a model from a Jupyter notebook."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from classy_vision.tasks import ClassificationTask\n",
"\n",
"my_model = MyModel(num_classes=1000)\n",
"my_task = ClassificationTask().set_model(my_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Integrate it with the configuration system\n",
"\n",
"Classy Vision is also able to read a configuration file and instantiate the model. This is useful to keep your experiments organized and reproducible. For that, you have to:\n",
"\n",
"- Implement a `from_config` method\n",
"- Add the `register_model` decorator to `MyModel`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"from classy_vision.models import ClassyModel, register_model\n",
"\n",
"\n",
"@register_model(\"my_model\")\n",
"class MyModel(ClassyModel):\n",
" def __init__(self, num_classes):\n",
" super().__init__()\n",
" \n",
" # Average all the pixels, generate one output per class\n",
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
" num_channels = 3\n",
" self.fc = nn.Linear(num_channels, num_classes)\n",
"\n",
" @classmethod\n",
" def from_config(cls, config):\n",
" # This method takes a configuration dictionary \n",
" # and returns an instance of the class. In this case, \n",
" # we'll let the number of classes be configurable.\n",
" return cls(num_classes=config[\"num_classes\"])\n",
" \n",
" def forward(self, x):\n",
" # perform average pooling\n",
" out = self.avgpool(x)\n",
"\n",
" # reshape the output and apply the fc layer\n",
" out = out.reshape(out.size(0), -1)\n",
" out = self.fc(out)\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can start using this model in our configurations. The argument passed to `register_model` is used to identify the model class in the configuration:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-0.1141, -0.4042, 0.2112]])\n"
]
}
],
"source": [
"from classy_vision.models import build_model\n",
"import torch\n",
"\n",
"model_config = {\n",
" \"name\": \"my_model\",\n",
" \"num_classes\": 3\n",
"}\n",
"my_model = build_model(model_config)\n",
"assert isinstance(my_model, MyModel)\n",
"\n",
"# my_model inherits from torch.nn.Module, so inference works as usual:\n",
"x = torch.rand((1, 3, 200, 200))\n",
"with torch.no_grad():\n",
" print(my_model(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that your model is integrated with the configuration system, you can train it using `classy_train.py` as described in the [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial, no further changes are needed! Just make sure the code defining your model is in the `models` folder of your classy project."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Conclusion\n",
"\n",
"In this tutorial, we learned how to make your model compatible with Classy Vision and how to integrate it with the configuration system. Refer to our documentation to learn more about [ClassyModel](https://classyvision.ai/api/models.html)."
]
}
],
"metadata": {
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions website/tutorials.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
{
"id": "getting_started",
"title": "Getting started"
},{
"id": "classy_model",
"title": "Creating a custom model"
},{
"id": "classy_loss",
"title": "Creating a custom loss"
Expand All @@ -12,9 +15,6 @@
},{
"id": "fine_tuning",
"title": "Fine Tuning"
},{
"id": "wsl_model_predict",
"title": "WSL Model Predict"
}
]
}

0 comments on commit 888ce39

Please sign in to comment.