In [1]:
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio

from sklearn.preprocessing import LabelEncoder

import torch
import torchdata.datapipes as dp
from torch.utils.data import DataLoader

import torchvision
torchvision.disable_beta_transforms_warning();

import torchvision.transforms.v2 as t
from torchvision.models import AlexNet

import torchmetrics
import pytorch_lightning as pl

from streaming import StreamingDataset, MDSWriter
from streaming.base.util import clean_stale_shared_memory

from tqdm import tqdm
from typing import Callable, Any

from dotenv import load_dotenv
load_dotenv();

from hyperparameters import Hyperparameters
from datamodules import ImagenetteDataLoader, viz_batch 

In [10]:
#IMAGENET = Path("/mnt/c/Users/SambhavChandra/datasets/imagenet/")
IMAGENET = Path("/run/media/sambhav/2A2E24A52E246BCF/Users/SambhavChandra/datasets/imagenet/") 
#IMAGENET_SHARDS = Path.home() / "datasets" / "imagenet"

In [3]:
def viz_batch(batch: tuple[torch.Tensor, torch.Tensor], le: LabelEncoder) -> None:
    images, targets = batch
    labels = le.inverse_transform(targets.ravel())
    assert images.shape[0] == targets.shape[0], "#images != #targets"

    subplot_dims:tuple[int, int]
    if images.shape[0] <= 8:
        subplot_dims = (1, images.shape[0])
    else:
        subplot_dims = (int(np.ceil(images.shape[0]/8)), 8)

    figsize = 20
    figsize_factor = subplot_dims[0] / subplot_dims[1]
    _, axes = plt.subplots(nrows = subplot_dims[0], 
                           ncols = subplot_dims[1], 
                           figsize = (figsize, figsize * figsize_factor))
    for idx, ax in enumerate(axes.ravel()):
        ax.imshow(images[idx].permute(1, 2, 0))
        ax.tick_params(axis = "both", which = "both", 
                       bottom = False, top = False, 
                       left = False, right = False,
                       labeltop = False, labelbottom = False, 
                       labelleft = False, labelright = False)
        ax.set_xlabel(f"{labels[idx]}({targets[idx].item()})")

In [13]:
class ImageDataLoader(dp.iter.IterDataPipe):
    def __init__(self, 
                 src_dp: dp.iter.IterDataPipe, 
                 label_encoder: LabelEncoder,
                 transform: Callable | None = None):
        self.src_dp  = src_dp 
        assert label_encoder
        self.le = label_encoder
        self.transform = transform if transform else self._default_transform
    
    def __iter__(self): 
        for path, label in self.src_dp:
           yield (self.transform(self._load_image(path)),
                  self._encode_label(label))
     
    def _load_image(self, image_path: Path) -> torch.Tensor:
        image = (iio.imread(uri = image_path,
                           plugin = "pillow",
                           extension = ".jpeg")
                    .squeeze())
        #Duplicate Grayscale Image
        if image.ndim == 2:
            image = np.stack((image,)*3, axis = -1)
        assert image.shape[-1] == 3
        return torch.from_numpy(image.transpose(2, 0, 1))

    def _encode_label(self, label) -> torch.Tensor:
        return torch.from_numpy(
            self.le.transform([label])
        ).squeeze()
    
    def _default_transform(self, image: torch.Tensor) -> torch.Tensor:
        transform = t.Compose([
            t.Resize((256, 256), antialias = True),
            t.ConvertImageDtype(torch.float32)
        ])
        return transform(image / 255)

In [14]:
class ImagenetRemoteDataset(StreamingDataset):
    def __init__(self,
                 remote: str,
                 local: str,
                 shuffle: bool,
                 batch_size: int,
                 cache_limit: int|str,
                 shuffle_seed: int = 420,
                 transform: Callable | None = None
                ) -> None:
        super().__init__(local=local, remote=remote, 
                         shuffle=shuffle, shuffle_seed=shuffle_seed,
                         batch_size=batch_size, cache_limit=cache_limit
                         )
        self.transform = transform if transform else self._default_transform

    def __getitem__(self, idx:int) -> Any:
        obj = super().__getitem__(idx)
        image = self._load_image(obj["image"]) 
        label = torch.tensor(obj["label"], dtype = torch.float32)

        return self.transform(image), label

    def _load_image(self, bytestream: bytes):
        image = (iio.imread(uri = bytestream,
                           plugin = "pillow",
                           extension = ".jpg")
                    .squeeze())
        #Duplicate Grayscale Image
        if image.ndim == 2:
            image = np.stack((image,)*3, axis = -1)
        assert image.shape[-1] == 3
        return torch.from_numpy(image.transpose(2, 0, 1))
    
    def _default_transform(self, image: torch.Tensor) -> torch.Tensor:
        transform = t.Compose([
            t.Resize((256, 256), antialias = True),
            t.ConvertImageDtype(torch.float32)
        ])
        return transform(image / 255)

