In [1]:
# Persist data on google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!cd /content/drive/MyDrive/COMS_4995/final_project/
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-c7q2uf4v
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-c7q2uf4v
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


# Loading Required Data

In [14]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import zipfile
import os

base_dir = '/content/drive/MyDrive/COMS_4995/final_project/'
image_dir = base_dir + 'data/train_images/'
zip_path = base_dir + "data/train_images.zip"

train_zip_path = base_dir + "data/train_images.zip"
train_extract_dir = base_dir + "data/"
test_zip_path = base_dir + "data/test_images.zip"
test_extract_dir = base_dir + "data/"

# === Unzip train_images if not already extracted ===
if not os.path.exists(train_extract_dir):
    print("Unzipping train_images.zip...")
    with zipfile.ZipFile(train_zip_path, 'r') as zip_ref:
        zip_ref.extractall(train_extract_dir)
    print("Unzipped to:", train_extract_dir)
else:
    print("train_images/ folder already exists. Skipping unzip.")

# === Unzip test_images if not already extracted ===
if not os.path.exists(test_extract_dir):
    print("Unzipping test_images.zip...")
    with zipfile.ZipFile(test_zip_path, 'r') as zip_ref:
        zip_ref.extractall(test_extract_dir)
    print("Unzipped to:", test_extract_dir)
else:
    print("test_images/ folder already exists. Skipping unzip.")

# Load CSVs
train_df = pd.read_csv(base_dir + "data/train_data.csv")
superclass_map = pd.read_csv(base_dir + "data/superclass_mapping.csv")
subclass_map = pd.read_csv(base_dir + "data/subclass_mapping.csv")

# Merge labels for readability
train_df = train_df.merge(superclass_map, left_on="superclass_index", right_index=True, suffixes=('', '_super'))
train_df = train_df.merge(subclass_map, left_on="subclass_index", right_index=True, suffixes=('', '_sub'))

# Rename for clarity
train_df.rename(columns={"class": "superclass_name", "class_sub": "subclass_name"}, inplace=True)
train_df = train_df.drop(columns=["index", "index_sub"])

train_images/ folder already exists. Skipping unzip.
test_images/ folder already exists. Skipping unzip.


# Explore/Visualize Training Data

In [187]:
import numpy as np

total_images = len(train_df)
print("Total training images:", total_images)

num_superclasses = train_df['superclass_name'].nunique()
num_subclasses = train_df['subclass_name'].nunique()
print("Unique superclasses:", num_superclasses)
print("Unique subclasses:", num_subclasses)

# === Superclass Counts with Percentages ===
superclass_counts = train_df['superclass_name'].value_counts().reset_index()
superclass_counts.columns = ['Superclass', 'Image Count']
superclass_counts['Percent of Total'] = (superclass_counts['Image Count'] / total_images * 100).round(2)

print("\nImages per Superclass:")
print(superclass_counts)

# === Group by (superclass, subclass) ===
grouped = train_df.groupby(['superclass_name', 'subclass_name']).size().reset_index(name='Image Count')

# Total per superclass (for per-superclass percentages)
super_totals = train_df.groupby('superclass_name').size().to_dict()
total_images = len(train_df)

# Compute percentages
grouped['% of Superclass'] = grouped.apply(
    lambda row: (row['Image Count'] / super_totals[row['superclass_name']] * 100), axis=1
).round(2)

grouped['% of Total'] = (grouped['Image Count'] / total_images * 100).round(2)

# Sort and display
grouped = grouped.sort_values(by=['superclass_name', 'Image Count'], ascending=[True, False])

# === Global (across all data) stats ===
global_percents = grouped['% of Total'].values
print("\n Stats on % of Total Dataset per (superclass, subclass):")
print(f"  Min:    {global_percents.min():.2f}%")
print(f"  Max:    {global_percents.max():.2f}%")
print(f"  Mean:   {global_percents.mean():.2f}%")
print(f"  Median: {np.median(global_percents):.2f}%")
print(f"  Std:    {global_percents.std():.2f}%")

# === Per-superclass relative stats ===
relative_percents = grouped['% of Superclass'].values
print("\n Stats on % of Superclass per subclass:")
print(f"  Min:    {relative_percents.min():.2f}%")
print(f"  Max:    {relative_percents.max():.2f}%")
print(f"  Mean:   {relative_percents.mean():.2f}%")
print(f"  Median: {np.median(relative_percents):.2f}%")
print(f"  Std:    {relative_percents.std():.2f}%")


superclasses = grouped['superclass_name'].unique()

for sc in superclasses:
    print(f"\n Stats for subclasses within superclass: {sc}")

    subset = grouped[grouped['superclass_name'] == sc]['% of Superclass'].values
    print(f"  Number of subclasses: {len(subset)}")
    print(f"  Min:    {subset.min():.2f}%")
    print(f"  Max:    {subset.max():.2f}%")
    print(f"  Mean:   {subset.mean():.2f}%")
    print(f"  Median: {np.median(subset):.2f}%")
    print(f"  Std:    {subset.std():.2f}%")

max_count = grouped['Image Count'].max()
min_count = grouped['Image Count'].min()

print("\nSubclass with most images:")
print(f"  {max_count} images")

print("Subclass with fewest images:")
print(f"  {min_count} images")

Total training images: 6288
Unique superclasses: 3
Unique subclasses: 87

Images per Superclass:
  Superclass  Image Count  Percent of Total
0    reptile         2354             37.44
1        dog         2084             33.14
2       bird         1850             29.42

 Stats on % of Total Dataset per (superclass, subclass):
  Min:    0.78%
  Max:    1.62%
  Mean:   1.15%
  Median: 0.80%
  Std:    0.40%

 Stats on % of Superclass per subclass:
  Min:    2.08%
  Max:    5.41%
  Mean:   3.45%
  Median: 2.70%
  Std:    1.16%

 Stats for subclasses within superclass: bird
  Number of subclasses: 29
  Min:    2.70%
  Max:    5.41%
  Mean:   3.45%
  Median: 2.70%
  Std:    1.21%

 Stats for subclasses within superclass: dog
  Number of subclasses: 29
  Min:    2.35%
  Max:    4.80%
  Mean:   3.45%
  Median: 2.35%
  Std:    1.22%

 Stats for subclasses within superclass: reptile
  Number of subclasses: 29
  Min:    2.08%
  Max:    4.33%
  Mean:   3.45%
  Median: 4.25%
  Std:    1.06%

Sub

In [16]:
train_df.head()

Unnamed: 0,image,superclass_index,subclass_index,description,superclass_name,subclass_name
0,0.jpg,1,37,"nature photograph of a dog, specifically a Mal...",dog,"Maltese dog, Maltese terrier, Maltese"
1,1.jpg,0,42,"nature photograph of a bird, specifically a oy...",bird,"oystercatcher, oyster catcher"
2,2.jpg,1,62,"nature photograph of a dog, specifically a Afg...",dog,"Afghan hound, Afghan"
3,3.jpg,1,31,"nature photograph of a dog, specifically a Shi...",dog,Shih-Tzu
4,4.jpg,0,4,"nature photograph of a bird, specifically a gr...",bird,"great grey owl, great gray owl, Strix nebulosa"


# Dataloader

In [17]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch

class MultiClassImageDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        image = Image.open(img_path).convert('RGB')

        superclass_label = row['superclass_index']
        subclass_label = row['subclass_index']

        if self.transform:
            image = self.transform(image)

        return image, superclass_label, subclass_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, img_name

# Incorporate Tiny ImageNet for example images of novel data

In [18]:
from datasets import load_dataset
!pip install -U datasets
ds_train = load_dataset("slegroux/tiny-imagenet-200-clean", split="train")



In [19]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip tiny-imagenet-200.zip tiny-imagenet-200/words.txt

URL transformed to HTTPS due to an HSTS policy
--2025-05-13 18:42:03--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip.1’


2025-05-13 18:42:08 (51.2 MB/s) - ‘tiny-imagenet-200.zip.1’ saved [248100043/248100043]

Archive:  tiny-imagenet-200.zip
replace tiny-imagenet-200/words.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: n


In [20]:
# grab mapping of id to human readable label
wnid_to_label = {}
with open("tiny-imagenet-200/words.txt", "r") as f:
    for line in f:
        wnid, label = line.strip().split("\t")
        wnid_to_label[wnid] = label

# get the ids
synsets = ds_train.features["label"].names

# map id to human readable labels
idx_to_label = {i: wnid_to_label[wnid] for i, wnid in enumerate(synsets)}
unique_labels = set(label for id, label in idx_to_label.items())

print(f"Number of unique labels: {len(unique_labels)}")
print(unique_labels)

Number of unique labels: 200
{'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 'brown bear, bruin, Ursus arctos', 'oboe, hautboy, hautbois', 'military uniform', 'brain coral', "potter's wheel", 'lifeboat', 'steel arch bridge', 'triumphal arch', 'vestment', 'kimono', 'suspension bridge', 'seashore, coast, seacoast, sea-coast', 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 'slug', 'cockroach, roach', 'mashed potato', 'fur coat', 'bathtub, bathing tub, bath, tub', 'trilobite', 'frying pan, frypan, skillet', 'gasmask, respirator, gas helmet', 'lawn mower, mower', 'altar', 'sandal', 'confectionery, confectionary, candy store', 'scoreboard', 'banana', 'goldfish, Carassius auratus', "spider web, spider's web", 'bison', 'organ, pipe organ', 'alp', 'grasshopper, hopper', 'maypole', 'barn', 'American alligator, Alligator mississipiensis', 'albatross, mollymawk', 'apron', 'obelisk', 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',

In [21]:
# manually extracted novel classes from dataset
non_reptile_bird_dog_labels = {
    'thatch, thatched roof',
    'parking meter',
    'syringe',
    'gondola',
    'dumbbell',
    'altar',
    'drumstick',
    'centipede',
    'cannon',
    'limousine, limo',
    'stopwatch, stop watch',
    'CD player',
    'basketball',
    'meat loaf, meatloaf',
    'chain',
    'orangutan, orang, orangutang, Pongo pygmaeus',
    'brass, memorial tablet, plaque',
    'sunglasses, dark glasses, shades',
    'walking stick, walkingstick, stick insect',
    'sulphur butterfly, sulfur butterfly',
    'sea slug, nudibranch',
    'comic book',
    'bell pepper',
    'pomegranate',
    'convertible',
    'triumphal arch',
    'punching bag, punch bag, punching ball, punchball',
    "spider web, spider's web",
    'miniskirt, mini',
    'mushroom',
    'frying pan, frypan, skillet',
    'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
    'bikini, two-piece',
    'cockroach, roach',
    'sewing machine',
    'cliff, drop, drop-off',
    'orange',
    'military uniform',
    'refrigerator, icebox',
    'beer bottle',
    'cauliflower',
    'slug',
    'scoreboard',
    'poncho',
    'desk',
    'guacamole',
    'bison',
    'rocking chair, rocker',
    'Christmas stocking',
    'espresso',
    'obelisk',
    'tarantula',
    'candle, taper, wax light',
    'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
    'jellyfish',
    'vestment',
    'water tower',
    'pop bottle, soda bottle',
    'chimpanzee, chimp, Pan troglodytes',
    'wok',
    'pretzel',
    'lampshade, lamp shade',
    'turnstile',
    'potpie',
    'go-kart',
    'suspension bridge',
    'computer keyboard, keypad',
    'snorkel',
    'barbershop',
    'banana',
    'water jug',
    'hourglass',
    'ice lolly, lolly, lollipop, popsicle',
    'cliff dwelling',
    'hog, pig, grunter, squealer, Sus scrofa',
    'sea cucumber, holothurian',
    'king penguin, Aptenodytes patagonica',
    'abacus',
    'tabby, tabby cat',
    'dining table, board',
    'wooden spoon',
    'acorn',
    'bow tie, bow-tie, bowtie',
    'pay-phone, pay-station',
    'lion, king of beasts, Panthera leo',
    'gazelle',
    'chest',
    'beaker',
    'lawn mower, mower',
    'confectionery, confectionary, candy store',
    'lemon',
    'pole',
    'steel arch bridge',
    'bullet train, bullet',
    'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
    'mashed potato',
    'alp',
    'guinea pig, Cavia cobaya',
    'trilobite',
    'bathtub, bathing tub, bath, tub',
    'tractor',
    'sock',
    'bucket, pail',
    'jinrikisha, ricksha, rickshaw',
    'kimono',
    'binoculars, field glasses, opera glasses',
    'pill bottle',
    'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
    'snail',
    'dugong, Dugong dugon',
    'pizza, pizza pie',
    'barrel, cask',
    'brown bear, bruin, Ursus arctos',
    'viaduct',
    'gasmask, respirator, gas helmet',
    "plunger, plumber's helper",
    'moving van',
    'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
    'broom',
    'swimming trunks, bathing trunks',
    'projectile, missile',
    'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
    'brain coral',
    'freight car',
    'sports car, sport car',
    'dam, dike, dyke',
    'remote control, remote',
    'sandal',
    'school bus',
    'fountain',
    'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
    'mantis, mantid',
    'scorpion',
    'volleyball',
    'nail',
    'trolleybus, trolley coach, trackless trolley',
    'Persian cat',
    'lifeboat',
    'teapot',
    'crane',
    'umbrella',
    'lakeside, lakeshore',
    'barn',
    'organ, pipe organ',
    'ice cream, icecream',
    'Arabian camel, dromedary, Camelus dromedarius',
    'oboe, hautboy, hautbois',
    'reel',
    'apron',
    'beacon, lighthouse, beacon light, pharos',
    'sombrero',
    'flagpole, flagstaff',
    'Egyptian cat',
    'torch',
    'bee',
    'butcher shop, meat market',
    'plate',
    'fly',
    'cardigan',
    'ox',
    "potter's wheel",
    'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
    'seashore, coast, seacoast, sea-coast',
    'rugby ball',
    'iPod',
    'African elephant, Loxodonta africana',
    'teddy, teddy bear',
    "academic gown, academic robe, judge's robe",
    'birdhouse',
    'bannister, banister, balustrade, balusters, handrail',
    'magnetic compass',
    'grasshopper, hopper',
    'maypole',
    'picket fence, paling',
    'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
    'fur coat',
    'neck brace',
    'American lobster, Northern lobster, Maine lobster, Homarus americanus',
    'backpack, back pack, knapsack, packsack, rucksack, haversack',
    'black widow, Latrodectus mactans'
}

filtered_data = [data for data in ds_train if idx_to_label[data['label']] in non_reptile_bird_dog_labels]
print(len(filtered_data))


88244


In [22]:
from collections import Counter

num_samples = 2000
sampled_filtered_ds = ds_train.shuffle(seed=42).select(range(num_samples))
label_counts = Counter([data['label'] for data in sampled_filtered_ds])
counts = list(label_counts.values())

print(f"Total labels: {sum(counts)}")
print(f"Unique labels: {len(label_counts)}")
print(f"Max samples per label: {max(counts)}")
print(f"Min samples per label: {min(counts)}")
print(f"Average samples per label: {sum(counts) / len(counts):.2f}")

Total labels: 2000
Unique labels: 200
Max samples per label: 18
Min samples per label: 3
Average samples per label: 10.00


In [26]:
from tqdm import tqdm
# move these novel samples to train_image folder
oe_image_dir = base_dir + "data/train_images/"

oe_rows = []
for i, item in tqdm(enumerate(sampled_filtered_ds), total=len(sampled_filtered_ds)):
    image = item["image"]
    filename = f"novel_{i}.jpg"
    image_path = os.path.join(oe_image_dir, filename)
    image.save(image_path)

    oe_rows.append({
        "image": filename,
        "superclass_index": 3,
        "subclass_index": 87,
        "superclass_name": "novel",
        "subclass_name": "novel"
    })

novel_df = pd.DataFrame(oe_rows)

100%|██████████| 2000/2000 [01:12<00:00, 27.47it/s] 


# Utlity Helper Methods

In [27]:
import numpy as np
from collections import Counter

def analyze_prediction_distributions(prediction_distributions):
    # === Confidence Stats ===
    for key in ["super_confidence", "sub_confidence"]:
        values = np.array(prediction_distributions[key])
        print(f"\n Stats for {key.replace('_', ' ').title()}:")
        print(f"  Count:     {len(values)}")
        print(f"  Mean:      {values.mean():.4f}")
        print(f"  Std Dev:   {values.std():.4f}")
        print(f"  Min:       {values.min():.4f}")
        print(f"  Max:       {values.max():.4f}")
        print(f"  25th pct:  {np.percentile(values, 25):.4f}")
        print(f"  Median:    {np.median(values):.4f}")
        print(f"  75th pct:  {np.percentile(values, 75):.4f}")
        below_50 = (values < 0.5).sum()
        print(f"  Below 0.5: {below_50} samples ({below_50 / len(values) * 100:.2f}%)")

    # === Prediction Counts ===
    print("\n Raw Superclass Prediction Distribution:")
    super_counts = Counter(prediction_distributions["raw_superclass_pred"])
    for label, count in sorted(super_counts.items()):
        print(f"  Superclass {label}: {count} samples")

    print("\n Raw Subclass Prediction Distribution (Top 15):")
    subclass_counts = Counter(prediction_distributions["raw_subclass_pred"])
    for label, count in subclass_counts.most_common(15):
        print(f"  Subclass {label}: {count} samples")


    print("\n Raw Subclass Prediction Distribution (Bottom 15):")
    subclass_counts = Counter(prediction_distributions["raw_subclass_pred"])
    for label, count in sorted(subclass_counts.items(), key=lambda x: x[1])[:15]:
        print(f"  Subclass {label}: {count} samples")

In [None]:
import numpy as np

# visualize the predictions
def print_confidence_stats(confidences, label):
    confidences = np.array(confidences)
    print(f"\n Stats for {label} confidence:")
    print(f"  Mean:      {confidences.mean():.4f}")
    print(f"  Std Dev:   {confidences.std():.4f}")
    print(f"  Min:       {confidences.min():.4f}")
    print(f"  Max:       {confidences.max():.4f}")
    print(f"  25th pct:  {np.percentile(confidences, 25):.4f}")
    print(f"  50th pct:  {np.percentile(confidences, 50):.4f} (median)")
    print(f"  75th pct:  {np.percentile(confidences, 75):.4f}")
    print(f"  Below 0.5: {(confidences < 0.5).sum()} samples ({(confidences < 0.5).mean()*100:.2f}%)")

In [28]:
def get_dataset(transform, use_novel_data=True):
    """
    Given a transform and whether to use novel data in training,
    returns training and validation datasets
    """
    training_data = train_df if not use_novel_data else pd.concat([train_df, novel_df], ignore_index=True)

    full_dataset = MultiClassImageDataset(training_data, img_dir=image_dir, transform=transform)

    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size

    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    return train_dataset, val_dataset

In [133]:
from torchvision.transforms import RandAugment
def get_transform(data_augmentation=False, model="Resnet50"):
    if model == "Resnet50":
        base_transforms = [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]

        if data_augmentation:
            aug_transforms = [
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0))
                # RandAugment(num_ops=2, magnitude=9),
            ]
            return transforms.Compose(aug_transforms + base_transforms)
        else:
            return transforms.Compose(base_transforms)

    else:
        raise ValueError(f"No transform configured for model '{model}'")


