### Settings for Colab...

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! cp -R "/content/drive/MyDrive/Advanced Machine Learning Project/GenerativeReplay/alg" ./

### Load Data and Packages

In [None]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from alg.vcl_net import MultiHeadVCLSplitMNIST, Initialization
from alg.kcenter import KCenter
import pickle

In [None]:
ds_test = datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)
ds_train = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
def get_sMNIST(task_idx, device):
    np.random.seed(task_idx)
    labels = np.array_split(range(0, 10), 5)[task_idx]
    ds_train_filtered = list(filter(lambda item: item[1] in labels, ds_train))
    ds_test_filtered = list(filter(lambda item: item[1] in labels, ds_test))
    train_x = nn.Flatten()(torch.cat([d[0] for d in ds_train_filtered]))
    train_y = torch.tensor([d[1] - task_idx * 2 for d in ds_train_filtered])
    
    test_x = nn.Flatten()(torch.cat([d[0] for d in ds_test_filtered]))
    test_y = torch.tensor([d[1] - task_idx * 2 for d in ds_test_filtered])

    return train_x.to(device), train_y.to(device), test_x.to(device), test_y.to(device)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



### Define Generator

In [None]:
class Encoder(nn.Module):
    def __init__(self, z_dim, task_num=5):
        super().__init__()
        self.task_num = task_num
        self.fcs = nn.ModuleList([nn.Linear(784+task_num if i == 0 else 500+task_num, 500) for i in range(6)])
        self.head1 = nn.Linear(500+task_num, z_dim)
        self.head2 = nn.Linear(500+task_num, z_dim)
    def forward(self, x, task_ids):
        t = nn.functional.one_hot(task_ids.long(), num_classes=self.task_num)
        h = x
        for i in self.fcs:
            h = nn.ReLU()(i(torch.cat([h, t], dim=-1)))
        loc = self.head1(torch.cat([h, t], dim=-1))
        scale = torch.exp(self.head2(torch.cat([h, t], dim=-1)))
        return loc, scale
class Decoder(nn.Module):
    def __init__(self, z_dim, task_num=5):
        super().__init__()
        self.task_num = task_num
        self.fcs = nn.ModuleList([nn.Linear(z_dim+task_num if i == 0 else 500+task_num, 500) for i in range(6)])
        self.fc6 = nn.Linear(500+task_num, 784)
    def forward(self, z, task_ids):
        t = nn.functional.one_hot(task_ids.long(), num_classes=self.task_num)
        h = z
        for i in self.fcs:
            h = nn.ReLU()(i(torch.cat([h, t], dim=-1)))
        logits = self.fc6(torch.cat([h, t], dim=-1))
        return logits

### Training Fcn

