# Computer Vision Project Template Notebook

In [None]:
import pandas as pd
import numpy as np
import wandb
import os
import torch
from torch import optim, utils, Tensor
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.v2 as T
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import Dataset as TorchDataset, DataLoader
from dataclasses import dataclass
from torchvision import models
from sklearn.model_selection import train_test_split
from torch import optim as TorchOptimizers
from PIL import Image
from pathlib import Path
import torchinfo
from torchmetrics import MeanMetric
from PIL import Image

In [None]:
@dataclass(frozen=True)
class DataConfig:
    home_folder = ''
    data_folder = home_folder + 'data/'
    image_folder = data_folder + 'images/'

@dataclass(frozen=True)
class TrainConfig:
    model_name = 'efficientnet_b3'
    precision = 32

data_config = DataConfig()
train_config = TrainConfig()

In [None]:
class TemplateDataset(TorchDataset):
    def __init__(
            self,
            image_paths: pd.Series,
            targets: pd.Series,
            image_transforms: T.Compose,
    ):
        assert len(image_paths) == len(targets)
        self.image_paths = image_paths
        self.targets = torch.tensor(targets.to_numpy(dtype='float32'))
        self.transforms = image_transforms

        self._paths = []
        for im_path in self.image_paths:
            path = f'{data_config.image_folder}{im_path}'
            if not Path(path).exists():
                raise FileNotFoundError(f"Image not found: {path}")
            self._paths.append(path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        return self.transforms(Image.open(self._paths[idx]).convert('RGB')), self.targets[idx]

In [None]:
class TemplateDataModule(pl.LightningDataModule):
    def __init__(
            self,
            batch_size=32,
            num_workers=0,
            image_size=300,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers

        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]

        # Create image transforms
        self.transforms_no_aug = T.Compose([
            T.Resize(image_size, antialias=True),
            T.CenterCrop(image_size),
            T.ToTensor(),
            T.Normalize(imagenet_mean, imagenet_std),
        ])

    # noinspection PyAttributeOutsideInit
    def setup(self, stage):
        if stage == 'fit':
            train_df = pd.read_csv(data_config.data_folder + 'train_final.csv')
            x_cols = train_df['image_path']
            y_cols = train_df['targets']
            x_cols_train, x_cols_val, y_train, y_val = train_test_split(
                    x_cols, y_cols, random_state=42, test_size=0.2, stratify=y_cols
            )
            self.train_ds = TemplateDataset(image_paths=x_cols_train, targets=y_train, image_transforms=self.transforms_no_aug)
            self.val_ds = TemplateDataset(image_paths=x_cols_val, targets=y_val, image_transforms=self.transforms_no_aug)
        if stage == 'test':
            # create and assign test dataset
            self.test_ds = ...
        if stage == 'predict':
            # create predict ds?
            self.predict_ds = ...

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

In [None]:
class TemplateLightningModule(pl.LightningModule):
    def __init__(
            self,
            fine_tune_start=8,
            starting_learning_rate=4e-4,
            num_classes=3,
    ):
        super().__init__()

        self.starting_learning_rate = starting_learning_rate

        self.model = models.efficientnet_b3(weights='DEFAULT')

        for param in self.model.parameters():
            param.requires_grad = False

        for layer_to_unfreeze in range(fine_tune_start, 9):
            for param in self.model.features[layer_to_unfreeze].parameters():
                param.requires_grad = True

        in_features = self.model.classifier[-1].in_features
        self.model.classifier = nn.Sequential(
                nn.Dropout(p=0.3, inplace=True),
                nn.Linear(in_features, out_features=num_classes),
        )
        self.mean_train_loss = MeanMetric()
        self.mean_valid_loss = MeanMetric()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, *args, **kwargs):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.mean_train_loss(loss)
        self.log("train/batch_loss", self.mean_train_loss, prog_bar=True, logger=True)
        return loss

    def on_train_epoch_end(self):
        # Calculate epoch level metrics for the train set
        self.log("train/loss", self.mean_train_loss, prog_bar=True, logger=True)
        self.log("step", self.current_epoch, logger=True)

    def validation_step(self, batch, *args, **kwargs):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.mean_valid_loss(loss)

    def on_validation_epoch_end(self):
        # Calculate epoch level metrics for the validation set
        self.log("valid/loss", self.mean_valid_loss, prog_bar=True, logger=True)
        self.log("step", self.current_epoch, logger=True)

    def configure_optimizers(self):
        optimizer = TorchOptimizers.Adam(self.parameters(), lr=self.starting_learning_rate)
        # example of using scheduler
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        #
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": scheduler,
        #     "monitor": "val_loss",  # optional for ReduceLROnPlateau
        # }
        return optimizer

In [None]:
pl.seed_everything(42, workers=True)

data_module = TemplateDataModule()
model = TemplateLightningModule()

# create callbacks
early_stopping_callback = EarlyStopping(monitor="valid/loss", patience=20)

checkpoint_callback = ModelCheckpoint(
        monitor='valid/loss',
        mode="max",
        filename='efficientnet_b3-fine-tuning-epoch-{epoch:02d}',
        auto_insert_metric_name=False,
        save_weights_only=True,
)

trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    strategy="auto",
    max_epochs=train_config.max_epochs,
    precision = train_config.precision,
    callbacks=[
        early_stopping_callback,
        checkpoint_callback,
    ]
)

trainer.fit(model, data_module)