### Creating a Classy Model

Creating a new model in Classy Vision is the simple as creating one within PyTorch. The model needs to derive from `ClassyModel` (which inherits from [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#module)), call `super().__init()` with the appropriate args, and implement a `forward` method.

In this tutorial, we will focus on creating an image model, so it should expect an input tensor of shape `(N, C, H, W)`, where `N` is the batch size, `C` is the number of channels, `H` and `W` are the height and width of the image, respectively.

In [1]:
import torch.nn as nn

from classy_vision.models import ClassyModel


class MyModel(ClassyModel):
    def __init__(self, num_classes):
        super().__init__(num_classes)
        # create an average pool layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # we expect an RGB image
        num_channels = 3
        self.fc = nn.Linear(num_channels, num_classes)
        
    def forward(self, x):
        # perform average pooling
        out = self.avgpool(x)

        # reshape the output and apply the fc layer
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

Now we can start using this model for training.

In [2]:
from classy_vision.tasks import ClassificationTask

my_model = MyModel(num_classes=1000)
my_task = ClassificationTask().set_model(my_model)

To be able to use the registration mechanism to be able to pick up the model from a configuration, we need to do two additional things -
- Implement a `from_config` method
- Add the `register_model` decorator to `MyModel`

In [3]:
import torch.nn as nn

from classy_vision.models import ClassyModel, register_model


@register_model("my_model")
class MyModel(ClassyModel):
    def __init__(self, num_classes):
        super().__init__(num_classes)
        # create an average pool layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # we expect an RGB image
        num_channels = 3
        self.fc = nn.Linear(num_channels, num_classes)

    @classmethod
    def from_config(cls, config):
        if "num_classes" not in config:
            raise ValueError('Need "num_classes" in config for MyModel')
        return cls(num_classes=config["num_classes"])
        
    def forward(self, x):
        # perform average pooling
        out = self.avgpool(x)

        # reshape the output and apply the fc layer
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

Now we can start using this model in our configurations.

In [4]:
from classy_vision.models import build_model

model_config = {
    "name": "my_model",
    "num_classes": 1000
}
my_model = build_model(model_config)
assert isinstance(my_model, MyModel)