In [None]:
def run(verbose=True, use_coreset=False):

    generator_models = {}
    solvers = {}


    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    test_x_all = []
    test_y_all = []
    test_task_i_all = []

    if use_coreset:
        coreset_x = None  
        coreset_y = None
        coreset_size = 40
        random_coreset = True

    batch_size = 1000
    accuracies = {}
    n_epochs = 120
    n_epochs_generator = 300


    previous_model, coreset_task_mask = None, None


    for task_i in range(5):
        train_x, train_y, test_x, test_y = get_sMNIST(task_i, device)
        test_x_all.append(test_x)
        test_y_all.append(test_y)
        test_task_i_all.append(torch.ones((test_x.shape[0]), dtype=int) * task_i)

        # define current model
        if task_i == 0:
            current_model = MultiHeadVCLSplitMNIST(num_heads=1, initialization=Initialization.RANDOM).to(device)
            current_model.set_prior(MultiHeadVCLSplitMNIST(num_heads=1, initialization=Initialization.DEFAULT).to(device))
        else:
            current_model = MultiHeadVCLSplitMNIST.new_from_prior(previous_model)
            current_model.add_head(initialization=Initialization.RANDOM)

        assert len(current_model.heads) == task_i + 1
        current_opt = torch.optim.Adam(current_model.parameters(), lr=0.001)


        if use_coreset:
            if random_coreset:
                coreset_idx = np.random.choice(train_x.shape[0], coreset_size, False)
            else:
                coreset_idx = np.array(KCenter(coreset_size).fit_transform(train_x.cpu().detach().numpy()))
            train_idx = np.delete(np.arange(train_x.shape[0]), coreset_idx)
            new_coreset_x = train_x[coreset_idx]
            new_coreset_y = train_y[coreset_idx]
            new_coreset_task_mask = torch.ones((new_coreset_x.shape[0]), dtype=int) * task_i
            train_x = train_x[train_idx]
            train_y = train_y[train_idx]

            if coreset_x == None:
                coreset_x = new_coreset_x
                coreset_y = new_coreset_y
                coreset_task_mask = new_coreset_task_mask
            else:
                coreset_x = torch.cat([new_coreset_x, coreset_x])
                coreset_y = torch.cat([new_coreset_y, coreset_y])
                coreset_task_mask = torch.cat([new_coreset_task_mask, coreset_task_mask])
                # "For all the algorithms with coresets, we choose 40 examples from each task to include into the coresets"
        
        ########################################################################
        ####################### train the generator first ######################
        # train generator
        new_decoder = Decoder(50).to(device)
        new_encoder = Encoder(50).to(device)

        batch_size_g = 256
        ELBOs = []
        optimizer = torch.optim.Adam(list(new_decoder.parameters())+list(new_encoder.parameters()), lr=0.001)
        for e in tqdm(range(n_epochs_generator)):
            elbos = []
            for batch in range(int(np.ceil(train_x.shape[0]/batch_size_g))):
                b_idx0 = batch_size_g*batch
                b_idx1 = batch_size_g*batch+batch_size_g
                batch_x = train_x[b_idx0: b_idx1]
                tasks_x = torch.ones(batch_x.shape[0], device=batch_x.device) * task_i
                # if task != 0: also generate some dataset from the old generator for training
                if task_i != 0:
                    with torch.no_grad():
                        for task_j in range(task_i):
                            z_old = torch.randn(batch_x.shape[0], 50, device=device)
                            tasks_old = torch.ones(batch_x.shape[0], device=batch_x.device) * task_j
                            batch_x_old = nn.Sigmoid()(old_decoder(z_old, tasks_old))
                            batch_x = torch.cat([batch_x, batch_x_old], dim=0)
                            tasks_x = torch.cat([tasks_x, tasks_old], dim=0)
                            if verbose and batch == 0 and e == 0:
                                print("Generating from old distributions...")
                                show_x = batch_x_old[:16].cpu().numpy().reshape(16, 28, 28)
                                show_x = np.hstack([show_x[i] for i in range(16)])

                                plt.imshow(show_x)
                                plt.title("Generated Old Images from Task %d"%task_j)
                                plt.show()
                z_loc, z_scale = new_encoder(batch_x, tasks_x)
                z = z_loc + torch.randn_like(z_scale) * z_scale
                outputs = new_decoder(z, tasks_x)
                loss = nn.BCELoss(reduction="none")(nn.Sigmoid()(outputs), batch_x).sum() / batch_x.shape[0]
                loss = loss + torch.distributions.kl_divergence(torch.distributions.Normal(z_loc, z_scale), torch.distributions.Normal(0, 1)).sum() / train_x.shape[0]
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                elbos.append(loss.item())
            ELBOs.append(np.mean(elbos))
        ########################################################################

        ELBO = []
        for epoch in (tqdm if verbose else iter)(range(n_epochs)):
            ELBO_batch = []
            for batch in range(int(np.ceil(train_x.shape[0] / batch_size))):
                batch_idx0 = batch * batch_size
                batch_idx1 = batch * batch_size + batch_size
                
                current_opt.zero_grad()
                batch_x = train_x[batch_idx0: batch_idx1]
                batch_y = train_y[batch_idx0: batch_idx1]
                batch_task =  torch.ones(batch_x.shape[0], device=batch_x.device) * task_i

                # generate samples following old distributions 
                if task_i != 0:
                    for task_j in range(task_i):
                        with torch.no_grad():
                            tasks_old = torch.ones(batch_x.shape[0], device=batch_x.device) * task_j
                            old_z = torch.randn(batch_x.shape[0], 50, device=device)
                            old_x = nn.Sigmoid()(old_decoder(old_z, tasks_old))
                            # predict using old models
                            old_y = nn.Softmax(-1)(torch.stack(previous_model.predict(old_x, task_j, 100), 0)).mean(0)
                            old_y = old_y.argmax(-1)
                        batch_x = torch.cat([old_x, batch_x], dim=0)
                        batch_y = torch.cat([old_y, batch_y], dim=0)
                        batch_task = torch.cat([tasks_old, batch_task])


                elbo = current_model.calculate_ELBO(x=batch_x, 
                                                    y=batch_y, 
                                                    n_particles=1,
                                                    task_i_mask=batch_task.long().cpu(),
                                                    dataset_size=train_x.shape[0] * (task_i+1))
                elbo.backward()
                nn.utils.clip_grad_value_(current_model.parameters(), 5)
                current_opt.step()
                ELBO_batch.append(elbo.item())
            ELBO.append(np.mean(ELBO_batch))
        if verbose:
            plt.plot(ELBO)
            plt.yscale("log")
            plt.show()
        acc = []
        for idx in range(len(test_x_all)):
            test_x_tensor = test_x_all[idx]
            test_y_tensor = test_y_all[idx]
            test_task_i_mask_tensor = test_task_i_all[idx]
            pred_y = []
            with torch.no_grad():
                for batch in range(int(np.ceil(test_x_tensor.shape[0] / batch_size))):
                    batch_idx0 = batch * batch_size
                    batch_idx1 = batch * batch_size + batch_size
                    pred_logit_samples = nn.Softmax(-1)(torch.stack(current_model.predict(test_x_tensor[batch_idx0:batch_idx1], test_task_i_mask_tensor[batch_idx0:batch_idx1], 100), 0)).mean(0)
                    pred_y.append(pred_logit_samples.argmax(-1))
                pred_y = torch.cat(pred_y)
                _acc = (pred_y == test_y_tensor).cpu().numpy().mean()
                acc.append(_acc)
        if verbose:
            if use_coreset:
                print("Accuracy by the propagation model", acc)
            else:
                print("Task {:d}, Accuracy: ".format(task_i), acc)
        if use_coreset == False:
            accuracies[task_i] = acc

        if use_coreset:
            # calculate prediction model
            pred_model = MultiHeadVCLSplitMNIST.new_from_prior(current_model)
            pred_opt = torch.optim.Adam(pred_model.parameters(), lr=0.001)

            ELBO = []
            for epoch in (tqdm if verbose else iter)(range(n_epochs)):
                ELBO_batch = []
                for batch in range(int(np.ceil(coreset_x.shape[0] / batch_size))):
                    batch_idx0 = batch * batch_size
                    batch_idx1 = batch * batch_size + batch_size
                    pred_opt.zero_grad()
                    elbo = pred_model.calculate_ELBO(x=coreset_x[batch_idx0: batch_idx1], 
                                                    y=coreset_y[batch_idx0: batch_idx1], 
                                                    task_i_mask=coreset_task_mask[batch_idx0: batch_idx1],
                                                    n_particles=1,
                                                    dataset_size=coreset_x.shape[0])
                    elbo.backward()
                    nn.utils.clip_grad_value_(pred_model.parameters(), 5)
                    pred_opt.step()
                    ELBO_batch.append(elbo.item())
                ELBO.append(np.mean(ELBO_batch))
            if verbose:
                plt.plot(ELBO)
                plt.show()

            acc = []
            for idx in range(len(test_x_all)):
                test_x_tensor = test_x_all[idx]     
                test_y_tensor = test_y_all[idx]
                test_task_i_mask_tensor = test_task_i_all[idx]
                pred_y = []
                with torch.no_grad():
                    for batch in range(int(np.ceil(test_x_tensor.shape[0] / batch_size))):
                        batch_idx0 = batch * batch_size
                        batch_idx1 = batch * batch_size + batch_size
                        pred_logit_samples = nn.Softmax(-1)(torch.stack(pred_model.predict(test_x_tensor[batch_idx0:batch_idx1], test_task_i_mask_tensor[batch_idx0:batch_idx1], 100), 0)).mean(0)
                        pred_y.append(pred_logit_samples.argmax(-1))
                    pred_y = torch.cat(pred_y)
                    _acc = (pred_y == test_y_tensor).cpu().numpy().mean()
                    acc.append(_acc)
            accuracies[task_i] = acc
            if verbose:
                print("Task {:d}, Accuracy: ".format(task_i), acc)


        previous_model = current_model
        old_decoder = new_decoder
        old_encoder = new_encoder
        generator_models[task_i] = [new_decoder, new_encoder]
        solvers[task_i] = current_model if not use_coreset else pred_model


    # with open("/content/drive/MyDrive/Advanced Machine Learning Project/GenerativeReplay/Split_MNIST_GR1_models.pkl", "wb") as f:
    #     pickle.dump(generator_models, f)
    #     pickle.dump(solvers, f)

    return accuracies, generator_models, solvers


