In [None]:
import os
import shutil
import requests
import tarfile

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio

from sklearn.preprocessing import LabelEncoder
from dataclasses import dataclass, asdict

import torch
import torchvision
import torchmetrics

import torchdata.datapipes as dp
from torch.utils.data import DataLoader

torchvision.disable_beta_transforms_warning();
from torchvision.models import alexnet, AlexNet_Weights
import torchvision.transforms.v2 as t

from lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger

import wandb

from tqdm import tqdm
from typing import Callable, Any
from dotenv import load_dotenv
load_dotenv();

from hyperparameters import Hyperparameters

import logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)

In [None]:
IMAGENETTE = Path.home() / "datasets" / "imagenette"

CHECKPOINTS_DIR = Path.cwd() / "checkpoints"
CHECKPOINTS_DIR.mkdir(exist_ok=True)

LOGS_DIR = Path.cwd() / "logs"
LOGS_DIR.mkdir(exist_ok=True)

In [None]:
def viz_batch(batch: tuple[torch.Tensor, torch.Tensor], le: LabelEncoder, df: pd.DataFrame) -> None:
    images, targets = batch
    labels = le.inverse_transform(targets.ravel())
    labels = [df.loc[x].label for x in labels]
    assert images.shape[0] == targets.shape[0], "#images != #targets"

    subplot_dims:tuple[int, int]
    if images.shape[0] <= 8:
        subplot_dims = (1, images.shape[0])
    else:
        subplot_dims = (int(np.ceil(images.shape[0]/8)), 8)

    figsize = 20
    figsize_factor = subplot_dims[0] / subplot_dims[1]
    _, axes = plt.subplots(nrows = subplot_dims[0], 
                           ncols = subplot_dims[1], 
                           figsize = (figsize, figsize * figsize_factor))
    for idx, ax in enumerate(axes.ravel()):
        ax.imshow(images[idx].permute(1, 2, 0))
        ax.tick_params(axis = "both", which = "both", 
                       bottom = False, top = False, 
                       left = False, right = False,
                       labeltop = False, labelbottom = False, 
                       labelleft = False, labelright = False)
        ax.set_xlabel(f"{labels[idx]}({targets[idx].item()})")

In [None]:
@dataclass(frozen = True, repr = True)
class Hyperparameters:
    task: str
    random_seed: int
    num_classes: int
    metrics: list[str]

    criterion: Callable | str
    optimizer: Callable | str  
    learning_rate: float
    momentum: float
    weight_decay: float

    batch_size: int
    grad_accum: int
    test_split: float
    transform : list[str]

    num_workers: int

    def get_dict(self) -> dict:
        return asdict(self)

    def get_datamodule_dict(self) -> dict:    
        return {
            "batch_size": self.batch_size // self.grad_accum,
            "transform": self.transform,
            "test_split": self.test_split,

            "num_workers": self.num_workers,
        }

    def get_litmodule_hparams(self) -> dict:
        return {
            "task": self.task,
            "num_classes": self.num_classes,
            "metrics": self.metrics,

            "criterion": self.criterion,
            "optimizer": self.optimizer,
            "learning_rate": self.learning_rate,
            "momentum": self.momentum,
            "weight_decay": self.weight_decay,

            "num_workers": self.num_workers,
        }

