In [18]:
%reload_ext autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Imports

In [19]:
# External libraries
import os as so
import sys as s
import pathlib as pl
import torch
import torch.nn as nn
import torch.optim as om
from torch import Tensor
from torch.utils.data import random_split
from torch.utils.data import DataLoader, ConcatDataset
import torcheval
from torcheval.metrics import MulticlassF1Score, Mean
import optuna as opt
import torchvision as tn
import sklearn as sn
from sklearn.metrics import f1_score
import pandas as ps
import numpy as ny
import typing as t
import pathlib as pl
import matplotlib.pyplot as pt
import random as rng
from tqdm import tqdm
import tqdm as tm
from pprint import pprint
from git import Repo
import lightning as tl
from lightning.pytorch.loggers import WandbLogger

In [20]:
# Add local package to path
if (p := pl.Path(so.getcwd(), '..').absolute().as_posix()) not in s.path:
    s.path.append(p)

# Local imports
from gic import *
from gic.models.resnet import ResCNN
from gic.models.densenet import DenseCNN
from gic.models.convnext import ConvNextNet
from gic.models.autoencoder import AutoEncoder, UNet
from gic.data import load_data, load_batched_data, GICPreprocess, GICPerturb, GICDatasetModule

### Data Loading

In [21]:
train_dl, valid_dl, test_dl = load_batched_data(DATA_PATH, 'disjoint', gen_torch, pin_memory=pin_memory, num_workers=num_workers, prefetch_factor=prefetch_factor, batch_size=32)
tr_perturb = GICPerturb(gen_torch, mask=True, normalize=False)
tr_preprocess = GICPreprocess(augment=False, normalize=False)
tr_augment = GICPreprocess(augment=True, normalize=False)

### Model

In [22]:
from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler


class DenseNetModule(tl.LightningModule):
    def __init__(self, **kwargs):
        super(DenseNetModule, self).__init__()

        # Model
        unet = AutoEncoder(64, 256, 'SiLU')
        unet.load_state_dict(torch.load('../ckpt/unet.pt'))
        unet.eval().requires_grad_(False)
        self.enc = unet.unet.encoder
        f_drop = 0.2

        self.net_densecnn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),

            nn.Dropout1d(f_drop),
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512),
            nn.SiLU(),

            nn.Dropout1d(f_drop),
            nn.Linear(in_features=512, out_features=512),
            nn.BatchNorm1d(512),
            nn.SiLU(),

            nn.Linear(in_features=512, out_features=100)
        )
        # self.net_densecnn= ConvNextNet(64, 4, conv_dropout=0.1, dense_dropout=0.3, conv_layers=2, dense_features=256, dense_layers=1, patch_reduce=False)
        # self.net_densecnn = DenseCNN(**kwargs)
        self.loss_fn = nn.CrossEntropyLoss()

        # Metrics
        self.metric_train_f1_score = MulticlassF1Score(num_classes=CONST_NUM_CLASS, average='macro', device=self.device)
        self.metric_valid_f1_score = MulticlassF1Score(num_classes=CONST_NUM_CLASS, average='macro', device=self.device)
        self.metric_train_loss = Mean(device=self.device)
        self.metric_valid_loss = Mean(device=self.device)
        self.save_hyperparameters(kwargs)

    def forward(self, x: Tensor) -> Tensor:
        print(x.shape)
        _, e = self.enc(x)
        return self.net_densecnn(e)

    def on_train_start(self) -> None:
        self.metric_train_f1_score.to(self.device)
        self.metric_train_loss.to(self.device)

    def on_train_epoch_start(self) -> None:
        self.metric_train_f1_score.reset()
        self.metric_train_loss.reset()

    def training_step(self, batch: t.Tuple[Tensor, Tensor], _: t.Any) -> STEP_OUTPUT:
        X, y_true = batch
        logits: Tensor = self(X)
        loss: Tensor = self.loss_fn(logits, y_true)

        self.metric_train_loss.update(loss.detach())
        self.metric_train_f1_score.update(logits.detach(), y_true)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log('train_f1_score', self.metric_train_f1_score.compute().item())
        self.log('train_loss', self.metric_train_loss.compute().item())

    def on_validation_start(self) -> None:
        self.metric_valid_loss.to(self.device)
        self.metric_valid_f1_score.to(self.device)

    def on_validation_epoch_start(self) -> None:
        self.metric_valid_f1_score.reset()
        self.metric_valid_loss.reset()

    def validation_step(self, batch: t.Tuple[Tensor, Tensor], _: t.Any) -> STEP_OUTPUT:
        X, y_true = batch
        logits: Tensor = self(X)
        loss: Tensor = self.loss_fn(logits, y_true)

        self.metric_valid_loss.update(loss)
        self.metric_valid_f1_score.update(logits, y_true)
        return loss

    def on_validation_epoch_end(self) -> None:
        self.log('valid_f1_score', self.metric_valid_f1_score.compute().item())
        self.log('valid_loss', self.metric_valid_loss.compute().item())

    def predict_step(self, batch: Tensor, _: t.Any) -> Any:
        return torch.argmax(self(batch), dim=-1)

    def configure_optimizers(self) -> OptimizerLRScheduler:
        optim = om.AdamW(self.parameters(), betas=(0.9, 0.999), lr=6e-4)
        scheduler = om.lr_scheduler.ReduceLROnPlateau(optim, 'max', 0.75, 10, min_lr=2e-4, cooldown=5)
        return {
            'optimizer': optim,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'train_f1_score',
            },
        }

In [23]:
wb_logger = WandbLogger(project='Generated Image Classification', name='TorchLighting', save_dir=LOG_PATH)

### Logging

In [24]:
gic_data = GICDatasetModule(DATA_PATH, False, 32, num_workers, prefetch_factor, pin_memory, True, gen_torch)
trainer = tl.Trainer(max_epochs=150, enable_checkpointing=True, logger=wb_logger)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Training & Validation Loop

In [25]:
net_densecnn = DenseNetModule()
trainer.fit(net_densecnn, datamodule=gic_data)

/home/invokariman/.cache/pypoetry/virtualenvs/gic-BfYgXNhZ-py3.11/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type             | Params
--------------------------------------------------
0 | enc          | Encoder          | 6.6 M 
1 | net_densecnn | Sequential       | 840 K 
2 | loss_fn      | CrossEntropyLoss | 0     
--------------------------------------------------
840 K     Trainable params
6.6 M     Non-trainable params
7.4 M     Total params
29.642    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

AttributeError: 'list' object has no attribute 'shape'

### Full Training Loop

In [None]:
gic_data = GICDatasetModule(DATA_PATH, True, 32, num_workers, prefetch_factor, pin_memory, False, gen_torch)
trainer = tl.Trainer(max_epochs=345, enable_checkpointing=False, logger=wb_logger)
net_densecnn = DenseNetModule()
trainer.fit(net_densecnn, datamodule=gic_data)

### Testing Loop

In [None]:
y_hat: t.List[Tensor] = t.cast(t.List[Tensor], trainer.predict(net_densecnn, datamodule=gic_data, return_predictions=True))
preds = torch.cat(y_hat, dim=0)

### Submission

In [None]:
data = test_dl.dataset._GICDataset__data
data['Class'] = preds
data.to_csv(SUBMISSION_PATH, index=False)