In [1]:
import os

import torch
from torchvision import transforms

import pandas as pd
import numpy as np
import random

from torchvision.datasets import CIFAR10
from brain import Agent

from tqdm import tqdm

device = "cpu"#torch.device("cuda")
max_iterations = 6


In [2]:
def set_seeds(seed):
    torch.manual_seed(seed)  # Sets seed for PyTorch RNG
    torch.cuda.manual_seed_all(seed)  # Sets seeds of GPU RNG
    np.random.seed(seed=seed)  # Set seed for NumPy RNG
    random.seed(seed)  # Set seed for random RNG

set_seeds(123)


In [3]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
 ])

val_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
 ])

batch_size=32

trainset = CIFAR10("../data", train=True, transform=train_transform, download=True)
train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=batch_size)

valset = CIFAR10("../data", train=False, transform=val_transform, download=True)
val_loader = torch.utils.data.DataLoader(valset, shuffle=True, batch_size=batch_size)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = Agent(max_iterations).to(device)

loss_fn = torch.nn.CrossEntropyLoss()
lr=1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)#, momentum=0.9)


In [5]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("tensorboard")
dataiter = iter(train_loader)
images, labels = dataiter.next()
writer.add_graph(model, images.to(device))
writer.flush()


  batch_size = int(c1.shape[0]/64)


In [6]:
def train_loop(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    trainloss = 0
    train_correct = 0
    train_total = 0
    train_compute_cost = []

    for batch, (X, y) in enumerate(tqdm(dataloader)):
        model = model.train()
        X, y = X.to(device), y.to(device)

        probs_each_iteration = model(X) #max_iterations, batch_size, n_classes
        loss = 0
        for i in range(probs_each_iteration.shape[0]):
            probs = probs_each_iteration[i]
            loss += (0.5**i) * loss_fn(torch.log(probs), y)

        iteration_max_prob, _ = torch.max(probs_each_iteration, dim=2) #max_iterations, batch_size
        iteration_where_pred_made = (iteration_max_prob > 0.5).nonzero(as_tuple=True)[0].view(max_iterations, batch_size)[0] #batch_size
        train_compute_cost.append(iteration_where_pred_made)

        pred = torch.argmax(probs_each_iteration[iteration_where_pred_made, range(batch_size)], dim=1)

        trainloss += loss.item()

        train_total += len(pred)
        train_correct += (pred == y).sum().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    
    val_correct = 0
    val_total = 0
    valloss = 0
    val_compute_cost = []
    model = model.eval()
    with torch.no_grad():
        for i, (X,y) in enumerate(val_loader):
            X, y = X.to(device), y.to(device)
            probs_each_iteration = model(X) #max_iterations, batch_size, n_classes
            valloss = 0
            for i in range(probs_each_iteration.shape[0]):
                probs = probs_each_iteration[i]
                valloss += (0.5**i) * loss_fn(torch.log(probs), y)

            iteration_max_prob, _ = torch.max(probs_each_iteration, dim=2) #max_iterations, batch_size
            iteration_where_pred_made = (iteration_max_prob > 0.5).nonzero(as_tuple=True)[0].view(max_iterations, batch_size)[0] #batch_size
            val_compute_cost.append(iteration_where_pred_made)
            pred = torch.argmax(probs_each_iteration[iteration_where_pred_made, range(batch_size)], dim=1)

            val_total += len(pred)
            val_correct += (pred == y).sum().item()

    print(f"Epoch: {epoch}; Train Loss: {trainloss/train_total}, Val Loss: {valloss/val_total}")

    writer.add_scalars("Training vs Validation Accuracy", {
        "training": train_correct/train_total,
        "validation": val_correct/val_total
    }, epoch)

    writer.add_scalars("Training vs Validation loss", {
        "training": trainloss/train_total,
        "validation": valloss/val_total
    }, epoch)

    writer.add_scalars("Training vs Validation Compute Cost", {
        "training": train_compute_cost.sum()/train_total.numel(),
        "validation": val_compute_cost.sum()/val_total.numel()
    }, epoch)


In [7]:
#checkpoint = torch.load("model")
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#checkpoint["epoch"]

In [8]:
#for g in optimizer.param_groups:
#    g['weight_decay'] = 0.0011


In [9]:
epochs = 1000
for t in range(0, epochs):
    for g in optimizer.param_groups:
        g['lr'] = max(lr*(0.997**t), 0.5*lr)
        #g['weight_'] = lr*0.5
    train_loop(train_loader, model, loss_fn, optimizer, epoch=t)
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': t
            }, "model")


  0%|          | 2/1563 [00:54<11:51:38, 27.35s/it]