In [None]:
import importlib.util
import subprocess
import sys

# List of packages to check and install
packages = ['transformers', 'timm', 'torch', 'torchvision', 'pandas', 'numpy', 'scikit-learn', 'Pillow']

def is_package_installed(package):
    """Check if a package is installed."""
    return importlib.util.find_spec(package) is not None

def install_package(package):
    """Install a package using pip."""
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

# Check and install each package
for package in packages:
    if not is_package_installed(package):
        print(f"Installing {package}...")
        install_package(package)
    else:
        print(f"{package} is already installed.")

print("All required packages are checked and installed if needed.")

transformers is already installed.
timm is already installed.
torch is already installed.
torchvision is already installed.
pandas is already installed.
numpy is already installed.
Installing scikit-learn...
Installing Pillow...
All required packages are checked and installed if needed.


In [None]:
import torch
import torchvision
import transformers
import timm
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from PIL import Image

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("Transformers version:", transformers.__version__)
print("TIMM version:", timm.__version__)
print("Pandas version:", pd.__version__)
print("NumPy version:", np.__version__)

PyTorch version: 2.6.0+cu124
Torchvision version: 0.21.0+cu124
Transformers version: 4.51.3
TIMM version: 1.0.15
Pandas version: 2.2.2
NumPy version: 2.0.2


In [None]:
import os
import requests
import zipfile
from pathlib import Path

# Create directory if it doesn't exist
data_dir = Path('./ePillID_data')
data_dir.mkdir(exist_ok=True)

# Download the file with the correct direct link
zip_path = Path('./ePillID_data.zip')
url = 'https://github.com/usuyama/ePillID-benchmark/releases/download/ePillID_data_v1.0/ePillID_data.zip'

# Only download if the file doesn't exist
if not zip_path.exists():
    print("Downloading dataset...")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(zip_path, 'wb') as f:
            f.write(response.content)
        print(f"Download complete. File size: {zip_path.stat().st_size / (1024*1024):.2f} MB")
    else:
        print(f"ERROR: Failed to download file. Status code: {response.status_code}")
else:
    print(f"File already exists at {zip_path}. Skipping download.")

# Check if file exists and has reasonable size before extracting
if zip_path.exists():
    if zip_path.stat().st_size > 1000000:  # More than 1MB
        # Only extract if the data directory is empty
        if not os.listdir(data_dir) if data_dir.exists() else True:
            print("Extracting ZIP file...")
            try:
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(data_dir)
                print("Extraction complete!")
            except zipfile.BadZipFile:
                print("ERROR: The file is not a valid ZIP file.")
        else:
            print(f"Data directory {data_dir} already has files. Skipping extraction.")
    else:
        print("ERROR: The file seems too small to be valid.")
else:
    print("ZIP file not found. Cannot extract.")

# Check the contents
if data_dir.exists() and os.listdir(data_dir):
    print(f"Files in data directory: {os.listdir(data_dir)}")
else:
    print("Data directory is empty or not found")

File already exists at ePillID_data.zip. Skipping download.
Data directory ePillID_data already has files. Skipping extraction.
Files in data directory: ['ePillID_data']


In [None]:
import os
import pandas as pd

# Update data_path to the nested directory
data_path = "./ePillID_data/ePillID_data"
metadata_file = os.path.join(data_path, "all_labels.csv")
metadata = pd.read_csv(metadata_file)
print(metadata.head())
print(metadata.columns)

     images             pilltype_id  label_code_id  prod_code_id  is_ref  \
0     0.jpg  51285-0092-87_BE305F72          51285            92   False   
1    10.jpg  00093-0148-01_4629A34D             93           148   False   
2   100.jpg  00093-7248-06_7829BC3D             93          7248   False   
3  1003.jpg  00093-0928-06_6926B4E5             93           928   False   
4  1004.jpg  50111-0459-01_1C300E70          50111           459   False   

   is_front  is_new                      image_path                   label  
