# Building Custom Models

We will building a custom ResNet model, by wrapping the ResNet class of the timm library into a `PreTrainedModel`.

## Writing a custom configuration

The configuration of a model is an object that will contain all the necessary information to build the model. The model can only take a `config` to be initialized, so we need that object to be as complete as possible.

We will take a couple of arguments of the ResNet class that we may want to tweak. Different configurations will then give us the different types of ResNets that are possible.

In [1]:
from transformers import PretrainedConfig
from typing import List


class ResnetConfig(PretrainedConfig):
    model_type = 'resnet'

    def __init__(
            self,
            block_type='bottleneck',
            layers: List[int] = [3, 4, 6, 3],
            num_classes: int = 1000,
            input_channels: int = 3,
            cardinality: int = 1,
            base_width: int = 64,
            stem_width: int = 64,
            stem_type: str = "",
            avg_down: bool = False,
            **kwargs
    ):
        if block_type not in ['basic', 'bottleneck']:
            raise ValueError(f"`block_type` must be one of ['basic', 'bottleneck'], got {block_type}.")
        if stem_type not in ['', 'deep', 'deep-tiered']:
            raise ValueError(f"`stem_type` must be one of ['', 'deep', 'deep-tiered'], got {stem_type}.")

        self.block_type = block_type
        self.layers = layers
        self.num_classes = num_classes
        self.input_channels = input_channels
        self.cardinality = cardinality
        self.base_width = base_width
        self.stem_width = stem_width
        self.stem_type = stem_type
        self.avg_down = avg_down
        super().__init__(**kwargs)

Three important things:
* have to inherit from `PretrainedConfig`,
* the `__init__` of our `PretrainedConfig` must accept any kwargs,
* those `kwargs` need to be passed to the superclass `__init__`.

The inheritance is to make sure we get all the functionality from the HuggingFace Transformers library.

Defining a `model_type` for our configuration is not mandatory, unless we want to register our model with the auto classes.

With this configuration done, we can easily create and save our configuration:

In [2]:
resnet50d_config = ResnetConfig(
    block_type='bottleneck',
    stem_width=32,
    stem_type='deep',
    avg_down=True
)
resnet50d_config.save_pretrained('custom-resnet')

This will save a file named `config.json` inside the folder `custom-resnet`. We can reload our config with the `from_pretrained` method:

In [None]:
resnet50d_config = ResnetConfig.from_pretrained('custom-resnet')

## Writing a custom model

Now that we have our ResNet configuraiton, we can go on writing the model. We will write two models: one to extract the hidden features from a batch of images and one that is suitable for image classification.

In [3]:
from transformers import PreTrainedModel
from timm.models.resnet import BasicBlock, Bottleneck, ResNet

BLOCK_MAPPING = {
    'basic': BasicBlock,
    'bottleneck': Bottleneck
}


class ResnetModel(PreTrainedModel):
    config_class = ResnetConfig # ResnetConfig class is defined in the previous section

    def __init__(self, config):
        super().__init__(config)

        # Mapping the block types to the actual block classes
        block_layer = BLOCK_MAPPING[config.block_type]

        self.model = ResNet(
            block_layer,
            config.layers,
            num_classes=config.num_classes,
            in_chans=config.input_channels,
            cardinality=config.cardinality,
            base_width=config.base_width,
            stem_width=config.stem_width,
            stem_type=config.stem_type,
            avg_down=config.avg_down
        )

    def forward(self, tensor):
        return self.model.forward_features(tensor)

For the image classification task, we need to change the forward method in the model:

In [5]:
import torch

class ResnetModelForImageClassification(PreTrainedModel):
    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        block_layer = BLOCK_MAPPING[config.block_type]

        self.model = ResNet(
            block_layer,
            config.layers,
            num_classes=config.num_classes,
            in_chans=config.input_channels,
            cardinality=config.cardinality,
            base_width=config.base_width,
            stem_width=config.stem_width,
            stem_type=config.stem_type,
            avg_down=config.avg_down
        )

    def forward(self, tensor, labels=None):
        logits = self.model(tensor)

        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {'loss': loss, 'logits': logits}
        else:
            return {'logits': logits}

We inherit both classes from `PreTrainedModel` and call the superclass initialization with the `config` (similar to writing a class with `torch.nn.Module`). The `config_class` is not mandatory, unless we want to register our model with the auto classes.

We can have our model return anything we want. Moreover, with the loss included in the returned dictionary when labels are passed, the model is directly usable inside the `Trainer` class. If we use another output format, we need to write our own training loop or apply another library for training.

Now that we have our model classes, we can create one:

In [8]:
resnet50d = ResnetModel(resnet50d_config)
resnet50d

ResnetModel(
  (model): ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [9]:
resnet50d_ic = ResnetModelForImageClassification(resnet50d_config)
resnet50d_ic

ResnetModelForImageClassification(
  (model): ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [17]:
sample_data = torch.randn(1, 3, 224, 224)

features = resnet50d(sample_data)
logits = resnet50d_ic(sample_data)['logits']

features.size(), logits.size()

(torch.Size([1, 2048, 7, 7]), torch.Size([1, 1000]))

In this example, we did not change the resnet default architecture, so we will use the pretrained weights:

In [7]:
import timm

pretrained_model = timm.create_model('resnet50d', pretrained=True)
# transfer weights
resnet50d.model.load_state_dict(pretrained_model.state_dict())

model.safetensors:   0%|          | 0.00/103M [00:00<?, ?B/s]

<All keys matched successfully>

## Registering a model with custom code to the `AutoClass`

As long as our config has a `model_type` attribute that is different from the existing model types, and that our model classes have the right `config_class` attributes, we can add them to the auto classes:

In [None]:
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification

AutoConfig.register('resnet', ResnetConfig)
AutoModel.register(ResnetConfig, ResnetModel)
AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)

The first argument when registering our custom config to `AutoConfig` needs to match the `model_type` of our custom config, and the first argument when registering our custom models to auto model class needs to match the `config_class` of those models.

Using a model with custom code

In [None]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("sgugger/custom-resnet50d", trust_remote_code=True)

In [None]:
commit_hash = "ed94a7c6247d8aedce4647f00f20de6875b5b292"
model = AutoModelForImageClassification.from_pretrained(
    "sgugger/custom-resnet50d", trust_remote_code=True, revision=commit_hash
)