<a href="https://colab.research.google.com/github/manasdeshpande125/da6401_assignment2-partA/blob/main/DL_ASG2_Q2_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Impporting Libraries and Setting device for GPU**

In [None]:
# Imports
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl
import torchvision

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Wandb Login
import wandb
!wandb login 41a2853ea088e37bd0d456e78102e82edb455afc

In [None]:
# Optional: Enable DataParallel if multiple GPUs
def prepare_model_for_device(model):
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
        model = nn.DataParallel(model)
    return model.to(device)

**Dataset Download**

In [None]:
!wget https://storage.googleapis.com/wandb_datasets/nature_12K.zip -O nature_12K.zip
!unzip -q nature_12K.zip
!rm nature_12K.zip

**Load Data Function: Here I have split /train into training and validation**

In [None]:
def load_data(batch_count, data_aug='n', train_dir='inaturalist_12K/train'):
    if data_aug.lower() == 'y':
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=(0, 30)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    dataset = ImageFolder(root=train_dir, transform=transform)
    val_size = round(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(10))

    trainloader = DataLoader(train_ds, batch_size=batch_count, shuffle=True, num_workers=2)
    validationloader = DataLoader(val_ds, batch_size=batch_count, shuffle=False, num_workers=2)

    classes = dataset.classes
    return trainloader, validationloader, classes


In [None]:
# Load data
trainloader, valloader, classes = load_data(batch_count=32, data_aug='y')
print("Number of classes:", len(classes))
print("Classes:", classes)

**Building a CNN with 5 layers using Pytorch_Lightning Module**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger  #Used for logging into wandb


class LightningNet(pl.LightningModule):
    def __init__(self,
                 input_shape=(3, 224, 224),
                 filters=[4, 16, 32, 64, 128],
                 kernel_size=[3, 3, 3, 3, 3],
                 activation=nn.ReLU,
                 batch_size=32,
                 use_batch_norm=True,
                 use_dropout=True,
                 dropout_rate=0.25,
                 learning_rate=1e-3,
                 num_classes=10):   #default parameters

        super().__init__()
        self.save_hyperparameters()

        self.activation = activation()
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(dropout_rate)
        self.pool = nn.MaxPool2d(2, 2)
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()

        in_channels = input_shape[0]
        self.conv_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        assert len(filters) == len(kernel_size), "filters and kernel_sizes must be the same length"

        for out_channels, k_size in zip(filters, kernel_size):
            self.conv_layers.append(
                nn.Conv2d(in_channels, out_channels, kernel_size=k_size, padding=k_size // 2)
            )
            if use_batch_norm:
                self.bn_layers.append(nn.BatchNorm2d(out_channels))
            in_channels = out_channels

        # Compute flattened size after conv stack
        self.flattened_size = self._get_conv_output(input_shape, batch_size)

        # Fully connected layers
        self.fc1 = nn.Linear(self.flattened_size, 84)
        self.fc2 = nn.Linear(84, num_classes)

    def _get_conv_output(self, shape, batch_size):
        dummy_input = torch.zeros(batch_size, *shape)
        dummy_output = self._forward_features(dummy_input)
        return dummy_output.view(batch_size, -1).size(1)

    def _forward_features(self, x):
        for i, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batch_norm:
                x = self.bn_layers[i](x)
            x = self.activation(x)
            x = self.pool(x)
            if self.use_dropout:
                x = self.dropout(x)
        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = torch.flatten(x, 1)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True, on_step=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True, on_step=False)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        self.log('test_loss', loss, prog_bar=True,on_epoch=True, on_step=False)
        self.log('test_acc', acc, prog_bar=True,on_epoch=True, on_step=False)

        return {'test_loss': loss, 'test_acc': acc}

    # Commented out as it was showing some error in newer pytorch_lightning versions
    # def test_epoch_end(self, outputs):
    #     avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
    #     avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
    #     self.log('avg_test_loss', avg_loss,on_epoch=True, on_step=False)
    #     self.log('avg_test_acc', avg_acc,on_epoch=True, on_step=False)

    def configure_optimizers(self, optimizer_type='adam'):
        if optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)

        elif optimizer_type == 'momentum':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9)

        elif optimizer_type == 'nesterov':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9, nesterov=True)

        elif optimizer_type == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        elif optimizer_type == 'nadam':
            optimizer = torch.optim.NAdam(self.parameters(), lr=self.learning_rate)