0     False   False     fcn_mix_weight/dc_224/0.jpg  51285-0092-87_BE305F72  
1     False   False    fcn_mix_weight/dc_224/10.jpg  00093-0148-01_4629A34D  
2      True   False   fcn_mix_weight/dc_224/100.jpg  00093-7248-06_7829BC3D  
3     False   False  fcn_mix_weight/dc_224/1003.jpg  00093-0928-06_6926B4E5  
4      True   False  fcn_mix_weight/dc_224/1004.jpg  50111-0459-01_1C300E70  
Index(['images', 'pilltype_id', 'label_code_id', 'prod_code_id', 'is_ref',


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from collections import defaultdict
import os
from pathlib import Path

# Define transformations for ViT
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom Dataset
class EPillIDDataset(Dataset):
    def __init__(self, metadata, transform=None):
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = self.metadata.iloc[idx]["image_path"]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.metadata.iloc[idx]["pilltype_id"]
        return image, label

# Split reference and consumer datasets
reference_metadata = metadata[metadata["is_ref"] == True]
consumer_metadata = metadata[metadata["is_ref"] == False]
reference_dataset = EPillIDDataset(reference_metadata, transform=transform)
consumer_dataset = EPillIDDataset(consumer_metadata, transform=transform)
print(f"Reference images: {len(reference_dataset)}")
print(f"Consumer images: {len(consumer_dataset)}")

# Add directory analysis to show distribution of images
print("\nReference images by directory:")
ref_dir_counts = defaultdict(int)
for img_path in reference_metadata["image_path"]:
    # Extract just the directory part (not full path to each file)
    directory = os.path.dirname(img_path)
    ref_dir_counts[directory] += 1

# Print summary of reference image directories
for dir_path, count in sorted(ref_dir_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"- {dir_path} = {count} images")

print("\nConsumer images by directory:")
consumer_dir_counts = defaultdict(int)
for img_path in consumer_metadata["image_path"]:
    # Extract just the directory part (not full path to each file)
    directory = os.path.dirname(img_path)
    consumer_dir_counts[directory] += 1

# Print summary of consumer image directories
for dir_path, count in sorted(consumer_dir_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"- {dir_path} = {count} images")

Reference images: 9804
Consumer images: 3728

Reference images by directory:
- segmented_nih_pills_224 = 7804 images
- fcn_mix_weight/dr_224 = 2000 images

Consumer images by directory:
- fcn_mix_weight/dc_224 = 3728 images


In [None]:
# Cell 6 - update image paths in meta data to point to correct directories
# Update image paths in metadata
def update_image_paths(metadata):
    data_path = "./ePillID_data/ePillID_data"

    # Define directories
    dr_224_dir = os.path.join(data_path, "classification_data", "fcn_mix_weight", "dr_224")
    segmented_dir = os.path.join(data_path, "classification_data", "segmented_nih_pills_224")
    dc_224_dir = os.path.join(data_path, "classification_data", "fcn_mix_weight", "dc_224")

    # Update paths
    updated_metadata = metadata.copy()

    # For each row in the dataframe
    for idx, row in updated_metadata.iterrows():
        img_name = row['images']

        if row['is_ref']:
            # Try both reference directories
            path1 = os.path.join(dr_224_dir, img_name)
            path2 = os.path.join(segmented_dir, img_name)

            if os.path.exists(path1):
                updated_metadata.at[idx, 'image_path'] = path1
            elif os.path.exists(path2):
                updated_metadata.at[idx, 'image_path'] = path2
            else:
                updated_metadata.at[idx, 'image_path'] = None
        else:
            # Consumer images
            path = os.path.join(dc_224_dir, img_name)
            if os.path.exists(path):
                updated_metadata.at[idx, 'image_path'] = path
            else:
                updated_metadata.at[idx, 'image_path'] = None

    # Filter out images with no valid path
    valid_metadata = updated_metadata[updated_metadata['image_path'].notna()].copy()

    print(f"Total valid images: {len(valid_metadata)}")
    print(f"Valid reference images: {sum(valid_metadata['is_ref'] == True)}")
    print(f"Valid consumer images: {sum(valid_metadata['is_ref'] == False)}")

    return valid_metadata

# Update metadata with correct paths
fixed_metadata = update_image_paths(metadata)

Total valid images: 13532
Valid reference images: 9804
Valid consumer images: 3728


In [None]:
# Cell 7: Label encoding with the combined dataset
from sklearn.preprocessing import LabelEncoder

# Encode labels
label_encoder = LabelEncoder()
fixed_metadata['label_encoded'] = label_encoder.fit_transform(fixed_metadata['pilltype_id'])
num_classes = len(fixed_metadata['pilltype_id'].unique())
print(f"Number of unique pill types: {num_classes}")

# Split into reference and consumer datasets again
reference_metadata = fixed_metadata[fixed_metadata["is_ref"] == True]
consumer_metadata = fixed_metadata[fixed_metadata["is_ref"] == False]
print(f"Valid reference images: {len(reference_metadata)}")
print(f"Valid consumer images: {len(consumer_metadata)}")

# Print statistics about label distribution
print("\nLabel distribution:")
ref_labels = set(reference_metadata['label_encoded'])
con_labels = set(consumer_metadata['label_encoded'])
common_labels = ref_labels.intersection(con_labels)
print(f"Unique labels in reference set: {len(ref_labels)}")
print(f"Unique labels in consumer set: {len(con_labels)}")
print(f"Labels in both sets: {len(common_labels)}")

# Print information about a few examples from each set
print("\nSample reference pill:")
print(reference_metadata[['pilltype_id', 'label_encoded']].head(1).to_string(index=False))
print("\nSample consumer pill:")
print(consumer_metadata[['pilltype_id', 'label_encoded']].head(1).to_string(index=False))

Number of unique pill types: 4902
Valid reference images: 9804
Valid consumer images: 3728

Label distribution:
Unique labels in reference set: 4902
Unique labels in consumer set: 960
Labels in both sets: 960

Sample reference pill:
           pilltype_id  label_encoded
00002-3228-30_391E1C80              0

Sample consumer pill:
           pilltype_id  label_encoded
51285-0092-87_BE305F72           2089


In [None]:
# Cell 8: Data split for training and validation
from sklearn.model_selection import train_test_split
import pandas as pd

# For ePillID benchmark, we should:
# 1. Use consumer images for training/validation
# 2. Use reference images as the "gallery" for retrieval evaluation

# First, analyze the class distribution to understand our dataset
class_counts = consumer_metadata['label_encoded'].value_counts()
print(f"Distribution of samples per class:")
print(f"- Minimum samples per class: {class_counts.min()}")
print(f"- Maximum samples per class: {class_counts.max()}")
print(f"- Average samples per class: {class_counts.mean():.2f}")
print(f"- Classes with only 1 sample: {sum(class_counts == 1)}")
print(f"- Total unique classes: {len(class_counts)}")

# Given the high number of classes and many with single samples,
# we'll use a simple random split rather than stratification
train_data, val_data = train_test_split(
    consumer_metadata,
    test_size=0.2,
    random_state=42
    # No stratify parameter - using random split instead
)

print(f"\nTraining samples (consumer): {len(train_data)}")
print(f"Validation samples (consumer): {len(val_data)}")
print(f"Reference samples (gallery): {len(reference_metadata)}")

# Check how many classes are represented in each split
train_classes = set(train_data['label_encoded'])
val_classes = set(val_data['label_encoded'])
ref_classes = set(reference_metadata['label_encoded'])

print(f"\nUnique classes in training: {len(train_classes)}")
print(f"Unique classes in validation: {len(val_classes)}")
print(f"Unique classes in reference: {len(ref_classes)}")

# Check overlap
train_val_overlap = train_classes.intersection(val_classes)
train_ref_overlap = train_classes.intersection(ref_classes)
val_ref_overlap = val_classes.intersection(ref_classes)

print(f"\nClasses in both training and validation: {len(train_val_overlap)}")
print(f"Classes in both training and reference: {len(train_ref_overlap)}")
print(f"Classes in both validation and reference: {len(val_ref_overlap)}")

# Check for classes that appear in consumer but not reference images (or vice versa)
consumer_only = (train_classes.union(val_classes)).difference(ref_classes)
reference_only = ref_classes.difference(train_classes.union(val_classes))

print(f"\nClasses that appear only in consumer images: {len(consumer_only)}")
print(f"Classes that appear only in reference images: {len(reference_only)}")

# For evaluation, we should ensure we only test on classes that have both reference and consumer images
print(f"\nClasses with both reference and consumer images: {len(train_ref_overlap.union(val_ref_overlap))}")

# Additionally, check how many classes are in validation but not training
val_only_classes = val_classes.difference(train_classes)
if len(val_only_classes) > 0:
    print(f"\nWarning: {len(val_only_classes)} classes appear in validation but not training")

    # Check how many of these classes have reference images
    val_only_with_ref = val_only_classes.intersection(ref_classes)
    print(f"- Of these, {len(val_only_with_ref)} have reference images")

    # Print sample size for these classes
    val_only_samples = val_data[val_data['label_encoded'].isin(val_only_classes)]
    print(f"- Total samples for validation-only classes: {len(val_only_samples)}")

    # Optionally, we could move these to training to avoid "zero-shot" evaluation
    # But for now, we'll keep them, as this mimics real-world pill identification scenarios

Distribution of samples per class:
- Minimum samples per class: 1
- Maximum samples per class: 5
- Average samples per class: 3.88
- Classes with only 1 sample: 54
- Total unique classes: 960

Training samples (consumer): 2982
Validation samples (consumer): 746
Reference samples (gallery): 9804

Unique classes in training: 941
Unique classes in validation: 541
Unique classes in reference: 4902

Classes in both training and validation: 522
Classes in both training and reference: 941
Classes in both validation and reference: 541

Classes that appear only in consumer images: 0
Classes that appear only in reference images: 3942

Classes with both reference and consumer images: 960

- Of these, 19 have reference images
- Total samples for validation-only classes: 24


In [None]:
# Cell 9: Dataset and DataLoader setup with optimized dataset class
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
from PIL import Image
import os

# Define transformations
# For training: add augmentations to improve generalization
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# For validation and reference: no augmentations, just resizing and normalization
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class OptimizedEPillIDDataset(Dataset):
    def __init__(self, metadata, transform=None):
        self.metadata = metadata
        self.transform = transform
        self.valid_indices = []
        self.image_paths = []
        self.labels = []

        # Pre-process all valid paths and labels
        for idx in range(len(self.metadata)):
            img_path = self.metadata.iloc[idx]["image_path"]
            if os.path.exists(img_path):
                self.valid_indices.append(idx)
                self.image_paths.append(img_path)
                self.labels.append(self.metadata.iloc[idx]["label_encoded"])

        print(f"Dataset has {len(self.valid_indices)} valid images out of {len(self.metadata)} entries")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        metadata_idx = self.valid_indices[idx]

        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)

            return image, label, metadata_idx
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            placeholder = torch.zeros(3, 224, 224)
            return placeholder, label, metadata_idx

# Create datasets with the optimized class
train_dataset = OptimizedEPillIDDataset(train_data, transform=train_transform)
val_dataset = OptimizedEPillIDDataset(val_data, transform=val_transform)
reference_dataset = OptimizedEPillIDDataset(reference_metadata, transform=val_transform)

print(f"Training dataset: {len(train_dataset)} images")
print(f"Validation dataset: {len(val_dataset)} images")
print(f"Reference dataset: {len(reference_dataset)} images")

Dataset has 2982 valid images out of 2982 entries
Dataset has 746 valid images out of 746 entries
Dataset has 9804 valid images out of 9804 entries
Training dataset: 2982 images
Validation dataset: 746 images
Reference dataset: 9804 images


In [None]:
import torch
from torch.utils.data import DataLoader
import os
import psutil
import time

# Print system information for diagnostics
print(f"CPU count: {os.cpu_count()} logical processors")
print(f"Available memory: {psutil.virtual_memory().available / (1024**3):.2f} GB")

# Configure performance parameters - optimized for Colab
batch_size = 32  # Reduced for memory efficiency
num_workers = 4  # Reduced to avoid CPU memory overload
prefetch_factor = 2  # Reduced to minimize memory usage
persistent_workers = True

# Create optimized data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False,
    prefetch_factor=prefetch_factor,
    persistent_workers=persistent_workers if num_workers > 0 else False,
    drop_last=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False,
    prefetch_factor=prefetch_factor,
    persistent_workers=persistent_workers if num_workers > 0 else False,
    drop_last=False
)

reference_loader = DataLoader(
    reference_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False,
    prefetch_factor=prefetch_factor,
    persistent_workers=persistent_workers if num_workers > 0 else False,
    drop_last=False
)

# Print information about the data loaders
print(f"Training batches: {len(train_loader)} (batch size: {batch_size})")
print(f"Validation batches: {len(val_loader)} (batch size: {batch_size})")
print(f"Reference batches: {len(reference_loader)} (batch size: {batch_size})")

# Check disk I/O performance
try:
    print("\nTesting disk read speed (this takes a few seconds)...")
    start_time = time.time()
    test_size_mb = 100
    with open("temp_test_file.bin", "wb") as f:
        f.write(b'0' * 1024 * 1024 * test_size_mb)

    with open("temp_test_file.bin", "rb") as f:
        start_read = time.time()
        f.read()
        end_read = time.time()

    read_speed = test_size_mb / (end_read - start_read)
    print(f"Disk read speed: {read_speed:.2f} MB/s")

    os.remove("temp_test_file.bin")
except Exception as e:
    print(f"Couldn't test disk speed: {e}")

# Check first batch timing
print("\nTiming first batch loading...")
start_time = time.time()
dataloader_iterator = iter(train_loader)
first_batch = next(dataloader_iterator)
end_time = time.time()
images, labels, indices = first_batch

