In [3]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import imageio.v3 as iio
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder

import torch
import torchvision
import torchdata.datapipes as dp
import pytorch_lightning as pl

from hyperparameters import Hyperparameters
from tqdm import tqdm

In [3]:
DATA_ROOT = Path.home() / "datasets"
RESISC = DATA_ROOT / "resisc-45" 
IMAGENET = DATA_ROOT / "tiny-imagenet"
CIFAR10 = DATA_ROOT / "cifar10"
MLRSNET = DATA_ROOT / "mlrs-net"

In [4]:
#@dp.functional_datapipe("load_datapoint")
class DataPointLoader(dp.iter.IterDataPipe):
    def __init__(self, source_datapipe, encoder) -> None:
        self.source_datapipe = source_datapipe
        self.encoder = encoder
    
    def __iter__(self):
        for file_name, file_stream in self.source_datapipe:
            yield (self._load_image(file_stream), 
                   self._encode_label(file_name))

    def _load_image(self, image_path: Path) -> torch.Tensor:
        return torch.from_numpy(
            iio.imread(uri = image_path, 
                       plugin = "pillow", 
                       extension = ".jpg")
                .astype(np.float32)
                .transpose(2, 0, 1))

    def _encode_label(self, file_name) -> torch.Tensor:
        return torch.from_numpy(
            self.encoder.transform([file_name.split("/")[-2]]))

In [6]:
import s3fs

In [58]:
connection_kwargs = {"endpoint_url": "https://usc1.contabostorage.com"}
pipe = (dp.iter.IterableWrapper(["s3://resisc-45"])
               .list_files_by_fsspec(**connection_kwargs)
               .list_files_by_fsspec(**connection_kwargs)
               .open_files_by_fsspec(mode = 'rb', **connection_kwargs)
               )
pipe = DataPointLoader(pipe, le)
next(iter(pipe))

('s3://resisc-45/airplane/airplane_001.jpg',
 StreamWrapper<<File-like object S3FileSystem, resisc-45/airplane/airplane_001.jpg>>)

In [49]:
len(list(pipe))

31500

In [None]:
#le = LabelEncoder().fit([x.name for x in RESISC.iterdir()])
#pipe = dp.iter.FileLister(RESISC.as_posix(), recursive=True) 
#pipe = dp.iter.FileOpener(pipe, mode = 'b')
#pipe = dp.utils.StreamWrapper(pipe)
#pipe = DataPointLoader(pipe, le)
#image, label = next(iter(pipe))

In [None]:
image_df = pd.DataFrame({"image_path": list((MLRSNET/"Images").rglob("*.jpg"))})
image_df["image"] = image_df["image_path"].apply(lambda x: x.name)
image_df = image_df.set_index("image")

label_df = pd.concat([pd.read_csv(x) for x in (MLRSNET / "Labels").iterdir()])
label_df = label_df.set_index("image")

assert len(label_df) == len(image_df), "#images != #labels"

df = label_df.join(image_df, sort = True)

In [None]:
## How to split train and test
# Equal test samples (500?) from each class or proportional?
# Instinct says equal
label_df.sum(axis = 0).sort_values()

In [None]:
label_df.sum(axis = 0).plot(kind = "bar", rot=90, figsize=(10, 10))

In [None]:
image_dp = dp.iter.IterableWrapper(df["image_path"])
label_dp = dp.iter.IterableWrapper(df.iloc[:, :-1].values)

In [None]:
next(iter(label_dp))

In [None]:
def viz_batch(batch: tuple[torch.Tensor, torch.Tensor], le: LabelEncoder) -> None:
    images, targets = batch
    labels = le.inverse_transform(targets)
    assert images.shape[0] == targets.shape[0], "#images != #targets"

    subplot_dims:tuple[int, int]
    if images.shape[0] <= 8:
        subplot_dims = (1, images.shape[0])
    else:
        subplot_dims = (int(np.ceil(images.shape[0]/8)), 8)

    figsize = 20
    figsize_factor = subplot_dims[0] / subplot_dims[1]
    _, axes = plt.subplots(nrows = subplot_dims[0], 
                           ncols = subplot_dims[1], 
                           figsize = (figsize, figsize * figsize_factor))
    for idx, ax in enumerate(axes.ravel()):
        ax.imshow(images[idx].permute(1, 2, 0))
        ax.tick_params(axis = "both", which = "both", 
                       bottom = False, top = False, 
                       left = False, right = False,
                       labeltop = False, labelbottom = False, 
                       labelleft = False, labelright = False)
        ax.set_xlabel(f"{labels[idx]}({targets[idx]})")

In [None]:


experiment = Hyperparameters(
    task = "multiclass-classification",
    num_classes = 45,
    metrics = ["accuracy", "f1score"],
    learning_rate = 1e-5,
    batch_size = 64,
    num_workers = 16,
    optimizer = torch.optim.Adam,
    criterion = torch.nn.CrossEntropyLoss(),
)

resisc_dm = ResiscDataModule(RESISC, experiment)

In [None]:
le = resisc_dm.label_encoder
resisc_dm.setup("fit")
batch = next(iter(resisc_dm.train_dataloader()))
images, targets = batch
print(images.shape)
print(targets.shape)

In [None]:
viz_batch(batch, le)