In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/CS4995/

/content/drive/MyDrive/CS4995


In [None]:
ls

example_test_predictions.csv  superclass_mapping.csv  train_data.csv
simple_cnn_demo.ipynb         [0m[01;34mtest_images[0m/            [01;34mtrain_images[0m/
subclass_mapping.csv          test_images.zip         train_images.zip


In [None]:
super_class_train = pd.read_csv('train_data.csv')
# Load dataset
super_class_train = pd.read_csv('train_data.csv')
super_class_train['image'] = super_class_train['image'].apply(
    lambda x: x if x.startswith('./train_images/train_images/') else f'./train_images/train_images/{x}'
)

In [None]:
# Add other dataset paths and labels
add_dirs = {
    './train_images/other_animals/': 3,
    './train_images/bird_sub_novel/': 0,
    './train_images/dog_sub_novel/': 1,
    './train_images/reptile_novel/': 2
}
for root_dir, label in add_dirs.items():
    img_paths = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            full_path = os.path.join(root, file)
            img_paths.append(full_path.replace("\\", "/"))
    df_new = pd.DataFrame({
        'image': img_paths,
        'superclass_index': label,
        'description': ["no description"] * len(img_paths)
    })
    super_class_train = pd.concat([super_class_train, df_new], ignore_index=True)

# Dataset and Dataloaders


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from tqdm import tqdm

# ---------------- Dataset ----------------
class SuperClassDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'image']
        label = self.df.loc[idx, 'superclass_index']
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

    def set_transform(self, transform):
        self.transform = transform

# ---------------- Model ----------------
class ViTBlock(nn.Module):
    def __init__(self, dim, heads=4, mlp_dim=256, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
        )

    def forward(self, x):
        x = self.norm1(x)
        attn_output, _ = self.attn(x, x, x)
        x = x + attn_output
        x = self.norm2(x)
        x = x + self.mlp(x)
        return x

class SuperClassCNNViT(nn.Module):
    def __init__(self, output_dim=4):
        super(SuperClassCNNViT, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.1),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.patch_dim = 256
        self.num_patches = 7 * 7
        self.vit = nn.Sequential(
            ViTBlock(dim=self.patch_dim, heads=4, mlp_dim=512),
            ViTBlock(dim=self.patch_dim, heads=4, mlp_dim=512),
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.patch_dim * self.num_patches, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, output_dim)
        )

    def forward(self, x):
        x = self.conv_layers(x)              # (B, 256, 7, 7)
        x = x.flatten(2).transpose(1, 2)     # (B, 49, 256)
        x = self.vit(x)                      # (B, 49, 256)
        x = x.flatten(1)                     # (B, 49*256)
        x = self.classifier(x)               # (B, output_dim)
        return x

# ---------------- Data Loading ----------------
# Load CSV
df = pd.read_csv('train_data.csv')
df['image'] = df['image'].apply(lambda x: x if x.startswith('./train_images/train_images/') else f'./train_images/train_images/{x}')

# Append additional directories
extra_dirs = {
    './train_images/other_animals/': 3,
    './train_images/bird_sub_novel/': 0,
    './train_images/dog_sub_novel/': 1,
    './train_images/reptile_novel/': 2,
}
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')

for root_dir, label in extra_dirs.items():
    img_paths = []
    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.lower().endswith(valid_extensions):  # Only image files
                full_path = os.path.join(root, file).replace("\\", "/")
                img_paths.append(full_path)
    df_extra = pd.DataFrame({
        'image': img_paths,
        'superclass_index': label,
        'description': ["no description"] * len(img_paths)
    })
    df = pd.concat([df, df_extra], ignore_index=True)


# Transform (same for train & val)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

# Dataset & split
dataset = SuperClassDataset(df, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Apply the same transform to both
train_dataset.dataset.set_transform(transform)
val_dataset.dataset.set_transform(transform)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# ---------------- Trainer ----------------
class Trainer:
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.best_acc = 0.0  # Track best validation accuracy

    def train_epoch(self):
        self.model.train()
        total_loss, correct, total = 0.0, 0, 0
        for inputs, labels in tqdm(self.train_loader, desc="Training", leave=False):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        print(f"Training Loss: {total_loss / len(self.train_loader):.4f} | Accuracy: {correct / total * 100:.2f}%")

    def validate_epoch(self):
        self.model.eval()
        total_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(self.val_loader, desc="Validating", leave=False):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item()
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)
        print(f"Validation Loss: {total_loss / len(self.val_loader):.4f} | Accuracy: {correct / total * 100:.2f}%")
        if correct / total * 100 > self.best_acc:
            self.best_acc = correct / total * 100
            torch.save(self.model.state_dict(), 'best_model.pth')
            print("✅ Best model saved!")


# ---------------- Run Training ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SuperClassCNNViT(output_dim=4)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, device)
for epoch in range(20):
    print(f"\nEpoch {epoch + 1}/10")
    trainer.train_epoch()
    trainer.validate_epoch()




Epoch 1/10




Training Loss: 0.8383 | Accuracy: 66.35%




Validation Loss: 0.5843 | Accuracy: 76.97%
✅ Best model saved!

Epoch 2/10




Training Loss: 0.6119 | Accuracy: 76.41%




Validation Loss: 0.4882 | Accuracy: 81.70%
✅ Best model saved!

Epoch 3/10




Training Loss: 0.5157 | Accuracy: 80.02%




Validation Loss: 0.4363 | Accuracy: 83.44%
✅ Best model saved!

Epoch 4/10




Training Loss: 0.4694 | Accuracy: 81.56%




Validation Loss: 0.3744 | Accuracy: 86.47%
✅ Best model saved!

Epoch 5/10




Training Loss: 0.4303 | Accuracy: 83.58%




Validation Loss: 0.3521 | Accuracy: 86.64%
✅ Best model saved!

Epoch 6/10




Training Loss: 0.3833 | Accuracy: 85.55%




Validation Loss: 0.3672 | Accuracy: 86.27%

Epoch 7/10




Training Loss: 0.3670 | Accuracy: 85.87%




Validation Loss: 0.3320 | Accuracy: 86.89%
✅ Best model saved!

Epoch 8/10




Training Loss: 0.3403 | Accuracy: 87.04%




Validation Loss: 0.3539 | Accuracy: 86.27%

Epoch 9/10




Training Loss: 0.3238 | Accuracy: 87.82%




Validation Loss: 0.3059 | Accuracy: 88.34%
✅ Best model saved!

Epoch 10/10




Training Loss: 0.3068 | Accuracy: 88.17%




Validation Loss: 0.2965 | Accuracy: 88.51%
✅ Best model saved!

Epoch 11/10




Training Loss: 0.2897 | Accuracy: 89.42%




Validation Loss: 0.2977 | Accuracy: 88.26%

Epoch 12/10




Training Loss: 0.2652 | Accuracy: 90.31%




Validation Loss: 0.2661 | Accuracy: 89.88%
✅ Best model saved!

Epoch 13/10




Training Loss: 0.2583 | Accuracy: 90.56%




Validation Loss: 0.2813 | Accuracy: 89.42%

Epoch 14/10




Training Loss: 0.2379 | Accuracy: 90.86%




Validation Loss: 0.2902 | Accuracy: 89.29%

Epoch 15/10




Training Loss: 0.2317 | Accuracy: 91.29%




Validation Loss: 0.2517 | Accuracy: 90.21%
✅ Best model saved!

Epoch 16/10




Training Loss: 0.2180 | Accuracy: 91.57%




Validation Loss: 0.2575 | Accuracy: 90.62%
✅ Best model saved!

Epoch 17/10




Training Loss: 0.2014 | Accuracy: 92.37%




Validation Loss: 0.2545 | Accuracy: 90.21%

Epoch 18/10




Training Loss: 0.1995 | Accuracy: 92.66%




Validation Loss: 0.2887 | Accuracy: 89.42%

Epoch 19/10




Training Loss: 0.1824 | Accuracy: 93.12%




Validation Loss: 0.2503 | Accuracy: 90.54%

Epoch 20/10




Training Loss: 0.1736 | Accuracy: 93.33%




Validation Loss: 0.2413 | Accuracy: 91.00%
✅ Best model saved!


In [None]:
model.load_state_dict(torch.load("best_model.pth"))

<All keys matched successfully>

In [None]:
torch.save(model.state_dict(), "best_superclass_model.pth")
print("✅ Model weights loaded and saved as 'reloaded_best_model.pth'")


✅ Model weights loaded and saved as 'reloaded_best_model.pth'
