## Imports

In [1]:
import logging
import os
from typing import List, Optional

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import tvp  # noqa
from tvp.data.datamodule import MetaData
from tvp.data.datasets.registry import get_dataset
from tvp.task_vectors.task_vectors import TaskVector
from tvp.utils.io_utils import load_model_from_artifact
from tvp.utils.utils import build_callbacks
from torch.nn.utils import vector_to_parameters
from torch.nn.utils import parameters_to_vector

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

  from .autonotebook import tqdm as notebook_tqdm
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


## Configuration

In [2]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="playground")

hydra.initialize()

In [3]:
cfg = compose(config_name="task_vectors", overrides=[])

In [4]:
seed_index_everything(cfg)

cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

template_core: NNTemplateCore = NNTemplateCore(
    restore_cfg=cfg.train.get("restore", None),
)
logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

Global seed set to 1608637542


[34m[1mwandb[0m: Currently logged in as: [33mcrisostomi[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Load models

In [5]:
import copy


zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

[34m[1mwandb[0m: Downloading large artifact ViT-B-16_pt:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_SVHN_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_MNIST_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.7


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_CIFAR100_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


## Spherical average

In [6]:
import numpy as np


def inner_product(point, tangent_vector_a, tangent_vector_b):
    return torch.tensordot(
        tangent_vector_a, tangent_vector_b, dims=tangent_vector_a.ndim
    )

def dist(point_a, point_b):
    inner = max(min(inner_product(point_a, point_a, point_b), 1), -1)
    return torch.arccos(inner)

def norm(point, tangent_vector):
    return torch.norm(tangent_vector)

def projection(point, vector):
    return vector - inner_product(point, point, vector) * point

def ell_q(q, p):
    vector = projection(q, p - q)
    distance = dist(q, p)
    epsilon = np.finfo(np.float64).eps
    factor = (distance + epsilon) / (norm(q, vector) + epsilon)
    return factor * vector

def exp_q(point, tangent_vector):
    norm_value = norm(point, tangent_vector)
    return point * torch.cos(norm_value) + tangent_vector * torch.sinc(norm_value / np.pi)


In [7]:
from tqdm import tqdm


def spherical_weighted_average(points, weights, tol=1e-8, max_iter=1000, dim=2, verbose=False):
    """
    Compute the spherical weighted average of points on a sphere with given weights using PyTorch.
    
    Args:
    - points (torch.Tensor): A tensor of shape (n, d+1) representing n points on the d-dimensional sphere S^d.
    - weights (torch.Tensor): A tensor of shape (n,) representing the non-negative weights with sum 1.
    - tol (float): Tolerance for the stopping criterion based on the norm of u.
    - max_iter (int): Maximum number of iterations for the main loop.

    Returns:
    - q (torch.Tensor): The spherical weighted average of the input points.
    """

    print(f'Points shape: {points.shape}')
    print(f'Weights shape: {weights.shape}')

    points = points.cuda()
    weights = weights.cuda()

    points_copy = copy.deepcopy(points.clone())

    points =  torch.nn.functional.normalize(points, p=2, dim=1)

    # (num_points, d+1)
    assert points.shape[-1] == dim+1, f"points.shape = {points.shape}, dim = {dim}"

    # points have shape (num_tasks, num_params)
    # weights have shape (num_tasks,)
    
    with torch.no_grad():
        # Ensure weights sum to 1
        weights = weights / weights.sum()
        
        # Initialization, q has shape (d+1,)
        q = (weights[:, None] * points).sum(dim=0)
        q = q / (torch.norm(q) )

        assert q.shape[0] == dim+1, f"q.shape = {q.shape}, dim = {dim}"

        for _ in tqdm(range(max_iter)):
            # Compute p_i^* for each point

            # (num_points, d+1)
            p_i_stars = torch.stack([ell_q(q, p) for p in points])
            u = (weights[:, None] * (p_i_stars )).sum(dim=0) 
            q = exp_q(q, u)

            q = q / (torch.norm(q) )

            # Check if u is sufficiently small
            if torch.norm(u) < tol:
                break
            
            if verbose:
                print(f"Norm: {torch.norm(u)}")

    # solve for alphas such that alphas * points = q and such that sum(alphas) = 1

    # (num_points, d+1) --> (2, 32762332 + 1)
    constraint_weight = 100
    weights_sum_to_one = torch.full((points.shape[0], ), fill_value=constraint_weight, device=points.device).unsqueeze(1)
    points_with_constraint = torch.cat([points, weights_sum_to_one], dim=1)
    q_with_constraint = torch.cat([q, torch.tensor([constraint_weight], device=points.device)] )

    # solve the system of linear equations Ax = B, where A ~ (num_eqs, num_variables), x ~ (num_variables), B ~ (num_eqs)
    # for us, each row of A is a point and the last row is the constraint
    # this means that for each param in the models, we have an equation that
    alphas = torch.linalg.lstsq(points_with_constraint.T, q_with_constraint).solution

    interpolated_vector = (alphas[:, None] * points_copy).sum(dim=0)

    print(f'Found spherical interpolation coefficients: {alphas}, summing to {alphas.sum()}')
    print(f'Reconstruction error: {(interpolated_vector - q).sum()}')

    return interpolated_vector

## SLERP

In [8]:
import torch
from typing import Union


def slerp(
    t: Union[float, torch.Tensor],
    v0: torch.Tensor,
    v1: torch.Tensor,
    DOT_THRESHOLD: float = 0.9995,
    eps: float = 1e-8,
):
    """
    Spherical linear interpolation using PyTorch

    Args:
        t (float/torch.Tensor): Float value between 0.0 and 1.0
        v0 (torch.Tensor): Starting vector
        v1 (torch.Tensor): Final vector
        DOT_THRESHOLD (float): Threshold for considering the two vectors as
                               colinear. Not recommended to alter this.
    Returns:
        v2 (torch.Tensor): Interpolation vector between v0 and v1
    """
    # Ensure inputs are tensors
    v0 = torch.tensor(v0, dtype=torch.float32).cuda()
    v1 = torch.tensor(v1, dtype=torch.float32).cuda()

    v0_copy = v0.clone()
    v1_copy = v1.clone()

    # Normalize the vectors to get the directions and angles
    v0 = v0 / (torch.norm(v0) + eps)
    v1 = v1 / (torch.norm(v1) + eps)

    # Dot product with the normalized vectors
    dot = torch.sum(v0 * v1)

    # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
    if torch.abs(dot) > DOT_THRESHOLD:
        print("colinear vectors")
        return None 

    # Calculate initial angle between v0 and v1
    theta_0 = torch.acos(dot)
    sin_theta_0 = torch.sin(theta_0)

    # Angle at timestep t
    theta_t = theta_0 * t
    sin_theta_t = torch.sin(theta_t)

    # Finish the slerp algorithm
    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0

    res = s0 * v0_copy + s1 * v1_copy

    print(f'Interpolation coefficients: {s0, s1}')

    return res

## Task vectors

In [9]:
flatten = lambda model: parameters_to_vector(model.parameters()) 

zeroshot_vec = flatten(zeroshot_model)

In [30]:
task_vectors = [TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply]

In [31]:

def apply_task_vector(model, task_vector):
    model.load_state_dict({k: v + task_vector[k] for k, v in model.state_dict().items()})

### Aggregate task vectors

In [32]:
with torch.no_grad():
    task_vectors = torch.stack([flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply])

### Standard task vectors

In [34]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

### Spherical

In [35]:
weights = torch.full(size=(len(task_vectors),), fill_value=1/len(task_vectors)).cuda()

task_vector_spherical = spherical_weighted_average(copy.deepcopy(task_vectors), weights, tol=1e-7, max_iter=200, dim=task_vectors.shape[-1]-1)
task_vector_spherical = task_vector_spherical.cpu()

Points shape: torch.Size([3, 111792129])
Weights shape: torch.Size([3])


  3%|▎         | 6/200 [00:00<00:09, 20.19it/s]


Found spherical interpolation coefficients: tensor([0.3382, 0.3375, 0.3243], device='cuda:0'), summing to 1.000023603439331
Reconstruction error: 10.372989654541016


In [15]:
task_vectors_slerp_torch = slerp(0.5, task_vectors[0], task_vectors[1]).cpu()
print((task_vector_spherical - task_vectors_slerp_torch).norm())

  v0 = torch.tensor(v0, dtype=torch.float32).cuda()
  v1 = torch.tensor(v1, dtype=torch.float32).cuda()


Interpolation coefficients: (tensor(0.6551, device='cuda:0'), tensor(0.6551, device='cuda:0'))
tensor(2.0776)


In [17]:
# alpha = 0.8

# multi_task_vector = task_vectors_sum / len(task_vectors)
multi_task_vector = task_vector_spherical 
# multi_task_vector = task_vectors_slerp_torch

In [18]:
print(task_vector_spherical.norm())
print(task_vectors.norm(keepdim=True, dim=1))

tensor(2.2642)
tensor([[3.7747],
        [3.1578],
        [4.0069]])


In [62]:
delta_model = copy.deepcopy(zeroshot_model) 
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [63]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

In [64]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
  rank_zero_warn(
  rank_zero_warn(


## Load dataset

In [65]:
seed_index_everything(cfg)

dataset = get_dataset(
    cfg.nn.data.train_dataset,
    preprocess_fn=model.encoder.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    logger=False,
    callbacks=callbacks,
    **cfg.train.trainer,
    
)

Global seed set to 1608637542


Files already downloaded and verified
Files already downloaded and verified


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


## Evaluation

In [66]:
# pylogger.info("Evaluating on the training set")
# trainer.test(model=model, dataloaders=dataset.train_loader)

pylogger.info("Evaluating on the test set!")
trainer.test(model=model, dataloaders=dataset.test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 40/40 [00:10<00:00,  3.76it/s]


[{'acc/test': 0.7657999992370605, 'loss/test': 1.0084812641143799}]