## Imports


In [1]:
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

import open_clip
import wandb
import torch.nn as nn

import hydra
import omegaconf
import pytorch_lightning as pl
import tqdm
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.utils.data import DataLoader

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
from mass.pl_module.image_classifier import ImageClassifier

# Force the execution of __init__.py if this file is executed directly.
import mass  # noqa
from mass.data.datasets.registry import get_dataset
from mass.modules.encoder import ClassificationHead, ImageEncoder
from mass.modules.projection_router import ProjectionRouter
from mass.modules.nn_router import NNRouter
from mass.modules.heads import get_classification_head
from mass.modules.router import AbstractRouter
from mass.pl_module.encoder import EncoderWrapper
from mass.utils.io_utils import load_model_from_disk
from mass.utils.plots import plot_interactive_radar_chart
from mass.utils.utils import (
    compute_task_dict, 
    apply_dict_to_model,
    build_callbacks,
    get_finetuning_accuracies,
    add_normalized_accuracy,
    compute_avg_accuracy,
    print_memory,
    get_routing_weights,
    svd_key_from_layer
)
from mass.task_vectors.task_singular_vectors import *
import json
import os

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

  from .autonotebook import tqdm as notebook_tqdm
Project not installed in the current env, activate the correct env or install it with:
	pip install -e .
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [2]:
%load_ext autoreload
%autoreload 2

In [34]:
def boilerplate(cfg):
    cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

    num_tasks = len(cfg.eval_datasets)
    cfg.core.tags.append(f"n{num_tasks}")
    cfg.core.tags.append(f'{cfg.nn.module.encoder.model_name}')
    cfg.core.tags.append(f'optim_notebook')

    template_core = NNTemplateCore(
        restore_cfg=cfg.train.get("restore", None),
    )
    logger: NNLogger = NNLogger(
        logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id
    )

    logger.upload_source()

    return logger, template_core


def get_classification_heads(cfg: DictConfig):
    classification_heads = []

    for dataset_name in cfg.eval_datasets:

        classification_head = get_classification_head(
            cfg.nn.module.encoder.model_name,
            dataset_name,
            cfg.nn.data.data_path,
            cfg.misc.ckpt_path,
            cache_dir=cfg.misc.cache_dir,
            openclip_cachedir=cfg.misc.openclip_cachedir,
        )

        classification_heads.append(classification_head)

    return classification_heads

def is_supported_layer(layer_key: str) -> bool:
    """
    Keep layers inside resblocks, attn or mlp, but exclude only biases and layer norms.
    """
    return (
        ("resblocks." in layer_key)
        and (("attn" in layer_key) or ("mlp" in layer_key))
        and not ("ln" in layer_key)
        and not ("gelu" in layer_key)
        and not ("bias" in layer_key)
        and not ("c_proj" in layer_key)
        and not ("c_fc" in layer_key)
    )

def is_supported_layer_svd(layer_key: str) -> bool:
    """
    Keep layers inside resblocks, attn or mlp, but exclude only biases and layer norms.
    """
    return (
        ("resblocks." in layer_key)
        and (("attn" in layer_key) or ("mlp" in layer_key))
        and not ("ln" in layer_key)
        and not ("gelu" in layer_key)
        and not ("bias" in layer_key)
        and not ("c_proj" in layer_key)
        and not ("out_proj" in layer_key)
    )

def from_router_to_svd_dict_key(key):
    key = key.replace("model.encoder.", "")
    if "attn" in key:
        return key + ".in_proj_weight"
    if "mlp" in key:
        return key + ".c_fc.weight"

def svd_key_to_router_key(svd_key: str) -> str:
    if svd_key.endswith(".in_proj_weight"):
        base = svd_key[:-len(".in_proj_weight")]
    elif svd_key.endswith(".c_fc.weight"):
        base = svd_key[:-len(".c_fc.weight")]
    else:
        raise ValueError(f"Invalid SVD format {svd_key!r}")

    if not base.startswith("model.visual."):
        raise ValueError(f"Not a valid prefix {base!r}")
    return base.replace("model.visual.", "model.visual.transformer.", 1)

