# Imports, global constants, functions

In [None]:
import torch
import torchvision
from models.model_conv import ConvNet
import torchvision.transforms as transforms
import math
import copy
import time
import pickle
from datetime import datetime
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
from grad_utils import *

In [None]:
BATCH_SIZE = 128
NUM_WORKERS = 32
PIN_MEMORY = True
NUM_EPOCHS = 100000
GRAD_DIM = 247434
PATH = "./generated_data/" + datetime.today().strftime("%Y%m%d%H%M%S")
os.mkdir(PATH)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
test_set = torchvision.datasets.CIFAR10(root = "./data",
                                        train = False,
                                        download = True,
                                        transform = transform)

test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size = BATCH_SIZE,
                                          shuffle = False,
                                          pin_memory = PIN_MEMORY,
                                          num_workers = NUM_WORKERS)

In [None]:
def load_model_conv(device_str, model_num):
    model = ConvNet().eval().to(device_str)
    for param in model.parameters():
        param.requires_grad = True
    state_dict = torch.load("./models/model_conv.pt")
    model.load_state_dict(state_dict)
    return model

# Gradients

In [None]:
model = load_model_conv("cuda:6", 16)

In [None]:
sum(1 for p in model.parameters())

In [None]:
random_data = get_data(2000, test_loader, 10000, PATH)
random_batch = torch.vstack([v for k, v in random_data.items()])

In [None]:
paramlist = [param for param in model.parameters()]

In [None]:
y = model(random_batch.to("cuda:6"))
torch.save(model.h.cpu(), PATH + "/h_values.pt")
paramlist = [param for param in model.parameters()]
with open(PATH + "/weights.pickle", "wb") as f:
    pickle.dump(paramlist, f)

In [None]:
grads = get_grads_per_layer(y, model, PATH)

In [None]:
GRAD_DIM = sum(x.flatten().shape[0] for x in paramlist)
GRAD_DIM

In [None]:
flattened_grads = get_flattened_summed_grads(grads)

In [None]:
unnormed_grads = flattened_grads.clone()

In [None]:
blocklist = [param_layer.flatten() for param_layer in paramlist]

In [None]:
flattened_params = torch.cat([p.flatten() for p in paramlist])

In [None]:
h_values = torch.load(PATH + "/h_values.pt")

In [None]:
#normed_h_values = torch.stack([row / torch.max(row) for row in h_values])
normed_h_values = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in h_values])
in_h_maxnorm, out_h_maxnorm = calculate_inner_products(normed_h_values,
                                                       GRAD_DIM,
                                                       weights = blocklist, metric = "",
                                                       to_norm = False, device = "cuda:7")
gap, _, _ = calculate_gap(in_h_maxnorm, out_h_maxnorm)

In [None]:
#normed_grads = torch.stack([row / torch.max(row) for row in unnormed_grads])
#normed_grads = unnormed_grads / torch.max(unnormed_grads)
normed_grads = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in unnormed_grads])
in_full_maxnorm, out_full_maxnorm = calculate_inner_products(normed_grads, 
                                                             GRAD_DIM,
                                                             weights = blocklist, metric = "block",
                                                             to_norm = False, device = "cuda:7")
gap2, v1, v2 = calculate_gap(in_full_maxnorm, out_full_maxnorm)

In [None]:
sparsified_block_v3= sparsify_v3(unnormed_grads, "cuda:1", to_norm_output = False,
                                  threshold = 1.15)
#normed_sparsed = torch.stack([row / torch.max(row) for row in sparsified_block_v3])
normed_sparsed = torch.stack([(row - torch.min(row)) / (torch.max(row) - torch.min(row)) for row in sparsified_block_v3])
asdin, asdout = calculate_inner_products(normed_sparsed,
                                         GRAD_DIM,
                                         weights = blocklist, metric = "block",
                                         to_norm = False, device = "cuda:1")
gap4, v3, v4 = calculate_gap(asdin, asdout)

In [None]:
inputgaps_v2 = get_gap_for_each_input_v2(in_full_maxnorm, out_full_maxnorm)
sparsed_gaps_v2 = get_gap_for_each_input_v2(asdin, asdout)

In [None]:
inputgaps_v2 = inputgaps_v2.type(torch.int).numpy()
sparsed_gaps_v2 = sparsed_gaps_v2.type(torch.int).numpy()

In [None]:
fig, axs = plt.subplots(1, 1, figsize = (12, 6), dpi = 120, sharex = True, sharey = True)
axs.hist(np.abs(inputgaps_v2),
         bins = 200, histtype = 'step', label = 'block-diagonal gap')
axs.hist(np.abs(sparsed_gaps_v2),
         bins = 200, histtype = 'step', label = 'elementwise sparse gap')

axs.get_xaxis().set_ticks([])
axs.get_yaxis().set_ticks([])
axs.set_title("Small CNN on CIFAR-10", y = 1.0, color = 'black', pad = -20, fontsize = 16)
plt.legend(labelcolor = 'black', fontsize = 12, loc = 'upper right')
plt.savefig("ct_cifar10_final.svg", dpi = 300)
plt.show()