# Implementing a concept bottleneck model (CBM) on the CUB dataset. Inspired by: https://arxiv.org/pdf/2007.04612

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np

import os
import random

import pandas as pd
import requests
from tqdm import tqdm
import tarfile
from torchvision import models


from PIL import Image

import itertools




In [None]:
# Install gdown if you haven't already
!pip install gdown

# Resnet 18
!gdown 1me7X6jSSAZV0xaK_slQpxGKgF5nJTo8e

Downloading...
From: https://drive.google.com/uc?id=1me7X6jSSAZV0xaK_slQpxGKgF5nJTo8e
To: /content/resnet18-5c106cde.pth
100% 46.8M/46.8M [00:00<00:00, 62.7MB/s]


In [None]:
import os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Create a path to save your file in your Google Drive
save_dir = '/content/drive/My Drive/ConceptBottleneckBirds/'

file = 'your_filename.txt'  # Change filename as needed

# Specify the directory path you want to create

# Create directory if it doesn't exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f'Directory created at: {save_dir}')
else:
    print(f'Directory already exists at: {save_dir}')

# Example: Now you can save files in this directory
save_path = os.path.join(save_dir, file)
with open(save_path, 'w') as f:
    f.write('Hello from the birds using Colab!')



Mounted at /content/drive
Directory already exists at: /content/drive/My Drive/ConceptBottleneckBirds/


In [None]:
def set_seed(seed):
   torch.manual_seed(seed)
   torch.cuda.manual_seed(seed)
   torch.cuda.manual_seed_all(seed)
   np.random.seed(seed)
   random.seed(seed)
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False

set_seed(42)  # Replace with your seed number

In [None]:
class BirdsDataset(Dataset):
    def __init__(self, images, concepts, labels, transform=None):
        assert type(concepts) == type(labels) == type(images) == list, (
          "concepts, labels, and images must be of the same type, list. \nGot: %s, %s, %s" % (type(concepts), type(labels), type(images)))
        self.images = images
        self.concepts = concepts
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        concepts = self._convert_concepts_to_tensor(self.concepts[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, concepts, label


    def _convert_concepts_to_tensor(self, concept_list):
        """
        Convert list of concept dictionaries to binary tensor.
        We use is_present field to create a binary vector.
        """
        # Create tensor of zeros
        concept_tensor = torch.zeros(312)

        # Fill in the binary values from is_present
        for i, concept_dict in enumerate(concept_list):
            concept_tensor[i] = 1.0 if concept_dict['is_present'] == 1.0 else 0.0

        return concept_tensor


class CUBDataset(Dataset):
    """
    Create a PyTorch dataset from a list of image paths.

    Args:
        image_paths: List of paths to image files
        transform: Optional transform to be applied on images
                  (if None, will convert to tensor and normalize)
    """

    def __init__(self, image_paths, concepts, labels, transform=None):
      self.concepts = []
      self.labels = []
      self.images = []

      assert type(concepts) == type(labels) == type(image_paths) == list, (
        "concepts, labels, and image_paths must be of the same type, list. \nGot: %s, %s, %s" % (type(concepts), type(labels), type(image_paths)))

      assert len(image_paths) == len(concepts) == len(labels), (
        "Number of images, concepts, and labels must match")

      base_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            # Removed the crop from here. To allow it be applied dynamically at get time.
        ])

      # Default transform if none provided
      self.transform = transform if transform is not None else transforms.Compose([])

      for image_path, concept, label in zip(image_paths, concepts, labels):
        try:
          image = Image.open(image_path).convert('RGB')
        except Exception as e:
          print(f"Error loading image {image_path}: {str(e)}")
        # Apply base transforms
        image = base_transforms(image)
        self.images.append(image)

        self.concepts.append(self._convert_concepts_to_tensor(concept))
        self.labels.append(torch.tensor(label, dtype=torch.long))

    def _convert_concepts_to_tensor(self, concept_list):
        """
        Convert list of concept dictionaries to binary tensor.
        We use is_present field to create a binary vector.
        """
        # Create tensor of zeros
        concept_tensor = torch.zeros(312)

        # Fill in the binary values from is_present
        for i, concept_dict in enumerate(concept_list):
            concept_tensor[i] = 1.0 if concept_dict['is_present'] == 1.0 else 0.0

        return concept_tensor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        image = self.transform(self.images[idx])
        label = self.labels[idx]
        concept = self.concepts[idx]

        return image, concept, label


