#Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import pickle
import gzip
import copy
import wandb

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Misc

In [None]:
def KL_DIV(mu0,rho0,mu1,rho1):
  std1 = torch.exp(rho1)
  std0 = torch.exp(rho0)
  kl_div = (0.5*((std0/(1e-12+std1))**2+((mu0-mu1)**2/(1e-12+std1**2))-1+2*torch.log(std1/(1e-12+std0)))).sum()
  return kl_div


In [None]:
class BayesianLinear(nn.Module):
    def __init__(self, input_size, output_size,device,init_rho,prior_rho,dist="gaussian"):
        super(BayesianLinear, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.device=device
        self.dist=dist

        self.post_w_mu = nn.Parameter(torch.empty(output_size, input_size))
        self.post_b_mu = nn.Parameter(torch.empty(output_size))

        self.post_w_rho = nn.Parameter(torch.full((output_size, input_size), init_rho))
        self.post_b_rho = nn.Parameter(torch.full((output_size,), init_rho))

        self.prior_w_mu = torch.zeros(output_size, input_size).to(device)
        self.prior_b_mu = torch.zeros(output_size).to(device)

        self.prior_w_rho = torch.full((output_size, input_size), prior_rho).to(device)
        self.prior_b_rho = torch.full((output_size,), prior_rho).to(device)

        self.init_parameters()


    def to(self, device):
        super(BayesianLinear, self).to(device)
        self.prior_w_mu = self.prior_w_mu.to(device)
        self.prior_b_mu = self.prior_b_mu.to(device)
        self.prior_w_rho = self.prior_w_rho.to(device)
        self.prior_b_rho = self.prior_b_rho.to(device)


    def init_parameters(self):
        nn.init.trunc_normal_(self.post_w_mu,std=0.1)
        nn.init.trunc_normal_(self.post_b_mu,std=0.1)
        nn.init.trunc_normal_(self.prior_w_mu,std=0.1)
        nn.init.trunc_normal_(self.prior_b_mu,std=0.1)


    def forward(self,x):
      a = self.sample_activation(x)
      b = self.sample_b(x)
      return a+b


    def sample_activation(self,x):
      if self.dist=="gaussian":
        act_mean = F.linear(x,self.post_w_mu) #(batch_size,output_size)
        std = torch.exp(self.post_w_rho) #(output_size,input_size)
        act_std = torch.sqrt(1e-8+F.linear((x**2),(std**2))) #(batch_size,output_size)
        act_eps = torch.randn(x.shape[0],1,self.output_size).to(self.device)
        act = act_mean+act_std*act_eps
        return act
      if self.dist=="radial":
        eps=torch.randn(self.post_w_rho.shape,device=self.device)
        eps=eps/torch.linalg.norm(eps)
        weights=self.post_w_mu+np.random.randn()*eps*torch.exp(0.5*self.post_w_rho)
        act=F.linear(x,weights)
        return act


    def sample_weights_radial(self,n_samples):
      eps=torch.randn((n_samples,*self.post_w_rho.shape),device=self.device)
      eps=eps/torch.linalg.norm(eps,dim=(1,2)).view(-1,1,1)
      weights=self.post_w_mu.view(1,*self.post_w_mu.shape)+np.random.randn()*eps*torch.exp(self.post_w_rho)
      return weights


    def Lcrossentropy(self,n_samples):
      norms=torch.linalg.norm((self.sample_weights_radial(n_samples)-self.prior_w_mu.unsqueeze(0))/torch.exp(self.prior_w_rho.unsqueeze(0)),dim=(1,2))**2
      L=-0.5*torch.mean(norms)
      return L


    def sample_b(self,x):
      if self.dist=="gaussian":
        b_eps = torch.randn(x.shape[0],1,self.output_size).to(self.device)
        b = self.post_b_mu.view(1,1,-1)+torch.exp(self.post_b_rho.view(1,1,-1))*b_eps
        return b


    def update_prior(self):
      self.prior_w_mu=copy.deepcopy(self.post_w_mu.data)
      self.prior_w_rho=copy.deepcopy(self.post_w_rho.data)
      self.prior_b_mu=copy.deepcopy(self.post_b_mu.data)
      self.prior_b_rho=copy.deepcopy(self.post_b_rho.data)


    def kl_div(self):
      return(KL_DIV(self.post_w_mu,self.post_w_rho,self.prior_w_mu,self.prior_w_rho)+KL_DIV(self.post_b_mu,self.post_b_rho,self.prior_b_mu,self.prior_b_rho))


In [None]:
class MultiGenerativeBayesianFC(nn.Module):
  def __init__(self,input_size=784,hidden_size=500,n_hidden=2,z_size=50,init_rho=0,prior_rho=0,n_tasks=5,n_fisher_samples=3000):
    super(MultiGenerativeBayesianFC,self).__init__()
    self.device=device
    self.encoders=nn.ModuleList([nn.ModuleList() for i in range (n_tasks)])
    dims_encoder=[input_size]+(n_hidden*2-1)*[hidden_size]+[2*z_size]

    for i in range (len(dims_encoder)-1):
      for encoder in self.encoders:
        encoder.append(nn.Linear(dims_encoder[i],dims_encoder[i+1],device=device))

    self.generator_heads=nn.ModuleList([nn.ModuleList() for i in range (n_tasks)])
    dims_genheads=[z_size]+n_hidden*[hidden_size]

    for i in range (len(dims_genheads)-1):
      for head in self.generator_heads:
        head.append(BayesianLinear(input_size=dims_genheads[i],output_size=dims_genheads[i+1],init_rho=init_rho,prior_rho=prior_rho,device=device))

    self.generator_shared=nn.ModuleList()
    dims_genshared=n_hidden*[hidden_size]+[input_size]

    for i in range (len(dims_genshared)-1):
      self.generator_shared.append(BayesianLinear(input_size=dims_genshared[i],output_size=dims_genshared[i+1],init_rho=init_rho,prior_rho=prior_rho,device=device))

    self.n_fishersamples=n_fisher_samples
    self.fisher_dicts=[]
    self.param_history=[]
    self.SI=None


  def sample_gaussian(self,mu,rho):
    eps = torch.randn(mu.shape,device=device)
    return mu+torch.exp(rho)*eps


  def encode(self,x,id_task):
    encoder=self.encoders[id_task]
    for i,layer in enumerate(encoder):
      x=layer(x)
      if i<len(encoder)-1:
        x=F.leaky_relu(x)
    mu,logvar=x[:,:,:int(x.shape[2]/2)],x[:,:,int(x.shape[2]/2):]
    return mu,logvar


  def decode(self,mu,logvar,id_task):
    genhead=self.generator_heads[id_task]
    z_int=self.sample_gaussian(mu,logvar)
    z=z_int
    for i,layer in enumerate(genhead):
      z=F.relu(layer(z))
    for i,layer in enumerate(self.generator_shared):
      z=layer(z)
      if i<len(self.generator_shared)-1:
        z=F.relu(z)
      else:
        z=F.sigmoid(z)
    return z,z_int


  def forward(self,x,id_task,n_samples=1):
    if len(x.shape)==2:
      x=x.unsqueeze(0).tile((n_samples,1,1))
    mu,logvar=self.encode(x,id_task)
    xout,z_int=self.decode(mu,logvar,id_task)
    return xout


  def kl_div(self,n_data,id_task):
    total_kl = 0
    for layer in self.generator_shared:
        total_kl += layer.kl_div()
    return total_kl/n_data


  def logprob(self,x,id_task,n_samples,n_data,epsilon=1e-8):
    batch_size=x.shape[0]
    x=x.unsqueeze(0)
    mu,logvar=self.encode(x,id_task)
    mu=mu.tile((n_samples,1,1))
    logvar=logvar.tile((n_samples,1,1))
    xout,z_int=self.decode(mu,logvar,id_task)
    KL=KL_DIV(mu,logvar,torch.zeros_like(mu).to(self.device),torch.zeros_like(logvar).to(self.device))
    prob=torch.mul(torch.log(torch.clip(xout,min=1e-9,max=1)),x)
    inv_prob=torch.mul(torch.log(torch.clip(1-xout,min=1e-9,max=1)),1-x)
    inv_prob[inv_prob != inv_prob] = epsilon
    return (torch.sum(torch.add(prob,inv_prob))-KL)/(batch_size*n_samples)


  def elbo(self,x,id_task,n_data,n_samples=10,epsilon=1e-8):
    nll=-self.logprob(x,id_task,n_samples,n_data,epsilon)
    kl=self.kl_div(n_data,id_task)
    return nll+kl,kl


  def nll(self,x,id_task,n_data,n_samples=10,epsilon=1e-8):
    return -self.logprob(x,id_task,n_samples,n_data,epsilon)


  def calculate_fisher(self,train_loader1,task_id):
    #Calculation of fisher matrix after each task
    fisher_dic = {}
    n_data=get_total_elements(train_loader1)
    for n, p in self.named_parameters():
      if n.startswith("generator_shared"):
        fisher_dic[n] = p.detach().clone().zero_()
    self.eval()
    for batch_id, (x, y) in enumerate(train_loader1):
      if batch_id >= self.n_fishersamples:
        break
      x,y=x.to(self.device),y.to(self.device)
      loss = self.nll(x,task_id,n_data)
      self.zero_grad()
      loss.backward()
      for n, p in self.named_parameters():
        if n.startswith("generator_shared"):
          if p.requires_grad and p.grad is not None:
            fisher_dic[n] += p.grad.detach() ** 2

    fisher_dic = {n: p/batch_id for n,p in fisher_dic.items()}
    self.fisher_dicts.append(fisher_dic)
    param_dic = {}
    for n, p in self.named_parameters():
      if n.startswith("generator_shared"):
        param_dic[n] = p.detach().clone()
    self.param_history.append(param_dic)
    self.train()


  def calculate_ewc(self):
    #calculation of ewc loss
    Lloss = []
    for task in range(len(self.fisher_dicts)):
      for n, p in self.named_parameters():
        if n.startswith("generator_shared"):
          if p.requires_grad:
            Lloss.append( torch.sum( self.fisher_dicts[task][n] * ( p-self.param_history[task][n] ) ** 2 ))

    return (1./2)*sum(Lloss)


  def update_prior(self):
    for head in self.generator_heads:
      for layer in head:
        layer.update_prior()
    for layer in self.generator_shared:
      layer.update_prior()


class Classifier(nn.Module):
  def __init__(self,device,input_size=784,hidden_size=500,n_hidden=2,output_size=10):
    super(Classifier,self).__init__()
    self.layers=nn.ModuleList()
    self.layers.append(nn.Linear(input_size,hidden_size,device=device))
    for i in range(n_hidden-1):
      self.layers.append(nn.Linear(hidden_size,hidden_size,device=device))
    self.layers.append(nn.Linear(hidden_size,output_size,device=device))


  def forward(self,x):
    for i,layer in enumerate(self.layers):
      x=layer(x)
      if i<len(self.layers)-1:
        x=F.relu(x)
    return F.log_softmax(x,dim=1)

In [None]:
class SynapticIntelligence():
  def __init__(self,model,si_c=0.1,epsilon=0.1,gamma=0.9):
    self.prev_SI={}
    self.omega_SI={}
    self.W={}
    self.p_old={}
    self.model=model
    self.si_c = si_c
    self.epsilon = epsilon
    self.gamma=gamma

  def init_SI(self):
    #initialization step
    for n, p in self.model.named_parameters():
      if n.startswith("generator_shared"):
        if p.requires_grad:
            self.prev_SI[n] = p.data.clone()

  def prepare_w_P_SI(self):
    #other initialization step
    for n, p in self.model.named_parameters():
      if n.startswith("generator_shared"):
        if p.requires_grad:
            self.W[n] = p.data.clone().zero_()
            self.p_old[n] = p.data.clone()

  def update_parameter_importance(self):
    #update the importance of parameters, in order to calculate omega
    for n, p in self.model.named_parameters():
      if n.startswith("generator_shared"):
        if p.requires_grad:
            if p.grad is not None:
                self.W[n].add_(-p.grad*(p.detach()-self.p_old[n]))
            self.p_old[n] = p.detach().clone()

  def update_omega(self):
    #update of omega which allows to compute the regularization term
        for n, p in self.model.named_parameters():
          if n.startswith("generator_shared"):
            if p.requires_grad:
                p_prev = self.prev_SI[n]
                p_current = p.detach().clone()
                p_change = p_current - p_prev
                omega_add = self.W[n]/(p_change**2 + self.epsilon)
                try:
                    omega = self.omega_SI[n]
                except KeyError:
                    omega = p.detach().clone().zero_()
                omega_new = omega + omega_add

                self.prev_SI[n] = p_current
                self.omega_SI[n] = omega_new

  def divide_omega(self):
    #potential division of omega to avoid divergence
    pass
    '''
    for n, p in self.model.named_parameters():
      if n.startswith("layers"):
        if p.requires_grad:
          self.omega_SI[n]=self.omega_SI[n]*self.gamma'''

  def surrogate_l(self):
    #calculation of SI loss
        try:
            losses = []
            for n, p in self.model.named_parameters():
              if n.startswith("generator_shared"):
                if p.requires_grad:
                    prev_values = self.prev_SI[n]
                    omega = self.omega_SI[n]
                    losses.append((omega * (p-prev_values)**2).sum())
            return sum(losses)
        except KeyError:
            print("KeyError")
            return torch.tensor(0., device=self._device())


#Datasets

In [None]:
def permute_MNIST(x_train,y_train,x_test,y_test,task_id):
    np.random.seed(task_id)
    permutation = np.random.permutation(x_train.shape[1])
    return x_train[:, permutation], y_train, x_test[:, permutation], y_test

def split_MNIST(x_train,y_train,x_test,y_test,task_id):
    Tasks=[[0,1],[2,3],[4,5],[6,7],[8,9]]
    classes=Tasks[task_id]
    idxs_train=np.where(np.logical_or(y_train==classes[0],y_train==classes[1]))[0]
    idxs_test=np.where(np.logical_or(y_test==classes[0],y_test==classes[1]))[0]
    return x_train[idxs_train],1*(y_train[idxs_train]==y_train[idxs_train][0]),x_test[idxs_test],1*(y_test[idxs_test]==y_train[idxs_train][0])

def gen_mnist(x_train,y_train,x_test,y_test,task_id):
    idxs_train=np.where(y_train==task_id)[0]
    idxs_test=np.where(y_test==task_id)[0]

    return x_train[idxs_train],y_train[idxs_train],x_test[idxs_test],y_test[idxs_test]


''' Coreset methods '''

def random_coreset(x_train, y_train, x_coreset, y_coreset, nb_coreset,dataset="split"):

    ids = np.random.choice(x_train.shape[0], nb_coreset, False)

    x_coreset.append(x_train[ids,:])
    y_coreset.append(y_train[ids])

    x_train = np.delete(x_train, ids, axis=0)
    y_train = np.delete(y_train, ids, axis=0)

    return x_train, y_train, x_coreset, y_coreset

def k_center(x_train, y_train, x_coreset, y_coreset, nb_coreset,dataset="split"):

    ids = []
    visited = np.ones(x_train.shape[0])
    id = 0

    while len(ids)<nb_coreset:
        ds = dists(x_train,id)*visited
        id = np.atleast_1d(np.argmax(ds))[0]
        visited[id]=0
        ids.append(id)

    x_coreset.append(x_train[ids,:])
    y_coreset.append(y_train[ids])

    x_train = np.delete(x_train, ids, axis=0)
    y_train = np.delete(y_train, ids, axis=0)

    if dataset=="permuted":
      return ids
    else:
      return x_train, y_train, x_coreset, y_coreset

def dists(x_train, id):
    return np.linalg.norm(x_train-x_train[id,:],axis=1)

def get_total_elements(loader):
  total_elements = 0
  for batch_idx, (data, target) in enumerate(loader):
      total_elements += data.shape[0]  # Add the batch size to the total
  return total_elements


In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset


def dataloader(X, Y, batch_size=128, shuffle=True, num_workers=0, pin_memory=False):
    loaders = []
    for x, y in zip(X, Y):
        assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray), "X and Y must be NumPy arrays"
        dataset = TensorDataset(torch.from_numpy(x).float(), torch.from_numpy(y).long())
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers, pin_memory=pin_memory)
        loaders.append(loader)
    return loaders


