In [1]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-3w0pcvm6
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-3w0pcvm6
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting 

In [2]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

import clip

In [3]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Training on:", device)

Training on: cuda


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import zipfile
import os


DATA_FOLDER = '/content/drive/MyDrive/NNDL Project/Data'
TRAIN_ZIP_PATH = os.path.join(DATA_FOLDER, 'train_images.zip')
TEST_ZIP_PATH = os.path.join(DATA_FOLDER, 'test_images.zip')
EXTRACT_PATH = '/content/data'



with zipfile.ZipFile(TRAIN_ZIP_PATH, 'r') as zip_ref:
  zip_ref.extractall(EXTRACT_PATH)
with zipfile.ZipFile(TEST_ZIP_PATH, 'r') as zip_ref:
  zip_ref.extractall(EXTRACT_PATH)


train_ann_df = pd.read_csv(os.path.join(DATA_FOLDER, 'train_data.csv'))
super_map_df = pd.read_csv(os.path.join(DATA_FOLDER, 'superclass_mapping.csv'))
sub_map_df = pd.read_csv(os.path.join(DATA_FOLDER, 'subclass_mapping.csv'))

train_img_dir = os.path.join(EXTRACT_PATH, 'train_images')
test_img_dir = os.path.join(EXTRACT_PATH, 'test_images')

In [6]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label


class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):  # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [7]:
image_preprocessing = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
        ),
    ]
)

# Create train and val split
train_dataset = MultiClassImageDataset(
    train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=image_preprocessing
)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1])

# Create test dataset
test_dataset = MultiClassImageTestDataset(
    super_map_df, sub_map_df, test_img_dir, transform=image_preprocessing
)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)

**Jason's Experiments**

In [8]:
class Trainer:
    def __init__(
        self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'
    ):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = (
                data[0].to(device),
                data[1].to(device),
                data[3].to(device),
            )

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(
                sub_outputs, sub_labels
            )
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, sub_labels = (
                    data[0].to(device),
                    data[1].to(device),
                    data[3].to(device),
                )

                super_outputs, sub_outputs = self.model(inputs)
                loss = self.criterion(super_outputs, super_labels) + self.criterion(
                    sub_outputs, sub_labels
                )
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self):
      if not self.test_loader:
          raise NotImplementedError('test_loader not specified')

      # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
      test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
      total_super_unseen = 0
      total_sub_unseen = 0
      with torch.no_grad():
          for i, data in enumerate(self.test_loader):
              inputs, img_names = data[0].to(device), data[1]
              batch_size = inputs.size(0)

              super_outputs, sub_outputs = self.model(inputs)

              # We convert with softmax to apply the threshold to probabilities, not logits.
              super_probs = F.softmax(super_outputs, dim=1)
              sub_probs = F.softmax(sub_outputs, dim=1)

              super_max, super_pred = torch.max(super_probs.data, 1)
              sub_max, sub_pred = torch.max(sub_probs.data, 1)

              # Handle batched predictions
              super_pred_labels = torch.where(super_max > 0.9, super_pred,
                                            torch.ones_like(super_pred) * 3).cpu().numpy()
              sub_pred_labels = torch.where(sub_max > 0.5, sub_pred,
                                          torch.ones_like(sub_pred) * 87).cpu().numpy()

              # Update statistics
              total_super_unseen += (super_max <= 0.9).sum().item()
              total_sub_unseen += (sub_max <= 0.5).sum().item()

              # Add predictions to our results
              for j in range(batch_size):
                  test_predictions['image'].append(img_names[j])
                  test_predictions['superclass_index'].append(int(super_pred_labels[j]))
                  test_predictions['subclass_index'].append(int(sub_pred_labels[j]))

      print(f'Total superclasses unseen: {total_super_unseen}')
      print(f'Total subclasses unseen: {total_sub_unseen}')

      return pd.DataFrame(data=test_predictions)

In [9]:
class CLIPMultiLabelClassifier(nn.Module):
    def __init__(self, device, num_subclasses):
        super().__init__()
        self.clip_model, _ = clip.load('ViT-B/32', device=device)
        self.clip_model.eval()
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.superclass_head = nn.Linear(self.clip_model.visual.output_dim, 4)
        self.subclass_head = nn.Linear(self.clip_model.visual.output_dim, num_subclasses + 1)

    def forward(self, images):
        with torch.no_grad():
            features = self.clip_model.encode_image(images).float()

        superclass_logits = self.superclass_head(features)
        subclass_logits = self.subclass_head(features)
        return superclass_logits, subclass_logits

In [10]:
model = CLIPMultiLabelClassifier(device=device, num_subclasses=87).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader, device)

for epoch in range(1):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')
print('Finished training.')

test_predictions = trainer.test()
test_predictions.to_csv('test_predictions_clip.csv', index=False)

100%|███████████████████████████████████████| 338M/338M [00:22<00:00, 15.9MiB/s]


Epoch 1
Training loss: 3.502
Validation loss: 2.469
Validation superclass acc: 99.84 %
Validation subclass acc: 69.27 %

Finished training.
Total superclasses unseen: 4503
Total subclasses unseen: 11180


***Below are Nico's Experiments.***

In [11]:

import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
# import torch.nn.functional as F # Not strictly needed in this version
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import clip
from sklearn.model_selection import train_test_split # For stratified splitting
import pandas as pd # Assuming train_ann_df and super_map_df are pandas DFs
import numpy as np # For a few operations

# --- 0. Configuration & Presumed Pre-loaded Data ---
# Ensure these are defined before running this script:
# train_ann_df = pd.read_csv(...) # Load your annotations
# super_map_df = pd.read_csv(...) # Load your superclass mapping
# train_img_dir = "path/to/your/training/images"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example placeholders if you want to run this snippet directly (replace with actual loading)
if 'train_ann_df' not in globals():
    print("Placeholder: Creating dummy train_ann_df and super_map_df")
    # Dummy super_map_df
    super_map_data = {'index': [0, 1, 2], 'class': ['bird', 'reptile', 'dog']}
    super_map_df = pd.DataFrame(super_map_data).set_index('index')

    # Dummy train_ann_df
    num_samples = 300
    image_names = [f"img_{i}.jpg" for i in range(num_samples)]
    superclass_indices = random.choices([0, 1, 2], k=num_samples) # 0:bird, 1:reptile, 2:dog
    train_ann_df = pd.DataFrame({'image': image_names, 'superclass_index': superclass_indices})
    # Create dummy image files if train_img_dir is also placeholder
    # train_img_dir = "./dummy_train_images"
    # if not os.path.exists(train_img_dir):
    #     os.makedirs(train_img_dir)
    #     for img_name in image_names:
    #         try:
    #             Image.new('RGB', (64,64)).save(os.path.join(train_img_dir, img_name))
    #         except Exception as e:
    #             print(f"Could not create dummy image {img_name}: {e}")


DOG_CLASS_NAME = 'dog' # The class to be treated as "novel" in this phase

# --- 1. Data Preparation & Splitting (Revised) ---

# Find the original integer label for the class to be treated as novel
try:
    dog_original_label_idx = int(super_map_df[super_map_df['class'] == DOG_CLASS_NAME].index[0])
except IndexError:
    raise ValueError(f"Class '{DOG_CLASS_NAME}' not found in super_map_df.")
except KeyError: # If 'index' is not the index name but a column
    dog_original_label_idx = int(super_map_df[super_map_df['class'] == DOG_CLASS_NAME]['index'].iloc[0])


# Identify all original labels and known original labels for this phase
all_original_labels = sorted(train_ann_df['superclass_index'].unique())
known_original_labels_for_phase1a = [l for l in all_original_labels if l != dog_original_label_idx]

if not known_original_labels_for_phase1a:
    raise ValueError("No 'known' classes remain after designating one as novel. Check your class setup.")

# Create a new contiguous mapping for labels for Phase 1a:
# Known classes get 0, 1, ...
# The "novel" proxy class (dog) gets the next available index.
phase1a_label_map = {original_label: new_label for new_label, original_label in enumerate(known_original_labels_for_phase1a)}
phase1a_novel_target_idx = len(known_original_labels_for_phase1a) # e.g., 2 if bird=0, reptile=1
phase1a_label_map[dog_original_label_idx] = phase1a_novel_target_idx

num_phase1a_outputs = phase1a_novel_target_idx + 1 # Total distinct target labels for this phase

print(f"Phase 1a: '{DOG_CLASS_NAME}' (original label {dog_original_label_idx}) will be mapped to 'novel' target label: {phase1a_novel_target_idx}")
print(f"Other known classes mapped to: { {k:v for k,v in phase1a_label_map.items() if v != phase1a_novel_target_idx} }")
print(f"Total output neurons for Phase 1a model: {num_phase1a_outputs}")

# Create a new column in train_ann_df for these Phase 1a target labels
train_ann_df['phase1a_target_label'] = train_ann_df['superclass_index'].map(phase1a_label_map)

