## Dataset Manager

In [None]:
import os
import random
from PIL import Image
from typing import Optional, Union

import torch
from torchvision import transforms

class DataManager:
    def __init__(self, root_dir: str = "./data", size: tuple[int, int] = (256, 256)) -> None:
        
        self.domains = [dirs for dirs in os.listdir(root_dir) if dirs not in ["train", "test"]]
        print(f"{len(self.domains)} domains found: {self.domains}")

        self.size = size
        self.root_dir = root_dir

        self.transforms = transforms.ToTensor()

        num_class = {}
        self.data = {
            "train": {},
            "test": {}
        }
        for domain in self.domains:
            domain_dir = os.path.join(root_dir, domain)
            for class_dir in os.listdir(domain_dir):
                if class_dir not in num_class:
                    num_class[class_dir] = 0
                num_class[class_dir] += 1

            for split in ["train", "test"]:
                with open(os.path.join(root_dir, split, f"{domain}_{split}.txt")) as f:
                    self.data[split][domain] = [[line.strip().split()[0], int(line.strip().split()[1])] for line in f.readlines()]

        underact = 0
        for key, val in num_class.items():
            if val != len(self.domains):
                print(f"class {key} only found in {val} domains")
                underact += 1

        if underact:
            print(f"{underact} classes are in minority across the domains")
        else:
            print(f"All classes are present across all the domains")

        print(f"Total number of classes: {len(num_class)}")

    def sample(self, domain: str, split: str = "train", batch_size: int = 1, return_tensors: bool = False) -> Union[
        Optional[tuple[Image.Image, int]],
        Optional[tuple[list[Image.Image], list[int]]],
        Optional[tuple[torch.Tensor, torch.Tensor]]
    ]:

        """
        Samples data points from the specified domain and split
    
        Args:
            domain (str): The domain to sample from. Must be one of the available domains in the dataset.
            split (str, optional): The dataset split to use. Defaults to "train".
            batch_size (int, optional): Number of samples to return. Defaults to 1.
            return_tensors (bool, optional): If True, returns PyTorch tensors instead of PIL images. Defaults to False.
    
        Returns:
            Union[
                Optional[tuple[Image.Image, int]],
                Optional[tuple[list[Image.Image], list[int]]],
                Optional[tuple[torch.Tensor, torch.Tensor]]
            ]
        """
        
        if domain not in self.domains:
            print(f"Domain: {domain} not found in the dataset, available domains are: {self.domains}")
            return None

        samples = random.choices(self.data[split][domain], k=batch_size)
        
        images = []
        labels = []

        for img_path, idx in samples:
            img = Image.open(os.path.join(self.root_dir, img_path)).convert("RGB").resize(self.size)
            images.append(img)
            labels.append(idx)

        if return_tensors:
            images = torch.stack([self.transforms(img) for img in images])
            labels = torch.tensor(labels)
            return (images, labels)
        else:
            if batch_size == 1:
                return (images[0], labels[0])
            else:
                return (images, labels)

data_manager = DataManager()

In [None]:
[img, idx] = data_manager.sample("real", batch_size = 10, return_tensors = True)

## ResNet

In [None]:
import torch
import torch.nn as nn
from collections import OrderedDict

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels, self.out_channels =  in_channels, out_channels
        self.blocks = nn.Identity()
        self.shortcut = nn.Identity()   
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels

class ResNetResidualBlock(ResidualBlock):
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels)
        self.expansion, self.downsampling, self.conv = expansion, downsampling, conv
        self.shortcut = nn.Sequential(OrderedDict(
        {
            'conv' : nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
                      stride=self.downsampling, bias=False),
            'bn' : nn.BatchNorm2d(self.expanded_channels)
            
        })) if self.should_apply_shortcut else None
        
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels
        
def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(OrderedDict({'conv': conv(in_channels, out_channels, *args, **kwargs), 
                          'bn': nn.BatchNorm2d(out_channels) }))

class ResNetBasicBlock(ResNetResidualBlock):
    expansion = 1
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation(),
            conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),
        )
    

