In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary
import torch.nn.functional as F

from wideresnet import Wide_ResNet
from generator import Generator
from cifar10utils import getData, test

'''
generator loss:
@output : logits of the student
@output : logits of the teacher

for the KL div as said here https://discuss.pytorch.org/t/kl-divergence-produces-negative-values/16791/4
and here https://discuss.pytorch.org/t/kullback-leibler-divergence-loss-function-giving-negative-values/763/2
the inputs should be logprobs for the output(student) and probabilities for the targets(teacher)

this was very difficult to understand 

'''
def attention(x):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    """
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))


def attention_diff(x, y):
    """
    Taken from https://github.com/szagoruyko/attention-transfer
    :param x = activations
    :param y = activations
    """
    return (attention(x) - attention(y)).pow(2).mean()


def divergence(student_logits, teacher_logits):
    divergence = F.kl_div(F.log_softmax(student_logits, dim=1), F.softmax(teacher_logits, dim=1))

    return divergence


def KT_loss_generator(student_logits, teacher_logits):

    divergence_loss = divergence(student_logits, teacher_logits)
    total_loss = - divergence_loss

    return total_loss


def KT_loss_student(student_logits, teacher_logits, student_activations, teacher_activations,beta):

    divergence_loss = divergence(student_logits, teacher_logits)
    if beta > 0:
        at_loss = 0
        for i in range(len(student_activations)):
            at_loss = at_loss + beta * attention_diff(student_activations[i], teacher_activations[i])
    else:
        at_loss = 0

    total_loss = divergence_loss + at_loss

    return total_loss



def main(n_batches,lr_gen,lr_stud,batch_size,test_batch_size,g_input_dim,ng,ns,test_freq,beta):
    
    device = 'cuda:0'
    
    # Get the data
    train_loader, val_loader, test_loader = getData(batch_size,test_batch_size,0.1)
    
    #test_loss = torch.nn.CrossEntropyLoss()
    
    teacher = Wide_ResNet(16,2,0,10)
    teacher = teacher.to(device)
    teacher.load_state_dict(torch.load('./pretrained_models/cifar_net_test.pth'))
    
    generator = Generator(z_dim=g_input_dim)
    generator = generator.to(device)
    generator.train()
    
    student = Wide_ResNet(16,1,0,10)
    student = student.to(device)
    
    generator_optim = torch.optim.Adam(generator.parameters(), lr=lr_gen)
    gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optim, n_batches)
    
    student_optim = torch.optim.Adam(student.parameters(), lr=lr_stud)
    stud_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optim, n_batches)
    
    print('Teacher net test:')
    test_loss, test_accuracy = test(test_loader,teacher,device)
    teacher.eval()
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    print('Student net test:')
    test_loss, test_accuracy = test(test_loader,student,device)
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    for i in range(n_batches):
        print('Batch ' + str(i))
        noise = torch.randn(batch_size,g_input_dim)
        noise = noise.to(device)
        
        gen_loss_print = 0
        
        for j in range(ng):
            gen_imgs = generator(noise)
            gen_imgs = gen_imgs.to(device)

            teacher_pred, *teacher_activations = teacher(gen_imgs)
            student_pred, *student_activations = student(gen_imgs)

            gen_loss = KT_loss_generator(student_pred,teacher_pred)
            generator_optim.zero_grad()
            gen_loss.backward()
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 5)

            generator_optim.step()
            
            gen_loss_print += gen_loss.item()
        
        print('Gen loss :' + str(gen_loss_print/ng) )
        
        stud_loss_print = 0
        for j in range(ns):
            student.train()
            gen_imgs = generator(noise)
            teacher_pred, *teacher_activations = teacher(gen_imgs)
            student_pred, *student_activations = student(gen_imgs)
            
            stud_loss = KT_loss_student(student_pred,teacher_pred, student_activations,teacher_activations, beta )
            student_optim.zero_grad()
            stud_loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), 5)
            student_optim.step()
            
            stud_loss_print += stud_loss.item()
        
        print('Stud loss :' + str(stud_loss_print/ns) )
            
        stud_scheduler.step()
        gen_scheduler.step()
        
        if(i % test_freq) == 0:
            print('Student net test:')
            test_loss, test_accuracy = test(test_loader,student,device)
            print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
n_batches = 2001
lr_gen = 1e-3
lr_stud = 2e-3
batch_size = 128
test_batch_size = 128
g_input_dim = 100
ng = 1
ns = 10
test_freq = 20
beta = 250
    
main(n_batches,lr_gen,lr_stud,batch_size,test_batch_size,g_input_dim,ng,ns,test_freq,beta)
    

Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 16x2
| Wide-Resnet 16x1
Teacher net test:
	 Test loss: 	 0.004332, 	 Test accuracy 	 81.45
Student net test:
	 Test loss: 	 0.018027, 	 Test accuracy 	 10.10
Batch 0




Gen loss :-0.19704370200634003
Stud loss :0.4913065046072006
Student net test:
	 Test loss: 	 0.029392, 	 Test accuracy 	 10.02
Batch 1
Gen loss :-0.11669502407312393
Stud loss :0.2732813894748688
Batch 2
Gen loss :-0.05323195457458496
Stud loss :0.18428010120987892
Batch 3
Gen loss :-0.02305629849433899
Stud loss :0.14837367609143257
Batch 4
Gen loss :-0.011543425731360912
Stud loss :0.13176514357328414
Batch 5
Gen loss :-0.006782103329896927
Stud loss :0.1135058969259262
Batch 6
Gen loss :-0.004704089369624853
Stud loss :0.10592681169509888
Batch 7
Gen loss :-0.0035686076153069735
Stud loss :0.10211708024144173
Batch 8
Gen loss :-0.0029252253007143736
Stud loss :0.09694579690694809
Batch 9
Gen loss :-0.0025428382214158773
Stud loss :0.09644677266478538
Batch 10
Gen loss :-0.0021936490666121244
Stud loss :0.08962583690881729
Batch 11
Gen loss :-0.001921864808537066
Stud loss :0.08321783281862735
Batch 12
Gen loss :-0.0016996437916532159
Stud loss :0.07874535173177719
Batch 13
Gen loss

Gen loss :-0.000896961078979075
Stud loss :0.008262773789465428
Batch 108
Gen loss :-0.0008375655161216855
Stud loss :0.008369325753301382
Batch 109
Gen loss :-0.0010189767926931381
Stud loss :0.008486206363886594
Batch 110
Gen loss :-0.0011325024534016848
Stud loss :0.008512733737006783
Batch 111
Gen loss :-0.000875613302923739
Stud loss :0.008723379112780094
Batch 112
Gen loss :-0.0010550027946010232
Stud loss :0.008357910485938192
Batch 113
Gen loss :-0.000863665365613997
Stud loss :0.008813877310603856
Batch 114
Gen loss :-0.0007741716108284891
Stud loss :0.008682869747281075
Batch 115
Gen loss :-0.0008438844233751297
Stud loss :0.008472861815243959
Batch 116
Gen loss :-0.0007698810077272356
Stud loss :0.00844981879927218
Batch 117
Gen loss :-0.0010637718951329589
Stud loss :0.009032769221812486
Batch 118
Gen loss :-0.0011014320189133286
Stud loss :0.00891452506184578
Batch 119
Gen loss :-0.000991043052636087
Stud loss :0.009128301497548818
Batch 120
Gen loss :-0.000964379112701863

Gen loss :-0.013296491466462612
Stud loss :0.06292060576379299
Batch 215
Gen loss :-0.01769843138754368
Stud loss :0.06683884486556053
Batch 216
Gen loss :-0.012832269072532654
Stud loss :0.06103239506483078
Batch 217
Gen loss :-0.014526220969855785
Stud loss :0.06726521886885166
Batch 218
Gen loss :-0.014619332738220692
Stud loss :0.06932308673858642
Batch 219
Gen loss :-0.015457344241440296
Stud loss :0.07093526870012283
Batch 220
Gen loss :-0.01695954240858555
Stud loss :0.07640839889645576
Student net test:
	 Test loss: 	 0.032012, 	 Test accuracy 	 20.54
Batch 221
Gen loss :-0.017638621851801872
Stud loss :0.08135442323982715
Batch 222
Gen loss :-0.020806578919291496
Stud loss :0.09403396397829056
Batch 223
Gen loss :-0.017850151285529137
Stud loss :0.08773765712976456
Batch 224
Gen loss :-0.017965255305171013
Stud loss :0.097135541588068
Batch 225
Gen loss :-0.02894476056098938
Stud loss :0.11751035377383232
Batch 226
Gen loss :-0.03665146231651306
Stud loss :0.14653602987527847


Gen loss :-0.005674341227859259
Stud loss :0.02499695047736168
Batch 322
Gen loss :-0.006280349101871252
Stud loss :0.026569523103535175
Batch 323
Gen loss :-0.005274009890854359
Stud loss :0.02457724343985319
Batch 324
Gen loss :-0.0050397636368870735
Stud loss :0.02521888930350542
Batch 325
Gen loss :-0.005437081679701805
Stud loss :0.02426519487053156
Batch 326
Gen loss :-0.005512375850230455
Stud loss :0.02617575563490391
Batch 327
Gen loss :-0.004894312471151352
Stud loss :0.025589629635214807
Batch 328
Gen loss :-0.004598430823534727
Stud loss :0.024301376193761826
Batch 329
Gen loss :-0.005401795729994774
Stud loss :0.02346032653003931
Batch 330
Gen loss :-0.005309609696269035
Stud loss :0.02348516136407852
Batch 331
Gen loss :-0.005885121878236532
Stud loss :0.02485437449067831
Batch 332
Gen loss :-0.005510570947080851
Stud loss :0.025778743624687194
Batch 333
Gen loss :-0.005005206447094679
Stud loss :0.024989986792206764
Batch 334
Gen loss :-0.005162648856639862
Stud loss :0.

Stud loss :0.061992093548178674
Batch 429
Gen loss :-0.017310766503214836
Stud loss :0.06058026514947414
Batch 430
Gen loss :-0.0160395335406065
Stud loss :0.06265334747731685
Batch 431
Gen loss :-0.015707775950431824
Stud loss :0.06468667685985566
Batch 432
Gen loss :-0.01755817048251629
Stud loss :0.06796149872243404
Batch 433
Gen loss :-0.01925632916390896
Stud loss :0.06750903874635697
Batch 434
Gen loss :-0.017589304596185684
Stud loss :0.0655752331018448
Batch 435
Gen loss :-0.020715558901429176
Stud loss :0.07155461311340332
Batch 436
Gen loss :-0.01896822825074196
Stud loss :0.06966598182916642
Batch 437
Gen loss :-0.01764996163547039
Stud loss :0.0722113836556673
Batch 438
Gen loss :-0.018457448109984398
Stud loss :0.07161606587469578
Batch 439
Gen loss :-0.017858406528830528
Stud loss :0.07044531740248203
Batch 440
Gen loss :-0.01894241012632847
Stud loss :0.0673159345984459
Student net test:
	 Test loss: 	 0.024832, 	 Test accuracy 	 27.25
Batch 441
Gen loss :-0.019005078822

Gen loss :-0.012724495492875576
Stud loss :0.0667992316186428
Batch 538
Gen loss :-0.01578715816140175
Stud loss :0.06747915260493756
Batch 539
Gen loss :-0.013189414516091347
Stud loss :0.06488748826086521
Batch 540
Gen loss :-0.014302290976047516
Stud loss :0.06768500693142414
Student net test:
	 Test loss: 	 0.023061, 	 Test accuracy 	 31.19
Batch 541
Gen loss :-0.013648184947669506
Stud loss :0.06790213249623775
Batch 542
Gen loss :-0.015007617883384228
Stud loss :0.06888451017439365
Batch 543
Gen loss :-0.013899916782975197
Stud loss :0.07193171456456185
Batch 544
Gen loss :-0.015490707941353321
Stud loss :0.06886624731123447
Batch 545
Gen loss :-0.01291600801050663
Stud loss :0.06423595771193505
Batch 546
Gen loss :-0.013468257151544094
Stud loss :0.06615377403795719
Batch 547
Gen loss :-0.013246064074337482
Stud loss :0.06796102784574032
Batch 548
Gen loss :-0.017564764246344566
Stud loss :0.07312839776277542
Batch 549
Gen loss :-0.014015751890838146
Stud loss :0.067386832833290

Gen loss :-0.009551974944770336
Stud loss :0.04312584213912487
Batch 645
Gen loss :-0.008580345660448074
Stud loss :0.042691930383443835
Batch 646
Gen loss :-0.010682939551770687
Stud loss :0.044860322028398514
Batch 647
Gen loss :-0.008827977813780308
Stud loss :0.04421420246362686
Batch 648
Gen loss :-0.01091244351118803
Stud loss :0.04642755016684532
Batch 649
Gen loss :-0.012146525084972382
Stud loss :0.04593916870653629
Batch 650
Gen loss :-0.009357147850096226
Stud loss :0.044872748851776126
Batch 651
Gen loss :-0.010033870115876198
Stud loss :0.045906081795692444
Batch 652
Gen loss :-0.010995668359100819
Stud loss :0.0469194546341896
Batch 653
Gen loss :-0.01053842157125473
Stud loss :0.04718719683587551
Batch 654
Gen loss :-0.008625208400189877
Stud loss :0.045053808391094206
Batch 655
Gen loss :-0.012260599993169308
Stud loss :0.04776026606559754
Batch 656
Gen loss :-0.010556647554039955
Stud loss :0.04540090747177601
Batch 657
Gen loss :-0.0096736503764987
Stud loss :0.044739

Gen loss :-0.010812784545123577
Stud loss :0.05181880816817284
Batch 753
Gen loss :-0.01050406601279974
Stud loss :0.052983736619353294
Batch 754
Gen loss :-0.010474525392055511
Stud loss :0.05050730630755425
Batch 755
Gen loss :-0.012393190525472164
Stud loss :0.05236758813261986
Batch 756
Gen loss :-0.013643271289765835
Stud loss :0.053251459077000615
Batch 757
Gen loss :-0.017322083935141563
Stud loss :0.05518090277910233
Batch 758
Gen loss :-0.011673521250486374
Stud loss :0.053853008151054385
Batch 759
Gen loss :-0.0105039207264781
Stud loss :0.053537018969655036
Batch 760
Gen loss :-0.011148606427013874
Stud loss :0.05095738507807255
Student net test:
	 Test loss: 	 0.015065, 	 Test accuracy 	 42.21
Batch 761
Gen loss :-0.01106600183993578
Stud loss :0.05279076285660267
Batch 762
Gen loss :-0.011599804274737835
Stud loss :0.051365474611520766
Batch 763
Gen loss :-0.009379680268466473
Stud loss :0.05079679265618324
Batch 764
Gen loss :-0.011346030049026012
Stud loss :0.05252078138

Stud loss :0.07013584598898888
Student net test:
	 Test loss: 	 0.013782, 	 Test accuracy 	 49.23
Batch 861
Gen loss :-0.018406635150313377
Stud loss :0.06980278566479683
Batch 862
Gen loss :-0.017843617126345634
Stud loss :0.0703457809984684
Batch 863
Gen loss :-0.019205769523978233
Stud loss :0.06897399611771107
Batch 864
Gen loss :-0.019426319748163223
Stud loss :0.07875583320856094
Batch 865
Gen loss :-0.01612558774650097
Stud loss :0.07071887515485287
Batch 866
Gen loss :-0.021018078550696373
Stud loss :0.07007804363965989
Batch 867
Gen loss :-0.017764829099178314
Stud loss :0.06875250935554504
Batch 868
Gen loss :-0.020840153098106384
Stud loss :0.07320020087063313
Batch 869
Gen loss :-0.019828563556075096
Stud loss :0.07702579647302628
Batch 870
Gen loss :-0.015815867111086845
Stud loss :0.06776966378092766
Batch 871
Gen loss :-0.02031274512410164
Stud loss :0.07147916778922081
Batch 872
Gen loss :-0.01810913160443306
Stud loss :0.06772235073149205
Batch 873
Gen loss :-0.0161470

Gen loss :-0.01650509051978588
Stud loss :0.06954846940934659
Batch 969
Gen loss :-0.018038397654891014
Stud loss :0.07085665874183178
Batch 970
Gen loss :-0.016478247940540314
Stud loss :0.06904644779860973
Batch 971
Gen loss :-0.014911608770489693
Stud loss :0.07061453759670258
Batch 972
Gen loss :-0.016346966847777367
Stud loss :0.06957094185054302
Batch 973
Gen loss :-0.01713079772889614
Stud loss :0.07271518111228943
Batch 974
Gen loss :-0.01659664325416088
Stud loss :0.07263014018535614
Batch 975
Gen loss :-0.018957043066620827
Stud loss :0.0738187287002802
Batch 976
Gen loss :-0.01701853610575199
Stud loss :0.06991328671574593
Batch 977
Gen loss :-0.01804967410862446
Stud loss :0.07184542641043663
Batch 978
Gen loss :-0.017708001658320427
Stud loss :0.07271864041686057
Batch 979
Gen loss :-0.015784909948706627
Stud loss :0.06935142688453197
Batch 980
Gen loss :-0.01815096102654934
Stud loss :0.0696093924343586
Student net test:
	 Test loss: 	 0.009494, 	 Test accuracy 	 60.22
Ba

Stud loss :0.0750728689134121
Batch 1076
Gen loss :-0.014989945106208324
Stud loss :0.07316587269306182
Batch 1077
Gen loss :-0.016561217606067657
Stud loss :0.071696712449193
Batch 1078
Gen loss :-0.016859715804457664
Stud loss :0.06959822438657284
Batch 1079
Gen loss :-0.018478747457265854
Stud loss :0.07356437109410763
Batch 1080
Gen loss :-0.01756061054766178
Stud loss :0.07138077579438687
Student net test:
	 Test loss: 	 0.008893, 	 Test accuracy 	 62.17
Batch 1081
Gen loss :-0.020658573135733604
Stud loss :0.07341547906398774
Batch 1082
Gen loss :-0.01764518767595291
Stud loss :0.07252166792750359
Batch 1083
Gen loss :-0.017481118440628052
Stud loss :0.07111995816230773
Batch 1084
Gen loss :-0.01766776107251644
Stud loss :0.0728283528238535
Batch 1085
Gen loss :-0.016149157658219337
Stud loss :0.06911525838077068
Batch 1086
Gen loss :-0.01830573007464409
Stud loss :0.07532687000930309
Batch 1087
Gen loss :-0.01817811094224453
Stud loss :0.07491683848202228
Batch 1088
Gen loss :-0

Gen loss :-0.01841883361339569
Stud loss :0.0741082090884447
Batch 1183
Gen loss :-0.01892874203622341
Stud loss :0.07557560727000237
Batch 1184
Gen loss :-0.01710687205195427
Stud loss :0.07236929684877395
Batch 1185
Gen loss :-0.017530953511595726
Stud loss :0.06939261928200721
Batch 1186
Gen loss :-0.019083520397543907
Stud loss :0.06972380094230175
Batch 1187
Gen loss :-0.01826414093375206
Stud loss :0.07175561077892781
Batch 1188
Gen loss :-0.019740935415029526
Stud loss :0.07231424935162067
Batch 1189
Gen loss :-0.016104508191347122
Stud loss :0.06593220196664333
Batch 1190
Gen loss :-0.01835625246167183
Stud loss :0.07011389695107936
Batch 1191
Gen loss :-0.017530998215079308
Stud loss :0.06835129708051682
Batch 1192
Gen loss :-0.017855441197752953
Stud loss :0.07284406200051308
Batch 1193
Gen loss :-0.019434845075011253
Stud loss :0.07032561711966992
Batch 1194
Gen loss :-0.015129977837204933
Stud loss :0.07022701650857925
Batch 1195
Gen loss :-0.01703673042356968
Stud loss :0.

KeyboardInterrupt: 