<a href="https://colab.research.google.com/github/bhanup6663/COMP691_DL/blob/main/challenge2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Cell 1
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
# Cell 2: Prepare Training and Validation Data
def prepare_data(random_classes=True, num_classes=2, samples_per_class=25, val_samples_per_class=5, seed=None):
    if seed is not None:
        np.random.seed(seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    full_train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    if random_classes:
        selected_classes = np.random.choice(range(10), num_classes, replace=False)
    else:
        selected_classes = np.arange(num_classes)

    class_indices = [i for i in range(len(full_train_set)) if full_train_set.targets[i] in selected_classes]
    train_indices = []
    val_indices = []

    for cls in selected_classes:
        cls_indices = [i for i in class_indices if full_train_set.targets[i] == cls]
        np.random.shuffle(cls_indices)
        train_indices.extend(cls_indices[val_samples_per_class:])
        val_indices.extend(cls_indices[:val_samples_per_class])

    class RemappedSubset(torch.utils.data.Dataset):
        def __init__(self, dataset, indices, target_transform=None):
            self.dataset = dataset
            self.indices = indices
            self.target_transform = target_transform

        def __getitem__(self, idx):
            img, target = self.dataset[self.indices[idx]]
            if self.target_transform:
                target = self.target_transform(target)
            return img, target

        def __len__(self):
            return len(self.indices)

    target_transform = lambda x: selected_classes.tolist().index(x)
    train_subset = RemappedSubset(full_train_set, train_indices, target_transform=target_transform)
    val_subset = RemappedSubset(full_train_set, val_indices, target_transform=target_transform)

    train_loader = DataLoader(train_subset, batch_size=10, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=10, shuffle=False)

    return train_loader, val_loader, selected_classes

train_loader, val_loader, classes_used = prepare_data(seed=42)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 13008644.98it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
#Cell 3
def load_model():
    model = models.resnet18(weights="IMAGENET1K_V1")
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)

    return model

model = load_model()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 107MB/s]


In [None]:
# Cell 4: Training the Model with Validation Reporting
def train_model_with_validation(model, train_loader, val_loader, epochs=10, print_interval=2):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    best_val_acc = 0.0
    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        for batch, (inputs, labels) in enumerate(train_loader, 1):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            # Print for the first and last batch of each epoch only
            if batch == 1 or batch == len(train_loader):
                print(f'Epoch {epoch+1}, Batch {batch}, Loss: {loss.item():.4f}')

        train_loss /= len(train_loader)
        train_acc = train_correct / train_total

        # Validation phase
        val_loss, val_correct, val_total = 0.0, 0, 0
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print(f"New best model found at Epoch {epoch+1} with Validation Accuracy {best_val_acc*100:.2f}%.")

        scheduler.step()

    return model

model = train_model_with_validation(model, train_loader, val_loader, epochs=10)


Epoch 1, Batch 1, Loss: 0.7634


In [None]:
# Cell 5: Prepare Filtered Test Data with Corrected Class Mapping
def prepare_filtered_test_data(classes_used, seed=None):
    if seed is not None:
        np.random.seed(seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    class_indices = [i for i in range(len(test_set)) if test_set.targets[i] in classes_used]

    class RemappedSubset(torch.utils.data.Dataset):
        def __init__(self, dataset, indices, target_transform=None):
            self.dataset = dataset
            self.indices = indices
            self.target_transform = target_transform

        def __getitem__(self, idx):
            img, target = self.dataset[self.indices[idx]]
            if self.target_transform:
                target = self.target_transform(target)
            return img, target

        def __len__(self):
            return len(self.indices)

    if isinstance(classes_used, np.ndarray):
        classes_used = classes_used.tolist()

    target_transform = lambda x: classes_used.index(x) if x in classes_used else -1
    remapped_test_subset = RemappedSubset(test_set, class_indices, target_transform=target_transform)
    test_loader = DataLoader(remapped_test_subset, batch_size=10, shuffle=False)

    return test_loader

test_loader = prepare_filtered_test_data(classes_used)


In [None]:
# Cell 6: Evaluate Model with Confusion Matrix
def evaluate_model_with_confusion_matrix(model, test_loader, class_indices):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

    class_labels = [classes[i] for i in class_indices]

    correct = 0
    total = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy:.2f}%')

    cm = confusion_matrix(all_labels, all_predictions, labels=class_indices)
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

evaluate_model_with_confusion_matrix(model, test_loader, classes_used)


In [None]:
#Cell 7
def show_misclassified_images(model, test_loader, classes):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    misclassified = []
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            wrong_indices = predicted != labels
            if any(wrong_indices):
                wrong_images = images[wrong_indices].cpu()
                wrong_labels = labels[wrong_indices].cpu()
                wrong_preds = predicted[wrong_indices].cpu()
                misclassified.extend([(img, pred, true) for img, pred, true in zip(wrong_images, wrong_preds, wrong_labels)])
            if len(misclassified) >= 10:
                break

    plt.figure(figsize=(10, 10))
    for i, (img, pred, true) in enumerate(misclassified[:10]):
        img = img.numpy().transpose((1, 2, 0))
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        plt.subplot(5, 2, i+1)
        plt.imshow(img)
        plt.title(f'True: {classes[true]}, Pred: {classes[pred]}')
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()
    plt.show()

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
show_misclassified_images(model, test_loader, classes)
