# LIBRARIES

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch import optim
import torchmetrics

import numpy as np
import matplotlib.pyplot as plt

In [None]:
import wandb
wandb.login()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DATASET

In [None]:

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


# VISUALIZE TRAIN IMAGES

In [None]:
import torchvision

def imshow(inp, title=None, ax=None, figsize=(5, 5)):
  """Imshow for Tensor."""
  inp = inp.numpy().transpose((1, 2, 0))
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  inp = std * inp + mean
  inp = np.clip(inp, 0, 1)
  if ax is None:
    fig, ax = plt.subplots(1, figsize=figsize)
  ax.imshow(inp)
  ax.set_xticks([])
  ax.set_yticks([])
  if title is not None:
    ax.set_title(title)

# Get a batch of training data
inputs, classes = next(iter(train_dataloader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs, nrow=4)

fig, ax = plt.subplots(1, figsize=(10, 10))
imshow(out,
        #title=[class_names[x] for x in classes],
        ax=ax)

# LINEAR NEURAL NET

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()
model.to(device)
model

#  CONVOLUTION NEURAL NETWORK

class ConvNet(nn.Module):
    def __init__(self, input_shape=(1,28,28)):
        super(ConvNet, self).__init__()

        # DEFINE THE CONVOLUTION LAYERS

        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64,3)
        self.conv3 = nn.Conv2d(64, 128, 3)

        # DEFINE THE POOLING LATER

        self.pool = nn.MaxPool2d(2, 2)

        n_size = self._get_conv_output(input_shape)

        # DEFINE THE LINEAR CLASSS

        self.fc1 = nn.Linear(n_size, 512)
        self.fc2 = nn.Linear(512, 10)

        # DEFINE THE DROPOUT
        self.dropout = nn.Dropout(p=0.25)

    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))
        output_feat = self._forward_features(input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    def _forward_features(self, X):
        X = self.pool(F.relu(self.conv1(X)))
        X = self.pool(F.relu(self.conv2(X)))
        X = self.pool(F.relu(self.conv3(X)))
        return X

    def forward(self, X):
        X = self._forward_features(X)
        X = X.view(X.size(0), -1)
        X = self.dropout(X) # DROPOUT
        X = F.relu(self.fc1(X))
        X = self.dropout(X) # DROPOUT
        X = self.fc2(X)
        return X
        
model = ConvNet()

# PUSH MODEL TO DEVICE(CUDA)
model.to(device)
print(model)

# HYPERPARAMETERS

In [None]:
config = dict(learning_rate = 1e-3,
                batch_size = 64,
                epochs = 10,
                loss = nn.CrossEntropyLoss(),
                #loss = nn.NLLLoss(),
                optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-5)
                )

In [None]:
for k, v in config.items():
    print(f"{k}: {v}")

# TRAINING

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    #acc = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        acc = torchmetrics.functional.accuracy(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"train loss: {loss:>3f}  [{current:>5d}/{size:>5d}]")
            #wandb.log({"train loss": loss, "train accuracy": acc})
    print(f"Train Metrics: \n Train Accuracy: {(100*acc):>0.1f}%, Train Loss: {loss:>8f} \n")
    wandb.log({"Train Accuracy": acc, "Train Loss": loss })


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            acc = torchmetrics.functional.accuracy(pred, y)
            #correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    wandb.log({"Test Accuracy":acc, "Test Loss": test_loss})

    test_loss /= num_batches
    correct /= size
    #print(f"Test Metrics: \n Test Accuracy: {(100*correct):>0.1f}%, Test Loss: {test_loss:>8f} \n")
    print(f"Test Metrics: \n Test Accuracy: {(100*acc):>0.1f}%, Test Loss: {test_loss:>8f} \n")
    #wandb.log({"Test Accuracy":acc, "Test Loss": test_loss})

In [None]:
# define loss and optimizer

#loss_fn = nn.NLLLoss()
#loss_fn = config.get("loss")
loss_fn = nn.CrossEntropyLoss()

#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adagrad(model.parameters(), lr=config.get("learning_rate"))
#optimizer = config.get("optimizer")

print(loss_fn, optimizer)


In [None]:
%%time
def train():
    # train with wandb sweeps
    with wandb.init(project="demo_wandb_fashionmnsit_test", config=config):
        
        wandb.watch(model, criterion=loss_fn, log="all", log_freq=10)
        
        epochs = 10
        for t in range(epochs):

            print(f"Epoch {t+1}\n-------------------------------")
            train_loop(train_dataloader, model, loss_fn, optimizer)
            test_loop(test_dataloader, model, loss_fn)
        wandb.save(torch.save(model.state_dict(), f=f"./models/fashion_mnist_{t}.pt"))
    print("Done!")


# HYPER PARAMETER OPTIMIZATION
## SWEEPS

In [None]:
# this line initializes the sweep
sweep_id = wandb.sweep({'name': 'my-awesome-sweep',
                        'metric': 'accuracy',
                        'method': 'grid',
                        'parameters': {'a': {'values': [1, 2, 3, 4]}}})

# this line actually runs it -- parameters are available to
# my_train_func via wandb.config
wandb.agent(sweep_id, function=my_train_func)

In [None]:
sweep_config = {
    'method': 'grid', #grid, random
    'metric': {
      'name': 'loss',
      'goal': 'minimize'   
    },
    'metric':{
        'name': 'accuracy',
        'goal':'maximize'
    },
    'parameters': {
        'epochs': {
            'values': [2, 5, 10]
        },
        'batch_size': {
            'values': [ 64, 32, 16, 8]
        },
        'dropout': {
            'values': [0.1, 0.3, 0.5]
        },
        'learning_rate': {
            'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
        },
        'fc_layer_size':{
            'values':[128,256,512]
        },
        'optimizer': {
            'values': ['adam', 'sgd', 'adagrad']
        },
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="demo_wandb_fashionmnsit_test")

In [None]:
wandb.agent(sweep_id, train, count=5)