## Imports

In [None]:
import logging
import os
from typing import List, Optional

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import tvp  # noqa
from tvp.data.datamodule import MetaData
from tvp.data.datasets.registry import get_dataset
from tvp.task_vectors.task_vectors import TaskVector
from tvp.utils.io_utils import load_model_from_artifact
from tvp.utils.utils import build_callbacks
from torch.nn.utils import vector_to_parameters
from torch.nn.utils import parameters_to_vector

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="playground")

In [None]:
cfg = compose(config_name="task_vectors", overrides=[])

In [None]:
seed_index_everything(cfg)

cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

template_core: NNTemplateCore = NNTemplateCore(
    restore_cfg=cfg.train.get("restore", None),
)
logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

## Load models

In [None]:
import copy


zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

## Task vectors

In [None]:
flatten = lambda model: parameters_to_vector(model.parameters())

zeroshot_vec = flatten(zeroshot_model)

In [None]:
task_vectors = [
    TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply
]

In [None]:
def apply_task_vector(model, task_vector):
    model.load_state_dict({k: v + task_vector[k] for k, v in model.state_dict().items()})

### Aggregate task vectors

In [None]:
with torch.no_grad():
    task_vectors = torch.stack(
        [flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply]
    )

### Standard task vectors

In [None]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

In [None]:
multi_task_vector = task_vectors_sum / len(task_vectors)

In [None]:
delta_model = copy.deepcopy(zeroshot_model)
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [None]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

In [None]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

## Task Singular Vectors

### Compute SVDs

In [None]:
dataset_names = cfg.task_vectors.to_apply

In [None]:
# arbitrary dataset used as a key
ref_dataset = dataset_names[0]

#### Get layer task tensors, i.e. deltas for each layer maintaining the tensor structure

In [None]:
layer_task_tensors = {}

for dataset, task_vector in zip(dataset_names, task_vectors):
    delta_model = copy.deepcopy(zeroshot_model)

    vector_to_parameters(task_vector, delta_model.parameters())

    layer_task_tensors[dataset] = dict(delta_model.named_parameters())

In [None]:
is_matrix = lambda x: len(x.shape) == 2
layers_to_ignore = {"model.text_projection"}

In [None]:
svd_results = {}

for dataset_name, task_vector in zip(dataset_names, task_vectors):
    svd_results[dataset_name] = {}

    for layer_name, layer_task_tensor in layer_task_tensors[dataset_name].items():
        if is_matrix(layer_task_tensor) and layer_name not in layers_to_ignore:
            svd = torch.svd(layer_task_tensor)

            svd_results[dataset_name][layer_name] = {"U": svd.U, "S": svd.S, "V": svd.V}

In [None]:
def select_task_components(x, dim, perc_comps_for_task):
    num_comps = int(x.shape[dim] * perc_comps_for_task)

    if x.dim() == 1:
        return x[:num_comps]

    assert x.dim() == 2

    if dim == 0:
        return x[:num_comps, :]
    elif dim == 1:
        return x[:, :num_comps]

In [None]:
from functools import partial

pow = 2
perc_comps_for_task = 1 / len(dataset_names)

select_comps = partial(select_task_components, perc_comps_for_task=perc_comps_for_task)

task_sing_vectors = {}

