### Settings for Colab...

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# ! cp -R "/content/drive/MyDrive/Advanced Machine Learning Project/GenerativeReplay/alg" ./

### Load Data and Packages

In [2]:
def set_up_parent_imports():
    import sys, os
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

set_up_parent_imports()

import torch
from torch import nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from alg.vcl_net import MultiHeadVCLSplitNotMNIST, Initialization
from alg.kcenter import KCenter
from datasets import MultiTaskDataset
import pickle

### Define Generator

In [3]:
class Encoder(nn.Module):
    def __init__(self, z_dim, task_num=5):
        super().__init__()
        self.task_num = task_num
        self.fcs = nn.ModuleList([nn.Linear(784+task_num if i == 0 else 500+task_num, 500) for i in range(6)])
        self.head1 = nn.Linear(500+task_num, z_dim)
        self.head2 = nn.Linear(500+task_num, z_dim)
    def forward(self, x, task_ids):
        t = nn.functional.one_hot(task_ids.long(), num_classes=self.task_num)
        h = x
        for i in self.fcs:
            h = nn.ReLU()(i(torch.cat([h, t], dim=-1)))
        loc = self.head1(torch.cat([h, t], dim=-1))
        scale = torch.exp(self.head2(torch.cat([h, t], dim=-1)))
        return loc, scale
class Decoder(nn.Module):
    def __init__(self, z_dim, task_num=5):
        super().__init__()
        self.task_num = task_num
        self.fcs = nn.ModuleList([nn.Linear(z_dim+task_num if i == 0 else 500+task_num, 500) for i in range(6)])
        self.fc6 = nn.Linear(500+task_num, 784)
    def forward(self, z, task_ids):
        t = nn.functional.one_hot(task_ids.long(), num_classes=self.task_num)
        h = z
        for i in self.fcs:
            h = nn.ReLU()(i(torch.cat([h, t], dim=-1)))
        logits = self.fc6(torch.cat([h, t], dim=-1))
        return logits

### Training Fcn

