In [None]:
# default_exp classification.modelling.meta_arch.common

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Meta Architectures : Generalized Image Classifier 
> Default Model Architectures for Image Classification

In [None]:
# export
import logging
from typing import *

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.core.memory import get_human_readable_count
from torch.nn import Module

from gale.classification.modelling.backbones import ImageClassificationBackbone
from gale.classification.modelling.build import build_backbone, build_head
from gale.classification.modelling.heads import ImageClassificationHead
from gale.core.classes import GaleModule
from gale.core.nn.shape_spec import ShapeSpec

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# hide
from gale.core.utils.logger import setup_logger

setup_logger()
_logger = logging.getLogger("gale.classification.modelling.meta_arch.common")

<IPython.core.display.Javascript object>

In [None]:
# export
class GeneralizedImageClassifier(GaleModule):
    """
    A General Image Classifier. Any models that contains the following 2 components:
    1. Feature extractor (aka backbone)
    2. Image Classification head (Pooling + Classifier)
    """

    def __init__(
        self,
        backbone: ImageClassificationBackbone,
        head: ImageClassificationHead,
    ):
        """
        Arguments:
        1. `backbone`: a `ImageClassificationBackbone` module, must follow gale's backbone interface
        2. `head`: a head containg the classifier. and the pooling layer, must be an instance of
        `ImageClassificationHead`.
        """
        super(GeneralizedImageClassifier, self).__init__()
        self.backbone = backbone
        assert isinstance(backbone, ImageClassificationBackbone)
        self.head = head
        assert isinstance(head, ImageClassificationHead)

    def forward(self, batched_inputs: torch.Tensor) -> torch.Tensor:
        """
        Runs the batched_inputs through `backbone` followed by the `head`.
        Returns a Tensor which contains the logits for the batched_inputs.
        """
        out = self.backbone(batched_inputs)
        out = self.head(out)
        return out

    @classmethod
    def from_config_dict(cls, cfg: DictConfig):
        """
        Instantiate the Meta Architecture from gale config
        """
        # fmt: off
        
        if not hasattr(cfg.model, "backbone"):
            _logger.error("Configuration for model backbone not found")
            raise ValueError
            
        if not hasattr(cfg.model, "head"):
            _logger.error("Configuration for model head not found")
            raise ValueError
            
        input_shape = ShapeSpec(cfg.input.channels, cfg.input.height, cfg.input.width)
        _logger.debug(f"Inputs: {input_shape}")
        
        backbone = build_backbone(cfg, input_shape=input_shape)
        param_count = get_human_readable_count(sum([m.numel() for m in backbone.parameters()]))
        _logger.debug('Backbone {} created, param count: {}.'.format(cfg.model.backbone.name, param_count))
        
        head = build_head(cfg, backbone.output_shape())
        param_count = get_human_readable_count(sum([m.numel() for m in head.parameters()]))
        _logger.debug('Head {} created, param count: {}.'.format(cfg.model.head.name, param_count))
        
        kwds = {"backbone": backbone, "head": head}
        
        instance = cls(**kwds)
        instance._cfg = OmegaConf.to_container(cfg.model, resolve=True)
        instance.input_shape = input_shape
        param_count = get_human_readable_count(sum([m.numel() for m in instance.parameters()]))
        _logger.info("Model created, param count: {}.".format(param_count))
        # fmt: on
        return instance

    def build_param_dicts(self):
        """
        Builds up the Paramters dicts for optimization.
        """
        backbone_params = self.backbone.build_param_dicts()
        head_params = self.head.build_param_dicts()
        return backbone_params + head_params

    def get_lrs(self) -> List:
        """
        Returns a List containining the Lrs' for
        each parameter group. This is required to build schedulers
        like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
        the max lrs' for all the Param Groups.
        """
        lrs = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
        return lrs

<IPython.core.display.Javascript object>

In [None]:
show_doc(GeneralizedImageClassifier)
show_doc(GeneralizedImageClassifier.__init__)