In [15]:
class ImagenetDataModule(pl.LightningDataModule):
    def __init__(self, root: Path, params: Hyperparameters, remote: str | None = None):
        super().__init__()
        self.root = root
        self.remote = remote
        if remote is not None:
            self._setup_remote()
        else:
            self._setup_local()

        self.params = params
    
    def setup(self, stage: str):
        if stage == "fit":
            if self.remote is not None:
                self.train_dataset = self._prepare_remote_train()
            else:
                self.train_dataset = self._prepare_local_train()

        elif stage == "test":
            if self.remote is not None:
                self.val_dataset = self._prepare_remote_val()
            else:
                self.val_dataset = self._prepare_local_val()

    def train_dataloader(self):
        return DataLoader(
            dataset = self.train_dataset, 
            batch_size = self.params.batch_size,
            num_workers = self.params.num_workers,
            #shuffle = True
            )
    
    def test_dataloader(self):
        return DataLoader(
            dataset = self.val_dataset, 
            batch_size = self.params.batch_size,
            num_workers = self.params.num_workers,
            )
    
    def _setup_local(self) -> None:
        labels = self.root / "LOC_synset_mapping.txt"
        self.labels_df = self._get_labels_df(labels)
        self._prepare_label_encoder(sorted(self.labels_df.index.tolist()))

        train_dir = self.root / "ILSVRC" / "Data" / "CLS-LOC" / "train"
        val_soln = self.root / "LOC_val_solution.csv"

        self.train_df = self._get_train_df(train_dir, Path.cwd() / "train.csv")
        self.val_df = self._get_val_df(val_soln, Path.cwd() / "val.csv")
    
    def _setup_remote(self) -> None:
        #TODO: how to handle label_encoder creation? hardcode class names? or put labels.csv on remote?
        self.labels_df = pd.read_csv(Path.cwd() / "labels.csv", index_col=0) 
        self._prepare_label_encoder(sorted(self.labels_df.index.tolist()))

        self.local_shards_train: Path = self.root / "shards" / "train"
        self._reset_dir(self.local_shards_train)

        self.local_shards_val: Path = self.root / "shards" / "val"
        self._reset_dir(self.local_shards_val)

        self.remote_shards_train: str = self.remote + "/train"
        self.remote_shards_val: str = self.remote + "/val"

        clean_stale_shared_memory()

    def _prepare_local_train(self) -> Any:
            datapipe = self._datapipe_from_dataframe(self.train_df)
            #Sharding Filter, Prefetcher, Pinned Memory
            #self.train_dp = (self.train_dp
                                #.shuffle(buffer_size=len(self.train_df)))
            datapipe = ImageDataLoader(datapipe, self.label_encoder) #type: ignore 
            datapipe = datapipe.set_length(len(self.train_df))
            return datapipe
    
    def _prepare_local_val(self) -> Any:
            datapipe = self._datapipe_from_dataframe(self.val_df)
            datapipe = ImageDataLoader(datapipe, self.label_encoder) #type: ignore
            datapipe = datapipe.set_length(len(self.val_df))
            return datapipe

    def _prepare_remote_train(self) -> Any:
        clean_stale_shared_memory()
        dataset = ImagenetRemoteDataset(
            remote = self.remote_shards_train,
            local = self.local_shards_train.as_posix(),
            shuffle = False,
            shuffle_seed = self.params.random_seed,
            batch_size = self.params.batch_size,
            cache_limit = self.params.local_cache_limit,
        )
        return dataset

    def _prepare_remote_val(self) -> Any:
        clean_stale_shared_memory()
        dataset = ImagenetRemoteDataset(
            remote = self.remote_shards_val,
            local = self.local_shards_val.as_posix(),
            shuffle = False,
            batch_size = self.params.batch_size,
            cache_limit = self.params.local_cache_limit,
        )
        return dataset

    def _get_labels_df(self, path: Path) -> pd.DataFrame:
        df = pd.read_table(path, header = None)
        df = df[0].str.split(" ", n = 2, expand = True)
        df.columns = ["wnid", "label", "words"]
        df["label"] = df["label"].str.strip(',')
        df = df.set_index("wnid")
        return df 

    def _get_train_df(self, train_dir: Path, path_to_csv: Path | None = None) -> pd.DataFrame:
        if path_to_csv:
            assert path_to_csv.exists() and path_to_csv.is_file(), "invalid path" #type: ignore

            df = pd.read_csv(path_to_csv, index_col=0)
            df["path"] = df["path"].apply(lambda x: train_dir / x)
            return df[["path", "label"]] #type: ignore

        else:
            df = pd.DataFrame({"path": list(train_dir.rglob("*.JPEG"))})
            df["label"] = df["path"].apply(lambda x: x.parent.stem)
            return df

    def _get_val_df(self, val_soln_csv: Path | None = None, val_csv: Path | None = None) -> pd.DataFrame:
        val_prefix: Path = self.root / "ILSVRC" / "Data" / "CLS-LOC" / "val"
        if val_csv:
            assert val_csv.exists() and val_csv.is_file(), "invalid path"
            df = pd.read_csv(val_csv, index_col=0)

            df["path"] = df["path"].apply(lambda x: val_prefix / x)
            return df[["path", "label"]] #type: ignore

        elif val_soln_csv:
            assert val_soln_csv.exists() and val_soln_csv.is_file(), "invalid soln path"

            df = pd.read_csv(val_soln_csv) 
            df["path"] = df["ImageId"].apply(lambda x: val_prefix / f"{x}.JPEG") #type: ignore
            df["label"] = df["PredictionString"].str.split(" ", n = 1, expand = True).iloc[:, 0]
            return df[["path", "label"]] #type: ignore
        
    def _datapipe_from_dataframe(self, dataframe: pd.DataFrame):
        return dp.iter.Zipper(
            dp.iter.IterableWrapper(dataframe.path),
            dp.iter.IterableWrapper(dataframe.label)
            )
    
    def _prepare_label_encoder(self, class_names: list):
        self.label_encoder = LabelEncoder().fit(class_names)
    
    def _reset_dir(self, dir_path: Path) -> None:
        if dir_path.exists() and dir_path.is_dir():
            shutil.rmtree(dir_path)
        dir_path.mkdir(parents = True, exist_ok = True)

