##Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import pickle
import gzip
import copy
import wandb

In [2]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Misc

In [3]:
def KL_DIV(mu0,rho0,mu1,rho1):
  std1 = torch.exp(0.5*rho1)
  std0 = torch.exp(0.5*rho0)
  kl_div = (0.5*((std0/std1)**2+((mu0-mu1)**2/(std1**2))-1+2*torch.log(std1/std0))).sum()
  return kl_div


#Models

##Layers

In [53]:
class BayesianLinear(nn.Module):
    def __init__(self, input_size, output_size,device,init_rho,prior_rho,dist="gaussian"):
        super(BayesianLinear, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.device=device
        self.dist=dist

        self.post_w_mu = nn.Parameter(torch.zeros(output_size, input_size))
        self.post_b_mu = nn.Parameter(torch.zeros(output_size))

        self.post_w_rho = nn.Parameter(torch.full((output_size, input_size), init_rho))
        self.post_b_rho = nn.Parameter(torch.full((output_size,), init_rho))

        self.prior_w_mu = torch.zeros(output_size, input_size).to(device)
        self.prior_b_mu = torch.zeros(output_size).to(device)

        self.prior_w_rho = torch.full((output_size, input_size), prior_rho).to(device)
        self.prior_b_rho = torch.full((output_size,), prior_rho).to(device)
        rate = torch.tensor(1.0, device=device)
        self.exp_dist=torch.distributions.Exponential(rate=rate)
        self.init_parameters()


    def to(self, device):
        super(BayesianLinear, self).to(device)
        self.prior_w_mu = self.prior_w_mu.to(device)
        self.prior_b_mu = self.prior_b_mu.to(device)
        self.prior_w_rho = self.prior_w_rho.to(device)
        self.prior_b_rho = self.prior_b_rho.to(device)


    def init_parameters(self):
        nn.init.trunc_normal_(self.post_w_mu,std=0.1)
        nn.init.trunc_normal_(self.post_b_mu,std=0.1)
        nn.init.trunc_normal_(self.prior_w_mu,std=0.1)
        nn.init.trunc_normal_(self.prior_b_mu,std=0.1)


    def forward(self,x,T):
      a = self.sample_activation(x) #(T,batch_size,output_size)
      b = self.sample_b(x=x) #(T,1,output_size)
      return a+b #(T,batch_size,output_size)


    def sample_activation(self,x):
      if self.dist=="gaussian":
        act_mean = F.linear(x,self.post_w_mu) #(T,batch_size,output_size)
        std = torch.exp(0.5*self.post_w_rho) #(output_size,input_size)
        act_std = torch.sqrt(1e-8+F.linear((x**2),(std**2))) #(T,batch_size,output_size)
        act_eps = torch.randn(x.shape[0],1,self.output_size).to(self.device) #(T,batch_size,output_size)
        act = act_mean+act_std*act_eps #(T,batch_size,output_size)
        return act #(T,batch_size,output_size)
      if self.dist=="radial":
        eps=torch.randn((x.shape[0],*self.post_w_rho.shape),device=self.device)
        eps=eps/torch.linalg.norm(eps,dim=(1,2)).view(-1,1,1)
        weights=self.post_w_mu.unsqueeze(0)+torch.randn((x.shape[0],1,1),device=self.device)*eps*torch.exp(0.5*self.post_w_rho.unsqueeze(0)) #(T,output_size,input_size)
        act=weights.unsqueeze(1)@x.unsqueeze(-1) #(T,1,output_size,input_size)@(T,batch_size,input_size,1)=#(T,1,output_size,1)
        return act.squeeze(-1)
      if self.dist=="laplace":
        bernoulli=2*torch.bernoulli(torch.full((x.shape[0],*self.post_w_rho.shape),0.5,device=self.device))-1
        eps=self.exp_dist.sample((x.shape[0],*self.post_w_rho.shape))
        weights=(self.post_w_mu.unsqueeze(0)+bernoulli*eps*torch.exp(0.5*self.post_w_rho.unsqueeze(0))).unsqueeze(1) #(T,1,output_size,input_size)
        act=weights@x.unsqueeze(-1) #(T,batch_size,input_size,1)
        return act.squeeze(-1) #(T,batch_size,input_size)


    def sample_weights_radial(self,n_samples):
      eps=torch.randn((n_samples,*self.post_w_rho.shape),device=self.device)
      eps=eps/torch.linalg.norm(eps,dim=(1,2)).view(-1,1,1)
      r=torch.randn((eps.shape[0],1,1),device=self.device)
      weights=self.post_w_mu.view(1,*self.post_w_mu.shape)+r*eps*torch.exp(0.5*self.post_w_rho).unsqueeze(0)
      return weights


    def Lcrossentropy(self,n_samples):
      norms=torch.linalg.norm((self.sample_weights_radial(n_samples)-self.prior_w_mu.unsqueeze(0))/torch.exp(0.5*self.prior_w_rho.unsqueeze(0)),dim=(1,2))
      L=-0.5*torch.mean(norms**2)
      L+=-(self.prior_w_mu.numel()-1)*torch.mean(torch.log(norms))

      x=torch.ones(n_samples,1,1)
      normsb=torch.linalg.norm((self.sample_b(x).squeeze(1)-self.prior_b_mu.unsqueeze(0))/torch.exp(0.5*self.prior_b_rho.unsqueeze(0)),dim=1)
      Lb=-0.5*torch.mean(normsb**2)
      Lb+=-(self.prior_b_mu.numel()-1)*torch.mean(torch.log(normsb))

      return L+Lb


    def Lentropy(self,n_samples):
      x=torch.ones(n_samples,1,1)
      norms=torch.linalg.norm((self.sample_weights_radial(n_samples)-self.post_w_mu.unsqueeze(0))/torch.exp(0.5*self.post_w_rho.unsqueeze(0)),dim=(1,2))
      L=-torch.log(torch.exp(0.5*self.post_w_rho)).sum()

      sample=self.sample_b(x).squeeze(1)
      normsb=torch.linalg.norm((sample-self.post_b_mu.unsqueeze(0))/torch.exp(0.5*self.post_b_rho.unsqueeze(0)),dim=1)
      Lb=-torch.log(torch.exp(0.5*self.post_b_rho)).sum()
      return L+Lb


    def sample_b(self,x):
      if self.dist=="gaussian":
        b_eps = torch.randn(x.shape[0],self.output_size).to(self.device)
        b = self.post_b_mu.unsqueeze(0)+torch.exp(0.5*self.post_b_rho.unsqueeze(0))*b_eps
        return b.unsqueeze(1)
      if self.dist=="radial":
        b_eps = torch.randn(x.shape[0],self.output_size).to(self.device)
        b_eps=b_eps/torch.linalg.norm(b_eps,dim=1).view(-1,1)
        b = self.post_b_mu.unsqueeze(0)+torch.randn((x.shape[0],1),device=self.device)*torch.exp(0.5*self.post_b_rho.unsqueeze(0))*b_eps
        return b.unsqueeze(1)
      if self.dist=="laplace":
        bernoulli=2*torch.bernoulli(torch.full((x.shape[0],1,*self.post_b_rho.shape),0.5,device=self.device))-1
        eps=self.exp_dist.sample((x.shape[0],1,*self.post_b_rho.shape)).to(self.device)
        b=self.post_b_mu.view(1,1,-1)+bernoulli*eps*torch.exp(0.5*self.post_b_rho.view(1,1,-1)) #(T,1,output_size)
        return b #(T,1,output_size)

    def update_prior(self):
      self.prior_w_mu=copy.deepcopy(self.post_w_mu.data)
      self.prior_w_rho=copy.deepcopy(self.post_w_rho.data)
      self.prior_b_mu=copy.deepcopy(self.post_b_mu.data)
      self.prior_b_rho=copy.deepcopy(self.post_b_rho.data)

    def kl_div(self):
      if self.dist=="gaussian":
        return(KL_DIV(self.post_w_mu,self.post_w_rho,self.prior_w_mu,self.prior_w_rho)+KL_DIV(self.post_b_mu,self.post_b_rho,self.prior_b_mu,self.prior_b_rho))
      if self.dist=="laplace":
        kl_div=0.
        bpost=torch.exp(0.5*self.post_w_rho)
        bprior=torch.exp(0.5*self.prior_w_rho)
        kl_div+=torch.log(bprior/bpost).sum()
        kl_div+=((bpost*torch.exp(-torch.abs(self.post_w_mu-self.prior_w_mu)/bpost)+torch.abs(self.post_w_mu-self.prior_w_mu))/bprior).sum()
        bpost=torch.exp(0.5*self.post_b_rho)
        bprior=torch.exp(0.5*self.prior_b_rho)
        kl_div+=torch.log(bprior/bpost).sum()
        kl_div+=((bpost*torch.exp(-torch.abs(self.post_b_mu-self.prior_b_mu)/bpost)+torch.abs(self.post_b_mu-self.prior_b_mu))/bprior).sum()
        return kl_div

class NonBayesianLinear(nn.Module):
    def __init__(self, input_size, output_size,device):
        super(NonBayesianLinear, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.device=device

        self.w = nn.Parameter(torch.empty(output_size, input_size))
        self.b = nn.Parameter(torch.empty(output_size))

        self.init_parameters()


    def to(self, device):
        super(BayesianLinear, self).to(device)
        self.w = self.w.to(device)
        self.b = self.b.to(device)


    def init_parameters(self):
        nn.init.trunc_normal_(self.w,std=0.1)
        nn.init.trunc_normal_(self.b,std=0.1)

    def forward(self,x,T):
      return F.linear(x,self.w,self.b)
    def update_prior(self):
      pass

##Models

In [71]:
class BayesianFC(nn.Module):
  def __init__(self,input_size,output_size,hidden_size,n_hidden,device,init_rho,prior_rho,n_fisher_samples=3000,dist="gaussian",output="categorical",bayesian=True):
    super(BayesianFC,self).__init__()
    self.device=device
    self.layers=nn.ModuleList()
    if bayesian:
      if dist=="laplace":
        self.layers.append(BayesianLinear(input_size,hidden_size,device,init_rho,prior_rho,dist=dist))
        for i in range(n_hidden-1):
          self.layers.append(BayesianLinear(hidden_size,hidden_size,device,init_rho,prior_rho,dist=dist))
        self.layers.append(BayesianLinear(hidden_size,output_size,device,init_rho,prior_rho,dist=dist))
      else:
        self.layers.append(BayesianLinear(input_size,hidden_size,device,init_rho,prior_rho,dist=dist))
        for i in range(n_hidden-1):
          self.layers.append(BayesianLinear(hidden_size,hidden_size,device,init_rho,prior_rho,dist=dist))
        self.layers.append(BayesianLinear(hidden_size,output_size,device,init_rho,prior_rho,dist=dist))
    else:
      print("Non Bayesian network")
      self.layers.append(NonBayesianLinear(input_size,hidden_size,device))
      for i in range(n_hidden-1):
        self.layers.append(NonBayesianLinear(hidden_size,hidden_size,device))
      self.layers.append(NonBayesianLinear(hidden_size,output_size,device))
    self.bn = nn.BatchNorm1d(100)
    self.n_fishersamples=n_fisher_samples
    self.fisher_dicts=[]
    self.param_history=[]
    self.SI=None
    self.dist=dist
    self.output=output
    self.grad={}


  def transfer_weights(self,nonbayesian_model):
    for i,layer in enumerate(self.layers):
      layer.post_w_mu.data=nonbayesian_model.layers[i].w.data
      layer.post_b_mu.data=nonbayesian_model.layers[i].b.data
      layer.to(self.device)


  def forward(self,x,T):
    batch_size=x.shape[0]
    if bayesian:
      x=x.unsqueeze(0).tile((T,1,1))
    else:
      x=x.unsqueeze(0)
    for i,layer in enumerate(self.layers):
      x=layer(x,T)
      if i<len(self.layers)-1:
        x=F.relu(x,T)
    output_dim = x.shape[-1]
    if self.output=="categorical":
      return F.log_softmax(x, dim=-1)
    elif self.output=="gaussian":
      return x[:,:,:10],torch.ones(x[:,:,10:].shape,device=self.device)*(-20)

  def kl_div(self,n_data,n_samples=100):
    if self.dist=="gaussian" or self.dist=="laplace":
      total_kl = 0
      for layer in self.layers:
          total_kl += layer.kl_div()
      return total_kl/n_data
    if self.dist=="radial":
      L_entropy=0
      L_cross_entropy=0
      for layer in self.layers:
        L_entropy+=layer.Lentropy(n_samples)
        L_cross_entropy+=layer.Lcrossentropy(n_samples)

      return (L_entropy-L_cross_entropy)/n_data


  def calculate_fim(self,train_loader1):
    #Fisher information matrix calculation
    fisher_dic = {}

    for n, p in self.named_parameters():
      fisher_dic[n] = p.detach().clone().zero_()

    self.eval()
    for batch_id, (x, y) in enumerate(train_loader1):
      if batch_id >= self.n_fishersamples:
        break
      x,y=x.to(self.device),y.to(self.device)
      loss = self.nll(x,y)
      self.zero_grad()
      loss.backward()
      for n, p in self.named_parameters():
        if p.requires_grad and p.grad is not None:
          fisher_dic[n] += p.grad.detach() ** 2

    fisher_dic = {n: p/batch_id for n,p in fisher_dic.items()}
    self.fisher_dicts.append(fisher_dic)
    param_dic = {}
    for n, p in self.named_parameters():
      param_dic[n] = p.detach().clone()
    self.param_history.append(param_dic)
    self.train()


  def record_grads(self):
    for n, p in self.named_parameters():
      if p.requires_grad and p.grad is not None:
        if n in self.grad:
          self.grad[n].append(p.grad.detach())
        else:
          self.grad[n]=[p.grad.detach()]


  def std(self):
    var={}
    sum_std=0
    for n, p in self.named_parameters():
      if p.requires_grad and p.grad is not None:
        grads=torch.stack(self.grad[n])
        grads=grads.view(grads.shape[0],-1)
        print(grads.shape)
        stds=grads.std(dim=0)
    sum_std+=stds.sum()
    self.grad={}
    return sum_std


  def calculate_ewc(self):
    #Calculation of EWC loss
    Lloss = []
    for task in range(len(self.fisher_dicts)):
      for n, p in self.named_parameters():
        if p.requires_grad:
          Lloss.append( torch.sum( self.fisher_dicts[task][n] * ( p-self.param_history[task][n] ) ** 2 ))

    return (1./2)*sum(Lloss)


  def update_prior(self):
    for layer in self.layers:
        layer.update_prior()


  def elbo(self,x,y,n_data,T=10):
    nll=self.nll(x,y,T=T)
    kl=self.kl_div(n_data)
    return nll+kl


  def nll(self,x,y,T=10):
    if self.output=="categorical":
      log_probs=self.forward(x,T=T).mean(dim=0)
      nll=F.nll_loss(log_probs,y,reduction="mean")
      return nll
    if self.output=="gaussian":
      mu,rho=self.forward(x,T=T)
      sigma=torch.exp(0.5*rho)
      y_one_hot=F.one_hot(y,num_classes=10).unsqueeze(0)
      log_term=torch.sum(-torch.log(sigma),dim=-1)
      square_term=torch.sum(-(1/(2*(sigma**2)))*(y_one_hot-mu)**2,dim=-1)
      nll=(-log_term-square_term).mean()
      return nll


  def rmse_gaussian(self,x,y,T=10):
    if self.output=="gaussian":
      mu,rho=self.forward(x,T)
      sigma=torch.exp(0.5*rho)
      y_one_hot=F.one_hot(y,num_classes=10).unsqueeze(0)
      rmse=torch.sqrt(torch.mean(torch.sum(sigma**2+(mu-y_one_hot)**2,dim=-1)))
      return rmse
    else:
      raise ValueError("function and output type not matching")


class MultiBayesianFC(nn.Module):

  def __init__(self,input_size,hidden_size,n_hidden,device,n_tasks=5,n_fisher_samples=3000):
    super(MultiBayesianFC,self).__init__()
    self.device=device
    self.layers=nn.ModuleList()
    self.layers.append(BayesianLinear(input_size,hidden_size,device))
    for i in range(n_hidden-1):
      self.layers.append(BayesianLinear(hidden_size,hidden_size,device))
    self.last_layers=nn.ModuleList()
    for i in range (n_tasks):
      self.last_layers.append(BayesianLinear(hidden_size,2,device))
    self.bn = nn.BatchNorm1d(100)
    self.n_fishersamples=n_fisher_samples
    self.fisher_dicts=[]
    self.param_history=[]
    self.SI=None


  def transfer_weights(self,nonbayesian_model):
    for i,layer in enumerate(self.layers):
      layer.post_w_mu.data=nonbayesian_model.layers[i].w.data
      layer.post_b_mu.data=nonbayesian_model.layers[i].b.data
      layer.to(self.device)
    for i,layer in enumerate(self.last_layers):
      layer.post_w_mu.data=nonbayesian_model.layers[i].w.data
      layer.post_b_mu.data=nonbayesian_model.layers[i].b.data
      layer.to(self.device)


  def forward(self,x,id_task,n_samples=1):
    x_tiled=x.tile((n_samples, 1))
    for i,layer in enumerate(self.layers):
      x_tiled=F.relu(layer(x_tiled))
    x_tiled=self.last_layers[id_task](x_tiled)
    output_dim = x_tiled.shape[-1]
    x_tiled = x_tiled.view(n_samples, -1, output_dim)

    return F.log_softmax(x_tiled, dim=-1).mean(dim=0)


  def kl_div(self,n_data):
    total_kl = 0
    for layer in self.layers:
        total_kl += layer.kl_div()
    return total_kl/n_data


  def calculate_fim(self,train_loader1,task_id):
    #Fisher information matrix calculation
    fisher_dic = {}

    for n, p in self.named_parameters():
      if n.startswith("layers"):
        fisher_dic[n] = p.detach().clone().zero_()

    self.eval()
    for batch_id, (x, y) in enumerate(train_loader1):
      if batch_id >= self.n_fishersamples:
        break
      x,y=x.to(self.device),y.to(self.device)
      loss = self.nll(x,y,task_id)
      self.zero_grad()
      loss.backward()
      for n, p in self.named_parameters():
        if n.startswith("layers"):
          if p.requires_grad and p.grad is not None:
            fisher_dic[n] += p.grad.detach() ** 2

    fisher_dic = {n: p/batch_id for n,p in fisher_dic.items()}
    self.fisher_dicts.append(fisher_dic)
    param_dic = {}
    for n, p in self.named_parameters():
      if n.startswith("layers"):
        param_dic[n] = p.detach().clone()
    self.param_history.append(param_dic)
    self.train()


  def calculate_ewc(self):
    #EWC loss calculation
    Lloss = []
    for task in range(len(self.fisher_dicts)):
      for n, p in self.named_parameters():
        if n.startswith("layers"):
          if p.requires_grad:
            Lloss.append( torch.sum( self.fisher_dicts[task][n] * ( p-self.param_history[task][n] ) ** 2 ))

    return (1./2)*sum(Lloss)


  def update_prior(self):
    for layer in self.layers:
        layer.update_prior()
    for layer in self.last_layers:
        layer.update_prior()


  def elbo(self,x,y,n_data,task_id,T=10):
    x_tiled = x.tile((T, 1))  # (T * batch_size, input_size)
    y_tiled = y.tile((T,))
    nll=self.nll(x_tiled,y_tiled,task_id)
    kl=self.kl_div(n_data)
    return nll+kl


  def nll(self,x,y,task_id):
    log_probs=self.forward(x,task_id)
    nll=F.nll_loss(log_probs,y,reduction="mean")
    return nll

##SI-specific class

In [6]:
class SynapticIntelligence():
  def __init__(self,model,si_c=0.1,epsilon=0.1,gamma=0.9):
    self.prev_SI={}
    self.omega_SI={}
    self.W={}
    self.p_old={}
    self.model=model
    self.si_c = si_c
    self.epsilon = epsilon
    self.gamma=gamma

  def init_SI(self):
    #initializing the dictionary
    for n, p in self.model.named_parameters():
      if n.startswith("layers"):
        if p.requires_grad:
            self.prev_SI[n] = p.data.clone()

  def prepare_w_P_SI(self):
    #other initialization
    for n, p in self.model.named_parameters():
      if n.startswith("layers"):
        if p.requires_grad:
            self.W[n] = p.data.clone().zero_()
            self.p_old[n] = p.data.clone()

  def update_parameter_importance(self):
    #to compute matrixes which measure parameter importances
    for n, p in self.model.named_parameters():
      if n.startswith("layers"):
        if p.requires_grad:
            if p.grad is not None:
                self.W[n].add_(-p.grad*(p.detach()-self.p_old[n]))
            self.p_old[n] = p.detach().clone()

  def update_omega(self):
    #using previous calculations to update omega
        for n, p in self.model.named_parameters():
          if n.startswith("layers"):
            if p.requires_grad:
                p_prev = self.prev_SI[n]
                p_current = p.detach().clone()
                p_change = p_current - p_prev
                omega_add = self.W[n]/(p_change**2 + self.epsilon)
                try:
                    omega = self.omega_SI[n]
                except KeyError:
                    omega = p.detach().clone().zero_()
                omega_new = omega + omega_add

                self.prev_SI[n] = p_current
                self.omega_SI[n] = omega_new

  def divide_omega(self):
    pass
    '''
    for n, p in self.model.named_parameters():
      if n.startswith("layers"):
        if p.requires_grad:
          self.omega_SI[n]=self.omega_SI[n]*self.gamma'''

  def surrogate_l(self):
    #Calculation of SI loss
        try:
            losses = []
            for n, p in self.model.named_parameters():
              if n.startswith("layers"):
                if p.requires_grad:
                    prev_values = self.prev_SI[n]
                    omega = self.omega_SI[n]
                    losses.append((omega * (p-prev_values)**2).sum())
            return sum(losses)
        except KeyError:
            return torch.tensor(0., device=self._device())


#Utils

In [7]:
def permute_MNIST(x_train,y_train,x_test,y_test,task_id):
    np.random.seed(task_id)
    permutation = np.random.permutation(x_train.shape[1])
    return x_train[:, permutation], y_train, x_test[:, permutation], y_test

def split_MNIST(x_train,y_train,x_test,y_test,task_id):
    Tasks=[[0,1],[2,3],[4,5],[6,7],[8,9]]
    classes=Tasks[task_id]
    idxs_train=np.where(np.logical_or(y_train==classes[0],y_train==classes[1]))[0]
    idxs_test=np.where(np.logical_or(y_test==classes[0],y_test==classes[1]))[0]
    if task_id==1:
      print(y_test[idxs_test])
    return x_train[idxs_train],1*(y_train[idxs_train]==y_train[idxs_train][0]),x_test[idxs_test],1*(y_test[idxs_test]==y_train[idxs_train][0])


''' Coreset methods '''

def random_coreset(x_train, y_train, x_coreset, y_coreset, nb_coreset,dataset="split"):

    ids = np.random.choice(x_train.shape[0], nb_coreset, False)

    x_coreset.append(x_train[ids,:])
    y_coreset.append(y_train[ids])

    x_train = np.delete(x_train, ids, axis=0)
    y_train = np.delete(y_train, ids, axis=0)

    return x_train, y_train, x_coreset, y_coreset

def k_center(x_train, y_train, x_coreset, y_coreset, nb_coreset,dataset="split"):

    ids = []
    visited = np.ones(x_train.shape[0])
    id = 0

    while len(ids)<nb_coreset:
        ds = dists(x_train,id)*visited
        id = np.atleast_1d(np.argmax(ds))[0]
        visited[id]=0
        ids.append(id)

    x_coreset.append(x_train[ids,:])
    y_coreset.append(y_train[ids])

    x_train = np.delete(x_train, ids, axis=0)
    y_train = np.delete(y_train, ids, axis=0)

    if dataset=="permuted":
      return ids
    else:
      return x_train, y_train, x_coreset, y_coreset

def dists(x_train, id):
    return np.linalg.norm(x_train-x_train[id,:],axis=1)

def get_total_elements(loader):
  total_elements = 0
  for batch_idx, (data, target) in enumerate(loader):
      total_elements += data.shape[0]
  return total_elements


#Training functions

##Loaders

In [8]:
import torch
from torch.utils.data import DataLoader, TensorDataset
def dataloader(X, Y, batch_size=128, shuffle=True, num_workers=0, pin_memory=False):
    loaders = []
    for x, y in zip(X, Y):
        assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray), "X and Y must be NumPy arrays"
        dataset = TensorDataset(torch.from_numpy(x).float(), torch.from_numpy(y).long())
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers, pin_memory=pin_memory)
        loaders.append(loader)
    return loaders

