# Lecture 24a: Gradient Descent Learning Rule
### Updating parameters once every epoch

## Load Packages

In [None]:
%matplotlib inline
import time
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

print(torch.__version__) # This code has been updated for PyTorch 1.0.0

## Load Data

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
BatchSize = 100

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BatchSize,
                                          shuffle=True, num_workers=4) # Creating dataloader

testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BatchSize,
                                         shuffle=False, num_workers=4) # Creating dataloader

In [None]:
# Check availability of GPU

use_gpu = torch.cuda.is_available()
if use_gpu:
    print('GPU is available!')
    device = "cuda"
else:
    print('GPU is not available!')
    device = "cpu"

## Neural Network

In [None]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.Layer1 = nn.Sequential(
            nn.Linear(28*28, 400),
            nn.ReLU(),
            nn.Linear(400, 256),
            nn.ReLU())
        self.Layer2 = nn.Sequential(
            nn.Linear(256, 10))

    def forward(self, x):
        x = self.Layer1(x)
        x = self.Layer2(x)
        return x

In [None]:
net = NeuralNet()
net = net.to(device)

## Train Classifier

In [None]:
iterations = 10
learning_rate = 0.1
criterion = nn.CrossEntropyLoss()

Plotacc = []
trainLoss = []

for epoch in range(iterations):  # loop over the dataset multiple times
    start = time.time()
    correct = 0  
    runningLoss = 0    
    total = 0
    net.train() # For training
    for i,data in enumerate(trainloader): # i -> batch number
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.view(-1, 28*28).to(device), labels.to(device)         
              
        outputs = net(inputs) # forward 
        loss = criterion(outputs, labels) # calculate loss   
        if i == 0: #First batch              
            totalLoss = loss                          
        else:
            totalLoss += loss               
   
    totalLoss = totalLoss/(i+1) 
    # updating parameters once in every epoch      
    net.zero_grad()  # zeroes the gradient buffers of all parameters    
    totalLoss.backward()
    for f in net.parameters():
        f.data.sub_(f.grad.data * learning_rate) # weight = weight - learning_rate * gradient (Update Weights)      
    
    trainLoss.append(totalLoss.item())   
    
    net.eval() # For testing [Affects batch-norm and dropout layers (if any)]
    with torch.no_grad(): # Gradient computation is not involved in inference
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.view(-1, 28*28).to(device), labels.to(device)
            total += labels.size(0)

            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum()      

    Plotacc.append(float(correct)*100/float(total))
    epochTimeEnd = time.time()-start
    print('At Epoch {:.0f}: Loss = {:.6f} , Acc = {:.4f}%'.format(epoch+1,totalLoss.item(),float(correct)*100/float(total)))   
    print('Epoch completed in {:.0f}m {:.0f}s'.format(epochTimeEnd // 60, epochTimeEnd % 60))
    
fig = plt.figure()        
plt.plot(range(epoch+1),trainLoss,'r-',label='Cross Entropy Loss')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.ylabel('Training Loss')  
    
fig = plt.figure()        
plt.plot(range(epoch+1),Plotacc,'g-',label='Accuracy')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.ylabel('Testing Accuracy')  
print('Finished Training')