In [None]:
# import torch.nn as nn

# def val_eval(net, device, val_loader):
#     net.eval()
#     val_loss = 0.0
#     val_total = 0
#     val_correct = 0
#     criterion = nn.CrossEntropyLoss()

#     with torch.no_grad():
#         for inputs, labels in val_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = net(inputs)
#             loss = criterion(outputs, labels)

#             val_loss += loss.item() * inputs.size(0)
#             _, preds = torch.max(outputs, 1)
#             val_total += labels.size(0)
#             val_correct += (preds == labels).sum().item()

#     avg_loss = val_loss / val_total
#     accuracy = round((val_correct / val_total) * 100, 2)
#     print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
#     return accuracy, avg_loss

In [None]:
# import torch.optim as optim
# from tqdm import tqdm

# def train_CNN(epochs, filter_mode, act_str, batch_count, data_aug,
#               batch_norm, drop, optimizer_name, lr_rate, drop_value):

#     # Load dataset
#     train_data, val_data, classes = load_data(batch_count, data_aug)

#     # Filter configurations
#     filter_map = {
#         'all_32':   [32, 32, 32, 32, 32],
#         'inc':      [16, 32, 64, 128, 256],
#         'dec':      [128, 64, 32, 16, 8],
#         'inc_dec':  [32, 64, 128, 64, 32],
#         'dec_inc':  [128, 64, 32, 64, 128],
#     }

#     filters = filter_map.get(filter_mode, [32, 64, 128, 64, 32])  # default fallback

#     # Activation function
#     activation_map = {
#         'relu': nn.ReLU(),
#         'sigmoid': nn.Sigmoid(),
#         'tanh': nn.Tanh(),
#         'gelu': nn.GELU(),
#         'silu': nn.SiLU()
#     }
#     act_fn = activation_map.get(act_str.lower(), nn.ReLU())  # default to ReLU


#     net = LightningNet(
#     input_shape=(3, 224, 224),
#     filters=[16, 32, 64, 128, 256],
#     kernel_size=3,
#     activation=nn.GELU,
#     use_batch_norm=True,
#     use_dropout=True,
#     dropout_rate=0.3,
#     learning_rate=1e-4,
#     num_classes=len(classes)  # from earlier data loader
#     )

#     trainer = pl.Trainer(max_epochs=10, accelerator='auto', devices=1)


#     # Loss function
#     criterion = nn.CrossEntropyLoss()

#     # Optimizer
#     optimizer_map = {
#         'SGD': optim.SGD(net.parameters(), lr=lr_rate, momentum=0.9),
#         'Adam': optim.Adam(net.parameters(), lr=lr_rate),
#         'RMSProp': optim.RMSprop(net.parameters(), lr=lr_rate, momentum=0.9)
#     }
#     optimizer = optimizer_map.get(optimizer_name, optim.Adam(net.parameters(), lr=lr_rate))

#     for epoch in range(1, epochs + 1):
#         # net.train()
#         trainer.fit(net, trainloader, valloader)
#         train_loss = 0.0
#         train_total = 0
#         train_correct = 0

#         for inputs, labels in tqdm(train_data, desc=f"Epoch {epoch}/{epochs}"):
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()

#             outputs = net(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             train_loss += loss.item() * inputs.size(0)
#             _, preds = torch.max(outputs, 1)
#             train_total += labels.size(0)
#             train_correct += (preds == labels).sum().item()

#         train_avg_loss = train_loss / train_total
#         train_acc = round((train_correct / train_total) * 100, 2)

#         print(f"Epoch {epoch}: Train Loss = {train_avg_loss:.4f}, Accuracy = {train_acc:.2f}%", end=" | ")

#         val_acc, val_loss = val_eval(net, device, val_data)

#         # Example wandb logging
#         wandb.log({
#             "epoch": epoch,
#             "train_loss": train_avg_loss,
#             "train_acc": train_acc,
#             "val_loss": val_loss,
#             "val_acc": val_acc
#         })

#     return train_avg_loss, train_acc, val_acc

In [None]:
# train_CNN(
#     epochs=5,
#     filter_mode='inc_dec',
#     act_str='relu',
#     batch_count=32,
#     data_aug='y',
#     batch_norm='y',
#     drop='y',
#     optimizer_name='Adam',
#     lr_rate=1e-3,
#     drop_value=0.3
# )

**Function to calculate loss and Accuracy**

