In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
import matplotlib.pyplot as plt
import time

In [None]:
# compute KL(X || Y)
# where X and Y are d-dim. vectors with indep. coordinates s.t.
# X_i ~ N(mu_x_i, exp(log_sigma_x_i)) and
# Y_i ~ N(mu_y_i, exp(log_sigma_y_i))
def KL_divergence(mu_x, log_sigma_x, mu_y, log_sigma_y):
    kl = log_sigma_y - log_sigma_x + 0.5 * (torch.exp(2 * log_sigma_x) + \
           (mu_x - mu_y)**2) / torch.exp(2 * log_sigma_y) - 0.5

    axes_to_reduce = list(range(1, len(mu_x.shape)))
    return torch.sum(kl, axes_to_reduce)


# compute KL(X || Y)
# where X and Y are d-dim. vectors with indep. coordinates s.t.
# X_i ~ Lap(mu_x_i, exp(log_b_x_i)) and
# Y_i ~ Lap(mu_y_i, exp(log_b_y_i))
def KL_divergence(mu_x, log_b_x, mu_y, log_b_y):
    kl = (torch.exp(log_b_x - torch.abs(mu_x - mu_y) / torch.exp(log_b_x)) + \
         torch.abs(mu_x - mu_y)) / torch.exp(log_b_y) + log_b_y - log_b_x - 1.0

    axes_to_reduce = list(range(1, len(mu_x.shape)))
    return torch.sum(kl, axes_to_reduce)

# sample from N ~ N(mu, log_sigma)
def sample_gaussian(mu, log_sigma):
    return mu + torch.exp(log_sigma) * torch.randn(mu.shape).to(mu.device)

# sample from L ~ Lap(mu, log_b)
def sample_laplace(mu, log_b):
    shape = mu.get_shape()
    x = torch.log(torch.rand(shape, dtype=tf.float32)) - \
        torch.log(torch.rand(shape, dtype=tf.float32))
    return mu + torch.exp(log_b) * x

def bernoulli_log_likelihood(x, mu):
    log_likelihood = x * torch.log(torch.clamp(mu, 1e-9, 1.0)) \
                      + (1.0 - x) * torch.log(torch.clamp(1.0 - mu, 1e-9, 1.0))

    axes_to_reduce = list(range(1, len(x.shape)))
    return torch.sum(log_likelihood, axes_to_reduce)

def normal_log_likelihood(x, mu, log_sigma):
    log_likelihood = -0.5 * np.log(2 * np.pi) - log_sigma \
                      - 0.5 * ((x - mu) / torch.exp(log_sigma)) ** 2

    axes_to_reduce = list(range(1, len(x.shape)))
    return torch.sum(log_likelihood, axes_to_reduce)

def laplace_log_likelihood(x, mu, log_sigma):
    log_likelihood = -log_sigma - torch.abs(x - mu) / torch.exp(log_sigma) - np.log(2.0)

    axes_to_reduce = list(range(1, len(x.shape)))
    return torch.sum(log_likelihood, axes_to_reduce)

