In [None]:
import copy
from collections import OrderedDict

import sys
import time
import os
import gc
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable

from torch import optim
import torch.nn.functional as F
import numpy as np

In [None]:
from models import conv_block, ConvNet

In [None]:
from utils import count_parameters, L1Grad, L2Grad, LPGrad, create_activation_gradients, create_weight_gradients

In [None]:
# Load training data
transform_train = transforms.Compose([                                   
    transforms.RandomCrop(32, padding=4),                                       
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# Load testing data
transform_test = transforms.Compose([                                           
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
torch.manual_seed(43) # to give stable randomness 

device = 'cuda'
net = ConvNet()
net = net.to(device)

print(count_parameters(net))

1212778


In [None]:
# training function

In [None]:
def train(net, trainloader, criterion, regularizer, optimizer, epoch, train_loss_tracker, train_acc_tracker):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs.requires_grad=True
        optimizer.zero_grad()
        
        # forward pass
        outputs, activations = net(inputs)
        loss = criterion(outputs, targets)      # unregularized loss
        
        a_grad_dict = create_activation_gradients(activations, loss)
        w_grad_dict = create_weight_gradients(net, loss, model_type='convnet')
        
        combined_loss = loss
        if regularizer:
            if isinstance(regularizer, L1Grad) or isinstance(regularizer, L2Grad):
                combined_loss += regularizer(list(w_grad_dict.values()), list(a_grad_dict.values()))
              
        # backward pass
        combined_loss.backward(retain_graph=False)
        
        # update optimizer state
        optimizer.step()       # optimizer shouldn't contain inputs, or else inputs.requires_grad=True will make the input itself change...
        
        # compute average loss
        train_loss += combined_loss.item()
        train_loss_tracker.append(combined_loss.item())
        loss = train_loss / (batch_idx + 1)
        
        # compute accuracy
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        
        # Print status
        sys.stdout.write(f'\rEpoch {epoch}: Train Loss: {loss:.3f}' +  
                         f'| Train Acc: {acc:.3f}')
        sys.stdout.flush()
        
    train_acc_tracker.append(acc)
    sys.stdout.flush()


In [None]:
# testing function 

In [None]:
def test(net, testloader, criterion, regularizer, epoch, test_loss_tracker, test_acc_tracker):
    global best_acc
    best_acc = 0 
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs.requires_grad=True
        
        # forward pass
        outputs, activations = net(inputs)
        loss = criterion(outputs, targets)      # unregularized loss
        
        a_grad_dict = create_activation_gradients(activations, loss)
        w_grad_dict = create_weight_gradients(net, loss, model_type='convnet')
        
        combined_loss = loss
        if regularizer:
            if isinstance(regularizer, L1Grad) or isinstance(regularizer, L2Grad):
                combined_loss += regularizer(list(w_grad_dict.values()), list(a_grad_dict.values()))        
        
        test_loss += combined_loss.item()
        test_loss_tracker.append(combined_loss.item())
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        loss = test_loss / (batch_idx + 1)
        acc = 100.* correct / total
    sys.stdout.write(f' | Test Loss: {loss:.3f} | Test Acc: {acc:.3f}\n')
    sys.stdout.flush()
    
    # Save checkpoint.
    acc = 100.*correct/total
    test_acc_tracker.append(acc)
    if acc > best_acc:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc  

In [None]:
train_loss_tracker, train_acc_tracker = [], []
test_loss_tracker, test_acc_tracker = [], []

lr = 0.1
lmbda = 0.01

epochs = 5

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

In [None]:
criterion = nn.CrossEntropyLoss() #CustomCE()

In [None]:
regularizer = L1Grad(lmbda)
regularizer2= L2Grad(lmbda)

regula = LPGrad(lmbda, P=1)

In [None]:
print(f'Training for {epochs} epochs, with learning rate {lr} and lambda {lmbda}')

start_time = time.time()
for epoch in range(0, epochs):
    ep_start_time = time.time()
    
    train_alt(net, trainloader, criterion, regula, optimizer, epoch, train_loss_tracker, train_acc_tracker)
    test_alt(net, testloader, criterion, regula, epoch, test_loss_tracker, test_acc_tracker)
    #scheduler.step()
    
    ep_end_time = time.time()
    epoch_time = ep_end_time - ep_start_time
    print(f"Training time: {epoch_time} seconds")

total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Training for 5 epochs, with learning rate 0.1 and lambda 0.01
Epoch 0: Train Loss: 2.536| Train Acc: 11.7781 | Test Loss: 2.238 | Test Acc: 13.200
Training time: 180.56733226776123 seconds
Epoch 1: Train Loss: 2.266| Train Acc: 14.314 | Test Loss: 2.123 | Test Acc: 16.560
Training time: 179.4168622493744 seconds
Epoch 2: Train Loss: 2.178| Train Acc: 17.607

KeyboardInterrupt: 

In [None]:
def train_alt(net, trainloader, criterion, regularizer, optimizer, epoch, train_loss_tracker, train_acc_tracker):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        #inputs.requires_grad=True
        optimizer.zero_grad()
        
        # forward pass
        outputs, activations = net(inputs)
        loss = criterion(outputs, targets)      # unregularized loss
        
        for a in activations:
            #a.requires_grad=True
            a.retain_grad()
        
        # create higher order autograd graph
        two_backward_required = not (regularizer == None)
        loss.backward(create_graph=two_backward_required, retain_graph=two_backward_required)    # forces activations, weights to have higher order gradients
        
        w_grad_list = []
        for m in net.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                w_grad_list.append(m.weight.grad)
        a_grad_list = []
        for a in activations:
            a_grad_list.append(a.grad)
            
        #optimizer.zero_grad()
        
        #combined_loss = loss
        # second backward pass
        reg_term =0
        if regularizer:
            if isinstance(regularizer, L1Grad) or isinstance(regularizer, L2Grad) or isinstance(regularizer, LPGrad):
                reg_term = regularizer( w_grad_list, a_grad_list )
                reg_term.backward(retain_graph=False)
        
        #combined_loss.backward(create_graph=False, retain_graph=False)
        
        
        # update optimizer state
        optimizer.step()       # optimizer shouldn't contain inputs, or else inputs.requires_grad=True will make the input itself change...
        
#         # take apart computation graph
#         with torch.no_grad():
#             if regularizer:
#                 reg_term.backward()

#         try:
#             loss.backward(retain_graph=False)
#         except:
#             pass

        
        
        
        # compute average loss
        train_loss += (loss+reg_term).item()
        train_loss_tracker.append((loss+reg_term).item())
        loss_value = train_loss / (batch_idx + 1)
        
        # preventing memory leak
        optimizer.zero_grad(True)
        inputs.detach()
        loss.grad=None
        loss.detach()
        if regularizer:
            reg_term.grad=None
            reg_term.detach()
#         for x in w_grad_list:
#             x.detach()
#         for x in a_grad_list:
#             x.detach()
#         del w_grad_list; del a_grad_list; del loss; del reg_term
        
        
        torch.cuda.empty_cache()
        
        # compute accuracy
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        
        # Print status
        sys.stdout.write(f'\rEpoch {epoch}: Train Loss: {loss_value:.3f}' +  
                         f'| Train Acc: {acc:.3f}' )  # +f'| Batch Index: {batch_idx}' + f'| Num_GC: {count_gc_objects()}')
        sys.stdout.flush()
        
        #time.sleep(5)
        
    train_acc_tracker.append(acc)
    sys.stdout.flush()
    
def test_alt(net, testloader, criterion, regularizer, epoch, test_loss_tracker, test_acc_tracker):
    global best_acc
    best_acc = 0 
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        # forward pass
        outputs, activations = net(inputs)
        loss = criterion(outputs, targets)      # unregularized loss
        
        for a in activations:
            a.retain_grad()
            
        loss.backward()
        
        w_grad_list = []
        for m in net.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                w_grad_list.append(m.weight.grad)
        a_grad_list = []
        for a in activations:
            a_grad_list.append(a.grad)
        
        reg_term =0
        if regularizer:
            if isinstance(regularizer, L1Grad) or isinstance(regularizer, L2Grad) or isinstance(regularizer, LPGrad):
                reg_term = regularizer( w_grad_list, a_grad_list )  
        
        test_loss += (loss+reg_term).item()
        test_loss_tracker.append((loss+reg_term).item())
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        loss_value = test_loss / (batch_idx + 1)
        acc = 100.* correct / total
        
        # preventing memory leak
        optimizer.zero_grad(True)
        inputs.grad=None
        inputs.detach()
        loss.grad=None
        loss.detach()
        if regularizer:
            reg_term.grad=None
            reg_term.detach()
        
        # empty cache
        torch.cuda.empty_cache()        
        
    sys.stdout.write(f' | Test Loss: {loss:.3f} | Test Acc: {acc:.3f}\n')
    sys.stdout.flush()
    
    # Save checkpoint.
    acc = 100.*correct/total
    test_acc_tracker.append(acc)
    if acc > best_acc:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc  

In [None]:
def count_gc_objects():
    count = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                #print(type(obj), obj.size())
                count += 1
        except:
            pass
    return count

In [None]:
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

In [None]:
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass