<a href="https://colab.research.google.com/github/jvallikivi/mlmi4-vcl/blob/main/v2_get_test_ll.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.distributions import kl_divergence, Normal
from collections import OrderedDict
from torchvision import datasets
from torchvision import transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 50

In [None]:
# I used their evaluation function and adjusted for torch for consistency
def gaussian_prob(x, loc=0., log_scale=torch.Tensor([0.]).to(device)):
    p1 = -torch.log(torch.Tensor([2 * torch.pi]).to(device))/2 - log_scale
    p2 = -((x - loc)/(np.sqrt(2)*(torch.exp(log_scale))))*2
    prob = (p1 + p2).sum(1) 
    return prob

def bernoulli_prob(x, p=torch.Tensor([0.5]).to(device)):
    p1 = x * torch.log(torch.clamp(p, 1e-9, 1.0))
    p2 = (1 - x) * torch.log(torch.clamp(1.0 - p, 1e-9, 1.0))
    prob = (p1 + p2).sum(1)

    return prob

def evaluate(model, test_batch):
    test_batch = test_batch.repeat(100, 1)
    loc, log_scale = model.encoder_forward(test_batch) 
    z = torch.randn_like(loc) * torch.exp(log_scale) + loc
    gen = model.decoder_forward(z, sampling=True)
    kl =  gaussian_prob(z, loc, log_scale) - gaussian_prob(z)
    bce = bernoulli_prob(test_batch, gen)
    test_ll = (bce - kl).reshape((100, batch_size))
    max = test_ll.max(0)[0]
    test_ll = torch.log(torch.clamp_min(torch.exp(test_ll - max).mean(0), 1e-9)) + max
    mean = test_ll.mean()
    var = ((test_ll - mean) ** 2).mean()
    return mean, var

In [None]:
class BayesLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight_loc = nn.Parameter(torch.zeros(in_features, out_features))
        self.log_weight_scale = nn.Parameter(torch.zeros(in_features, out_features))

        self.bias_loc = nn.Parameter(torch.zeros(out_features))
        self.log_bias_scale = nn.Parameter(torch.zeros(out_features))
        
        self.weight = self.weight_loc.data.clone().to(device)
        self.bias = self.bias_loc.data.clone().to(device)



    def get_params(self):
        """
        return two tensors, obtaining by concatenating locs and scales together
        these parameters can be further used to calculate e.g, KL divergence (vectorizedly)
        """
        return (
                torch.cat([self.weight_loc.flatten(), self.bias_loc.flatten()]), 
                torch.cat([self.log_weight_scale.flatten(), self.log_bias_scale.flatten()])
               )

    
    def forward(self, x,  activition_fn=None):

        y = x @ self.weight + self.bias
        if activition_fn:
            return activition_fn(y)
        return y

     

class Encoder(nn.Module):

    def __init__(self, 
                 input_size, 
                 in_features, 
                 out_features, 
                 n_layers,                 ):
        """
        """
        super(Encoder, self).__init__()
        self.x_to_h = nn.ModuleList( [self.get_layer(input_size, in_features, True)] + 
                                    [self.get_layer(in_features, in_features, True) for _ in range(n_layers - 1)]).to(device)
        self.head = nn.ModuleList([self.get_layer(in_features, out_features, True) for _ in range(2)]).to(device)
                                    # + [self.get_layer(in_features, out_features)]).to(device)


    def get_layer(self, in_feature, out_features, activation=False):
        layer = nn.Linear(in_feature, out_features)
        return layer
     

