# Imports, global constants, functions

In [1]:
import torch
import torchvision
from models.model_conv import ConvNet
import torchvision.transforms as transforms
import math
import copy
import time
import pickle
from datetime import datetime
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
from grad_utils import *

In [2]:
BATCH_SIZE = 128
NUM_WORKERS = 32
PIN_MEMORY = True
NUM_EPOCHS = 100000
GRAD_DIM = 247434
PATH = "./generated_data/" + datetime.today().strftime("%Y%m%d%H%M%S")
os.mkdir(PATH)

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [4]:
test_set = torchvision.datasets.CIFAR10(root = "./data",
                                        train = False,
                                        download = True,
                                        transform = transform)

test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size = BATCH_SIZE,
                                          shuffle = False,
                                          pin_memory = PIN_MEMORY,
                                          num_workers = NUM_WORKERS)

Files already downloaded and verified


In [5]:
def load_model_conv(device_str, model_num):
    model = ConvNet().eval().to(device_str)
    for param in model.parameters():
        param.requires_grad = True
    state_dict = torch.load("./models/model_conv.pt")
    model.load_state_dict(state_dict)
    return model

# Gradients

In [6]:
model = load_model_conv("cuda:6", 16)

In [7]:
sum(1 for p in model.parameters())

12

In [8]:
random_data = get_data(2000, test_loader, 10000, PATH)
random_batch = torch.vstack([v for k, v in random_data.items()])

In [9]:
paramlist = [param for param in model.parameters()]

In [10]:
y = model(random_batch.to("cuda:6"))
torch.save(model.h.cpu(), PATH + "/h_values.pt")
paramlist = [param for param in model.parameters()]
with open(PATH + "/weights.pickle", "wb") as f:
    pickle.dump(paramlist, f)

In [None]:
grads = get_grads_per_layer(y, model, PATH)

done with 1
-- 0.7835016250610352 --
done with 2
-- 0.7174558639526367 --
done with 3
-- 0.7132217884063721 --
done with 4
-- 0.7186088562011719 --
done with 5
-- 0.7206039428710938 --
done with 6
-- 0.7134828567504883 --
done with 7
-- 0.7112386226654053 --
done with 8
-- 0.7129058837890625 --
done with 9
-- 0.710667610168457 --
done with 10
-- 0.7100358009338379 --
done with 11
-- 0.7141132354736328 --
done with 12
-- 0.7113535404205322 --
done with 13
-- 0.7150840759277344 --
done with 14
-- 0.7127516269683838 --
done with 15
-- 0.714486837387085 --
done with 16
-- 0.7097089290618896 --
done with 17
-- 0.715679407119751 --
done with 18
-- 0.7098386287689209 --
done with 19
-- 0.7154951095581055 --
done with 20
-- 0.7179770469665527 --
done with 21
-- 0.7134251594543457 --
done with 22
-- 0.7109014987945557 --
done with 23
-- 0.7101082801818848 --
done with 24
-- 0.7176878452301025 --
done with 25
-- 0.7119271755218506 --
done with 26
-- 0.71407151222229 --
done with 27
-- 0.71368551

done with 215
-- 0.7148234844207764 --
done with 216
-- 0.7120881080627441 --
done with 217
-- 0.7153773307800293 --
done with 218
-- 0.7110800743103027 --
done with 219
-- 0.7116124629974365 --
done with 220
-- 0.7143609523773193 --
done with 221
-- 0.7132034301757812 --
done with 222
-- 0.7145724296569824 --
done with 223
-- 0.7153675556182861 --
done with 224
-- 0.713798999786377 --
done with 225
-- 0.7136082649230957 --
done with 226
-- 0.7133529186248779 --
done with 227
-- 0.7176415920257568 --
done with 228
-- 0.7190403938293457 --
done with 229
-- 0.7201931476593018 --
done with 230
-- 0.7234780788421631 --
done with 231
-- 0.7190942764282227 --
done with 232
-- 0.720391035079956 --
done with 233
-- 0.7221865653991699 --
done with 234
-- 0.7212393283843994 --
done with 235
-- 0.7196524143218994 --
done with 236
-- 0.7200727462768555 --
done with 237
-- 0.7217080593109131 --
done with 238
-- 0.7167575359344482 --
done with 239
-- 0.7204041481018066 --
done with 240
-- 0.72003030