def aggdataloader(X, Y, batch_size=128, shuffle=True, num_workers=0, pin_memory=False):
    loaders = []
    x_agg,y_agg=None,None
    for i,(x, y) in enumerate(zip(X, Y)):
        if i==0:
          x_agg=x
          y_agg=y
        else:
          x_agg=np.concatenate((x_agg,x),axis=0)
          y_agg=np.concatenate((y_agg,y),axis=0)
        assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray), "X and Y must be NumPy arrays"
        dataset = TensorDataset(torch.from_numpy(x_agg).float(), torch.from_numpy(y_agg).long())
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers, pin_memory=pin_memory)
        loaders.append(loader)
    return loaders


def loaders(nb_tasks,coreset_size,batch_size=128):
  with gzip.open('sample_data/mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
  X_trains,Y_trains,Xk_trains,Yk_trains,X_tests,Y_tests=[],[],[],[],[],[]
  for i in range(nb_tasks):
    i_x_train,i_y_train,i_x_test,i_y_test=gen_mnist(train_set[0],train_set[1],test_set[0],test_set[1],i)
    X_trains.append(i_x_train)
    Y_trains.append(i_y_train)
    Xk_trains.append(i_x_train)
    Yk_trains.append(i_y_train)
    X_tests.append(i_x_test)
    Y_tests.append(i_y_test)
  X_coresets,Xk_coresets,Y_coresets,Yk_coresets=[],[],[],[]
  for i in range(nb_tasks):
    Xk_trains[i],Yk_trains[i],Xk_coresets,Yk_coresets=k_center(Xk_trains[i],Yk_trains[i],X_coresets,Y_coresets,coreset_size)
    X_trains[i],Y_trains[i],X_coresets,Y_coresets=random_coreset(X_trains[i],Y_trains[i],X_coresets,Y_coresets,coreset_size)
  train_loaders=dataloader(X_trains,Y_trains,batch_size)
  train_loaders1=dataloader(X_trains,Y_trains,1)
  train_loadersk=dataloader(Xk_trains,Yk_trains,batch_size)
  train_loadersk1=dataloader(Xk_trains,Yk_trains,1)
  test_loaders=dataloader(X_tests,Y_tests,batch_size)
  coreset_loaders=dataloader(X_coresets,Y_coresets,batch_size)
  kcoreset_loaders=dataloader(Xk_coresets,Yk_coresets,batch_size)
  return train_loaders,test_loaders,coreset_loaders,train_loaders1,train_loadersk,kcoreset_loaders,train_loadersk1

In [None]:
def train(model,loader,optimizer,num_epochs,device):
  model.train()
  n_data=get_total_elements(loader)

  for epoch in range(num_epochs):
    running_loss = 0.
    running_kldiv=0.
    j=0
    for inputs, targets in loader:
      inputs,targets=inputs.to(device),targets.to(device)
      outputs=model(inputs)
      loss=F.nll_loss(outputs,targets,reduce="mean")
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
    if epoch%10==0:
      print(f"Epoch {epoch}: Loss: {running_loss/n_data}, KL-Div : {running_kldiv/n_data}")

def train_split(model,loader,optimizer,num_epochs,device,clip_value,task_id,obj="nll",is_print=True,lambda_ewc=20,lambda_si=100):
  model.train()
  n_data=get_total_elements(loader)
  if obj=="si":
    if task_id==0:
      model.SI=SynapticIntelligence(model)
      model.SI.init_SI()
    model.SI.prepare_w_P_SI()

  for epoch in range(num_epochs):
    running_loss = 0.
    running_kldiv=0.
    j=0
    for inputs, targets in loader:
      j+=1
      inputs, targets = inputs.to(device), targets.to(device)
      if obj=="nll":
        loss=model.nll(inputs,task_id,n_data)
      elif obj=="elbo":
        loss,kl_div=model.elbo(inputs,task_id,n_data)
        running_kldiv+=kl_div
      elif obj=="ewc":
        loss=model.nll(inputs,task_id,n_data)+lambda_ewc*model.calculate_ewc()
      elif obj=="si":
        if task_id==0:
          loss=model.nll(inputs,task_id,n_data)
        else:
          loss=model.nll(inputs,task_id,n_data)+lambda_si*model.SI.surrogate_l()
      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
      optimizer.step()
      if obj=="si":
        model.SI.update_parameter_importance()
      running_loss += loss.item()
    if is_print and epoch%10==0:
      print(f"Epoch {epoch}: Loss: {running_loss/n_data}, KL-Div : {running_kldiv/n_data}")

def test(model,loader,device,classifier,n_plots=10,task_id=None,n_samples=100,dataset='split'):
  model.to(device)
  model.eval()
  n_data=get_total_elements(loader)
  print=True
  nll=[]
  uncertainty=[]
  img=None
  with torch.no_grad():
    for inputs, targets in loader:
      inputs, targets = inputs.to(device), targets.to(device)
      xout = model(inputs,task_id,n_samples=1)
      probs=classifier(xout.squeeze(0))
      uncertainty.append(F.nll_loss(probs,targets,reduce="mean").item())
      if print:
        plot_digits(xout.squeeze(0)[0:n_plots])
        print=False
      img=xout.squeeze(0)[0].cpu().detach().numpy()
      nllval=model.nll(inputs,task_id,n_data)
      nll.append(nllval.item())
  return(-np.mean(nll),np.mean(uncertainty),img)


In [None]:
def run_vcl(
    model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,classifier,lr=0.001,dataset="permuted",obj="elbo",clip_value = 10.,
lambda_ewc=20.):

    LL = np.zeros((len(train_loaders),len(train_loaders)))
    U=np.zeros((len(train_loaders),len(train_loaders)))
    IMGS=np.zeros((len(train_loaders),len(train_loaders),784))
    optimizer=None
    model.to(device)

    for i in range(len(train_loaders)):

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        train_split(model, train_loaders[i], optimizer, epochs, device, task_id=i,obj=obj,clip_value=clip_value,lambda_ewc=lambda_ewc)
        model.update_prior()
        for k in range(i+1):
            LL[i][k],U[i][k],IMGS[i][k]=test(model, test_loaders[k], device,classifier,task_id=k)
        if obj=="ewc":
          model.calculate_fisher(train_loaders1[i],task_id=i)
        elif obj=="si":
          model.SI.update_omega()

    return LL,U,IMGS


In [None]:
def plot_digits(x):
  n_plots=x.shape[0]
  images=x.view(-1,28,28)
  fig, axes = plt.subplots(1, 5, figsize=(15, 5))
  for i, ax in enumerate(axes):
      ax.imshow(images[i].detach().cpu().numpy(), cmap="gray")
      ax.axis("off")
  plt.show()

#Experiments

In [None]:
#Make sure to include the file mnist.pkl.gz in your working directory (available at https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz)
with gzip.open('sample_data/mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')


In [None]:
epochs=200
init_rho=-4.8
prior_rho=0.
batch_size=1024
n_tasks=10
lr=0.01
coreset_size=200
epochs_coresets=0
n_fisher_samples=3000
obj="ewc"
k_center_bool=False

In [None]:
train_loaders,test_loaders,coreset_loaders,train_loaders1,train_loadersk,kcoreset_loaders,train_loadersk1=loaders(n_tasks,coreset_size,batch_size=batch_size)

In [None]:
classifier=Classifier(device)
loader=dataloader(np.array([train_set[0]]),np.array([train_set[1]]),batch_size=512)[0]
optimizer=torch.optim.Adam(classifier.parameters(),lr=lr)
train(classifier,loader,optimizer,epochs,device)


In [None]:
test_loader=dataloader(np.array([test_set[0]]),np.array([test_set[1]]),batch_size=512)[0]
correct=0
n_data=get_total_elements(test_loader)
with torch.no_grad():
  for inputs, targets in test_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    xout = classifier(inputs)
    pred = xout.argmax(dim=1)
    correct += (1*(pred==targets)).sum().item()
print(correct/n_data)

In [None]:
init_rho=-5.
model=MultiGenerativeBayesianFC(init_rho=init_rho,prior_rho=prior_rho,n_tasks=10)
obj='elbo'
LLvcl,Uvcl,IMGSvcl=run_vcl(
      model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,classifier,lr=lr,obj=obj
  )
model=MultiGenerativeBayesianFC(init_rho=init_rho,prior_rho=prior_rho,n_tasks=10)
obj="ewc"
LLewc,Uewc,IMGSewc=run_vcl(
      model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,classifier,lr=lr,obj=obj
  )
model=MultiGenerativeBayesianFC(init_rho=init_rho,prior_rho=prior_rho,n_tasks=10)
obj="si"
LLsi,Usi,IMGSsi=run_vcl(
      model, train_loaders, coreset_loaders, test_loaders,train_loaders1, epochs,epochs_coresets, device,classifier,lr=lr,obj=obj
  )
