In [1]:
#import all relevant packages
import torch, torchvision
import torchvision.transforms as transforms
import numpy as np
from backpack import backpack, extend
from backpack.extensions import DiagHessian
import matplotlib.pyplot as plt
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.nn import functional as F

s=9
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data wrangling

In [None]:
# data wrangling: load Fashion-MNIST
EMNIST_transform = transforms.Compose([
    transforms.ToTensor(),
])


EMNIST_train = torchvision.datasets.EMNIST(
        '~/data/emnist',
        train=True,
        download=True,
        transform=EMNIST_transform,
        split = 'mnist')



EMNIST_test = torchvision.datasets.EMNIST(
        '~/data/emnist',
        train=False,
        download=False,
        transform=EMNIST_transform,
        split = 'mnist')


0it [00:00, ?it/s]

Downloading and extracting zip archive
Downloading http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to /Users/moreez/data/emnist/EMNIST/raw/emnist.zip


100%|█████████▉| 561463296/561753746 [01:56<00:00, 5237881.55it/s]

Extracting /Users/moreez/data/emnist/EMNIST/raw/emnist.zip to /Users/moreez/data/emnist/EMNIST/raw


561758208it [02:10, 5237881.55it/s]                               

Processing byclass




Processing bymerge


In [None]:
# have a look at the data to verify
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #plt.imshow(npimg)
    plt.show()

images = EMNIST_train.data[:10].view(10, 1, 28, 28)
imshow(torchvision.utils.make_grid(images, nrow=5))


In [None]:
emnist_train_loader = torch.utils.data.dataloader.DataLoader(
    EMNIST_train,
    batch_size=128,
    shuffle=True
)

emnist_test_loader = torch.utils.data.dataloader.DataLoader(
    EMNIST_test,
    batch_size=128,
    shuffle=False,
)


# Training routine

In [None]:
#set up the network
def NN(num_classes=10):
    
    features = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Conv2d(32, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 4 * 32, num_classes)
    )
    return(features)


In [None]:
#set up the training routine
emnist_model = NN(num_classes=10)
loss_function = torch.nn.CrossEntropyLoss()