In [None]:
for i in range(10):
    print("============================ Exp %d ============================"%i)
    accs, generator_models, solvers = run(False, False)
    with open("/content/drive/MyDrive/Advanced Machine Learning Project/GenerativeReplay/Split_MNIST_GR1_models_exp_%d.pkl"%i, "wb") as f:
        pickle.dump([accs, generator_models, solvers], f)
    print(accs)



100%|██████████| 300/300 [02:35<00:00,  1.93it/s]
100%|██████████| 300/300 [02:38<00:00,  1.90it/s]
100%|██████████| 300/300 [02:44<00:00,  1.82it/s]
100%|██████████| 300/300 [04:32<00:00,  1.10it/s]
100%|██████████| 300/300 [07:50<00:00,  1.57s/it]


{0: [0.9995271867612293], 1: [0.9995271867612293, 0.9926542605288933], 2: [0.9995271867612293, 0.990205680705191, 0.9994663820704376], 3: [0.9995271867612293, 0.9887365328109696, 0.9973319103521878, 0.9954682779456193], 4: [0.9995271867612293, 0.9867776689520078, 0.9946638207043756, 0.9944612286002014, 0.9813414019162885]}


100%|██████████| 300/300 [02:25<00:00,  2.06it/s]
100%|██████████| 300/300 [02:33<00:00,  1.95it/s]
100%|██████████| 300/300 [02:38<00:00,  1.89it/s]
100%|██████████| 300/300 [04:28<00:00,  1.12it/s]
100%|██████████| 300/300 [07:45<00:00,  1.55s/it]