In [5]:
def run(verbose=True, use_coreset=False):

    generator_models = {}
    solvers = {}


    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    test_x_all = []
    test_y_all = []
    test_task_i_all = []

    if use_coreset:
        coreset_x = None  
        coreset_y = None
        coreset_size = 40
        random_coreset = True

    batch_size = 1000
    accuracies = {}
    n_epochs = 120
    n_epochs_generator = 300


    previous_model, coreset_task_mask = None, None

    dataset = MultiTaskDataset('split notMNIST', device)
    for task_i in range(5):
        train_x, train_y, test_x, test_y = dataset.get_task_dataset(task_i)
        test_x_all.append(test_x)
        test_y_all.append(test_y)
        test_task_i_all.append(torch.ones((test_x.shape[0]), dtype=int) * task_i)

        # define current model
        if task_i == 0:
            current_model = MultiHeadVCLSplitNotMNIST(num_heads=1, initialization=Initialization.RANDOM).to(device)
            current_model.set_prior(MultiHeadVCLSplitNotMNIST(num_heads=1, initialization=Initialization.DEFAULT).to(device))
        else:
            current_model = MultiHeadVCLSplitNotMNIST.new_from_prior(previous_model)
            current_model.add_head(initialization=Initialization.RANDOM)
            # Set last head the same as the first head (this was in the initial implementation in tf) TODO investigate
            # current_model.heads[-1].set_params(*previous_model.heads[0].get_params())

        assert len(current_model.heads) == task_i + 1
        current_opt = torch.optim.Adam(current_model.parameters(), lr=0.001)


        if use_coreset:
            if random_coreset:
                coreset_idx = np.random.choice(train_x.shape[0], coreset_size, False)
            else:
                coreset_idx = np.array(KCenter(coreset_size).fit_transform(train_x.cpu().detach().numpy()))
            train_idx = np.delete(np.arange(train_x.shape[0]), coreset_idx)
            new_coreset_x = train_x[coreset_idx]
            new_coreset_y = train_y[coreset_idx]
            new_coreset_task_mask = torch.ones((new_coreset_x.shape[0]), dtype=int) * task_i
            train_x = train_x[train_idx]
            train_y = train_y[train_idx]

            if coreset_x == None:
                coreset_x = new_coreset_x
                coreset_y = new_coreset_y
                coreset_task_mask = new_coreset_task_mask
            else:
                coreset_x = torch.cat([new_coreset_x, coreset_x])
                coreset_y = torch.cat([new_coreset_y, coreset_y])
                coreset_task_mask = torch.cat([new_coreset_task_mask, coreset_task_mask])
                # "For all the algorithms with coresets, we choose 40 examples from each task to include into the coresets"
        
        ########################################################################
        ####################### train the generator first ######################
        # train generator
        new_decoder = Decoder(50).to(device)
        new_encoder = Encoder(50).to(device)

        batch_size_g = 256
        ELBOs = []
        optimizer = torch.optim.Adam(list(new_decoder.parameters())+list(new_encoder.parameters()), lr=0.001)
        for e in tqdm(range(n_epochs_generator)):
            elbos = []
            for batch in range(int(np.ceil(train_x.shape[0]/batch_size_g))):
                b_idx0 = batch_size_g*batch
                b_idx1 = batch_size_g*batch+batch_size_g
                batch_x = train_x[b_idx0: b_idx1]
                tasks_x = torch.ones(batch_x.shape[0], device=batch_x.device) * task_i
                # if task != 0: also generate some dataset from the old generator for training
                if task_i != 0:
                    with torch.no_grad():
                        for task_j in range(task_i):
                            z_old = torch.randn(batch_x.shape[0], 50, device=device)
                            tasks_old = torch.ones(batch_x.shape[0], device=batch_x.device) * task_j
                            batch_x_old = nn.Sigmoid()(old_decoder(z_old, tasks_old))
                            batch_x = torch.cat([batch_x, batch_x_old], dim=0)
                            tasks_x = torch.cat([tasks_x, tasks_old], dim=0)
                            if verbose and batch == 0 and e == 0:
                                print("Generating from old distributions...")
                                show_x = batch_x_old[:16].cpu().numpy().reshape(16, 28, 28)
                                show_x = np.hstack([show_x[i] for i in range(16)])

                                plt.imshow(show_x)
                                plt.title("Generated Old Images from Task %d"%task_j)
                                plt.show()
                z_loc, z_scale = new_encoder(batch_x, tasks_x)
                z = z_loc + torch.randn_like(z_scale) * z_scale
                outputs = new_decoder(z, tasks_x)
                loss = nn.BCELoss(reduction="none")(nn.Sigmoid()(outputs), batch_x).sum() / batch_x.shape[0]
                loss = loss + torch.distributions.kl_divergence(torch.distributions.Normal(z_loc, z_scale), torch.distributions.Normal(0, 1)).sum() / train_x.shape[0]
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                elbos.append(loss.item())
            ELBOs.append(np.mean(elbos))
        ########################################################################

        ELBO = []
        for epoch in (tqdm if verbose else iter)(range(n_epochs)):
            ELBO_batch = []
            for batch in range(int(np.ceil(train_x.shape[0] / batch_size))):
                batch_idx0 = batch * batch_size
                batch_idx1 = batch * batch_size + batch_size
                
                current_opt.zero_grad()
                batch_x = train_x[batch_idx0: batch_idx1]
                batch_y = train_y[batch_idx0: batch_idx1]
                batch_task =  torch.ones(batch_x.shape[0], device=batch_x.device) * task_i

                # generate samples following old distributions 
                if task_i != 0:
                    for task_j in range(task_i):
                        with torch.no_grad():
                            tasks_old = torch.ones(batch_x.shape[0], device=batch_x.device) * task_j
                            old_z = torch.randn(batch_x.shape[0], 50, device=device)
                            old_x = nn.Sigmoid()(old_decoder(old_z, tasks_old))
                            # predict using old models
                            old_y = nn.Softmax(-1)(torch.stack(previous_model.predict(old_x, task_j, 100), 0)).mean(0)
                            old_y = old_y.argmax(-1)
                        batch_x = torch.cat([old_x, batch_x], dim=0)
                        batch_y = torch.cat([old_y, batch_y], dim=0)
                        batch_task = torch.cat([tasks_old, batch_task])


                elbo = current_model.calculate_ELBO(x=batch_x, 
                                                    y=batch_y, 
                                                    n_particles=1,
                                                    task_i_mask=batch_task.long().cpu(),
                                                    dataset_size=train_x.shape[0] * (task_i+1))
                elbo.backward()
                nn.utils.clip_grad_value_(current_model.parameters(), 5)
                current_opt.step()
                ELBO_batch.append(elbo.item())
            ELBO.append(np.mean(ELBO_batch))
        if verbose:
            plt.plot(ELBO)
            plt.yscale("log")
            plt.show()
        acc = []
        for idx in range(len(test_x_all)):
            test_x_tensor = test_x_all[idx]
            test_y_tensor = test_y_all[idx]
            test_task_i_mask_tensor = test_task_i_all[idx]
            pred_y = []
            with torch.no_grad():
                for batch in range(int(np.ceil(test_x_tensor.shape[0] / batch_size))):
                    batch_idx0 = batch * batch_size
                    batch_idx1 = batch * batch_size + batch_size
                    pred_logit_samples = nn.Softmax(-1)(torch.stack(current_model.predict(test_x_tensor[batch_idx0:batch_idx1], test_task_i_mask_tensor[batch_idx0:batch_idx1], 100), 0)).mean(0)
                    pred_y.append(pred_logit_samples.argmax(-1))
                pred_y = torch.cat(pred_y)
                _acc = (pred_y == test_y_tensor).cpu().numpy().mean()
                acc.append(_acc)
        if verbose:
            if use_coreset:
                print("Accuracy by the propagation model", acc)
            else:
                print("Task {:d}, Accuracy: ".format(task_i), acc)
        if use_coreset == False:
            accuracies[task_i] = acc

        if use_coreset:
            # calculate prediction model
            pred_model = MultiHeadVCLSplitNotMNIST.new_from_prior(current_model)
            pred_opt = torch.optim.Adam(pred_model.parameters(), lr=0.001)

            ELBO = []
            for epoch in (tqdm if verbose else iter)(range(n_epochs)):
                ELBO_batch = []
                for batch in range(int(np.ceil(coreset_x.shape[0] / batch_size))):
                    batch_idx0 = batch * batch_size
                    batch_idx1 = batch * batch_size + batch_size
                    pred_opt.zero_grad()
                    elbo = pred_model.calculate_ELBO(x=coreset_x[batch_idx0: batch_idx1], 
                                                    y=coreset_y[batch_idx0: batch_idx1], 
                                                    task_i_mask=coreset_task_mask[batch_idx0: batch_idx1],
                                                    n_particles=1,
                                                    dataset_size=coreset_x.shape[0])
                    elbo.backward()
                    nn.utils.clip_grad_value_(pred_model.parameters(), 5)
                    pred_opt.step()
                    ELBO_batch.append(elbo.item())
                ELBO.append(np.mean(ELBO_batch))
            if verbose:
                plt.plot(ELBO)
                plt.show()

            acc = []
            for idx in range(len(test_x_all)):
                test_x_tensor = test_x_all[idx]     
                test_y_tensor = test_y_all[idx]
                test_task_i_mask_tensor = test_task_i_all[idx]
                pred_y = []
                with torch.no_grad():
                    for batch in range(int(np.ceil(test_x_tensor.shape[0] / batch_size))):
                        batch_idx0 = batch * batch_size
                        batch_idx1 = batch * batch_size + batch_size
                        pred_logit_samples = nn.Softmax(-1)(torch.stack(pred_model.predict(test_x_tensor[batch_idx0:batch_idx1], test_task_i_mask_tensor[batch_idx0:batch_idx1], 100), 0)).mean(0)
                        pred_y.append(pred_logit_samples.argmax(-1))
                    pred_y = torch.cat(pred_y)
                    _acc = (pred_y == test_y_tensor).cpu().numpy().mean()
                    acc.append(_acc)
            accuracies[task_i] = acc
            if verbose:
                print("Task {:d}, Accuracy: ".format(task_i), acc)


        previous_model = current_model
        old_decoder = new_decoder
        old_encoder = new_encoder
        generator_models[task_i] = [new_decoder, new_encoder]
        solvers[task_i] = current_model if not use_coreset else pred_model


    # with open("/content/drive/MyDrive/Advanced Machine Learning Project/GenerativeReplay/Split_MNIST_GR1_models_w_Coreset.pkl", "wb") as f:
    #     pickle.dump(generator_models, f)
    #     pickle.dump(solvers, f)

    return accuracies, generator_models, solvers


