In [1]:
# --- 1. SETUP ---
# Clone the repository to get the model architecture files
!git clone https://github.com/bearpaw/pytorch-classification.git


Cloning into 'pytorch-classification'...
remote: Enumerating objects: 287, done.[K
remote: Total 287 (delta 0), reused 0 (delta 0), pack-reused 287 (from 1)[K
Receiving objects: 100% (287/287), 440.37 KiB | 1.73 MiB/s, done.
Resolving deltas: 100% (167/167), done.


In [2]:
# Change directory
%cd pytorch-classification

/home/monil/Desktop/Work/FML/Project/pytorch-classification


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
ls

cifar.py  imagenet.py  LICENSE  [0m[01;34mmodels[0m/  README.md  TRAINING.md  [01;34mutils[0m/


In [4]:


import torch
import torch.nn.functional as F # <-- Added for softmax
from collections import OrderedDict
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np

# Import all the model architectures we need
from models.cifar import resnet
from models.cifar.densenet import densenet, Bottleneck
from models.cifar.wrn import WideResNet

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"==> Using device: {device}")


# --- 2. LOAD CIFAR-100 TEST DATA ---
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset_full = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test)
# Using the full test set for final validation
test_loader = DataLoader(test_dataset_full, batch_size=100, shuffle=False, num_workers=2)
print("CIFAR-100 test data loaded.")


# --- 3. REUSABLE VALIDATION FUNCTION (MODIFIED) ---
def validate_model(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    total_confidence = 0.0 # New: variable to sum confidences
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            # --- New: Calculate confidence ---
            probabilities = F.softmax(outputs, dim=1)
            confidences, predicted = torch.max(probabilities, 1)
            total_confidence += confidences.sum().item()

            # --- Original: Calculate accuracy ---
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    # --- New: Calculate average confidence ---
    avg_confidence = 100 * total_confidence / total
    return accuracy, avg_confidence # Return both values




==> Using device: cpu


100%|██████████| 169M/169M [02:39<00:00, 1.06MB/s]


CIFAR-100 test data loaded.


In [6]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
# --- 4. LOAD AND VALIDATE ResNet-164 ---
print("\n--- Validating ResNet-164 ---")
try:
    model_resnet = resnet(depth=164, num_classes=100, block_name='Bottleneck').to(device)
    checkpoint_path_resnet = '/home/monil/Desktop/Work/FML/Project/resnet164Cifar100/model_best.pth.tar' # Make sure this path is correct
    checkpoint = torch.load(checkpoint_path_resnet, map_location=device, weights_only=False)
    state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
    model_resnet.load_state_dict(new_state_dict)

    # --- Updated: Get both accuracy and confidence ---
    accuracy, avg_conf = validate_model(model_resnet, test_loader, device)
    print(f" ResNet-164 Accuracy: {accuracy:.2f}% | Avg. Confidence: {avg_conf:.2f}%")

except FileNotFoundError:
    print(f" Could not find ResNet checkpoint at: {checkpoint_path_resnet}")




In [None]:
# --- 5. LOAD AND VALIDATE DenseNet-100 ---
print("\n--- Validating DenseNet-100 ---")
try:
    model_densenet = densenet(depth=190, num_classes=100, growthRate=40, compressionRate=2, block=Bottleneck).to(device)
    checkpoint_path_densenet = '/home/monil/Desktop/Work/FML/Project/densenet190Cifar100/model_best.pth.tar' # Make sure this path is correct
    checkpoint = torch.load(checkpoint_path_densenet, map_location=device, weights_only=False)
    state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
    model_densenet.load_state_dict(new_state_dict)

    # --- Updated: Get both accuracy and confidence ---
    accuracy, avg_conf = validate_model(model_densenet, test_loader, device)
    print(f" DenseNet-100 Accuracy: {accuracy:.2f}% | Avg. Confidence: {avg_conf:.2f}%")

except FileNotFoundError:
    print(f" Could not find DenseNet checkpoint at: {checkpoint_path_densenet}")




In [None]:
# --- 6. LOAD AND VALIDATE ResNet-56 from torch.hub ---
print("\n--- Validating ResNet-56 (from torch.hub) ---")
try:
    model_hub = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True, trust_repo=True).to(device)

    # --- Updated: Get both accuracy and confidence ---
    accuracy, avg_conf = validate_model(model_hub, test_loader, device)
    print(f" ResNet-56 Accuracy: {accuracy:.2f}% | Avg. Confidence: {avg_conf:.2f}%")

except Exception as e:
    print(f" Could not load torch.hub model. Error: {e}")



In [None]:
# --- 7. LOAD AND VALIDATE Wide ResNet-28-10 ---
print("\n--- Validating Wide ResNet-28-10 ---")
try:
    model_wrn = WideResNet(depth=28, num_classes=100, widen_factor=10, dropRate=0.3).to(device)
    checkpoint_path_wrn = '/home/monil/Desktop/Work/FML/Project/WRNCifar100/model_best.pth.tar' # Make sure this path is correct
    checkpoint = torch.load(checkpoint_path_wrn, map_location=device, weights_only=False)
    state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
    model_wrn.load_state_dict(new_state_dict)

    # --- Updated: Get both accuracy and confidence ---
    accuracy, avg_conf = validate_model(model_wrn, test_loader, device)
    print(f" Wide ResNet-28-10 Accuracy: {accuracy:.2f}% | Avg. Confidence: {avg_conf:.2f}%")

except FileNotFoundError:
    print(f" Could not find Wide ResNet checkpoint at: {checkpoint_path_wrn}")