{0: [1.0], 1: [0.9995271867612293, 0.9906953966699314], 2: [0.9995271867612293, 0.9867776689520078, 1.0], 3: [0.9995271867612293, 0.9843290891283056, 0.9967982924226254, 0.9974823766364552], 4: [0.9990543735224586, 0.9813907933398629, 0.996264674493063, 0.9959718026183283, 0.9773071104387292]}


100%|██████████| 300/300 [02:17<00:00,  2.18it/s]
100%|██████████| 300/300 [02:28<00:00,  2.03it/s]
100%|██████████| 300/300 [02:45<00:00,  1.82it/s]
100%|██████████| 300/300 [04:31<00:00,  1.11it/s]
100%|██████████| 300/300 [07:46<00:00,  1.55s/it]


{0: [1.0], 1: [1.0, 0.9931439764936337], 2: [1.0, 0.9892262487757101, 0.9989327641408752], 3: [1.0, 0.9862879529872673, 0.9994663820704376, 0.9954682779456193], 4: [0.9990543735224586, 0.9784524975514202, 0.9983991462113126, 0.9949647532729103, 0.9808371154815936]}


100%|██████████| 300/300 [02:19<00:00,  2.15it/s]
100%|██████████| 300/300 [02:28<00:00,  2.02it/s]
100%|██████████| 300/300 [02:36<00:00,  1.91it/s]
100%|██████████| 300/300 [04:30<00:00,  1.11it/s]
100%|██████████| 300/300 [07:48<00:00,  1.56s/it]


{0: [0.9995271867612293], 1: [0.9995271867612293, 0.9926542605288933], 2: [0.9995271867612293, 0.990205680705191, 0.9978655282817502], 3: [0.9995271867612293, 0.9862879529872673, 0.9989327641408752, 0.9964753272910373], 4: [0.9995271867612293, 0.9853085210577864, 0.9973319103521878, 0.9949647532729103, 0.9778113968734241]}


100%|██████████| 300/300 [02:18<00:00,  2.17it/s]
100%|██████████| 300/300 [02:26<00:00,  2.05it/s]
100%|██████████| 300/300 [02:37<00:00,  1.90it/s]
100%|██████████| 300/300 [04:31<00:00,  1.11it/s]
100%|██████████| 300/300 [07:45<00:00,  1.55s/it]