def add_transformer_key(layer: str):
    if not layer.startswith("model.visual."):
        raise ValueError(f"Not a valid prefix {base!r}")
    return layer.replace("model.visual.", "model.visual.transformer.", 1)


## Imports

In [4]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="dev")
cfg = compose(config_name="optimisation", overrides=["nn/benchmark=debug"])

In [5]:
seed_index_everything(cfg)

logger, template_core = boilerplate(cfg)

Global seed set to 1608637542


[34m[1mwandb[0m: Currently logged in as: [33mzirilli-1967394[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# upperbound accuracies, used for logging the normalized accuracy
finetuned_accuracies = get_finetuning_accuracies(cfg.misc.finetuned_accuracy_path)

In [7]:
# only has vision encoder, no text transformer
zeroshot_encoder_statedict = load_model_from_disk(cfg.misc.pretrained_checkpoint)

zeroshot_encoder: ImageEncoder = instantiate(
    cfg.nn.module.encoder
)  # the second pass backbone

zeroshot_encoder.load_state_dict(zeroshot_encoder_statedict, strict=False)

<All keys matched successfully>

In [8]:
finetuned_name = (
    lambda name: Path(cfg.misc.ckpt_path) / f"{name}Val" / "nonlinear_finetuned.pt"
)
finetuned_models = {
    dataset: load_model_from_disk(finetuned_name(dataset))
    for dataset in cfg.task_vectors.to_apply
}

num_tasks = len(cfg.eval_datasets)

pylogger.info(f"Number of tasks: {len(cfg.eval_datasets)}")
pylogger.info(f"Finetuned models: {list(finetuned_models.keys())}")



In [9]:
from torch.utils.data import Dataset, TensorDataset
from collections import defaultdict   

from mass.data.datasets.registry import get_task_evaluation_dataset
from mass.data.datasets.templates import get_dataset_to_label, get_dataset_label

datasets = {dataset_name: get_dataset(
            dataset_name,
            preprocess_fn=zeroshot_encoder.val_preprocess,
            location=cfg.nn.data.data_path,
            batch_size=cfg.nn.data.batch_size.train,
        ) for dataset_name in cfg.task_vectors.to_apply}

Loading training dataset from ../data/fer-2013/train...
Loading test dataset from ../data/fer-2013/test...


In [10]:
class LayerHook:

    def __init__(self, model: torch.nn.Module):
        self.middle_features: Dict[str, List[torch.Tensor]] = defaultdict(list)
        self.hooks = []

        pylogger.info(f"Registering hooks...")
        for name, module in model.named_modules():
            if not is_supported_layer(name):
                continue
            handle = module.register_forward_hook(self._hook_fn(name))
            self.hooks.append(handle)

    def _hook_fn(self, name: str):
        def hook(module, inputs, outputs):
            data = inputs[0] if isinstance(inputs, tuple) else inputs
            if isinstance(data, torch.Tensor):
                self.middle_features[name].append(data.permute(1, 0, 2).detach().cpu())
            else:
                pylogger.warning(f"Unexpected input type {type(data)} at layer '{name}'")
        return hook

    def remove_hooks(self):
        for handle in self.hooks:
            handle.remove()
        self.hooks.clear()

In [11]:
class EmbeddingsDataset(Dataset):
    def __init__(
        self,
        finetuned_models: Dict[str, torch.nn.Module],
        datasets: Dict[str, pl.LightningDataModule],
        n_batches,
        cfg: dict,
        callbacks: List = None
    ):
        super().__init__()
        self.finetuned_models = finetuned_models
        self.datasets = datasets
        self.cfg = cfg
        self.n_batches = n_batches
        self.callbacks = callbacks or []

        self.loggers: Dict[str, LayerHook] = {}
        self.layer_datasets: Dict[str, TensorDataset] = {}

    def generate_layer_datasets(self) -> Dict[str, TensorDataset]:
        temp_feats: Dict[str, List[torch.Tensor]] = defaultdict(list)
        temp_labels: Dict[str, List[torch.Tensor]] = defaultdict(list)

        for task, model in self.finetuned_models.items():
            
            pylogger.info(f"Instantiating finetuned model for task: '{task}'")
            finetuned_encoder: ImageEncoder = instantiate(
                cfg.nn.module.encoder
            ) 

            finetuned_encoder.load_state_dict(model, strict=False)
                

            hook = LayerHook(finetuned_encoder)
            self.loggers[task] = hook

            lt_encoder: EncoderWrapper = instantiate(
                cfg.nn.module,
                encoder = finetuned_encoder,
                _recursive_=False,
            )
            
            label = get_dataset_label(task)

            trainer = pl.Trainer(
                default_root_dir=cfg.core.storage_dir,
                plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
                logger=logger,
                callbacks=self.callbacks,
                limit_test_batches=self.n_batches,
                **cfg.train.trainer,
            )

            dataloader = self.datasets[task].train_loader
            pylogger.info(f"Generating embeddings for task '{task}' with label {label}")
            trainer.test(model=lt_encoder, dataloaders=dataloader)


            hook.remove_hooks()
    
            for layer_name, feats in hook.middle_features.items():
                for batch_feats in feats:
                    batch_size = batch_feats.size(0)
                    temp_feats[layer_name].append(batch_feats)
                
                    temp_labels[layer_name].append(
                        torch.full((batch_size,), label, dtype=torch.long)
                    )
            del hook

        for layer_name in temp_feats:
            all_feats = torch.cat(temp_feats[layer_name], dim=0)
            all_labels = torch.cat(temp_labels[layer_name], dim=0)
            self.layer_datasets[layer_name] = TensorDataset(all_feats, all_labels)

        return self.layer_datasets

In [12]:
embed_dt = EmbeddingsDataset(finetuned_models, datasets, cfg.number_of_train_batches, cfg)

## Generate the datasets

In [13]:
data = embed_dt.generate_layer_datasets()

  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  3.70it/s]


  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 12.38it/s]


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 12.01it/s]


In [14]:
pylogger.info(data.keys())

In [15]:
task_dicts = {}
for dataset in cfg.task_vectors.to_apply:
    task_dicts[dataset] = compute_task_dict(
        zeroshot_encoder_statedict, finetuned_models[dataset]
    )

In [16]:
task_dicts.keys()

dict_keys(['FER2013', 'RESISC45', 'MNIST'])

In [17]:
svd_dicts = get_svd_dict(
    task_dicts, cfg.eval_datasets, cfg.misc.svd_path, cfg.svd_compress_factor
)

In [18]:
pylogger.info(svd_dicts.keys())

In [19]:
pylogger.info(svd_dicts['FER2013'].keys())

In [20]:
for key in svd_dicts['FER2013'].keys():
    if is_supported_layer_svd(key):
        pylogger.info(key)


## Optimisation

In [21]:
def format_parameters(
    parameters: Dict[str, Dict[str, Dict[str, torch.Tensor]]],
    train: bool,
    device: torch.device
) -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, int]]:
    tasks = list(parameters.keys())
    layers = parameters[tasks[0]].keys()
    state: Dict[str, Dict[str, torch.Tensor]] = {}
    for layer in layers:
        if not is_supported_layer_svd(layer):
            continue
        u_list, s_list, v_list = [], [], []
        for task in tasks:
            svd = parameters[task][layer]
            u_list.append(svd['u'])
            s_list.append(svd['s'])
            v_list.append(svd['v'])

        u = torch.stack(u_list).to(device)
        s = torch.stack(s_list).to(device)
        v = torch.stack(v_list).to(device)

        state[layer] = {
            'u': nn.Parameter(u.clone().detach(), requires_grad=train),
            's': nn.Parameter(s.clone().detach(), requires_grad=train),
            'v': nn.Parameter(v.clone().detach(), requires_grad=train),
        }
    idx_map = {get_dataset_label(task): i for i, task in enumerate(tasks)}
    return state, idx_map