# Train ResNet50

In [146]:
class CosineClassifier(nn.Module):
    def __init__(self, in_features, num_classes, scale=10.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.scale = scale

    def forward(self, x):
        x_norm = nn.functional.normalize(x, dim=1)
        w_norm = nn.functional.normalize(self.weight, dim=1)
        return self.scale * torch.matmul(x_norm, w_norm.T)

In [185]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNet50MultiHead(nn.Module):
    def __init__(
        self,
        num_superclasses=4,
        num_subclasses=88,
        use_dropout=False,
        dropout_p=0.5,
        use_nonlinear_head=False,
        use_cosine_classifier=False
    ):
        """
        ResNet-50 backbone with dual heads. Each head can be:
          - Linear (default), or
          - Non-linear MLP: Linear -> ReLU -> Dropout -> Linear

        Args:
            num_superclasses (int): Number of superclass categories.
            num_subclasses (int): Number of subclass categories.
            use_dropout (bool): Whether to apply dropout in head.
            dropout_p (float): Dropout probability.
            use_nonlinear_head (bool): If True, use 2-layer MLP head.
        """
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)

        # Freeze most layers except last block
        for name, param in self.backbone.named_parameters():
            if "layer4" in name or "avgpool" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        self.backbone.fc = nn.Identity()
        self.embedding_dim = 2048
        self.use_dropout = use_dropout
        self.use_nonlinear_head = use_nonlinear_head
        self.use_cosine_classifier = use_cosine_classifier

        # === Superclass Head ===
        if self.use_cosine_classifier:
            self.superclass_head = CosineClassifier(self.embedding_dim, num_superclasses)
        elif self.use_nonlinear_head:
            self.superclass_head = nn.Sequential(
                nn.Linear(self.embedding_dim, 512),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if use_dropout else nn.Identity(),
                nn.Linear(512, num_superclasses)
            )
        else:
            self.superclass_head = nn.Linear(self.embedding_dim, num_superclasses)

        # === Subclass Head ===
        if self.use_cosine_classifier:
            self.subclass_head = CosineClassifier(self.embedding_dim, num_subclasses)
        elif self.use_nonlinear_head:
            self.subclass_head = nn.Sequential(
                nn.Linear(self.embedding_dim, 512),
                nn.ReLU(),
                nn.Dropout(p=dropout_p) if use_dropout else nn.Identity(),
                nn.Linear(512, num_subclasses)
            )
        else:
            self.subclass_head = nn.Linear(self.embedding_dim, num_subclasses)

    def forward(self, x):
        features = self.backbone(x)
        super_logits = self.superclass_head(features)
        sub_logits = self.subclass_head(features)
        return super_logits, sub_logits