done with 426
-- 0.7169063091278076 --
done with 427
-- 0.7171390056610107 --
done with 428
-- 0.7181508541107178 --
done with 429
-- 0.7157735824584961 --
done with 430
-- 0.7198221683502197 --
done with 431
-- 0.716484785079956 --
done with 432
-- 0.7178316116333008 --
done with 433
-- 0.7195894718170166 --
done with 434
-- 0.716310977935791 --
done with 435
-- 0.7169630527496338 --
done with 436
-- 0.7217206954956055 --
done with 437
-- 0.7159347534179688 --
done with 438
-- 0.7170190811157227 --
done with 439
-- 0.7162339687347412 --
done with 440
-- 0.7159972190856934 --
done with 441
-- 0.7149860858917236 --
done with 442
-- 0.7173833847045898 --
done with 443
-- 0.71897292137146 --
done with 444
-- 0.7167794704437256 --
done with 445
-- 0.7142825126647949 --
done with 446
-- 0.717332124710083 --
done with 447
-- 0.7160396575927734 --
done with 448
-- 0.7163174152374268 --
done with 449
-- 0.717083215713501 --
done with 450
-- 0.7164874076843262 --
done with 451
-- 0.718695640563

done with 637
-- 0.716789722442627 --
done with 638
-- 0.7189521789550781 --
done with 639
-- 0.7193281650543213 --
done with 640
-- 0.7211246490478516 --
done with 641
-- 0.716963529586792 --
done with 642
-- 0.7198150157928467 --
done with 643
-- 0.7176468372344971 --
done with 644
-- 0.7182254791259766 --
done with 645
-- 0.717778205871582 --
done with 646
-- 0.720771074295044 --
done with 647
-- 0.7175252437591553 --
done with 648
-- 0.7159378528594971 --
done with 649
-- 0.7172775268554688 --
done with 650
-- 0.7174057960510254 --
done with 651
-- 0.7163779735565186 --
done with 652
-- 0.7171025276184082 --
done with 653
-- 0.7183587551116943 --
done with 654
-- 0.7177455425262451 --
done with 655
-- 0.7215356826782227 --
done with 656
-- 0.7169034481048584 --
done with 657
-- 0.7198522090911865 --
done with 658
-- 0.7208733558654785 --
done with 659
-- 0.7210631370544434 --
done with 660
-- 0.7181684970855713 --
done with 661
-- 0.7201836109161377 --
done with 662
-- 0.7198188304

done with 848
-- 0.7192707061767578 --
done with 849
-- 0.7170989513397217 --
done with 850
-- 0.7211499214172363 --
done with 851
-- 0.720170259475708 --
done with 852
-- 0.7195580005645752 --
done with 853
-- 0.7205007076263428 --
done with 854
-- 0.7220206260681152 --
done with 855
-- 0.7270617485046387 --
done with 856
-- 0.7181248664855957 --
done with 857
-- 0.7183818817138672 --
done with 858
-- 0.7209005355834961 --
done with 859
-- 0.7194476127624512 --
done with 860
-- 0.7187604904174805 --
done with 861
-- 0.7190859317779541 --
done with 862
-- 0.7206218242645264 --
done with 863
-- 0.7206635475158691 --
done with 864
-- 0.7139427661895752 --
done with 865
-- 0.7138819694519043 --
done with 866
-- 0.7150905132293701 --
done with 867
-- 0.7126812934875488 --
done with 868
-- 0.7170073986053467 --
done with 869
-- 0.7128074169158936 --
done with 870
-- 0.7152583599090576 --
done with 871
-- 0.7124621868133545 --
done with 872
-- 0.7136731147766113 --
done with 873
-- 0.7181503