In [None]:
def val_eval_lightning(model, val_loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    acc = round((total_correct / total_samples) * 100, 2)
    return acc, avg_loss


**Function which creates model for LightningNet class and trains it**

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train_CNN_lightning_manual(epochs, filter_mode, act_str, batch_count, data_aug,
                                batch_norm, drop, optimizer_name, lr_rate, drop_value):

    # 1. Load dataset
    train_loader, val_loader, classes = load_data(batch_count, data_aug)

    # 2. Filter configurations
    filter_map = {
        'all_32':   [32, 32, 32, 32, 32],
        'inc':      [16, 32, 64, 128, 256],
        'dec':      [128, 64, 32, 16, 8],
        'inc_dec':  [32, 64, 128, 64, 32],
        'dec_inc':  [128, 64, 32, 64, 128],
    }
    filters = filter_map.get(filter_mode, [32, 64, 128, 64, 32])

    # 3. Activation function
    activation_map = {
        'relu': nn.ReLU,
        'sigmoid': nn.Sigmoid,
        'tanh': nn.Tanh,
        'gelu': nn.GELU,
        'silu': nn.SiLU
    }
    activation_fn = activation_map.get(act_str.lower(), nn.ReLU)

    # 4. Instantiate model
    model = LightningNet(
        input_shape=(3, 224, 224),
        filters=filters,
        kernel_size=3,
        activation=activation_fn,
        batch_size=batch_count,
        use_batch_norm=(batch_norm == 'y'),
        use_dropout=(drop == 'y'),
        dropout_rate=drop_value,
        learning_rate=lr_rate,
        num_classes=len(classes)
    ).to(device)

    # 5. Define optimizer
    optimizer_map = {
        'sgd': optim.SGD(model.parameters(), lr=lr_rate, momentum=0.9),
        'adam': optim.Adam(model.parameters(), lr=lr_rate),
        'rmsprop': optim.RMSprop(model.parameters(), lr=lr_rate, momentum=0.9)
    }
    optimizer = optimizer_map.get(optimizer_name.lower(), optim.Adam(model.parameters(), lr=lr_rate))

    criterion = nn.CrossEntropyLoss()

    # 6. Training loop
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)  #class forward function
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / total
        train_acc = round((correct / total) * 100, 2)

        # 7. Validation after each epoch
        val_acc, val_loss = val_eval_lightning(model, val_loader)

        print(f"Epoch {epoch}: "
              f"Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.2f}%, "
              f"Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.2f}%\n")

    return model  # return model if you want to save/checkpoint later


In [None]:
trained_model = train_CNN_lightning_manual(
    epochs=2,
    filter_mode='inc',
    act_str='relu',
    batch_count=32,
    data_aug='y',
    batch_norm='y',
    drop='y',
    optimizer_name='adam',
    lr_rate=1e-3,
    drop_value=0.25
)

**Specifying the Sweep Configuration**

In [None]:
sweep_configuration = {
    'method': 'bayes',
    'metric': {
        'name': 'val_Accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'epochs': {'values': [10,20]},
        'hidden_layer_size': {'values': [84, 96, 128]},
        'learning_rate': {'values': [1e-3, 1e-4]},
        'weight_decay': {'values': [0, 0.0005, 0.5]},
        'optimizer_name': {
            'values': ['sgd', 'momentum', 'nesterov', 'rmsprop', 'adam', 'nadam']
        },
        'batch_size': {'values': [16, 32, 64]},
        'activation_type': {'values': ['sigmoid', 'tanh', 'ReLU', 'GELU', 'SiLU', 'Mish']},
        'loss_type': {'values': ['cross_entropy']},
        'filters': {
            'values': [
                'all_32',
                'inc',
                'dec',
                'inc_dec',
                'dec_inc',

            ]
        },
        'data_augmentation': {'values': [True, False]},
        'use_batch_norm': {'values': [True, False]},
        'use_dropout': {'values': [True, False]},
        'dropout_rate': {'values': [0.2, 0.3]},

        # NEW kernel size combinations
        'kernel_sizes': {
            'values': [
                [3, 3, 3, 3, 3],
                [3, 3, 5, 3, 3],
                [5, 3, 5, 3, 5],
                [3, 5, 3, 5, 3],
                [5 ,5 , 3, 3, 3]
            ]
        }
    }
}


**Function to call train and test function based on config parameters with early stopping**