print(f"Time to load first batch: {end_time - start_time:.2f} seconds")
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")

CPU count: 2 logical processors
Available memory: 10.64 GB
Training batches: 94 (batch size: 32)
Validation batches: 24 (batch size: 32)
Reference batches: 307 (batch size: 32)

Testing disk read speed (this takes a few seconds)...
Disk read speed: 1722.51 MB/s





Timing first batch loading...
Time to load first batch: 1.09 seconds
Batch shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32])


CELL 11: THIS IS TO DEFINE THE MODEL:


In [None]:
import torch
import torch.nn as nn
import timm

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the Vision Transformer model (Swin Transformer)
model_name = "swin_base_patch4_window7_224"
num_classes = len(fixed_metadata['pilltype_id'].unique())  # 4902 pill types

# Initialize the model with global pooling
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=num_classes,
    global_pool='avg'  # Add average pooling to flatten features
)

# Move model to device
model = model.to(device)
print(f"Loaded {model_name} with {num_classes} output classes")

# Verify model output shape
dummy_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    output = model(dummy_input)
    print(f"Model output shape: {output.shape}")  # Should be [1, 4902]

# Initialize head weights to reduce initial logit magnitudes
with torch.no_grad():
    if hasattr(model, 'head') and isinstance(model.head, nn.Linear):
        model.head.weight.data.normal_(mean=0.0, std=0.01)
        model.head.bias.data.zero_()

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded swin_base_patch4_window7_224 with 4902 output classes
Model output shape: torch.Size([1, 4902])