In [16]:
class ClassificationModel(pl.LightningModule):
    def __init__(self, model, params: Hyperparameters):
        super().__init__()
        self.model = model
        self.params = params
        self.model = model
        
        self._set_metrics()
        self.save_hyperparameters(
            {i:params.get_dict()[i] for i in params.get_dict().keys() if i!='criterion'},
            ignore = ["model"]
        ) 
    
    def forward(self, batch):
        x, _ = batch
        return self.model(x)

    def _set_metrics(self):
        self.train_metrics = torchmetrics.Accuracy("multiclass", 
                                    num_classes=self.params.num_classes,
                                    average = "micro")
        self.test_metrics = torchmetrics.Accuracy("multiclass",
                                    num_classes=self.params.num_classes, 
                                    average = "micro")

    def training_step(self, batch, batch_idx):
        loss = self._forward_pass(batch, self.train_metrics)
        self.log("train_loss", loss, on_step = True, on_epoch = True)
        #self.log_dict(self.train_metrics, on_epoch=True, on_step=False)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._forward_pass(batch, self.test_metrics)
        self.log("test_loss", loss)
    
    #def validation_step(self, batch, batch_idx):
        #loss = self._forward_pass(batch, self.val_metrics)
        #self.log("val_loss", loss)
      
    def configure_optimizers(self):
        return self.params.optimizer(self.model.parameters(), 
                                     lr = self.params.learning_rate)

    def _forward_pass(self, batch, metrics):
        x, y = batch
        y_pred = self.model(x)
        #metrics(y_pred, y) 
        return self.params.criterion(y_pred, y)

In [17]:
experiment = Hyperparameters(
    task = "multiclass_classification",
    random_seed = 42,
    num_classes = 1000,
    test_split = .3,
    metrics = ["accuracy", "f1score"],

    learning_rate = 1e-6,
    batch_size =  64,
    num_workers = 16,
    optimizer = torch.optim.Adam,
    criterion = torch.nn.CrossEntropyLoss(),

    local_cache_limit = "10gb"
)

pl.seed_everything(experiment.random_seed);

Global seed set to 42


In [18]:
imagenet = ImagenetDataModule(IMAGENET, experiment)
alexnet = AlexNet(experiment.num_classes)
classifier = ClassificationModel(alexnet, experiment)

In [12]:
trainer = pl.Trainer(
    fast_dev_run=True
)
trainer.fit(model = classifier, datamodule = imagenet)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


AttributeError: 'ImageDataLoader' object has no attribute '_standard_transform