In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
%cd cl-adaptation/

In [None]:
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
import re
from typing import List
from matplotlib import pyplot as plt
import numpy as np
import torch


def parse_path(filename):
    match = re.match(r".*dt:(\d+)_t:(\d+)_", filename)
    if match:
        data_task = int(match.group(1))
        task = int(match.group(2))
    else:
        raise ValueError("Filename format is incorrect")
    return data_task, task


@dataclass
class Weight:
    filename: str
    W: torch.Tensor
    b: torch.Tensor
    varcov: None
    task: int
    data_task: int

    @classmethod
    def from_path(cls, path):
        path = Path(path)

        ckpt = torch.load(path, map_location="cpu")["state_dict"]

        W = ckpt["linear_layer.weight"].flatten().numpy()
        b = ckpt["linear_layer.bias"].numpy()

        data_task, task = parse_path(path.name)
        varcov = None

        return cls(path.name, W, b, varcov, task, data_task)

    def get_item(self, kind="W"):
        return self.W if kind == "W" else self.b


def get_weigths_and_sort(directory_path: Path) -> List[Weight]:
    paths = list(directory_path.rglob("*.ckpt"))

    weights = [Weight.from_path(path) for path in paths]
    weights = sorted(weights, key=lambda x: x.filename)
    return weights


def get_subsequents(weights):
    nth = list(filter(lambda x: x.data_task - x.task == 1, weights))
    prev_nth = list(filter(lambda x: x.data_task - x.task == 0, weights))
    return nth, prev_nth


def get_zeros(weights):
    all_weights = list(filter(lambda x: x.task == 0, weights))

    prev_nth = [all_weights[0]] * (len(all_weights) - 1)
    nth = all_weights[1:]
    return nth, prev_nth


def get_diffs_W(weights, func):
    nth, prev_nth = func(weights)

    wds = np.asarray(list(map(lambda x: x.W.flatten(), nth)))
    -np.asarray(list(map(lambda x: x.W.flatten(), prev_nth)))
    return wds


def get_diffs_b(weights, func):
    nth, prev_nth = func(weights)

    wds = np.asarray(list(map(lambda x: x.b, nth)))
    -np.asarray(list(map(lambda x: x.b, prev_nth)))
    return wds


def plot(reg, noreg, name, *args, **kwargs):
    fig, axs = plt.subplots(2, 2, figsize=(10, 8), sharex=True, sharey=True)
    for i in range(4):
        # Determine subplot indices
        row = i // 2
        col = i % 2

        # Plot histogram for the current dataset
        axs[row, col].hist(
            noreg[i], bins=kwargs["bins"], alpha=0.5, label="noreg", density=True
        )
        axs[row, col].hist(
            reg[i], bins=kwargs["bins"], alpha=0.5, label="reg", density=True
        )
        axs[row, col].set_title(f"After task {i+1}")
        axs[row, col].legend()
        plt.grid()

    fig.suptitle(name)
    plt.tight_layout()
    plt.show()

In [None]:
ckpt_dir = Path("linear_checkpoints")
weights_list = get_weigths_and_sort(ckpt_dir)

In [None]:
paths = list(ckpt_dir.iterdir())

In [None]:
def get_diffs(what):
    weights = list(filter(lambda x: what in x.filename, weights_list))
    assert len(weights) > 0
    diffs = []
    for i in range(4):
        j = i + 1
        res = weights[i].get_item("W") - weights[j].get_item("W")
        diffs.append(res)
    return diffs

In [None]:
def get_diffs_zero(what):
    weights = list(filter(lambda x: what in x.filename, weights_list))
    assert len(weights) > 0
    diffs = []
    for i in range(1, 5):
        res = weights[0].get_item("W") - weights[i].get_item("W")
        diffs.append(res)
    return diffs

In [None]:
hist_kw = {"bins": 100, "density": True, "alpha": 0.5}

methods = ["finetuning", "replay", "ewc", "lwf"]

import matplotlib as mpl

# mpl.rcParams["text.usetex"] = False
# mpl.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
# nice_fonts = {
#     "text.usetex": True,
#     "font.family": "serif",
#     "font.serif": "Times New Roman",
#     "font.size": 14,
#     "lines.linewidth": 3,
# }
# mpl.rcParams.update(nice_fonts)


def plot_methods(hist_kw, methods, diff_fn):
    for method in methods:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        for idx, ax in enumerate(axs):
            ax.hist(
                diff_fn(f"{method}_noreg")[idx],
                **hist_kw,
                label="No regularization",
            )
            ax.hist(
                diff_fn(f"{method}_reg")[idx],
                **hist_kw,
                label="Regularization",
            )
            ax.set_xlim([-0.1, 0.1])
            ax.set_xticks([-0.1, 0, 0.1])
            ax.set_ylim([0, 60])
            ax.grid()
            ax.legend()
        fig.suptitle(method)


plot_methods(hist_kw, methods, get_diffs_zero)

In [None]:
len(weights_list)