In [150]:
  # === Prediction tracking for analysis ===
resnet_training_prediction_distributions = {
  "raw_superclass_pred": [],
  "raw_subclass_pred": [],
    "super_confidence": [],
    "sub_confidence": []
  }


# Code block to train Resnet
# === Code block to train ResNet-50 ===
def train_resnet50(use_novel_data=True, use_nonlinear_head=False, use_augmentation=False, use_cosine_classifier=False, use_dropout=False, dropout_p=0.5, num_epochs=20):
    """
    Trains a ResNet-50 model with two heads (superclass and subclass).

    Args:
        use_novel_data (bool): Whether to include novel data during training.
        use_dropout (bool): Whether to use dropout before classification heads.
        dropout_p (float): Dropout probability.
        num_epochs (int): Number of training epochs.

    Returns:
        Trained model (ResNet50MultiHead)
    """
    # === Image transforms ===
    image_transforms = get_transform(data_augmentation=use_augmentation, model="Resnet50")

    # === Load dataset ===
    train_dataset, val_dataset = get_dataset(transform=image_transforms, use_novel_data=use_novel_data)
    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # === Model, Loss, Optimizer Setup ===
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model = ResNet50MultiHead(use_nonlinear_head=use_nonlinear_head, use_cosine_classifier=use_cosine_classifier, use_dropout=use_dropout, dropout_p=dropout_p).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # === Training loop ===
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, super_labels, sub_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = images.to(device)
            super_labels = super_labels.to(device)
            sub_labels = sub_labels.to(device)

            optimizer.zero_grad()
            super_logits, sub_logits = model(images)

            super_loss = criterion(super_logits, super_labels)
            sub_loss = criterion(sub_logits, sub_labels)
            loss = super_loss + sub_loss

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_loss:.4f}")

        # === Validation ===
        model.eval()
        super_correct = sub_correct = total = 0
        super_loss_total, sub_loss_total = 0.0, 0.0
        with torch.no_grad():
            for images, super_labels, sub_labels in val_loader:
                images = images.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                super_logits, sub_logits = model(images)

                # === Compute and accumulate validation loss
                super_loss = criterion(super_logits, super_labels)
                sub_loss = criterion(sub_logits, sub_labels)
                super_loss_total += super_loss.item()
                sub_loss_total += sub_loss.item()

                super_preds = torch.argmax(super_logits, dim=1)
                sub_preds = torch.argmax(sub_logits, dim=1)

                super_correct += (super_preds == super_labels).sum().item()
                sub_correct += (sub_preds == sub_labels).sum().item()
                total += images.size(0)

                # Confidence tracking
                super_probs = torch.softmax(super_logits, dim=1)
                sub_probs = torch.softmax(sub_logits, dim=1)
                super_conf, _ = torch.max(super_probs, dim=1)
                sub_conf, _ = torch.max(sub_probs, dim=1)

                resnet_training_prediction_distributions["raw_superclass_pred"].extend(super_preds.cpu().tolist())
                resnet_training_prediction_distributions["raw_subclass_pred"].extend(sub_preds.cpu().tolist())
                resnet_training_prediction_distributions["super_confidence"].extend(super_conf.cpu().tolist())
                resnet_training_prediction_distributions["sub_confidence"].extend(sub_conf.cpu().tolist())

        avg_super_loss = super_loss_total / len(val_loader)
        avg_sub_loss = sub_loss_total / len(val_loader)

        print(f"Validation Accuracy | Superclass: {super_correct / total:.4f} | Subclass: {sub_correct / total:.4f}")
        print(f"Validation Loss     | Superclass: {avg_super_loss:.4f} | Subclass: {avg_sub_loss:.4f}")


    return model

