In [33]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from collections import defaultdict
from datetime import date
from functools import partial
import os
from pathlib import Path
from pprint import pprint
from typing import Any
import torch

from src.sign_visualizer.core import (
    choose_task,
    get_activations,
    load_and_clean,
    load_extractor,
)
from src.sign_visualizer.dataset_interface import (
    ContinualDataset,
    ContinualDatasetConfig,
)


def is_sign_changed(output_1: torch.Tensor, output_2: torch.Tensor):
    return (output_1.mean(0) * output_2.mean(0)) < 0


def is_negative(output_1: torch.Tensor):
    return output_1.mean(0) < 0


def save(container: Any, path: str) -> None:
    torch.save(container, path)


cifar_10_paths = [
    "results/2024/03.14/cifar10_fixed_finetuning_fc$:var_0.64:cov_12.8",
    "results/2024/03.14/cifar10_fixed_finetuning_noreg",
    "results/2024/03.21/cifar10_fixed_finetuning",
    "results/2024/03.22/cifar10_fixed_lwf",
    "results/2024/03.14/cifar10_fixed_finetuning_.*after_relu|fc$:var_0.64:cov_12.8",
]
cifar_100_paths = [
    # "results/2024/03.22/cifar100_fixed_finetuning",
    # "results/2024/03.21/cifar100_fixed_finetuning",
    # "results/2024/03.21/cifar100_fixed_lwf",
    # # "results/2024/03.07/cifar100_fixed_finetuning_fc",
    # "results/2024/03.21/cifar100_fixed_ewc",
    "results/2024/03.26/cifar100_fixed_finetuning_reg_deeper",
    "results/2024/03.26/cifar100_fixed_finetuning_reg_fc",
]

In [3]:
datasets = {
    "cifar100": {"num_tasks": 5, "num_classes": 20, "paths": cifar_100_paths},
    "cifar10": {"num_tasks": 5, "num_classes": 2, "paths": cifar_10_paths},
}
DATASET = "cifar100"
NUM_TASKS = datasets[DATASET]["num_tasks"]
NUM_CLASSES = datasets[DATASET]["num_classes"]
DEVICE = "cuda"
PATHS = datasets[DATASET]["paths"]

for path in PATHS:
    match = list(Path(path).rglob("args*.txt"))
    assert len(match) == 1, "Number of matched files is not equal to 1"

    # Open the file and grep for "no_last_relu"
    file_path = match[0]
    with open(file_path, "r") as file:
        assert "no_last_relu" in file.read(), file_path


for path in PATHS:
    model = partial(load_extractor, path=path, load_and_clean=load_and_clean)
    config = ContinualDatasetConfig(f"{DATASET}_fixed", NUM_TASKS)
    cl_dataset = ContinualDataset(config)
    dataloader = partial(
        torch.utils.data.DataLoader,
        batch_size=128,
        num_workers=1,
        drop_last=True,
        pin_memory=True,
    )

    results = {}
    for task_id in range(NUM_TASKS):
        weights = torch.load(choose_task(path, task_id))[f"heads.{task_id}.weight"]
        test_0 = cl_dataset[task_id, "test"]
        test_0_dataloader = dataloader(test_0)
        activations_0 = get_activations(
            model(task=task_id), test_0_dataloader, device=DEVICE
        )

        changes = []

        for i in range(task_id + 1, NUM_TASKS):
            activations_1 = get_activations(
                model(task=i), test_0_dataloader, device=DEVICE
            )
            # torch.load(path)
            labels_signs = {}
            for label in activations_0.keys():
                d_changed = is_sign_changed(activations_0[label], activations_1[label])
                labels_signs[label] = d_changed

            sorted_signs = torch.stack(
                [value for key, value in sorted(labels_signs.items())]
            )

            changed_mass = (
                sorted_signs * weights.abs() / weights.abs().sum(1).unsqueeze(1)
            ).sum(1)
            changes.append(changed_mass.detach().cpu())

        results[task_id] = changes

    save_path = Path("sign_mass_results", *Path(path).parts[1:])
    os.makedirs(save_path, exist_ok=True)

    save(
        results,
        os.path.join(save_path, "sign.pkl"),
    )
    pprint(results)

Files already downloaded and verified
Files already downloaded and verified
_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias'])


100%|██████████| 15/15 [00:08<00:00,  1.72it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.87it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.10it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.33it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.81it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.66it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.98it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  9.11it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:02<00:00,  6.62it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:02<00:00,  6.63it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.25it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.92it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.29it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00, 10.77it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.12it/s]


{0: [tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0085,
        0.0000, 0.0000]),
     tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
     tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
     tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])],
 1: [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
     tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
     tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0025, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000])],
 2: [tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0044, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0038, 0.0000, 0.0053, 

100%|██████████| 15/15 [00:01<00:00,  8.06it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.26it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:01<00:00, 10.18it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.89it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.99it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.10it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.96it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  9.22it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00, 11.58it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.84it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.06it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  8.17it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias'])


100%|██████████| 15/15 [00:01<00:00,  7.75it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:02<00:00,  6.74it/s]


_IncompatibleKeys(missing_keys=[], unexpected_keys=['heads.0.weight', 'heads.0.bias', 'heads.1.weight', 'heads.1.bias', 'heads.2.weight', 'heads.2.bias', 'heads.3.weight', 'heads.3.bias', 'heads.4.weight', 'heads.4.bias'])


100%|██████████| 15/15 [00:01<00:00,  9.60it/s]

{0: [tensor([0.0809, 0.0067, 0.0041, 0.0000, 0.0012, 0.0078, 0.0059, 0.0067, 0.0524,
        0.0129, 0.0020, 0.0520, 0.0000, 0.0021, 0.0017, 0.0000, 0.0111, 0.0437,
        0.0000, 0.0000]),
     tensor([0.0864, 0.0132, 0.0134, 0.0077, 0.0051, 0.0036, 0.0117, 0.0065, 0.0646,
        0.0127, 0.0009, 0.0502, 0.0123, 0.0032, 0.0035, 0.0000, 0.0022, 0.0289,
        0.0037, 0.0122]),
     tensor([0.0803, 0.0067, 0.0005, 0.0055, 0.0257, 0.0036, 0.0256, 0.0162, 0.0469,
        0.0100, 0.0038, 0.0000, 0.0000, 0.0060, 0.0017, 0.0010, 0.0011, 0.0395,
        0.0033, 0.0041]),
     tensor([0.1452, 0.0144, 0.0035, 0.0017, 0.0091, 0.0113, 0.0204, 0.0084, 0.0543,
        0.0116, 0.0104, 0.0272, 0.0023, 0.0041, 0.0084, 0.0034, 0.0011, 0.0631,
        0.0068, 0.0030])],
 1: [tensor([0.1929, 0.0301, 0.0728, 0.1087, 0.0424, 0.0101, 0.0088, 0.0163, 0.0266,
        0.0937, 0.0316, 0.0185, 0.0763, 0.0017, 0.0124, 0.0576, 0.0908, 0.0210,
        0.0087, 0.0234]),
     tensor([0.2114, 0.0523, 0.0675, 0.1182,


