In [1]:
%matplotlib inline

In [2]:
from matplotlib.pylab import *
# style.use("style_sheet.mplstyle")

rc('figure', figsize=(6, 4))
rc('savefig', bbox='tight')
plt.rcParams.update({'font.size': 18})

In [3]:
import torch
import numpy as np
from models import LeNet, ResNet18
import glob


In [None]:
# LeNet
fig, ax = plt.subplots()

norms = [[],[]]
re_norms = [[],[]]

model = LeNet(10)
model.load_state_dict(torch.load("checkpoints/mnist/lenet/run_ms_0/run0/best_model.pth.tar", map_location='cpu')['model'])
k=0
for l in model.children():
    if isinstance(l, torch.nn.Dropout):
        continue
    s = l.weight.data.shape[0]
    norms[0] += [k]*s 
    norms[1] += list((l.weight.data.view(s,-1).norm(dim=1) + l.bias.data.abs()).numpy())
    k+=1

    
model.load_state_dict(torch.load("checkpoints/mnist/lenet/run_ms_0/run0/best_model.pth.tar", map_location='cpu')['model'])
model.norm()
k = 0
for l in model.children():    
    if isinstance(l, torch.nn.Dropout):
        continue
    s = l.weight.data.shape[0]
    re_norms[0] += [k]*s 
    re_norms[1] += list((l.weight.data.view(s,-1).norm(dim=1) + l.bias.data.abs()).numpy())
    k+=1

ax.scatter(norms[0], norms[1])
ax.scatter(re_norms[0], re_norms[1])
ax.set_xlabel("Layer number")
ax.set_ylabel("Layer norm")
ax.set_title("Norm of filters at each layer")
ax.legend(["original model", "normalized model"], fancybox=True, shadow=True, frameon=True, loc=0,handletextpad=0.1)
plt.savefig("LeNet.png")

In [None]:
from utils import get_loader, class_model_run
import argparse
import torch 
from tqdm.notebook import tqdm 

def get_args(*args):

    parser = argparse.ArgumentParser()

    parser.add_argument('--dir', type=str, default='.')
    parser.add_argument('--dtype', type=str, default="mnist", help='Data type')
    parser.add_argument('--bs', type=int, default=64, help='batch size')

    args = parser.parse_args(*args)

    args.data_dir = f"{args.dir}/data/{args.dtype}"
    args.use_cuda = torch.cuda.is_available()

    return args

dset_loaders = get_loader(get_args([]), training=True)
criterion = torch.nn.CrossEntropyLoss()
model = LeNet(10)
model.load_state_dict(torch.load("checkpoints/mnist/lenet/run_ms_0/run0/best_model.pth.tar", map_location='cpu')['model'])
model.norm()
model.eval()
total_loss = []
for inp_data in tqdm(dset_loaders['train']):
    inputs, targets = inp_data
    
    loss = criterion(model(inputs), targets)
    total_loss.append(loss.item())
    
print(np.mean(total_loss))

In [None]:
# ResNet18
fig, ax = plt.subplots()

params = {}
norms = [[],[]]
model = ResNet18()
model.load_state_dict(torch.load("checkpoints/cifar10/resnet/run0/best_model.pth.tar", map_location='cpu')['model'])

for n, p in model.named_parameters():
    new_n = ''.join(n.split('.')[:-1])
    if new_n not in params.keys():
        params[new_n] = [p]
    else:
        params[new_n].append(p)

for i,item in enumerate(params.items()):
    k,v = item
    if len(v) == 1:
        s = v[0].shape[0]
        norms[0] += [i]*s 
        norms[1] += list((v[0].data.view(s,-1).norm(dim=1)).numpy())
    elif len(v) == 2:
        if k == 'linear':
            continue
        else:
            s = v[0].shape[0]
            norms[0] += [i]*s
            norms[1] += list((v[0].data.abs() + v[1].data.abs()).numpy())
    else:
        print(f"Problem in {k}")

params = {}
re_norms = [[],[]]
model = ResNet18()
model.load_state_dict(torch.load("checkpoints/cifar10/resnet/run0/best_model.pth.tar", map_location='cpu')['model'])
model.norm()
for n, p in model.named_parameters():
    new_n = ''.join(n.split('.')[:-1])
    if new_n not in params.keys():
        params[new_n] = [p]
    else:
        params[new_n].append(p)

for i,item in enumerate(params.items()):
    k,v = item
    if len(v) == 1:
        s = v[0].shape[0]
        re_norms[0] += [i]*s 
        re_norms[1] += list((v[0].data.view(s,-1).norm(dim=1)).numpy())
    elif len(v) == 2:
        if k == 'linear':
            continue
        else:
            s = v[0].shape[0]
            re_norms[0] += [i]*s
            re_norms[1] += list((v[0].data.abs() + v[1].data.abs()).numpy())
    else:
        print(f"Problem in {k}")
        
ax.scatter(norms[0], norms[1], marker='s')
ax.scatter(re_norms[0], re_norms[1], marker='*')
ax.set_xlabel("Layer number")
ax.set_ylabel("Layer norm")
ax.set_title("Norm of filters at each layer")
ax.legend(["original model", "normalized model"], fancybox=True, shadow=True, frameon=True, loc=0,handletextpad=0.1)
plt.savefig("resnet.png")

In [4]:
from utils import get_loader, class_model_run
import argparse
import torch 
from tqdm.notebook import tqdm 

def get_args(*args):

    parser = argparse.ArgumentParser()

    parser.add_argument('--dir', type=str, default='.')
    parser.add_argument('--dtype', type=str, default="cifar10", help='Data type')
    parser.add_argument('--bs', type=int, default=64, help='batch size')

    args = parser.parse_args(*args)

    args.data_dir = f"{args.dir}/data/{args.dtype}"
    args.use_cuda = torch.cuda.is_available()

    return args

dset_loaders = get_loader(get_args([]), training=True)
criterion = torch.nn.CrossEntropyLoss()
model = ResNet18()
model.load_state_dict(torch.load("checkpoints/cifar10/resnet/run0/best_model.pth.tar", map_location='cpu')['model'])
# model.norm()
model.eval()
total_loss = []
for inp_data in tqdm(dset_loaders['train']):
    inputs, targets = inp_data
    
    loss = criterion(model(inputs), targets)
    total_loss.append(loss.item())
    
print(np.mean(total_loss))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10
Files already downloaded and verified


  0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 