In [None]:

import torch
import torch.nn as nn
import torch.optim as optim

import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import tqdm

from model import *

from tqdm import tqdm

import pickle



device = 'cuda' if torch.cuda.is_available() else 'cpu'


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

train_length=int(0.7* len(trainset))

val_length=len(trainset)-train_length

trainset, val_set = torch.utils.data.random_split(trainset, (train_length,val_length))


trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=True, num_workers=2)

validationloader = torch.utils.data.DataLoader(
    val_set, batch_size=100, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)



# Model


def ResNet18():
    return ResNet(BasicBlock, [2, 1, 1, 1])

net = ResNet18()

lr = .1
epochs = 200

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True



criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print('Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss/len(trainloader), 100.*correct/total, correct, total)) 
    return train_loss/len(trainloader), 100.*correct/total


def val(epoch):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for _, (inputs, targets) in enumerate(validationloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print('ValLoss: %.3f | ValAcc: %.3f%% (%d/%d)'% (test_loss/len(validationloader), 100.*correct/total, correct, total))
    return test_loss/len(validationloader), 100.*correct/total

    

min_valid_loss = 100
train_losses = []
val_losses = []
train_accs = []
val_accs = []
learning_rates = []
for epoch in range(epochs):
    last_lr = scheduler.get_last_lr()[0]
    train_loss, train_acc =train(epoch)
    valid_loss, val_acc = val(epoch)
    
    learning_rates.append(last_lr)
    train_losses.append(train_loss)
    val_losses.append(valid_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    scheduler.step()
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(net.state_dict(), 'saved_model.pth')
    with open('train_losses.pickle', 'wb') as handle:
      pickle.dump(train_losses, handle)
    with open('val_losses.pickle', 'wb') as handle:
      pickle.dump(val_losses, handle)
    with open('train_accs.pickle', 'wb') as handle:
      pickle.dump(train_accs, handle)
    with open('val_accs.pickle', 'wb') as handle:
      pickle.dump(val_accs, handle)
    with open('learning_rates.pickle', 'wb') as handle:
      pickle.dump(learning_rates, handle)
    