done with 1058
-- 0.7249729633331299 --
done with 1059
-- 0.7271411418914795 --
done with 1060
-- 0.7298979759216309 --
done with 1061
-- 0.7223846912384033 --
done with 1062
-- 0.7242779731750488 --
done with 1063
-- 0.7228794097900391 --
done with 1064
-- 0.7243146896362305 --
done with 1065
-- 0.7247884273529053 --
done with 1066
-- 0.7270634174346924 --
done with 1067
-- 0.725008487701416 --
done with 1068
-- 0.7247259616851807 --
done with 1069
-- 0.7244012355804443 --
done with 1070
-- 0.7248220443725586 --
done with 1071
-- 0.726870059967041 --
done with 1072
-- 0.7240853309631348 --
done with 1073
-- 0.7257857322692871 --
done with 1074
-- 0.7262489795684814 --
done with 1075
-- 0.7251241207122803 --
done with 1076
-- 0.7233734130859375 --
done with 1077
-- 0.7247636318206787 --
done with 1078
-- 0.7261130809783936 --
done with 1079
-- 0.7274527549743652 --
done with 1080
-- 0.7242541313171387 --
done with 1081
-- 0.7241086959838867 --
done with 1082
-- 0.7225828170776367 --
do

done with 1264
-- 0.7256598472595215 --
done with 1265
-- 0.7236599922180176 --
done with 1266
-- 0.7245378494262695 --
done with 1267
-- 0.7230086326599121 --
done with 1268
-- 0.7265417575836182 --
done with 1269
-- 0.7243659496307373 --
done with 1270
-- 0.7248344421386719 --
done with 1271
-- 0.7264976501464844 --
done with 1272
-- 0.7239954471588135 --
done with 1273
-- 0.7271180152893066 --
done with 1274
-- 0.7224602699279785 --
done with 1275
-- 0.7254269123077393 --
done with 1276
-- 0.7263948917388916 --
done with 1277
-- 0.7260770797729492 --
done with 1278
-- 0.7239096164703369 --
done with 1279
-- 0.7218067646026611 --
done with 1280
-- 0.7263848781585693 --
done with 1281
-- 0.725299596786499 --
done with 1282
-- 0.7322494983673096 --
done with 1283
-- 0.7253456115722656 --
done with 1284
-- 0.7237148284912109 --
done with 1285
-- 0.7292282581329346 --
done with 1286
-- 0.7238564491271973 --
done with 1287
-- 0.7249109745025635 --
done with 1288
-- 0.7253694534301758 --
d

done with 1470
-- 0.7283318042755127 --
done with 1471
-- 0.7263240814208984 --
done with 1472
-- 0.7257423400878906 --
done with 1473
-- 0.7255785465240479 --
done with 1474
-- 0.7235655784606934 --
done with 1475
-- 0.7303385734558105 --
done with 1476
-- 0.7249391078948975 --
done with 1477
-- 0.7229619026184082 --
done with 1478
-- 0.7254128456115723 --
done with 1479
-- 0.7262024879455566 --
done with 1480
-- 0.7237939834594727 --
done with 1481
-- 0.7254805564880371 --
done with 1482
-- 0.7243263721466064 --
done with 1483
-- 0.7263412475585938 --
done with 1484
-- 0.7245917320251465 --
done with 1485
-- 0.7242422103881836 --
done with 1486
-- 0.7248342037200928 --
done with 1487
-- 0.7240426540374756 --
done with 1488
-- 0.7242574691772461 --
done with 1489
-- 0.7251884937286377 --
done with 1490
-- 0.7259521484375 --
done with 1491
-- 0.7245104312896729 --
done with 1492
-- 0.7253894805908203 --
done with 1493
-- 0.7267289161682129 --
done with 1494
-- 0.723750114440918 --
done