<h2 id="GeneralizedImageClassifier" class="doc_header"><code>class</code> <code>GeneralizedImageClassifier</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>GeneralizedImageClassifier</code>(**`backbone`**:[`ImageClassificationBackbone`](/gale/classification.modelling.backbones.html#ImageClassificationBackbone), **`head`**:[`ImageClassificationHead`](/gale/classification.modelling.heads.html#ImageClassificationHead)) :: [`GaleModule`](/gale/core.classes.html#GaleModule)

A General Image Classifier. Any models that contains the following 2 components:
1. Feature extractor (aka backbone)
2. Image Classification head (Pooling + Classifier)

<h4 id="GeneralizedImageClassifier.__init__" class="doc_header"><code>GeneralizedImageClassifier.__init__</code><a href="__main__.py#L9" class="source_link" style="float:right">[source]</a></h4>

> <code>GeneralizedImageClassifier.__init__</code>(**`backbone`**:[`ImageClassificationBackbone`](/gale/classification.modelling.backbones.html#ImageClassificationBackbone), **`head`**:[`ImageClassificationHead`](/gale/classification.modelling.heads.html#ImageClassificationHead))

Arguments:
1. `backbone`: a [`ImageClassificationBackbone`](/gale/classification.modelling.backbones.html#ImageClassificationBackbone) module, must follow gale's backbone interface
2. `head`: a head containg the classifier. and the pooling layer, must be an instance of
[`ImageClassificationHead`](/gale/classification.modelling.heads.html#ImageClassificationHead).

<IPython.core.display.Javascript object>

In [None]:
show_doc(GeneralizedImageClassifier.from_config_dict)

<h4 id="GeneralizedImageClassifier.from_config_dict" class="doc_header"><code>GeneralizedImageClassifier.from_config_dict</code><a href="__main__.py#L35" class="source_link" style="float:right">[source]</a></h4>

> <code>GeneralizedImageClassifier.from_config_dict</code>(**`cfg`**:`DictConfig`)

Instantiate the Meta Architecture from gale config

<IPython.core.display.Javascript object>

In [None]:
show_doc(GeneralizedImageClassifier.forward)

<h4 id="GeneralizedImageClassifier.forward" class="doc_header"><code>GeneralizedImageClassifier.forward</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>GeneralizedImageClassifier.forward</code>(**`batched_inputs`**:`Tensor`)

Runs the batched_inputs through `backbone` followed by the `head`.
Returns a Tensor which contains the logits for the batched_inputs.

<IPython.core.display.Javascript object>

In [None]:
show_doc(GeneralizedImageClassifier.build_param_dicts)

<h4 id="GeneralizedImageClassifier.build_param_dicts" class="doc_header"><code>GeneralizedImageClassifier.build_param_dicts</code><a href="__main__.py#L71" class="source_link" style="float:right">[source]</a></h4>

> <code>GeneralizedImageClassifier.build_param_dicts</code>()

Builds up the Paramters dicts for optimization.

<IPython.core.display.Javascript object>

> Note: Any Custom Meta Architecture that you build should inherit from `GaleModule` and must be registered in `META_ARCH_REGISTRY`.

## Instantiation via config

In [None]:
from dataclasses import dataclass, field
from omegaconf import MISSING, OmegaConf

<IPython.core.display.Javascript object>

dataclasses for config creation - 

In [None]:
@dataclass
class BackboneConf:
    model_name: str = "resnet18"
    act: Any = None
    lr: Any = 1e-03
    lr_div: Any = 10
    wd: Any = 0.0
    freeze_at: int = 2
    pretrained: bool = True
    drop_block_rate: Any = None
    drop_path_rate: Any = None
    bn_tf: bool = False


@dataclass
class HeadConf:
    num_classes: int = MISSING
    act: str = "ReLU"
    lin_ftrs: Any = None
    ps: Any = 0.5
    concat_pool: bool = True
    first_bn: bool = True
    bn_final: bool = False
    lr: float = 0.002
    wd: float = 0
    filter_wd: bool = False

<IPython.core.display.Javascript object>

In [None]:
b_args = OmegaConf.structured(BackboneConf())
h_args = OmegaConf.structured(HeadConf(num_classes=10))

# Backbone config
b = OmegaConf.create()
b.name = "ResNetBackbone"
b.init_args = b_args

# Head config
h = OmegaConf.create()
h.name = "FastaiHead"
h.init_args = h_args

i = OmegaConf.create()
i.channels = 3
i.height = 224
i.width = 224

m = OmegaConf.create()
m.backbone = b
m.head = h

# config
conf = OmegaConf.create(dict(input=i, model=m))
conf = OmegaConf.structured(conf)

<IPython.core.display.Javascript object>

In [None]:
m = GeneralizedImageClassifier.from_config_dict(conf)

inp = torch.randn(2, m.input_shape.channels, m.input_shape.height, m.input_shape.width)
o = m(inp)

[32m[04/25 22:32:16 gale.classification.modelling.meta_arch.common]: [0mInputs: ShapeSpec(channels=3, height=224, width=224)
[32m[04/25 22:32:16 gale.classification.modelling.meta_arch.common]: [0mBackbone ResNetBackbone created, param count: 11.2 M.
[32m[04/25 22:32:16 gale.classification.modelling.meta_arch.common]: [0mHead FastaiHead created, param count: 532 K.
[32m[04/25 22:32:16 gale.classification.modelling.meta_arch.common]: [0mModel created, param count: 11.7 M.


<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script()

Converted 00_core.utils.logger.ipynb.
Converted 00a_core.utils.visualize.ipynb.
Converted 00b_core.utils.structures.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 02_core.nn.optim.optimizers.ipynb.
Converted 02a_core.nn.optim.lr_schedulers.ipynb.
Converted 03_core.classes.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 04a_classification.modelling.heads.ipynb.
Converted 04b_classification.modelling.meta_arch.common.ipynb.
Converted 04b_classification.modelling.meta_arch.vit.ipynb.
Converted 05_classification.data.common.ipynb.
Converted 05_classification.data.transforms.ipynb.
Converted 06_collections.pandas.ipynb.
Converted 06a_collections.callbacks.notebook.ipynb.
Converted 06b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>