class LayerOptimProblem:
    def __init__(
        self,
        cfg,
        parameters: Dict[str, Dict[str, Dict[str, torch.Tensor]]],
        dataset,
        device,
    ) -> None:
        self.device = device
        self.params, self.idx_map = format_parameters(parameters, train=True, device=self.device)
        self.original, _         = format_parameters(parameters, train=False, device=self.device)
        for layer in self.params:
            for name in ('u','s','v'):
                self.params[layer][name]     = self.params[layer][name].to(self.device)
                self.original[layer][name]   = self.original[layer][name].to(self.device).detach()
        self.dataset = dataset
        optim_params = []
        for layer_state in self.params.values():
            optim_params += [
                layer_state['u'],
                layer_state['s'],
                layer_state['v'],
            ]
        self.optimizer = instantiate(cfg.optim, params=optim_params)

    def _reconstruct(self, layer_params: Dict[str, torch.Tensor]) -> torch.Tensor:
        u = layer_params['u']
        s = torch.diag_embed(layer_params['s'])
        vT = layer_params['v']
        return torch.einsum('tir,trr,trm->tim', u, s, vT)

    def _recon_diff_loss(self, layer: str) -> torch.Tensor:
        '''
        it's the part of the loss corresponding to: 
        \sum_t\lVert U_t^l-U_t^{l'}\rVert_2 + \lVert \Sigma_t^l-\Sigma_t^{l'}\rVert_2 + \lVert V_t^l-V_t^{l'}\rVert_2
        i.e. the distance from the original parameters

        we use this term to avoid sampling out of distribution weights
        '''
        delta       = self.params[layer]
        delta_prime = self.original[layer]
        return (
            torch.norm(delta['u'] - delta_prime['u']) +
            torch.norm(delta['s'] - delta_prime['s']) +
            torch.norm(delta['v'] - delta_prime['v'])
        )

    def _interference_loss(self, layer: str, x: torch.Tensor, t_x: torch.Tensor) -> torch.Tensor:
        '''
        It's the part of the loss corresponding to:
        \sum_{x^{l-1}\in \mathcal{D}', t_x \neq t}\lVert(U_t^{l'}\Sigma_t^{l'} (V_t^{l'})^T)x^{l-1} \rVert_2 
        i.e. is the magnitude of the output of the models not corresponding to the "correct one"

        we use this term to shrink out the infterference
        '''
        delta = self._reconstruct(self.params[layer]) # t, output, embedding
        y = torch.einsum('tod,bpd->tbpo', delta, x) # t, output, embedding @ batch, patch, embedding
        norms = torch.norm(y, dim=-1)
        norms = norms.sum(dim=-1)       
        T = norms.size(0)
        task_indices = torch.tensor([self.idx_map[int(label)] for label in t_x.tolist()], device=self.device)

        # only where the task index doesn't coincide with the label
        mask = torch.arange(T, device=self.device).unsqueeze(1) != task_indices.unsqueeze(0)
        return norms[mask].mean()

    def _signal_loss(self, layer: str, x: torch.Tensor, t_x: torch.Tensor) -> torch.Tensor:
        '''
        It's the part of the loss corresponding to:
        \sum_{x^{l-1}\in \mathcal{D}', t_x = t} \lVert ((U_t^l\Sigma_t^l (V_t^l)^T)- (U_t^{l'}\Sigma_t^{l'} (V_t^{l'})^T)x^{l-1})\rVert_2
        i.e. is the difference between the deltas of the tasks corresponding to the label of x times x

        it is a term we use to avoid shrking also the signal of finetuned when optimising
        '''
        delta_prime = self._reconstruct(self.params[layer])
        delta = self._reconstruct(self.original[layer])
        diff  = delta_prime - delta
        y = torch.einsum('tod,bpd->tbpo', diff, x)
        norms = torch.norm(y, dim=-1)
        norms = norms.sum(dim=-1)  
        T = norms.size(0)
        task_indices = torch.tensor([self.idx_map[int(label)] for label in t_x.tolist()], device=self.device)

        # only where the task index coincide with the label
        mask = torch.arange(T, device=self.device).unsqueeze(1) == task_indices.unsqueeze(0)
        return norms[mask].mean()
        

    def fit(
        self,
        max_epochs: int = 1,
        tol: float = 1e-4
    ) -> None:
        for epoch in range(1, max_epochs + 1):
            total_loss = 0.0

            
            layer_bar = tqdm(
                self.params,
                desc=f"Epoch {epoch}",
                unit="layer",
                position=0,
                leave=False,
                dynamic_ncols=True
            )
            for layer in layer_bar:
                layer_bar.set_description(f"Epoch {epoch} | Layer {layer}")

                dataloader = DataLoader(
                    self.dataset[svd_key_to_router_key(layer)],
                    shuffle=True,
                    batch_size=32 # TODO: remove this hard coding
                )

                batch_bar = tqdm(
                    dataloader,
                    desc="  Batch",
                    unit="batch",
                    position=0,
                    leave=True,
                    dynamic_ncols=True
                )
                for x, t_x in batch_bar:
                    x, t_x = x.to(self.device), t_x.to(self.device)

                    self.optimizer.zero_grad()

                    l_diff = self._recon_diff_loss(layer)
                    l_int  = self._interference_loss(layer, x, t_x)
                    l_sig  = self._signal_loss(layer, x, t_x)
                    batch_loss = l_diff + l_int + l_sig
                    batch_loss.backward()
                    self.optimizer.step()

                    total_loss += batch_loss.item()

                    batch_bar.set_postfix({
                        "l_diff": f"{l_diff.item():.4f}",
                        "l_int":  f"{l_int.item():.4f}",
                        "l_sig":  f"{l_sig.item():.4f}",
                        "loss":   f"{batch_loss.item():.4f}"
                    })

                layer_bar.set_postfix(last_loss=f"{batch_loss.item():.4f}")

            avg_loss = total_loss / len(dataloader)
            pylogger.info(
                f"Epoch {epoch:3d} — Avg Loss: {avg_loss:.4f} "
                f"-- l_recon: {l_diff.item():.4f} "
                f"-- l_sig:   {l_sig.item():.4f} "
                f"-- l_int:   {l_int.item():.4f}"
            )
            if avg_loss < tol:
                break