class VCL_Generator(nn.Module):
    def __init__(self, previous_model, 
                 nencoder_layers=3, ndecoder_layers=1, 
                 z_dim=50, h_dim=500, x_dim=784, random_initialize=False):
        super().__init__()
        
        self.ndecoder_layers = ndecoder_layers
        self.z_dim = z_dim

        #######ENCODER PART

        self.encoder = Encoder(input_size=x_dim,
                                in_features=h_dim,
                                out_features=z_dim,
                                n_layers=nencoder_layers).to(device)
        self.x_to_h = self.encoder.x_to_h#.to(device)
        self.encoder_head = self.encoder.head#.to(device)

        ######GENERATOR PART


        self.decoder_head = nn.ModuleList([BayesLinear(z_dim, h_dim)] +
                                          [BayesLinear(h_dim, h_dim) for _ in range(self.ndecoder_layers)]).to(device) 
        self.h_to_x = nn.ModuleList([BayesLinear(h_dim, x_dim)]).to(device)
        
        
        # define a layer dict
        self.layer_dict = OrderedDict()
        for ix, layer in enumerate(self.x_to_h):
            self.layer_dict[f'x_to_h_{ix}'] = layer

        for ix, layer in enumerate(self.h_to_x):
            self.layer_dict[f'h_to_x_{ix}'] = layer
        
        for ix, layer in enumerate(self.decoder_head):
            self.layer_dict[f'decoder_head_{ix}'] = layer

        self.layer_dict["encoder_head"] = self.encoder_head

        # just a sanity check 
        assert id(self.layer_dict[f"x_to_h_{0}"]) == id(self.x_to_h[0])
        
        
        with torch.no_grad():
            if previous_model != None:
                for key in self.layer_dict:
                    if "decoder_head" in key or "h_to_x" in key:
                        self.layer_dict[key].weight_loc.data = previous_model.layer_dict[key].weight_loc.data.clone()
                        self.layer_dict[key].bias_loc.data = previous_model.layer_dict[key].bias_loc.data.clone()
                        self.layer_dict[key].log_weight_scale.data  = previous_model.layer_dict[key].log_weight_scale.data.clone()
                        self.layer_dict[key].log_bias_scale.data  = previous_model.layer_dict[key].log_bias_scale.data.clone()
                    else:
                        self.layer_dict[key] = previous_model.layer_dict[key]
            if random_initialize:
                for layer in self.h_to_x + self.decoder_head:
                    layer.log_weight_scale = nn.Parameter(torch.ones(layer.log_weight_scale.shape, device=device) * -6)
                    layer.log_bias_scale = nn.Parameter(torch.ones(layer.log_bias_scale.shape, device=device) * -6)
                    torch.nn.init.xavier_uniform_(layer.weight_loc)
        
        if  previous_model != None:
            previous_locs, previous_logscales = previous_model.get_params()
            self.previous_model_locs = previous_locs
            self.previous_model_log_scales = previous_logscales
        else:
            self.previous_model_locs = None
            self.previous_model_log_scales = None


    def encoder_forward(self, x):
        for layer in self.x_to_h:
            x = layer(x)
            x = nn.ReLU()(x)
        loc = self.encoder_head[0](x)
        log_scale = self.encoder_head[1](x)
        return loc, log_scale

    
    def decoder_forward(self, z, sampling=False):
        size = len(self.decoder_head + self.h_to_x)
        
        for ix, layer in enumerate(self.decoder_head + self.h_to_x):
            if sampling:
                layer.weight = torch.randn_like(layer.weight_loc) * torch.exp(layer.log_weight_scale) + layer.weight_loc
                layer.bias = torch.randn_like(layer.bias_loc) * torch.exp(layer.log_bias_scale) + layer.bias_loc
            z = layer(z, nn.Sigmoid())
        return z


    def get_params(self):  
        locs = []
        logscales = []
        for layer in self.h_to_x:
            loc, scale = layer.get_params()
            locs.append(loc)
            logscales.append(scale)
        return locs, logscales

    def calculate_loss(self, x, y, n_particles=10, dataset_size=6000):
        
        locs, logscales = self.get_params()
        # calculate KL between "prior" and posterior
        size = len(locs)
        kl = 0
        for ix in range(size):
            KL = kl_divergence(Normal(loc=locs[ix], scale=torch.exp(logscales[ix])),
                                Normal(loc=self.previous_model_locs[ix], scale=torch.exp(self.previous_model_log_scales[ix]))
                            )
            kl += KL.sum()/dataset_size 

        loc, log_scale = self.encoder_forward(x) 
        kl_z = kl_divergence(Normal(loc, torch.exp(log_scale)), 
                             Normal(0., torch.exp(torch.Tensor([0.]).to(device))))
        bce = 0.
        for _ in range(n_particles):
            z = torch.randn_like(loc) * torch.exp(log_scale) + loc
            gen = self.decoder_forward(z, sampling=True)
            bce = bce + torch.nn.functional.binary_cross_entropy(input=gen, target=x, reduction="sum")
            

        loss = (kl_z.sum() + bce / n_particles) / x.shape[0] + kl
        return loss

In [None]:
ds_test = datasets.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)
ds_train = datasets.MNIST("./", train=True, transform=transforms.ToTensor(), download=True)

def get_digit(task_idx=0, conv=False):
    #image normalization 255
    train_y = torch.tensor([d[1] for d in ds_train])
    test_y = torch.tensor([d[1] for d in ds_test])

    train_mask = train_y == task_idx
    test_mask = test_y == task_idx

    train_y = train_y[train_mask]
    test_y = test_y[test_mask]

    test_x = nn.Flatten()(torch.cat([d[0] for d in ds_test]))[test_mask]
    train_x = nn.Flatten()(torch.cat([d[0] for d in ds_train]))[train_mask]

    return train_x.to(device), train_y.to(device), test_x.to(device), test_y.to(device)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [None]:
checkpoints_dir = "models_mnist_final"

