# Torch playground

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm, trange

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

from extra.utils import *

## NN params

In [None]:
torch.cuda.is_available()

In [None]:
# batch_size = 54 # gtx960
batch_size = 256

## CIFAR preload

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

## MobileNetV2

In [None]:
from models.mobilenetv2 import *

## Prep

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

net = MobileNetV2()
net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.001,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

## Training

In [None]:
def train_nn(epoch):
    net.train()
    losses, accuracies = [], []
    with tqdm(trainloader, unit="batch") as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = net(data)
            _, predictions = output.max(1)
            loss = criterion(output, target)
            correct = predictions.eq(target).sum().item()
            accuracy = (predictions == target).float().mean()

            loss.backward()
            optimizer.step()

            tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())
            
            losses.append(loss)
            accuracies.append(accuracy)
            
    plt.ylim(-0.1, 1.1)
    plot(losses)
    plot(accuracies)

In [None]:
def test_nn(epoch):
#     global best_acc
    net.eval()
    with torch.no_grad():
        with tqdm(testloader, unit="batch") as tepoch:
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                
                data, target = data.to(device), target.to(device)
                output = net(data)
                _, predictions = output.max(1)
                loss = criterion(output, target)
                correct = predictions.eq(target).sum().item()
                accuracy = (predictions == target).float().mean()

                tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())

In [None]:
for epoch in trange(0, 10): #start_epoch, start_epoch+200
    train_nn(epoch)

    test_nn(epoch)
    scheduler.step()

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images[0:4]))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))