In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler


import numpy as np
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import dataset
import model_store

!pip install wandb -qU
import wandb
wandb.login()

OPTIMIZERS = ["SGD", "Momentum SGD", "Adam"]
LR = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
BATCH_SIZE = [4, 8, 16, 32, 64, 128]
MOMENTUM = [0., 0.5, 0.9]
SCHEDULER = ["Vanilla", "LinearLR", "LambdaLR", "StepLR", "CosineAnnealingLR"]

for scheduler in SCHEDULER:
  for lr in LR:

    # store hyper parameters
    hparams  = {
        "model": "CNN",
        "detaset": "CIFAR-10",
        "optimizer": "SGD",
        "momentum": 0.,
        "epochs": 5,
        "train_batch_size": 8,
        "eval_batch_size": 32,
        "lr": lr,
        "checkpoint": 100,
        "scheduler": scheduler
    }

    if hparams["optimizer"] == "Momentum SGD":
      hparams["momentum"] = 0.9

    train_bs = hparams["train_batch_size"]
    optim = hparams["optimizer"]
    run_id = f"{optim}_{lr}_{train_bs}_{scheduler}"
    wandb.init(config=hparams,
              project="CNN_lrscheduler_comparison_sgd__v2",
              entity="dsa4212-project",
              name=run_id,
              )

    # avalable GPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # get dataset
    train, test = dataset.get_dataset(hparams["detaset"])
    train_loader = torch.utils.data.DataLoader(
                    train,
                    batch_size=hparams["train_batch_size"],
                    shuffle=True,
                    num_workers=2)

    test_loader = torch.utils.data.DataLoader(
                    test,
                    batch_size=hparams["eval_batch_size"],
                    shuffle=False,
                    num_workers=2)

    hparams["n_classes"] = len(train.classes)
    hparams["input_shape"] = train[0][0].shape

    # get model
    model = model_store.get_model(hparams).to(device)

    # get optimizer
    if "SGD" in hparams["optimizer"]:
        optimizer = torch.optim.SGD(model.parameters(), lr=hparams["lr"], momentum=hparams["momentum"])
    elif hparams["optimizer"] == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=hparams["lr"])

    # scheduler
    if hparams["scheduler"] == "LambdaLR":
      lambda1 = lambda epoch: 0.65 ** epoch
      scheduler_inst = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    elif hparams["scheduler"] == "StepLR":
      scheduler_inst = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    elif hparams["scheduler"] == "CosineAnnealingLR":
      scheduler_inst = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
    elif hparams["scheduler"] == "LinearLR":
      scheduler_inst = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=30)

    criterion = nn.CrossEntropyLoss()

    print("Training Started")
    # training phase
    steps_per_epoch = len(train)/hparams['train_batch_size']
    steps = 0
    for epoch in range(hparams["epochs"]):

        for i, data in enumerate(train_loader, 0):
            steps += 1
            results = {
                      'step': steps,
                      'epoch': steps / steps_per_epoch,
                  }
            results["lr"] = optimizer.param_groups[0]['lr']

            images, labels = data
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            loss = F.cross_entropy(model(images), labels)
            results['loss'] = loss.item() / hparams["train_batch_size"]
            loss.backward()
            optimizer.step()

            if i % hparams["checkpoint"] == 0 or i == len(train_loader) - 1:
                correct = 0
                total = 0
                with torch.no_grad():
                  for data in test_loader:
                      images, labels = data
                      images, labels = images.to(device), labels.to(device)
                      logits = model(images)
                      pred_label = logits.argmax(dim=1)
                      total += labels.size(0)
                      correct += (pred_label == labels).sum().item()
                  results['accuracy'] = 100 * correct / total
            wandb.log(results)

        if hparams["scheduler"] != "Vanilla":
          scheduler_inst.step()

    print("Training Done")