In [None]:
class ImageDataLoader(dp.iter.IterDataPipe):
    def __init__(self, 
                 src_dp: dp.iter.IterDataPipe, 
                 label_encoder: LabelEncoder,
                 transform: Callable | None = None):
        self.src_dp  = src_dp 
        self.le = label_encoder
        self.transform = transform if transform else self._default_transform
    
    def __iter__(self): 
        for path, label in self.src_dp:
            image = self._load_image(path)
            image = self._minmax_image(image)
            image = self.transform(image) #type: ignore
            label = self._encode_label(label)
            yield (image, label)
     
    def _load_image(self, image_path: Path) -> torch.Tensor:
        image = (iio.imread(uri = image_path,
                           plugin = "pillow",
                           extension = ".jpg")
                    .squeeze())

        #Duplicate Grayscale Image
        if image.ndim == 2:
            image = np.stack((image,)*3, axis = -1)
        assert image.shape[-1] == 3, "Not A 3 Channel Image"

        image = image.transpose(2, 0, 1)
        image = image.astype(np.float32)
        return torch.from_numpy(image)

    def _encode_label(self, label) -> torch.Tensor:
        return torch.tensor(
            self.le.transform([label])[0], #type: ignore
        dtype = torch.long)
    
    def _minmax_image(self, image: torch.Tensor) -> torch.Tensor:
        return (image - image.min()) / (image.max() - image.min())
    
    def _default_transform(self, image: torch.Tensor) -> torch.Tensor:
        return t.Compose([
            t.Resize((256, 256), antialias=True),
        ])(image)

In [None]:
class ImagenetteDataModule(LightningDataModule):
    def __init__(self, root: Path, params: Hyperparameters, transform: Callable | None = None) -> None:
        super().__init__()
        self.root = root
        if not self.root.is_dir():
            self.root.mkdir(parents = True)
        self.transform = transform
        self.batch_size = (params.batch_size // params.grad_accum)

        #TODO: Figure out how to automate getting num_workers
        #os.cpu_count or something like that
        self.num_workers = params.num_workers

        self.save_hyperparameters(params.get_datamodule_dict(),
            ignore = ["transform", "params"])

    def prepare_data(self) -> None:
        if self._is_empty_dir(self.root):
            url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz"  
            print("Root is Empty, Downloading Dataset")
            archive: Path = self.root / "archive.tgz"
            self._download_from_url(url, archive)
            print("Extracting Dataset")
            self._extract_tgz(archive, self.root)
            print("Deleting Archive")
            archive.unlink(missing_ok=True)
            print("Moving Items to Root")
            self._move_dir_up(self.root / "imagenette2")
            print("Done!")
    
    def setup(self, stage: str) -> None:
        self._setup_local()
        if stage == "fit":
            self.train_dataset = self._prepare_local_train()
            self.val_dataset = self._prepare_local_val() 
        
        elif stage == "validate":
            self.val_dataset = self._prepare_local_val()

        elif stage == "test":
            self.val_dataset = self._prepare_local_val()

        elif stage == "predict":
            self.val_dataset = self._prepare_local_val()
        
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset = self.train_dataset, 
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            #persistent_workers = True,
            pin_memory = True,
            shuffle = True
            )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset = self.val_dataset, 
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            pin_memory = True
            )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset = self.val_dataset, 
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset = self.val_dataset, 
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            )

    def _setup_local(self) -> None:
        df = pd.read_csv(self.root/"noisy_imagenette.csv")
        df["label"] = df["noisy_labels_0"]
        df["path"] = df["path"].apply(lambda x: self.root / x)

        self.train_df = df[["path", "label"]][df["is_valid"] == False].reset_index(drop = True)
        self.val_df = df[["path", "label"]][df["is_valid"] == True].reset_index(drop = True)

        self._prepare_label_encoder(df["label"].unique().tolist())

    def _prepare_local_train(self) -> Any:
        pipe = self._datapipe_from_dataframe(self.train_df)
        pipe = (pipe 
                    .shuffle(buffer_size=len(self.train_df))
                    #.sharding_filter()
                    #.pinned_memory()
                    #.load_image_data()
                    #.prefetch()
                    #.set_length()
                )
        pipe = ImageDataLoader(pipe, self.label_encoder, self.transform) #type:ignore 
        pipe = pipe.prefetch(self.batch_size)
        pipe = pipe.set_length(len(self.train_df))
        return pipe

    def _prepare_local_val(self) -> Any:
        pipe = self._datapipe_from_dataframe(self.val_df)
        pipe = ImageDataLoader(pipe, self.label_encoder) #type: ignore
        pipe = pipe.set_length(len(self.val_df))
        return pipe

    def _datapipe_from_dataframe(self, dataframe: pd.DataFrame) -> Any:
        return dp.iter.Zipper(
            dp.iter.IterableWrapper(dataframe.path),
            dp.iter.IterableWrapper(dataframe.label)
            )
    
    def _prepare_label_encoder(self, class_names: list) -> None:
        self.label_encoder = LabelEncoder().fit(sorted(class_names))

    def _download_from_url(self, url: str, local_filename: Path) -> None:
        response = requests.head(url)
        file_size = int(response.headers.get("Content-Length", 0))

        with requests.get(url, stream=True) as response:
            with open(local_filename, "wb") as output_file:
                with tqdm(
                    total=file_size, unit="B", unit_scale=True, unit_divisor=1024
                ) as progress_bar:
                    for data in response.iter_content(chunk_size=1024*1024):
                        output_file.write(data)
                        progress_bar.update(len(data))
    
    def _extract_tgz(self, tgz_file, out_dir) -> None: 
        with tarfile.open(tgz_file, "r:gz") as tar:
            tar.extractall(out_dir)
        
    def _is_empty_dir(self, path: Path) -> bool:
        return not list(path.iterdir())
        
    def _move_dir_up(self, source_dir: Path) -> None:
        for path in source_dir.iterdir():
            dest_path = source_dir.parent / path.name
            if path.is_dir():
                path.rename(dest_path)
            else:
                shutil.move(path, dest_path)
        source_dir.rmdir()

