In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#test folders should list
!ls /content/drive/MyDrive/18794-Diffusion-Project/tiny-imagenet-200/train
#/content/drive/MyDrive/18794-Diffusion-Project/tiny-imagenet-200/train

In [None]:
#test
import os

base_path = "/content/drive/MyDrive/18794-Diffusion-Project/tiny-imagenet-200"
print("Base path exists:", os.path.exists(base_path))
print("Train path exists:", os.path.exists(os.path.join(base_path, "train")))
!ls -d /content/drive/MyDrive/18794-Diffusion-Project/tiny-imagenet-200/train/* | head

Data loader

In [None]:
import os, glob
from collections import Counter
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

DATASET_DIR = "/content/drive/MyDrive/18794-Diffusion-Project/tiny-imagenet-200"
TRAIN_ROOT = os.path.join(DATASET_DIR, "train")

SELECTED_CLASSES = [
    'n02123045',  # cat
    'n02504458',  # elephant
    'n01641577',  # frog
    'n01443537', #fish
    'n01629819', #lizard
    'n01742172', #snake
    'n01855672', #goose
    'n01910747', #jellyfish
    'n01944390', #snail


]

# Native tiny-ImageNet size
train_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
])

EXTS = ["jpg", "jpeg", "png", "bmp", "tif", "tiff", "webp", "ppm"]

def collect_paths_for_class(train_root, wnid):
    """
    Return a list of image paths for a given wnid.
    Handles both: train/<wnid>/*.JPEG  and train/<wnid>/images/*.JPEG
    """
    paths = []
    for ext in EXTS + [e.upper() for e in EXTS]:
        paths.extend(glob.glob(os.path.join(train_root, wnid, f"*.{ext}")))

    for ext in EXTS + [e.upper() for e in EXTS]:
        paths.extend(glob.glob(os.path.join(train_root, wnid, "images", f"*.{ext}")))
    return sorted(paths)

class TinyImageNetSubset(Dataset):
    def __init__(self, train_root, selected_classes, transform=None):
        if not os.path.isdir(train_root):
            raise FileNotFoundError(f"train_root not found: {train_root}")

        self.transform = transform
        self.selected = [wnid for wnid in selected_classes if os.path.isdir(os.path.join(train_root, wnid))]
        self.missing = [wnid for wnid in selected_classes if wnid not in self.selected]

        # wnid to new id in [0..K-1] (order of SELECTED_CLASSES)
        self.class_to_new = {wnid: i for i, wnid in enumerate(selected_classes) if wnid in self.selected}
        self.new_to_class = {v: k for k, v in self.class_to_new.items()}

        # Collect samples only from the selected classes
        self.samples = []
        for wnid in self.selected:
            files = collect_paths_for_class(train_root, wnid)
            if not files:
                print(f"[warn] No images found for {wnid} under {train_root}/{wnid} or {train_root}/{wnid}/images")
            for p in files:
                self.samples.append((p, self.class_to_new[wnid]))

        if len(self.samples) == 0:
            raise FileNotFoundError(
                "No images found for the selected classes. "
                "Check folder names, structure, or permissions."
            )

        #debug
        counts = Counter([t for _, t in self.samples])
        print("\n--- Subset summary ---")
        if self.missing:
            print("Missing classes:", self.missing)
        else:
            print("Missing classes: (none)")
        for new_id in sorted(counts):
            print(f"{new_id:2d} -> {self.new_to_class[new_id]}: {counts[new_id]} images")
        print(f"TOTAL images in subset: {len(self.samples)}\n")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, target


subset = TinyImageNetSubset(TRAIN_ROOT, SELECTED_CLASSES, transform=train_transform)
loader = DataLoader(subset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

#debug
imgs, labels = next(iter(loader))
print("Batch:", imgs.shape, labels.shape)


Show Images loaded

In [None]:
import matplotlib.pyplot as plt
import torchvision

grid = torchvision.utils.make_grid(imgs[:16], nrow=8)
plt.figure(figsize=(8,4))
plt.axis("off")
plt.imshow(grid.permute(1,2,0).cpu().numpy())
plt.show()
