In [None]:
# hide
# skip
!git clone https://github.com/benihime91/gale # install gale on colab
!pip install -e "gale[dev]"

In [None]:
# default_exp classification.modelling.meta_arch.vit

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Meta Architectures : Vision Transformer (ViT) 
> Pretrained Vision Transformers modified for use in gale from timm

In [None]:
# export
import logging
from typing import *

import torch
from fastcore.all import store_attr, use_kwargs_dict
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.core.memory import get_human_readable_count
from timm import create_model
from timm.models.vision_transformer import VisionTransformer
from timm.optim.optim_factory import add_weight_decay

from gale.core.classes import GaleModule
from gale.core.nn.activations import ACTIVATION_REGISTRY
from gale.core.nn.shape_spec import ShapeSpec
from gale.core.nn.utils import trainable_params

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# export
# @TODO: Add support for Discriminative Lr's
class ViT(GaleModule):
    """
    A interface to create a Vision Transformer from timm. For available model check :
    https://github.com/rwightman/pytorch-image-models/timm/models/vision_transformer.py
    """

    @use_kwargs_dict(
        keep=True,
        num_classes=1000,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
    )
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        lr: float = 1e-03,
        wd: float = 1e-05,
        pretrained: bool = True,
        freeze_to: Optional[int] = None,
        finetune: Optional[bool] = None,
        act: Optional[str] = None,
        reset_classifier: bool = True,
        filter_wd: bool = True,
        **kwargs,
    ):
        """
        Arguments:
        1. `input_shape` (ShapeSpec): input image shape. For ViT `height=width` and check the above link for avilable model shapes.
        2. `model_name` (str): name of the ViT model, check the above link for avilable models.
        3. `pretrained` (bool): load weights pretrained on imagenet.
        4. `act` (str): name of the activation layer. Must be registerd in `ACTIVATION_REGISTRY`
        5. `num_classes` (int): num output classes.
        6. `drop_rate` (float): dropout rate.
        7. `attn_drop_rate` (float): attention dropout rate.
        8. `drop_path_rate` (float): stochastic depth rate.
        9. `reset_classifier` (bool): resets the weights of the classifier.
        10. `freeze_to` (int): Freeze the param meter groups of the model upto n.
        11. `finetune` (bool): Freeze all the layers and keep only the `classifier` trainable.
        """
        super(ViT, self).__init__()
        # create model from timm
        assert input_shape.height == input_shape.width
        in_chans = input_shape.channels

        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)
        # fmt: off
        self.model: VisionTransformer = create_model(model_name, pretrained, in_chans=in_chans, act=act, **kwargs)
        # fmt: on
        assert isinstance(self.model, VisionTransformer)

        if reset_classifier:
            num_cls = kwargs.pop("num_classes")
            self.model.reset_classifier(num_cls)

        if freeze_to is not None:
            self.freeze_to(freeze_to)

        if finetune:
            if freeze_to is not None and isinstance(freeze_to, int):
                msg = "You have sprecified freeze_to along with finetune"
                _logger.warning(msg)
            _logger.info("Freezing all the model parameters except for the classifier")
            self.freeze()

            classifier = ["head", "head_dist"]

            for name, module in self.model.named_children():
                if name in classifier:
                    for p in module.parameters():
                        p.requires_grad_(True)

        store_attr("wd, lr, filter_wd")

    def forward(self, batched_inputs: torch.Tensor) -> torch.Tensor:
        """
        Runs the batched_inputs through the created model.
        """
        out = self.model(batched_inputs)
        return out

    @classmethod
    def from_config_dict(cls, cfg: DictConfig):
        """
        Instantiate the Meta Architecture from gale config
        """
        # fmt: off
        input_shape = ShapeSpec(cfg.input.channels, cfg.input.height, cfg.input.width)
        _logger.debug(f"Inputs: {input_shape}")
        instance = super().from_config_dict(cfg.model.meta_architecture.init_args, input_shape=input_shape)
        param_count = get_human_readable_count(sum([m.numel() for m in instance.parameters()]))
        _logger.debug('{} created, param count: {}.'.format(cfg.model.meta_architecture.init_args.model_name, param_count))
        # fmt: on
        return instance

    def build_param_dicts(self):
        """
        Builds up the Paramters dicts for optimization.
        """
        if self.filter_wd:
            param_lists = add_weight_decay(
                self.model,
                weight_decay=self.wd,
                skip_list=self.model.no_weight_decay(),
            )
            param_lists[0]["lr"] = self.lr
            param_lists[1]["lr"] = self.lr
        else:
            ps = trainable_params(self.model)
            param_lists = dict(params=ps, lr=self.lr, wd=self.wd)
        return param_lists

    def get_lrs(self) -> List:
        """
        Returns a List containining the Lrs' for
        each parameter group. This is required to build schedulers
        like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
        the max lrs' for all the Param Groups.
        """
        lrs = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
        return lrs