In [6]:
for i in range(10):
    print("============================ Exp %d ============================"%i)
    accs, generator_models, solvers = run(False, True)
    with open("./Split_notMNIST_GR1_models_w_Coreset_exp_%d.pkl"%i, "wb") as f:
        pickle.dump([accs, generator_models, solvers], f)
    print(accs)



100%|██████████| 300/300 [00:18<00:00, 16.59it/s]
100%|██████████| 300/300 [00:18<00:00, 15.83it/s]
100%|██████████| 300/300 [00:20<00:00, 14.45it/s]
100%|██████████| 300/300 [00:29<00:00, 10.29it/s]
100%|██████████| 300/300 [00:48<00:00,  6.13it/s]


{0: [0.9813664596273292], 1: [0.9834368530020704, 0.9661016949152542], 2: [0.9834368530020704, 0.9585687382297552, 0.9721189591078067], 3: [0.9834368530020704, 0.9623352165725048, 0.9721189591078067, 0.9395085066162571], 4: [0.979296066252588, 0.9661016949152542, 0.9739776951672863, 0.947069943289225, 0.9379194630872483]}


100%|██████████| 300/300 [00:16<00:00, 17.79it/s]
100%|██████████| 300/300 [00:19<00:00, 15.74it/s]
100%|██████████| 300/300 [00:20<00:00, 14.37it/s]
100%|██████████| 300/300 [00:29<00:00, 10.19it/s]
100%|██████████| 300/300 [00:49<00:00,  6.12it/s]


