In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
import kornia as K
import pytorch_lightning as pl
import torch.nn as nn
import torch
from PIL import Image
import numpy as np
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn.init as init
import torchmetrics
import torch.optim as optim
import yaml
from pytorch_lightning.callbacks import ModelCheckpoint


In [None]:
#!pip install wandb -qU
import wandb
from pytorch_lightning.loggers import WandbLogger
wandb.login()

# Data

In [2]:
class PreProcess(nn.Module):
    """Module to perform pre-process using Kornia on torch tensors."""
    def __init__(self) -> None:
        super().__init__()

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: Image) -> torch.Tensor:
        x_tmp: np.ndarray = np.array(x)  # HxWxC
        x_out: torch.Tensor = K.image_to_tensor(x_tmp, keepdim=True)  # CxHxW
        x_out: K.enhance.Normalize(0.0, self._max_val)(x_out)
        return x_out.float()
    
# Data Module for Lightning
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.path = "Python/Lightning_Resnet19/DataModules/datasets/CIFAR10"
        self.classes = 10
        self.batch_size = batch_size
        '''self.transform = transforms.Compose([
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])'''
        self.transform = PreProcess()

    def prepare_data(self):
        CIFAR10(root=self.path, train=True, download=True)
        CIFAR10(root=self.path, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            cifar10_full = CIFAR10(root=self.path, train=True, transform=self.transform)
            self.train_dataset, self.val_dataset = random_split(cifar10_full, [45000, 5000])

        if stage == 'test' or stage is None:
            self.test_dataset = CIFAR10(root=self.path, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=7, persistent_workers=True, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size*4, num_workers=7, persistent_workers=True, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size*4, num_workers=7)


# Models

In [2]:
def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='relu')
        init.constant_(m.bias, 0)

# ResNet19_fc model
class Resnet19_fc(nn.Module):
    def __init__(self, num_classes):
        super(Resnet19_fc, self).__init__()
        self.epoch = 0
        self.resnet_base = resnet18(weights=ResNet18_Weights.DEFAULT)
        num_ftrs = self.resnet_base.fc.in_features
        self.resnet_base.fc = nn.Sequential(
            nn.Linear(num_ftrs, num_ftrs // 2),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(num_ftrs // 2, num_classes)
        )

    def forward(self, x):
        return self.resnet_base(x)

# ResNet19_conv model
class Resnet19_conv(nn.Module):
    def __init__(self, num_classes):
        super(Resnet19_conv, self).__init__()
        self.epoch = 0
        self.resnet_base = resnet18(weights=ResNet18_Weights.DEFAULT)
        additional_conv_layer = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.resnet_base.layer4.add_module("additional_conv_layer", additional_conv_layer)
        num_ftrs = self.resnet_base.fc.in_features
        self.resnet_base.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet_base(x)

class Basic_block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(Basic_block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
                nn.BatchNorm2d(out_channels)
            )
        self.relu_out = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu_out(out)
        return out

class Resnet19_snn(nn.Module):
    def __init__(self, num_classes):
        super(Resnet19_snn, self).__init__()
        self.epoch = 0

        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()

        self.block1 = nn.Sequential(
            Basic_block(128, 128, stride=1),
            Basic_block(128, 128, stride=1),
            Basic_block(128, 128, stride=1)
        )

        self.block2 = nn.Sequential(
            Basic_block(128, 256, stride=2),
            Basic_block(256, 256, stride=1),
            Basic_block(256, 256, stride=1),
        )

        self.block3 = nn.Sequential(
            Basic_block(256, 512, stride=2),
            Basic_block(512, 512, stride=1)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.fc2 = nn.Linear(256, num_classes)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu2(x)
        x = self.fc2(x)
        
        return x
    
def select_model(model_type, num_classes):
    """
    Args:
    model_type (str): Type of the model ('fc', 'conv', or 'snn').
    num_classes (int): Number of classes for the final output layer.
    device: mps (mac), cuda, cpu.

    Returns:
    torch.nn.Module: The selected ResNet model.
    """
    if model_type == 'fc':
        model = Resnet19_fc(num_classes)
    elif model_type == 'conv':
        model = Resnet19_conv(num_classes)
    elif model_type == 'snn':
        model = Resnet19_snn(num_classes)
        model.apply(initialize_weights)
    else:
        raise ValueError("Model type not found")

    return model

# Lit_module

In [None]:
class DataAugmentation(nn.Module):
    """Module to perform data augmentation using Kornia on torch tensors."""

    def __init__(self, apply_color_jitter: bool = False) -> None:
        super().__init__()

        self._max_val: float = 255.0

        self.transforms = nn.Sequential(#K.enhance.Normalize(0.0, self._max_val),
                                        K.augmentation.RandomHorizontalFlip(p=0.5),
                                        K.augmentation.RandomVerticalFlip(p=0.5))

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_out = self.transforms(x)  # BxCxHxW
        return x_out
    
class ResNetLightningModule(pl.LightningModule):
    def __init__(self, model, config, num_classes):
        super().__init__()
        self.model = model
        self.config = config
        self.transform = DataAugmentation()
        # Metrics
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        # Loss function
        if self.config['loss'] == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters(ignore=['model'])

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_aug = self.transform(x)
        y_hat = self.forward(x_aug)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('train_acc', self.train_acc(y_hat, y), prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('val_acc', self.val_acc(y_hat, y), prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        acc = self.test_acc(logits, y)
        self.log("test_acc", acc)
        return acc

    def configure_optimizers(self):
        if self.config['optimizer'] == 'adam':
            optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
        elif self.config['optimizer'] == 'sgd':
            optimizer = optim.SGD(self.model.parameters(), lr=self.config['lr'],
                                    momentum=self.config['momentum'],
                                    weight_decay=self.config['weight_decay'])

        if self.config['scheduler'] == 'step_lr':
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.config['step_size'],
                                                    gamma=self.config['gamma'])
            return [optimizer], [scheduler]
        return optimizer

<IPython.core.display.Javascript object>

# Main

In [None]:
def main():
    pl.seed_everything(1)
    wandb.login()

    # Load config file
    with open('Python/Lightning_Resnet19/config.yaml', 'r') as file:
        config = yaml.safe_load(file)
        
    # Data Module
    dm = CIFAR10DataModule()

    # Model
    model_name = 'snn'
    num_classes = dm.classes#todo
    model = select_model(model_name, num_classes)
    #model_1 = resnet19.select_model('fc', num_classes, device)
    #model_2 = resnet19.select_model('conv', num_classes, device)

    # Lightning Module
    lit_model = ResNetLightningModule(model, config, num_classes)

    # Logger
    wandb_logger = WandbLogger(project="lit_resnet19", name = model_name, log_model="True")
    # Trainer
    trainer = pl.Trainer(max_epochs=config['max_epochs'], logger=wandb_logger,
                         callbacks=[ ModelCheckpoint(dirpath="./checkpoints",
                                             save_on_train_epoch_end=True,
                                             filename='snn-{epoch}-loss:{val_loss:.2f}-acc:{val_acc:.2f}',
                                             save_top_k = 3,
                                             monitor="val_acc", 
                                             mode="max",
                                             ),
                                    ])
    trainer.fit(lit_model, datamodule = dm)
    wandb.finish()

In [None]:
if __name__ == '__main__':
    main()