In [None]:
import copy
def calculate_digit_ll(digit_num):
    lls = []
    # try:

    _, _, test_x, _ = get_digit(digit_num)

    model = VCL_Generator(previous_model=None)
    model.load_state_dict(torch.load(f"mnist/Copy of model_task{digit_num}"))
    nbatches = test_x.shape[0] // batch_size
    ll_total = 0.
    for batch in range(nbatches):
        batch_idx0 = batch * batch_size
        batch_idx1 = batch * batch_size + batch_size
        ll, _ = evaluate(model, test_x[batch_idx0: batch_idx1])
        ll_total += ll/nbatches

    lls.append(ll_total.detach().item())
    for i in range(digit_num+1, 10, 1):
        model1 = VCL_Generator(previous_model=None)
        model1.load_state_dict(torch.load(f"mnist/Copy of model_task{i}"))
        model.h_to_x = copy.copy(model1.h_to_x)
        ll_total = 0.
        for batch in range(nbatches):
            batch_idx0 = batch * batch_size
            batch_idx1 = batch * batch_size + batch_size
            ll, _ = evaluate(model, test_x[batch_idx0: batch_idx1])
            ll_total += ll/nbatches
        lls.append(ll_total.detach().item())    
    # except:
        # pass

    return lls

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
%cd /content/drive/Shareddrives/MLMI4/checkpoints/generative/

/content/drive/Shareddrives/MLMI4/checkpoints/generative


In [None]:
from collections import defaultdict
test_ll_dict = defaultdict(list)
for i in tqdm(range(10)):
    test_ll_dict[i] = calculate_digit_ll(i)

100%|██████████| 10/10 [03:20<00:00, 20.06s/it]


In [None]:
test_ll_dict

defaultdict(list,
            {0: [-114.6396713256836,
              -114.51583862304688,
              -114.04544067382812,
              -114.31660461425781,
              -114.18891906738281,
              -114.09271240234375,
              -113.90177154541016,
              -113.81189727783203,
              -113.77135467529297,
              -113.79896545410156],
             1: [-46.972354888916016,
              -46.680885314941406,
              -46.60992431640625,
              -46.56977844238281,
              -46.712310791015625,
              -46.956058502197266,
              -46.871124267578125,
              -46.73015594482422,
              -46.782352447509766],
             2: [-119.64279174804688,
              -119.72482299804688,
              -119.47166442871094,
              -119.98435974121094,
              -120.2338638305664,
              -119.90421295166016,
              -119.76033782958984,
              -119.7942123413086],
             3: [-103.504318237

In [None]:
def evaluate(model, test_batch):
    with torch.no_grad():
        test_ll = []
        for i in range(100):
            test_batch = test_batch
            loc, log_scale = model.encoder_forward(test_batch) 
            z = torch.randn_like(loc) * torch.exp(log_scale) + loc
            gen = model.decoder_forward(z, sampling=True)
            kl =  gaussian_prob(z, loc, log_scale) - gaussian_prob(z)
            bce = bernoulli_prob(test_batch, gen)
            test_ll.append(bce - kl)
        test_ll = torch.stack(test_ll)
        max = test_ll.max(0)[0]
        test_ll = torch.log(torch.clamp_min(torch.exp(test_ll - max).mean(0), 1e-9)) + max
        mean = test_ll.mean()
        var = ((test_ll - mean) ** 2).mean()
    return mean, var

In [None]:
import copy
def calculate_digit_ll(digit_num):
    batch_size = 10000000 # use the whole set as one batch to save time
    lls = []
    # try:

    _, _, test_x, _ = get_digit(digit_num)

    model = VCL_Generator(previous_model=None)
    model.load_state_dict(torch.load(f"mnist/Copy of model_task{digit_num}"))
    nbatches = int(np.ceil(test_x.shape[0] / batch_size))
    ll_total = 0.
    for batch in range(nbatches):
        batch_idx0 = batch * batch_size
        batch_idx1 = batch * batch_size + batch_size
        ll, _ = evaluate(model, test_x[batch_idx0: batch_idx1])
        ll_total += ll/nbatches

    lls.append(ll_total.item())
    for i in range(digit_num+1, 10, 1):
        model1 = VCL_Generator(previous_model=None)
        model1.load_state_dict(torch.load(f"mnist/Copy of model_task{i}"))
        model.h_to_x = copy.copy(model1.h_to_x)
        ll_total = 0.
        for batch in range(nbatches):
            batch_idx0 = batch * batch_size
            batch_idx1 = batch * batch_size + batch_size
            ll, _ = evaluate(model, test_x[batch_idx0: batch_idx1])
            ll_total += ll/nbatches
        lls.append(ll_total.item())    
    # except:
    #     pass

    return ll

In [None]:
from collections import defaultdict
test_ll_dict = defaultdict(list)
for i in tqdm(range(10)):
    test_ll_dict[i] = calculate_digit_ll(i)

100%|██████████| 10/10 [03:13<00:00, 19.37s/it]


In [None]:
test_ll_dict

defaultdict(list,
            {0: tensor(-113.6500, device='cuda:0'),
             1: tensor(-46.5864, device='cuda:0'),
             2: tensor(-119.5419, device='cuda:0'),
             3: tensor(-104.1698, device='cuda:0'),
             4: tensor(-95.8306, device='cuda:0'),
             5: tensor(-109.2902, device='cuda:0'),
             6: tensor(-93.2820, device='cuda:0'),
             7: tensor(-77.5606, device='cuda:0'),
             8: tensor(-111.7574, device='cuda:0'),
             9: tensor(-85.8380, device='cuda:0')})