In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

from wrappers import get_dataset
from utils import plot_image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
models.list_models()

['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'fasterrcnn_mobilenet_v3_large_320_fpn',
 'fasterrcnn_mobilenet_v3_large_fpn',
 'fasterrcnn_resnet50_fpn',
 'fasterrcnn_resnet50_fpn_v2',
 'fcn_resnet101',
 'fcn_resnet50',
 'fcos_resnet50_fpn',
 'googlenet',
 'inception_v3',
 'keypointrcnn_resnet50_fpn',
 'lraspp_mobilenet_v3_large',
 'maskrcnn_resnet50_fpn',
 'maskrcnn_resnet50_fpn_v2',
 'maxvit_t',
 'mc3_18',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mvit_v1_b',
 'mvit_v2_s',
 'quantized_googlenet',
 '

In [3]:
class Classifier(nn.Module):
    
    def __init__(self, 
                 backbone: str,
                 n_classes: int,
                 n_hidden: int=128,
                 ):
        super().__init__()
                
        self.n_classes = n_classes
        self.n_features = None
        self.n_hidden = n_hidden
        
        self.model = None
        self._setup_layers(backbone)

    def _setup_layers(self, backbone):
        self.n_features, preprocess, backbone = self._load_backbone(backbone)
        head = self._load_head()
        self.model = nn.ModuleList([
            preprocess, 
            backbone,
            head,
        ])
                
    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x
    
    def extract_features(self, x):
        for layer in self.model[:-1]:
            x = layer(x)
        return x
    
    @staticmethod
    def _load_backbone(backbone):

        match backbone.lower():
                case 'resnet18':
                    weights = models.ResNet18_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet18(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
    
                case 'resnet34':
                    weights = models.ResNet34_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet34(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
    
                case 'resnet50':
                    weights = models.ResNet50_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet50(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
                
                case 'resnet101':
                    weights = models.ResNet101_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet101(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
                
                case 'resnet152':
                    weights = models.ResNet152_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet152(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
                
                case 'resnet152':
                    weights = models.ResNet152_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.resnet152(weights=weights)
                    n_features = model.fc.in_features
                    model.fc = nn.Identity()
                    return n_features, preprocess, model
            
                case 'mobilenet_v3_small':
                    weights = models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.mobilenet_v3_small(weights=weights)
                    n_features = model.classifier[0].in_features
                    model.classifier = nn.Identity()
                    return n_features, preprocess, model
            
                case 'mobilenet_v3_large':
                    weights = models.MobileNet_V3_Large_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.mobilenet_v3_large(weights=weights)
                    n_features = model.classifier[0].in_features
                    model.classifier = nn.Identity()
                    return n_features, preprocess, model
                
                case 'maxvit_t':
                    weights = models.MaxVit_T_Weights.IMAGENET1K_V1
                    preprocess = weights.transforms()
                    model = models.maxvit_t(weights=weights)
                    n_features = model.classifier[3].in_features
                    for i in [3, 4, 5]:
                        model.classifier[i] = nn.Identity()
                    return n_features, preprocess, model
            
                case _:
                    raise NotImplementedError(backbone)
            
    def _load_head(self):
        head = nn.Sequential(
            nn.Linear(self.n_features, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, self.n_classes),
        )
        return head

In [4]:
batch_size = 32
train, validation, test, n_classes = get_dataset("letter_recognition", batch_size)

In [5]:
n_classes

26

In [6]:
my_model = Classifier(
    backbone="mobilenet_v3_small",
    n_classes=n_classes,
    n_hidden=128,
)

In [7]:
my_model

Classifier(
  (model): ModuleList(
    (0): ImageClassification(
        crop_size=[224]
        resize_size=[256]
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        interpolation=InterpolationMode.BILINEAR
    )
    (1): MobileNetV3(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_si

In [8]:
for batch in train:
    x, y = batch["image"], batch["label"]
    pred = my_model(x)
    ypred = nn.functional.softmax(pred, dim=1)
    acc = (y == ypred.argmax(dim=1)).sum() / len(y)
    print([(a.item(), b.item()) for a, b in zip(y, ypred.argmax(dim=1))])
    print(acc.item())
    break

[(8, 13), (0, 13), (5, 19), (24, 13), (9, 7), (18, 19), (6, 13), (22, 13), (15, 10), (19, 19), (10, 13), (12, 19), (4, 13), (19, 6), (6, 7), (4, 13), (24, 13), (4, 19), (22, 19), (17, 10), (18, 13), (2, 7), (1, 19), (24, 14), (23, 10), (19, 10), (8, 10), (14, 7), (8, 7), (15, 10), (24, 13), (23, 7)]
0.03125