In [None]:

def download_cub200_2011():
    """
    Downloads the CUB-200-2011 dataset and extracts it.
    Returns the path to the extracted dataset.
    """
    # Create a directory for the dataset
    base_dir = '/content/CUB_200_2011'
    dataset_dir = os.path.join(base_dir, 'CUB_200_2011')

    # Check if dataset already exists
    if os.path.exists(dataset_dir) and os.path.exists(os.path.join(dataset_dir, 'images.txt')):
        print("Dataset already downloaded and extracted.")
        return dataset_dir

    os.makedirs(base_dir, exist_ok=True)

    # URL for the dataset
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz'
    tgz_path = os.path.join(base_dir, 'CUB_200_2011.tgz')

    # Download only if not already downloaded
    if not os.path.exists(tgz_path):
        print("Downloading CUB-200-2011 dataset...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(tgz_path, 'wb') as f:
            for data in tqdm(response.iter_content(chunk_size=1024),
                            total=total_size//1024,
                            unit='KB'):
                f.write(data)

    # Extract only if not already extracted
    if not os.path.exists(dataset_dir):
        print("\nExtracting dataset...")
        with tarfile.open(tgz_path, 'r:gz') as tar:
            tar.extractall(base_dir)

    # Remove the downloaded tar file to save space
    if os.path.exists(tgz_path):
        os.remove(tgz_path)

    return dataset_dir

def read_txt_file(filepath, num_cols):
    """
    Safely read space-separated text files with a specific number of columns.
    """
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= num_cols:
                data.append(parts[:num_cols])
    return pd.DataFrame(data)

def load_cub_data(data_dir):
    """
    Loads and organizes the CUB dataset metadata.
    Returns dictionaries for image paths, labels, and attribute data.
    """
    # Load image paths and labels using the safe reader
    images_df = read_txt_file(os.path.join(data_dir, 'images.txt'), 2)
    images_df.columns = ['image_id', 'image_path']
    images_df['image_id'] = images_df['image_id'].astype(int)

    labels_df = read_txt_file(os.path.join(data_dir, 'image_class_labels.txt'), 2)
    labels_df.columns = ['image_id', 'class_id']
    labels_df['image_id'] = labels_df['image_id'].astype(int)
    labels_df['class_id'] = labels_df['class_id'].astype(int)

    # Load train/test split
    train_test_df = read_txt_file(os.path.join(data_dir, 'train_test_split.txt'), 2)
    train_test_df.columns = ['image_id', 'is_training']
    train_test_df['image_id'] = train_test_df['image_id'].astype(int)
    train_test_df['is_training'] = train_test_df['is_training'].astype(int)

    # Load attributes using the safe reader
    attr_df = read_txt_file(os.path.join(data_dir, 'attributes/image_attribute_labels.txt'), 5)
    attr_df.columns = ['image_id', 'attribute_id', 'is_present', 'certainty', 'time']
    attr_df = attr_df.astype({
        'image_id': int,
        'attribute_id': int,
        'is_present': int,
        'certainty': int,
        'time': float
    })

    print("Merging")

    # Merge dataframes
    data = images_df.merge(labels_df, on='image_id')
    data = data.merge(train_test_df, on='image_id')

    print("Creating Dictionaries")
    # Create dictionaries
    image_paths = {row['image_id']: os.path.join(data_dir, 'images', row['image_path'])
                  for _, row in data.iterrows()}

    labels = {row['image_id']: row['class_id'] - 1  # Convert to 0-based indexing
             for _, row in data.iterrows()}

    train_test = {row['image_id']: row['is_training']
                  for _, row in data.iterrows()}

    # Organize attributes
    print("Organizing Attributes")
    # This is the slow part. Optimize...
    attributes = {}
    for _, row in attr_df.iterrows():
        image_id = row['image_id']
        if image_id not in attributes:
            attributes[image_id] = []
        attributes[image_id].append({
            'attribute_id': row['attribute_id'],
            'is_present': row['is_present'],
            'certainty': row['certainty']
        })

    return {
        'image_paths': image_paths,
        'labels': labels,
        'train_test_split': train_test,
        'attributes': attributes
    }


In [None]:
"""
# Download and extract the dataset
data_dir = download_cub200_2011()
print(f"\nDataset directory: {data_dir}")

# Load the dataset metadata
print("\nLoading dataset metadata...")
data = load_cub_data(data_dir)

num_classes = len(set(data['labels'].values()))

first_image_id = list(data['image_paths'].keys())[0]
num_concepts = len(data['attributes'][first_image_id])

# Print some statistics
print("\nDataset statistics:")
print(f"Total number of images: {len(data['image_paths'])}")
print(f"Number of training images: {sum(data['train_test_split'].values())}")

# a map of int id to class label 0 train, 1 test
print(f"Number of test images: {len(data['train_test_split']) - sum(data['train_test_split'].values())}")
print(f"Number of classes: {num_classes}")

# Example of accessing data for first image
print(f"\nExample data for image {first_image_id}:")
print(f"Image path: {data['image_paths'][first_image_id]}")
print(f"Class label: {data['labels'][first_image_id]}")
print(f"Is training: {data['train_test_split'][first_image_id]}")
print(f"Number of concepts: {num_concepts}")

"""

'\n# Download and extract the dataset\ndata_dir = download_cub200_2011()\nprint(f"\nDataset directory: {data_dir}")\n\n# Load the dataset metadata\nprint("\nLoading dataset metadata...")\ndata = load_cub_data(data_dir)\n\nnum_classes = len(set(data[\'labels\'].values()))\n\nfirst_image_id = list(data[\'image_paths\'].keys())[0]\nnum_concepts = len(data[\'attributes\'][first_image_id])\n\n# Print some statistics\nprint("\nDataset statistics:")\nprint(f"Total number of images: {len(data[\'image_paths\'])}")\nprint(f"Number of training images: {sum(data[\'train_test_split\'].values())}")\n\n# a map of int id to class label 0 train, 1 test\nprint(f"Number of test images: {len(data[\'train_test_split\']) - sum(data[\'train_test_split\'].values())}")\nprint(f"Number of classes: {num_classes}")\n\n# Example of accessing data for first image\nprint(f"\nExample data for image {first_image_id}:")\nprint(f"Image path: {data[\'image_paths\'][first_image_id]}")\nprint(f"Class label: {data[\'labels\

In [26]:
import copy
from torch.optim.lr_scheduler import StepLR


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

# Define the concept bottleneck model
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts=10, num_classes=200):
        super(ConceptBottleneckModel, self).__init__()
        self.encoder_res = models.resnet18(weights=None)
        self.encoder_res.load_state_dict(
            torch.load("/content/resnet18-5c106cde.pth")
        )
        #n_features = self.encoder_res.fc.in_features
        self.encoder_res.fc = Identity()
        self.features = nn.Sequential(self.encoder_res)

        # Concept predictor
        self.concept_predictor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 1024), # this was (512, 256)
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_concepts),
            nn.Sigmoid()
        )


        # Class predictor
        """
        self.class_predictor = nn.Sequential(
            nn.Linear(num_concepts, 2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        )
        """


        self.class_predictor = nn.Sequential(
            nn.Linear(num_concepts, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )


    def forward(self, x, return_concepts=False):
        features = self.features(x)
        concepts = self.concept_predictor(features)
        outputs = self.class_predictor(concepts)

        if return_concepts:
            return outputs, concepts
        return outputs

# Training function
def train_model(model,
                train_loader,
                val_loader,
                num_epochs=10,
                validation_interval=1,
                lr=0.001,
                lambda_=1.0,
                restore_best_model=True,
                device='cuda'):
    criterion_concepts = nn.BCELoss()
    criterion_classes = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5, verbose=True)
    model = model.to(device)
    best_val_acc = 0.0
    best_state_dict = None
    epoch_since_improvement_limit = 5
    #epoch_since_improvement_limit = 5
    epochs_since_improvement = 0

    for epoch in range(num_epochs):
        model.train()
        training_loss = 0.0
        train_correct = 0
        train_total = 0

        for i, (images, concepts, labels) in enumerate(train_loader):
            images = images.to(device)
            concepts = concepts.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            predicted_classes, predicted_concepts = model(images, return_concepts=True)

            # Calculate losses
            concept_loss = criterion_concepts(predicted_concepts, concepts)
            class_loss = criterion_classes(predicted_classes, labels)
            #print("predicted_classes = %s" % predicted_classes)
            #print("labels = %s" % labels)

            total_loss = class_loss + lambda_ * concept_loss
            training_loss += total_loss
            train_total += labels.size(0)
            _, predicted = torch.max(predicted_classes.data, 1)
            train_correct += (predicted == labels).sum().item()

            #print("Total Loss is: %s, Class Loss is: %s, Concept Loss: %s" % (total_loss.item(), class_loss.item(), concept_loss.item()))
            # Backward pass
            total_loss.backward()
            optimizer.step()

        # Print epoch-level metrics
        avg_train_loss = training_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_acc:.2f}%')

        if (epoch + 1) % validation_interval == 0:
          # Validation
          model.eval()
          val_acc = 0.0
          concept_acc = 0.0
          val_total = 0
          concept_total = 0
          validation_loss = 0


          with torch.no_grad():
              for images, concepts, labels in val_loader:
                  images = images.to(device)
                  concepts = concepts.to(device)
                  labels = labels.to(device)

                  predicted_labels, predicted_concepts = model(images, return_concepts=True)
                  """
                  # 1. Check your class predictions range before CrossEntropyLoss
                  print("predicted_labels min %s max %s" %(torch.min(predicted_labels), torch.max(predicted_labels)))
                  # Should be reasonable logit values, not extreme

                  # 2. Verify label format
                  print("labels min %s max %s" % (labels.min(), labels.max()))  # Should be 0 to num_classes-1

                  # 3. Check concept predictions before BCE
                  print("predicted_concepts min %s max %s" % (torch.min(predicted_concepts), torch.max(predicted_concepts)))
                  # Should be between 0-1 after sigmoid
                  """

                  _, predicted = torch.max(predicted_labels.data, 1)
                  val_total += labels.size(0)
                  val_acc += (predicted == labels).sum().item()
                  concept_total += concepts.size(0) * concepts.size(1)
                  preds = (predicted_concepts > 0.5).float()  # Need to threshold sigmoid outputs
                  concept_acc += (preds == concepts).float().sum()

                  concept_loss = criterion_concepts(predicted_concepts, concepts)
                  class_loss = criterion_classes(predicted_labels, labels)
                  validation_loss += class_loss + lambda_ * concept_loss
          val_acc = 100 * val_acc / val_total
          concept_acc = 100 * concept_acc / concept_total
          print(f'Validation Epoch {epoch+1}/{num_epochs}:')
          print(f'Validation Loss: {validation_loss/len(val_loader):.4f}')
          print(f'Validation Class Label Accuracy: {val_acc:.2f}%')
          print(f'Validation Concept Accuracy: {concept_acc:.2f}%')
          #scheduler.step(validation_loss)

          if val_acc > best_val_acc:
            print("New Best Model Validation Accuracy: %s" % val_acc)
            best_val_acc = val_acc
            best_state_dict = copy.deepcopy(model.state_dict())
            torch.save(best_state_dict, 'best_model.pth')
            epochs_since_improvement = 0
          else:
            epochs_since_improvement += 1
            print(f"Epochs since last improvement: {epochs_since_improvement}")
          if epochs_since_improvement >= epoch_since_improvement_limit:
            print("Soft reset of weights and reduce learning rate triggered")
            model.load_state_dict(best_state_dict)
            scheduler.step()
            epochs_since_improvement = 0
    if restore_best_model and best_state_dict:
        model.load_state_dict(best_state_dict)