CELL 12: THIS IS TO DEFINE LOSS FUNCTION AND OPTIMIZER

1.   CrossEntropyLoss is used for classification pre-training.
2.   AdamW is suitable for transformer models, with a low learning rate for fine-tuning.
3.   The cosine annealing scheduler helps with convergence.

In [None]:
import torch.optim as optim
import torch.nn as nn

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.01
)

# Define scheduler for learning rate decay
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print("Loss function and optimizer initialized")

Loss function and optimizer initialized


CELL 13:
 THIS IS FOR THE TRAINING LOOP:



*   This loop trains for classification, which helps the model learn discriminative features.
*   The model is saved when validation accuracy improves.
*   Adjust num_epochs based on your computational resources (start with 5–10 epochs).



In [None]:
from tqdm import tqdm
import torch
from torch.amp import GradScaler, autocast
import gc
import torch.nn.utils as nn_utils

def train_epoch(model, loader, criterion, optimizer, scaler, device, accum_steps=4):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    optimizer.zero_grad()

    for i, (images, labels, _) in enumerate(tqdm(loader, desc="Training")):
        images, labels = images.to(device), labels.to(device, dtype=torch.long)

        # Debug shapes
        if i == 0:
            print(f"Batch {i+1}: Images shape: {images.shape}, Labels shape: {labels.shape}")

        with autocast('cuda'):
            outputs = model(images)
            if i == 0:
                print(f"Batch {i+1}: Outputs shape: {outputs.shape}")
            if len(outputs.shape) != 2 or outputs.shape[1] != 4902:
                raise ValueError(f"Unexpected output shape: {outputs.shape}, expected [batch_size, 4902]")
            loss = criterion(outputs, labels) / accum_steps

        # Check for inf/NaN loss
        if not torch.isfinite(loss):
            print(f"Warning: Skipping batch {i+1} due to inf/NaN loss: {loss.item()}")
            optimizer.zero_grad()  # Clear gradients
            continue

        # Debug loss value
        if i == 0:
            print(f"Batch {i+1}: Loss value: {loss.item() * accum_steps}")

        scaler.scale(loss).backward()

        if (i + 1) % accum_steps == 0 or (i + 1) == len(loader):
            # Unscale gradients
            scaler.unscale_(optimizer)

            # Gradient clipping with stronger norm
            grad_norm = nn_utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            if i % 10 == 0 or (i + 1) == len(loader):
                print(f"Batch {i+1}: Gradient norm: {grad_norm:.4f}")

            # Check for inf/NaN gradients
            has_inf_gradients = False
            for param in model.parameters():
                if param.grad is not None and not torch.isfinite(param.grad).all():
                    has_inf_gradients = True
                    break
            if has_inf_gradients:
                print(f"Warning: Inf/NaN gradients in batch {i+1}, skipping optimizer step")
                scaler.update()
                optimizer.zero_grad()
                continue

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item() * accum_steps
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader) if total > 0 else float('inf')
    epoch_acc = 100 * correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for i, (images, labels, _) in enumerate(tqdm(loader, desc="Validation")):
            images, labels = images.to(device), labels.to(device, dtype=torch.long)
            with autocast('cuda'):
                outputs = model(images)
                if len(outputs.shape) != 2 or outputs.shape[1] != 4902:
                    raise ValueError(f"Unexpected output shape: {outputs.shape}, expected [batch_size, 4902]")
                loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"Warning: Skipping validation batch {i+1} due to inf/NaN loss")
                continue

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader) if total > 0 else float('inf')
    epoch_acc = 100 * correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc

# Clear memory
gc.collect()
torch.cuda.empty_cache()

# Initialize scaler
scaler = GradScaler('cuda')

# Training loop
num_epochs = 10
best_val_acc = 0.0
save_path = "./best_swin_model.pth"

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scaler, device, accum_steps=4)
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    scheduler.step()

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print(f"Saved best model with Val Acc: {best_val_acc:.2f}%")

print("Training complete")

Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 8.5396728515625


Training: 100%|██████████| 94/94 [00:40<00:00,  2.29it/s]


Batch 94: Gradient norm: 6.8740


Validation: 100%|██████████| 24/24 [00:04<00:00,  5.63it/s]


Epoch 1/10
Train Loss: 8.2233, Train Acc: 0.17%
Val Loss: 7.5748, Val Acc: 0.67%
Saved best model with Val Acc: 0.67%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 6.8516845703125


Training: 100%|██████████| 94/94 [00:42<00:00,  2.21it/s]


Batch 94: Gradient norm: 8.8354


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.10it/s]


Epoch 2/10
Train Loss: 6.3940, Train Acc: 3.55%
Val Loss: 6.4321, Val Acc: 3.08%
Saved best model with Val Acc: 3.08%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 5.01336669921875


