In [2]:
import matplotlib.pyplot as plt
import torch
import time
import os
from CL_BRUNO_TMLR import *
import pandas as pd
from torch.utils.data import Subset
import torchvision
import torchvision.transforms as transforms
from tiny_imagenet_torch import TinyImageNet
device='cuda'
torch.manual_seed(314159)
np.random.seed(314159)

In [2]:
preprocess_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
])

preprocess_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
])

tinyimg_train = TinyImageNet(
    root='./tinyimagenet',
    train=True,
    download=False,
    transform=preprocess_train
)

tinyimg_test = TinyImageNet(
    root='./tinyimagenet',
    train=False,
    download=False,
    transform=preprocess_test
)

In [3]:
# pretrain feature vector using the first batch of 10 classes
tinyimg_01_train = get_subset(list(range(20)), tinyimg_train)
tinyimg_01_test = get_subset(list(range(20)), tinyimg_test)
tinyimg_01_train_loader = DataLoader(tinyimg_01_train, batch_size=128, shuffle=True)
tinyimg_01_test_loader = DataLoader(tinyimg_01_test, batch_size=128, shuffle=True)
dataset_sizes = {'train': len(tinyimg_01_train), 'val': len(tinyimg_01_test)}
my_loader = {'train': tinyimg_01_train_loader, 'val': tinyimg_01_test_loader}

  idx = torch.tensor(dataset.targets) == i


In [3]:
# initialise the model
test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,
                     y_cat_num=[20], single_task=True, n_dense_block=6, n_hidden_dense=128,
                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=20, device=device)
test_model = test_model.to(device)
# test_model.init_extractor(my_loader, 50)

In [4]:
tinyimg_loader_train = DataLoader(tinyimg_train, batch_size=2000)
tinyimg_loader_test = DataLoader(tinyimg_test, batch_size=2000)

# with torch.no_grad():
#     transformed_x_train = []
#     transformed_y_train = torch.tensor([])
#     for x, y in tinyimg_loader_train:
#         transformed_x_train += [test_model.my_extractor(x.to(device))]
#         transformed_y_train = torch.cat((transformed_y_train, y))
#     transformed_x_train = torch.vstack(transformed_x_train).cpu()
#     transformed_y_train = transformed_y_train.to(torch.long)

#     transformed_x_test = []
#     transformed_y_test = torch.tensor([])
#     for x, y in tinyimg_loader_test:
#         transformed_x_test += [test_model.my_extractor(x.to(device))]
#         transformed_y_test = torch.cat((transformed_y_test, y))
#     transformed_x_test = torch.vstack(transformed_x_test).cpu()
#     transformed_y_test = transformed_y_test.to(torch.long)

# pd.DataFrame(transformed_x_train.numpy()).to_csv('tinyimg_pretrained_feature_train.csv', index=False)
# pd.DataFrame(transformed_x_test.numpy()).to_csv('tinyimg_pretrained_feature_test.csv', index=False)

transformed_x_train = torch.tensor(pd.read_csv('tinyimg_pretrained_feature_train.csv').to_numpy(), dtype=torch.float)
transformed_x_test = torch.tensor(pd.read_csv('tinyimg_pretrained_feature_test.csv').to_numpy(), dtype=torch.float)

transformed_y_test = torch.tensor(tinyimg_test.targets, dtype=torch.long)
transformed_y_train = torch.tensor(tinyimg_train.targets, dtype=torch.long)

  transformed_y_test = torch.tensor(tinyimg_test.targets, dtype=torch.long)
  transformed_y_train = torch.tensor(tinyimg_train.targets, dtype=torch.long)


In [5]:
task_split = [list(range(i*20, (i+1)*20)) for i in range(10)]
CIL_tinyimg_train = {}
CIL_tinyimg_test = {}
# nhwc to nchw
for _ in range(len(task_split)):
    task_id = np.array([i for i,j in enumerate(transformed_y_train) if j in task_split[_]])
    perm = torch.randperm(len(task_id))
    CIL_tinyimg_train['X_{}'.format(_)] = transformed_x_train[task_id][perm].to(device)
    CIL_tinyimg_train['y_{}'.format(_)] = transformed_y_train[task_id][perm].to(device)
    task_id = np.array([i for i, j in enumerate(transformed_y_test) if j in task_split[_]])
    CIL_tinyimg_test['X_{}'.format(_)] = transformed_x_test[task_id].to(device)
    CIL_tinyimg_test['y_{}'.format(_)] = transformed_y_test[task_id].to(device)