In [None]:
class RandomisedLinearModule(nn.Module):
    def __init__(self, in_dim, out_dim, activation):
        super().__init__()

        scale = np.sqrt(6.0/(in_dim + out_dim))

        self.mu_W = nn.Parameter(torch.rand(out_dim, in_dim) * 2 * scale - scale)
        self.log_sigma_W = nn.Parameter(torch.ones(out_dim, in_dim, dtype=torch.float32) * (-6.0))
        self.mu_b = nn.Parameter(torch.zeros(out_dim, dtype=torch.float32))
        self.log_sigma_b = nn.Parameter(torch.ones(out_dim, dtype=torch.float32) * (-6.0))

        self.mu_W_prior = torch.zeros(out_dim, in_dim, dtype=torch.float32, device="cuda:0")
        self.log_sigma_W_prior = torch.zeros(out_dim, in_dim, dtype=torch.float32, device="cuda:0")
        self.mu_b_prior = torch.zeros(out_dim, dtype=torch.float32, device="cuda:0")
        self.log_sigma_b_prior = torch.zeros(out_dim, dtype=torch.float32, device="cuda:0")

        self.activation = activation

    def forward(self, x, sampling_mode=True):
        if sampling_mode:
            #W = sample_laplace(self.mu_W, self.log_sigma_W)
            #b = sample_laplace(self.mu_b, self.log_sigma_b)
            W = sample_gaussian(self.mu_W, self.log_sigma_W)
            b = sample_gaussian(self.mu_b, self.log_sigma_b)
        else:
            W = self.mu_W
            b = self.mu_b

        return self.activation(F.linear(x, W, b))

    def KL_div(self):
        return KL_divergence(self.mu_W, self.log_sigma_W, self.mu_W_prior, self.log_sigma_W_prior).sum() + \
               KL_divergence(self.mu_b, self.log_sigma_b, self.mu_b_prior, self.log_sigma_b_prior).sum()

    def update_prior(self):
        with torch.no_grad():
            self.mu_W_prior.copy_(self.mu_W)
            self.log_sigma_W_prior.copy_(self.log_sigma_W)
            self.mu_b_prior.copy_(self.mu_b)
            self.log_sigma_b_prior.copy_(self.log_sigma_b)

    def reset_log_sigmas(self):
        with torch.no_grad():
            self.log_sigma_W.copy_(torch.full_like(self.log_sigma_W, -6.0))
            self.log_sigma_b.copy_(torch.full_like(self.log_sigma_b, -6.0))

class SharedModule(nn.Module):
    def __init__(self, dim_x, dim_h, n_layers):
        super().__init__()

        self.n_layers = n_layers

        self.layers_list = nn.ModuleList(
            [RandomisedLinearModule(dim_h, dim_h, F.relu) for _ in range(n_layers-1)] + \
            [RandomisedLinearModule(dim_h, dim_x, F.sigmoid)]
        )

    def forward(self, x, sampling_mode=True):
        for layer in self.layers_list:
            x = layer(x, sampling_mode)

        return x

    def KL_div(self):
        kl_div = 0.0
        for layer in self.layers_list:
            kl_div += layer.KL_div()

        return kl_div

    def update_prior(self):
        for layer in self.layers_list:
            layer.update_prior()

    def reset_log_sigmas(self):
        for layer in self.layers_list:
            layer.reset_log_sigmas()


class TaskSpecificModule(nn.Module):
    def __init__(self, dim_z, dim_h, n_layers):
        super().__init__()

        self.layers_list = nn.ModuleList(
            [RandomisedLinearModule(dim_z, dim_h, F.relu)] + \
            [RandomisedLinearModule(dim_h, dim_h, F.relu) for _ in range(n_layers-1)]
        )

    def forward(self, x, sampling_mode=True):
        for layer in self.layers_list:
            x = layer(x, sampling_mode)
        return x

class GeneratorModule(nn.Module):
    def __init__(self, dim_z, dim_x, dim_h, n_tasks, n_layers_shared, n_layers_taskspec):
        super().__init__()

        self.shared_module = SharedModule(dim_x, dim_h, n_layers_shared)

        self.taskspec_modules = nn.ModuleList(
            [TaskSpecificModule(dim_z, dim_h, n_layers_taskspec) for _ in range(n_tasks)]
        )

    def forward(self, z, task_ind, sampling_mode=True):
        x = self.taskspec_modules[task_ind](z, sampling_mode)
        x = self.shared_module(x, sampling_mode)
        return x

    def KL_div_shared_prior_post(self):
        return self.shared_module.KL_div()

    def update_shared_params_prior(self):
        self.shared_module.update_prior()

    def reset_shared_params_log_sigmas(self):
        self.shared_module.reset_log_sigmas()