emnist_train_optimizer = torch.optim.Adam(emnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
#dont use SGD, it is way worse than Adam here
EMNIST_PATH = "EMNIST_weights_seed={}.pth".format(s)
#print(FMNIST_PATH)

In [None]:
# helper function to get accuracy
def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()


In [None]:
# Write the training routine and save the model at FMNIST_PATH

def train(verbose=False, num_iter=5):
    max_len = len(emnist_train_loader)
    for iter in range(num_iter):
        for batch_idx, (x, y) in enumerate(emnist_train_loader):
            output = emnist_model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            emnist_train_optimizer.step()
            emnist_train_optimizer.zero_grad()

            if verbose:
                if batch_idx % 10 == 0:
                    print(
                        "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                        "Minibatch Loss %.3f  " % (loss) +
                        "Accuracy %.0f" % (accuracy * 100) + "%"
                    )

    print("saving model at: {}".format(EMNIST_PATH))
    torch.save(emnist_model.state_dict(), EMNIST_PATH)

In [None]:
#after training it once, comment this out to save time if you rerun the entire script
#train(verbose=True, num_iter=5)


In [None]:
#predict in distribution
EMNIST_PATH = "EMNIST_weights_seed={}.pth".format(s)

emnist_model = NN(num_classes=10)
print("loading model from: {}".format(EMNIST_PATH))
emnist_model.load_state_dict(torch.load(EMNIST_PATH))
emnist_model.eval()

acc = []

max_len = len(emnist_test_loader)
for batch_idx, (x, y) in enumerate(emnist_test_loader):
        output = emnist_model(x)
        accuracy = get_accuracy(output, y)
        if batch_idx % 10 == 0:
            print(
                "Batch {}/{} \t".format(batch_idx, max_len) + 
                "Accuracy %.0f" % (accuracy * 100) + "%"
            )
        acc.append(accuracy)
    
avg_acc = np.mean(acc)
print('overall test accuracy on EMNIST: {:.02f} %'.format(avg_acc * 100))

# Laplace approximation of the weights
* we use the BackPACK package to approximate the Hessian of the parameters. Especially look at the DiagHessian() method.
* we do one iteration over the entire training set and use the mean of the Hessian of the mini-batches as the best approximation of the Hessian.
* we add a prior variance to our Hessian. The precision is 1 over the variance. we use a prior precision of 10, 20, and 50 (or variance of 1/10, 1/20, 1/50).
    * edit: I will use precicisions of 10, 60, 120, 1000 in the following


In [None]:
def get_Hessian_NN(model, train_loader, prec0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Cov_diag = []
    for param in model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Cov_diag.append(torch.zeros(ps, device=device))
        #print(param.numel())

    #var0 = 1/prec0
    max_len = len(train_loader)

    with backpack(DiagHessian()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.float().cuda(), y.long().cuda()

            model.zero_grad()
            lossfunc(model(x), y).backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    H_ = param.diag_h
                    #add prior here
                    H_ += prec0 * torch.ones(H_.size())
                    H_inv = torch.sqrt(1/H_) #<-- standard deviation
                    #H_inv = 1/H_              #<-- variance 

                    rho = 1-1/(batch_idx+1)

                    Cov_diag[idx] = rho*Cov_diag[idx] + (1-rho)* H_inv
            
            if verbose:
                print("Batch: {}/{}".format(batch_idx, max_len))
    
    return(Cov_diag)

In [None]:
#EMNIST_NN_Std_prec_00001 = get_Hessian_NN(model=emnist_model, train_loader=emnist_train_loader, prec0=0.0001,verbose=False)
#torch.save(EMNIST_NN_Std_prec_00001, 'Hessian_prec00001_EMNIST.pth')

In [None]:
EMNIST_NN_Std_prec_00001 = torch.load('Hessian_prec00001_KMNIST.pth')

## Now we want to look at the single layers of our network, and how they behave w.r.t. the variance
* every tensor represents one of the six layers of out network


## visualize the first layer of our networks in a heatmap
* therefore we put the tensor in the right form/dimensions, by concatening all of its included arrays and then reshaping the tensor


In [None]:
import seaborn as sns
def visualize(tensor):
    output = tensor[0][0]
    for i in range(1, len(tensor[0])):
        output = np.concatenate((output, tensor[0][i]))
    output = output.transpose(2, 0, 1).reshape(5, -1)
    heatmap = sns.heatmap(output)
    plt.xticks = (np.arange(0, step=20))
    plt.show()
   

In [None]:
#visualize(MNIST_NN_Hessian_diag_10)

In [None]:
def meancalc(Hessian_diag_x):
    i = 0 
    for name, parameter in emnist_model.named_parameters():
        mean = torch.mean(Hessian_diag_x[i])
        print("mean variance of layer {0:s}: {1:.4f}".format(name, mean.item()))
        i += 1
    


In [None]:
meancalc(Hessian_diag_x=EMNIST_NN_Std_prec_00001)

In [None]:
plt.imshow(EMNIST_NN_Std_prec_00001[4], cmap='gist_stern',extent=[0,512,0,1],  aspect='auto')
ax = plt.gca()
ax.set_xticks(np.arange(1, 512, 32));
ax.set_xticklabels(np.arange(1, 32, 2));
ax.set_title('seed {}'.format(s))
plt.colorbar()
plt.tight_layout()
plt.show()
#plt.savefig('linear_seed1000={}'.format(s))

plt.hist(EMNIST_NN_Std_prec_00001[4])

In [None]:
"""
mnist_number = 9
linear_layer_index = 4 #linear layer has index 4
linear_layer = MNIST_NN_Hessian_diag_120[linear_layer_index][mnist_number]

#reshape the flattened array to 32* (4x4)
layer_split = np.array_split(np.array(linear_layer), 32)
for i in range(len(layer_split)-1):
    layer_split[i] = np.reshape(layer_split[i], (4, -1))

#plot setup
fig, axs = plt.subplots(4,8, figsize=(20, 15))
fig.subplots_adjust(hspace = .001, wspace=.001)
axs = axs.ravel()

#iterate through the features and plot them
for i in range(len(layer_split)):
    layer_split[i] = np.reshape(layer_split[i], (4, -1))
    axs[i].imshow(layer_split[i])
    axs[i].set_title('feature '+str(i+1))
"""

In [None]:
for name, parameters in emnist_model.named_parameters():
    if name == '7.weight':
        a = parameters

b = a.detach().numpy()

plt.imshow(b, cmap='prism',extent=[0,512,0,1],  aspect='auto')
ax = plt.gca()
ax.set_xticks(np.arange(1, 512, 32));
ax.set_xticklabels(np.arange(1, 32, 2));
ax.set_title('weight {}'.format(s))
plt.colorbar()
plt.tight_layout()
im = plt.show()
plt.hist(b)

In [None]:
to_hist = []
target_feature = 9
for i in range(10):
    weight = a[i].detach().numpy()
    weight = np.array_split(np.array(weight), 32)
    to_hist.append(weight[target_feature -1])
plt.hist(to_hist)

In [None]:
"""
observe = [1, 8, 13, 14, 17, 24]
f,c = plt.subplots(3,2)
plt.figure(figsize=(20,20))
c = c.ravel()
for i in range(10):
    test = np.array_split(a[i].detach().numpy(), 32)
for idx, ax in enumerate(c):
    ax.set_title(str(observe[idx] +1))
    ax.hist(test[observe[idx]])
plt.tight_layout()
        
    
#for idx,ax in enumerate(a):
 #   ax.hist(test[observe[idx]])
 """