In [None]:
import numpy
import matplotlib.pyplot

import yaml 

import wandb

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

import torchmetrics

import lightning as L
from pytorch_lightning.loggers import WandbLogger

In [None]:
# test
backbone_weights = torchvision.models.ResNet34_Weights.DEFAULT
preprocessor = backbone_weights.transforms()
backbone = torchvision.models.resnet34(weights=backbone_weights)

In [None]:
with open("parameters.yaml", "r") as yaml_file:
    parameters = yaml.load(yaml_file, Loader=yaml.FullLoader)
from pprint import pprint
pprint(parameters)

In [None]:
# fully-supervised fine-tuning
class Backbone(L.LightningModule):
    
    def __init__(self, n_classes, user_parameters):
        super().__init__()
        
        self.n_classes = n_classes
        self.user_parameters = user_parameters
        
        self.backbone_weights = torchvision.models.ResNet34_Weights.DEFAULT
        self.preprocessor = self.backbone_weights.transforms()
        self.backbone = torchvision.models.resnet34(weights=self.backbone_weights)
        self.n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(self.n_features, self.n_classes)
        
        match self.user_parameters['Loss_Function']['loss_function']:
            case 'cross_entropy':
                self.loss_function = F.cross_entropy
            case 'qwk':
                from WeightedKappaLoss import WeightedKappaLoss
                self.loss_function = WeightedKappaLoss(self.n_classes, mode='quadratic')
            case _:
                self.loss_function = F.cross_entropy  # defaults to cross entropy

        self.save_hyperparameters()  # wandb
        
    def forward(self, x):
        x_processed = self.preprocessor(x)
        x_features = self.backbone(x_processed)
        return self.head(x_features)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            [
                {"params": self.backbone.parameters(),
                 "name": "backbone"},
                {"params": self.head.parameters(),
                 "name": "head"},
            ],  
            lr=self.user_parameters['Optimizer']['lr'],
            weight_decay=self.user_parameters['Optimizer']['weight_decay'],
        )
        
        return optimizer
        
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        probas = F.softmax(logits, dim=1)
        y_hat = probas.argmax(dim=1)
        
        accuracy = torchmetrics.functional.accuracy(
            y_hat, y, task='multiclass', num_classes=self.n_classes
        )
        
        qwk = torchmetrics.functional.cohen_kappa(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
            weights='quadratic'
        )   
        
        recall = torchmetrics.functional.recall(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
        )
        
        self.log("train/loss", loss)  # wandb
        self.log("train/accuracy", accuracy)  # wandb
        self.log("train/recall", recall)  # wandb
        self.log("train/qwk", qwk)  # wandb
        
        self.log("trainable_parameters", sum([p.numel() for p in self.parameters() if p.requires_grad]))
        
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        probas = F.softmax(logits, dim=1)
        y_hat = probas.argmax(dim=1)
        
        accuracy = torchmetrics.functional.accuracy(
            y_hat, y, task='multiclass', num_classes=self.n_classes
        )
        
        qwk = torchmetrics.functional.cohen_kappa(
            y_hat, y, task='multiclass', num_classes=self.n_classes,
            weights='quadratic'
        )
        
        
        self.log("val/loss", loss)  # wandb
        self.log("val/accuracy", accuracy)  # wandb
        self.log("val/qwk", qwk)  # wandb

    
        

In [None]:
model = Backbone(n_classes=3, user_parameters=parameters)

In [None]:
wandb_logger = WandbLogger(
    project='debug-runs',
    config=parameters,
)

In [None]:
fake_train_data = torchvision.datasets.FakeData(
    size=1000,
    image_size=(3, 512, 512),
    num_classes=3,
    transform=torchvision.transforms.ToTensor(),
)

fake_val_data = torchvision.datasets.FakeData(
    size=100,
    image_size=(3, 512, 512),
    num_classes=3,
    transform=torchvision.transforms.ToTensor()
)

fake_train_data

In [None]:
fake_train_dataloader = torch.utils.data.DataLoader(
    batch_size=parameters['Training']['batch_size'],
    dataset=fake_train_data,
    shuffle=True,
    num_workers=4,
)

fake_val_dataloader = torch.utils.data.DataLoader(
    batch_size=parameters['Training']['batch_size'],
    dataset=fake_val_data,
    num_workers=4,
)

In [None]:
callback_model_checkpoint = L.pytorch.callbacks.ModelCheckpoint(
    dirpath='checkpoints',
    filename='{epoch}',
    monitor='val/loss',
    save_last=True,
    save_top_k=1,
)

callback_early_stopping = L.pytorch.callbacks.EarlyStopping(
    monitor='val/loss',
    patience=parameters['Training']['early_stopping_patience'],
)

callback_backbone_finetuning = L.pytorch.callbacks.BackboneFinetuning(
    unfreeze_backbone_at_epoch=parameters['Training']['epochs_before_unfreeze'],
    lambda_func=lambda lr: parameters['Training']['gain_after_unfreeze'],
    backbone_initial_ratio_lr=parameters['Training']['gain_before_unfreeze'],
)

callback_learning_rate_monitor = L.pytorch.callbacks.LearningRateMonitor(logging_interval='step')

callbacks = [
    # callback_model_checkpoint,
    # callback_early_stopping,
    callback_backbone_finetuning,
    callback_learning_rate_monitor,
]

In [None]:
model = Backbone(n_classes=3, user_parameters=parameters)
trainer = L.Trainer(
    limit_train_batches=parameters['Training']['limit_train_batches'], 
    max_epochs=parameters['Training']['max_epochs'],
    logger=wandb_logger,
    callbacks=callbacks,
    log_every_n_steps=parameters['Training']['log_every_n_steps'],
)

In [None]:
trainer.fit(model, fake_train_dataloader, fake_val_dataloader)
wandb.finish()