for layer_name, layer_task_tensor in layer_task_tensors[ref_dataset].items():
    if is_matrix(layer_task_tensor) and layer_name not in layers_to_ignore:
        U = torch.concat(
            [select_comps(svd_results[dataset_name][layer_name]["U"], dim=1) for dataset_name in dataset_names], dim=1
        ).detach()
        S = torch.concat(
            [select_comps(svd_results[dataset_name][layer_name]["S"], dim=0) for dataset_name in dataset_names]
        ).detach()
        Vt = torch.concat(
            [select_comps(svd_results[dataset_name][layer_name]["V"], dim=1).T for dataset_name in dataset_names], dim=0
        ).detach()

        assert U.shape[1] == S.shape[0] == Vt.shape[0]  # rank
        assert U.shape[0] == layer_task_tensor.shape[0] and Vt.shape[1] == layer_task_tensor.shape[1]  # N, M

        var_u = torch.pow(
            torch.linalg.multi_dot((U.mT, U, torch.diag(S))),
            pow,
        )
        var_v = torch.pow(
            torch.linalg.multi_dot((torch.diag(S), Vt, Vt.mT)),
            pow,
        )

        var_u = var_u / (torch.sum(torch.abs(var_u), dim=0) + 1e-12)
        var_v = var_v / (torch.sum(torch.abs(var_v), dim=1, keepdim=True) + 1e-12)

        S_tilde = torch.diagonal(torch.diag(S) @ (var_u * var_v))
        assert S_tilde.shape == S.shape

        interf = U.mT @ (U @ torch.diag_embed(S) @ Vt) @ Vt.mT

        no_interf = (
            U.mT
            @ torch.linalg.multi_dot(
                (
                    U,
                    torch.diag(S_tilde) @ (var_u * var_v),
                    Vt,
                )
            )
            @ Vt.mT
        )

        task_sing_vec = U @ torch.diag(S_tilde) @ (var_u * var_v) @ Vt

        assert task_sing_vec.shape == layer_task_tensor.shape

        task_sing_vectors[layer_name] = {
            "u1_u2": U.mT @ U,
            "s1+s2": S,
            "tilde_s": S_tilde,
            "v1_v2": Vt.mT @ Vt,
            "interf": interf,
            "no_interf": no_interf,
            "task_sing_vec": task_sing_vec,
        }

### Plots

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.ticker as mticker
# from matplotlib.colors import LinearSegmentedColormap

# num_buckets = 11

# color_values = np.linspace(-0.5, 0.5, num_buckets + 1)
# color_list = []

# for i in range(num_buckets):
#     color = plt.cm.RdBu((i + 0.5) / num_buckets)  # Adjusting 0.5 to center colors
#     color_list.append(color)

# # Create a LinearSegmentedColormap with your custom colors
# custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", color_list, num_buckets)

# for layer_name, layer_task_tensor in delta_model.named_parameters():

#     if is_matrix(layer_task_tensor):


#         if layer_name not in ["model.token_embedding.weight", "model.positional_embedding"]:
#             print(f"Plotting  {layer_name}")

#             fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
#             cax = ax1.imshow(
#                 result[layer_name]["u1_u2"],  # [:25, :25],
#                 vmin=-0.5,
#                 vmax=0.5,
#                 cmap=custom_cmap,
#                 aspect="auto",
#             )
#             ax1.set_title(f" {layer_name}: sum_u.T @ sum_u")
#             cbar = fig.colorbar(
#                 cax,
#                 # ticks=[-0.5, 0, 0.5],
#                 # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
#                 ticks=np.linspace(-1, 1, num_buckets),
#                 format=mticker.FixedFormatter(
#                     [round(x, 2) for x in np.linspace(-1, 1, num_buckets)]
#                 ),
#                 extend="both",
#             )

#             cax = ax2.imshow(
#                 result[layer_name]["v1_v2"],  # [:25, :25],
#                 vmin=-0.5,
#                 vmax=0.5,
#                 cmap=custom_cmap,
#                 aspect="auto",
#             )
#             ax2.set_title(f"{layer_name}: sum_v @ sum_v.T")
#             cbar = fig.colorbar(
#                 cax,
#                 # ticks=[-0.5, 0, 0.5],
#                 # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
#                 ticks=np.linspace(-1, 1, num_buckets),
#                 format=mticker.FixedFormatter(
#                     [round(x, 2) for x in np.linspace(-1, 1, num_buckets)]
#                 ),
#                 extend="both",
#             )
#             plt.show()

#             fig, axs = plt.subplots(
#                 nrows=1, ncols=2, figsize=(20, 10), sharey=False, sharex=False
#             )

#             axs[0].plot(result[layer_name]["s1+s2"])  # axs[0].semilogy(s_anchor1)
#             axs[0].set_title(f"Singular values of the  datasets in concatenation")
#             axs[0].set_xlabel("Singular value index")
#             axs[0].set_ylabel("Singular value")

#             # the fraction of the energy captured by the first r singular values
#             axs[1].plot(
#                 np.cumsum(result[layer_name]["s1+s2"]) / torch.sum(result[layer_name]["s1+s2"])
#             )
#             axs[1].set_title("Cumulative sum of the singular values")
#             axs[1].set_xlabel("Singular value index")
#             axs[1].set_ylabel("Cumulative sum")
#             plt.show()

