In [41]:
from datasets import load_dataset
import numpy as np
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import torch
import kornia as K
from PIL import Image

hf_dataset = load_dataset("blanchon/UC_Merced", split= "train")
# ucmerced_test = load_dataset("blanchon/UC_Merced", split= "train[80%:]")

In [42]:
class PreProcess(torch.nn.Module):
    """Module to perform pre-process using Kornia on torch tensors."""
    def __init__(self) -> None:
        super().__init__()
 
    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: Image) -> torch.Tensor:
        x_tmp: np.ndarray = np.array(x)  # HxWxC
        x_out: torch.Tensor = K.image_to_tensor(x_tmp, keepdim=True)  # CxHxW
        return x_out.float() / 255.0

train_transforms = torch.nn.Sequential(
    PreProcess(),
    K.augmentation.Resize(size=224, side="short"),
    K.augmentation.CenterCrop(size=224),
    K.augmentation.RandomHorizontalFlip(p=0.5),
    K.augmentation.ColorJiggle(),
    K.augmentation.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
)

val_transforms = torch.nn.Sequential(
    PreProcess(),
    K.augmentation.Resize(size=224, side="short"),
    K.augmentation.CenterCrop(size=224),
    K.augmentation.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
)

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["image"] = [train_transforms(image).squeeze() for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["image"] = [val_transforms(image).squeeze() for image in example_batch["image"]]
    return example_batch

In [43]:
hf_dataset.set_transform(preprocess_train)

In [None]:
hf_dataset[0]

In [None]:
train_loader = DataLoader(hf_dataset, batch_size=16, shuffle=True, num_workers=0)

for batch in train_loader:
    print(batch['image'].shape)
    break

In [None]:
labels = hf_dataset.features["label"].names
print("ALL LABELS")
print(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

print("\nALL LABELS to ID")
print(label2id)
print("\nALL ID to LABELS")
print(id2label)

In [None]:
# Inspect an example
dataset[0][0].shape