# Split the main dataframe into training and validation sets
# Stratify by the 'phase1a_target_label' to ensure both sets see all (remapped) classes
try:
    train_indices, val_indices = train_test_split(
        train_ann_df.index,
        test_size=0.25,  # e.g., 25% for validation
        stratify=train_ann_df['phase1a_target_label'],
        random_state=42  # For reproducibility
    )
except ValueError as e:
    print(f"Warning: Stratified split failed with error: {e}. Falling back to non-stratified split.")
    print("This can happen if a class has too few samples for stratification.")
    # Check class counts for stratification
    class_counts = train_ann_df['phase1a_target_label'].value_counts()
    print(f"Class counts for stratification: \n{class_counts}")
    min_samples_for_stratify = 2 # sklearn usually needs at least 2 per class for stratification
    if any(class_counts < min_samples_for_stratify):
        print(f"At least one class has fewer than {min_samples_for_stratify} samples, which causes stratification issues.")
    train_indices, val_indices = train_test_split(
        train_ann_df.index,
        test_size=0.25,
        random_state=42
    )


train_df = train_ann_df.loc[train_indices].reset_index(drop=True)
val_df = train_ann_df.loc[val_indices].reset_index(drop=True)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Train df 'phase1a_target_label' counts:\n{train_df['phase1a_target_label'].value_counts()}")
print(f"Val df 'phase1a_target_label' counts:\n{val_df['phase1a_target_label'].value_counts()}")


# --- 2. Transforms (Targeting 224x224 for CLIP) ---
CLIP_INPUT_SIZE = 224
# CLIP's official normalization values
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(CLIP_INPUT_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(CLIP_MEAN, CLIP_STD),
])

val_transform = transforms.Compose([
    transforms.Resize((CLIP_INPUT_SIZE, CLIP_INPUT_SIZE)), # Resize without cropping for validation
    transforms.ToTensor(),
    transforms.Normalize(CLIP_MEAN, CLIP_STD),
])

# --- 3. Updated Dataset Class for Phase 1a ---
class SuperClassPhase1aDataset(Dataset):
    def __init__(self, dataframe, img_base_dir, transform_fn):
        self.df = dataframe
        self.img_base_dir = img_base_dir
        self.transform_fn = transform_fn
        # The 'phase1a_target_label' column in df already has the correct 0, 1, ... N-1 mapping

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_base_dir, row['image'])
        try:
            img = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            print(f"ERROR: Image not found at {img_path}")
            # Return a dummy image and label or raise error
            # For now, creating a placeholder black image
            img = Image.new('RGB', (CLIP_INPUT_SIZE, CLIP_INPUT_SIZE), color='black')
            # It's better to ensure all images exist before this point
        target_label = int(row['phase1a_target_label'])

        if self.transform_fn:
            img = self.transform_fn(img)

        return img, target_label

# --- 4. DataLoaders ---
BATCH_SIZE = 32 # Reduced from 64 for potentially smaller datasets / memory constraints
NUM_WORKERS = 2 # Adjusted from 4, common default

train_ds = SuperClassPhase1aDataset(train_df, train_img_dir, train_transform)
val_ds = SuperClassPhase1aDataset(val_df, train_img_dir, val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True if device.type == 'cuda' else False)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True if device.type == 'cuda' else False)

# --- 5. Model Definition (CLIP Backbone + Neck + Head for Phase 1a) ---
def build_neck_super_phase1a(device_obj, num_phase1a_classes, neck_dim=512, drop_p=0.2): # Increased dropout slightly
    class CLIPNeckSuperPhase1a(nn.Module):
        def __init__(self):
            super().__init__()
            # Load CLIP model. It will be on the specified device.
            self.clip_model, _ = clip.load('ViT-B/32', device=device_obj)
            # Freeze CLIP parameters
            for param in self.clip_model.parameters():
                param.requires_grad = False

            clip_output_dim = self.clip_model.visual.output_dim
            self.neck = nn.Sequential(
                nn.Linear(clip_output_dim, neck_dim),
                nn.LayerNorm(neck_dim), # LayerNorm is often good here
                nn.GELU(),
                nn.Dropout(drop_p),
                nn.Linear(neck_dim, clip_output_dim), # Project back or to another dimension
            )
            # Head for Phase 1a classes (e.g., bird, reptile, novel_dog_proxy)
            self.super_head = nn.Linear(clip_output_dim, num_phase1a_classes)

        def forward(self, images):
            # CLIP model expects images directly, no need for torch.no_grad here if already frozen
            # but if you want to be extra sure for the encoding part:
            with torch.no_grad():
                image_features = self.clip_model.encode_image(images).float()

            neck_features = self.neck(image_features)
            logits = self.super_head(neck_features)
            return logits
    # Instantiate and move to device
    model = CLIPNeckSuperPhase1a().to(device_obj)
    return model

# --- 6. Model Instantiation, Criterion, Optimizer ---
model = build_neck_super_phase1a(device, num_phase1a_outputs)
criterion = nn.CrossEntropyLoss()
# Only pass parameters of the neck and head to the optimizer
optimizer = optim.Adam(
    [param for param in model.neck.parameters() if param.requires_grad] +
    [param for param in model.super_head.parameters() if param.requires_grad],
    lr=1e-4, # Adjusted learning rate
    weight_decay=1e-4 # Added weight decay
)

# Learning rate scheduler (optional, but good practice)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# --- 7. Training and Validation Loop ---
NUM_EPOCHS = 1 # As originally planned for initial run

print(f"\nStarting Phase 1a training for {NUM_EPOCHS} epochs...")
for epoch in range(1, NUM_EPOCHS + 1):
    # Training Phase
    model.train()
    total_train_loss = 0.0
    train_correct_preds = 0
    train_total_samples = 0

    for batch_idx, (images, target_labels) in enumerate(train_loader):
        images = images.to(device)
        target_labels = target_labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, target_labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        preds = logits.argmax(dim=1)
        train_correct_preds += (preds == target_labels).sum().item()
        train_total_samples += target_labels.size(0)

        if batch_idx % 50 == 0: # Print training progress every 50 batches
             print(f"  Epoch {epoch} [{(batch_idx+1)*BATCH_SIZE:>5}/{len(train_ds):>5} ({100.*(batch_idx+1)/len(train_loader):.0f}%)]\tLoss: {loss.item():.4f}")


    avg_train_loss = total_train_loss / len(train_loader)
    train_accuracy = 100.0 * train_correct_preds / train_total_samples if train_total_samples > 0 else 0.0
    print(f"Epoch {epoch} Summary — Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

    # Validation Phase
    model.eval()
    total_val_loss = 0.0
    val_correct_all_preds = 0
    val_correct_unseen_proxy_preds = 0
    val_total_samples = 0
    val_total_unseen_proxy_gt = 0

    with torch.no_grad():
        for images, target_labels in val_loader:
            images = images.to(device)
            target_labels_gpu = target_labels.to(device) # For loss and comparison with GPU preds

            logits = model(images)
            loss = criterion(logits, target_labels_gpu)
            total_val_loss += loss.item()

            preds_gpu = logits.argmax(dim=1)

            val_total_samples += target_labels.size(0) # labels is on CPU here, .size(0) is fine
            val_correct_all_preds += (preds_gpu == target_labels_gpu).sum().item()

            # Identify unseen proxy samples (those whose target is phase1a_novel_target_idx)
            # target_labels is on CPU, so use it for masking
            unseen_proxy_mask_cpu = (target_labels == phase1a_novel_target_idx)
            val_total_unseen_proxy_gt += unseen_proxy_mask_cpu.sum().item()

            # Check predictions for these unseen proxy samples
            # Apply the mask (moved to GPU) to the predictions (on GPU)
            if unseen_proxy_mask_cpu.any(): # Ensure there are unseen samples in this batch
                 val_correct_unseen_proxy_preds += (preds_gpu[unseen_proxy_mask_cpu.to(device)] == phase1a_novel_target_idx).sum().item()


    avg_val_loss = total_val_loss / len(val_loader)
    overall_val_accuracy = 100.0 * val_correct_all_preds / val_total_samples if val_total_samples > 0 else 0.0
    unseen_proxy_val_accuracy = 100.0 * val_correct_unseen_proxy_preds / val_total_unseen_proxy_gt if val_total_unseen_proxy_gt > 0 else 0.0 # Correct denominator

    print(f"  Validation — Avg Loss: {avg_val_loss:.4f}, Overall Acc: {overall_val_accuracy:.2f}%, "
          f"Unseen ('{DOG_CLASS_NAME}' as novel) Acc: {unseen_proxy_val_accuracy:.2f}%")
    print("-" * 50)

    # Step the scheduler based on validation loss
    scheduler.step(avg_val_loss)

print("Phase 1a training complete.")


Phase 1a: 'dog' (original label 1) will be mapped to 'novel' target label: 2
Other known classes mapped to: {np.int64(0): 0, np.int64(2): 1}
Total output neurons for Phase 1a model: 3
Training set size: 4716
Validation set size: 1572
Train df 'phase1a_target_label' counts:
phase1a_target_label
1    1766
2    1563
0    1387
Name: count, dtype: int64
Val df 'phase1a_target_label' counts:
phase1a_target_label
1    588
2    521
0    463
Name: count, dtype: int64

Starting Phase 1a training for 1 epochs...




Epoch 1 Summary — Train Loss: 0.1250, Train Accuracy: 97.67%
  Validation — Avg Loss: 0.0136, Overall Acc: 99.94%, Unseen ('dog' as novel) Acc: 100.00%
--------------------------------------------------
Phase 1a training complete.


In [12]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import clip

# Assuming train_ann_df, super_map_df, train_img_dir, and device are already defined
# from the previous Phase 1a code execution.
# Also assuming val_ann_df was created during the train_test_split in Phase 1a.

# Get the mapping of superclass names to their original integer indices
superclass_to_index = dict(zip(super_map_df['class'], super_map_df.index))

# --- transforms (now target 224x224 for CLIP) ---
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073),
        (0.26862954, 0.26130258, 0.27577711)
    ),
])
val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073),
        (0.26862954, 0.26130258, 0.27577711)
    ),
])