Training:  13%|█▎        | 12/94 [00:06<00:35,  2.32it/s]



Training: 100%|██████████| 94/94 [00:41<00:00,  2.26it/s]


Batch 94: Gradient norm: 9.6763


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.12it/s]


Epoch 3/10
Train Loss: 4.8531, Train Acc: 15.93%
Val Loss: 5.5766, Val Acc: 8.71%
Saved best model with Val Acc: 8.71%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 3.76312255859375


Training:  51%|█████     | 48/94 [00:22<00:20,  2.29it/s]



Training: 100%|██████████| 94/94 [00:42<00:00,  2.21it/s]


Batch 94: Gradient norm: 12.9797


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.19it/s]


Epoch 4/10
Train Loss: 3.6484, Train Acc: 35.12%
Val Loss: 4.8710, Val Acc: 14.08%
Saved best model with Val Acc: 14.08%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 3.1591529846191406


Training: 100%|██████████| 94/94 [00:42<00:00,  2.22it/s]


Batch 94: Gradient norm: 9.0046


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.18it/s]


Epoch 5/10
Train Loss: 2.7001, Train Acc: 54.39%
Val Loss: 4.2958, Val Acc: 20.11%
Saved best model with Val Acc: 20.11%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 1.5195465087890625


Training: 100%|██████████| 94/94 [00:41<00:00,  2.25it/s]


Batch 94: Gradient norm: 8.4424


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.06it/s]


Epoch 6/10
Train Loss: 2.0049, Train Acc: 68.24%
Val Loss: 3.8728, Val Acc: 27.61%
Saved best model with Val Acc: 27.61%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 1.4219207763671875


Training: 100%|██████████| 94/94 [00:41<00:00,  2.25it/s]


Batch 94: Gradient norm: 7.3370


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.10it/s]


Epoch 7/10
Train Loss: 1.5091, Train Acc: 78.30%
Val Loss: 3.5670, Val Acc: 32.84%
Saved best model with Val Acc: 32.84%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 1.2771577835083008


Training: 100%|██████████| 94/94 [00:41<00:00,  2.25it/s]


Batch 94: Gradient norm: 10.3836


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.10it/s]


Epoch 8/10
Train Loss: 1.1924, Train Acc: 84.91%
Val Loss: 3.4411, Val Acc: 36.06%
Saved best model with Val Acc: 36.06%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 1.1783485412597656


Training: 100%|██████████| 94/94 [00:42<00:00,  2.23it/s]


Batch 94: Gradient norm: 9.7003


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.35it/s]


Epoch 9/10
Train Loss: 1.0584, Train Acc: 87.96%
Val Loss: 3.3672, Val Acc: 37.80%
Saved best model with Val Acc: 37.80%


Training:   0%|          | 0/94 [00:00<?, ?it/s]

Batch 1: Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
Batch 1: Outputs shape: torch.Size([32, 4902])
Batch 1: Loss value: 1.4635391235351562


Training: 100%|██████████| 94/94 [00:41<00:00,  2.24it/s]


Batch 94: Gradient norm: 7.9621


Validation: 100%|██████████| 24/24 [00:03<00:00,  6.13it/s]

Epoch 10/10
Train Loss: 0.9679, Train Acc: 89.70%
Val Loss: 3.3498, Val Acc: 37.53%
Training complete





CELL 14: THIS IS TO EXTRACT EMBEDDINGS TO COMPUTE SIMILARITY

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def extract_embeddings(model, loader, device):
    model.eval()
    embeddings = []
    labels = []
    indices = []

    with torch.no_grad():
        for images, lbls, idxs in tqdm(loader, desc="Extracting embeddings"):
            images = images.to(device)
            # Get embeddings (remove the classifier head)
            features = model.forward_features(images)  # For Swin Transformer

            # The shape of features is [batch_size, 7, 7, 1024]
            # We need to reshape it to [batch_size, 7*7*1024]
            batch_size = features.size(0)
            features = features.reshape(batch_size, -1)  # Flatten to [batch_size, 7*7*1024]

            embeddings.append(features.cpu())
            labels.append(lbls)
            indices.append(idxs)

    embeddings = torch.cat(embeddings, dim=0)
    labels = torch.cat(labels, dim=0)
    indices = torch.cat(indices, dim=0)
    return embeddings, labels, indices

# Extract embeddings for reference and validation sets
reference_embeddings, reference_labels, reference_indices = extract_embeddings(model, reference_loader, device)
val_embeddings, val_labels, val_indices = extract_embeddings(model, val_loader, device)

print(f"Reference embeddings shape: {reference_embeddings.shape}")
print(f"Validation embeddings shape: {val_embeddings.shape}")

Extracting embeddings: 100%|██████████| 307/307 [02:01<00:00,  2.53it/s]
Extracting embeddings: 100%|██████████| 24/24 [00:09<00:00,  2.52it/s]

