In [None]:
import os

import torch
from torchvision import transforms

import numpy as np
import random

from torchvision.datasets import CIFAR10
from brain import Agent

from tqdm import tqdm

device = torch.device("cuda")

In [None]:
def set_seeds(seed):
    torch.manual_seed(seed)  # Sets seed for PyTorch RNG
    torch.cuda.manual_seed_all(seed)  # Sets seeds of GPU RNG
    np.random.seed(seed=seed)  # Set seed for NumPy RNG
    random.seed(seed)  # Set seed for random RNG

set_seeds(1)


In [None]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 
    transforms.RandomHorizontalFlip(), 
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
 ])

val_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
 ])

batch_size=192

trainset = CIFAR10("../data", train=True, transform=train_transform, download=True)
train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=batch_size, drop_last=True, num_workers=3, pin_memory=True)

valset = CIFAR10("../data", train=False, transform=val_transform, download=True)
val_loader = torch.utils.data.DataLoader(valset, shuffle=False, batch_size=batch_size, drop_last=True, num_workers=3, pin_memory=True)


In [None]:
model = Agent().to(device)
model2 = Agent().to(device)
model2.load_state_dict(model.state_dict())

loss_fn = torch.nn.CrossEntropyLoss()
lr=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=lr, weight_decay=1e-4)

#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#                optimizer, factor=0.2, mode="max", verbose=True)

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("tensorboard")
dataiter = iter(train_loader)
images, labels = dataiter.next()
writer.add_graph(model, images.to(device))
writer.flush()


In [None]:

def train_loop(dataloader, model, loss_fn, optimizer, epoch, dropout_p):
    size = len(dataloader.dataset)
    trainloss = 0
    train_correct = 0
    train_total = 0
    for batch, (X, y) in enumerate(tqdm(dataloader)):
        model.train()
        model2.train()
        
        optimizer.zero_grad()
        optimizer2.zero_grad()
        
        model = model.train()
        X, y = X.to(device), y.to(device)
        pred = model(X, dropout_p)
        loss = loss_fn(pred, y)
        print(loss)
        loss.backward()

        trainloss += loss.item()

        pred = torch.argmax(pred, dim=1)
        train_total += len(pred)
        train_correct += (pred == y).sum().item()
    
    
        for i, (X,y) in enumerate(val_loader):
            X, y = X.to(device), y.to(device)
            pred = model2(X)
            valloss = loss_fn(pred, y)
            valloss.backward()
            break
        
        for p1, p2 in zip(model.parameters(), model2.parameters()):
            p1.grad = torch.where(p1.grad*p2.grad > 0, p1.grad, p1.grad/100)
            print((p1.grad*p2.grad > 0).sum()/torch.numel(p1.grad))
        
        optimizer.step()
        model2.load_state_dict(model.state_dict())
        
        
        
        
    model2.eval()
    valloss = 0
    val_total = 0
    val_correct = 0
    for i, (X,y) in enumerate(val_loader):
        X, y = X.to(device), y.to(device)
        pred = model2(X)
        valloss += loss_fn(pred, y).item()
        pred = torch.argmax(pred, dim=1)
        val_total += len(pred)
        val_correct += (pred == y).sum().item()
        
    
    print(f"Epoch: {epoch}; Train Loss: {trainloss/train_total}, Val Loss: {valloss/val_total}")

    writer.add_scalars("Training vs Validation Accuracy", {
        "training": train_correct/train_total,
        "validation": val_correct/val_total
    }, epoch)

    writer.add_scalars("Training vs Validation loss", {
        "training": trainloss/train_total,
        "validation": valloss/val_total
    }, epoch)


In [None]:
#checkpoint = torch.load("model")
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#last_epoch = checkpoint["epoch"]


In [None]:
epochs = 1000
for t in range(0, epochs):
    dropout_p = 0.3
    train_loop(train_loader, model, loss_fn, optimizer, epoch=t, dropout_p=dropout_p)
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        #'scheduler_state_dict': scheduler.state_dict(),
        'epoch': t
        }, "model")