In [None]:
import torch.optim as optim
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl

def train_CNN_lightning(epochs, filter_mode, act_str, batch_count, data_aug,
                        batch_norm, drop, optimizer_name, lr_rate, drop_value,kernel_sizes,test_loader=None):

    # 1. Load dataset
    train_loader, val_loader, classes = load_data(batch_count, data_aug)
    #test_loader, test_classes = load_test_data(batch_count)

    # 2. Define filter configurations
    filter_map = {
        'all_32':   [32, 32, 32, 32, 32],
        'inc':      [16, 32, 64, 128, 256],
        'dec':      [128, 64, 32, 16, 8],
        'inc_dec':  [32, 64, 128, 64, 32],
        'dec_inc':  [128, 64, 32, 64, 128],
    }
    filters = filter_map.get(filter_mode, [32, 64, 128, 64, 32])
    if len(kernel_sizes) != len(filters):
        raise ValueError("Length of kernel_sizes must match number of filters.")

    # 3. Activation functions
    activation_map = {
        'relu': nn.ReLU,
        'sigmoid': nn.Sigmoid,
        'tanh': nn.Tanh,
        'gelu': nn.GELU,
        'silu': nn.SiLU,
        'mish': nn.Mish
    }
    activation_fn = activation_map.get(act_str.lower(), nn.ReLU)

    # 4. Optimizer mapping passed as string to LightningNet
    optimizer_name = optimizer_name.lower()  # For consistency

    # 5. Model initialization
    model = LightningNet(
        input_shape=(3, 224, 224),
        filters=filters,
        kernel_size=kernel_sizes,
        activation=activation_fn,
        batch_size=batch_count,
        use_batch_norm=(batch_norm == 'y'),
        use_dropout=(drop == 'y'),
        dropout_rate=drop_value,
        learning_rate=lr_rate,
        num_classes=len(classes)
    )

    # 6. WandB logger (Optional, comment if not using wandb)
    wandb_logger = WandbLogger(project="DA6401-assignment-2", log_model=True)

    # 7. Callbacks (optional: Early stopping, Checkpointing)
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=3, mode='min'),
        ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, filename='{epoch}-{val_loss:.2f}')
    ]

    # 8. Trainer setup
    trainer = Trainer(
        max_epochs=epochs,
        accelerator='auto',
        devices=1 if torch.cuda.is_available() else None,
        callbacks=callbacks,
        logger=wandb_logger
    )

    # 9. Train the model
    trainer.fit(model, train_loader, val_loader)
    # 10. Test the model
    trainer.test(model, dataloaders=test_loader)
    return model


**Sweep Function**

In [None]:
import wandb
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

def sweep_train():
    wandb.init(project="DA6401-Assignment-2", entity="cs24m024-iit-madras-org", config=sweep_configuration)
    config = wandb.config

    # Set up run name for easier tracking in W&B
    wandb.run.name = f"e_{config.epochs}_lr_{config.learning_rate}_wd_{config.weight_decay}_o_{config.optimizer_name}_bs_{config.batch_size}_ac_{config.activation_type}_los_{config.loss_type}"


    # Call your training function
    train_CNN_lightning(
        epochs=config.epochs,
        filter_mode=config.filters,
        act_str=config.activation_type,
        batch_count=config.batch_size,
        data_aug='y' if config.data_augmentation else 'n',
        batch_norm='y' if config.use_batch_norm else 'n',
        drop='y' if config.use_dropout else 'n',
        optimizer_name=config.optimizer_name,
        lr_rate=config.learning_rate,
        drop_value=config.dropout_rate,
        kernel_sizes=config.kernel_sizes
    )


In [None]:
sweep_id = wandb.sweep(sweep_configuration,project='DA6401-Assignment-2')
wandb.agent(sweep_id,function=sweep_train,project='DA6401-Assignment-2',count=100)

**To calculate top 3 sweep configurations for validation accuracy**

In [None]:
entity = 'cs24m024-iit-madras'
project = 'DA6401-assignment-2'
sweep_ids = [
    '37haj8j1',   # Sweep 1 ID
    'u47hemx2',   # Sweep 2 ID
]

# Connect to the API
api = wandb.Api()

# Gather runs from both sweeps
all_runs = []
for sweep_id in sweep_ids:
    sweep = api.sweep(f"{entity}/{project}/{sweep_id}")
    runs = sweep.runs
    for run in runs:
        if run.state == 'finished' and 'val_acc' in run.summary:
            all_runs.append(run)