Reference embeddings shape: torch.Size([9804, 7, 7, 1024])
Validation embeddings shape: torch.Size([746, 7, 7, 1024])





CELL 15: THIS IS TO COMPUTE METRICS

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

def compute_cosine_similarity(query_emb, gallery_emb):
    # Normalize embeddings for cosine similarity
    query_emb = F.normalize(query_emb, p=2, dim=1)
    gallery_emb = F.normalize(gallery_emb, p=2, dim=1)

    # Compute dot product between normalized embeddings
    similarity = torch.matmul(query_emb, gallery_emb.T)
    return similarity

def compute_retrieval_metrics(query_emb, query_labels, gallery_emb, gallery_labels, k_values=[2, 3, 4, 5]):
    similarity = compute_cosine_similarity(query_emb, gallery_emb)
    num_queries = query_emb.size(0)

    top_1_correct = 0
    top_k_correct = {k: 0 for k in k_values}
    total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0

    for i in tqdm(range(num_queries), desc="Computing metrics"):
        query_label = query_labels[i].item()
        scores = similarity[i].cpu().numpy()
        sorted_indices = np.argsort(scores)[::-1]  # Descending order

        # Get labels of top-k retrieved images
        retrieved_labels = gallery_labels.cpu().numpy()[sorted_indices]
        true_label_mask = (retrieved_labels == query_label)

        # Top-1 Accuracy
        if retrieved_labels[0] == query_label:
            top_1_correct += 1

        # Top-k Accuracy
        for k in k_values:
            if np.any(true_label_mask[:k]):
                top_k_correct[k] += 1

        # TP, FP, TN, FN for k=5
        k = 5
        top_k_retrieved = true_label_mask[:k]
        tp = np.sum(top_k_retrieved)  # Correctly retrieved
        fp = k - tp  # Incorrectly retrieved

        # Compute FN: Correct labels not in top-k
        relevant_indices = np.where(gallery_labels.cpu().numpy() == query_label)[0]
        fn = len(relevant_indices) - tp

        # Compute TN: Non-relevant images not in top-k
        # This is approximate to avoid computing all negatives
        non_relevant_indices = np.where(gallery_labels.cpu().numpy() != query_label)[0]
        retrieved_set = set(sorted_indices[:k])
        tn = len([idx for idx in non_relevant_indices if idx not in retrieved_set])

        total_tp += tp
        total_fp += fp
        total_tn += tn
        total_fn += fn

    # Compute accuracies
    top_1_accuracy = top_1_correct / num_queries
    top_k_accuracy = {k: correct / num_queries for k, correct in top_k_correct.items()}

    # Aggregate confusion matrix metrics
    confusion_metrics = {
        'TP': total_tp,
        'FP': total_fp,
        'TN': total_tn,
        'FN': total_fn
    }

    return top_1_accuracy, top_k_accuracy, confusion_metrics

# Make sure embeddings are flattened properly
print(f"Checking reference embeddings shape: {reference_embeddings.shape}")
print(f"Checking validation embeddings shape: {val_embeddings.shape}")

# Compute metrics
top_1_accuracy, top_k_accuracy, confusion_metrics = compute_retrieval_metrics(
    val_embeddings, val_labels,
    reference_embeddings, reference_labels,
    k_values=[2, 3, 4, 5]
)

# Print results
print(f"Top-1 Accuracy: {top_1_accuracy:.4f}")
for k, acc in top_k_accuracy.items():
    print(f"Top-{k} Accuracy: {acc:.4f}")
print("\nConfusion Matrix Metrics (k=5):")
print(f"True Positives: {confusion_metrics['TP']}")
print(f"False Positives: {confusion_metrics['FP']}")
print(f"True Negatives: {confusion_metrics['TN']}")
print(f"False Negatives: {confusion_metrics['FN']}")

  similarity = torch.matmul(query_emb, gallery_emb.T)


RuntimeError: The size of tensor a (746) must match the size of tensor b (1024) at non-singleton dimension 0

CELL 16: THIS IS SHAPE BASED PERFORMANCE ANALYSIS

In [None]:
import pandas as pd
import numpy as np

# Placeholder shape column (replace with actual shape data if available)
shape_types = ['round', 'capsule', 'oval', 'square']  # Example shapes
if 'shape' not in fixed_metadata.columns:
    np.random.seed(42)
    fixed_metadata['shape'] = np.random.choice(shape_types, size=len(fixed_metadata))

# Filter validation metadata
val_metadata = fixed_metadata[fixed_metadata.index.isin(val_data.index)]