In [27]:
def get_data_dict():
  data_dir = download_cub200_2011()
  data = load_cub_data(data_dir)
  return data


def get_train_val_test_datasets(data):

  # Initialize the split dictionary
  splits = {}

  # Get indices where value in dict is 1 (training)
  train_indices = [k for k, v in data['train_test_split'].items() if v == 1]

  # Randomly shuffle these indices
  shuffled_indices = np.random.permutation(train_indices)

  # Calculate split point for 80/20 split of training data
  n_train = int(len(train_indices) * 0.8)

  # First set all indices in original dict to 'test'
  for idx in data['train_test_split'].keys():
      splits[idx] = 'test'

  # Update training indices
  for idx in shuffled_indices[:n_train]:
      splits[idx] = 'train'

  # Update validation indices
  for idx in shuffled_indices[n_train:]:
      splits[idx] = 'val'

  data['split'] = splits


  # First get sorted IDs for train and test
  train_ids = sorted([id for id, split in data['split'].items() if split == "train"])
  val_ids = sorted([id for id, split in data['split'].items() if split == "val"])
  test_ids = sorted([id for id, split in data['split'].items() if split == "test"])

  print(len(train_ids))
  print(len(val_ids))
  print(len(test_ids))

  # Following the transformations from CBM paper
  resol = 299


  train_transforms = transforms.Compose(
    [
    transforms.RandomRotation(15),      # Rotate first while we have full image
    transforms.RandomResizedCrop(224),  # Then crop (avoiding empty corners)
    transforms.RandomHorizontalFlip(),
    ]
  )

  val_transforms = transforms.Compose([
      transforms.Resize(size=(224, 224)),
  ])

  test_transforms = transforms.Compose([
      transforms.Resize(size=(224, 224)),
      #transforms.CenterCrop(resol)
  ])


  # Create training dataset using the sorted train IDs
  train_dataset = CUBDataset(
      image_paths=[data['image_paths'][id] for id in train_ids],
      concepts=[data['attributes'][id] for id in train_ids],
      labels=[data['labels'][id] for id in train_ids],
      transform=train_transforms
  )

  val_dataset = CUBDataset(
      image_paths=[data['image_paths'][id] for id in val_ids],
      concepts=[data['attributes'][id] for id in val_ids],
      labels=[data['labels'][id] for id in val_ids],
      transform=val_transforms
  )

  # Create validation dataset using the sorted test IDs
  test_dataset = CUBDataset(
      image_paths=[data['image_paths'][id] for id in test_ids],
      concepts=[data['attributes'][id] for id in test_ids],
      labels=[data['labels'][id] for id in test_ids],
      transform=test_transforms
  )
   # Verify the split
  print(f"Training samples: {len(train_dataset)}")
  print(f"Validation samples: {len(val_dataset)}")
  print(f"Test samples: {len(test_dataset)}")

  return train_dataset, val_dataset, test_dataset



def get_train_val_test_loaders(train_dataset, val_dataset, test_dataset, batch_size):

  import multiprocessing as mp

  num_cpus = mp.cpu_count()
  num_workers = num_cpus - 2
  print(f"Number of CPUs: {num_cpus}")
  print(f"Number of workers: {num_workers}")

  train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True
  )

  val_loader = DataLoader(
      val_dataset,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True
  )

  test_loader = DataLoader(
      test_dataset,
      batch_size=batch_size,
      shuffle=False,  # No need to shuffle validation data
      num_workers=num_workers
  )
  return train_loader, val_loader, test_loader

In [None]:
%%time
data_dict = get_data_dict()
print("Creating Datasets")
train_dataset, val_dataset, test_dataset = get_train_val_test_datasets(data_dict)
print("Creating Dataloaders")
train_loader, val_loader, test_loader = get_train_val_test_loaders(train_dataset, val_dataset, test_dataset, batch_size=512)


Downloading CUB-200-2011 dataset...


1123619KB [00:43, 26014.92KB/s]                             



Extracting dataset...
Merging
Creating Dictionaries
Organizing Attributes
Creating Datasets
4795
1199
5794
Training samples: 4795
Validation samples: 1199
Test samples: 5794
Creating Dataloaders
Number of CPUs: 12
Number of workers: 10
CPU times: user 8min 28s, sys: 23.7 s, total: 8min 52s
Wall time: 4min 24s