done with 1676
-- 0.727849006652832 --
done with 1677
-- 0.7310404777526855 --
done with 1678
-- 0.7303681373596191 --
done with 1679
-- 0.7291948795318604 --
done with 1680
-- 0.7295804023742676 --
done with 1681
-- 0.7288455963134766 --
done with 1682
-- 0.7264323234558105 --
done with 1683
-- 0.7250156402587891 --
done with 1684
-- 0.7283740043640137 --
done with 1685
-- 0.7271649837493896 --
done with 1686
-- 0.7240259647369385 --
done with 1687
-- 0.7261950969696045 --
done with 1688
-- 0.726243257522583 --
done with 1689
-- 0.7277922630310059 --
done with 1690
-- 0.7262821197509766 --
done with 1691
-- 0.728813648223877 --
done with 1692
-- 0.730017900466919 --
done with 1693
-- 0.7289097309112549 --
done with 1694
-- 0.7258014678955078 --
done with 1695
-- 0.726768970489502 --
done with 1696
-- 0.7320220470428467 --
done with 1697
-- 0.7294585704803467 --
done with 1698
-- 0.7280197143554688 --
done with 1699
-- 0.7273690700531006 --
done with 1700
-- 0.7302236557006836 --
done 

In [None]:
GRAD_DIM = sum(x.flatten().shape[0] for x in paramlist)
GRAD_DIM

In [None]:
flattened_grads = get_flattened_summed_grads(grads)

In [None]:
unnormed_grads = flattened_grads.clone()

In [None]:
blocklist = [param_layer.flatten() for param_layer in paramlist]

In [None]:
flattened_params = torch.cat([p.flatten() for p in paramlist])

In [None]:
h_values = torch.load(PATH + "/h_values.pt")

In [None]:
#normed_h_values = torch.stack([row / torch.max(row) for row in h_values])
normed_h_values = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in h_values])
in_h_maxnorm, out_h_maxnorm = calculate_inner_products(normed_h_values,
                                                       GRAD_DIM,
                                                       weights = blocklist, metric = "",
                                                       to_norm = False, device = "cuda:7")
gap, _, _ = calculate_gap(in_h_maxnorm, out_h_maxnorm)

In [None]:
#normed_grads = torch.stack([row / torch.max(row) for row in unnormed_grads])
#normed_grads = unnormed_grads / torch.max(unnormed_grads)
normed_grads = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in unnormed_grads])
in_full_maxnorm, out_full_maxnorm = calculate_inner_products(normed_grads, 
                                                             GRAD_DIM,
                                                             weights = blocklist, metric = "block",
                                                             to_norm = False, device = "cuda:7")
gap2, v1, v2 = calculate_gap(in_full_maxnorm, out_full_maxnorm)

In [None]:
sparsified_block_v3= sparsify_v3(unnormed_grads, "cuda:1", to_norm_output = False,
                                  threshold = 1.15)
#normed_sparsed = torch.stack([row / torch.max(row) for row in sparsified_block_v3])
normed_sparsed = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in sparsified_block_v3])
asdin, asdout = calculate_inner_products(normed_sparsed,
                                         GRAD_DIM,
                                         weights = blocklist, metric = "block",
                                         to_norm = False, device = "cuda:1")
gap4, v3, v4 = calculate_gap(asdin, asdout)

In [None]:
inputgaps_v2 = get_gap_for_each_input_v2(in_full_maxnorm, out_full_maxnorm)
sparsed_gaps_v2 = get_gap_for_each_input_v2(asdin, asdout)

In [None]:
inputgaps_v2 = inputgaps_v2.type(torch.int).numpy()
sparsed_gaps_v2 = sparsed_gaps_v2.type(torch.int).numpy()

In [None]:
fig, axs = plt.subplots(1, 1, figsize = (12, 6), dpi = 120, sharex = True, sharey = True)
axs.hist(np.abs(inputgaps_v2),
         bins = 200, histtype = 'step', label = 'block-diagonal gap')
axs.hist(np.abs(sparsed_gaps_v2),
         bins = 200, histtype = 'step', label = 'elementwise sparse gap')

axs.get_xaxis().set_ticks([])
axs.get_yaxis().set_ticks([])
axs.set_title("Small CNN on CIFAR-10", y = 1.0, color = 'black', pad = -20, fontsize = 16)
plt.legend(labelcolor = 'black', fontsize = 12, loc = 'upper right')
plt.savefig("ct_cifar10_final.svg", dpi = 300)
plt.show()