In [None]:
# hide
# skip
!git clone https://github.com/benihime91/gale # install gale on colab
!pip install -e "gale[dev]"

In [None]:
# default_exp classification.model.meta_arch.vision_transformer

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *
from timm.utils import *

warnings.filterwarnings("ignore")

setup_default_logging()

<IPython.core.display.Javascript object>

# Meta Architectures : Vision Transformer (ViT) 
> Pretrained Vision Transformers modified for use in gale from timm

In [None]:
# export
import logging
from collections import namedtuple
from dataclasses import dataclass
from typing import *

import timm
import torch
from fastcore.all import store_attr, use_kwargs_dict
from omegaconf import MISSING, DictConfig, OmegaConf
from pytorch_lightning.core.memory import get_human_readable_count
from timm.optim.optim_factory import add_weight_decay

from gale.core_classes import BasicModule
from gale.torch_utils import trainable_params
from gale.utils.activs import ACTIVATION_REGISTRY
from gale.utils.shape_spec import ShapeSpec

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# export
# @TODO: Add support for Discriminative Lr's
class VisionTransformer(BasicModule):
    _hypers = namedtuple("hypers", field_names=["lr", "wd"])
    """
    A interface to create a Vision Transformer from timm. For available model check :
    https://github.com/rwightman/pytorch-image-models/timm/models/vision_transformer.py
    """

    @use_kwargs_dict(
        keep=True,
        num_classes=1000,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
    )
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        lr: float = 1e-03,
        wd: float = 1e-05,
        pretrained: bool = True,
        freeze_to: Optional[int] = None,
        finetune: Optional[bool] = None,
        act: Optional[str] = None,
        reset_classifier: bool = True,
        filter_wd: bool = True,
        **kwargs,
    ):
        """
        Arguments:
        1. `input_shape` (ShapeSpec): input image shape. For ViT `height=width` and check the above link for avilable model shapes.
        2. `model_name` (str): name of the ViT model, check the above link for avilable models.
        3. `pretrained` (bool): load weights pretrained on imagenet.
        4. `act` (str): name of the activation layer. Must be registerd in `ACTIVATION_REGISTRY`
        5. `num_classes` (int): num output classes.
        6. `drop_rate` (float): dropout rate.
        7. `attn_drop_rate` (float): attention dropout rate.
        8. `drop_path_rate` (float): stochastic depth rate.
        9. `reset_classifier` (bool): resets the weights of the classifier.
        10. `freeze_to` (int): Freeze the param meter groups of the model upto n.
        11. `finetune` (bool): Freeze all the layers and keep only the `classifier` trainable.
        """
        super(VisionTransformer, self).__init__()
        # create model from timm
        assert input_shape.height == input_shape.width
        in_chans = input_shape.channels

        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)

        self.model = timm.create_model(
            model_name, pretrained, in_chans=in_chans, act=act, **kwargs
        )

        if reset_classifier:
            num_cls = kwargs.pop("num_classes")
            self.model.reset_classifier(num_cls)

        if freeze_to is not None:
            self.freeze_to(freeze_to)

        if finetune:
            if freeze_to is not None and isinstance(freeze_to, int):
                msg = "You have sprecified freeze_to along with finetune"
                _logger.warning(msg)
            _logger.info("Freezing all the model parameters except for the classifier")
            self.freeze()

            classifier = ["head", "head_dist"]

            for name, module in self.model.named_children():
                if name in classifier:
                    for p in module.parameters():
                        p.requires_grad_(True)

        store_attr("wd, lr, filter_wd, input_shape")

    def forward(self, batched_inputs: torch.Tensor) -> torch.Tensor:
        """
        Runs the batched_inputs through the created model.
        """
        out = self.model(batched_inputs)
        return out

    @classmethod
    def from_config_dict(cls, cfg: DictConfig):
        """
        Instantiate the Meta Architecture from gale config
        """
        # fmt: off
        input_shape = ShapeSpec(cfg.input.channels, cfg.input.height, cfg.input.width)
        _logger.debug(f"Inputs: {input_shape}")
        instance = super().from_config_dict(cfg.model.meta_architecture.init_args, input_shape=input_shape)
        param_count = get_human_readable_count(sum([m.numel() for m in instance.parameters()]))
        _logger.debug('{} created, param count: {}.'.format(cfg.model.meta_architecture.init_args.model_name, param_count))
        # fmt: on
        return instance

    def build_param_dicts(self):
        """
        Builds up the Paramters dicts for optimization.
        """
        if self.filter_wd:
            param_lists = add_weight_decay(
                self.model,
                weight_decay=self.wd,
                skip_list=self.model.no_weight_decay(),
            )
            param_lists[0]["lr"] = self.lr
            param_lists[1]["lr"] = self.lr
        else:
            ps = trainable_params(self.model)
            param_lists = dict(params=ps, lr=self.lr, wd=self.wd)
        return param_lists

    @property
    def hypers(self) -> Tuple:
        """
        Returns list of parameters like `lr` and `wd`
        for each param group
        """
        lrs = []
        wds = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
            wds.append(p["weight_decay"])
        return self._hypers(lrs, wds)