# with functional regularization

In [3]:
# set alignment_reg=0. to turn off alignment regularizer
train_loss, test_loss = test_model.train_init(CIL_tinyimg_train['X_0'], CIL_tinyimg_train['y_0'],
                                              torch.zeros(CIL_tinyimg_train['y_0'].shape[0], dtype=torch.long).to(device),
                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)
                                              # context_portion=0.2)


N = len(CIL_tinyimg_test['y_0'])
my_id_test = range(len(CIL_tinyimg_test['y_0']))
# runnable, check outputs
q = torch.zeros((N, 20))
p = torch.zeros(N)
for i,j in enumerate(my_id_test):
    a, b = test_model.prediction(CIL_tinyimg_test['X_0'][j], 0)
    q[i] = a.cpu()
    p[i] = b.cpu()
print(torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_0'][my_id_test].cpu())/N)


tensor(0.3530)


In [11]:
batch_sizes = [128]*9
# doing CIL
for batch_id in range(1, 10):
    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_tinyimg_train['X_{}'.format(batch_id)],
                                                                         y_new=CIL_tinyimg_train['y_{}'.format(batch_id)],
                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),
                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,
                                                                         embedding_reg=0.1)

    acc = 0.
    for hist_id in range(batch_id+1):
        N = len(CIL_tinyimg_test['y_{}'.format(hist_id)])
        my_id_test = range(len(CIL_tinyimg_test['y_{}'.format(hist_id)]))
        q = torch.zeros((N, (batch_id + 1) * 20))
        p = torch.zeros(N)
        for i, j in enumerate(my_id_test):
            a, b = test_model.prediction(CIL_tinyimg_test['X_{}'.format(hist_id)][j], 0)
            q[i] = a.cpu()
            p[i] = b.cpu()
        acc += torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N
    print('Batch {}'.format(batch_id))
    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))



Batch 1
misclassification: 0.521
Batch 2
misclassification: 0.605
Batch 3
misclassification: 0.693
Batch 4
misclassification: 0.774
Batch 5
misclassification: 0.838
Batch 6
misclassification: 0.867
Batch 7
misclassification: 0.869
Batch 8
misclassification: 0.871
Batch 9
misclassification: 0.883


# without funtional regularization

In [10]:
# initialise the model
test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,
                     y_cat_num=[20], single_task=True, n_dense_block=6, n_hidden_dense=128,
                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=20, device=device)
test_model = test_model.to(device)

# set alignment_reg=0. to turn off alignment regularizer
train_loss, test_loss = test_model.train_init(CIL_tinyimg_train['X_0'], CIL_tinyimg_train['y_0'],
                                              torch.zeros(CIL_tinyimg_train['y_0'].shape[0], dtype=torch.long).to(device),
                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)
                                              # context_portion=0.2)

batch_sizes = [128]*9
# doing CIL
for batch_id in range(1, 10):
    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_tinyimg_train['X_{}'.format(batch_id)],
                                                                         y_new=CIL_tinyimg_train['y_{}'.format(batch_id)],
                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),
                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,
                                                                         embedding_reg=0.1, alignment_reg=0.)

    acc = 0.
    for hist_id in range(batch_id+1):
        N = len(CIL_tinyimg_test['y_{}'.format(hist_id)])
        my_id_test = range(len(CIL_tinyimg_test['y_{}'.format(hist_id)]))
        q = torch.zeros((N, (batch_id + 1) * 20))
        p = torch.zeros(N)
        for i, j in enumerate(my_id_test):
            a, b = test_model.prediction(CIL_tinyimg_test['X_{}'.format(hist_id)][j], 0)
            q[i] = a.cpu()
            p[i] = b.cpu()
        acc += torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N
    print('Batch {}'.format(batch_id))
    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))


Batch 1
misclassification: 0.558
Batch 2
misclassification: 0.665
Batch 3
misclassification: 0.743
Batch 4
misclassification: 0.824
Batch 5
misclassification: 0.888
Batch 6
misclassification: 0.892
Batch 7
misclassification: 0.913
Batch 8
misclassification: 0.934
Batch 9
misclassification: 0.944
