In [None]:
import os
import random
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import top_k_accuracy_score
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, random_split
import timm

# CONFIG
IMG_DIR = os.path.abspath("../../data_t1/train")
CSV_PATH = os.path.abspath("../../data_t1/train.csv")
VAL_LABELS_FILE = os.path.abspath('../../data_t1/val.csv')
VAL_DIR = os.path.abspath('../../data_t1/val')

# Step 1: Load and filter dataset
train_df = pd.read_csv(CSV_PATH)

NUM_LABELS = len(train_df.label.unique())  # Set to 1000 if needed
BATCH_SIZE = 128
IMG_SIZE = 224
EPOCHS = 25
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODELS_TO_COMPARE = {
    "resnet50": models.resnet50(weights="ResNet50_Weights.DEFAULT"),
    # "efficientnet_b0": timm.create_model('efficientnet_b0', pretrained=True),
    # "densenet121": models.densenet121(weights="DenseNet121_Weights.DEFAULT")
}


train_df = train_df[train_df['label'].isin(train_df['label'].unique()[:NUM_LABELS])]
train_df = train_df.reset_index(drop=True)

# Step 2: Encode labels
le = LabelEncoder()
train_df['encoded_label'] = le.fit_transform(train_df['label'])

# Step 3: PyTorch Dataset
class CatDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.data = dataframe
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, row['encoded_label'], row['filename']

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CatDataset(train_df, IMG_DIR, transform)

# Train/Test split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=1)

# Step 4: Train and evaluate models
def train_and_evaluate(model_name, model):
    print(f"\n🚀 Training {model_name}")

    # ✅ Replace classifier BEFORE moving model to DEVICE
    if "resnet" in model_name:
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, NUM_LABELS)
    elif "efficientnet" in model_name:
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, NUM_LABELS)
    elif "densenet" in model_name:
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, NUM_LABELS)

    # ✅ Now move the full model to DEVICE
    model.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"\n📘 Epoch {epoch+1}/{EPOCHS}")
        for batch_idx, (images, labels, _) in enumerate(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            batch_correct = (predicted == labels).sum().item()
            correct += batch_correct
            total += labels.size(0)

            batch_acc = batch_correct / labels.size(0)
            print(f"[Batch {batch_idx+1}/{len(train_loader)}] Loss: {loss.item():.4f} | Accuracy: {batch_acc*100:.2f}%")

        epoch_loss = total_loss / len(train_loader)
        epoch_acc = correct / total
        print(f"🔹 Epoch {epoch+1} Summary — Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc*100:.2f}%")

        save_path = f"{model_name}_epoch{epoch+1}.pt"
        torch.save(model.state_dict(), os.path.join("trained_models",save_path))
        print(f"💾 Saved model checkpoint to {save_path}")
    return model_name, model
    # return true_labels, top3_preds, results


In [None]:
# Run all models
for name, model in MODELS_TO_COMPARE.items():
    model_name, model = train_and_evaluate(name, model)

In [4]:
def compute_top3_accuracy(df):
    correct_predictions = df['flag'].sum()
    total_predictions = len(df)
    top3_accuracy = (correct_predictions / total_predictions) * 100
    return top3_accuracy

def evaluate(model_name,model):
    # Evaluation
    model.eval()
    true_labels = []
    top3_preds = []
    results = []

    with torch.no_grad():
        for images, labels, filenames in tqdm(val_loader, desc=f"{model_name} - Inference"):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)  # ✅ Make sure labels are also on the same device
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            top3 = torch.topk(probs, k=3)
            preds = top3.indices[0].cpu().numpy()
            true_label = labels.item()
            true_labels.append(true_label)
            top3_preds.append(preds)
            top3_label_names = le.inverse_transform(preds)
            true_label_name = le.inverse_transform([true_label])[0]
            results.append([filenames[0], true_label_name] + top3_label_names.tolist())
    
    res_df = pd.DataFrame(results, columns=["filepath", "true_label", "label_1", "label_2", "label_3"])
    res_df["flag"] = res_df.apply(lambda row: int(row["true_label"] in [row["label_1"], row["label_2"], row["label_3"]]), axis=1)
    top3_accuracy = compute_top3_accuracy(res_df)
    print(f"🎯 Top-3 Accuracy: {top3_accuracy:.2f}%")

    # Optional: Save CSV if needed
    # res_df[['filepath', 'label_1', 'label_2', 'label_3']].to_csv(f"{model_name}_submission.csv", index=False)

    return res_df, top3_accuracy

In [10]:
model_name = 'resnet50'
model_path = "resnet50_epoch20.pt"
# Step 4: Load model
def load_trained_model(model_name, model_path, num_labels):
    if "resnet" in model_name:
        model = models.resnet50()
        model.fc = nn.Linear(model.fc.in_features, num_labels)
    elif "densenet" in model_name:
        model = models.densenet121()
        model.classifier = nn.Linear(model.classifier.in_features, num_labels)
    else:
        raise ValueError("Unsupported model type")
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model

model = load_trained_model(model_name, model_path, NUM_LABELS)

In [None]:
model_name = 'resnet50'
res_df, top3_accuracy = evaluate(model_name, model)