# Notebook settings

In [None]:
do_training = True

# Imports

In [None]:
from zipfile import ZipFile

from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import ResNet50_Weights, ResNet18_Weights
# Voir si c'est torch ou torchvision à ce niveau
from torchvision.utils import _log_api_usage_once
from torchvision.models._api import WeightsEnum, Weights
from torchvision.models._meta import _IMAGENET_CATEGORIES
from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param

import pytorch_lightning as pl

from typing import Any, Dict, Callable, Tuple, Optional, Type, Union, List

from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

# from captum.attr import DeepLift, DeepLiftShap, GradientShap, Saliency
import captum.attr as explain

from os.path import isdir

from src.models.resnet import resnet50, resnet18

# Extract and load data

## Extract and load train/test sets

In [None]:
archive_path = '../data/fruits_vegetables_360.zip'

if not isdir('../data/fruits_vegetables_360/'):
  with ZipFile(archive_path, mode='r') as zip:
    zip.extractall('../data/fruits_vegetables_360/')

train_set = ImageFolder('../data/fruits_vegetables_360/fruits-360_dataset/fruits-360/Training/', transform=transforms.ToTensor())
test_set = ImageFolder('../data/fruits_vegetables_360/fruits-360_dataset/fruits-360/Test/', transform=transforms.ToTensor())

## Split training set into training and validation sets

In [None]:
train_ratio = 0.9
total_size = len(train_set)

train_size = int(train_ratio * total_size)
valid_size = total_size - train_size

train_set, valid_set = random_split(train_set, [train_size, valid_size])

Look at set sizes

In [None]:
len(train_set), len(valid_set), len(test_set)

# ResNet
Simple to use and simple to use explainability process on it.

## Load and freeze weights for finetuning

We freeze the entire model weights and biases just to keep classifier learnable.  
We want to keep initial features extractor.

In [None]:
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

model.fc = nn.Linear(in_features=512, out_features=131)

for name, params in model.named_parameters():
    if 'fc' in name:
        continue
    params.requires_grad = False

# PyTorch Lightning wrapper

In [None]:
class LightningWrapper(pl.LightningModule):
  def __init__(self, model: nn.Module, loss_function: Callable=F.cross_entropy, optimizer: torch.optim.Optimizer=torch.optim.Adam, 
               optimizer_params: Dict[str, Any]={'lr': 0.001}, **pl_module) -> None:
    super(LightningWrapper, self).__init__(**pl_module)
    self.save_hyperparameters()
    self.wrapped_model = model
    self.loss_function = loss_function
    self.optimizer = optimizer
    self.optimizer_params = optimizer_params

  def configure_optimizers(self) -> torch.optim.Optimizer:
    return self.optimizer(self.parameters(), **self.optimizer_params)
  
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.wrapped_model(x)
  
  def training_step(self, train_batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
    inputs, targets = train_batch
    outputs = self.wrapped_model(inputs)
    loss = self.loss_function(outputs, targets)
    self.log('train_loss', loss)
    return loss
  
  def validation_step(self, valid_batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
    inputs, targets = valid_batch
    outputs = self.wrapped_model(inputs)
    loss = self.loss_function(outputs, targets)
    self.log('valid_loss', loss)
    return loss
  
  def test_step(self, valid_batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
    inputs, targets = valid_batch
    outputs = self.wrapped_model(inputs)
    loss = self.loss_function(outputs, targets)
    self.log('test_loss', loss)
    return loss
  
  def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
    inputs = batch
    outputs = self.wrapped_model(inputs)
    return torch.softmax(outputs)

Config Lightning Trainer

In [None]:
model_checkpoint = ModelCheckpoint(
    dirpath='../data/lightning_logs/finetuning_resnet50_fv/',
    save_last=True,
    every_n_epochs=1
)

# Training

In [None]:
# Model and trainer
lightning_model = LightningWrapper(model) if do_training else LightningWrapper.load_from_checkpoint('../data/lightning_logs/finetuning_resnet50_fv/last.ckpt')

device = 'gpu' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(accelerator=device, max_epochs=25, callbacks=[model_checkpoint])

# Dataloaders
train_loader = DataLoader(train_set, batch_size=512, num_workers=2)
valid_loader = DataLoader(valid_set, batch_size=512, num_workers=2)
test_loader = DataLoader(test_set, batch_size=512, num_workers=2)

if do_training:
  # Training
  trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=valid_loader)

# Testing

In [None]:
trainer.test(lightning_model, dataloaders=test_loader)

# Explicability / Interpretability with Captum

In [None]:
img, tgt = train_set[0]
img = img.unsqueeze(0)
img.requires_grad = True

In [None]:
deeplift = explain.DeepLift(lightning_model.wrapped_model)
attribution = deeplift.attribute(img, target=tgt)

Le inplace est mauvais dans la construction de notre modèle (analysons le code) --> le inplace est utilisé pour la fonction d'activation (ReLU)