In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary
import torch.nn.functional as F

from wideresnet import Wide_ResNet
from generator import Generator

'''
Function that loads the dataset and returns the data-loaders
'''
def getData(batch_size,test_batch_size,val_percentage):
    # Normalize the training set with data augmentation
    transform_train = transforms.Compose([ 
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(20),
        torchvision.transforms.ColorJitter(brightness=0.03, contrast=0.03, saturation=0.03, hue=0.03),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # Normalize the test set same as training set without augmentation
    transform_test = transforms.Compose([ 
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Download/Load data
    full_training_data = torchvision.datasets.CIFAR10('./data',train = True,transform=transform_train,download=True)  
    test_data = torchvision.datasets.CIFAR10('./data',train = False,transform=transform_test,download=True)  

    # Create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int((1-val_percentage)*num_samples+1)
    validation_samples = num_samples - training_samples
    training_data, validation_data = torch.utils.data.random_split(full_training_data, [training_samples, validation_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data,batch_size=batch_size,shuffle=True)
    val_loader = torch.utils.data.DataLoader(validation_data,batch_size=batch_size,shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_data,batch_size=test_batch_size,shuffle=False,drop_last=False,num_workers=4)

    return train_loader, val_loader, test_loader

'''
Function to test that returns the loss per sample and the total accuracy
'''
def test(data_loader,net,cost_fun,device):
    net.eval()
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.
    
    loss_funct = torch.nn.CrossEntropyLoss()

    for batch_idx, (inputs,targets) in enumerate(data_loader):

        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = net(inputs)[0]
        
        loss = loss_funct(outputs,targets)

        # Metrics computation
        samples+=inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()
    
    net.train()

    return cumulative_loss/samples, cumulative_accuracy/samples*100
'''
generator loss:
@output : logits of the student
@output : logits of the teacher

for the KL div as said here https://discuss.pytorch.org/t/kl-divergence-produces-negative-values/16791/4
and here https://discuss.pytorch.org/t/kullback-leibler-divergence-loss-function-giving-negative-values/763/2
the inputs should be logprobs for the output(student) and probabilities for the targets(teacher)

this was very difficult to undertand 

'''
def genLoss(output, target):
    student_pred = F.log_softmax(output)
    teacher_pred = F.softmax(target)
    
    loss = F.kl_div(student_pred,teacher_pred)
    minus_loss = -loss
    
    return minus_loss

def studentLoss(output,target):
    
    student_pred = F.log_softmax(output)
    teacher_pred = F.softmax(target)
    
    loss = F.kl_div(student_pred,teacher_pred)
    
    return loss




def main(n_batches,lr_gen,lr_stud,batch_size,test_batch_size,g_input_dim,ng,ns,test_freq):
    
    device = 'cuda:0'
    
    # Get the data
    train_loader, val_loader, test_loader = getData(batch_size,test_batch_size,0.1)
    
    test_loss = torch.nn.CrossEntropyLoss()
    
    teacher = Wide_ResNet(16,2,0,10)
    teacher = teacher.to(device)
    teacher.load_state_dict(torch.load('./pretrained_models/cifar_net_test.pth'))
    
    generator = Generator(z_dim=g_input_dim)
    generator = generator.to(device)
    generator.train()
    
    student = Wide_ResNet(16,1,0,10)
    student = student.to(device)
    
    generator_optim = torch.optim.Adam(generator.parameters(), lr=lr_gen)
    gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optim, n_batches)
    
    student_optim = torch.optim.Adam(student.parameters(), lr=lr_stud)
    stud_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optim, n_batches)
    
    print('Teacher net test:')
    test_loss, test_accuracy = test(test_loader,teacher,test_loss,device)
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    print('Student net test:')
    test_loss, test_accuracy = test(test_loader,teacher,test_loss,device)
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    for i in range(n_batches):
        print('Batch ' + str(i))
        noise = torch.randn(batch_size,g_input_dim)
        noise = noise.to(device)
        
        gen_loss_print = 0
        
        for j in range(ng):
            gen_imgs = generator(noise)
            gen_imgs = gen_imgs.to(device)

            teacher_pred, *teacher_activations = teacher(gen_imgs)
            student_pred, *student_activations = student(gen_imgs)

            gen_loss = genLoss(student_pred,teacher_pred)
            generator_optim.zero_grad()
            gen_loss.backward()

            generator_optim.step()
            
            gen_loss_print += gen_loss.item()
        
        print('Gen loss :' + str(gen_loss_print/ng) )
        
        stud_loss_print = 0
        for j in range(ns):
            student.train()
            gen_imgs = generator(noise)
            teacher_pred, *teacher_activations = teacher(gen_imgs)
            student_pred, *student_activations = student(gen_imgs)
            
            stud_loss = studentLoss(student_pred,teacher_pred)
            student_optim.zero_grad()
            stud_loss.backward()
            student_optim.step()
            
            stud_loss_print += stud_loss.item()
        
        print('Stud loss :' + str(stud_loss_print/ns) )
            
        stud_scheduler.step()
        gen_scheduler.step()
        
        if(i % test_freq) == 0:
            print('Student net test:')
            test_loss, test_accuracy = test(test_loader,teacher,test_loss,device)
            print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
n_batches = 100
lr_gen = 2e-3
lr_stud = 2e-3
batch_size = 128
test_batch_size = 128
g_input_dim = 100
ng = 1
ns = 10
test_freq = 5
    
main(n_batches,lr_gen,lr_stud,batch_size,test_batch_size,g_input_dim,ng,ns,test_freq)
    

Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 16x2
| Wide-Resnet 16x1
Teacher net test:
	 Test loss: 	 0.004356, 	 Test accuracy 	 81.47
Student net test:
	 Test loss: 	 0.004356, 	 Test accuracy 	 81.47
Batch 0




Gen loss :-0.08278881758451462
Stud loss :0.09982872754335403
Student net test:
Files already downloaded and verified
Files already downloaded and verified
	 Test loss: 	 0.011429, 	 Test accuracy 	 59.36
Batch 1
Gen loss :-0.0761445164680481
Stud loss :0.08729197159409523
Batch 2
Gen loss :-0.08127801865339279
Stud loss :0.10628908537328244
Batch 3
Gen loss :-0.0946321040391922
Stud loss :0.11840420700609684
Batch 4
Gen loss :-0.08500321209430695
Stud loss :0.10360203720629216
Batch 5
Gen loss :-0.08993116766214371
Stud loss :0.11170846745371818
Student net test:
Files already downloaded and verified
Files already downloaded and verified
	 Test loss: 	 0.029878, 	 Test accuracy 	 30.74
Batch 6
Gen loss :-0.08656004816293716
Stud loss :0.10214017070829869
Batch 7
Gen loss :-0.09681618213653564
Stud loss :0.11374375633895398
Batch 8
Gen loss :-0.08284517377614975
Stud loss :0.09625374525785446
Batch 9
Gen loss :-0.08573966473340988
Stud loss :0.09663736969232559
Batch 10
Gen loss :-0.08

KeyboardInterrupt: 