In [28]:
%%time
num_concepts = 312
num_classes = 200
model = ConceptBottleneckModel(num_concepts=num_concepts, num_classes=num_classes)
#model = model.to(torch.bfloat16).to('cuda')

#model = model.to(torch.bfloat16)


# lambda_ controls the balance between the class loss and the concept loss. A small lambda shrinks the importance of the concept loss.


CPU times: user 257 ms, sys: 87.9 ms, total: 345 ms
Wall time: 291 ms


  torch.load("/content/resnet18-5c106cde.pth")


In [29]:
%%time

train_model(model, train_loader, val_loader, validation_interval=3, lr=0.0001, lambda_=1, num_epochs=300)

Epoch 1, Training Loss: 6.1614, Training Accuracy: 0.56%
Epoch 2, Training Loss: 6.0415, Training Accuracy: 0.83%
Epoch 3, Training Loss: 5.9136, Training Accuracy: 1.15%
Validation Epoch 3/300:
Validation Loss: 5.8607
Validation Class Label Accuracy: 2.00%
Validation Concept Accuracy: 66.62%
New Best Model Validation Accuracy: 2.0016680567139282
Epoch 4, Training Loss: 5.7862, Training Accuracy: 2.25%
Epoch 5, Training Loss: 5.6312, Training Accuracy: 3.77%
Epoch 6, Training Loss: 5.4873, Training Accuracy: 4.38%
Validation Epoch 6/300:
Validation Loss: 5.2679
Validation Class Label Accuracy: 11.51%
Validation Concept Accuracy: 75.88%
New Best Model Validation Accuracy: 11.509591326105088
Epoch 7, Training Loss: 5.3004, Training Accuracy: 6.65%
Epoch 8, Training Loss: 5.1425, Training Accuracy: 7.13%
Epoch 9, Training Loss: 4.9895, Training Accuracy: 8.74%
Validation Epoch 9/300:
Validation Loss: 4.6998
Validation Class Label Accuracy: 17.01%
Validation Concept Accuracy: 80.67%
New Be

