In [1]:
import pickle as pkl
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Softmax
from torch.utils.data import DataLoader
from torch.nn.functional import softmax

In [2]:
with open("mnist.pkl",'rb') as f:
    train_set, valid_set, test_set = pkl.load(f,encoding='bytes')
    f.close()

In [3]:
def get_data(number):   #filters through the training/test data to find ones with specified label
    new_train = []
    new_test = []
    for i,label in enumerate(train_set[1]): #find data with label "number" in training set
        if label == number:
            new_train.append(train_set[0][i])
    for i,label in enumerate(valid_set[1]): #find data with label "number" in validation set and add to training set (since no need for validation in VCL)
        if label == number:
            new_train.append(train_set[0][i])
    for i,label in enumerate(test_set[1]):#find data with label "number" in test set for testing
        if label == number:
            new_test.append(train_set[0][i])
    return new_train,new_test

def generate_task_data(tasknumber):
    data_0 = get_data(tasknumber)
    data_1 = get_data(tasknumber + 1)  
             
    trainx = np.vstack((data_0[0],data_1[0]))
    zeros =  np.zeros((len(data_1[0]), 1))
    ones =  np.ones((len(data_0[0]), 1))
    trainy = np.vstack((ones,zeros))# create labels for classes
    trainy = np.hstack((trainy,1-trainy))# format for cross entropy loss (-1*likelihood) downstream

    testx = np.vstack((data_0[1],data_1[1]))
    zeros =  np.zeros((len(data_1[1]), 1))
    ones =  np.ones((len(data_0[1]), 1))
    testy = np.vstack((ones,zeros))
    testy = np.hstack((testy,1-testy))
    return trainx,trainy,testx,testy


In [4]:
#First train a regular NN to find appropriate initialisation values for bayesian network
class MLP(nn.Module):
  def __init__(self, n_inputs: int):
    super(MLP, self).__init__()
    self.layers = nn.Sequential(
        nn.Linear(n_inputs,256),
        nn.ReLU(),
        nn.Linear(256,256),
        #nn.Softmax(),
        nn.ReLU(),
        nn.Linear(256,2),
        
        #nn.LogSoftmax()

    )
  def forward(self, x):
    return self.layers(x)

def ce_loss(logits,targets):
  logprobs = torch.nn.functional.log_softmax(logits)
  return -(targets*logprobs).sum()
# ###checking###
def train_model(x, y, model, lr: float=0.005, n_epochs: int=100):
  bs = 500
  optimiser = torch.optim.Adam(model.parameters(), lr=lr)#,weight_decay=5)
  dloader = DataLoader(list(zip(x, y)), shuffle=True, batch_size=bs)
  for epoch in range(n_epochs):
    train_loss = 0
    for batch_x, batch_y in dloader:
      optimiser.zero_grad()
      output = model(batch_x)
      loss_val = torch.binary_cross_entropy_with_logits(output,batch_y).sum()
      loss_val.backward()
      optimiser.step()

      train_loss += loss_val.item() * len(batch_x)

    train_loss /= len(x)

    if epoch % 2  == 0:
      print(f'Epoch: {epoch+1}/{n_epochs}, loss: {train_loss:.4f}')

  return model
data = generate_task_data(0)
model = MLP(n_inputs=784)
initial_model = train_model(torch.tensor(data[0],dtype=torch.float32),torch.tensor(data[1],dtype=torch.float32),model)
print(initial_model)
ans = torch.tensor(data[3].argmax(axis=1))
succ = 0
preds = torch.softmax(initial_model(torch.tensor(data[2])),dim=1).argmax(dim=1)


for i in range(len(preds)):
    if preds[i] == ans[i]:
        succ +=1
print(succ)
print("/")
print(len(ans))


Epoch: 1/100, loss: 274.4540
Epoch: 3/100, loss: 188.4021
Epoch: 5/100, loss: 181.4164
Epoch: 7/100, loss: 170.6639
Epoch: 9/100, loss: 158.9686
Epoch: 11/100, loss: 146.1150
Epoch: 13/100, loss: 134.3355
Epoch: 15/100, loss: 121.9030
Epoch: 17/100, loss: 108.8356
Epoch: 19/100, loss: 99.3294
Epoch: 21/100, loss: 96.7782
Epoch: 23/100, loss: 87.8035
Epoch: 25/100, loss: 81.4479
Epoch: 27/100, loss: 75.4425
Epoch: 29/100, loss: 68.6143
Epoch: 31/100, loss: 69.5818
Epoch: 33/100, loss: 68.5809
Epoch: 35/100, loss: 65.3790
Epoch: 37/100, loss: 71.1018
Epoch: 39/100, loss: 59.3254
Epoch: 41/100, loss: 53.1488
Epoch: 43/100, loss: 53.5741
Epoch: 45/100, loss: 55.4159
Epoch: 47/100, loss: 53.6616
Epoch: 49/100, loss: 54.4257
Epoch: 51/100, loss: 49.5054
Epoch: 53/100, loss: 56.0544
Epoch: 55/100, loss: 50.9661
Epoch: 57/100, loss: 45.0319
Epoch: 59/100, loss: 42.3179
Epoch: 61/100, loss: 43.3004
Epoch: 63/100, loss: 43.6162
Epoch: 65/100, loss: 51.3333
Epoch: 67/100, loss: 53.9323
Epoch: 69/

In [5]:
# creates first prior:
task_weights_means = []   #To hold task (n) specific head means in position (n)
task_weights_variances = []  #for each entry : value i,j is the variance that would be the diagonal entry on the cov matrix for the associated mean field approximation
task_bias_means = []
task_bias_variances = []

shared_weights_mean = []
shared_weights_variances = []
shared_bias_mean = []
shared_bias_variances = []

#first shared layer weights and biases
shared_weights_mean.append(initial_model.layers[0].weight.clone().detach())#Extract shared means from regular neural network for initialisation of shared hidden layers
shared_bias_mean.append(initial_model.layers[0].bias.clone().detach())  #Also use .clone.detach to instantiate it as tensor for autograd downstream and remove from previous computation graph
shared_weights_variances.append(torch.full(shared_weights_mean[0].shape,-6,requires_grad=False,dtype=torch.float32))  #initialise variances as 10^-6 as per paper
shared_bias_variances.append(torch.full(shared_bias_mean[0].shape,-6,requires_grad=False,dtype=torch.float32))        #we are storing variances as log variances to prevent negative variance

#Second layer shared weights and biases
shared_weights_mean.append(initial_model.layers[2].weight.clone().detach())
shared_bias_mean.append(initial_model.layers[2].bias.clone().detach())
shared_weights_variances.append(torch.full(shared_weights_mean[1].shape,-6,requires_grad=False,dtype=torch.float32))
shared_bias_variances.append(torch.full(shared_bias_mean[1].shape,-6,requires_grad=False,dtype=torch.float32))

#initialise first task head weights/variances
task_weights_means.append((initial_model.layers[4].weight).clone().detach()) #Extract task1 specific means for initialisation of first head task
task_bias_means.append((initial_model.layers[4].bias).clone().detach())

#task1_initial_weight_variance = torch.full_like(initial_model.layers[4].weight,-6,requires_grad=False,dtype=torch.float32)#initialise variances as 10^-6 as per paper
#task1_initial_bias_variance = torch.full_like(initial_model.layers[4].bias,-6,requires_grad=False,dtype=torch.float32)

task_weights_variances.append(torch.full_like(initial_model.layers[4].weight,-6,requires_grad=False,dtype=torch.float32))# store task specific variances 
task_bias_variances.append(torch.full_like(initial_model.layers[4].bias,-6,requires_grad=False,dtype=torch.float32))

In [13]:
l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances = [],[],[],[],[],[],[],[] #bayesian model params
opt_params = []
for i in range(len(shared_weights_mean)):
    l1weightmeans.append(shared_weights_mean[i].detach().clone())   #copy weights and biases from mlp to be initialisation for means in bayesian model
    l1weightvariances.append(shared_weights_variances[i].detach().clone())
    l1biasmeans.append(shared_bias_mean[i].detach().clone())
    l1biasvariances.append(shared_bias_variances[i].detach().clone())
    l1weightmeans[i].requires_grad = True   #Set requires grad to True on bayesian model params to be updated in autodiff
    l1weightvariances[i].requires_grad = True
    l1biasmeans[i].requires_grad = True
    l1biasvariances[i].requires_grad = True
    opt_params.append(l1weightmeans[i])
    opt_params.append(l1weightvariances[i])
    opt_params.append(l1biasmeans[i])
    opt_params.append(l1biasvariances[i])
    

for i in range(len(task_weights_means)):   #Ensure models parameters will be updated through autograd for every task specific parameter
    l2weightmeans.append(task_weights_means[i].detach().clone())
    l2weightvariances.append(task_weights_variances[i].detach().clone())
    l2biasmeans.append(task_bias_means[i].detach().clone())
    l2biasvariances.append(task_bias_variances[i].detach().clone())
    l2weightmeans[i].requires_grad = True
    l2weightvariances[i].requires_grad = True
    l2biasmeans[i].requires_grad = True
    l2biasvariances[i].requires_grad = True
    opt_params.append(l2weightmeans[i])
    opt_params.append(l2weightvariances[i])
    opt_params.append(l2biasmeans[i])
    opt_params.append(l2biasvariances[i])

In [15]:
#def KL_divergence():#l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances,shared_weights_mean,shared_weights_variances,shared_bias_mean,shared_bias_variances,task_weights_means,task_weights_variances,task_bias_means,task_bias_variances):
def KL_divergence(l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances,shared_weights_mean,shared_weights_variances,shared_bias_mean,shared_bias_variances,task_weights_means,task_weights_variances,task_bias_means,task_bias_variances):
    KL = 0 #ignore constant term as it does not contribute to optimisation process
    mu_term = 0
    variance_term = 0
    log_var_term = 0

    for i in range(len(shared_weights_mean)):
        mu_term += (1/2) * torch.sum(torch.mul(torch.square(shared_weights_mean[i]-l1weightmeans[i]),(torch.pow(torch.exp(shared_weights_variances[i]),-2))))  #mu term for first layer weights
        variance_term += (1/2) * torch.sum(torch.mul(torch.pow(torch.exp(shared_weights_variances[i]),-2),torch.square(torch.exp(l1weightvariances[i]))))  
        log_var_term += torch.sum(torch.subtract(shared_weights_variances[i],l1weightvariances[i])) # only need to subtract and ignore prefactor since we are dealing with log variances to start with
        #print(mu_term,variance_term,log_var_term)                                                  # since 0.5 * log(a^2/b^2) =(log a - log b )
        mu_term += (1/2) * torch.sum(torch.mul(torch.square(shared_bias_mean[i]-l1biasmeans[i]),(torch.pow(torch.exp(shared_bias_variances[i]),-2)))) # add contributions from first layer biases 
        variance_term += (1/2) * torch.sum(torch.mul(torch.square(torch.exp(l1biasvariances[i])),torch.pow(torch.exp(shared_bias_variances[i]),-2)))
        log_var_term += torch.sum(torch.subtract(shared_bias_variances[i],l1biasvariances[i])) 

    for task in range(len(l2weightmeans)): # contribution for each task weights 
        mu_term += 1/2 * torch.sum(torch.mul(torch.square(task_weights_means[task] - l2weightmeans[task]),(torch.pow(torch.exp(task_weights_variances[task]),-2))))
        variance_term += 1/2 * torch.sum(torch.mul(torch.pow(torch.exp(task_weights_variances[task]),-2),torch.square(torch.exp(l2weightvariances[task]))))
        log_var_term += torch.sum(torch.subtract(task_weights_variances[task],l2weightvariances[task]))     #torch.sum(torch.log(torch.square(torch.div(task_weights_variances[task][i],l2weightvariances[task][i]))))

      #contribution from task biases
    for task in range(len(l2biasmeans)):#for each task  
        #for i in range(len(task_bias_means[task])):     
        mu_term += 1/2 * torch.sum(torch.mul(torch.square(task_bias_means[task]-l2biasmeans[task]),(torch.pow(torch.exp(task_bias_variances[task]),-2))))
        variance_term += 1/2 * torch.sum(torch.mul(torch.pow(torch.exp(task_bias_variances[task]),-2),torch.square(torch.exp(l2biasvariances[task]))))
        log_var_term += torch.sum(torch.subtract(task_bias_variances[task],l2biasvariances[task]))   # torch.sum(torch.log(torch.square(torch.div(task_bias_variances[task],l2biasvariances[task]))))
    KL = mu_term + variance_term + log_var_term
    print(f"calculated KL divergence {KL - 133633}") #include constant term for completeness
    return (KL - 133633)
def cross_entropy(inputs,targets):
    result = torch.mul(torch.log(inputs),targets)
    return -1*torch.sum(result) #since cross entropy is negative log likelihood for binary classification
                                      

In [10]:
class bayesian_network():#l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances):

    #def forward(self,input,iterations,task,l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances):
    def forward(self,input,iterations,task,*args):
                                                  #number of samples for MC integration of likelihood is iterations
      print(f"task is {task}")
      l2wm,l2wv,l2bm,l2bv = l2weightmeans[task].tile([iterations,1,1]),l2weightvariances[task].tile([iterations,1,1]),l2biasmeans[task].tile([iterations,1]),l2biasvariances[task].tile([iterations,1])
      l2bv,l2bm = l2bv.unsqueeze(-1).expand(-1,-1,len(input.transpose(0,1))),l2bm.unsqueeze(-1).expand(-1,-1,len(input.transpose(0,1))) #tile to compute all normal samples simultaneousluy
      for i in range(len(l1weightmeans)):
        l1wm,l1wv,l1bm,l1bv = l1weightmeans[i].tile([iterations,1,1]),l1weightvariances[i].tile([iterations,1,1]),l1biasmeans[i].tile(([iterations,1])),l1biasvariances[i].tile([iterations,1])
        l1bv,l1bm = l1bv.unsqueeze(-1).expand(-1,-1,len(input.transpose(0,1))),l1bm.unsqueeze(-1).expand(-1,-1,len(input.transpose(0,1)))
        weightnormal = torch.normal(torch.zeros_like(l1wm),torch.ones_like(l1wv))#sample N(0,1) in order to realise weights/biases for layer 1
        biasnormal = torch.normal(torch.zeros_like(l1bm),torch.ones_like(l1bv))
        if i == 0:
          result = torch.add(l1wm,torch.mul(weightnormal,torch.exp(l1wv)))
          result = torch.add(torch.matmul(result,input),torch.add(l1bm,torch.mul(torch.exp(l1bv),biasnormal)))
        if i == 1 :
          result = torch.matmul(torch.add(l1wm,torch.mul(l1wv,weightnormal)),result)
        result = torch.relu(result)
        
      ###forward pass for task specific head 
      weightnormal = torch.normal(torch.zeros_like(l2wm),torch.ones_like(l2wv))#sample N(0,1) in order to realise weights/biases for layer 2
      biasnormal = torch.normal(torch.zeros_like(l2bm),torch.ones_like(l2bv))
      result = torch.matmul(torch.add(l2wm,torch.mul(torch.exp(l2wv),weightnormal)),result)
      result = torch.squeeze(result,-1)
      result = torch.add(result,torch.add(l2bm,torch.mul(torch.exp(l2bv),biasnormal)))
      result = torch.mean(result,dim=0)

      return result.transpose(0,1)
  
    def loss(self,prediction,label):#,label,l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances,shared_weights_mean,shared_weights_variances,shared_bias_mean,shared_bias_variances,task_weights_means,task_weights_variances,task_bias_means,task_bias_variances):   # we will minimise loss so we want a loss that is KL - log_lik = KL + cross entropy
      loss = torch.tensor(0,dtype=torch.float32)#,requires_grad=True)
      loss_fn = torch.nn.BCEWithLogitsLoss()
      loss = loss + loss_fn(prediction,label)
      #print(f"KL and likelihood are {KL_divergence().item(),}")
      #loss += torch.binary_cross_entropy_with_logits(prediction,label)
      loss = loss + KL_divergence(l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances,shared_weights_mean,shared_weights_variances,shared_bias_mean,shared_bias_variances,task_weights_means,task_weights_variances,task_bias_means,task_bias_variances)
      #l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances,shared_weights_mean,shared_weights_variances,shared_bias_mean,shared_bias_variances,task_weights_means,task_weights_variances,task_bias_means,task_bias_variances)
      
      return loss
#def train_bayes_model(self,x, y, lr: float=1e-3, n_epochs: int=120):
    def train_bayes_model(self,x, y,task, lr: float=1e-3, n_epochs: int=50):
    
      optimiser = torch.optim.Adam(params=[*l1weightmeans,*l1weightvariances,*l1biasmeans,*l1biasvariances,*l2weightmeans,*l2weightvariances,*l2biasmeans,*l2biasvariances],lr= lr)#[l1weightmeans[0],l1weightmeans[1],l1weightvariances[0],l1weightvariances[1],l1biasmeans[0],l1biasmeans[1],l1biasvariances[0],l1biasvariances[1],l2weightmeans[task],l2weightvariances[task],l2biasmeans[task],l2biasvariances[task]], lr=lr)
      dloader = DataLoader(list(zip(x, y)), shuffle=True, batch_size=len(x))
      train_loss = torch.tensor(0,dtype=torch.float32)
      for epoch in range(n_epochs):       
        epoch_loss = torch.tensor(0,dtype=torch.float32)
        optimiser.zero_grad()
        for batch_x, batch_y in dloader:
          epoch_loss = epoch_loss +self.loss(self.forward(batch_x.transpose(0,1),10,task,l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances),batch_y)
        print(f"epoch loss is {epoch_loss.item()}")
        epoch_loss.backward()
        optimiser.step()
        train_loss += epoch_loss
        print(f"completed epoch {epoch}, stepping optimiser")  
        if epoch % 20  == 0:
          print(f'Epoch: {epoch+1}/{n_epochs}, loss: {train_loss:.4f}')

      return [l1weightmeans,l1weightvariances,l1biasmeans,l1biasvariances,l2weightmeans,l2weightvariances,l2biasmeans,l2biasvariances]


        
      

In [221]:
bn = bayesian_network()
data = generate_task_data(0)
trained_params = bn.train_bayes_model(data[0], data[1],task = 0 , lr = float(0.001),n_epochs=10)

task is 0
calculated KL divergence 0.0
epoch loss is 10.831725030486075
completed epoch 0, stepping optimiser
Epoch: 1/10, loss: 10.8317
task is 0
calculated KL divergence 7259.46875
epoch loss is 7266.2676557095665


KeyboardInterrupt: 

In [8]:
def create_new_prior_from_posterior():  #updates prior with current posterior ready to build next bead
    for i in range(len(l1weightmeans)):
        shared_weights_mean[i] = l1weightmeans[i].clone().detach()   #change prior to current models params but detach so they retian requires_grad = False
        shared_weights_variances[i] = l1weightmeans[i].clone().detach()
        shared_bias_mean[i] = l1biasmeans[i].clone().detach()
        shared_bias_variances[i] = l1biasvariances[i].clone().detach()
        
    for i in range(len(task_weights_means)):   #Ensure models parameters will be updated through autograd for every task specific parameter
        task_weights_means[i] = l2weightmeans[i].clone().detach()
        task_weights_variances[i] = l2weightvariances[i].clone().detach()
        task_bias_means[i] = l2biasmeans[i].clone().detach()
        task_bias_variances[i] = l2biasvariances[i].clone().detach()
    return None
def add_new_task_params(): #add model paramaters for the next task head
    l2weightmeans.append(torch.normal(mean = torch.zeros(l2weightmeans[0].shape))) #mean 0 std 1 with same shape as other output layers
    l2weightvariances.append(torch.full(l2weightvariances[0].shape,-6,dtype=torch.float32))  # variances initialsied 10^-6 as in paper
    l2biasmeans.append(torch.normal(mean = torch.zeros(l2biasmeans[0].shape)))
    l2biasvariances.append(torch.full(l2biasvariances[0].shape,-6,dtype=torch.float32))
    
    task_weights_means.append(l2weightmeans[-1].detach())#copy model parameters into prior and detach since they are the prior 
    task_weights_variances.append(l2weightvariances[-1].detach())
    task_bias_means.append(l2biasmeans[-1].detach())
    task_bias_variances.append(l2biasvariances[-1].detach())
    

    l2weightmeans[-1].requires_grad = True
    l2weightvariances[-1].requires_grad = True
    l2biasmeans[-1].requires_grad = True
    l2biasvariances[-1].requires_grad = True
        
    


In [19]:
#main training sequence

bn = bayesian_network()
data = generate_task_data(0)
testdata = data[2]
task_scores = [[],[],[],[],[]] 
bn.train_bayes_model(x = data[0],y =  data[1], lr = float(0.0001),n_epochs=300,task=0)
testlabels = torch.tensor(data[3]).argmax(dim =1)
testresult = bn.forward(torch.tensor(testdata).transpose(0,1),300,0)
testresult = testresult.argmax(dim = 1)
score = 0
for k in range(len(testresult)):         ##
    if testresult[i] == testlabels[i]:
        score += 1
task_scores[0].append(score/len(testresult))    #should just be able to repeat training here for other tasks 
create_new_prior_from_posterior() 
add_new_task_params()
data = generate_task_data(2)
testdata = data[2]
bn.train_bayes_model(x = data[0],y =  data[1], lr = float(0.001),n_epochs=10,task=1)
testlabels = torch.tensor(data[3]).argmax(dim =1)
testresult = bn.forward(torch.tensor(testdata).transpose(0,1),300,0)
testresult = testresult.argmax(dim = 1)
score = 0
for k in range(len(testresult)):
    if testresult[i] == testlabels[i]:
        score += 1
task_scores[1].append(score/len(testresult))





task is 0
calculated KL divergence 1458238.375
epoch loss is 1458239.5184938477
completed epoch 0, stepping optimiser
Epoch: 1/300, loss: 1458239.5000
task is 0
calculated KL divergence 1458211.125
epoch loss is 1458211.652197572
completed epoch 1, stepping optimiser
task is 0
calculated KL divergence 1458183.25
epoch loss is 1458184.2878089326
completed epoch 2, stepping optimiser
task is 0
calculated KL divergence 1458155.875
epoch loss is 1458156.802481062
completed epoch 3, stepping optimiser
task is 0
calculated KL divergence 1458129.125
epoch loss is 1458129.7479860482
completed epoch 4, stepping optimiser
task is 0
calculated KL divergence 1458102.375
epoch loss is 1458103.2561366821
completed epoch 5, stepping optimiser
task is 0
calculated KL divergence 1458075.375
epoch loss is 1458076.1461678136
completed epoch 6, stepping optimiser
task is 0
calculated KL divergence 1458048.75
epoch loss is 1458050.268219961
completed epoch 7, stepping optimiser
task is 0
calculated KL dive