# Sort runs based on val_acc
top_runs = sorted(all_runs, key=lambda r: r.summary['val_acc'], reverse=True)[:3]

# Print top 3 models across both sweeps
print("Top 3 runs by validation accuracy (across both sweeps):")
for i, run in enumerate(top_runs, 1):
    print(f"\nModel #{i}")
    print(f"Name        : {run.name}")
    print(f"Run ID      : {run.id}")
    print(f"Sweep ID    : {run.sweep.id}")
    print(f"Val Accuracy: {run.summary['val_acc']:.4f}")
    print(f"Config      : {run.config}")

**Testing on top 3 accuracies of model and showing predictions in grid of images**

In [None]:
wandb.init(project="DA6401-Assignment-2", name="testing")

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer

# 1. Load test dataset
def load_test_data(test_dir='inaturalist_12K/val', batch_size=32):
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_dataset = ImageFolder(root=test_dir, transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return test_loader, test_dataset.classes

# 2. Visualization
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    return np.transpose(npimg, (1, 2, 0))

def show_predictions(model, dataloader, classes, rows=10, cols=3):
    model.eval()
    images_shown = 0
    plt.figure(figsize=(15, 30))
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(model.device)
            labels = labels.to(model.device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            for i in range(images.size(0)):
                if images_shown >= rows * cols:
                    break
                plt.subplot(rows, cols, images_shown + 1)
                plt.imshow(imshow(images[i].cpu()))
                plt.title(f"Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}", fontsize=8)
                plt.axis('off')
                images_shown += 1
            if images_shown >= rows * cols:
                break
    plt.tight_layout()
    plt.show()

# 3. Load test data once
test_loader, test_classes = load_test_data()

# 4. Train and evaluate all 3 models

# Model 1 – Top Val Accuracy: 0.4095
model1 = train_CNN_lightning(
    epochs=20,
    filter_mode="inc_dec",
    act_str="RELU",
    batch_count=64,
    data_aug='y',
    batch_norm='y',
    drop='n',
    optimizer_name="adam",
    lr_rate=0.001,
    drop_value=0.2,
    kernel_sizes=[3, 3, 3, 3, 3],
    test_loader=test_loader
)
# trainer1 = Trainer(accelerator='auto', devices=1 if torch.cuda.is_available() else None)
# trainer1.test(model1, test_loader)
show_predictions(model1, test_loader, test_classes)

# Model 2 – Val Accuracy: 0.4005
model2 = train_CNN_lightning(
    epochs=20,
    filter_mode="dec_inc",
    act_str="Mish",
    batch_count=64,
    data_aug='y',
    batch_norm='y',
    drop='n',
    optimizer_name="momentum",
    lr_rate=0.0001,
    drop_value=0.3,
    kernel_sizes=[5, 5, 3, 3, 3],
    test_loader=test_loader
)
# trainer2 = Trainer(accelerator='auto', devices=1 if torch.cuda.is_available() else None)
# trainer2.test(model2, test_loader)
show_predictions(model2, test_loader, test_classes)

# Model 3 – Val Accuracy: 0.3960
model3 = train_CNN_lightning(
    epochs=13,
    filter_mode="dec_inc",
    act_str="GELU",
    batch_count=32,
    data_aug='y',
    batch_norm='y',
    drop='n',
    optimizer_name="nadam",
    lr_rate=0.0001,
    drop_value=0.2,
    kernel_sizes=[3, 3, 3, 3, 3],
    test_loader=test_loader
)
# trainer3 = Trainer(accelerator='auto', devices=1 if torch.cuda.is_available() else None)
# trainer3.test(model3, test_loader)
show_predictions(model3, test_loader, test_classes)


**Plotting Confusion Matrix**

In [None]:
wandb.init(project="DA6401-Assignment-2", name="confusion_matrix1")

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

model1.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for data, labels in test_loader:
        data = data.to(device)
        outputs = model1(data)

        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute the confusion matrix
cm = confusion_matrix(all_labels, all_preds)
# Log the confusion matrix to wandb
wandb.log({
    "confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=all_labels,
        preds=all_preds,
        class_names=test_classes
    )
})

# Plotting the confusion matrix
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cm, annot=True, fmt='d', ax=ax, cmap='Blues', xticklabels=test_classes, yticklabels=test_classes)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
plt.show()