## Imports


In [1]:
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

import open_clip
import wandb

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
from mass.pl_module.image_classifier import ImageClassifier

# Force the execution of __init__.py if this file is executed directly.
import mass  # noqa
from mass.data.datasets.registry import get_dataset
from mass.modules.encoder import ClassificationHead, ImageEncoder
from mass.modules.projection_router import ProjectionRouter
from mass.modules.nn_router import NNRouter
from mass.modules.heads import get_classification_head
from mass.modules.router import AbstractRouter
from mass.pl_module.encoder import EncoderWrapper
from mass.utils.io_utils import load_model_from_disk
from mass.utils.plots import plot_interactive_radar_chart
from mass.utils.utils import (
    compute_task_dict, 
    apply_dict_to_model,
    build_callbacks,
    get_finetuning_accuracies,
    add_normalized_accuracy,
    compute_avg_accuracy,
    print_memory,
    get_routing_weights,
    svd_key_from_layer
)
from mass.task_vectors.task_singular_vectors import *
import json
import os

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

  from .autonotebook import tqdm as notebook_tqdm
Project not installed in the current env, activate the correct env or install it with:
	pip install -e .
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def boilerplate(cfg):
    cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

    num_tasks = len(cfg.eval_datasets)
    cfg.core.tags.append(f"n{num_tasks}")
    cfg.core.tags.append(f'{cfg.nn.module.encoder.model_name}')
    cfg.core.tags.append(f'mnist_notebook')

    template_core = NNTemplateCore(
        restore_cfg=cfg.train.get("restore", None),
    )
    logger: NNLogger = NNLogger(
        logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id
    )

    logger.upload_source()

    return logger, template_core


def get_classification_heads(cfg: DictConfig):
    classification_heads = []

    for dataset_name in cfg.eval_datasets:

        classification_head = get_classification_head(
            cfg.nn.module.encoder.model_name,
            dataset_name,
            cfg.nn.data.data_path,
            cfg.misc.ckpt_path,
            cache_dir=cfg.misc.cache_dir,
            openclip_cachedir=cfg.misc.openclip_cachedir,
        )

        classification_heads.append(classification_head)

    return classification_heads

def is_supported_layer(layer_key: str) -> bool:
    """
    Check if layer_key contains 'mlp' or 'attn' and 'resblocks.'
    """

    return (
        ("resblocks." in layer_key)
        and (("attn" in layer_key) or ("mlp" in layer_key))
        and not ("ln" in layer_key)
        and not ("gelu" in layer_key)
        and not ("c_proj" in layer_key)
        and not ("c_fc" in layer_key)
    )


## Imports

In [4]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="dev")
cfg = compose(config_name="optimisation", overrides=["nn/benchmark=debug"])

In [5]:
seed_index_everything(cfg)

logger, template_core = boilerplate(cfg)

Global seed set to 1608637542


