In [None]:
import os
from pprint import pprint

import hydra
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf

from src.eval.aggregation import create_task_vector
from src.eval.eval_utils import perform_eval_with_merged_vector
from src.utils.variables_and_paths import ALL_DATASETS

In [None]:
with hydra.initialize(version_base="1.3", config_path="config", job_name="test_app"):
    cfg = hydra.compose(config_name="config")

cfg.method.name = "SVD"

if cfg.DATASETS == "":
    cfg.DATASETS = ALL_DATASETS[: cfg.num_tasks]
else:
    cfg.num_tasks = len(cfg.DATASETS)
cfg.DATASETS_VAL = [dataset + "Val" for dataset in cfg.DATASETS]
cfg.data_location = os.path.expanduser(cfg.data_location)
OmegaConf.set_struct(cfg, True)

print(cfg.method.full_name)
print()
print(OmegaConf.to_yaml(cfg))
OmegaConf.set_struct(cfg, True)

In [None]:
import torch

from src.models.task_vectors import ImageEncoder, NonLinearTaskVector
from src.utils.utils import (
    check_parameterNamesMatch,
)
from src.utils.svd_utils import (
    compute_svd_dict,
    sum_svd_dict,
    compute_and_sum_svd_mem_reduction,
)
from src.eval.aggregation import get_all_checkpoints

ft_checks, ptm_check = get_all_checkpoints(cfg)
check_parameterNamesMatch(ft_checks + [ptm_check])

remove_keys = []

print(f"Flattening out Checkpoints")
task_vectors = [NonLinearTaskVector(cfg.model, ptm_check, check) for check in ft_checks]

print(f"MODEL: {cfg.model}, METHOD {cfg.method.name}")

print(f"=== Using SVD ===")
svd_dict = compute_svd_dict(task_vectors, cfg)

In [None]:
pow = 2
result = {}
for key in svd_dict[cfg.DATASETS[0]]:
    if "u" in svd_dict[cfg.DATASETS[0]][key].keys() and len(task_vectors[0].vector[key].shape) == 2:
        sum_u = sum([svd_dict[dataset][key]["u"] for dataset in cfg.DATASETS])
        sum_s = sum([svd_dict[dataset][key]["s"] for dataset in cfg.DATASETS])
        sum_v = sum([svd_dict[dataset][key]["v"] for dataset in cfg.DATASETS])

        var_u = torch.pow(
            torch.linalg.multi_dot((sum_u.mT, sum_u, torch.diag(sum_s))),
            pow,
        )
        var_v = torch.pow(
            torch.linalg.multi_dot((torch.diag(sum_s), sum_v, sum_v.mT)),
            pow,
        )

        var_u = var_u / (torch.sum(torch.abs(var_u), dim=0) + 1e-12)
        var_v = var_v / (torch.sum(torch.abs(var_v), dim=1, keepdim=True) + 1e-12)

        tilde_s = torch.diagonal(torch.diag(sum_s) @ (var_u * var_v))

        new_vector = torch.linalg.multi_dot(  # bool_mask *
            (
                sum_u,
                torch.diag(sum_s) @ (var_u * var_v),
                sum_v,
            )
        )

        result[key] = {
            "u1_u2": sum_u.mT @ sum_u,
            "s1+s2": sum_s,
            "tilde_s": tilde_s,
            "v1_v2": sum_v @ sum_v.mT,
            "interf": sum_u.mT @ (sum_u @ torch.diag_embed(sum_s) @ sum_v) @ sum_v.mT,
            "no_interf": sum_u.mT @ new_vector @ sum_v.mT,
        }

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.colors import LinearSegmentedColormap

num_buckets = 11

color_values = np.linspace(-0.5, 0.5, num_buckets + 1)
color_list = []

for i in range(num_buckets):
    color = plt.cm.RdBu((i + 0.5) / num_buckets)  # Adjusting 0.5 to center colors
    color_list.append(color)

# Create a LinearSegmentedColormap with your custom colors
custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", color_list, num_buckets)