<IPython.core.display.Javascript object>

In [None]:
show_doc(VisionTransformer)

<h2 id="VisionTransformer" class="doc_header"><code>class</code> <code>VisionTransformer</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>VisionTransformer</code>(**`model_name`**:`str`, **`input_shape`**:`ShapeSpec`, **`lr`**:`float`=*`0.001`*, **`wd`**:`float`=*`1e-05`*, **`pretrained`**:`bool`=*`True`*, **`freeze_to`**:`Optional`\[`int`\]=*`None`*, **`finetune`**:`Optional`\[`bool`\]=*`None`*, **`act`**:`Optional`\[`str`\]=*`None`*, **`reset_classifier`**:`bool`=*`True`*, **`filter_wd`**:`bool`=*`True`*, **`num_classes`**=*`1000`*, **`drop_rate`**=*`0.0`*, **`attn_drop_rate`**=*`0.0`*, **`drop_path_rate`**=*`0.0`*, **\*\*`kwargs`**) :: [`BasicModule`](/gale/core-classes.html#BasicModule)

Abstract class offering interface which should be implemented by all `Backbones`,
`Heads` and `Meta Archs` in gale.

<IPython.core.display.Javascript object>

**Arguments :**

1. `input_shape` (ShapeSpec): input image shape. For ViT `height=width` and check the above link for avilable model shapes.
2. `model_name` (str): name of the ViT model, check the above link for avilable models.
3. `pretrained` (bool): load weights pretrained on imagenet.
4. `act` (str): name of the activation layer. Must be registerd in `ACTIVATION_REGISTRY`
5. `num_classes` (int): num output classes.
6. `drop_rate` (float): dropout rate.
7. `attn_drop_rate` (float): attention dropout rate.
8. `drop_path_rate` (float): stochastic depth rate.
9. `reset_classifier` (bool): resets the weights of the classifier.
10. `freeze_to` (int): Freeze the param meter groups of the model upto n.
11. `finetune` (bool): Freeze all the layers and keep only the `classifier` trainable.

In [None]:
inp = ShapeSpec(3, 224, 224)

m = VisionTransformer(
    model_name="vit_small_patch16_224",
    pretrained=False,
    input_shape=inp,
    finetune=True,
    reset_classifier=True,
    num_classes=10,
)

Freezing all the model parameters except for the classifier


<IPython.core.display.Javascript object>

In [None]:
i = torch.randn(2, inp.channels, inp.height, inp.width)
o = m(i)

<IPython.core.display.Javascript object>

Similar to `GeneralizedImageClassifier` we can also instantiate `ViT` from a config. `ViT` does not require neither a `backbone` nor a `head` configuration. We just need the particular initialization arguments for the vit model defined in `model_name`.

> Note: You input shape must match the dimensions that the Vision Transformer model supports. Unlike `GeneralizedImageClassifier`, `ViT` is dependent on the shape.

### DataClass

In [None]:
from dataclasses import dataclass, field
from omegaconf import MISSING, OmegaConf

<IPython.core.display.Javascript object>

In [None]:
# export
@dataclass
class VisionTransformerDataClass:
    model_name: str = MISSING
    lr: float = 1e-03
    wd: float = 1e-05
    pretrained: bool = False
    freeze_to: Optional[int] = None
    finetune: Optional[bool] = True
    reset_classifier: bool = True
    filter_wd: bool = True
    drop_rate: float = 0.0
    attn_drop_rate: float = 0.0
    drop_path_rate: float = 0.0
    num_classes: int = MISSING

<IPython.core.display.Javascript object>

Here how a `VisionTransformer` can be instantiated via the config ...

In [None]:
# collapse-output
inp = ShapeSpec(3, 224, 224)

meta_args = OmegaConf.structured(
    VisionTransformerDataClass(model_name="vit_small_patch16_224", num_classes=2)
)

meta = OmegaConf.create()
meta.name = "ViT"
meta.init_args = meta_args

i = OmegaConf.create()
i.channels = inp.channels
i.height = inp.height
i.width = inp.width

C = OmegaConf.create()
C.input = i
C.model = OmegaConf.create()
C.model.meta_architecture = meta

print(OmegaConf.to_yaml(C, resolve=True))

input:
  channels: 3
  height: 224
  width: 224
model:
  meta_architecture:
    name: ViT
    init_args:
      model_name: vit_small_patch16_224
      lr: 0.001
      wd: 1.0e-05
      pretrained: false
      freeze_to: null
      finetune: true
      reset_classifier: true
      filter_wd: true
      drop_rate: 0.0
      attn_drop_rate: 0.0
      drop_path_rate: 0.0
      num_classes: 2



<IPython.core.display.Javascript object>

In [None]:
m = VisionTransformer.from_config_dict(C)
shape = (m.input_shape.channels, m.input_shape.height, m.input_shape.width)
inp = torch.randn(2, *shape)
o = m(inp)
o

Freezing all the model parameters except for the classifier


tensor([[-0.0844,  0.1789],
        [-0.0287, -0.1308]], grad_fn=<AddmmBackward>)

<IPython.core.display.Javascript object>

In [None]:
# hide
# cuda
import pytorch_lightning as pl
import torchmetrics
import torchvision.transforms as T
from fastcore.all import Path
from nbdev.export import Config
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from gale.collections.callbacks.notebook import NotebookProgressCallback
from gale.collections.download import download_and_extract_archive
from gale.schedules import WarmupStepLR
from gale.utils.display import show_images

URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"

# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path)

Using downloaded and verified file: /Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data.zip
Extracting /Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data.zip to /Users/ayushman/Desktop/gale/nbs/data


<IPython.core.display.Javascript object>

In [None]:
# hide
# cuda
data_transforms = {
    "train": T.Compose(
        [
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": T.Compose(
        [
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

training_data = ImageFolder(
    data_path / "hymenoptera_data/train", transform=data_transforms["train"]
)
validation_data = ImageFolder(
    data_path / "hymenoptera_data/val", transform=data_transforms["val"]
)

train_dl = DataLoader(training_data, batch_size=32, shuffle=True)
valid_dl = DataLoader(validation_data, batch_size=32, shuffle=False)

<IPython.core.display.Javascript object>

In [None]:
# hide
# cuda
class Learner(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.train_metric = torchmetrics.Accuracy()
        self.valid_metric = torchmetrics.Accuracy()
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, xb):
        return self.model(xb)

    def training_step(self, batch: Any, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.train_metric(torch.nn.functional.softmax(y_hat), y)
        self.log_dict(dict(loss=loss, acc=acc))
        return loss

    def validation_step(self, batch: Any, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.valid_metric(torch.nn.functional.softmax(y_hat), y)
        self.log_dict(dict(val_loss=loss, val_acc=acc))

    def configure_optimizers(self):
        paramters = self.model.build_param_dicts()
        opt = optim.AdamW(paramters)
        sch = WarmupStepLR(
            opt,
            num_decays=2,
            warmup_epochs=1,
            decay_rate=0.1,
            epochs=self.trainer.max_epochs,
        )
        return [opt], [sch]

<IPython.core.display.Javascript object>

In [None]:
# hide
# cuda
cbs = [
    NotebookProgressCallback(),
    pl.callbacks.LearningRateMonitor(logging_interval="epoch", log_momentum=True),
]

logger = pl.loggers.TensorBoardLogger(
    save_dir="lightning_logs/", name="my_model", default_hp_metric=False
)

trainer = pl.Trainer(max_epochs=7, callbacks=cbs, log_every_n_steps=1, logger=logger)

model = VisionTransformer.from_config_dict(C)
learn = Learner(model)

trainer.fit(learn, train_dataloader=train_dl, val_dataloaders=valid_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Freezing all the model parameters except for the classifier

  | Name         | Type              | Params
---------------------------------------------------
0 | model        | VisionTransformer | 48.0 M
1 | train_metric | Accuracy          | 0     
2 | valid_metric | Accuracy          | 0     
3 | loss_fn      | CrossEntropyLoss  | 0     
---------------------------------------------------
1.5 K     Trainable params
48.0 M    Non-trainable params
48.0 M    Total params
191.948   Total estimated model params size (MB)


epoch,val_loss,val_acc,loss,acc,time,samples/s
0,0.902881,0.48366,0.931735,0.45,94.0353,0.1382
1,0.86319,0.568627,0.902018,0.55,94.2788,0.1379
2,0.778777,0.522876,0.761685,0.55,99.1448,0.1311


1

<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script("04b_classification.model.meta_arch.vit.ipynb")

Converted 04b_classification.model.meta_arch.vit.ipynb.


<IPython.core.display.Javascript object>