In [175]:
model = train_resnet50(
    use_novel_data=False,
    use_augmentation=False,
    use_nonlinear_head=False,
    use_dropout=False,
    dropout_p=0.5,
    use_cosine_classifier=False,
    num_epochs=20
    )

Using device: cuda


Epoch 1: 100%|██████████| 89/89 [00:22<00:00,  4.00it/s]


Epoch 1/20 | Train Loss: 4.8008
Validation Accuracy | Superclass: 0.9857 | Subclass: 0.4467
Validation Loss     | Superclass: 0.4566 | Subclass: 3.5122


Epoch 2: 100%|██████████| 89/89 [00:22<00:00,  4.04it/s]


Epoch 2/20 | Train Loss: 3.5006
Validation Accuracy | Superclass: 0.9936 | Subclass: 0.6216
Validation Loss     | Superclass: 0.2522 | Subclass: 2.7408


Epoch 3: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 3/20 | Train Loss: 2.6864
Validation Accuracy | Superclass: 0.9936 | Subclass: 0.7472
Validation Loss     | Superclass: 0.1686 | Subclass: 2.1549


Epoch 4: 100%|██████████| 89/89 [00:22<00:00,  4.01it/s]


Epoch 4/20 | Train Loss: 2.1325
Validation Accuracy | Superclass: 0.9936 | Subclass: 0.7742
Validation Loss     | Superclass: 0.1267 | Subclass: 1.7758


Epoch 5: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 5/20 | Train Loss: 1.7412
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8172
Validation Loss     | Superclass: 0.1027 | Subclass: 1.4808


Epoch 6: 100%|██████████| 89/89 [00:21<00:00,  4.05it/s]


Epoch 6/20 | Train Loss: 1.4712
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8458
Validation Loss     | Superclass: 0.0823 | Subclass: 1.2509


Epoch 7: 100%|██████████| 89/89 [00:22<00:00,  4.00it/s]


Epoch 7/20 | Train Loss: 1.2592
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8633
Validation Loss     | Superclass: 0.0717 | Subclass: 1.1131