def aggdataloader(X, Y, batch_size=128, shuffle=True, num_workers=0, pin_memory=False):
    loaders = []
    x_agg,y_agg=None,None
    for i,(x, y) in enumerate(zip(X, Y)):
        if i==0:
          x_agg=x
          y_agg=y
        else:
          x_agg=np.concatenate((x_agg,x),axis=0)
          y_agg=np.concatenate((y_agg,y),axis=0)
        assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray), "X and Y must be NumPy arrays"
        dataset = TensorDataset(torch.from_numpy(x_agg).float(), torch.from_numpy(y_agg).long())
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers, pin_memory=pin_memory)
        loaders.append(loader)
    return loaders



def loaders(nb_tasks,coreset_size,batch_size=128,dataset="permuted"):
  with gzip.open('sample_data/mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
  X_trains,Y_trains,Xk_trains,Yk_trains,X_tests,Y_tests=[],[],[],[],[],[]
  for i in range(nb_tasks):
    if dataset=="permuted":
      i_x_train,i_y_train,i_x_test,i_y_test=permute_MNIST(train_set[0],train_set[1],test_set[0],test_set[1],i)
    elif dataset=="split":
      i_x_train,i_y_train,i_x_test,i_y_test=split_MNIST(train_set[0],train_set[1],test_set[0],test_set[1],i)
    X_trains.append(i_x_train)
    Y_trains.append(i_y_train)
    Xk_trains.append(i_x_train)
    Yk_trains.append(i_y_train)
    X_tests.append(i_x_test)
    Y_tests.append(i_y_test)
  X_coresets,Xk_coresets,Y_coresets,Yk_coresets=[],[],[],[]
  if dataset=="split":
    for i in range(nb_tasks):
      Xk_trains[i],Yk_trains[i],Xk_coresets,Yk_coresets=k_center(Xk_trains[i],Yk_trains[i],X_coresets,Y_coresets,coreset_size)
      X_trains[i],Y_trains[i],X_coresets,Y_coresets=random_coreset(X_trains[i],Y_trains[i],X_coresets,Y_coresets,coreset_size)
  elif dataset=="permuted":
    ids=k_center(Xk_trains[i],Yk_trains[i],X_coresets,Y_coresets,coreset_size,dataset="permuted")
    for i in range(nb_tasks):
      Xk_trains[i],Yk_trains[i],Xk_coresets,Yk_coresets=np.delete(Xk_trains[i],ids,axis=0),np.delete(Yk_trains[i],ids,axis=0),Xk_coresets+[Xk_trains[i][ids,:]],Yk_coresets+[Yk_trains[i][ids]]
      X_trains[i],Y_trains[i],X_coresets,Y_coresets=random_coreset(X_trains[i],Y_trains[i],X_coresets,Y_coresets,coreset_size)
  #Make lists of dataloaders
  train_loaders=dataloader(X_trains,Y_trains,batch_size)
  train_loaders1=dataloader(X_trains,Y_trains,1)
  train_loadersk=dataloader(Xk_trains,Yk_trains,batch_size)
  train_loadersk1=dataloader(Xk_trains,Yk_trains,1)
  test_loaders=dataloader(X_tests,Y_tests,batch_size)
  if dataset=="permuted":
    coreset_loaders=aggdataloader(X_coresets,Y_coresets,batch_size)
    kcoreset_loaders=aggdataloader(Xk_coresets,Yk_coresets,batch_size)
  elif dataset=="split":
    coreset_loaders=dataloader(X_coresets,Y_coresets,batch_size)
  return train_loaders,test_loaders,coreset_loaders,train_loaders1,train_loadersk,kcoreset_loaders,train_loadersk1

##Train and test

In [142]:
def train(model,loader,optimizer,num_epochs,device,clip_value,task_id,obj="nll",is_print=True,lambda_ewc=50.,lambda_si=50000000.):
  model.train()
  n_data=get_total_elements(loader)
  if obj=="si":
    if task_id==0:
      model.SI=SynapticIntelligence(model)
      model.SI.init_SI()
    model.SI.prepare_w_P_SI()

  for epoch in range(num_epochs):
    running_loss = 0.
    running_kldiv=0.
    j=0
    for inputs, targets in loader:
      j+=1
      inputs, targets = inputs.to(device), targets.to(device)
      if obj=="nll":
        loss=model.nll(inputs,targets)
      elif obj=="elbo":
        loss=model.elbo(inputs,targets,n_data)
      elif obj=="ewc":
        loss=model.nll(inputs,targets)+lambda_ewc*model.calculate_ewc()
      elif obj=="si":
        if task_id==0:
          loss=model.nll(inputs,targets)
        else:
          loss=model.nll(inputs,targets)+lambda_si*model.SI.surrogate_l()
      optimizer.zero_grad()
      loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      optimizer.step()
      if obj=="si":
        model.SI.update_parameter_importance()
      running_loss += loss.item()
    if is_print:
      print(f"Epoch {epoch}: Loss: {running_loss/n_data}, KL-Div : {running_kldiv/n_data}")

def train_split(model,loader,optimizer,num_epochs,device,clip_value,task_id,obj="nll",is_print=True,lambda_ewc=20,lambda_si=0.2):
  model.train()
  n_data=get_total_elements(loader)
  if obj=="si":
    if task_id==0:
      model.SI=SynapticIntelligence(model)
      model.SI.init_SI()
    model.SI.prepare_w_P_SI()

  for epoch in range(num_epochs):
    running_loss = 0.
    running_kldiv=0.
    j=0
    for inputs, targets in loader:
      j+=1
      inputs, targets = inputs.to(device), targets.to(device)
      if obj=="nll":
        loss=model.nll(inputs,targets,task_id)
      elif obj=="elbo":
        loss=model.elbo(inputs,targets,n_data,task_id)
      elif obj=="ewc":
        loss=model.nll(inputs,targets,task_id)+lambda_ewc*model.calculate_ewc()
      elif obj=="si":
        if task_id==0:
          loss=model.nll(inputs,targets,task_id)
        else:
          loss=model.nll(inputs,targets,task_id)+lambda_si*model.SI.surrogate_l()
      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
      optimizer.step()
      if obj=="si":
        model.SI.update_parameter_importance()
      running_loss += loss.item()
    if is_print:
      print(f"Epoch {epoch}: Loss: {running_loss/n_data}, KL-Div : {running_kldiv/n_data}")

def test(model,loader,device,task_id=None,n_samples=100,dataset='split'):
  model.to(device)
  model.eval()
  n_data=get_total_elements(loader)
  correct = 0
  rmses=[]
  with torch.no_grad():
    for inputs, targets in loader:
      inputs, targets = inputs.to(device), targets.to(device)
      if model.output=="categorical":
        if task_id is not None:
          outputs = model(inputs,task_id,n_samples)
        else:
          if model.dist=="gaussian":
            outputs=model(inputs,T=10)
          elif model.dist=="radial" or model.dist=="laplace":
            outputs=model(inputs,T=10)
        predicted=outputs.mean(dim=0).argmax(dim=1)
        correct += (predicted == targets).sum().item()
      if model.output=="gaussian":
        rmse=model.rmse_gaussian(inputs,targets)
        rmses.append(rmse.item())
        mu,rho=model(inputs,n_samples)
        outputs=mu.mean(dim=0)
        predicted=outputs.argmax(dim=1)
        correct += (predicted == targets).sum().item()
  if model.output=="categorical":
    return correct/n_data
  elif model.output=="gaussian":
    return np.mean(rmses),correct/n_data

##Core VCL

In [100]:
def run_vcl(
    model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,lr=0.001,dataset="permuted",obj="elbo",clip_value = 10.,
lambda_ewc=200.):

    ACC = np.zeros((len(train_loaders),len(train_loaders)))
    RMSE=np.zeros((len(train_loaders),len(train_loaders)))
    optimizer=None
    model.to(device)
    if obj=="elbo":
      objeval="elbo"
    else:
      objeval="nll"

    for i in range(len(train_loaders)):
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        if dataset=="permuted":
          train(model, train_loaders[i], optimizer, epochs, device, task_id=i,obj=obj,clip_value=clip_value,lambda_ewc=lambda_ewc)
        elif dataset=="split":
          train_split(model, train_loaders[i], optimizer, epochs, device, task_id=i,obj=obj,clip_value=clip_value,lambda_ewc=lambda_ewc)
        model.update_prior()

        eval_model = copy.deepcopy(model)
        optimizer = torch.optim.Adam(eval_model.parameters(), lr=lr)

        if dataset=="permuted":
          train(eval_model, coreset_loaders[i], optimizer, epochs_coresets, device,task_id=i, obj=objeval,is_print=False,clip_value=clip_value)
        elif dataset=="split":
          for k in range (0,i+1):
            train_split(eval_model, coreset_loaders[k], optimizer, epochs_coresets, device,task_id=k, obj=objeval,is_print=False,clip_value=clip_value)

        if dataset=="permuted":
          if model.output=="gaussian":
            for k in range(i+1):
                RMSE[i][k],ACC[i][k]=test(eval_model, test_loaders[k], device)
            print("RMSE",RMSE[i])
            print("ACC",ACC[i])
          elif model.output=="categorical":
            for k in range(i+1):
                ACC[i][k]=test(eval_model, test_loaders[k], device)

            print("ACC",ACC[i])
        elif dataset=="split":
          for k in range(i+1):
              ACC[i][k]=test(eval_model, test_loaders[k], device,k)
        wandb.log({"average_accuracy":ACC[i][0:i+1].mean()})
        if obj=="ewc":
          if dataset=="permuted":
            model.calculate_fim(train_loaders1[i])
          if dataset=="split":
            model.calculate_fim(train_loaders1[i],task_id=i)
        elif obj=="si":
          model.SI.update_omega()
          model.SI.divide_omega()

    return ACC,RMSE,model


#Experiments

In [None]:
torch.manual_seed(0)
np.random.seed(0)
wandb.login()

In [None]:
#Make sure to include the file mnist.pkl.gz in your working directory (available at https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz)
with gzip.open('sample_data/mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')

In [145]:
epochs=30
init_rho=-4.
prior_rho=-0.
batch_size=512
n_tasks=10
lr=0.002
epochs_coresets=0
n_fisher_samples=3000
obj="si"
k_center_bool=False
output="gaussian"
bayesian=True
dataset="permuted"

In [14]:
wandb.finish()

In [28]:
coreset_size=200
config={                         # Track hyperparameters and metadata
        "epochs":epochs,
        "init_rho":init_rho,
        "prior_rho":prior_rho,
        "batch_size":batch_size,
        "n_tasks":n_tasks,
        "lr":lr,
        "coreset_size":coreset_size,
        "epochs_coresets":epochs_coresets,
        "n_fisher_samples":n_fisher_samples,
        "obj":obj,
        "k_center_bool":k_center_bool
    }

train_loaders,test_loaders,coreset_loaders,train_loaders1,train_loadersk,kcoreset_loaders,train_loadersk1=loaders(n_tasks,coreset_size,batch_size=batch_size,dataset=dataset)

In [None]:
torch.autograd.set_detect_anomaly(False)
run = wandb.init(
  project="new_rho_influence",    # Specify your project
  config=config,
)
for a in [-5.]:
  print("init_rho",a,"prior_rho",prior_rho)
  print("__________________-")
  model=BayesianFC(784,20,100,2,device=device,init_rho=a,prior_rho=prior_rho,n_fisher_samples=n_fisher_samples,dist="gaussian",output=output,bayesian=bayesian)

  if k_center_bool:
    ACC2=run_vcl(
        model, train_loadersk, kcoreset_loaders, test_loaders,train_loadersk1, epochs,epochs_coresets, device,lr=lr,obj=obj
    )
  else:
    ACC2,RMSE,model=run_vcl(
        model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,lr=lr,obj=obj
    )
print(ACC2)
print(RMSE)
wandb.finish()