# --- Updated Dataset for super-class with 4 output neurons (including 'novel') ---
class SuperClassDatasetPhase1b(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.class_to_label = {
            superclass_to_index['bird']: 0,
            superclass_to_index['dog']: 1,
            superclass_to_index['reptile']: 2,
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(os.path.join(self.img_dir, row['image'])).convert('RGB')
        original_label = int(row['superclass_index'])
        target_label = self.class_to_label[original_label]  # Map to 0, 1, 2

        if self.transform:
            img = self.transform(img)
        return img, target_label

# Create DataLoaders for Phase 1b
train_ds_phase1b = SuperClassDatasetPhase1b(train_df, train_img_dir, train_tf)
val_ds_phase1b = SuperClassDatasetPhase1b(val_df, train_img_dir, val_tf) # Using the val_df created in Phase 1a

train_loader_phase1b = DataLoader(train_ds_phase1b, batch_size=64, shuffle=True, num_workers=4)
val_loader_phase1b = DataLoader(val_ds_phase1b, batch_size=64, shuffle=False, num_workers=4)

# --- build_neck with a 'novel' output neuron (4-way total) ---
def build_neck_super_phase1b(device, neck_dim=512, drop_p=0.1):
    class CLIPNeckSuperPhase1b(nn.Module):
        def __init__(self):
            super().__init__()
            self.clip, _ = clip.load('ViT-B/32', device=device)
            for p in self.clip.parameters():
                p.requires_grad = False

            D = self.clip.visual.output_dim
            self.neck = nn.Sequential(
                nn.Linear(D, neck_dim),
                nn.LayerNorm(neck_dim),
                nn.GELU(),
                nn.Dropout(drop_p),
                nn.Linear(neck_dim, D),
            )
            # 4-way: bird, dog, reptile, novel
            self.super_head = nn.Linear(D, 4)

        def forward(self, x):
            with torch.no_grad():
                f = self.clip.encode_image(x).float()
            f = self.neck(f)
            return self.super_head(f) # logits over [bird, dog, reptile, novel]
    return CLIPNeckSuperPhase1b().to(device)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_phase1b = build_neck_super_phase1b(device)
criterion_phase1b = nn.CrossEntropyLoss()
optimizer_phase1b = optim.Adam(filter(lambda p: p.requires_grad, model_phase1b.parameters()), lr=1e-3)

# --- Training loop for 5 epochs (Phase 1b) ---
for epoch in range(1, 10):
    model_phase1b.train()
    total_loss = 0
    for imgs, labels in train_loader_phase1b:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_phase1b.zero_grad()
        logits = model_phase1b(imgs)
        loss = criterion_phase1b(logits[:, :3], labels) # Only calculate loss on the first 3 outputs
        loss.backward()
        optimizer_phase1b.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} — Train loss (Phase 1b): {total_loss/len(train_loader_phase1b):.3f}")

    # Validate (Phase 1b) with confidence thresholding for "novel"
    model_phase1b.eval()
    correct_known = total = novel_predictions = 0
    confidence_threshold = 0.9 # You'll need to tune this

    with torch.no_grad():
        for imgs, true_labels in val_loader_phase1b:
            imgs = imgs.to(device)
            logits = model_phase1b(imgs)
            probs = torch.softmax(logits, dim=1)
            max_prob_known = torch.max(probs[:, :3], dim=1).values
            predicted_labels = torch.argmax(probs[:, :3], dim=1) # Predictions for known classes

            # Identify as "novel" if confidence in known classes is low
            novel_mask = (max_prob_known < confidence_threshold)
            predicted_super_labels = torch.where(novel_mask, torch.tensor([3] * len(true_labels)).to(device), predicted_labels)

            total += len(true_labels)
            correct_known += (predicted_super_labels[~novel_mask].cpu() == true_labels[~novel_mask.cpu()]).sum().item()
            novel_predictions += novel_mask.sum().item()

    accuracy_known = 100 * correct_known / total if total > 0 else 0
    novel_ratio = 100 * novel_predictions / total if total > 0 else 0
    print(f" Val Accuracy (Known Classes, Phase 1b): {accuracy_known:.2f}%")
    print(f" Val Ratio Predicted as Novel (Threshold = {confidence_threshold}): {novel_ratio:.2f}%")
    print("--------------------------------------------------")

print("Phase 1b training complete.")



Epoch 1 — Train loss (Phase 1b): 0.067
 Val Accuracy (Known Classes, Phase 1b): 99.75%
 Val Ratio Predicted as Novel (Threshold = 0.9): 0.19%
--------------------------------------------------
Epoch 2 — Train loss (Phase 1b): 0.012
 Val Accuracy (Known Classes, Phase 1b): 99.49%
 Val Ratio Predicted as Novel (Threshold = 0.9): 0.38%
--------------------------------------------------
Epoch 3 — Train loss (Phase 1b): 0.009
 Val Accuracy (Known Classes, Phase 1b): 98.92%
 Val Ratio Predicted as Novel (Threshold = 0.9): 0.95%
--------------------------------------------------
Epoch 4 — Train loss (Phase 1b): 0.009
 Val Accuracy (Known Classes, Phase 1b): 99.87%
 Val Ratio Predicted as Novel (Threshold = 0.9): 0.06%
--------------------------------------------------
Epoch 5 — Train loss (Phase 1b): 0.013
 Val Accuracy (Known Classes, Phase 1b): 99.62%
 Val Ratio Predicted as Novel (Threshold = 0.9): 0.32%
--------------------------------------------------
Epoch 6 — Train loss (Phase 1b): 0.

Getting the Dog Sub Classes

In [13]:
# Assuming super_map_df and train_ann_df are already loaded

DOG_SUPERCLASS_NAME = 'dog'

# 1. Identify the super-class index for "dog"
try:
    dog_superclass_index = super_map_df[super_map_df['class'] == DOG_SUPERCLASS_NAME].index[0]
except IndexError:
    raise ValueError(f"Super-class '{DOG_SUPERCLASS_NAME}' not found in super_map_df.")

print(f"The super-class index for '{DOG_SUPERCLASS_NAME}' is: {dog_superclass_index}")

# 2. Filter train_ann_df for "dog" examples
dog_train_df = train_ann_df[train_ann_df['superclass_index'] == dog_superclass_index].reset_index(drop=True)

print(f"Number of training examples for '{DOG_SUPERCLASS_NAME}': {len(dog_train_df)}")

# 3. Identify unique sub-class indices for "dog"
dog_subclass_indices = dog_train_df['subclass_index'].unique()
print(f"Unique sub-class indices for '{DOG_SUPERCLASS_NAME}': {dog_subclass_indices}")

# 4. Map sub-class indices to names
dog_subclasses = sub_map_df[sub_map_df.index.isin(dog_subclass_indices)]['class'].tolist()
print(f"Sub-classes for '{DOG_SUPERCLASS_NAME}': {dog_subclasses}")

num_dog_subclasses = len(dog_subclasses)
print(f"Number of sub-classes for '{DOG_SUPERCLASS_NAME}': {num_dog_subclasses}")

The super-class index for 'dog' is: 1
Number of training examples for 'dog': 2084
Unique sub-class indices for 'dog': [37 62 31 70 65 64 36 49 45  7 22 17 46 18  2 85 21 10 79 54 38  0 53 23
 32 25 12 77  9]
Sub-classes for 'dog': ['Scotch terrier, Scottish terrier, Scottie', 'standard schnauzer', 'Pekinese, Pekingese, Peke', 'Lhasa, Lhasa apso', 'Lakeland terrier', 'Tibetan terrier, chrysanthemum dog', 'cairn, cairn terrier', 'Blenheim spaniel', 'Chihuahua', 'Japanese spaniel', 'Dandie Dinmont, Dandie Dinmont terrier', 'Airedale, Airedale terrier', 'Shih-Tzu', 'giant schnauzer', 'basset, basset hound', 'Maltese dog, Maltese terrier, Maltese', 'Sealyham terrier, Sealyham', 'Australian terrier', 'papillon', 'bloodhound, sleuthhound', 'soft-coated wheaten terrier', 'West Highland white terrier', 'Afghan hound, Afghan', 'Rhodesian ridgeback', 'beagle', 'toy terrier', 'silky terrier, Sydney silky', 'Boston bull, Boston terrier', 'miniature schnauzer']
Number of sub-classes for 'dog': 29


In [14]:
class DogSubclassDataset(Dataset):
    def __init__(self, df, img_dir, transform, subclass_map_df):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.subclass_map_df = subclass_map_df
        self.dog_superclass_index = 1  # From the previous output
        self.dog_df = self.df[self.df['superclass_index'] == self.dog_superclass_index].reset_index(drop=True)
        self.dog_subclass_indices = self.dog_df['subclass_index'].unique().tolist()
        # Create a mapping from the original subclass index to a contiguous label (0 to num_dog_subclasses - 1)
        self.subclass_to_label = {index: label for label, index in enumerate(sorted(self.dog_subclass_indices))}

    def __len__(self):
        return len(self.dog_df)

    def __getitem__(self, idx):
        row = self.dog_df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        image = Image.open(img_path).convert('RGB')
        subclass_index = row['subclass_index']
        subclass_label = self.subclass_to_label[subclass_index]

        if self.transform:
            image = self.transform(image)

        return image, subclass_label

# Create the dataset for training the dog expert
train_dog_dataset = DogSubclassDataset(train_ann_df, train_img_dir, image_preprocessing, sub_map_df)

# Create a DataLoader for the dog training data
batch_size = 64
train_dog_loader = DataLoader(train_dog_dataset, batch_size=batch_size, shuffle=True)

# Let's also create a validation split for the dog data
from torch.utils.data import random_split

train_len = int(0.8 * len(train_dog_dataset))
val_len = len(train_dog_dataset) - train_len
train_dog_dataset, val_dog_dataset = random_split(train_dog_dataset, [train_len, val_len])

val_dog_loader = DataLoader(val_dog_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training samples for dog expert: {len(train_dog_dataset)}")
print(f"Number of validation samples for dog expert: {len(val_dog_dataset)}")
print(f"Number of dog sub-classes: {len(train_dog_dataset.dataset.subclass_to_label)}")


def build_dog_subclass_classifier(device, num_subclasses):
    class CLIPNeckSubclass(nn.Module):
        def __init__(self):
            super().__init__()
            self.clip, _ = clip.load('ViT-B/32', device=device)
            for p in self.clip.parameters():
                p.requires_grad = False

            D = self.clip.visual.output_dim
            self.neck = nn.Sequential(
                nn.Linear(D, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(512, D),
            )
            # Output layer for the dog sub-classes + 1 for 'novel'
            self.sub_head = nn.Linear(D, num_subclasses + 1)

        def forward(self, x):
            with torch.no_grad():
                f = self.clip.encode_image(x).float()
            f = self.neck(f)
            return self.sub_head(f)

    return CLIPNeckSubclass().to(device)

num_dog_subclasses = len(train_dog_dataset.dataset.subclass_to_label)
dog_expert_model = build_dog_subclass_classifier(device, num_dog_subclasses)
criterion_dog = nn.CrossEntropyLoss()
optimizer_dog = optim.Adam(filter(lambda p: p.requires_grad, dog_expert_model.parameters()), lr=1e-3)

Number of training samples for dog expert: 1667
Number of validation samples for dog expert: 417
Number of dog sub-classes: 29


In [15]:
def train_dog_expert(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_samples = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / len(train_loader)
        train_accuracy = 100 * correct_predictions / total_samples
        print(f"Epoch {epoch+1} — Train Loss (Dog Expert): {avg_loss:.3f}, Train Accuracy: {train_accuracy:.2f}%")

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_imgs, val_labels in val_loader:
                val_imgs, val_labels = val_imgs.to(device), val_labels.to(device)
                val_logits = model(val_imgs)
                loss = criterion(val_logits, val_labels)
                val_loss += loss.item()
                _, val_predicted = torch.max(val_logits, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Epoch {epoch+1} — Val Loss (Dog Expert): {avg_val_loss:.3f}, Val Accuracy: {val_accuracy:.2f}%")
        print("--------------------------------------------------")

# Train the dog expert
train_dog_expert(dog_expert_model, train_dog_loader, val_dog_loader, criterion_dog, optimizer_dog, device, num_epochs=10)

print("Dog expert training complete.")

Epoch 1 — Train Loss (Dog Expert): 1.459, Train Accuracy: 60.36%
Epoch 1 — Val Loss (Dog Expert): 0.392, Val Accuracy: 87.29%
--------------------------------------------------
Epoch 2 — Train Loss (Dog Expert): 0.304, Train Accuracy: 90.31%
Epoch 2 — Val Loss (Dog Expert): 0.174, Val Accuracy: 94.96%
--------------------------------------------------
Epoch 3 — Train Loss (Dog Expert): 0.217, Train Accuracy: 92.13%
Epoch 3 — Val Loss (Dog Expert): 0.234, Val Accuracy: 89.69%
--------------------------------------------------
Epoch 4 — Train Loss (Dog Expert): 0.191, Train Accuracy: 92.95%
Epoch 4 — Val Loss (Dog Expert): 0.133, Val Accuracy: 95.68%
--------------------------------------------------
Epoch 5 — Train Loss (Dog Expert): 0.128, Train Accuracy: 95.44%
Epoch 5 — Val Loss (Dog Expert): 0.137, Val Accuracy: 95.92%
--------------------------------------------------
Epoch 6 — Train Loss (Dog Expert): 0.109, Train Accuracy: 96.11%
Epoch 6 — Val Loss (Dog Expert): 0.103, Val Accura

**Getting the reptile sub classes**

In [16]:
# Assuming super_map_df and train_ann_df are already loaded

REPTILE_SUPERCLASS_NAME = 'reptile'

# 1. Identify the super-class index for "reptile"
try:
    reptile_superclass_index = super_map_df[super_map_df['class'] == REPTILE_SUPERCLASS_NAME].index[0]
except IndexError:
    raise ValueError(f"Super-class '{REPTILE_SUPERCLASS_NAME}' not found in super_map_df.")

print(f"The super-class index for '{REPTILE_SUPERCLASS_NAME}' is: {reptile_superclass_index}")

# 2. Filter train_ann_df for "reptile" examples
reptile_train_df = train_ann_df[train_ann_df['superclass_index'] == reptile_superclass_index].reset_index(drop=True)

print(f"Number of training examples for '{REPTILE_SUPERCLASS_NAME}': {len(reptile_train_df)}")

# 3. Identify unique sub-class indices for "reptile"
reptile_subclass_indices = reptile_train_df['subclass_index'].unique()
print(f"Unique sub-class indices for '{REPTILE_SUPERCLASS_NAME}': {reptile_subclass_indices}")

# 4. Map sub-class indices to names
reptile_subclasses = sub_map_df[sub_map_df.index.isin(reptile_subclass_indices)]['class'].tolist()
print(f"Sub-classes for '{REPTILE_SUPERCLASS_NAME}': {reptile_subclasses}")

num_reptile_subclasses = len(reptile_subclasses)
print(f"Number of sub-classes for '{REPTILE_SUPERCLASS_NAME}': {num_reptile_subclasses}")

The super-class index for 'reptile' is: 2
Number of training examples for 'reptile': 2354
Unique sub-class indices for 'reptile': [61 57  3 50 47 81 66 52 55 76 71 48 29  1 67 63 34 44 35 74 58 72 69 43
 13 15 68 39 33]
Sub-classes for 'reptile': ['African chameleon, Chamaeleo chamaeleon', 'terrapin', 'agama', 'mud turtle', 'spotted salamander, Ambystoma maculatum', 'common iguana, iguana, Iguana iguana', 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 'thunder snake, worm snake, Carphophis amoenus', 'frilled lizard, Chlamydosaurus kingi', 'American alligator, Alligator mississipiensis', 'hognose snake, puff adder, sand viper', 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 'American chameleon, anole, Anolis carolinensis', 'European fire salamander, Salamandra salamandra', 'triceratops', 'loggerhead, loggerhead turtle, Caretta caretta', 'alligator lizard', 'African crocodile, Nile crocodile, Crocodylus niloticus', 'common newt, Trit

In [17]:
class ReptileSubclassDataset(Dataset):
    def __init__(self, df, img_dir, transform, subclass_map_df):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.subclass_map_df = subclass_map_df
        self.reptile_superclass_index = 2  # From the previous output
        self.reptile_df = self.df[self.df['superclass_index'] == self.reptile_superclass_index].reset_index(drop=True)
        self.reptile_subclass_indices = self.reptile_df['subclass_index'].unique().tolist()
        # Create a mapping from the original subclass index to a contiguous label (0 to num_reptile_subclasses - 1)
        self.subclass_to_label = {index: label for label, index in enumerate(sorted(self.reptile_subclass_indices))}

    def __len__(self):
        return len(self.reptile_df)

    def __getitem__(self, idx):
        row = self.reptile_df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        image = Image.open(img_path).convert('RGB')
        subclass_index = row['subclass_index']
        subclass_label = self.subclass_to_label[subclass_index]

        if self.transform:
            image = self.transform(image)

        return image, subclass_label

# Create the dataset for training the reptile expert
train_reptile_dataset = ReptileSubclassDataset(train_ann_df, train_img_dir, image_preprocessing, sub_map_df)

# Create a DataLoader for the reptile training data
batch_size = 64
train_reptile_loader = DataLoader(train_reptile_dataset, batch_size=batch_size, shuffle=True)

# Let's also create a validation split for the reptile data
from torch.utils.data import random_split

train_len_reptile = int(0.8 * len(train_reptile_dataset))
val_len_reptile = len(train_reptile_dataset) - train_len_reptile
train_reptile_dataset, val_reptile_dataset = random_split(train_reptile_dataset, [train_len_reptile, val_len_reptile])

val_reptile_loader = DataLoader(val_reptile_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training samples for reptile expert: {len(train_reptile_dataset)}")
print(f"Number of validation samples for reptile expert: {len(val_reptile_dataset)}")
print(f"Number of reptile sub-classes: {len(train_reptile_dataset.dataset.subclass_to_label)}")

Number of training samples for reptile expert: 1883
Number of validation samples for reptile expert: 471
Number of reptile sub-classes: 29


In [18]:
def build_reptile_subclass_classifier(device, num_subclasses):
    class CLIPNeckSubclass(nn.Module):
        def __init__(self):
            super().__init__()
            self.clip, _ = clip.load('ViT-B/32', device=device)
            for p in self.clip.parameters():
                p.requires_grad = False

            D = self.clip.visual.output_dim
            self.neck = nn.Sequential(
                nn.Linear(D, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(512, D),
            )
            # Output layer for the reptile sub-classes + 1 for 'novel'
            self.sub_head = nn.Linear(D, num_subclasses + 1)

        def forward(self, x):
            with torch.no_grad():
                f = self.clip.encode_image(x).float()
            f = self.neck(f)
            return self.sub_head(f)

    return CLIPNeckSubclass().to(device)

num_reptile_subclasses = len(train_reptile_dataset.dataset.subclass_to_label)
reptile_expert_model = build_reptile_subclass_classifier(device, num_reptile_subclasses)
criterion_reptile = nn.CrossEntropyLoss()
optimizer_reptile = optim.Adam(filter(lambda p: p.requires_grad, reptile_expert_model.parameters()), lr=1e-3)

In [19]:
def train_reptile_expert(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_samples = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / len(train_loader)
        train_accuracy = 100 * correct_predictions / total_samples
        print(f"Epoch {epoch+1} — Train Loss (Reptile Expert): {avg_loss:.3f}, Train Accuracy: {train_accuracy:.2f}%")

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_imgs, val_labels in val_loader:
                val_imgs, val_labels = val_imgs.to(device), val_labels.to(device)
                val_logits = model(val_imgs)
                loss = criterion(val_logits, val_labels)
                val_loss += loss.item()
                _, val_predicted = torch.max(val_logits, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Epoch {epoch+1} — Val Loss (Reptile Expert): {avg_val_loss:.3f}, Val Accuracy: {val_accuracy:.2f}%")
        print("--------------------------------------------------")

# Train the reptile expert
train_reptile_expert(reptile_expert_model, train_reptile_loader, val_reptile_loader, criterion_reptile, optimizer_reptile, device, num_epochs=10)

print("Reptile expert training complete.")

Epoch 1 — Train Loss (Reptile Expert): 1.341, Train Accuracy: 60.54%
Epoch 1 — Val Loss (Reptile Expert): 0.376, Val Accuracy: 89.17%
--------------------------------------------------
Epoch 2 — Train Loss (Reptile Expert): 0.393, Train Accuracy: 86.19%
Epoch 2 — Val Loss (Reptile Expert): 0.260, Val Accuracy: 91.08%
--------------------------------------------------
Epoch 3 — Train Loss (Reptile Expert): 0.264, Train Accuracy: 90.31%
Epoch 3 — Val Loss (Reptile Expert): 0.217, Val Accuracy: 92.36%
--------------------------------------------------
Epoch 4 — Train Loss (Reptile Expert): 0.227, Train Accuracy: 91.76%
Epoch 4 — Val Loss (Reptile Expert): 0.141, Val Accuracy: 95.54%
--------------------------------------------------
Epoch 5 — Train Loss (Reptile Expert): 0.177, Train Accuracy: 93.50%
Epoch 5 — Val Loss (Reptile Expert): 0.091, Val Accuracy: 97.88%
--------------------------------------------------
Epoch 6 — Train Loss (Reptile Expert): 0.138, Train Accuracy: 95.20%
Epoch 

**Getting the reptile sub classes**

In [20]:
# Assuming super_map_df and train_ann_df are already loaded

BIRD_SUPERCLASS_NAME = 'bird'

# 1. Identify the super-class index for "bird"
try:
    bird_superclass_index = super_map_df[super_map_df['class'] == BIRD_SUPERCLASS_NAME].index[0]
except IndexError:
    raise ValueError(f"Super-class '{BIRD_SUPERCLASS_NAME}' not found in super_map_df.")

print(f"The super-class index for '{BIRD_SUPERCLASS_NAME}' is: {bird_superclass_index}")

# 2. Filter train_ann_df for "bird" examples
bird_train_df = train_ann_df[train_ann_df['superclass_index'] == bird_superclass_index].reset_index(drop=True)

print(f"Number of training examples for '{BIRD_SUPERCLASS_NAME}': {len(bird_train_df)}")

# 3. Identify unique sub-class indices for "bird"
bird_subclass_indices = bird_train_df['subclass_index'].unique()
print(f"Unique sub-class indices for '{BIRD_SUPERCLASS_NAME}': {bird_subclass_indices}")

# 4. Map sub-class indices to names
bird_subclasses = sub_map_df[sub_map_df.index.isin(bird_subclass_indices)]['class'].tolist()
print(f"Sub-classes for '{BIRD_SUPERCLASS_NAME}': {bird_subclasses}")

num_bird_subclasses = len(bird_subclasses)
print(f"Number of sub-classes for '{BIRD_SUPERCLASS_NAME}': {num_bird_subclasses}")

The super-class index for 'bird' is: 0
Number of training examples for 'bird': 1850
Unique sub-class indices for 'bird': [42  4 41 20 19 60 78 16  6 40 28 30 84 59 75 24  8 80 86 27 14 82 51 56
 26  5 11 73 83]
Sub-classes for 'bird': ['great grey owl, great gray owl, Strix nebulosa', 'bustard', 'ptarmigan', 'hen', 'pelican', 'junco, snowbird', 'cock', 'brambling, Fringilla montifringilla', 'king penguin, Aptenodytes patagonica', 'bald eagle, American eagle, Haliaeetus leucocephalus', 'albatross, mollymawk', 'water ouzel, dipper', 'black grouse', 'vulture', 'red-backed sandpiper, dunlin, Erolia alpina', 'redshank, Tringa totanus', 'oystercatcher, oyster catcher', 'ostrich, Struthio camelus', 'bulbul', 'house finch, linnet, Carpodacus mexicanus', 'goldfinch, Carduelis carduelis', 'dowitcher', 'chickadee', 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 'American coot, marsh hen, mud hen, water hen, Fulica americana', 'ruddy turnstone, Arenaria interpres', 'robin, American

In [21]:
class BirdSubclassDataset(Dataset):
    def __init__(self, df, img_dir, transform, subclass_map_df):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.subclass_map_df = subclass_map_df
        self.bird_superclass_index = 0  # From the previous output
        self.bird_df = self.df[self.df['superclass_index'] == self.bird_superclass_index].reset_index(drop=True)
        self.bird_subclass_indices = self.bird_df['subclass_index'].unique().tolist()
        # Create a mapping from the original subclass index to a contiguous label (0 to num_bird_subclasses - 1)
        self.subclass_to_label = {index: label for label, index in enumerate(sorted(self.bird_subclass_indices))}

    def __len__(self):
        return len(self.bird_df)

    def __getitem__(self, idx):
        row = self.bird_df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        image = Image.open(img_path).convert('RGB')
        subclass_index = row['subclass_index']
        subclass_label = self.subclass_to_label[subclass_index]

        if self.transform:
            image = self.transform(image)

        return image, subclass_label

# Create the dataset for training the bird expert
train_bird_dataset = BirdSubclassDataset(train_ann_df, train_img_dir, image_preprocessing, sub_map_df)

# Create a DataLoader for the bird training data
batch_size = 64
train_bird_loader = DataLoader(train_bird_dataset, batch_size=batch_size, shuffle=True)

# Let's also create a validation split for the bird data
from torch.utils.data import random_split

train_len_bird = int(0.8 * len(train_bird_dataset))
val_len_bird = len(train_bird_dataset) - train_len_bird
train_bird_dataset, val_bird_dataset = random_split(train_bird_dataset, [train_len_bird, val_len_bird])

val_bird_loader = DataLoader(val_bird_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training samples for bird expert: {len(train_bird_dataset)}")
print(f"Number of validation samples for bird expert: {len(val_bird_dataset)}")
print(f"Number of bird sub-classes: {len(train_bird_dataset.dataset.subclass_to_label)}")

def build_bird_subclass_classifier(device, num_subclasses):
    class CLIPNeckSubclass(nn.Module):
        def __init__(self):
            super().__init__()
            self.clip, _ = clip.load('ViT-B/32', device=device)
            for p in self.clip.parameters():
                p.requires_grad = False

            D = self.clip.visual.output_dim
            self.neck = nn.Sequential(
                nn.Linear(D, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(512, D),
            )
            # Output layer for the bird sub-classes + 1 for 'novel'
            self.sub_head = nn.Linear(D, num_subclasses + 1)

        def forward(self, x):
            with torch.no_grad():
                f = self.clip.encode_image(x).float()
            f = self.neck(f)
            return self.sub_head(f)

    return CLIPNeckSubclass().to(device)

num_bird_subclasses = len(train_bird_dataset.dataset.subclass_to_label)
bird_expert_model = build_bird_subclass_classifier(device, num_bird_subclasses)
criterion_bird = nn.CrossEntropyLoss()
optimizer_bird = optim.Adam(filter(lambda p: p.requires_grad, bird_expert_model.parameters()), lr=1e-3)

def train_bird_expert(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_samples = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / len(train_loader)
        train_accuracy = 100 * correct_predictions / total_samples
        print(f"Epoch {epoch+1} — Train Loss (Bird Expert): {avg_loss:.3f}, Train Accuracy: {train_accuracy:.2f}%")

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_imgs, val_labels in val_loader:
                val_imgs, val_labels = val_imgs.to(device), val_labels.to(device)
                val_logits = model(val_imgs)
                loss = criterion(val_logits, val_labels)
                val_loss += loss.item()
                _, val_predicted = torch.max(val_logits, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Epoch {epoch+1} — Val Loss (Bird Expert): {avg_val_loss:.3f}, Val Accuracy: {val_accuracy:.2f}%")
        print("--------------------------------------------------")

# Train the bird expert
train_bird_expert(bird_expert_model, train_bird_loader, val_bird_loader, criterion_bird, optimizer_bird, device, num_epochs=10)

print("Bird expert training complete.")

Number of training samples for bird expert: 1480
Number of validation samples for bird expert: 370
Number of bird sub-classes: 29
Epoch 1 — Train Loss (Bird Expert): 1.121, Train Accuracy: 74.54%
Epoch 1 — Val Loss (Bird Expert): 0.118, Val Accuracy: 97.84%
--------------------------------------------------
Epoch 2 — Train Loss (Bird Expert): 0.099, Train Accuracy: 97.19%
Epoch 2 — Val Loss (Bird Expert): 0.055, Val Accuracy: 98.65%
--------------------------------------------------
Epoch 3 — Train Loss (Bird Expert): 0.072, Train Accuracy: 97.41%
Epoch 3 — Val Loss (Bird Expert): 0.045, Val Accuracy: 98.92%
--------------------------------------------------
Epoch 4 — Train Loss (Bird Expert): 0.028, Train Accuracy: 99.41%
Epoch 4 — Val Loss (Bird Expert): 0.013, Val Accuracy: 99.73%
--------------------------------------------------
Epoch 5 — Train Loss (Bird Expert): 0.013, Train Accuracy: 99.73%
Epoch 5 — Val Loss (Bird Expert): 0.007, Val Accuracy: 100.00%
-------------------------

In [22]:
class HierarchicalValidationDataset(Dataset):
    def __init__(self, df, img_dir, transform, superclass_to_index):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.superclass_to_index = superclass_to_index

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        img_path = os.path.join(self.img_dir, row['image'])
        image = Image.open(img_path).convert('RGB')
        superclass_index = int(row['superclass_index'])
        subclass_index = int(row['subclass_index'])

        if self.transform:
            image = self.transform(image)

        return image, superclass_index, subclass_index

In [23]:
val_ds_hierarchical = HierarchicalValidationDataset(val_df, train_img_dir, val_tf, superclass_to_index)
val_loader_hierarchical = DataLoader(val_ds_hierarchical, batch_size=64, shuffle=False, num_workers=4)

In [24]:
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader

def validate_hierarchical(super_class_model, dog_expert, reptile_expert, bird_expert, val_loader, device, super_map_df, sub_map_df, train_dog_dataset, train_reptile_dataset, train_bird_dataset, confidence_threshold_superclass=0.8, confidence_threshold_subclass=0.7):
    """
    Validates the hierarchical classifier on the validation dataset with sub-class labels.

    Args:
        super_class_model (nn.Module): Trained super-class classifier.
        dog_expert (nn.Module): Trained dog sub-class expert.
        reptile_expert (nn.Module): Trained reptile sub-class expert.
        bird_expert (nn.Module): Trained bird sub-class expert.
        val_loader (DataLoader): DataLoader for the validation set (super-class AND sub-class labels).
        device (str): 'cuda' or 'cpu'.
        super_map_df (pd.DataFrame): DataFrame mapping super-class indices to names.
        sub_map_df (pd.DataFrame): DataFrame mapping sub-class indices to names.
        train_dog_dataset (Dataset): Training dataset for the dog expert (to access label mapping).
        train_reptile_dataset (Dataset): Training dataset for the reptile expert (to access label mapping).
        train_bird_dataset (Dataset): Training dataset for the bird expert (to access label mapping).
        confidence_threshold_superclass (float): Threshold for super-class novelty.
        confidence_threshold_subclass (float): Threshold for sub-class novelty.

    Returns:
        dict: Dictionary of validation metrics (super-class accuracy, sub-class accuracy, etc.).
    """
    super_class_model.eval()
    dog_expert.eval()
    reptile_expert.eval()
    bird_expert.eval()

    correct_super = 0
    total_super = 0
    novel_super_predicted = 0
    novel_super_actual = 0

    correct_sub = 0
    total_sub = 0
    novel_sub_predicted = 0
    novel_sub_actual = 0

    with torch.no_grad():
        for imgs, super_labels, sub_labels in val_loader: # Unpack sub_labels
            imgs = imgs.to(device)
            super_labels = super_labels.to(device)
            sub_labels = sub_labels.to(device)

            # 1. Super-class prediction and evaluation
            super_logits = super_class_model(imgs)
            super_probs = torch.softmax(super_logits, dim=1)
            max_prob_superclass, predicted_superclass_index = torch.max(super_probs[:, :3], dim=1)
            predicted_super_list = [super_map_df.iloc[i]['class'] for i in predicted_superclass_index.cpu().numpy()]
            true_super_indices = super_labels.cpu().numpy()
            true_super_list = [super_map_df.iloc[i]['class'] for i in true_super_indices]
            is_novel_superclass = (max_prob_superclass < confidence_threshold_superclass).cpu().numpy()

            for i in range(len(super_labels)):
                if is_novel_superclass[i]:
                    if true_super_list[i] not in ['bird', 'dog', 'reptile']:
                        novel_super_actual += 1
                    novel_super_predicted += 1
                else:
                    if predicted_super_list[i] == true_super_list[i]:
                        correct_super += 1
                total_super += 1

            # 2. Sub-class prediction and evaluation
            for i in range(len(imgs)):
                img = imgs[i].unsqueeze(0)
                predicted_superclass = predicted_super_list[i]
                true_superclass = true_super_list[i]
                true_subclass_index = sub_labels[i].item()
                true_subclass_name = sub_map_df.iloc[true_subclass_index]['class']

                sub_logits = None
                index_to_subclass = None
                predicted_subclass_name = "novel"
                max_prob_subclass_val = 0.0

                if predicted_superclass == 'dog':
                    sub_logits = dog_expert(img)
                    index_to_subclass = {v: k for k, v in train_dog_dataset.dataset.subclass_to_label.items()}
                elif predicted_superclass == 'reptile':
                    sub_logits = reptile_expert(img)
                    index_to_subclass = {v: k for k, v in train_reptile_dataset.dataset.subclass_to_label.items()}
                elif predicted_superclass == 'bird':
                    sub_logits = bird_expert(img)
                    index_to_subclass = {v: k for k, v in train_bird_dataset.dataset.subclass_to_label.items()}

                if sub_logits is not None:
                    sub_probs = torch.softmax(sub_logits[:, :-1], dim=1)
                    max_prob_subclass, predicted_subclass_index_local = torch.max(sub_probs, dim=1)
                    max_prob_subclass_val = max_prob_subclass.item()

                    if max_prob_subclass_val >= confidence_threshold_subclass:
                        # FIXED: Use the values directly instead of sorted keys positions
                        original_subclass_indices = list(index_to_subclass.values())
                        original_subclass_index = original_subclass_indices[predicted_subclass_index_local.item()]
                        predicted_subclass_name = sub_map_df.iloc[original_subclass_index]['class']

                    if predicted_superclass == true_superclass:
                        total_sub += 1
                        if predicted_subclass_name == true_subclass_name:
                            correct_sub += 1
                        # --- Conceptual Novel Sub-class Detection (Needs Careful Definition) ---
                        # if max_prob_subclass_val < confidence_threshold_subclass and true_subclass_name is a novel sub-class:
                        #     novel_sub_actual += 1
                        # if predicted_subclass_name == "novel" and (true_subclass_name is a novel sub-class or max_prob_subclass_val < confidence_threshold_subclass):
                        #     novel_sub_predicted += 1

    super_accuracy = 100 * correct_super / total_super if total_super > 0 else 0
    novel_super_precision = (novel_super_predicted / (novel_super_predicted + (total_super - correct_super - novel_super_predicted))) * 100 if novel_super_predicted > 0 else 0
    novel_super_recall = (novel_super_predicted / novel_super_actual) * 100 if novel_super_actual > 0 else 0

    sub_accuracy = 100 * correct_sub / total_sub if total_sub > 0 else 0
    # novel_sub_precision = (novel_sub_predicted / (novel_sub_predicted + (total_sub - correct_sub - novel_sub_predicted))) * 100 if novel_sub_predicted > 0 else 0
    # novel_sub_recall = (novel_sub_predicted / novel_sub_actual) * 100 if novel_sub_actual > 0 else 0

    return {
        "super_class_accuracy": super_accuracy,
        "novel_super_precision": novel_super_precision,
        "novel_super_recall": novel_super_recall,
        "sub_class_accuracy": sub_accuracy,
        # "novel_sub_precision": novel_sub_precision,
        # "novel_sub_recall": novel_sub_recall,
    }

# --- Run Validation with Sub-class Labels ---
validation_metrics = validate_hierarchical(
    model_phase1b,
    dog_expert_model,
    reptile_expert_model,
    bird_expert_model,
    val_loader_hierarchical, # Use the new DataLoader
    device,
    super_map_df,
    sub_map_df,
    train_dog_dataset,
    train_reptile_dataset,
    train_bird_dataset
)

print("Hierarchical Classifier Validation Metrics (with Sub-class Evaluation):")
for metric, value in validation_metrics.items():
    print(f"{metric}: {value}")

Hierarchical Classifier Validation Metrics (with Sub-class Evaluation):
super_class_accuracy: 99.61832061068702
novel_super_precision: 66.66666666666666
novel_super_recall: 0
sub_class_accuracy: 96.29865985960434


In [25]:
# prompt: what is the list of unique superclass_index values in val_df and what is the list of unique subclass_index values in val_df?

print(f"Unique superclass_index values in val_df: {val_df['superclass_index'].unique().tolist()}")
print(f"Unique subclass_index values in val_df: {val_df['subclass_index'].unique().tolist()}")


Unique superclass_index values in val_df: [2, 1, 0]
Unique subclass_index values in val_df: [39, 49, 42, 4, 13, 54, 38, 66, 61, 7, 74, 32, 18, 11, 45, 29, 65, 48, 30, 23, 70, 6, 68, 57, 71, 76, 21, 9, 83, 19, 28, 15, 81, 77, 0, 86, 84, 24, 44, 22, 80, 62, 16, 52, 41, 17, 69, 43, 35, 3, 64, 47, 37, 1, 46, 50, 14, 12, 27, 31, 2, 36, 72, 75, 26, 67, 59, 56, 58, 40, 55, 79, 78, 85, 63, 51, 73, 60, 8, 25, 20, 53, 10, 5, 82, 34, 33]


**Build wrapper class for compatibility with trainer**

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip

class HierarchicalCLIPClassifier(nn.Module):
    """
    Wrapper model that combines superclass classifier and expert subclass classifiers
    into a single model interface compatible with the Trainer class.
    """
    def __init__(self, device, superclass_model, dog_expert, reptile_expert, bird_expert,
                 superclass_threshold=0.9, subclass_threshold=0.9):
        super().__init__()
        self.device = device
        self.superclass_model = superclass_model
        self.dog_expert = dog_expert
        self.reptile_expert = reptile_expert
        self.bird_expert = bird_expert

        self.superclass_threshold = superclass_threshold
        self.subclass_threshold = subclass_threshold

        # Maps superclass index to name
        self.superclass_idx_to_name = {0: 'bird', 1: 'dog', 2: 'reptile'}
        # Maps superclass name to expert model
        self.expert_models = {
            'bird': self.bird_expert,
            'dog': self.dog_expert,
            'reptile': self.reptile_expert
        }

        # Maps from expert outputs to original subclass indices (prepare during initialization)
        self.expert_mappings = {
            'bird': list(train_bird_dataset.dataset.subclass_to_label.values()),
            'dog': list(train_dog_dataset.dataset.subclass_to_label.values()),
            'reptile': list(train_reptile_dataset.dataset.subclass_to_label.values())
        }

        # Novel class indices
        self.NOVEL_SUPERCLASS_IDX = 3  # Index for novel superclass
        self.NOVEL_SUBCLASS_IDX = 87   # Index for novel subclass

    def forward(self, x):
        # Get superclass predictions
        superclass_logits = self.superclass_model(x)

        # Prepare subclass logits tensor (batch_size x num_subclasses)
        # Initialize with very negative values (will be ignored in softmax)
        batch_size = x.size(0)
        subclass_logits = torch.ones(batch_size, 88, device=self.device) * -100

        # Get superclass predictions to route to experts
        super_probs = F.softmax(superclass_logits, dim=1)
        max_prob_superclass, predicted_superclass_idx = torch.max(super_probs[:, :3], dim=1)

        # Process each sample in batch individually to route to appropriate expert
        for i in range(batch_size):
            # Skip if prediction confidence is below threshold (mark as novel)
            if max_prob_superclass[i] < self.superclass_threshold:
                subclass_logits[i, self.NOVEL_SUBCLASS_IDX] = 0  # Set novel class logit to 0 (others are -100)
                continue

            # Get predicted superclass and corresponding expert
            pred_superclass_idx = predicted_superclass_idx[i].item()
            pred_superclass_name = self.superclass_idx_to_name[pred_superclass_idx]
            expert_model = self.expert_models[pred_superclass_name]

            # Get expert predictions
            expert_output = expert_model(x[i:i+1])  # Process single sample
            sub_probs = F.softmax(expert_output[:, :-1], dim=1)  # Exclude novel output
            max_prob_subclass, pred_subclass_local_idx = torch.max(sub_probs, dim=1)

            # If confidence is below threshold, mark as novel
            if max_prob_subclass.item() < self.subclass_threshold:
                subclass_logits[i, self.NOVEL_SUBCLASS_IDX] = 0
                continue

            # Map to original subclass index
            original_subclass_indices = self.expert_mappings[pred_superclass_name]
            original_subclass_idx = original_subclass_indices[pred_subclass_local_idx.item()]

            # Set the corresponding logit to a high value (will dominate in softmax)
            subclass_logits[i, original_subclass_idx] = 10.0

        return superclass_logits, subclass_logits

In [27]:
"""
# Initialize the hierarchical model
hierarchical_model = HierarchicalCLIPClassifier(
    device=device,
    superclass_model=model_phase1b,
    dog_expert=dog_expert_model,
    reptile_expert=reptile_expert_model,
    bird_expert=bird_expert_model,
    superclass_threshold=0.7,  # Adjust as needed
    subclass_threshold=0.7     # Adjust as needed
)

# Set up training components
criterion = nn.CrossEntropyLoss()
# Only optimize the parameters that need training if models are already trained
optimizer = optim.Adam(hierarchical_model.parameters(), lr=1e-3)

# Create the trainer
trainer = Trainer(
    model=hierarchical_model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device
)

# Skip training if models are already trained
# for epoch in range(1):
#     print(f'Epoch {epoch+1}')
#     trainer.train_epoch()
#     trainer.validate_epoch()
#     print('')

# Generate test predictions
print('Generating test predictions...')
test_predictions = trainer.test()
test_predictions.to_csv('test_predictions_hierarchical.csv', index=False)
print('Test predictions saved to test_predictions_hierarchical.csv')
"""

"\n# Initialize the hierarchical model\nhierarchical_model = HierarchicalCLIPClassifier(\n    device=device,\n    superclass_model=model_phase1b,\n    dog_expert=dog_expert_model,\n    reptile_expert=reptile_expert_model,\n    bird_expert=bird_expert_model,\n    superclass_threshold=0.7,  # Adjust as needed\n    subclass_threshold=0.7     # Adjust as needed\n)\n\n# Set up training components\ncriterion = nn.CrossEntropyLoss()\n# Only optimize the parameters that need training if models are already trained\noptimizer = optim.Adam(hierarchical_model.parameters(), lr=1e-3)\n\n# Create the trainer\ntrainer = Trainer(\n    model=hierarchical_model,\n    criterion=criterion,\n    optimizer=optimizer,\n    train_loader=train_loader,\n    val_loader=val_loader,\n    test_loader=test_loader,\n    device=device\n)\n\n# Skip training if models are already trained\n# for epoch in range(1):\n#     print(f'Epoch {epoch+1}')\n#     trainer.train_epoch()\n#     trainer.validate_epoch()\n#     print(''

In [28]:
# Initialize the hierarchical model
hierarchical_model = HierarchicalCLIPClassifier(
    device=device,
    superclass_model=model_phase1b,
    dog_expert=dog_expert_model,
    reptile_expert=reptile_expert_model,
    bird_expert=bird_expert_model,
    superclass_threshold=0.65,  # Adjusted threshold
    subclass_threshold=0.5      # Adjusted threshold
)

# Set up training components
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(hierarchical_model.parameters(), lr=1e-3)

# Create the trainer
trainer = Trainer(
    model=hierarchical_model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device
)

# Custom validation function for hierarchical model on val_loader_hierarchical
def validate_with_hierarchical(model, val_loader_hierarchical, device):
    model.eval()
    correct_super = 0
    total_super = 0
    correct_sub = 0
    total_sub = 0
    novel_super_pred = 0
    novel_sub_pred = 0

    with torch.no_grad():
        for imgs, super_labels, sub_labels in val_loader_hierarchical:
            imgs = imgs.to(device)
            super_labels = super_labels.to(device)
            sub_labels = sub_labels.to(device)

            super_logits, sub_logits = model(imgs)

            # Handle superclass predictions
            super_probs = F.softmax(super_logits, dim=1)
            max_prob_super, pred_super = torch.max(super_probs[:, :3], dim=1)
            novel_super_mask = max_prob_super < model.superclass_threshold
            pred_super[novel_super_mask] = 3  # Set to novel class

            # Calculate superclass accuracy
            total_super += super_labels.size(0)
            correct_super += (pred_super == super_labels).sum().item()
            novel_super_pred += novel_super_mask.sum().item()

            # Handle subclass predictions and accuracy
            sub_probs = F.softmax(sub_logits, dim=1)
            max_prob_sub, pred_sub = torch.max(sub_probs, dim=1)
            novel_sub_mask = pred_sub == 87  # Identify novel subclass predictions

            # Only count subclass accuracy for samples with correct superclass
            correct_super_mask = (pred_super == super_labels)
            total_sub += correct_super_mask.sum().item()
            correct_sub_and_super = correct_super_mask & (pred_sub == sub_labels)
            correct_sub += correct_sub_and_super.sum().item()
            novel_sub_pred += novel_sub_mask.sum().item()

    # Calculate final metrics
    super_acc = 100 * correct_super / total_super if total_super > 0 else 0
    sub_acc = 100 * correct_sub / total_sub if total_sub > 0 else 0
    novel_super_percent = 100 * novel_super_pred / total_super if total_super > 0 else 0
    novel_sub_percent = 100 * novel_sub_pred / total_super if total_super > 0 else 0

    return {
        "Superclass Accuracy": super_acc,
        "Subclass Accuracy": sub_acc,
        "Novel Superclass %": novel_super_percent,
        "Novel Subclass %": novel_sub_percent
    }

# Run validation on hierarchical validation set
print("Validating hierarchical model on validation set...")
val_metrics = validate_with_hierarchical(hierarchical_model, val_loader_hierarchical, device)
print("Validation Metrics:")
for metric, value in val_metrics.items():
    print(f"{metric}: {value:.2f}%")

# Generate test predictions
print('\nGenerating test predictions...')
test_predictions = trainer.test()
test_predictions.to_csv('test_predictions_hierarchical.csv', index=False)
print('Test predictions saved to test_predictions_hierarchical.csv')

# Analyze test predictions
print("\nTest Predictions Distribution:")
print(f"Total test samples: {len(test_predictions)}")
print(f"Superclass distribution:\n{test_predictions['superclass_index'].value_counts()}")
print(f"Percentage of novel superclass predictions: {(test_predictions['superclass_index'] == 3).mean() * 100:.2f}%")
print(f"Percentage of novel subclass predictions: {(test_predictions['subclass_index'] == 87).mean() * 100:.2f}%")

Validating hierarchical model on validation set...
Validation Metrics:
Superclass Accuracy: 99.62%
Subclass Accuracy: 0.70%
Novel Superclass %: 0.19%
Novel Subclass %: 0.64%

Generating test predictions...
Total superclasses unseen: 217
Total subclasses unseen: 0
Test predictions saved to test_predictions_hierarchical.csv

Test Predictions Distribution:
Total test samples: 11180
Superclass distribution:
superclass_index
2    4563
0    3412
1    2988
3     217
Name: count, dtype: int64
Percentage of novel superclass predictions: 1.94%
Percentage of novel subclass predictions: 10.75%


In [29]:

# Now, let's run the test and generate predictions
print("\nRunning test predictions...")
test_predictions = trainer.test()

# Save the predictions to a CSV file
test_predictions.to_csv('test_predictions_hierarchical.csv', index=False)
print("Test predictions saved to test_predictions_hierarchical.csv")

# Let's examine the test predictions distribution
print("\nTest Predictions Distribution:")
print(f"Total test samples: {len(test_predictions)}")
print(f"Superclass distribution:\n{test_predictions['superclass_index'].value_counts()}")
print(f"Percentage of novel superclass predictions: {(test_predictions['superclass_index'] == 3).mean() * 100:.2f}%")
print(f"Percentage of novel subclass predictions: {(test_predictions['subclass_index'] == 87).mean() * 100:.2f}%")

# Check if our predictions match the format expected by the leaderboard
print("\nValidating prediction format:")
print(f"All superclass indices are within expected range (0-3): {test_predictions['superclass_index'].between(0, 3).all()}")
print(f"All subclass indices are within expected range (0-87): {test_predictions['subclass_index'].between(0, 87).all()}")
print(f"No missing values in predictions: {not test_predictions.isnull().any().any()}")

# Show a sample of the predictions
print("\nSample of test predictions:")
print(test_predictions.head(10))


Running test predictions...
Total superclasses unseen: 217
Total subclasses unseen: 0
Test predictions saved to test_predictions_hierarchical.csv

Test Predictions Distribution:
Total test samples: 11180
Superclass distribution:
superclass_index
2    4563
0    3412
1    2988
3     217
Name: count, dtype: int64
Percentage of novel superclass predictions: 1.94%
Percentage of novel subclass predictions: 10.75%

Validating prediction format:
All superclass indices are within expected range (0-3): True
All subclass indices are within expected range (0-87): True
No missing values in predictions: True

Sample of test predictions:
   image  superclass_index  subclass_index
0  0.jpg                 1              17
1  1.jpg                 0               6
2  2.jpg                 2               6
3  3.jpg                 2              15
4  4.jpg                 1              13
5  5.jpg                 2               3
6  6.jpg                 0               4
7  7.jpg                

In [30]:
# Load and analyze the predictions
test_predictions = pd.read_csv('test_predictions_hierarchical.csv')

# Basic statistics
print(f"Total test samples: {len(test_predictions)}")
print(f"Superclass distribution:")
print(test_predictions['superclass_index'].value_counts())
print(f"Percentage classified as novel superclass: {(test_predictions['superclass_index'] == 3).mean() * 100:.2f}%")

# Look at subclass distribution
print(f"\nSubclass distribution - top 10 most common:")
print(test_predictions['subclass_index'].value_counts().head(10))
print(f"Number of unique subclasses predicted: {test_predictions['subclass_index'].nunique()}")

# Check for patterns in superclass-subclass relationship
print("\nMost common subclasses for each superclass:")
for super_idx in range(4):  # 0, 1, 2, 3 (where 3 is novel)
    super_mask = test_predictions['superclass_index'] == super_idx
    if super_mask.sum() > 0:
        print(f"Superclass {super_idx}:")
        print(test_predictions.loc[super_mask, 'subclass_index'].value_counts().head(5))

Total test samples: 11180
Superclass distribution:
superclass_index
2    4563
0    3412
1    2988
3     217
Name: count, dtype: int64
Percentage classified as novel superclass: 1.94%

Subclass distribution - top 10 most common:
subclass_index
87    1202
23     985
22     846
15     820
19     619
8      526
6      509
13     459
10     433
18     425
Name: count, dtype: int64
Number of unique subclasses predicted: 30

Most common subclasses for each superclass:
Superclass 0:
subclass_index
6     363
23    312
87    307
4     303
13    286
Name: count, dtype: int64
Superclass 1:
subclass_index
19    418
18    366
87    242
23    209
2     207
Name: count, dtype: int64
Superclass 2:
subclass_index
22    786
15    658
87    556
23    443
8     266
Name: count, dtype: int64
Superclass 3:
subclass_index
87    97
22    31
15    26
23    21
25    10
Name: count, dtype: int64



=== Running experiment: baseline ===


KeyboardInterrupt: 