Epoch 8: 100%|██████████| 89/89 [00:22<00:00,  4.00it/s]


Epoch 8/20 | Train Loss: 1.1064
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8601
Validation Loss     | Superclass: 0.0625 | Subclass: 1.0010


Epoch 9: 100%|██████████| 89/89 [00:21<00:00,  4.05it/s]


Epoch 9/20 | Train Loss: 0.9853
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8696
Validation Loss     | Superclass: 0.0555 | Subclass: 0.9093


Epoch 10: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 10/20 | Train Loss: 0.8972
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8680
Validation Loss     | Superclass: 0.0501 | Subclass: 0.8421


Epoch 11: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 11/20 | Train Loss: 0.8078
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8744
Validation Loss     | Superclass: 0.0449 | Subclass: 0.7717


Epoch 12: 100%|██████████| 89/89 [00:22<00:00,  4.01it/s]


Epoch 12/20 | Train Loss: 0.7452
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8919
Validation Loss     | Superclass: 0.0424 | Subclass: 0.7177


Epoch 13: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 13/20 | Train Loss: 0.6864
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8903
Validation Loss     | Superclass: 0.0393 | Subclass: 0.6843


Epoch 14: 100%|██████████| 89/89 [00:22<00:00,  3.99it/s]


Epoch 14/20 | Train Loss: 0.6410
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8967
Validation Loss     | Superclass: 0.0364 | Subclass: 0.6438


Epoch 15: 100%|██████████| 89/89 [00:22<00:00,  4.01it/s]


Epoch 15/20 | Train Loss: 0.6005
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8871
Validation Loss     | Superclass: 0.0346 | Subclass: 0.6067


Epoch 16: 100%|██████████| 89/89 [00:22<00:00,  4.00it/s]


Epoch 16/20 | Train Loss: 0.5666
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8998
Validation Loss     | Superclass: 0.0315 | Subclass: 0.5778


Epoch 17: 100%|██████████| 89/89 [00:22<00:00,  3.98it/s]


Epoch 17/20 | Train Loss: 0.5301
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.9030
Validation Loss     | Superclass: 0.0309 | Subclass: 0.5623


Epoch 18: 100%|██████████| 89/89 [00:22<00:00,  4.01it/s]


Epoch 18/20 | Train Loss: 0.5040
Validation Accuracy | Superclass: 0.9968 | Subclass: 0.9062
Validation Loss     | Superclass: 0.0305 | Subclass: 0.5330


Epoch 19: 100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Epoch 19/20 | Train Loss: 0.4807
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.8935
Validation Loss     | Superclass: 0.0286 | Subclass: 0.5200


Epoch 20: 100%|██████████| 89/89 [00:22<00:00,  3.97it/s]


Epoch 20/20 | Train Loss: 0.4598
Validation Accuracy | Superclass: 0.9952 | Subclass: 0.9126
Validation Loss     | Superclass: 0.0257 | Subclass: 0.4938


In [152]:
# Check the prediction distributions
analyze_prediction_distributions(resnet_training_prediction_distributions)


 Stats for Super Confidence:
  Count:     16580
  Mean:      0.9723
  Std Dev:   0.0336
  Min:       0.4372
  Max:       0.9982
  25th pct:  0.9681
  Median:    0.9813
  75th pct:  0.9880
  Below 0.5: 7 samples (0.04%)

 Stats for Sub Confidence:
  Count:     16580
  Mean:      0.8364
  Std Dev:   0.2196
  Min:       0.0378
  Max:       0.9948
  25th pct:  0.8226
  Median:    0.9340
  75th pct:  0.9634
  Below 0.5: 1748 samples (10.54%)

 Raw Superclass Prediction Distribution:
  Superclass 0: 4014 samples
  Superclass 1: 4142 samples
  Superclass 2: 4833 samples
  Superclass 3: 3591 samples

 Raw Subclass Prediction Distribution (Top 15):
  Subclass 87: 3607 samples
  Subclass 69: 358 samples
  Subclass 76: 319 samples
  Subclass 21: 315 samples
  Subclass 30: 303 samples
  Subclass 75: 285 samples
  Subclass 44: 281 samples
  Subclass 6: 275 samples
  Subclass 27: 272 samples
  Subclass 63: 268 samples
  Subclass 70: 250 samples
  Subclass 37: 243 samples
  Subclass 52: 241 samples


# Generating CSV for Submitting to Leaderboard

In [182]:
# using these to check the distributions of predictions
test_prediction_distributions = {
    "raw_superclass_pred": [],
    "raw_subclass_pred": [],
    "super_confidence": [],
    "sub_confidence": [],
    "super_logit": [],
    "sub_logit": []
}

# These are better if trained on novel data
superclass_cutoff = 0.95 # if bleow this predict novel
subclass_cutoff = 0.6 # if below this predict novel