In [30]:
from datetime import datetime

# Create timestamp string
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')  # Format: YYYYMMDD_HHMMSS


# Evaluate model
def evaluate_model(model, test_dataloader, device='cuda'):
    model.eval()  # Set to evaluation mode
    val_acc = 0.0
    concept_acc = 0.0
    val_total = 0
    concept_total = 0

    with torch.no_grad():
        for images, concepts, labels in test_dataloader:
            images = images.to(device)
            concepts = concepts.to(device)
            labels = labels.to(device)

            predicted_labels, predicted_concepts = model(images, return_concepts=True)
            _, predicted = torch.max(predicted_labels.data, 1)
            val_total += labels.size(0)
            val_acc += (predicted == labels).sum().item()
            concept_total += concepts.size(0) * concepts.size(1)
            preds = (predicted_concepts > 0.5).float()  # Need to threshold sigmoid outputs
            concept_acc += (preds == concepts).float().sum()

    val_acc = 100 * val_acc / val_total
    concept_acc = 100 * concept_acc / concept_total
    #print(f'Validation Loss: {running_loss/len(train_loader):.4f}')
    print(f'Test Class Label Accuracy: {val_acc:.2f}%')
    print(f'Test Concept Accuracy: {concept_acc:.2f}%')
    return val_acc, concept_acc