for key in svd_dict[cfg.DATASETS[0]]:
    if "u" in svd_dict[cfg.DATASETS[0]][key].keys() and len(task_vectors[0].vector[key].shape) == 2:
        if key not in ["model.token_embedding.weight", "model.positional_embedding"]:
            print(f"Plotting  {key}")

            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
            cax = ax1.imshow(
                result[key]["u1_u2"],  # [:25, :25],
                vmin=-0.5,
                vmax=0.5,
                cmap=custom_cmap,
                aspect="auto",
            )
            ax1.set_title(f" {key}: sum_u.T @ sum_u")
            cbar = fig.colorbar(
                cax,
                # ticks=[-0.5, 0, 0.5],
                # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
                ticks=np.linspace(-1, 1, num_buckets),
                format=mticker.FixedFormatter([round(x, 2) for x in np.linspace(-1, 1, num_buckets)]),
                extend="both",
            )

            cax = ax2.imshow(
                result[key]["v1_v2"],  # [:25, :25],
                vmin=-0.5,
                vmax=0.5,
                cmap=custom_cmap,
                aspect="auto",
            )
            ax2.set_title(f"{key}: sum_v @ sum_v.T")
            cbar = fig.colorbar(
                cax,
                # ticks=[-0.5, 0, 0.5],
                # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
                ticks=np.linspace(-1, 1, num_buckets),
                format=mticker.FixedFormatter([round(x, 2) for x in np.linspace(-1, 1, num_buckets)]),
                extend="both",
            )
            plt.show()

            fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), sharey=False, sharex=False)

            axs[0].plot(result[key]["s1+s2"])  # axs[0].semilogy(s_anchor1)
            axs[0].set_title(f"Singular values of the  datasets in concatenation")
            axs[0].set_xlabel("Singular value index")
            axs[0].set_ylabel("Singular value")

            # the fraction of the energy captured by the first r singular values
            axs[1].plot(np.cumsum(result[key]["s1+s2"]) / torch.sum(result[key]["s1+s2"]))
            axs[1].set_title("Cumulative sum of the singular values")
            axs[1].set_xlabel("Singular value index")
            axs[1].set_ylabel("Cumulative sum")
            plt.show()

            fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), sharey=False, sharex=False)

            axs[0].plot(result[key]["tilde_s"])  # axs[0].semilogy(s_anchor1)
            axs[0].set_title(f"Singular values of the  datasets in new Sigma")
            axs[0].set_xlabel("Singular value index")
            axs[0].set_ylabel("Singular value")

            # the fraction of the energy captured by the first r singular values
            axs[1].plot(np.cumsum(result[key]["tilde_s"]) / torch.sum(result[key]["tilde_s"]))
            axs[1].set_title("Cumulative sum of the singular values")
            axs[1].set_xlabel("Singular value index")
            axs[1].set_ylabel("Cumulative sum")
            plt.show()

            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
            cax = ax1.imshow(
                result[key]["interf"],  # [:25, :25],
                vmin=-0.01,
                vmax=0.01,
                cmap=custom_cmap,
                aspect="auto",
            )
            ax1.set_title(f" {key}: interference")
            cbar = fig.colorbar(
                cax,
                # ticks=[-0.5, 0, 0.5],
                # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
                ticks=np.linspace(-0.01, 0.01, num_buckets),
                format=mticker.FixedFormatter([round(x, 2) for x in np.linspace(-0.01, 0.01, num_buckets)]),
                extend="both",
            )

            cax = ax2.imshow(
                result[key]["no_interf"],  # [:25, :25],
                vmin=-0.005,
                vmax=0.005,
                cmap=custom_cmap,
                aspect="auto",
            )
            ax2.set_title(f"{key}: reduced interference")
            cbar = fig.colorbar(
                cax,
                # ticks=[-0.5, 0, 0.5],
                # format=mticker.FixedFormatter(["< -0.5", "0", "> 0.5"]),
                ticks=np.linspace(-0.005, 0.005, num_buckets),
                format=mticker.FixedFormatter([round(x, 3) for x in np.linspace(-0.005, 0.005, num_buckets)]),
                extend="both",
            )
            plt.show()