#             fig, axs = plt.subplots(
#                 nrows=1, ncols=2, figsize=(20, 10), sharey=False, sharex=False
#             )

#             axs[0].plot(result[layer_name]["tilde_s"])  # axs[0].semilogy(s_anchor1)
#             axs[0].set_title(f"Singular values of the  datasets in new Sigma")
#             axs[0].set_xlabel("Singular value index")
#             axs[0].set_ylabel("Singular value")

#             # the fraction of the energy captured by the first r singular values
#             axs[1].plot(
#                 np.cumsum(result[layer_name]["tilde_s"]) / torch.sum(result[layer_name]["tilde_s"])
#             )
#             axs[1].set_title("Cumulative sum of the singular values")
#             axs[1].set_xlabel("Singular value index")
#             axs[1].set_ylabel("Cumulative sum")
#             plt.show()

#             fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
#             cax = ax1.imshow(
#                 result[layer_name]["interf"],  # [:25, :25],
#                 vmin=-0.01,
#                 vmax=0.01,
#                 cmap=custom_cmap,
#                 aspect="auto",
#             )
#             ax1.set_title(f" {layer_name}: interference")
#             cbar = fig.colorbar(
#                 cax,
#                 # ticks=[-0.5, 0, 0.5],
#                 # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
#                 ticks=np.linspace(-0.01, 0.01, num_buckets),
#                 format=mticker.FixedFormatter(
#                     [round(x, 2) for x in np.linspace(-0.01, 0.01, num_buckets)]
#                 ),
#                 extend="both",
#             )

#             cax = ax2.imshow(
#                 result[layer_name]["no_interf"],  # [:25, :25],
#                 vmin=-0.005,
#                 vmax=0.005,
#                 cmap=custom_cmap,
#                 aspect="auto",
#             )
#             ax2.set_title(f"{layer_name}: reduced interference")
#             cbar = fig.colorbar(
#                 cax,
#                 # ticks=[-0.5, 0, 0.5],
#                 # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
#                 ticks=np.linspace(-0.005, 0.005, num_buckets),
#                 format=mticker.FixedFormatter(
#                     [round(x, 3) for x in np.linspace(-0.005, 0.005, num_buckets)]
#                 ),
#                 extend="both",
#             )
#             plt.show()

### Get SVD multi-task vector

In [None]:
type(layer_task_tensors[ref_dataset])

In [None]:
merged_layer_task_tensors = {}

for layer_name, layer_tensor in layer_task_tensors[ref_dataset].items():
    if is_matrix(layer_tensor) and layer_name not in layers_to_ignore:
        merged_layer_task_tensors[layer_name] = task_sing_vectors[layer_name]["task_sing_vec"]

    else:
        merged_layer_task_tensors[layer_name] = sum(
            [layer_task_tensors[dataset_name][layer_name] for dataset_name in dataset_names]
        ) / len(dataset_names)

### Apply multi-task vector

In [None]:
task_equipped_model = copy.deepcopy(zeroshot_model)

apply_task_vector(task_equipped_model, merged_layer_task_tensors)

In [None]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

## Evaluation

In [None]:
seed_index_everything(cfg)

accuracies = {}

for dataset in cfg.eval_datasets:
    classification_head_identifier = f"{cfg.nn.module.model.model_name}_{dataset}_head"
    classification_head = load_model_from_artifact(
        artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
    )

    model = hydra.utils.instantiate(
        cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
    )

    dataset = get_dataset(
        dataset,
        preprocess_fn=model.encoder.train_preprocess,
        location=cfg.nn.data.data_path,
        batch_size=cfg.nn.data.batch_size.train,
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

    storage_dir: str = cfg.core.storage_dir

    pylogger.info("Instantiating the <Trainer>")
    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
        logger=False,
        callbacks=callbacks,
        **cfg.train.trainer,
    )

    pylogger.info("Evaluating on the test set!")
    dataset_results = trainer.test(model=model, dataloaders=dataset.test_loader)

    accuracies[dataset] = dataset_results[0]["acc/test"]

In [None]:
mean_acc = sum(accuracies.values()) / len(accuracies)
mean_acc

In [None]:
# # pylogger.info("Evaluating on the training set")
# # trainer.test(model=model, dataloaders=dataset.train_loader)

# pylogger.info("Evaluating on the test set!")
# trainer.test(model=model, dataloaders=dataset.test_loader)