{0: [1.0], 1: [0.9995271867612293, 0.9911851126346719], 2: [0.9995271867612293, 0.9887365328109696, 0.9989327641408752], 3: [0.9995271867612293, 0.9872673849167483, 0.9978655282817502, 0.9984894259818731], 4: [0.9990543735224586, 0.9857982370225269, 0.9973319103521878, 0.9964753272910373, 0.9757942511346445]}


100%|██████████| 300/300 [02:21<00:00,  2.12it/s]
100%|██████████| 300/300 [02:29<00:00,  2.00it/s]
100%|██████████| 300/300 [02:38<00:00,  1.89it/s]
100%|██████████| 300/300 [04:31<00:00,  1.10it/s]
100%|██████████| 300/300 [07:46<00:00,  1.55s/it]


{0: [1.0], 1: [0.9995271867612293, 0.9906953966699314], 2: [0.9995271867612293, 0.9887365328109696, 0.9989327641408752], 3: [0.9990543735224586, 0.9887365328109696, 0.9973319103521878, 0.9959718026183283], 4: [0.9985815602836879, 0.9882468168462292, 0.996264674493063, 0.9939577039274925, 0.9793242561775088]}


100%|██████████| 300/300 [02:19<00:00,  2.15it/s]
100%|██████████| 300/300 [02:29<00:00,  2.01it/s]
100%|██████████| 300/300 [02:35<00:00,  1.93it/s]
100%|██████████| 300/300 [04:31<00:00,  1.11it/s]
100%|██████████| 300/300 [07:46<00:00,  1.56s/it]


{0: [0.9995271867612293], 1: [0.9995271867612293, 0.9906953966699314], 2: [0.9995271867612293, 0.9877571008814887, 0.9989327641408752], 3: [0.9995271867612293, 0.9872673849167483, 0.9989327641408752, 0.9959718026183283], 4: [0.9995271867612293, 0.9809010773751224, 0.9973319103521878, 0.9979859013091642, 0.9818456883509834]}


100%|██████████| 300/300 [02:18<00:00,  2.16it/s]
100%|██████████| 300/300 [02:29<00:00,  2.01it/s]
100%|██████████| 300/300 [02:37<00:00,  1.91it/s]
100%|██████████| 300/300 [04:31<00:00,  1.10it/s]
100%|██████████| 300/300 [07:47<00:00,  1.56s/it]


{0: [0.9995271867612293], 1: [0.9995271867612293, 0.9941234084231146], 2: [0.9995271867612293, 0.9906953966699314, 0.9994663820704376], 3: [0.9990543735224586, 0.9867776689520078, 0.9989327641408752, 0.9984894259818731], 4: [0.9971631205673759, 0.9833496571988247, 0.9994663820704376, 0.9974823766364552, 0.9798285426122038]}


100%|██████████| 300/300 [02:19<00:00,  2.16it/s]
100%|██████████| 300/300 [02:28<00:00,  2.03it/s]
100%|██████████| 300/300 [02:37<00:00,  1.91it/s]
100%|██████████| 300/300 [04:30<00:00,  1.11it/s]
100%|██████████| 300/300 [07:47<00:00,  1.56s/it]


{0: [0.9995271867612293], 1: [0.9995271867612293, 0.990205680705191], 2: [0.9995271867612293, 0.9882468168462292, 1.0], 3: [0.9995271867612293, 0.9867776689520078, 0.9989327641408752, 0.9959718026183283], 4: [0.9990543735224586, 0.9867776689520078, 0.9978655282817502, 0.9959718026183283, 0.9803328290468987]}


100%|██████████| 300/300 [02:17<00:00,  2.18it/s]
100%|██████████| 300/300 [02:26<00:00,  2.05it/s]
100%|██████████| 300/300 [02:34<00:00,  1.94it/s]
100%|██████████| 300/300 [04:29<00:00,  1.11it/s]
100%|██████████| 300/300 [07:46<00:00,  1.55s/it]


{0: [0.9995271867612293], 1: [1.0, 0.9906953966699314], 2: [1.0, 0.9906953966699314, 0.9994663820704376], 3: [0.9995271867612293, 0.9906953966699314, 0.9978655282817502, 0.9969788519637462], 4: [0.9995271867612293, 0.9887365328109696, 0.996264674493063, 0.9969788519637462, 0.9788199697428139]}