# Compute metrics by shape
shape_results = {}
for shape in shape_types:
    # Filter validation embeddings and labels for this shape
    shape_indices = val_metadata[val_metadata['shape'] == shape].index
    valid_mask = np.isin(val_indices.cpu().numpy(), shape_indices)

    if sum(valid_mask) == 0:
        print(f"No validation samples for shape: {shape}")
        continue

    shape_val_emb = val_embeddings[valid_mask]
    shape_val_labels = val_labels[valid_mask]

    # Compute metrics using the function from Cell 15
    top_1_acc, top_k_acc, conf_metrics = compute_retrieval_metrics(
        shape_val_emb, shape_val_labels,
        reference_embeddings, reference_labels,
        k_values=[2, 3, 4, 5]
    )

    shape_results[shape] = {
        'Top-1 Accuracy': top_1_acc,
        'Top-2 Accuracy': top_k_acc[2],
        'Top-3 Accuracy': top_k_acc[3],
        'Top-4 Accuracy': top_k_acc[4],
        'Top-5 Accuracy': top_k_acc[5],
        'True Positives': conf_metrics['TP'],
        'False Positives': conf_metrics['FP'],
        'True Negatives': conf_metrics['TN'],
        'False Negatives': conf_metrics['FN'],
        'Num Samples': sum(valid_mask)
    }

    print(f"Shape: {shape} ({sum(valid_mask)} samples)")
    print(f"  Top-1 Accuracy: {top_1_acc:.4f}")
    for k in [2, 3, 4, 5]:
        print(f"  Top-{k} Accuracy: {top_k_acc[k]:.4f}")
    print(f"  Confusion Matrix (k=5): TP={conf_metrics['TP']}, FP={conf_metrics['FP']}, TN={conf_metrics['TN']}, FN={conf_metrics['FN']}")

# Save results to CSV
results_df = pd.DataFrame.from_dict(shape_results, orient='index')
results_df.to_csv("shape_based_results.csv")
print("Shape-based results saved to shape_based_results.csv")

CELL 17: THIS IS TO COMPILE AND SAVE METRICS TO JSON FILE

In [None]:
import json
import torch

# Compile results
results = {
    'overall': {
        'Top-1 Accuracy': float(top_1_accuracy),
        'Top-k Accuracy': {str(k): float(acc) for k, acc in top_k_accuracy.items()},
        'Confusion Metrics': {
            'True Positives': int(confusion_metrics['TP']),
            'False Positives': int(confusion_metrics['FP']),
            'True Negatives': int(confusion_metrics['TN']),
            'False Negatives': int(confusion_metrics['FN'])
        },
        'Num Validation Samples': len(val_dataset)
    },
    'by_shape': {
        shape: {
            'Top-1 Accuracy': float(metrics['Top-1 Accuracy']),
            'Top-k Accuracy': {
                '2': float(metrics['Top-2 Accuracy']),
                '3': float(metrics['Top-3 Accuracy']),
                '4': float(metrics['Top-4 Accuracy']),
                '5': float(metrics['Top-5 Accuracy'])
            },
            'Confusion Metrics': {
                'True Positives': int(metrics['True Positives']),
                'False Positives': int(metrics['False Positives']),
                'True Negatives': int(metrics['True Negatives']),
                'False Negatives': int(metrics['False Negatives'])
            },
            'Num Samples': int(metrics['Num Samples'])
        } for shape, metrics in shape_results.items()
    }
}

# Save to JSON
with open('evaluation_results.json', 'w') as f:
    json.dump(results, f, indent=4)
print("Results saved to evaluation_results.json")

# Save the final model
final_model_path = "./final_swin_model.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")

CELL 18: THIS IS TO VISUALIZE RESULTS

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

def visualize_retrieval(query_idx, val_dataset, reference_dataset, similarity, top_k=5):
    query_img, query_label, query_meta_idx = val_dataset[query_idx]
    query_pill_id = fixed_metadata.iloc[query_meta_idx]['pilltype_id']

    # Get top-k reference indices
    scores = similarity[query_idx].cpu().numpy()
    top_indices = np.argsort(scores)[-top_k:][::-1]

    # Plot
    fig, axes = plt.subplots(1, top_k + 1, figsize=(15, 3))
    axes[0].imshow(query_img.permute(1, 2, 0).numpy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
    axes[0].set_title(f"Query: {query_pill_id}")
    axes[0].axis('off')

    for i, ref_idx in enumerate(top_indices):
        ref_img, ref_label, ref_meta_idx = reference_dataset[ref_idx]
        ref_pill_id = fixed_metadata.iloc[ref_meta_idx]['pilltype_id']
        axes[i + 1].imshow(ref_img.permute(1, 2, 0).numpy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
        axes[i + 1].set_title(f"Rank {i+1}: {ref_pill_id}\nScore: {scores[ref_idx]:.4f}")
        axes[i + 1].axis('off')

    plt.tight_layout()
    plt.savefig(f"retrieval_example_{query_idx}.png")
    plt.show()

# Visualize a few examples
similarity = compute_cosine_similarity(val_embeddings, reference_embeddings)
for i in range(3):  # Show 3 examples
    visualize_retrieval(i, val_dataset, reference_dataset, similarity)