In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder

import torch
from torch.utils.data import DataLoader
import torchvision
import imageio.v3 as iio
import torchdata.datapipes as dp
import pytorch_lightning as pl

from hyperparameters import Hyperparameters
from tqdm import tqdm

%load_ext dotenv
%dotenv -o

In [2]:
DATA_ROOT = Path.home() / "datasets"
RESISC = DATA_ROOT / "resisc-45" 
IMAGENET = DATA_ROOT / "tiny-imagenet"
CIFAR10 = DATA_ROOT / "cifar10"
MLRSNET = DATA_ROOT / "mlrs-net"

In [3]:
@dp.functional_datapipe("load_image_label_pair")
class ImageLabelPairLoader(dp.iter.IterDataPipe):
    def __init__(self, source_datapipe, label_encoder) -> None:
        self.source_datapipe = source_datapipe
        self.label_encoder = label_encoder
    
    def __iter__(self):
        for file_name, file_stream in self.source_datapipe:
            yield (self._load_image(file_stream), 
                   self._encode_label(file_name))

    def _load_image(self, image_path: Path) -> torch.Tensor:
        return torch.from_numpy(
            iio.imread(uri = image_path, 
                       plugin = "pillow", 
                       extension = ".jpg")
                .astype(np.float32)
                .transpose(2, 0, 1))

    def _encode_label(self, file_name) -> torch.Tensor:
        return torch.from_numpy(
            self.label_encoder.transform([file_name.split("/")[-2]]))

In [5]:
#connection_kwargs = {"endpoint_url": "https://usc1.contabostorage.com"}
#pipe = dp.iter.FSSpecFileLister("s3://resisc-45", **connection_kwargs)
 ## type: ignore
#class_names = sorted([Path(path).stem for path in iter(pipe)])
#label_encoder = LabelEncoder().fit(class_names)

#train_pipe, test_pipe = (dp.iter.FSSpecFileLister(pipe, **connection_kwargs)
                                #.demux(2, train_test_split, buffer_size = 700 * 45))

#train_pipe = (train_pipe.shuffle(buffer_size = 545 * 45)
                        #.open_files_by_fsspec("rb", **connection_kwargs) # type: ignore
                        #.load_image_label_pair(label_encoder)
             #)
##train_pipe = ImageLabelPairLoader(train_pipe, label_encoder)

#test_pipe = (test_pipe.open_files_by_fsspec("rb", **connection_kwargs) # type: ignore
                      #.load_image_label_pair(label_encoder)
            #)

In [6]:
@dp.functional_datapipe("apply_image_transforms")
class ImageTransformer(dp.iter.IterDataPipe):
    def __init__(self, source_datapipe, transforms = None) -> None:
        self.source_datapipe = source_datapipe
        self.transforms = transforms
    
    def __iter__(self):
        for image, annotation in self.source_datapipe:
            if self.transforms is not None:
                yield (self.transforms(image), annotation)
            else:
                yield (self._standard_transforms(image), annotation)
                
    def _standard_transforms(self, image: torch.Tensor):
        return torchvision.transforms.Resize(256, antialias=True)(image / 255.0) # type: ignore

