# install dependencies

In [None]:
!pip install torch torchvision wandb timm

# import and configurations

In [None]:
import os
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
from torch.utils.data import Dataset

PROJECT_NAME = "pet-breed-classifier"
NUM_CLASSES = 37
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 5  # change if you want

device = "cuda" if torch.cuda.is_available() else "cpu"
device


# Dataset class

In [None]:
class OxfordPetsDataset(Dataset):
    """
    Oxford-IIIT Pet Dataset loader using trainval/test txt files.
    """
    def __init__(self, root, split="trainval", transform=None):
        """
        root: path to dataset root containing 'images/' and 'annotations/'
        split: 'trainval' or 'test'
        """
        self.root = root
        self.transform = transform

        annot_file = os.path.join(root, "annotations", f"{split}.txt")
        self.samples = []

        with open(annot_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                img_id = parts[0]           # e.g. 'Abyssinian_1'
                class_id = int(parts[1])    # 1..37

                img_path = os.path.join(root, "images", img_id + ".jpg")
                label = class_id - 1        # convert to 0..36
                self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label