[34m[1mwandb[0m: Currently logged in as: [33mzirilli-1967394[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
# upperbound accuracies, used for logging the normalized accuracy
finetuned_accuracies = get_finetuning_accuracies(cfg.misc.finetuned_accuracy_path)

In [8]:
# only has vision encoder, no text transformer
zeroshot_encoder_statedict = load_model_from_disk(cfg.misc.pretrained_checkpoint)

zeroshot_encoder: ImageEncoder = instantiate(
    cfg.nn.module.encoder
)  # the second pass backbone

zeroshot_encoder.load_state_dict(zeroshot_encoder_statedict, strict=False)

<All keys matched successfully>

In [9]:
finetuned_name = (
    lambda name: Path(cfg.misc.ckpt_path) / f"{name}Val" / "nonlinear_finetuned.pt"
)
finetuned_models = {
    dataset: load_model_from_disk(finetuned_name(dataset))
    for dataset in cfg.task_vectors.to_apply
}

num_tasks = len(cfg.eval_datasets)

pylogger.info(f"Number of tasks: {len(cfg.eval_datasets)}")
pylogger.info(f"Finetuned models: {list(finetuned_models.keys())}")



In [10]:
from torch.utils.data import Dataset, TensorDataset
from collections import defaultdict   

from mass.data.datasets.registry import get_task_evaluation_dataset
from mass.data.datasets.templates import get_dataset_to_label, get_dataset_label

datasets = {dataset_name: get_dataset(
            dataset_name,
            preprocess_fn=zeroshot_encoder.val_preprocess,
            location=cfg.nn.data.data_path,
            batch_size=cfg.nn.data.batch_size.train,
        ) for dataset_name in cfg.task_vectors.to_apply}

Loading training dataset from ../data/fer-2013/train...
Loading test dataset from ../data/fer-2013/test...


In [11]:
class LayerHook:

    def __init__(self, model: torch.nn.Module):
        self.middle_features: Dict[str, List[torch.Tensor]] = defaultdict(list)
        self.hooks = []

        pylogger.info(f"Registering hooks...")
        for name, module in model.named_modules():
            if not is_supported_layer(name):
                continue
            handle = module.register_forward_hook(self._hook_fn(name))
            self.hooks.append(handle)

    def _hook_fn(self, name: str):
        def hook(module, inputs, outputs):
            data = inputs[0] if isinstance(inputs, tuple) else inputs
            if isinstance(data, torch.Tensor):
                self.middle_features[name].append(data.detach().cpu())
            else:
                pylogger.warning(f"Unexpected input type {type(data)} at layer '{name}'")
        return hook

    def remove_hooks(self):
        for handle in self.hooks:
            handle.remove()
        self.hooks.clear()

In [12]:
class EmbeddingsDataset(Dataset):
    def __init__(
        self,
        finetuned_models: Dict[str, torch.nn.Module],
        datasets: Dict[str, pl.LightningDataModule],
        n_batches,
        cfg: dict,
        callbacks: List = None
    ):
        super().__init__()
        self.finetuned_models = finetuned_models
        self.datasets = datasets
        self.cfg = cfg
        self.n_batches = n_batches
        self.callbacks = callbacks or []

        self.loggers: Dict[str, LayerHook] = {}
        self.layer_datasets: Dict[str, TensorDataset] = {}

    def generate_layer_datasets(self) -> Dict[str, TensorDataset]:
        temp_feats: Dict[str, List[torch.Tensor]] = defaultdict(list)
        temp_labels: Dict[str, List[torch.Tensor]] = defaultdict(list)

        for task, model in self.finetuned_models.items():
            
            pylogger.info(f"Instantiating finetuned model for task: '{task}'")
            finetuned_encoder: ImageEncoder = instantiate(
                cfg.nn.module.encoder
            ) 

            finetuned_encoder.load_state_dict(model, strict=False)
                

            hook = LayerHook(finetuned_encoder)
            self.loggers[task] = hook

            lt_encoder: EncoderWrapper = instantiate(
                cfg.nn.module,
                encoder = finetuned_encoder,
                _recursive_=False,
            )
            
            label = get_dataset_label(task)

            trainer = pl.Trainer(
                default_root_dir=cfg.core.storage_dir,
                plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
                logger=logger,
                callbacks=self.callbacks,
                limit_test_batches=self.n_batches,
                **cfg.train.trainer,
            )

            dataloader = self.datasets[task].train_loader
            pylogger.info(f"Generating embeddings for task '{task}' with label {label}")
            trainer.test(model=lt_encoder, dataloaders=dataloader)


            hook.remove_hooks()
    
            for layer_name, feats in hook.middle_features.items():
                for batch_feats in feats:
                    batch_size = batch_feats.size(0)
                    temp_feats[layer_name].append(batch_feats)
                
                    temp_labels[layer_name].append(
                        torch.full((batch_size,), label, dtype=torch.long)
                    )
            del hook

        for layer_name in temp_feats:
            all_feats = torch.cat(temp_feats[layer_name], dim=0)
            all_labels = torch.cat(temp_labels[layer_name], dim=0)
            self.layer_datasets[layer_name] = TensorDataset(all_feats, all_labels)

        return self.layer_datasets

In [13]:
embed_dt = EmbeddingsDataset(finetuned_models, datasets, cfg.number_of_train_batches, cfg)

## Generate the datasets

In [14]:
data = embed_dt.generate_layer_datasets()

InterpolationKeyError: Interpolation key 'ntasks' not found

In [33]:
task_dicts = {}
for dataset in cfg.task_vectors.to_apply:
    task_dicts[dataset] = compute_task_dict(
        zeroshot_encoder_statedict, finetuned_models[dataset]
    )

In [34]:
task_dicts.keys()

dict_keys(['FER2013', 'RESISC45', 'MNIST'])

In [35]:
svd_dicts = get_svd_dict(
    task_dicts, cfg.eval_datasets, cfg.misc.svd_path, cfg.svd_compress_factor
)

Computing and compressing SVD: 100%|██████████| 3/3 [00:07<00:00,  2.61s/it]


## Optimisation

In [None]:
class layer_optim_problem():
    def __init__(cfg, loss, data, params, alg="sgd"):
        self.loss_fn = loss
        self.dataloader = data

        # TODO: maybe I have to format the params as grouped matrices

        self.optim = instantiate(cfg.optim, params)

        self.tasks = 


    def fit():   
        # here we fit the problem, maybe best option to always use all the matrices on the fly and then create a mask to simulate the I function.


## Optimization

In [13]:
import geoopt
from geoopt import ManifoldParameter
from torch import nn 

n_epochs = 5
alpha=1.0
gamma=0.1
ref_state_dict = copy.deepcopy(zeroshot_encoder.state_dict())

In [None]:
def build_manifold_params_from_svd_dicts(ref_state_dict, svd_dicts, device="cuda"):
    """
    Go through each layer in ref_state_dict. If the layer is 2D, create:
      - U, V as geoopt.ManifoldParameter on the Stiefel manifold
      - log_sigma_diag as a standard nn.Parameter (size = k for some rank)
    We'll pick k = the sum of ranks across all tasks, or a smaller one, etc.

    Return a dict: layer_params[layer_name] = {
        "U": <ManifoldParameter>,
        "log_sigma": <nn.Parameter>,
        "V": <ManifoldParameter>,
        "rank": <int>  # the dimension k
    }
    """
    layer_params = {}
    all_datasets = list(svd_dicts.keys())

    for layer_name in ref_state_dict:
        weight = ref_state_dict[layer_name]
        # Skip e.g. "text_projection" or 1D weights:
        if "text_projection" in layer_name or weight.dim() != 2:
            continue

        # We'll define k = sum of ranks across tasks, or a smaller fraction
        # Let's do sum of ranks from all tasks for illustration:
        total_rank = 0
        for ds in all_datasets:
            if layer_name.replace(".transformer", "") not in svd_dicts[ds]:
                # e.g. if that layer doesn't exist for some reason
                continue
            
            s_ = svd_dicts[ds][layer_name.replace(".transformer", "")].get("s", None)
            if s_ is not None:
                total_rank += s_.shape[0]

        if total_rank == 0:
            # no SVD info found; skip
            continue

        # Create manifold parameters
        # We'll create random U of shape (m, total_rank) stiefel => but if total_rank > m, it's invalid
        # so we clamp it to min(m, n)?  We'll do a clamp for safety:
        m, n = weight.shape
        k = min(total_rank, m, n)

        # Stiefel manifold for U in R^{m x k} and V in R^{n x k}
        stiefel_manifold = geoopt.manifolds.Stiefel()
        U_stiefel = geoopt.ManifoldParameter(
            torch.randn(m, k, device=device),
            manifold=stiefel_manifold
        )
        V_stiefel = geoopt.ManifoldParameter(
            torch.randn(n, k, device=device),
            manifold=stiefel_manifold
        )

        # log_sigma for a diagonal of length k
        # (some tasks might effectively only need part of that rank, but we'll keep it flexible)
        log_sigma = nn.Parameter(torch.zeros(k, device=device))

        layer_params[layer_name] = {
            "U": U_stiefel,
            "log_sigma": log_sigma,
            "V": V_stiefel,
            "rank": k
        }

    return layer_params

In [15]:
ref_state_dict = copy.deepcopy(zeroshot_encoder.state_dict())

# 1) Build manifold params
layer_params = build_manifold_params_from_svd_dicts(
    ref_state_dict, svd_dicts, device='cuda'
)


Aborted!


KeyboardInterrupt: 

In [16]:
def reconstruct_deltas_from_svd_dicts(svd_dicts, device="cuda"):
    """
    Precompute each dataset's per-layer matrix Delta_t = U * diag(s) * V^T
    Return a dict: full_deltas[dataset][layer_name] = (m x n) matrix
    """
    full_deltas = {}
    for ds_name, layers_dict in svd_dicts.items():
        full_deltas[ds_name] = {}
        for layer_key, comp in layers_dict.items():
            if "dim1" in comp:
                # It's not a 2D matrix (like biases, or 1D weights). We skip for merging.
                continue
            u = comp["u"].to(device)
            s = comp["s"].to(device)
            v = comp["v"].to(device)
            # Reconstruct
            Delta_t = (u * s.unsqueeze(0)) @ v  # shape (m, n) if s is rank
            full_deltas[ds_name][layer_key] = Delta_t

    return full_deltas


def stiefel_merge_loss(
    layer_params,  # dict of {layer_name: {"U", "V", "log_sigma", "rank"}}
    full_deltas,   # from reconstruct_deltas_from_svd_dicts
    batch,         # list/tuple => (imgs, ds_names)
    alpha=1.0,
    gamma=0.1,
    device="cuda"
):
    """
    Compute total loss = sum of reconstruction for each (dataset, layer)
                        + gamma * interference penalty on the batch
    """
    (imgs, ds_names) = batch  # ds_names: list of dataset labels
    # We'll do a single pass over all layers to accumulate the total.

    total_loss = torch.tensor(0.0, device=device)

    # 1) Reconstruction penalty
    for ds_name, layer_dict in full_deltas.items():
        for layer_name, delta_matrix in layer_dict.items():
            if layer_name not in layer_params:
                continue  # e.g. text_projection or 1D

            U = layer_params[layer_name]["U"]
            log_sigma = layer_params[layer_name]["log_sigma"]
            V = layer_params[layer_name]["V"]

            Sigma_vec = torch.exp(log_sigma)  # shape (k,)
            # Reconstruct the merged version
            merged = (U * Sigma_vec.unsqueeze(0)) @ V.transpose(0,1)
            diff = merged - delta_matrix
            total_loss += alpha * diff.pow(2).sum()

    # 2) Interference penalty
    # For demonstration, let's do a simple approach: if ds_name != some anchor, penalize.
    # We'll iterate over layers, do a transform on each image, and penalize big norms if ds_name doesn't "match."
    # In real code, you'd define "which layer does classification?" or so. This is a placeholder.

    # Flatten out the batch data
    B = len(ds_names)
    # Suppose each image is shape (C,H,W). We'll flatten to (n,) for matmul with V.
    # But note in reality you'd also want a real forward pass, so treat it with caution.

    # We'll do just a naive flatten. Or pick the shape of the relevant layer. We'll do a trivial example:
    # Step 1: flatten images
    imgs_flat = imgs.view(B, -1)        # shape (B, n_?)

    # For each layer, penalize the norm if sample's ds_name doesn't match
    for layer_name, params_dict in layer_params.items():
        U = params_dict["U"]
        V = params_dict["V"]
        Sigma_vec = torch.exp(params_dict["log_sigma"])
        # shape checks:
        #   V: (n, k), so we want input x of shape (B, n)
        if imgs_flat.shape[1] != V.shape[0]:
            # mismatch => skip or clamp. In a real system, you'd unify dimension or skip the mismatch.
            continue

        # transform each image:  x -> (U * diag(Sigma) * V^T) x
        # We'll do step by step:
        Vt_x = torch.matmul(imgs_flat, V)                 # shape (B, k)
        scaled = Vt_x * Sigma_vec.unsqueeze(0)            # shape (B, k)
        transform = torch.matmul(scaled, U.transpose(0,1))# shape (B, m)

        # Now define a penalty: if ds_name != any associated with layer??? Let's guess each layer_name is "shared."
        # We'll say we penalize the L2 norm if ds_name doesn't appear in the original tasks used by that layer.
        # Or simpler: penalize for all tasks that do not match ds_names in the batch.
        # We can check if ds_name is in full_deltas?
        # We'll do a small example: if ds_name is not in full_deltas => penalty. Or we skip.
        # Actually, let's do a simple "task mismatch" check:

        # For each sample i in [0..B-1], check ds_names[i].
        # If ds_names[i] is not relevant for this layer, penalize transform^2.
        # We'll define relevant_datasets = any ds that we used in building the layer_name. 
        # (So if the layer is reconstructed from tasks [Cars, DTD, ...], then it's relevant to those tasks only.)

        # We'll build a set of datasets that contributed to this layer_name
        relevant_datasets = set()
        for ds_ in full_deltas:
            if layer_name in full_deltas[ds_]:
                relevant_datasets.add(ds_)

        # Now for each sample, if ds_name not in relevant_datasets => penalize
        penalty_mask = []
        for ds_i in ds_names:
            if ds_i not in relevant_datasets:
                penalty_mask.append(1.0)
            else:
                penalty_mask.append(0.0)
        penalty_mask = torch.tensor(penalty_mask, device=device).unsqueeze(-1)  # shape (B,1)

        # L2 norm: transform^2 summed over dim=1
        sample_norms = transform.pow(2).sum(dim=1, keepdim=True)
        interference_term = (penalty_mask * sample_norms).mean()
        total_loss += gamma * interference_term

    return total_loss


In [17]:
def stiefel_merge_training(
    ref_state_dict,
    svd_dicts,
    union_dataloader,
    device="cuda",
    alpha=1.0,
    gamma=0.1,
    n_epochs=10,
):
    """
    Demonstration of how to:
    1) Build stiefel manifold parameters for each 2D layer
    2) Reconstruct original Delta_t from svd_dicts
    3) Train an objective that merges them
    """

    # 1) Build manifold params
    layer_params = build_manifold_params_from_svd_dicts(
        ref_state_dict, svd_dicts, device=device
    )

    # 2) Pre-reconstruct Delta_t for each dataset/layer
    full_deltas = reconstruct_deltas_from_svd_dicts(svd_dicts, device=device)

    # 4) Create optimizers
    #    We need to separate manifold (U, V) vs. Euclidean (log_sigma)
    stiefel_params_list = []
    euclid_params_list = []
    for lyr, pdict in layer_params.items():
        stiefel_params_list.append({"params": pdict["U"], "lr": 1e-2})
        stiefel_params_list.append({"params": pdict["V"], "lr": 1e-2})
        euclid_params_list.append({"params": pdict["log_sigma"], "lr": 1e-2})

    opt_stiefel = geoopt.optim.RiemannianAdam(stiefel_params_list)
    opt_euclid = torch.optim.Adam(euclid_params_list)

    # 5) Training Loop
    for epoch in range(n_epochs):
        total_epoch_loss = 0.0
        num_batches = 0

        for batch_data in union_dataloader:
            # batch_data = (imgs, task_label_list)
            # Convert ds_names from list of strings to a tuple for the penalty function
            imgs, ds_names = batch_data
            # Move images to device if needed
            # ds_names is a list of strings => keep on CPU, we'll handle logic in the loss

            opt_stiefel.zero_grad()
            opt_euclid.zero_grad()

            loss_val = stiefel_merge_loss(
                layer_params=layer_params,
                full_deltas=full_deltas,
                batch=(imgs, ds_names),
                alpha=alpha,
                gamma=gamma,
                device=device
            )
            loss_val.backward()

            opt_stiefel.step()
            opt_euclid.step()

            total_epoch_loss += loss_val.item()
            num_batches += 1

        avg_loss = total_epoch_loss / max(num_batches, 1)
        print(f"Epoch {epoch+1}/{n_epochs}: loss={avg_loss:.4f}")

    # After training, each layer in layer_params has final stiefel merges
    return layer_params

In [23]:
# 2) Pre-reconstruct Delta_t for each dataset/layer
full_deltas = reconstruct_deltas_from_svd_dicts(svd_dicts, device='cuda')

In [25]:
# 4) Create optimizers
#    We need to separate manifold (U, V) vs. Euclidean (log_sigma)
stiefel_params_list = []
euclid_params_list = []
for lyr, pdict in layer_params.items():
    stiefel_params_list.append({"params": pdict["U"], "lr": 1e-2})
    stiefel_params_list.append({"params": pdict["V"], "lr": 1e-2})
    euclid_params_list.append({"params": pdict["log_sigma"], "lr": 1e-2})

opt_stiefel = geoopt.optim.RiemannianAdam(stiefel_params_list)
opt_euclid = torch.optim.Adam(euclid_params_list)


In [28]:
# 5) Training Loop
for epoch in range(n_epochs):
    total_epoch_loss = 0.0
    num_batches = 0

    for batch_data in union_dataset.test_loader:
        # batch_data = (imgs, task_label_list)
        # Convert ds_names from list of strings to a tuple for the penalty function
        imgs, ds_names = batch_data
        # Move images to device if needed
        # ds_names is a list of strings => keep on CPU, we'll handle logic in the loss

        opt_stiefel.zero_grad()
        opt_euclid.zero_grad()

        loss_val = stiefel_merge_loss(
            layer_params=layer_params,
            full_deltas=full_deltas,
            batch=(imgs, ds_names),
            alpha=alpha,
            gamma=gamma,
            device='cuda'
        )
        loss_val.backward()

        opt_stiefel.step()
        opt_euclid.step()

        total_epoch_loss += loss_val.item()
        num_batches += 1

    avg_loss = total_epoch_loss / max(num_batches, 1)
    print(f"Epoch {epoch+1}/{n_epochs}: loss={avg_loss:.4f}")

Epoch 1/5: loss=39291330560.0000
Epoch 2/5: loss=39291330560.0000
Epoch 3/5: loss=39291330560.0000
Epoch 4/5: loss=39291330560.0000
Epoch 5/5: loss=39291330560.0000


In [21]:
# merged_params = stiefel_merge_training(
#         copy.deepcopy(zeroshot_encoder_statedict), svd_dicts,
#         device='cuda',
#         union_dataloader=union_dataset.train_loader,
#         alpha=1.0,
#         gamma=0.1,
#         n_epochs=5,
#     )

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

In [29]:
layer_params 

{'model.positional_embedding': {'U': Parameter on Stiefel(canonical) manifold containing:
  Parameter(ManifoldParameter([[ 1.2395,  1.1817, -1.3416,  ..., -0.4088,  0.4562,
                      -1.3234],
                     [-1.2955,  0.9866,  0.6011,  ..., -1.8194, -0.3278,
                       0.9454],
                     [-1.0240, -0.3207, -0.7043,  ..., -0.8137,  1.0631,
                       0.3863],
                     ...,
                     [ 0.7060,  0.9913, -0.7432,  ...,  0.3554, -0.3818,
                       0.1962],
                     [ 0.1101, -0.5613,  1.1466,  ...,  1.0450,  0.6369,
                      -0.0896],
                     [-0.9872,  0.8011, -1.1756,  ...,  2.0918, -1.2521,
                      -1.1824]], device='cuda:0', requires_grad=True)),
  'log_sigma': Parameter containing:
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

## Apply to pretrained

In [None]:
def apply_merged_params_to_model(model, merged_params, device="cuda"):
    """
    Given a model with a certain state_dict and a dictionary 'merged_params'
    (where each layer_name maps to a dict of {U, log_sigma, V, rank}),
    reconstruct the merged weight and copy it into the model's state dict.

    Args:
        model:      A PyTorch module whose weights we want to overwrite.
        merged_params: dict[str -> dict with keys {U, log_sigma, V, rank}],
                      as returned by your stiefel-based merge training.
        device:     "cuda" or "cpu" device to ensure consistent location of tensors.

    Returns:
        model: the model with updated weights in-place.
    """

    # 1) Extract the current state dict
    state_dict = model.state_dict()

    # 2) For each layer name in merged_params, reconstruct the merged matrix
    for layer_name, param_dict in merged_params.items():
        if layer_name not in state_dict:
            # Possibly "layer_name" is something like "layer1.weight"
            # but your model might store it differently. Make sure keys match up.
            print(f"Warning: '{layer_name}' not found in model state_dict. Skipping.")
            continue

        U = param_dict["U"].to(device)
        log_sigma = param_dict["log_sigma"].to(device)
        V = param_dict["V"].to(device)

        # Diagonal of Sigma = exp(log_sigma)
        sigma_vec = torch.exp(log_sigma)  # shape (k,)
        # shape: merged -> (m, n)
        #   U: (m, k)
        #   sigma_vec: (k,) => broadcast in dim=0 => (1, k)
        #   V^T: (k, n) => so we do V.transpose(0, 1)
        merged = (U * sigma_vec.unsqueeze(0)) @ V.transpose(0, 1)

        # 3) Overwrite in the state_dict
        # Ensure shape matches exactly what the model expects
        if merged.shape != state_dict[layer_name].shape:
            print(
                f"Shape mismatch on layer '{layer_name}': "
                f"merged={merged.shape}, model={state_dict[layer_name].shape}. Skipping."
            )
            continue

        # Copy the merged matrix into the model's buffer
        state_dict[layer_name].copy_(merged)

    # 4) Load the modified state dict back into model
    model.load_state_dict(state_dict)

    return model

In [38]:
model = apply_merged_params_to_model(model=zeroshot_encoder, merged_params=layer_params, device='cuda')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (77x75 and 512x75)

## Evaluation

In [34]:
def evaluate(model, dataset_name, preprocess_fn):

    dataset = get_dataset(
        dataset_name,
        preprocess_fn=preprocess_fn,
        location=cfg.nn.data.data_path,
        batch_size=cfg.nn.data.batch_size.train,
    )

    trainer = pl.Trainer(
        **cfg.train.trainer,
    )

    pylogger.error("For now evaluation supported only on val-set")

    pylogger.info(f"Evaluating on the {dataset_name} test set!")
    test_results = trainer.test(model=model, dataloaders=dataset.test_loader)

    return test_results[0]["acc/test"]

In [36]:
for dataset in cfg.task_vectors.to_apply:

    model = ImageClassifier(
            encoder=encoder,
            x_key='x',
            y_key='y',
            classifier=get_classification_head(
                cfg.nn.module.encoder.model_name,
                dataset,
                cfg.nn.data.data_path,
                cfg.misc.ckpt_path,
                cache_dir=cfg.misc.cache_dir,
                openclip_cachedir=cfg.misc.openclip_cachedir,
            ),
        )
    
    evaluate(model, dataset, zeroshot_encoder.val_preprocess)


Loading classification head from /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_FER2013.pt


Loading training dataset from /media/donato/Extra-storage/Code/model-merging/mass/data/fer-2013/train...
Loading test dataset from /media/donato/Extra-storage/Code/model-merging/mass/data/fer-2013/test...
Building classification head.


100%|██████████| 7/7 [00:00<00:00, 136.55it/s]

Saving classification head to /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_FER2013.pt



  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


Loading training dataset from /media/donato/Extra-storage/Code/model-merging/mass/data/fer-2013/train...
Loading test dataset from /media/donato/Extra-storage/Code/model-merging/mass/data/fer-2013/test...


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


Missing logger folder: /media/donato/Extra-storage/Code/model-merging/mass/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 225/225 [00:06<00:00, 34.30it/s]


Loading classification head from /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_RESISC45.pt


Building classification head.


100%|██████████| 45/45 [00:00<00:00, 81.90it/s]
  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


Saving classification head to /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_RESISC45.pt


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 197/197 [00:09<00:00, 21.81it/s]


Loading classification head from /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_MNIST.pt


Building classification head.


100%|██████████| 10/10 [00:00<00:00, 254.72it/s]
INFO: GPU available: True (cuda), used: True


Saving classification head to /media/donato/Extra-storage/Code/model-merging/mass/checkpoints//ViT-B-32/head_MNIST.pt


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 313/313 [00:07<00:00, 39.29it/s]