In [22]:
problem = LayerOptimProblem(cfg, svd_dicts, data, 'cuda')

In [23]:
problem.fit()

  Batch:   0%|          | 0/30 [00:00<?, ?batch/s]roj_weight:   0%|          | 0/24 [00:00<?, ?layer/s]

  Batch: 100%|██████████| 30/30 [00:27<00:00,  1.11batch/s, l_diff=0.0918, l_int=30.7587, l_sig=0.4506, loss=31.3011]
  Batch: 100%|██████████| 30/30 [00:26<00:00,  1.12batch/s, l_diff=0.1545, l_int=75.9984, l_sig=5.0002, loss=81.1530]=31.3011]    
  Batch: 100%|██████████| 30/30 [00:26<00:00,  1.12batch/s, l_diff=0.0670, l_int=36.2059, l_sig=0.4009, loss=36.6738]loss=81.1530]
  Batch: 100%|██████████| 30/30 [00:26<00:00,  1.13batch/s, l_diff=0.0630, l_int=40.9813, l_sig=0.7233, loss=41.7675]=36.6738]    
  Batch: 100%|██████████| 30/30 [00:27<00:00,  1.09batch/s, l_diff=0.0636, l_int=39.0460, l_sig=0.6694, loss=39.7790]loss=41.7675]
  Batch: 100%|██████████| 30/30 [00:27<00:00,  1.09batch/s, l_diff=0.0607, l_int=39.7140, l_sig=0.5769, loss=40.3516]=39.7790]    
  Batch: 100%|██████████| 30/30 [00:26<00:00,  1.11batch/s, l_diff=0.0521, l_int=43.0162, l_sig=0.7610, loss=43.8293]loss=40.3516]
  Batch: 100%|██████████| 30/30 [00:27<00:00,  1.11batch/s, l_diff=0.0776, l_int=50.4042, l_sig=

## Apply to pretrained

In [30]:
problem.params.keys()

dict_keys(['model.visual.resblocks.0.attn.in_proj_weight', 'model.visual.resblocks.0.mlp.c_fc.weight', 'model.visual.resblocks.1.attn.in_proj_weight', 'model.visual.resblocks.1.mlp.c_fc.weight', 'model.visual.resblocks.2.attn.in_proj_weight', 'model.visual.resblocks.2.mlp.c_fc.weight', 'model.visual.resblocks.3.attn.in_proj_weight', 'model.visual.resblocks.3.mlp.c_fc.weight', 'model.visual.resblocks.4.attn.in_proj_weight', 'model.visual.resblocks.4.mlp.c_fc.weight', 'model.visual.resblocks.5.attn.in_proj_weight', 'model.visual.resblocks.5.mlp.c_fc.weight', 'model.visual.resblocks.6.attn.in_proj_weight', 'model.visual.resblocks.6.mlp.c_fc.weight', 'model.visual.resblocks.7.attn.in_proj_weight', 'model.visual.resblocks.7.mlp.c_fc.weight', 'model.visual.resblocks.8.attn.in_proj_weight', 'model.visual.resblocks.8.mlp.c_fc.weight', 'model.visual.resblocks.9.attn.in_proj_weight', 'model.visual.resblocks.9.mlp.c_fc.weight', 'model.visual.resblocks.10.attn.in_proj_weight', 'model.visual.resblo

In [39]:

def merge_parameters(zeroshot, problem: LayerOptimProblem) -> Dict[str, torch.Tensor]:
    def _reconstruct(layer_params: Dict[str, torch.Tensor]) -> torch.Tensor:
        u = layer_params['u']                                  # [T, O, R]
        s = torch.diag_embed(layer_params['s'])               # [T, R, R]
        vT = layer_params['v']                                # [T, I, R]
        return torch.einsum('tor,trr,tri->toi', u, s, vT)

    new_state = copy.deepcopy(zeroshot.state_dict())

    for layer_key, layer_params in problem.params.items():
        if add_transformer_key(layer_key) not in new_state:
            pylogger.warning(f"Skipping layer {layer_key}")
        delta_t = _reconstruct(layer_params)    # [T, out, in]
        delta   = delta_t.sum(dim=0)            # [out, in]

        delta = delta.to(new_state[add_transformer_key(layer_key)].device).type_as(new_state[add_transformer_key(layer_key)])
        new_state[add_transformer_key(layer_key)] = new_state[add_transformer_key(layer_key)] + delta

    return new_state

            
    

merged_vector = merge_parameters(zeroshot_encoder, problem)

merged_encoder: ImageEncoder = instantiate(
    cfg.nn.module.encoder
)  # the second pass backbone

merged_encoder.load_state_dict(merged_vector, strict=False)


<All keys matched successfully>

## Evaluation

In [40]:
def evaluate(model, dataset_name, preprocess_fn):

    dataset = get_dataset(
        dataset_name,
        preprocess_fn=preprocess_fn,
        location=cfg.nn.data.data_path,
        batch_size=cfg.nn.data.batch_size.train,
    )

    trainer = pl.Trainer(
        **cfg.train.trainer,
    )

    pylogger.error("For now evaluation supported only on val-set")

    pylogger.info(f"Evaluating on the {dataset_name} test set!")
    test_results = trainer.test(model=model, dataloaders=dataset.test_loader)

    return test_results[0]["acc/test"]

In [41]:
for dataset in cfg.task_vectors.to_apply:

    model = ImageClassifier(
            encoder=merged_encoder,
            x_key='x',
            y_key='y',
            classifier=get_classification_head(
                cfg.nn.module.encoder.model_name,
                dataset,
                cfg.nn.data.data_path,
                cfg.misc.ckpt_path,
                cache_dir=cfg.misc.cache_dir,
                openclip_cachedir=cfg.misc.openclip_cachedir,
            ),
        )
    
    evaluate(model, dataset, zeroshot_encoder.val_preprocess)


Loading classification head from ../checkpoints/ViT-B-32/head_FER2013.pt


  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


Loading training dataset from ../data/fer-2013/train...
Loading test dataset from ../data/fer-2013/test...


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


Missing logger folder: /root/mass/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 225/225 [00:11<00:00, 20.29it/s]


Loading classification head from ../checkpoints/ViT-B-32/head_RESISC45.pt


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 197/197 [00:18<00:00, 10.57it/s]


Loading classification head from ../checkpoints/ViT-B-32/head_MNIST.pt


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 313/313 [00:12<00:00, 25.60it/s]
