<a href="https://colab.research.google.com/github/jyanivaddi/ERA_V1/blob/master/session_6/S6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchsummary import summary
from s6_model import Net, model_summary, model_train, model_test
from s6_utils import load_mnist_data, preview_batch_images, plot_statistics

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
mnist_model = Net().to(device)
model_summary(mnist_model, input_size=(1, 28, 28))

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomApply([transforms.CenterCrop(22),],p=0.1),
    transforms.Resize((28,28)),
    transforms.RandomRotation([-15.,15],fill=0),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,)),
    ])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(.3081,)),
])

In [None]:
train_data, test_data = load_mnist_data(train_transforms, test_transforms)

In [None]:
torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, **kwargs)


In [None]:
preview_batch_images(train_loader)

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1, verbose=True)
train_losses = []
test_losses = []
train_acc = []
test_acc = []
for epoch in range(1, 20):
    model_train(model, device, train_loader, optimizer, train_acc, train_losses)
    model_test(model, device, test_loader, test_acc, test_losses)
    scheduler.step()

In [None]:
plot_statistics(train_losses, train_acc, test_losses, test_accuracy)