In [1]:
from fastai.vision.all import *
from fastai.distributed import *
from fastai.data import load
from fastai.callback.tracker import SaveModelCallback
from fastprogress import fastprogress

from torchvision import datasets, transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import sys
import os
import copy
import torchvision.transforms as T
import torch

from PIL import Image
import requests

import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch import nn

import argparse
from models.utils.joiner2 import Joiner
from models.utils.new_losses import *
from models.utils.metrics import Accuracy, Curating_Of_Attention_Loss
from models.utils.dataLoader import *
from models.utils.datasets import *
import webdataset as wds

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(device)

cuda:3


In [3]:
#PARAMETERS

#train_path  = "data/WebDataset-GramCifar/train/GramCifar-{0..4}.tar"
#valid_path = "data/WebDataset-GramCifar/valid/GramCifar-0.tar"

H = 32
W= 32
bs = 5
grid_l = 2
nclass = 10
backbone = False
epochs = 5

beta = 0.00005
gamma = 0.0005
sigma = 1.0

In [4]:
#Loading the DataSets
train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=ds_transform())
valid_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=ds_transform())
#print(len(train_ds))
#print(len(valid_ds))

Files already downloaded and verified
Files already downloaded and verified


In [5]:
#Creating the dataloader
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=bs)
valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size = bs, shuffle=False)

In [6]:
#Defining the Loss Functions
train_loss = SingleLabelCriticLoss()
valid_loss = SingleLabelCriticLoss()

In [7]:
#Building the model
model = Joiner(num_encoder_layers = 6, nhead=6, backbone = backbone, num_classes = nclass, bypass=False, pos_enc = "sin", hidden_dim=768, 
batch_size=bs, image_h=H, image_w=W, grid_l=grid_l,penalty_factor="1")
model = model.to(device)

In [8]:
#Defining the optimizer and the Learning Rate decay
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-7)
model_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=9, gamma=0.1)

In [None]:
#THE TRAINING LOOP
# model outputs -> [x, sattn, pattn, inputs, x0]

running_loss_history = [] # training loss - to generate a plot
running_acc_history = [] # traning accuracy
#running_latt_history = [] # traning accuracy

val_running_loss_history = [] # validation loss
val_running_acc_history = [] # validation accuracy

for e in range(epochs):
  
    start_time = time.time()
    print('Epoch {}/{}'.format(e+1, epochs))
    print('-' * 10)

    running_loss = 0.0
    running_acc = 0.0
    running_latt = 0.0
    
    val_running_loss = 0.0
    val_running_acc = 0.0
    
    batch=0
    
    for inputs, labels in train_loader:
        
        batch+=1
        sys.stdout.write('\rBatch: %d' %batch)
        sys.stdout.flush()

        #TRAINING
        inputs = inputs.to(device) # allow gpu use
        labels = labels.to(device)
        #labels[0] = labels[0].to(device)
        #labels[1] = labels[1].to(device)# allow gpu use
        outputs = model(inputs) #gives the output of the last layer
        loss = train_loss(outputs, labels) # comparing outputs and labels using the criteria
        
        optimizer.zero_grad() #zero the grad
        loss.backward() #backpropagation
        optimizer.step() #optimize weights 

        #COMPUTING TRAINING METRICS
        acc = Accuracy(outputs,labels)
        #latt = Curating_Of_Attention_Loss(outputs,labels)
        
        running_loss += loss.item() # the sum of the loss of all itens
        running_acc += acc
        #running_latt += latt
        
        Typenone = 0
        zeros = 0
        normal = 0
        for name, param in model.named_parameters():
            if param.grad == None:
                Typenone +=1
            elif torch.sum(param.grad) == 0:
                zeros += 1
            else:
                normal += 1
        if Typenone >10:
            print("None parameters:",Typenone)
            for name, param in model.named_parameters():
                if param.grad == None:
                    print(name)
        if zeros > 5:
            print("Zero Grad Parameters:", zeros)
        #print("Normally computed Parameters:",normal)
                
                
    else:
        #VALIDATION
        with torch.no_grad(): # to save memory (temporalely set all the requires grad to be false)
            for val_inputs, val_labels in valid_loader:
                val_inputs = val_inputs.to(device) # allow gpu use
                val_labels = val_labels.to(device) # allow gpu use
                val_outputs = model(val_inputs) #passes the image through the network and get the output
                val_loss = valid_loss(val_outputs, val_labels) #compare output and labels to get the loss 

                val_acc = Accuracy(val_outputs,val_labels)
                val_running_loss += val_loss.item() #same as for training
                val_running_acc += val_acc
                
    #Adding one step to the optimizer            
    model_lr_scheduler.step()
    
    #TRAINING LOSS AND ACCURACY
    epoch_loss = running_loss/len(train_ds) # the sum of the loss of all itens divided by the number of itens
    epoch_acc = running_acc/len(train_loader) # the sum of correct predictions divided by the number of itens
    #epoch_latt = running_latt/len(train_loader)
    
    running_loss_history.append(epoch_loss) #append to respective list
    running_acc_history.append(epoch_acc) #append to respective list
    #running_latt_history.append(epoch_latt) #append to respective list

    #VALIDATION LOSS AND ACCURACY
    val_epoch_loss = val_running_loss/len(valid_ds)
    val_epoch_acc = val_acc/ len(valid_loader)
    
    val_running_loss_history.append(val_epoch_loss) #append to respective list
    val_running_acc_history.append(val_epoch_acc) #append to respective list
    
    
    epoch_time_elapsed = time.time() - start_time
    print('Epoch training complete in {:.0f}m {:.0f}s'.format(
            epoch_time_elapsed // 60, epoch_time_elapsed % 60))
    print('training loss: {:.4f}, acc {:.4f}'.format(epoch_loss, epoch_acc.item()))
    print('validation loss: {:.4f}, validation acc {:.4f} '.format(val_epoch_loss, val_epoch_acc.item()))

Epoch 1/5
----------
Batch: 10000Epoch training complete in 23m 55s
training loss: 0.3806, acc 0.1428
validation loss: 0.3745, validation acc 0.0002 
Epoch 2/5
----------
Batch: 10000Epoch training complete in 23m 56s
training loss: 0.3681, acc 0.1485
validation loss: 0.3635, validation acc 0.0002 
Epoch 3/5
----------
Batch: 10000Epoch training complete in 23m 49s
training loss: 0.3586, acc 0.1552
validation loss: 0.3534, validation acc 0.0000 
Epoch 4/5
----------
Batch: 6709