In [None]:
class ClassificationModel(LightningModule):
    def __init__(self, model, params: Hyperparameters):
        super().__init__()
        self.model = model
        self.params = params
        #TODO : Add dicts for Metrics, Optimizers, Criterions
        #TODO : Remove Dependence on Hyperparameters Class, locally store stuff

        self._set_metrics()
        self.save_hyperparameters(
            {i:params.get_litmodule_hparams()[i] 
             for i in params.get_litmodule_hparams().keys() if i!='criterion'},
            ignore = ["model", "params"]
        ) 
    
    def forward(self, batch):
        x, _ = batch
        return self.model(x)

    def _forward_pass(self, batch, metrics : Callable | None = None):
        x, y = batch
        y_pred = self.model(x)
        if metrics:
            metrics(y_pred, y) 
        return self.params.criterion(y_pred, y) #type: ignore

    def _set_metrics(self):
        #self.train_metrics = torchmetrics.Accuracy("multiclass", 
                                    #num_classes=self.params.num_classes,
                                    #average = "macro")

        #self.test_metrics = torchmetrics.Accuracy("multiclass",
                                    #num_classes=self.params.num_classes, 
                                    #average = "macro")

        #self.val_metrics = torchmetrics.Accuracy("multiclass",
                                    #num_classes=self.params.num_classes, 
                                    #average = "macro")
        pass

    def training_step(self, batch, batch_idx):
        loss = self._forward_pass(batch)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        #self.log("train_acc", self.train_metrics, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._forward_pass(batch)#, self.test_metrics)
        self.log("test_loss", loss)
        #self.log("test_acc", self.test_metrics, on_epoch=True, on_step=False)
    
    def validation_step(self, batch, batch_idx):
        loss = self._forward_pass(batch)#, self.val_metrics)
        self.log("val_loss", loss, on_step = False, on_epoch = True)
        #self.log("val_acc", self.val_metrics, on_epoch=True, on_step=False)
      
    def configure_optimizers(self):
        return self.params.optimizer(self.model.parameters(), #type: ignore
                                     lr = self.params.learning_rate,
                                     #momentum = 0.9, weight_decay = 5e-4
                                     )