In [None]:
class LinearModule(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        scale = np.sqrt(6.0/(in_dim + out_dim))

        self.W = nn.Parameter(torch.rand(out_dim, in_dim, dtype=torch.float32) * 2 * scale - scale)
        self.b = nn.Parameter(torch.zeros(out_dim, dtype=torch.float32))

    def forward(self, x):
        return F.linear(x, self.W, self.b)

class Encoder(nn.Module):
    def __init__(self, dim_z, dim_x, dim_h, n_layers_shared, n_layers_taskspec):
        super().__init__()
        self.n_inner_layers = n_layers_shared + n_layers_taskspec

        self.mlp_module = nn.ModuleList(
            [LinearModule(dim_x, dim_h)] + \
            [LinearModule(dim_h, dim_h) for _ in range(self.n_inner_layers - 2)] + \
            [LinearModule(dim_h, 2*dim_z)]
        )

    def forward(self, x):
        for i, layer in enumerate(self.mlp_module):
            x = layer(x)
            if i < self.n_inner_layers - 1:
                x = F.relu(x)

        mu, log_sigma = x[:, :x.shape[1] // 2], x[:, (x.shape[1] // 2):]

        return mu, log_sigma

In [None]:
hyperparams = {
    'dim_z' : 50,
    'dim_h' : 500,
    'dim_x' : 28 ** 2,
    'n_layers_shared': 2,
    'n_layers_taskspec': 2,
    'n_tasks': 10,
    'batch_size': 50,
    'n_epochs': 200,
    'lr': 1e-4,
    'sampling_mode': False
}

In [None]:
# alter these when changing the dataset

task_labels = list(range(hyperparams['n_tasks']))
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [None]:
task_traindata_dict = {label: None for label in task_labels}
task_testdata_dict = {label: None for label in task_labels}

for label in task_labels:
    train_mask = [i for i in range(len(trainset)) if trainset.targets[i] == label]
    test_mask = [i for i in range(len(testset)) if testset.targets[i] == label]

    task_traindata_dict[label] = DataLoader(
        torch.utils.data.Subset(trainset, train_mask),
        batch_size=hyperparams['batch_size'],
        shuffle=True,
        num_workers=2
    )

    task_testdata_dict[label] = DataLoader(
        torch.utils.data.Subset(testset, test_mask),
        batch_size=hyperparams['batch_size'],
        shuffle=True,
        num_workers=2
    )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = GeneratorModule(
    hyperparams['dim_z'],
    hyperparams['dim_x'],
    hyperparams['dim_h'],
    hyperparams['n_tasks'],
    hyperparams['n_layers_shared'],
    hyperparams['n_layers_taskspec']
).to(device)

encoder = Encoder(
    hyperparams['dim_z'],
    hyperparams['dim_x'],
    hyperparams['dim_h'],
    hyperparams['n_layers_shared'],
    hyperparams['n_layers_taskspec']
).to(device)

In [None]:
def compute_loss(x_train, enc, gen, task, total_data, K=10):
    # K - number of samples used for Monte Carlo approximation of the expectation
    z_post_mu, z_post_log_sigma = enc(x_train)

    # compute KL(q_phi(z|x) || p(z))   knowing z has indep N(0,1) coordinates
    kl_z_prior_post = KL_divergence(z_post_mu, z_post_log_sigma,
                                       torch.zeros(z_post_mu.shape).to(device), torch.zeros(z_post_mu.shape).to(device))

    # compute KL divergence between prior and posterior of shared params
    # KL(q_t(theta) || q_(t-1)(theta))
    kl_shared_param = gen.KL_div_shared_prior_post()

    # estimate E_q_phi_z [log(p_theta(x|z))]
    # we model x ~ Ber(gen(z))
    post_x_log_likelihood = 0.0
    axes_to_reduce = list(range(1, len(x_train.shape)))

    for _ in range(K):
        #z = sample_laplace(z_post_mu, z_post_log_sigma)
        z = sample_gaussian(z_post_mu, z_post_log_sigma)
        x_mu = gen(z, task)
        post_x_log_likelihood += bernoulli_log_likelihood(x_train, x_mu) / K

    loss = kl_z_prior_post.mean() - post_x_log_likelihood.mean() + kl_shared_param / total_data

    return loss, kl_z_prior_post.mean().item(), post_x_log_likelihood.mean().item(), kl_shared_param.item()

In [None]:
def eval_test_ll_on_task(task_testdata_dict, task, enc, gen, sampling_mode=True, K=100):
    enc.eval(); gen.eval()

    total_batches = len(task_testdata_dict[task])

    test_ll_mean = 0.0
    test_ll_std = 0.0

    for data_batch, _ in task_testdata_dict[task]:
        data_batch = data_batch.to(device)

        xs_stacked = torch.tile(data_batch.view(-1, 28**2), (K, 1))
        z_post_mu, z_post_log_sigma = enc(xs_stacked)
        #z = sample_laplace(z_post_mu, z_post_log_sigma)
        z = sample_gaussian(z_post_mu, z_post_log_sigma)

        #z_prior_log_likelihood = laplace_log_likelihood(z, torch.zeros(z.shape).to(device), torch.zeros(z.shape).to(device))
        z_prior_log_likelihood = normal_log_likelihood(z, torch.zeros(z.shape).to(device), torch.zeros(z.shape).to(device))
        #z_post_log_likelihood = laplace_log_likelihood(z, z_post_mu, z_post_log_sigma)
        z_post_log_likelihood = normal_log_likelihood(z, z_post_mu, z_post_log_sigma)
        kl_z_prior_post = z_post_log_likelihood - z_prior_log_likelihood

        xs_stacked_mu = gen(z, task, sampling_mode)
        x_post_log_likelihood = bernoulli_log_likelihood(xs_stacked, xs_stacked_mu)

        bound = (x_post_log_likelihood - kl_z_prior_post).reshape(K, data_batch.shape[0])
        bound_max = torch.max(bound, dim=0).values
        bound -= bound_max

        log_norm_bound = torch.log(torch.clamp(torch.mean(torch.exp(bound), 0), 1e-9, np.inf))
        test_ll = log_norm_bound + bound_max
        test_ll_mean += torch.mean(test_ll) / total_batches
        test_ll_std += torch.sqrt(torch.var(test_ll, correction=0) / test_ll.shape[0]) / total_batches

    return test_ll_mean, test_ll_std

In [None]:
def train(task_traindata_dict, task_testdata_dict, enc, gen, hyperparams, device):
    adam = Adam(list(enc.parameters()) + list(gen.parameters()), lr=hyperparams['lr'])
    accuracies = []

    for task, data in enumerate(task_traindata_dict.values()):
        enc.train(); gen.train()
        tmp_accuracies = []
        print("Data for task", task, "arrives.")

        for epoch in range(hyperparams['n_epochs']):
            kl1s = []; expecs = []; kl2s = []

            total_data = len(data.sampler)

            for data_batch, _ in data:
                data_batch = data_batch.to(device)
                adam.zero_grad()
                loss, kl1, expec, kl2 = compute_loss(
                    data_batch.view(-1, 28**2),
                    enc, gen, task, total_data
                )
                kl1s.append(kl1); kl2s.append(kl2); expecs.append(expec)
                loss.backward()
                adam.step()

            print('epoch:', epoch, 'with stats', sum(kl1s)/len(kl1s), sum(expecs)/len(expecs), sum(kl2s)/len(kl2s))

        with torch.no_grad():
            enc.eval(); gen.eval()

            for prev_task in range(task + 1):
                #z = torch.log(torch.rand(hyperparams['dim_z'], dtype=tf.float32)) - \
                #    torch.log(torch.rand(hyperparams['dim_z'], dtype=tf.float32))
                z = torch.randn(hyperparams['dim_z']).to(device)
                x = generator(z, prev_task, sampling_mode=hyperparams['sampling_mode']).reshape(28, 28)
                plt.imshow(x.cpu().detach().numpy(), cmap='gray', vmin=0.0, vmax=1.0)
                plt.show()

                test_ll_mean, test_ll_std = eval_test_ll_on_task(
                    task_testdata_dict, prev_task, enc, gen,
                    sampling_mode=hyperparams['sampling_mode']
                )

                print("On task:", list(task_testdata_dict.keys())[prev_task],
                      "was achieved test_ll_mean:", test_ll_mean,
                      "with test_ll_std:", test_ll_std
                )

                tmp_accuracies.append((test_ll_mean, test_ll_std))

        gen.update_shared_params_prior()
        gen.reset_shared_params_log_sigmas()

        accuracies.append(tmp_accuracies)

    return accuracies

In [None]:
accuracies = train(task_traindata_dict, task_testdata_dict, encoder, generator, hyperparams, device)
print(accuracies)