super_logit_percentile_cutoff = 35
sub_logit_percentile_cutoff = 50


# These are better if trained not on novel
# superclass_cutoff = 0.98 # if bleow this predict novel (better for when not trained with novel data)
# subclass_cutoff = 0.75 # if below this predict novel

# === Submission Function ===
def generate_submission_csv(model, test_dir, output_path, transform, device, include_cutoffs=True):
    model.eval()
    model.to(device)

    test_dataset = MultiClassImageTestDataset(img_dir=test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    predictions = {
        "image": [],
        "superclass_index": [],
        "subclass_index": []
    }

    with torch.no_grad():
        for images, img_names in tqdm(test_loader):
          images = images.to(device)
          super_logits, sub_logits = model(images)

          super_probs = torch.softmax(super_logits, dim=1)
          sub_probs = torch.softmax(sub_logits, dim=1)

          super_logit_val, _ = torch.max(super_logits, dim=1)
          sub_logit_val, _ = torch.max(sub_logits, dim=1)

          super_conf, super_preds = torch.max(super_probs, dim=1)
          sub_conf, sub_preds = torch.max(sub_probs, dim=1)

          # Record raw values for analysis
          test_prediction_distributions["raw_superclass_pred"].append(super_preds.item())
          test_prediction_distributions["raw_subclass_pred"].append(sub_preds.item())
          test_prediction_distributions["super_confidence"].append(super_conf.item())
          test_prediction_distributions["sub_confidence"].append(sub_conf.item())
          test_prediction_distributions["super_logit"].append(super_logit_val.item())
          test_prediction_distributions["sub_logit"].append(sub_logit_val.item())

          # === Logit Thresholds ===

          # Add novel class detection threshold
          if include_cutoffs:
            if super_conf.item() < superclass_cutoff:
                super_preds[0] = 3
            if super_preds[0] == 3 or sub_conf.item() < subclass_cutoff:
                sub_preds[0] = 87

          predictions["image"].append(img_names[0])
          predictions["superclass_index"].append(super_preds.item())
          predictions["subclass_index"].append(sub_preds.item())

    if include_cutoffs:
      # === Compute 25th percentile logit thresholds ===
      super_logit_array = np.array(test_prediction_distributions["super_logit"])
      sub_logit_array = np.array(test_prediction_distributions["sub_logit"])

      super_logit_thresh = np.percentile(super_logit_array, super_logit_percentile_cutoff)
      sub_logit_thresh = np.percentile(sub_logit_array, sub_logit_percentile_cutoff)

      print(f"Logit thresholds — Super: {super_logit_thresh:.4f}, Sub: {sub_logit_thresh:.4f}")

      # === Generate second CSV using raw logit thresholds ===
      logit_predictions = {
          "image": [],
          "superclass_index": [],
          "subclass_index": []
      }

      for i in range(len(super_logit_array)):
          super_pred = test_prediction_distributions["raw_superclass_pred"][i]
          sub_pred = test_prediction_distributions["raw_subclass_pred"][i]
          super_logit = super_logit_array[i]
          sub_logit = sub_logit_array[i]

          if super_logit < super_logit_thresh:
              super_pred = 3
          if super_pred == 3 or sub_logit < sub_logit_thresh:
              sub_pred = 87

          logit_predictions["image"].append(predictions["image"][i])
          logit_predictions["superclass_index"].append(super_pred)
          logit_predictions["subclass_index"].append(sub_pred)

      # === Save logit-based submission ===
      logit_output_path = output_path.replace(".csv", "_logit.csv")
      pd.DataFrame(logit_predictions).to_csv(logit_output_path, index=False)
      print(f"Saved logit-based submission: {logit_output_path}")

    # === Save softmax-based submission with suffix ===
    softmax_output_path = output_path.replace(".csv", "_softmax.csv")
    pd.DataFrame(predictions).to_csv(softmax_output_path, index=False)
    print(f"Saved softmax-based submission to: {softmax_output_path}")

In [186]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
generate_submission_csv(
    model=model,
    test_dir="/content/drive/MyDrive/COMS_4995/final_project/data/test_images",
    output_path="/content/drive/MyDrive/COMS_4995/final_project/results/test_predictions.csv",
    transform=get_transform(data_augmentation=False, model="Resnet50"), # this code is for resnet
    # transform=preprocess, # transorm needed for clip
    # transform=efficient_net_transforms, # transform needed for efficientnet
    device=device,
    include_cutoffs=True
)

Using device: cuda


  1%|          | 112/11180 [00:01<03:07, 58.96it/s]


KeyboardInterrupt: 

In [None]:
analyze_prediction_distributions(test_prediction_distributions)