{0: [0.9813664596273292], 1: [0.9834368530020704, 0.9397363465160076], 2: [0.9834368530020704, 0.9472693032015066, 0.9553903345724907], 3: [0.9834368530020704, 0.9585687382297552, 0.9553903345724907, 0.947069943289225], 4: [0.979296066252588, 0.9623352165725048, 0.9628252788104089, 0.9508506616257089, 0.9312080536912751]}


100%|██████████| 300/300 [00:17<00:00, 17.62it/s]
100%|██████████| 300/300 [00:19<00:00, 15.14it/s]
100%|██████████| 300/300 [00:21<00:00, 13.76it/s]
100%|██████████| 300/300 [00:29<00:00, 10.07it/s]
100%|██████████| 300/300 [00:49<00:00,  6.06it/s]


{0: [0.9813664596273292], 1: [0.9813664596273292, 0.96045197740113], 2: [0.979296066252588, 0.9623352165725048, 0.9646840148698885], 3: [0.979296066252588, 0.9642184557438794, 0.9684014869888475, 0.9376181474480151], 4: [0.979296066252588, 0.9566854990583804, 0.9684014869888475, 0.9527410207939508, 0.9446308724832215]}


100%|██████████| 300/300 [00:17<00:00, 16.90it/s]
100%|██████████| 300/300 [00:19<00:00, 15.69it/s]
100%|██████████| 300/300 [00:22<00:00, 13.54it/s]
100%|██████████| 300/300 [00:29<00:00, 10.04it/s]
100%|██████████| 300/300 [00:49<00:00,  6.04it/s]


{0: [0.9834368530020704], 1: [0.9834368530020704, 0.9623352165725048], 2: [0.9834368530020704, 0.9642184557438794, 0.9684014869888475], 3: [0.9834368530020704, 0.967984934086629, 0.9702602230483272, 0.947069943289225], 4: [0.9834368530020704, 0.96045197740113, 0.9702602230483272, 0.9489603024574669, 0.959731543624161]}


100%|██████████| 300/300 [00:17<00:00, 16.84it/s]
100%|██████████| 300/300 [00:19<00:00, 15.33it/s]
100%|██████████| 300/300 [00:21<00:00, 14.18it/s]
100%|██████████| 300/300 [00:29<00:00, 10.12it/s]
100%|██████████| 300/300 [00:49<00:00,  6.07it/s]


{0: [0.9813664596273292], 1: [0.9855072463768116, 0.9642184557438794], 2: [0.9855072463768116, 0.9736346516007532, 0.9702602230483272], 3: [0.9875776397515528, 0.9661016949152542, 0.9684014869888475, 0.943289224952741], 4: [0.9855072463768116, 0.9623352165725048, 0.9739776951672863, 0.943289224952741, 0.9530201342281879]}


100%|██████████| 300/300 [00:17<00:00, 17.48it/s]
100%|██████████| 300/300 [00:19<00:00, 15.43it/s]
100%|██████████| 300/300 [00:21<00:00, 14.03it/s]
100%|██████████| 300/300 [00:29<00:00, 10.15it/s]
100%|██████████| 300/300 [00:49<00:00,  6.06it/s]


{0: [0.979296066252588], 1: [0.9855072463768116, 0.9661016949152542], 2: [0.9813664596273292, 0.96045197740113, 0.9684014869888475], 3: [0.979296066252588, 0.9623352165725048, 0.9739776951672863, 0.943289224952741], 4: [0.979296066252588, 0.9661016949152542, 0.9721189591078067, 0.943289224952741, 0.9681208053691275]}