In [None]:
local_checkpoint = ModelCheckpoint(
    dirpath=CHECKPOINTS_DIR,
    filename="{epoch}-{train_loss:2f}-{val_loss:2f}",
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    save_last=True,
)

local_logger = CSVLogger(
    save_dir=Path.cwd(),
    name="logs",
    version=1,
)

#wandb.finish()
#wandb_logger = WandbLogger(
    #save_dir=LOGS_DIR,
    #project="ilsvrc-with-imagenette",
    #log_model=True,
    #version='2',
#)


In [None]:
experiment = Hyperparameters(
    task = "multiclass_classification",
    random_seed = 42,
    num_classes = 10,
    metrics = ["accuracy", "f1score"],

    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = torch.optim.Adam,
    learning_rate = 1e-6,
    momentum = 0,
    weight_decay = 0,

    batch_size = 128,
    grad_accum = 4,
    test_split = .3,
    transform = ["random_crop_224"],
    num_workers = 8,

)

seed_everything(experiment.random_seed, workers = True);
alexnet_pretrained = alexnet(weights = "DEFAULT")
alexnet_pretrained.classifier[-1] = torch.nn.Linear(
                                        in_features=4096,
                                        out_features=experiment.num_classes,
                                        bias = True) 
#alexnet_transform = AlexNet_Weights.IMAGENET1K_V1.transforms()
alexnet_transform = t.Compose([
    t.Resize(size = (256, 256), antialias = True),
    t.RandomCrop(size = (224, 224), pad_if_needed = True),
    #t.RandomHorizontalFlip(p = .5)
])

classifier = ClassificationModel(alexnet_pretrained, experiment)
imagenette_dm = ImagenetteDataModule(
        root = IMAGENETTE, 
        params = experiment, 
        transform = alexnet_transform)

In [None]:
trainer = Trainer(
    #fast_dev_run=True,
    #deterministic=True,
    #benchmark=True,
    #enable_checkpointing=False,
    callbacks=[local_checkpoint],
    logger=[local_logger],

    max_epochs = 20,  
    accumulate_grad_batches = experiment.grad_accum,
    check_val_every_n_epoch = 5
)

In [None]:
last_ckpt = CHECKPOINTS_DIR / "last.ckpt"
last_ckpt = last_ckpt if last_ckpt.is_file() else None

trainer.fit(
    model = classifier,
    datamodule = imagenette_dm,
    ckpt_path = last_ckpt #type: ignore
)

In [None]:
labels_df = pd.read_csv("labels.csv", index_col=0)
labels_df["label"] = labels_df["label"].str.strip(',')

def prepare_preds(dm) -> Any:
    dm.setup("predict")
    pipe = dm._datapipe_from_dataframe(dm.val_df)
    pipe = ImageDataLoader(pipe, dm.label_encoder) #type: ignore
    pipe = pipe.shuffle()
    pipe = pipe.set_length(len(dm.val_df))
    return pipe

images = list()
labels = list()
labels_str = list()

dataset = prepare_preds(imagenette_dm) 
for idx, sample in enumerate(dataset):
    if idx >= 64: 
        break
    image = sample[0].clip(0, 1)
    label = sample[1]
    images.append(image)
    labels.append(label)

    label = imagenette_dm.label_encoder.inverse_transform([sample[1]])[0]
    label = labels_df.loc[label].label
    labels_str.append(label)

images = torch.stack(images)
labels = torch.stack(labels)
sample_batch = (images, labels)

In [None]:
print(images.shape)
print(labels.shape)
print(len(labels_str))

In [None]:
sample = [(image, label) for image, label in zip(images, labels)]
dl = DataLoader(
    dataset = sample, #type: ignore
    batch_size = 64 
)
predictions = trainer.predict(
    model = classifier,
    dataloaders = dl,
)
preds = predictions[0].argmax(axis = 1)

print(preds.shape)
preds_batch = (images, preds)

In [None]:
viz_batch(preds_batch, imagenette_dm.label_encoder, labels_df)