evaluate_model(model, test_loader)


best_state_dict = copy.deepcopy(model.state_dict())

# Save locally in this Colab
torch.save(best_state_dict, 'cub_model.pth')

# Save to Google Drive
torch.save(best_state_dict, os.path.join(save_dir,f'cub_model_{timestamp}.pth'))


Test Class Label Accuracy: 68.93%
Test Concept Accuracy: 90.15%


In [None]:
def check_data_ranges(dataloader, name=""):
    label_min = float('inf')
    label_max = float('-inf')
    concept_mins = None  # Changed from list to None
    concept_maxs = None

    for _, concepts, labels in dataloader:
        label_min = min(label_min, labels.min().item())
        label_max = max(label_max, labels.max().item())

        if concept_mins is None:  # First iteration
            concept_mins = concepts.min(dim=0)[0]
            concept_maxs = concepts.max(dim=0)[0]
        else:
            concept_mins = torch.minimum(concept_mins, concepts.min(dim=0)[0])
            concept_maxs = torch.maximum(concept_maxs, concepts.max(dim=0)[0])

    print(f"\n{name} Dataset Stats:")
    print(f"Label range: {label_min} to {label_max}")
    print(f"Number of unique concepts: {len(concept_mins)}")

# Check both datasets
check_data_ranges(train_loader, "Train")
check_data_ranges(val_loader, "Val")
check_data_ranges(test_loader, "Test")


Train Dataset Stats:
Label range: 0 to 199
Number of unique concepts: 312

Val Dataset Stats:
Label range: 0 to 199
Number of unique concepts: 312

Test Dataset Stats:
Label range: 0 to 199
Number of unique concepts: 312


In [None]:
def check_image_stats(dataloader, name=""):
    first_batch = next(iter(dataloader))
    images, concepts, labels = first_batch

    print(f"\n{name} Dataset Image Stats:")
    print(f"Image tensor shape: {images.shape}")  # Should be [batch_size, channels, height, width]
    print(f"Image dtype: {images.dtype}")
    print(f"Value range: min={images.min():.3f}, max={images.max():.3f}")
    print(f"Mean: {images.mean():.3f}")
    print(f"Std: {images.std():.3f}")

    # Check for NaN/Inf
    print(f"Contains NaN: {torch.isnan(images).any()}")
    print(f"Contains Inf: {torch.isinf(images).any()}")

# Check both datasets
check_image_stats(train_loader, "Train")
check_image_stats(val_loader, "Val")
check_image_stats(test_loader, "Test")


Train Dataset Image Stats:
Image tensor shape: torch.Size([512, 3, 224, 224])
Image dtype: torch.float32
Value range: min=-2.118, max=2.640
Mean: 0.103
Std: 1.056
Contains NaN: False
Contains Inf: False

Val Dataset Image Stats:
Image tensor shape: torch.Size([512, 3, 224, 224])
Image dtype: torch.float32
Value range: min=-2.118, max=2.640
Mean: 0.114
Std: 1.070
Contains NaN: False
Contains Inf: False

Test Dataset Image Stats:
Image tensor shape: torch.Size([512, 3, 224, 224])
Image dtype: torch.float32
Value range: min=-2.118, max=2.640
Mean: 0.118
Std: 1.048
Contains NaN: False
Contains Inf: False