100%|██████████| 300/300 [00:16<00:00, 17.66it/s]
100%|██████████| 300/300 [00:19<00:00, 15.50it/s]
100%|██████████| 300/300 [00:21<00:00, 13.78it/s]
100%|██████████| 300/300 [00:29<00:00, 10.18it/s]
100%|██████████| 300/300 [00:49<00:00,  6.08it/s]


{0: [0.9813664596273292], 1: [0.979296066252588, 0.967984934086629], 2: [0.979296066252588, 0.9566854990583804, 0.9646840148698885], 3: [0.979296066252588, 0.9585687382297552, 0.966542750929368, 0.947069943289225], 4: [0.9751552795031055, 0.9548022598870056, 0.9721189591078067, 0.941398865784499, 0.9429530201342282]}


100%|██████████| 300/300 [00:17<00:00, 17.60it/s]
100%|██████████| 300/300 [00:19<00:00, 15.61it/s]
100%|██████████| 300/300 [00:21<00:00, 14.26it/s]
100%|██████████| 300/300 [00:29<00:00, 10.17it/s]
100%|██████████| 300/300 [00:49<00:00,  6.09it/s]


{0: [0.9875776397515528], 1: [0.9834368530020704, 0.9642184557438794], 2: [0.9875776397515528, 0.9642184557438794, 0.9702602230483272], 3: [0.9834368530020704, 0.9642184557438794, 0.9739776951672863, 0.9565217391304348], 4: [0.9855072463768116, 0.96045197740113, 0.9721189591078067, 0.9508506616257089, 0.9530201342281879]}


100%|██████████| 300/300 [00:16<00:00, 17.66it/s]
100%|██████████| 300/300 [00:19<00:00, 15.62it/s]
100%|██████████| 300/300 [00:21<00:00, 14.20it/s]
100%|██████████| 300/300 [00:29<00:00, 10.19it/s]
100%|██████████| 300/300 [00:49<00:00,  6.11it/s]


{0: [0.9813664596273292], 1: [0.9834368530020704, 0.967984934086629], 2: [0.9855072463768116, 0.9642184557438794, 0.9776951672862454], 3: [0.9834368530020704, 0.9661016949152542, 0.9776951672862454, 0.9489603024574669], 4: [0.9772256728778468, 0.9717514124293786, 0.9795539033457249, 0.9489603024574669, 0.9312080536912751]}


100%|██████████| 300/300 [00:16<00:00, 17.83it/s]
100%|██████████| 300/300 [00:18<00:00, 15.99it/s]
100%|██████████| 300/300 [00:20<00:00, 14.33it/s]
100%|██████████| 300/300 [00:29<00:00, 10.18it/s]
100%|██████████| 300/300 [00:49<00:00,  6.10it/s]


{0: [0.979296066252588], 1: [0.9813664596273292, 0.96045197740113], 2: [0.979296066252588, 0.96045197740113, 0.9646840148698885], 3: [0.979296066252588, 0.9736346516007532, 0.9646840148698885, 0.9395085066162571], 4: [0.9813664596273292, 0.9642184557438794, 0.9721189591078067, 0.9527410207939508, 0.9580536912751678]}


In [6]:
# Load results
for i in range(10):
    with open("./Split_notMNIST_GR1_models_w_Coreset_exp_%d.pkl"%i, 'rb') as f:
        accs, generator_models, solvers = pickle.load(f)
        print(accs)

{0: [0.9813664596273292], 1: [0.9834368530020704, 0.9661016949152542], 2: [0.9834368530020704, 0.9585687382297552, 0.9721189591078067], 3: [0.9834368530020704, 0.9623352165725048, 0.9721189591078067, 0.9395085066162571], 4: [0.979296066252588, 0.9661016949152542, 0.9739776951672863, 0.947069943289225, 0.9379194630872483]}
{0: [0.9813664596273292], 1: [0.9834368530020704, 0.9397363465160076], 2: [0.9834368530020704, 0.9472693032015066, 0.9553903345724907], 3: [0.9834368530020704, 0.9585687382297552, 0.9553903345724907, 0.947069943289225], 4: [0.979296066252588, 0.9623352165725048, 0.9628252788104089, 0.9508506616257089, 0.9312080536912751]}
{0: [0.9813664596273292], 1: [0.9813664596273292, 0.96045197740113], 2: [0.979296066252588, 0.9623352165725048, 0.9646840148698885], 3: [0.979296066252588, 0.9642184557438794, 0.9684014869888475, 0.9376181474480151], 4: [0.979296066252588, 0.9566854990583804, 0.9684014869888475, 0.9527410207939508, 0.9446308724832215]}
{0: [0.9834368530020704], 1: [0