In [44]:
class ResiscDataModule(pl.LightningDataModule):
    def __init__(self, root: Path|str, is_s3_bucket: bool, params: Hyperparameters):
        super().__init__()
        self.root = root
        self.params = params
        self.is_s3_bucket = is_s3_bucket
        
    def setup(self, stage):
        self._remote_datapipe() if self.is_s3_bucket else self._local_datapipe()

        if stage == "fit":
            #self.train_dp, self.val_dp = (
                    #self.train_dp.random_split(weights={"train": (1-self.params.val_split), "valid": self.params.val_split}, 
                                               #total_length = 700 * self.params.num_classes,
                                               #seed=self.params.random_seed)
            #)
            self.train_dp = dp.iter.Shuffler(self.train_dp, buffer_size = 700 * self.params.num_classes) #type: ignore
            self.train_dp = ImageLabelPairLoader(self.train_dp, self.label_encoder)
            self.train_dp = ImageTransformer(self.train_dp)
            self.train_dp = self.train_dp.set_length(int(700 * (1-self.params.test_split) * self.params.num_classes))
            #TODO: multiply len by (1-self.params.val_split) ?

            #self.val_dp = ImageLabelPairLoader(self.val_dp, self.label_encoder)
            #self.val_dp = ImageTransformer(self.val_dp)

        if stage == "test":
            self.test_dp = ImageLabelPairLoader(self.test_dp, self.label_encoder)
            self.test_dp = ImageTransformer(self.test_dp)
            self.test_dp = self.test_dp.set_length(int(700 * self.params.test_split * self.params.num_classes))

    def train_dataloader(self):
        return DataLoader(dataset = self.train_dp, 
                          batch_size = self.params.batch_size,
                          num_workers = self.params.num_workers,
                          shuffle = True)
    
    #def val_dataloader(self):
        #return DataLoader(dataset = self.val_dp, 
                          #batch_size = self.params.batch_size,
                          #num_workers = self.params.num_workers)


    def test_dataloader(self):
        return DataLoader(dataset = self.test_dp, 
                          batch_size = self.params.batch_size,
                          num_workers = self.params.num_workers)

    def _set_label_encoder(self, dir_level_iter):
        class_names = sorted([Path(path).stem for path in dir_level_iter])
        self.label_encoder = LabelEncoder().fit(class_names)

    def _train_test_split(self, path: str) -> int:
        idx = int(path.split('/')[-1][:-4].split('_')[-1])
        return int(idx <= 700 * self.params.test_split)

    def _remote_datapipe(self):
        connection_kwargs = {"endpoint_url": "https://usc1.contabostorage.com"}
        pipe = dp.iter.FSSpecFileLister(self.root, **connection_kwargs) #type: ignore
        self._set_label_encoder(iter(pipe))

        pipe = dp.iter.FSSpecFileLister(pipe, **connection_kwargs)
        #train_test_split
        self.train_dp, self.test_dp = pipe.demux(num_instances=2, 
                                                 classifier_fn=self._train_test_split,
                                                 buffer_size=700*self.params.num_classes) 
        self.train_dp = self.train_dp.open_files_by_fsspec("rb", **connection_kwargs) # type: ignore
        self.test_dp = self.test_dp.open_files_by_fsspec("rb", **connection_kwargs) # type: ignore
    
    def _local_datapipe(self):
        self._set_label_encoder(self.root.iterdir()) #type: ignore

        pipe = dp.iter.FileLister(self.root.as_posix(), recursive=True) #type: ignore
        self.train_dp, self.test_dp = pipe.demux(num_instances=2, 
                                                 classifier_fn=self._train_test_split,
                                                 buffer_size=700*self.params.num_classes) 
        self.train_dp = self.train_dp.open_files("b")
        self.test_dp = self.test_dp.open_files("b")

In [45]:
experiment = Hyperparameters(
    task = "multiclass-classification",
    random_seed = 69,
    num_classes = 45,
    test_split = .25,
    metrics = ["accuracy", "f1score"],
    learning_rate = 1e-5,
    batch_size = 64,
    num_workers = 16,
    optimizer = torch.optim.Adam,
    criterion = torch.nn.CrossEntropyLoss(),
)

resisc_dm = ResiscDataModule(RESISC, False, experiment)

In [46]:
resisc_dm.setup("fit")
resisc_dl = resisc_dm.train_dataloader()
len(resisc_dl)
#viz_batch(next(iter(resisc_dl)), resisc_dm.label_encoder)

370

In [None]:
image_df = pd.DataFrame({"image_path": list((MLRSNET/"Images").rglob("*.jpg"))})
image_df["image"] = image_df["image_path"].apply(lambda x: x.name)
image_df = image_df.set_index("image")

label_df = pd.concat([pd.read_csv(x) for x in (MLRSNET / "Labels").iterdir()])
label_df = label_df.set_index("image")

assert len(label_df) == len(image_df), "#images != #labels"

df = label_df.join(image_df, sort = True)

In [None]:
## How to split train and test
# Equal test samples (500?) from each class or proportional?
# Instinct says equal
label_df.sum(axis = 0).sort_values()

In [None]:
label_df.sum(axis = 0).plot(kind = "bar", rot=90, figsize=(10, 10))

In [None]:
image_dp = dp.iter.IterableWrapper(df["image_path"])
label_dp = dp.iter.IterableWrapper(df.iloc[:, :-1].values)

In [None]:
next(iter(label_dp))

In [35]:
def viz_batch(batch: tuple[torch.Tensor, torch.Tensor], le: LabelEncoder) -> None:
    images, targets = batch
    labels = le.inverse_transform(targets.ravel())
    assert images.shape[0] == targets.shape[0], "#images != #targets"

    subplot_dims:tuple[int, int]
    if images.shape[0] <= 8:
        subplot_dims = (1, images.shape[0])
    else:
        subplot_dims = (int(np.ceil(images.shape[0]/8)), 8)

    figsize = 20
    figsize_factor = subplot_dims[0] / subplot_dims[1]
    _, axes = plt.subplots(nrows = subplot_dims[0], 
                           ncols = subplot_dims[1], 
                           figsize = (figsize, figsize * figsize_factor))
    for idx, ax in enumerate(axes.ravel()):
        ax.imshow(images[idx].permute(1, 2, 0))
        ax.tick_params(axis = "both", which = "both", 
                       bottom = False, top = False, 
                       left = False, right = False,
                       labeltop = False, labelbottom = False, 
                       labelleft = False, labelright = False)
        ax.set_xlabel(f"{labels[idx]}({targets[idx].item()})")