<IPython.core.display.Javascript object>

In [None]:
show_doc(ViT)

<h2 id="ViT" class="doc_header"><code>class</code> <code>ViT</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>ViT</code>(**`model_name`**:`str`, **`input_shape`**:`ShapeSpec`, **`lr`**:`float`=*`0.001`*, **`wd`**:`float`=*`1e-05`*, **`pretrained`**:`bool`=*`True`*, **`freeze_to`**:`Optional`\[`int`\]=*`None`*, **`finetune`**:`Optional`\[`bool`\]=*`None`*, **`act`**:`Optional`\[`str`\]=*`None`*, **`reset_classifier`**:`bool`=*`True`*, **`filter_wd`**:`bool`=*`True`*, **`num_classes`**=*`1000`*, **`drop_rate`**=*`0.0`*, **`attn_drop_rate`**=*`0.0`*, **`drop_path_rate`**=*`0.0`*, **\*\*`kwargs`**) :: [`GaleModule`](/gale/core.classes.html#GaleModule)

A interface to create a Vision Transformer from timm. For available model check :
https://github.com/rwightman/pytorch-image-models/timm/models/vision_transformer.py

<IPython.core.display.Javascript object>

**Arguments :**

1. `input_shape` (ShapeSpec): input image shape. For ViT `height=width` and check the above link for avilable model shapes.
2. `model_name` (str): name of the ViT model, check the above link for avilable models.
3. `pretrained` (bool): load weights pretrained on imagenet.
4. `act` (str): name of the activation layer. Must be registerd in `ACTIVATION_REGISTRY`
5. `num_classes` (int): num output classes.
6. `drop_rate` (float): dropout rate.
7. `attn_drop_rate` (float): attention dropout rate.
8. `drop_path_rate` (float): stochastic depth rate.
9. `reset_classifier` (bool): resets the weights of the classifier.
10. `freeze_to` (int): Freeze the param meter groups of the model upto n.
11. `finetune` (bool): Freeze all the layers and keep only the `classifier` trainable.

In [None]:
inp = ShapeSpec(3, 224, 224)

m = ViT(
    model_name="vit_small_patch16_224",
    pretrained=False,
    input_shape=inp,
    finetune=True,
    reset_classifier=True,
    num_classes=10,
)

<IPython.core.display.Javascript object>

In [None]:
i = torch.randn(2, inp.channels, inp.height, inp.width)
o = m(i)

<IPython.core.display.Javascript object>

Similar to `GeneralizedImageClassifier` we can also instantiate `ViT` from a config. `ViT` does not require neither a `backbone` nor a `head` configuration. We just need the particular initialization arguments for the vit model defined in `model_name`.

> Note: You input shape must match the dimensions that the Vision Transformer model supports. Unlike `GeneralizedImageClassifier`, `ViT` is dependent on the shape.

In [None]:
from dataclasses import dataclass, field
from omegaconf import MISSING, OmegaConf

<IPython.core.display.Javascript object>

In [None]:
@dataclass
class ModelConf:
    model_name: str = "vit_small_patch16_224"
    pretrained: bool = False
    finetune: bool = True
    reset_classifier: bool = True
    num_classes: int = 10


inp = ShapeSpec(3, 224, 224)

meta_args = OmegaConf.structured(ModelConf())

meta = OmegaConf.create()
meta.name = "ViT"
meta.init_args = meta_args

i = OmegaConf.create()
i.channels = inp.channels
i.height = inp.height
i.width = inp.width

C = OmegaConf.create()
C.input = i
C.model = OmegaConf.create()
C.model.meta_architecture = meta

# print(OmegaConf.to_yaml(C, resolve=True))

<IPython.core.display.Javascript object>

In [None]:
m = ViT.from_config_dict(C)

assert isinstance(m, ViT)
assert isinstance(m.model, VisionTransformer)

<IPython.core.display.Javascript object>

In [None]:
# collapse-output
print(m)

ViT(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=2304, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2304, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768,), eps=1

<IPython.core.display.Javascript object>

In [None]:
i = torch.randn(2, inp.channels, inp.height, inp.width)
o = m(i)

<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script()

Converted 00_core.utils.logger.ipynb.
Converted 00a_core.utils.visualize.ipynb.
Converted 00b_core.utils.structures.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 02_core.nn.optim.optimizers.ipynb.
Converted 02a_core.nn.optim.lr_schedulers.ipynb.
Converted 03_core.classes.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 04a_classification.modelling.heads.ipynb.
Converted 04b_classification.modelling.meta_arch.common.ipynb.
Converted 04b_classification.modelling.meta_arch.vit.ipynb.
Converted 05_classification.data.common.ipynb.
Converted 05a_classification.data.transforms.ipynb.
Converted 05b_classification.data.build.ipynb.
Converted 06_classification.task.ipynb.
Converted 07_collections.pandas.ipynb.
Converted 07a_collections.callbacks.notebook.ipynb.
Converted 07b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>