class ResNetBottleNeckBlock(ResNetResidualBlock):
    expansion = 4
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        self.blocks = nn.Sequential(
           conv_bn(self.in_channels, self.out_channels, self.conv, kernel_size=1),
             activation(),
             conv_bn(self.out_channels, self.out_channels, self.conv, kernel_size=3, stride=self.downsampling),
             activation(),
             conv_bn(self.out_channels, self.expanded_channels, self.conv, kernel_size=1),
        )
    

class ResNetLayer(nn.Module):
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        downsampling = 2 if in_channels != out_channels else 1
        
        self.blocks = nn.Sequential(
            block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling),
            *[block(out_channels * block.expansion, 
                    out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)]
        )

    def forward(self, x):
        x = self.blocks(x)
        return x

class ResNetEncoder(nn.Module):
    def __init__(self, in_channels=3, blocks_sizes=[64, 128, 256, 512], deepths=[2,2,2,2], 
                 activation=nn.ReLU, block=ResNetBasicBlock, *args,**kwargs):
        super().__init__()
        
        self.blocks_sizes = blocks_sizes
        
        self.gate = nn.Sequential(
            nn.Conv2d(in_channels, self.blocks_sizes[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(self.blocks_sizes[0]),
            activation(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.in_out_block_sizes = list(zip(blocks_sizes, blocks_sizes[1:]))
        self.blocks = nn.ModuleList([ 
            ResNetLayer(blocks_sizes[0], blocks_sizes[0], n=deepths[0], activation=activation, 
                        block=block,  *args, **kwargs),
            *[ResNetLayer(in_channels * block.expansion, 
                          out_channels, n=n, activation=activation, 
                          block=block, *args, **kwargs) 
              for (in_channels, out_channels), n in zip(self.in_out_block_sizes, deepths[1:])]       
        ])
        
        
    def forward(self, x):
        x = self.gate(x)
        for block in self.blocks:
            x = block(x)
        return x

class ResnetDecoder(nn.Module):
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.decoder = nn.Linear(in_features, n_classes)

    def forward(self, x):
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x

class ResNet(nn.Module):
    
    def __init__(self, in_channels, n_classes, *args, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.in_channels = in_channels
        self.encoder = ResNetEncoder(in_channels, *args, **kwargs)
        self.decoder = ResnetDecoder(self.encoder.blocks[-1].blocks[-1].expanded_channels, n_classes)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    @staticmethod
    def load_model(model_name: str, n_classes: Optional[int] = None, in_channels: int = 3) -> 'ResNet':

        model_configs = {
            "resnet18": {"block": ResNetBasicBlock, "depths": [2, 2, 2, 2]},
            "resnet34": {"block": ResNetBasicBlock, "depths": [3, 4, 6, 3]},
            "resnet50": {"block": ResNetBottleNeckBlock, "depths": [3, 4, 6, 3]},
            "resnet101": {"block": ResNetBottleNeckBlock, "depths": [3, 4, 23, 3]},
            "resnet152": {"block": ResNetBottleNeckBlock, "depths": [3, 8, 36, 3]}
        }

        if model_name[-4:] == ".pth":
            ckpt = torch.load(model_name, weights_only = False)
            model_name = ckpt["model_type"]
            n_classes = ckpt["n_classes"]
            in_channels = ckpt["in_channels"]
        
        config = model_configs.get(model_name.lower())
        
        if config:
            model = ResNet(in_channels, n_classes, block=config["block"], depths=config["depths"])
            model.model_name = model_name
        else:
            raise ValueError(f"{model_name} not implemented, available models are: {list(model_configs.keys())}")

        if model_name[-4:] == ".pth":
            model.load_state_dict(ckpt["state_dict"])

        return model

    def save_model(self, path: str) -> None:
        ckpt = {
            "model_type": self.model_name,
            "n_classes": self.n_classes,
            "in_channels": self.in_channels,
            "state_dict": self.state_dict() 
        }
        torch.save(ckpt, path)

## Base Model Training

In [2]:
from utils import DataManager
from models import ResNet

In [7]:
data_manager = DataManager(root_dir = "./data")

model = ResNet.load_model("resnet50", n_classes = 345, in_channels = 3)

5 domains found: ['real', 'quickdraw', 'clipart', 'painting', 'sketch']
All classes are present across all the domains
Total number of classes: 345
