## Installation



In [None]:
pip install wandb numpy pandas matplotlib torch torchvision

## Dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !unzip /content/drive/MyDrive/DL_Assignment2/Dataset/nature_12K.zip -d /content/drive/MyDrive/DL_Assignment2/Dataset

In [3]:
import os
os.chdir("/content/drive/MyDrive/DL_Assignment2")

## Libraries

In [21]:
%%writefile libraries.py
import torch
import os
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from torch.utils.data import Subset, DataLoader
import numpy as np
import torch.nn as nn
import torch.optim as optim
import wandb
import tqdm
import gc
import matplotlib.pyplot as plt

Overwriting libraries.py


## Dataset loader

In [22]:
%%writefile data_loader.py
import numpy as np
import torch
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
import os


def validationDataSplit(train_dataset):
  classLabels = [label for _,label in train_dataset.samples]
  num_classes = len(np.unique(classLabels))

  sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
  train_indices, val_indices = next(sss.split(train_dataset.samples, classLabels))

  train_subset = Subset(train_dataset, train_indices)
  val_subset = Subset(train_dataset, val_indices)
  return train_subset, val_subset, num_classes


def load_data(base_dir, isDataAug, batch_size):
  train_dir = os.path.join(base_dir, 'train')
  test_dir = os.path.join(base_dir, 'val')

  train_transform, test_transform = None, None

  if isDataAug == False:
    train_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  else:
    train_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.RandomRotation(10),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

  test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

  train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
  test_dataset = datasets.ImageFolder(test_dir, transform=test_transform)
  train_dataset, val_dataset, num_classes = validationDataSplit(train_dataset)

  # print(f"inp: {train_dataset[0][0].shape} {train_dataset[0][1]}")

  train_loader = DataLoader(train_dataset,shuffle=True,num_workers=2,batch_size=batch_size,pin_memory=True)
  test_loader = DataLoader(test_dataset,shuffle=True,num_workers=2,batch_size=64,pin_memory=True)
  val_loader = DataLoader(val_dataset,shuffle=True,num_workers=2,batch_size=64,pin_memory=True)

  return train_loader, test_loader, val_loader, num_classes

# load_data("/content/drive/MyDrive/DL_Assignment2/Dataset/inaturalist_12K/", True)

Overwriting data_loader.py


## Training CNN

In [23]:
%%writefile neural_network.py
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch

class ConvolutionalNeuralNetwork(nn.Module):
  activationFunctionsMap = {"ReLU": nn.ReLU, "GELU": nn.GELU, "SiLU": nn.SiLU}
  # optimizersMap = {"sgd": optim.SGD, "rmsprop": optim.RMSprop, "adam": optim.Adam}

  def __init__(self, num_classes,
               num_filters, filter_sizes,
               activationFun, optimizer,
               n_neurons_denseLayer,
               isBatchNormalization, dropout,
               learning_rate=0.001,
               momentum=0.5, beta = 0.9,
               beta1=0.9, beta2=0.99,
               epsilon=1e-8, weight_decay=0.0001):
    super(ConvolutionalNeuralNetwork, self).__init__()
    self.num_classes = num_classes
    self.num_filters = num_filters
    self.filter_sizes = filter_sizes
    self.activationFun = ConvolutionalNeuralNetwork.activationFunctionsMap[activationFun]
    # self.optimizer = ConvolutionalNeuralNetwork.optimizersMap[optimizer]

    self.n_neurons_denseLayer = n_neurons_denseLayer
    self.isBatchNormalization = isBatchNormalization
    self.dropout = dropout

    self.lr = learning_rate
    self.momentum = momentum
    self.betas = (beta1, beta2)
    self.eps = epsilon
    self.alpha = beta
    self.weight_decay = weight_decay

    self.defineModel()

    if(optimizer == "sgd"):
      self.optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
    elif(optimizer == "rmsprop"):
      self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay)
    elif(optimizer == "adam"):
      self.optimizer = optim.Adam(self.parameters(), lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay)



  def defineModel(self):
    self.model = nn.Sequential()

    inChannels = 3;     # RGB channels for inaturalist
    for i in range(len(self.num_filters)):
      self.model.append(nn.Conv2d(inChannels, self.num_filters[i], self.filter_sizes[i], padding=self.filter_sizes[i]//2))
      if self.isBatchNormalization:
        self.model.append(nn.BatchNorm2d(self.num_filters[i]))
      self.model.append(self.activationFun())
      self.model.append(nn.MaxPool2d(kernel_size=2))
      inChannels = self.num_filters[i]

    # computing flattened size
    input_shape = (3, 224, 224)
    with torch.no_grad():
      dummy_input = torch.zeros(1, *input_shape)
      dummy_output = self.model(dummy_input)
      flattened_size = dummy_output.view(dummy_output.size(0), -1).size(1)

    self.model.append(nn.Flatten())
    self.model.append(nn.Linear(flattened_size, self.n_neurons_denseLayer))
    self.model.append(self.activationFun())

    if(self.dropout > 0):
      self.model.append(nn.Dropout(self.dropout))

    self.model.append(nn.Linear(self.n_neurons_denseLayer, self.num_classes))

  def forward(self, inputs):
    return self.model(inputs)

  def backward(self, outputs, labels):
    loss = nn.CrossEntropyLoss()(outputs, labels)
    loss.backward()

  def updateWeights(self):
    self.optimizer.step()

Overwriting neural_network.py


## Accuracy calculation

In [24]:
%%writefile accuracy_calculation.py
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader

def findOutputs(device, cnn, inputDataLoader, isTestData=False):
  cnn.eval()  # setting the model to evaluation model
  outputs = []
  total_loss = 0.0
  n_correct = 0
  n_correct_top5 = 0
  n_correct_top2 = 0
  n_samples = 0

  with torch.no_grad():
    for batch_idx, (x_batch, y_batch) in enumerate(inputDataLoader):
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      batch_outputs = cnn(x_batch)

      loss = nn.CrossEntropyLoss()(batch_outputs, y_batch)
      total_loss += loss.item() * x_batch.size(0)

      y_pred_batch = torch.argmax(batch_outputs, dim=1)
      n_correct += (y_pred_batch == y_batch).sum().item()
      n_samples += x_batch.size(0)

      if isTestData == True:
          y_pred_batch_top5 = torch.topk(batch_outputs, 5, dim=1).indices
          n_correct_top5 += y_pred_batch_top5.eq(y_batch.view(-1, 1)).sum().item()

          y_pred_batch_top2 = torch.topk(batch_outputs, 2, dim=1).indices
          n_correct_top2 += y_pred_batch_top2.eq(y_batch.view(-1, 1)).sum().item()
      outputs.append(batch_outputs)

  outputs = torch.cat(outputs)
  accuracy = (n_correct * 100.0) / n_samples
  avg_loss = total_loss / n_samples

  top5_accuracy = None
  top2_accuracy = None
  if isTestData == True:
      top5_accuracy = (n_correct_top5 * 100.0) / n_samples
      top2_accuracy = (n_correct_top2 * 100.0) / n_samples
  return outputs, accuracy, avg_loss, top5_accuracy, top2_accuracy

Overwriting accuracy_calculation.py


## Training (Argparser included)

In [27]:
%%writefile train_local.py
import os
import gc
import wandb
import torch
from neural_network import *
from data_loader import *
from accuracy_calculation import *
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def trainNeuralNetwork_local(args):
  wandb.login()
  wandb.init(mode="online")
  wandb.init(project=args.wandb_project, entity=args.wandb_entity)
  if args.isDataAug == "True":
    isDataAug = True
  else:
    isDataAug = False

  if args.isBatchNormalization == "True":
    isBatchNormalization = True
  else:
    isBatchNormalization = False

  train_loader, test_loader, val_loader, num_classes = load_data(args.base_dir, isDataAug, args.batch_size)
  activationFun = args.activation
  optimizer = args.optimizer
  learning_rate = args.learning_rate
  momentum = args.momentum
  beta = args.beta
  beta1 = args.beta1
  beta2 = args.beta2
  epsilon = args.epsilon
  weight_decay = args.weight_decay
  dropout = args.dropout
  num_filters = args.num_filters
  filter_sizes = args.filter_sizes
  n_neurons_denseLayer = args.n_neurons_denseLayer

  wandb.run.name = f"{activationFun}_{optimizer}_{dropout}_{n_neurons_denseLayer}_DataAug-{isDataAug}_BatchNorm-{isBatchNormalization}"
  best_val_accuracy = 0.0
  best_accuracy_epoch = -1

  cnn = ConvolutionalNeuralNetwork(num_classes,
                                   num_filters, filter_sizes,
                                   activationFun, optimizer,
                                   n_neurons_denseLayer,
                                   isBatchNormalization, dropout,
                                   learning_rate,
                                   momentum, beta,
                                   beta1, beta2,
                                   epsilon, weight_decay)
  cnn.to(device)

  epochs = args.epochs
  for epochNum in range(epochs):
    print(f"Epoch {epochNum}:")
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
      if(batch_idx % 40 == 0):
        print(f"Batch idx {batch_idx} running")
        # break
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      cnn.optimizer.zero_grad()
      outputs = cnn(x_batch)
      cnn.backward(outputs, y_batch)
      cnn.updateWeights()
      del x_batch, y_batch, outputs

    # Validation accuracy
    val_outputs, val_accuracy, val_loss, _, _ = findOutputs(device, cnn, val_loader)
    print(f"validation: loss={val_loss}, accuracy={val_accuracy}")

    # Train accuracy
    train_outputs, train_accuracy, train_loss, _, _ = findOutputs(device, cnn, train_loader)
    print(f"training: loss={train_loss}, accuracy={train_accuracy}")

    if val_accuracy > best_val_accuracy:
      best_val_accuracy = val_accuracy
      best_accuracy_epoch = epochNum

    wandb.log({
        "epoch": epochNum + 1,
        "val_loss": val_loss,
        "val_accuracy": val_accuracy,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy
        },commit=True)
    del val_outputs, train_outputs
    gc.collect()
    torch.cuda.empty_cache()

  wandb.log({
      "best_acc_epoch": best_accuracy_epoch,
      "best_val_accuracy": best_val_accuracy
  })

  test_outputs, test_accuracy, test_loss, test_top5_accuracy, test_top2_accuracy = findOutputs(device, cnn, test_loader, True)
  print(f"testing: loss={test_loss}, top1_accuracy={test_accuracy}, top5_accuracy = {test_top5_accuracy}, top2_accuracy = {test_top2_accuracy}")

  wandb.log({
      "test_loss": test_loss,
      "test_top1_accuracy": test_accuracy,
      "test_top5_accuracy": test_top5_accuracy,
      "test_top2_accuracy": test_top2_accuracy
  })
  del cnn,train_loader, test_loader, val_loader
  gc.collect()
  torch.cuda.empty_cache()

  wandb.finish()

Overwriting train_local.py


### ArgParser

In [26]:
%%writefile argument_parser.py
import argparse

def parse_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument("-wp", "--wandb_project", type=str, default="DA6401_Assignment2",
                        help="Project name used to track experiments in Weights & Biases dashboard")
    parser.add_argument("-we", "--wandb_entity", type=str, default="nikhithaa-iit-madras",
                        help="Wandb Entity used to track experiments in the Weights & Biases dashboard.")
    parser.add_argument("-bd", "--base_dir", type=str, default="inaturalist_12K",
                        help="Base directory where dataset (train/val folders) are present")
    parser.add_argument("-e", "--epochs", type=int, default=10,
                        help="Number of epochs to train neural network")
    parser.add_argument("-b", "--batch_size", type=int, default=32,
                        help="Batch size used to train neural network")
    parser.add_argument("-o", "--optimizer", type=str, choices=["sgd", "rmsprop", "adam"], default="sgd",
                        help="Choose one among these optimizers: ['sgd', 'rmsprop', 'adam']")
    parser.add_argument("-lr", "--learning_rate", type=float, default=0.001,
                        help="Learning rate used to optimize model parameters")
    parser.add_argument("-m", "--momentum", type=float, default=0.9,
                        help="Momentum used by momentum and nag optimizers")
    parser.add_argument("-beta", "--beta", type=float, default=0.9,
                        help="Beta used by rmsprop optimizer")
    parser.add_argument("-beta1", "--beta1", type=float, default=0.9,
                        help="Beta1 used by adam and nadam optimizers")
    parser.add_argument("-beta2", "--beta2", type=float, default=0.999,
                        help="Beta2 used by adam and nadam optimizers")
    parser.add_argument("-eps", "--epsilon", type=float, default=0.00000001,
                        help="Epsilon used by optimizers")
    parser.add_argument("-w_d", "--weight_decay", type=float, default=0.0001,
                        help="Weight decay used by optimizers")
    parser.add_argument("-dp", "--dropout", type=float, default=0.0,
                        help="Dropout used in convolution neural network")
    parser.add_argument("-da", "--isDataAug", type=str, default="False",
                        help="Whether to use data augmentation or not")
    parser.add_argument("-bn", "--isBatchNormalization", type=str, default="False",
                        help="Whether to use batch normalization or not")
    parser.add_argument("-nf", "--num_filters", type=int, nargs=5,  default=[3, 3, 3, 3, 3],
                        help="Number of filters used in each convolution layer")
    parser.add_argument("-fsz", "--filter_sizes", type=int, nargs=5,  default=[32, 64, 64, 128, 256],
                        help="Size of filters in each convolution layer")
    parser.add_argument("-a", "--activation", type=str, choices=["ReLU", "SiLU", "GELU"], default="SiLU",
                        help="Choose one among these activation functions: ['ReLU', 'SiLU', 'GELU']")
    parser.add_argument("-ndl", "--n_neurons_denseLayer", type=int, default=128,
                        help="Number of neurons in dense layer")

    return parser.parse_args()

Overwriting argument_parser.py


### Main File

In [28]:
%%writefile main.py
from train_local import *
from argument_parser import *
import libraries

if __name__=="__main__":
  args = parse_arguments()
  trainNeuralNetwork_local(args)

Overwriting main.py


## Running main

In [None]:
# change epochs
!python3 main.py -wp DA6401_Assignment2 -we nikhithaa-iit-madras -b 128 -beta1 0.9 -beta2 0.999 -lr 0.001 -e 1 --base_dir Dataset/inaturalist_12K -o sgd -a SiLU -w_d 0 -nf 3 3 3 3 3 -fsz 32 64 64 128 256 -ndl 128 -dp 0.2 -bn False -da True

## Training(Sweep)

In [30]:
%%writefile train_sweep.py
# import libraries
from neural_network import *
from data_loader import *
from accuracy_calculation import *
import numpy as np
import torch.nn as nn
import wandb
import gc
import os
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def trainNeuralNetwork_sweep():
  wandb.init(mode="online")
  args = wandb.config
  train_loader, test_loader, val_loader, num_classes = load_data(args["base_dir"], args["isDataAug"], args["batch_size"])
  activationFun = args["activation"]
  optimizer = args["optimizer"]
  learning_rate = args["learning_rate"]
  momentum = args["momentum"]
  beta = args["beta"]
  beta1 = args["beta1"]
  beta2 = args["beta2"]
  epsilon = args["epsilon"]
  weight_decay = args["weight_decay"]
  dropout = args["dropout"]
  num_filters = args["num_filters"]
  filter_sizes = args["filter_sizes"]
  n_neurons_denseLayer = args["n_neurons_denseLayer"]
  isBatchNormalization = args["isBatchNormalization"]
  isDataAug = args["isDataAug"]

  wandb.run.name = f"{activationFun}_{optimizer}_{dropout}_{n_neurons_denseLayer}_DataAug-{isDataAug}_BatchNorm-{isBatchNormalization}"
  best_val_accuracy = 0.0
  best_accuracy_epoch = -1

  cnn = ConvolutionalNeuralNetwork(num_classes,
                                   num_filters, filter_sizes,
                                   activationFun, optimizer,
                                   n_neurons_denseLayer,
                                   isBatchNormalization, dropout,
                                   learning_rate,
                                   momentum, beta,
                                   beta1, beta2,
                                   epsilon, weight_decay)
  cnn.to(device)

  epochs = args["epochs"]
  for epochNum in range(epochs):
    print(f"Epoch {epochNum}:")
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
      if(batch_idx % 40 == 0):
        print(f"Batch idx {batch_idx} running")
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      cnn.optimizer.zero_grad()
      outputs = cnn(x_batch)
      cnn.backward(outputs, y_batch)
      cnn.updateWeights()
      del x_batch, y_batch, outputs

    # Validation accuracy
    val_outputs, val_accuracy, val_loss, _, _ = findOutputs(device, cnn, val_loader)
    # wandb.run.summary["metric_name"] = val_accuracy
    print(f"validation: loss={val_loss}, accuracy={val_accuracy}")

    # Train accuracy
    train_outputs, train_accuracy, train_loss, _, _ = findOutputs(device, cnn, train_loader)
    print(f"training: loss={train_loss}, accuracy={train_accuracy}")

    if val_accuracy > best_val_accuracy:
      best_val_accuracy = val_accuracy
      best_accuracy_epoch = epochNum

    wandb.log({
        "epoch": epochNum + 1,
        "val_loss": val_loss,
        "val_accuracy": val_accuracy,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy
        },commit=True)
    del val_outputs, train_outputs
    gc.collect()
    torch.cuda.empty_cache()

  wandb.log({
      "best_acc_epoch": best_accuracy_epoch,
      "best_val_accuracy": best_val_accuracy
  })

  test_outputs, test_accuracy, test_loss, test_top5_accuracy, test_top2_accuracy = findOutputs(device, cnn, test_loader, True)
  print(f"testing: loss={test_loss}, top1_accuracy={test_accuracy}, top5_accuracy = {test_top5_accuracy}, top2_accuracy = {test_top2_accuracy}")

  wandb.log({
      "test_loss": test_loss,
      "test_top1_accuracy": test_accuracy,
      "test_top5_accuracy": test_top5_accuracy,
      "test_top2_accuracy": test_top2_accuracy
  })
  del cnn,train_loader, test_loader, val_loader
  gc.collect()
  torch.cuda.empty_cache()

  wandb.finish()

Overwriting train_sweep.py


### Main Sweep File

In [31]:
%%writefile main_sweep.py
from train_sweep import *
# import libraries
import wandb

# best_acc_sweep_configuration = {
#     "method": "random",
#     "name" : "test_sweep1",
#     "parameters": {
#         "num_filters": {'values': [[256, 128, 64, 64, 32]]},
#         "filter_sizes": {'values': [[3, 3, 3, 3, 3]]},
#         "activation": {"values": ["SiLU"]},
#         "optimizer": {"values": ["sgd"]},
#         "learning_rate": {"values": [1e-3]},
#         "weight_decay": {"values": [0.0001]},
#         "momentum": {"values": [0.9]},
#         "beta": {"values": [0.9]},
#         "beta1": {"values":[0.9]},
#         "beta2": {"values": [0.999]},
#         "epsilon": {"values": [1e-8]},
#         # "base_dir": {"values":["/content/drive/MyDrive/DL_Assignment2/Dataset/inaturalist_12K/"]},
#         "base_dir": {"values": ["/kaggle/input/inaturalist/inaturalist_12K"]},
#         "isDataAug": {"values": ["False"]},
#         "isBatchNormalization": {"values": ["False"]},
#         "dropout": {"values": [0.3]},
#         "n_neurons_denseLayer": {"values": [128]},
#         "batch_size": {"values": [32]},
#         "epochs": {"values": [10]}
#     }
# }

sweep_configuration = {
    "method": "random",
    "name" : "train_sweep_final2_v1",
    "parameters": {
        "num_filters": {'values': [[32, 32, 32, 32, 32], [32, 64, 64, 128, 256], [256, 128, 64, 64, 32]]},
        "filter_sizes": {'values': [[3, 3, 3, 3, 3], [5, 5, 5, 5, 5],[3,3,5,5,1]]},
        "activation": {"values": ["ReLU", "SiLU", "GELU"]},
        "optimizer": {"values": ["adam", "rmsprop", "sgd"]},
        "learning_rate": {"values": [1e-3]},
        "weight_decay": {"values": [0.0001]},
        "momentum": {"values": [0.9]},
        "beta": {"values": [0.9]},
        "beta1": {"values":[0.9]},
        "beta2": {"values": [0.999]},
        "epsilon": {"values": [1e-8]},
        "base_dir": {"values":["/content/drive/MyDrive/DL_Assignment2/Dataset/inaturalist_12K/"]},
        # "base_dir": {"values": ["/kaggle/input/inaturalist/inaturalist_12K"]},
        "isDataAug": {"values": ["False", "True"]},
        "isBatchNormalization": {"values": ["True", "False"]},
        "dropout": {"values": [0.2, 0.3]},
        "n_neurons_denseLayer": {"values": [128, 256]},
        "batch_size": {"values": [32,64]},
        "epochs": {"values": [5,10]}
    }
}

if __name__=="__main__":
  wandb.login()
  wandb_id = wandb.sweep(sweep_configuration, project="DA6401_Assignment2")
  wandb.agent(wandb_id, function=trainNeuralNetwork_sweep)

Overwriting main_sweep.py


### Running

In [None]:
!python3 main_sweep.py

## 10 x 3 Grid Visualization (Havent written into file)

> Add blockquote



In [None]:
best_acc_sweep_configuration1 = {
    "method": "random",
    "name" : "visualize_sweep_final",
    "parameters": {
        "num_filters": {'values': [[256, 128, 64, 64, 32]]},
        "filter_sizes": {'values': [[3, 3, 3, 3, 3]]},
        "activation": {"values": ["SiLU"]},
        "optimizer": {"values": ["sgd"]},
        "learning_rate": {"values": [1e-3]},
        "weight_decay": {"values": [0.0001]},
        "momentum": {"values": [0.9]},
        "beta": {"values": [0.9]},
        "beta1": {"values":[0.9]},
        "beta2": {"values": [0.999]},
        "epsilon": {"values": [1e-8]},
        # "base_dir": {"values":["/content/drive/MyDrive/DL_Assignment2/Dataset/inaturalist_12K/"]},
        "base_dir": {"values": ["/kaggle/input/inaturalist/inaturalist_12K"]},
        "isDataAug": {"values": ["False"]},
        "isBatchNormalization": {"values": ["False"]},
        "dropout": {"values": [0.3]},
        "n_neurons_denseLayer": {"values": [128]},
        "batch_size": {"values": [32]},
        "epochs": {"values": [10]}
    }
}

def unnormalize(img):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(m)
    return torch.clamp(img, 0, 1)

def visualizeOutputs(cnn, testDataLoader):
  cnn.eval()  # setting the model to evaluation model
  total_loss = 0.0
  n_correct = 0
  n_correct_top2 = 0
  n_samples = 0

  pred_output_perLabel = [[] for i in range(10)]
  x_values_perLabel = [[] for i in range(10)]

  with torch.no_grad():
    for batch_idx, (x_batch, y_batch) in enumerate(testDataLoader):
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      batch_outputs = cnn(x_batch)

      loss = nn.CrossEntropyLoss()(batch_outputs, y_batch)
      total_loss += loss.item() * x_batch.size(0)

      y_pred_batch = torch.argmax(batch_outputs, dim=1)
      n_correct += (y_pred_batch == y_batch).sum().item()
      n_samples += x_batch.size(0)

      y_pred_batch_top2 = torch.topk(batch_outputs, 2, dim=1).indices
      n_correct_top2 += y_pred_batch_top2.eq(y_batch.view(-1, 1)).sum().item()

      # Collecting 3 images per class
      for i in range(len(y_batch)):
        label = y_batch[i].item()
        if len(pred_output_perLabel[label]) < 3:
           x_values_perLabel[label].append(x_batch[i].cpu())
           pred_output_perLabel[label].append(y_pred_batch[i].item())

  labelNameList = ["Amphibia", "Animalia", "Arachnida", "Aves", "Fungi", "Insecta", "Mammalia", "Mollusca", "Plantae", "Reptilia"]
  # Visualizing collected results
  fig, axes = plt.subplots(10, 3, figsize=(10, 20))
  plt.title("True Vs Pred results for each class label")
  for i in range(10):
     for j in range(3):
       predLabel = pred_output_perLabel[i][j]
       trueLabel = i
       print(f"True = {trueLabel}, Pred = {predLabel}")
       image = x_values_perLabel[i][j]
       image = unnormalize(image)
       ax = axes[i, j]
       ax.imshow(image.permute(1, 2, 0))
       ax.set_title(f'TrueLabel: {labelNameList[trueLabel]}\nPredLabel: {labelNameList[predLabel]}')
       ax.axis('off')
  plt.tight_layout()
  wandb.log({"Grid": wandb.Image(plt)})
  plt.show()

  accuracy = (n_correct * 100.0) / n_samples
  avg_loss = total_loss / n_samples
  top2_accuracy = (n_correct_top2 * 100.0) / n_samples
  return accuracy, avg_loss, top2_accuracy

def visualization_sweep():
  wandb.init(mode="online")
  args = wandb.config
  train_loader, test_loader, val_loader, num_classes = load_data(args["base_dir"], args["isDataAug"], args["batch_size"])
  activationFun = args["activation"]
  optimizer = args["optimizer"]
  learning_rate = args["learning_rate"]
  momentum = args["momentum"]
  beta = args["beta"]
  beta1 = args["beta1"]
  beta2 = args["beta2"]
  epsilon = args["epsilon"]
  weight_decay = args["weight_decay"]
  dropout = args["dropout"]
  num_filters = args["num_filters"]
  filter_sizes = args["filter_sizes"]
  n_neurons_denseLayer = args["n_neurons_denseLayer"]
  isBatchNormalization = args["isBatchNormalization"]
  isDataAug = args["isDataAug"]

  wandb.run.name = f"{activationFun}_{optimizer}_{dropout}_{n_neurons_denseLayer}_DataAug-{isDataAug}_BatchNorm-{isBatchNormalization}"
  best_val_accuracy = 0.0
  best_accuracy_epoch = -1

  cnn = ConvolutionalNeuralNetwork(num_classes,
                                   num_filters, filter_sizes,
                                   activationFun, optimizer,
                                   n_neurons_denseLayer,
                                   isBatchNormalization, dropout,
                                   learning_rate,
                                   momentum, beta,
                                   beta1, beta2,
                                   epsilon, weight_decay)
  cnn.to(device)

  epochs = args["epochs"]
  for epochNum in range(epochs):
    print(f"Epoch {epochNum}:")
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
      if(batch_idx % 40 == 0):
        print(f"Batch idx {batch_idx} running")
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      cnn.optimizer.zero_grad()
      outputs = cnn(x_batch)
      cnn.backward(outputs, y_batch)
      cnn.updateWeights()
      del x_batch, y_batch, outputs

  test_accuracy, test_loss, test_top2_accuracy = visualizeOutputs(cnn, test_loader)
  print(f"testing: loss={test_loss}, top1_accuracy={test_accuracy}, top2_accuracy = {test_top2_accuracy}")

  wandb.log({
      "test_loss": test_loss,
      "test_top1_accuracy": test_accuracy,
      "test_top2_accuracy": test_top2_accuracy
  })


  del cnn,train_loader, test_loader, val_loader
  gc.collect()
  torch.cuda.empty_cache()

  wandb.finish()

wandb.login(key="x")
wandb_id = wandb.sweep(best_acc_sweep_configuration1, project="DA6401_Assignment2")
wandb.agent(wandb_id, function=visualization_sweep)