In [1]:
import torch
import pickle
import types

# Wrapper function to enforce CPU loading
def load_on_cpu(func):
    def wrapper(*args, **kwargs):
        kwargs['map_location'] = torch.device('cpu')
        return func(*args, **kwargs)
    return wrapper

# Backup original torch.load
original_torch_load = torch.load

# Wrap torch.load with the load_on_cpu function
torch.load = load_on_cpu(torch.load)

try:
    with open('prototype.pickle', 'rb') as f:
        prototype_vectors = pickle.load(f)
finally:
    # Restore original torch.load
    torch.load = original_torch_load

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
prototype_vectors

{0: tensor([3.3725e-07, 8.4418e-03, 0.0000e+00, 1.1559e-03, 1.3602e-04, 9.1571e-03,
         0.0000e+00, 7.9300e-02, 7.8128e-06, 1.0846e-04, 5.0647e-04, 1.5918e-03,
         1.4557e-03, 6.7355e-06, 2.6962e-04, 2.8863e-05, 3.9983e-02, 4.1884e-02,
         7.8964e-04, 1.3372e-02, 7.3576e-03, 1.4010e-01, 1.3526e-03, 1.6335e-04,
         7.3890e-04, 1.2434e-06, 3.4584e-02, 9.9764e-03, 6.8637e-02, 8.4836e-04,
         1.1656e-01, 5.5938e-03, 6.8656e-03, 1.0524e-03, 1.3841e-01, 4.4056e-04,
         1.9179e-02, 7.0943e-02, 3.0554e-02, 8.5220e-02, 1.4570e-03, 1.7331e-03,
         1.0570e-05, 5.2899e-05, 1.3759e-05, 8.1781e-03, 5.9622e-02, 6.5664e-04,
         4.3187e-03, 0.0000e+00, 1.9021e-01, 3.6053e-02, 9.1410e-04, 2.8059e-03,
         1.3967e-03, 1.3722e-05, 1.1453e-05, 8.4138e-03, 3.3189e-03, 2.2886e-03,
         9.2720e-06, 8.2350e-05, 2.3157e-03, 1.9766e-04, 3.0168e-07, 6.9283e-03,
         0.0000e+00, 3.4030e-02, 0.0000e+00, 3.5192e-03, 0.0000e+00, 5.6836e-05,
         8.2820e-02, 1.06

In [3]:
orders = [i for i in range(10)]

In [4]:
from src.ours import *

In [5]:
class AttrDict(dict):
    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError(f"No such attribute: {key}")

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        if key in self:
            del self[key]
        else:
            raise AttributeError(f"No such attribute: {key}")

In [6]:
args = AttrDict({
    "dataset_name": "cifar10",
    "path_data": "./",
    "download_data": True,
    "model_name": "resnet18",
    "path_pretrained_model": "pretrained_model.pth",
    "path_init_params": "init_params.pth",
    "alpha_conv": 0.9,
    "num_tasks": 10,
    "num_classes": 10,
    "num_classes_first_task": 7,
    "num_iters": 3,
    "prune_batch_size": 1000,
    "batch_size": 128,
    "test_batch_size": 20,
    "train_epochs": 70,
    "retrain_epochs": 50,
    "optimizer": "adam",
    "lr_decay_type": "multistep",
    "lr": 1e-2,
    "decay_epochs_train": [20, 40, 60],
    "decay_epochs_retrain": [15, 25, 40],
    "gamma": 0.2,
    "wd": 5e-4,
    "seed": 0,
    "order_name": "default",
    "task_select_method": "max",
    "train": True,
})


In [7]:
task_labels = create_labels(
    args.num_classes, args.num_tasks, args.num_classes_first_task
)
train_dataset, test_dataset = task_construction(
            task_labels, args.dataset_name, None
)

net = init_model(args, "cpu")

net.class_task_map = create_class_task_map(
    args.num_classes_first_task, orders
)

# torch.save(net.state_dict(), args.path_init_params)

num_tasks = args.num_tasks

Files already downloaded and verified
Files already downloaded and verified


  splited_dataset.targets = torch.tensor(splited_dataset.targets)[idx]


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
start_task = 4
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

In [10]:
task_id = 0
set_task(net, task_id)
train_loader, test_loader = get_loaders(
    train_dataset[task_id], test_dataset[task_id], args.batch_size
)

In [17]:
itr = iter(test_loader)
x, y = next(itr)

In [18]:
print(y[0])

tensor(3)


In [19]:
net.features(x[None, 0])

tensor([[0.5215, 0.8243, 0.4115, 0.6368, 0.5364, 0.4749, 0.7610, 0.4676, 0.6231,
         0.4378, 0.5614, 0.6735, 0.5816, 0.5610, 0.7128, 0.4367, 0.4512, 0.4831,
         0.5333, 0.4727, 0.5679, 0.4925, 0.6256, 0.5727, 0.5954, 0.5599, 0.4734,
         0.6598, 0.4406, 0.7040, 0.8637, 0.6461, 0.5910, 0.5310, 0.4815, 0.4652,
         0.4945, 0.6841, 0.5283, 0.6965, 0.4335, 0.6534, 0.6127, 0.6665, 0.5854,
         0.5670, 0.5714, 0.4699, 0.5312, 0.4674, 0.4980, 0.4335, 0.5132, 0.6577,
         0.4591, 0.5257, 0.4918, 0.8461, 0.7214, 0.8141, 0.7623, 0.5490, 0.8683,
         0.5061, 0.5396, 0.7221, 0.6151, 0.6230, 1.0068, 0.6668, 0.6623, 0.8237,
         0.5237, 0.5849, 0.6144, 0.5976, 0.4589, 0.5073, 0.4890, 0.7747, 0.5055,
         0.4277, 0.4612, 0.4757, 0.4512, 0.5087, 0.7417, 0.4852, 0.5446, 0.7542,
         0.8887, 0.7693, 0.7346, 0.5309, 0.5665, 0.5200, 0.5192, 0.6225, 0.7196,
         0.6766, 0.4728, 0.5517, 0.6123, 0.5372, 0.5972, 0.5785, 0.5543, 0.5105,
         0.5482, 0.7086, 0.5

In [33]:
start_task = 1
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task1_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task1_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

set_task(net, 0)
print(net.predict(x[:10], prototype_vectors), y[:10])

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([35.3145, 25.1094, 40.6949, 14.9114,  7.1820, 40.3127, 23.0736, 25.3350,
        31.6882, 33.8653], grad_f

In [35]:
start_task = 2
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task2_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task2_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

set_task(net, 0)
print(net.predict(x[:10], prototype_vectors), y[:10])

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([1.0004, 1.1869, 0.9998, 1.0017, 2.6145, 1.0002, 1.0043, 1.0060, 1.0023,
        0.9816], grad_fn=<SumBack

In [36]:
start_task = 3
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task3_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task3_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

set_task(net, 0)
net.predict(x[:10], prototype_vectors), y[:10]

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([1.0004, 1.1869, 0.9998, 1.0017, 2.6145, 1.0002, 1.0043, 1.0060, 1.0023,
        0.9816], grad_fn=<SumBack

(tensor([6., 0., 8., 6., 1., 8., 3., 1., 0., 4.]),
 tensor([3, 0, 6, 6, 1, 6, 3, 1, 0, 5]))

In [37]:
start_task = 4
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

set_task(net, 0)
net.predict(x[:10], prototype_vectors), y[:10]

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([1.0004, 1.1869, 0.9998, 1.0017, 2.6145, 1.0002, 1.0043, 1.0060, 1.0023,
        0.9816], grad_fn=<SumBack

(tensor([9., 9., 9., 6., 9., 9., 9., 9., 9., 9.]),
 tensor([3, 0, 6, 6, 1, 6, 3, 1, 0, 5]))

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([1.0004, 1.1869, 0.9998, 1.0017, 2.6145, 1.0002, 1.0043, 1.0060, 1.0023,
        0.9816], grad_fn=<SumBack

(tensor([6., 0., 8., 6., 1., 8., 3., 1., 0., 4.]),
 tensor([3, 0, 6, 6, 1, 6, 3, 1, 0, 5]))

In [29]:
for i in range(10):
    print(net.features(x[None, i]), "norm: \n", torch.norm(net.features(x[None, i])))
print(prototype_vectors[9])

tensor([[6.3251e-03, 3.8589e-02, 8.1147e-02, 4.3196e-01, 6.2829e-02, 5.1588e-02,
         1.0488e-01, 0.0000e+00, 1.5152e-02, 2.2970e-01, 8.3436e-02, 1.5990e-01,
         5.7155e-02, 1.4794e-01, 3.2300e-01, 5.4602e-01, 4.4847e-02, 8.1612e-02,
         3.3397e-01, 2.4109e-02, 2.5832e-02, 8.9007e-02, 2.7916e-02, 1.6444e-01,
         9.4616e-02, 4.4236e-01, 1.7487e-02, 5.0395e-01, 5.1731e-02, 4.8420e-02,
         2.3683e-02, 9.9113e-02, 8.8667e-02, 1.3848e-01, 0.0000e+00, 1.3501e-01,
         2.3755e-03, 1.8176e-02, 8.7913e-02, 0.0000e+00, 1.2024e-01, 6.5833e-03,
         9.8761e-02, 1.9460e-01, 1.1841e-01, 5.6594e-02, 8.6199e-03, 1.1701e-02,
         9.8845e-02, 9.8862e-02, 3.1300e-02, 3.4562e-03, 2.1338e-01, 2.2038e-01,
         9.3654e-02, 4.3666e-02, 4.8750e-01, 4.4956e-02, 1.5754e-01, 8.0282e-02,
         1.5849e-01, 1.4913e-01, 6.8854e-02, 1.5732e-01, 3.4811e-01, 3.9309e-02,
         5.8929e-02, 2.1516e-02, 3.8995e-02, 6.2342e-03, 1.3217e-01, 2.3498e-01,
         2.0701e-02, 2.0391e

In [22]:
start_task = 4
if start_task > 0:
    path_to_save = "/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_7classes_adam_0.9_it1_order_default.pth"
    net.load_state_dict(torch.load(path_to_save, map_location=device))

    net._load_masks(
        file_name="/work/gu14/k36093/Ours/results/cifar10-resnet18/10_tasks/resnet18_task4_masks_7classes_adam_0.9_it1_order_default.pth",
        num_tasks=start_task,
    )

    if start_task < num_tasks:
        net._add_mask(task_id=start_task)

set_task(net, 0)
print(net.predict(x[:10], prototype_vectors), y[:10])

set_task(net, net.class_task_map[4])
print(f"x_feature; mask{net.class_task_map[4]}\n", net.features(x[None, 1]))
print(f"prototype; 4\n", prototype_vectors[4])
print("dist;\n", torch.sum(net.features(x[None, 1]) - prototype_vectors[4])**2)

set_task(net, net.class_task_map[9])
print(f"x_feature; mask{net.class_task_map[9]}\n", net.features(x[None, 1]))
print(f"prototype; 9\n", prototype_vectors[9])
print("dist;\n", torch.sum(net.features(x[None, 1]) - prototype_vectors[9])**2)

DIST tensor([0.7172, 0.4216, 1.4392, 0.7304, 0.6988, 0.9630, 0.8729, 0.9410, 0.4016,
        0.8998], grad_fn=<SumBackward1>)
DIST tensor([0.6774, 0.5327, 1.6009, 0.8791, 0.4398, 1.4391, 1.0455, 0.2667, 1.1131,
        1.2546], grad_fn=<SumBackward1>)
DIST tensor([0.5869, 1.0883, 1.1496, 0.2875, 0.8486, 0.7139, 0.3835, 1.2361, 0.4687,
        0.4675], grad_fn=<SumBackward1>)
DIST tensor([0.5738, 1.4161, 1.4610, 0.3101, 1.0471, 1.2024, 0.2671, 1.3091, 0.9172,
        0.6778], grad_fn=<SumBackward1>)
DIST tensor([0.6820, 1.2832, 1.0858, 0.4079, 1.0295, 0.9700, 0.4913, 1.2992, 0.5907,
        0.3728], grad_fn=<SumBackward1>)
DIST tensor([0.7413, 1.6087, 1.8059, 0.4130, 1.2788, 1.4105, 0.3407, 1.4755, 0.9725,
        0.6960], grad_fn=<SumBackward1>)
DIST tensor([0.3561, 1.3650, 0.9874, 0.1833, 0.9749, 0.9056, 0.4049, 1.1934, 0.8895,
        0.6180], grad_fn=<SumBackward1>)
DIST tensor([1.0004, 1.1869, 0.9998, 1.0017, 2.6145, 1.0002, 1.0043, 1.0060, 1.0023,
        0.9816], grad_fn=<SumBack

In [137]:
for p in prototype_vectors.values():
    print(p.norm())

tensor(0.8429)
tensor(0.8614)
tensor(0.8686)
tensor(0.8994)
tensor(0.8969)
tensor(0.9520)
tensor(0.8835)
tensor(1.0147)
tensor(1.0231)
tensor(1.0507)
