## Imports

In [1]:
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

from model_merging.data.dataset import HFImageClassification
from model_merging.model.image_classifier import ImageClassifier
import open_clip
import wandb

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import model_merging  # noqa
from model_merging.model.encoder import ClassificationHead, ImageEncoder
from model_merging.model.heads import (
    get_classification_head,
)
from model_merging.utils.io_utils import (
    boilerplate,
    load_model_from_hf,
)
from model_merging.utils.plots import plot_interactive_radar_chart
from model_merging.utils.utils import (
    build_callbacks,
    get_finetuning_accuracies,
    compute_avg_accuracy,
    print_memory,
)
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


  import pkg_resources
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [6]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="layer_analysis")
cfg = compose(config_name="multitask", overrides=["benchmark=hard"])

'hydra/launcher/basic' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  coro.send(None)


## Boilerplate

In [8]:
pylogger = logging.getLogger(__name__)

In [7]:
seed_index_everything(cfg)

logger, template_core = boilerplate(cfg)

num_tasks = len(cfg.benchmark.datasets)

# Temporarily disable struct mode to allow dynamic update
omegaconf.OmegaConf.set_struct(cfg, False)
cfg.num_tasks = num_tasks  # Now we can safely update it
omegaconf.OmegaConf.set_struct(cfg, True)  # Re-enable struct mode

# upperbound accuracies, used for logging the normalized accuracy
finetuned_accuracies: Dict[str, float] = get_finetuning_accuracies(
    cfg.misc.finetuned_accuracy_path
)[cfg.nn.encoder.model_name]

  rank_zero_warn(


## Load models

In [11]:
# only has vision encoder, no text transformer
base_model: ImageEncoder = load_model_from_hf(
    model_name=cfg.nn.encoder.model_name
)

finetuned_models = {
    dataset: load_model_from_hf(
        model_name=cfg.nn.encoder.model_name, dataset_name=dataset.name
    ).state_dict()
    for dataset in cfg.benchmark.datasets
}

pylogger.info(f"Number of tasks: {cfg.num_tasks}")
pylogger.info(f"Finetuned models: {list(finetuned_models.keys())}")


## Merging

In [12]:
non_matrix_params_aggregation = 'mean'
compression_ratio = 1 / len(cfg.benchmark.datasets)

In [13]:
from model_merging.merging.structured import aggregate_decomposed_task_vectors, decompose_task_vectors, get_svd_dict
from model_merging.utils.utils import apply_dict_to_model, compute_task_dict

task_dicts = {}

for dataset in cfg.benchmark.datasets:
    task_dicts[dataset] = compute_task_dict(
        base_model.state_dict(), finetuned_models[dataset]
    )
    del finetuned_models[dataset]  # Delete one model at a time
    torch.cuda.empty_cache()

print_memory("after computing task dicts")

In [14]:
decomposed_task_vectors = decompose_task_vectors(task_dicts, compression_ratio)

multi_task_vector = aggregate_decomposed_task_vectors(
    ref_state_dict=copy.deepcopy(base_model.state_dict()),
    decomposed_task_vectors=decomposed_task_vectors,
    non_matrix_params_aggregation=non_matrix_params_aggregation
)

Computing and compressing SVD: 100%|██████████| 3/3 [00:12<00:00,  4.24s/it]
Summing SVD: 100%|██████████| 158/158 [00:06<00:00, 23.47it/s]


In [15]:
merged_encoder: ImageEncoder = copy.deepcopy(base_model)

merged_encoder = apply_dict_to_model(
    multi_task_vector,
    merged_encoder,
)

In [None]:
results = {}
print_memory("before eval")
for dataset_cfg in cfg.benchmark.datasets:

    dataset = instantiate(
        dataset_cfg, preprocess_fn=base_model.val_preprocess
    )

    classification_head = get_classification_head(
        cfg.nn.encoder.model_name,
        dataset_cfg.name,
        ckpt_path=cfg.misc.ckpt_path,
        openclip_cachedir=cfg.misc.openclip_cachedir,
        device=cfg.device,
    )

    model = ImageClassifier(
        encoder=merged_encoder,
        classifier=classification_head,
        x_key=cfg.conventions.x_key,
        y_key=cfg.conventions.y_key,
    )

    model.set_metrics(len(dataset.classnames))
    model.set_task(dataset_cfg.name)
    model.set_finetuning_accuracy(
        finetuned_accuracies[
            dataset_cfg.name + "Val" if cfg.eval_on_train else dataset_cfg.name
        ]
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

    trainer = pl.Trainer(
        default_root_dir=cfg.core.storage_dir,
        plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
        logger=logger,
        callbacks=callbacks,
        limit_test_batches=(
            cfg.number_of_train_batches if cfg.eval_on_train else None
        ),
        **cfg.train.trainer,
    )

    if cfg.eval_on_train:
        pylogger.error("For now evaluation supported only on val-set")
        pylogger.info(f"Evaluating on {dataset_cfg.name} the training set")
        test_results = trainer.test(model=model, dataloaders=dataset.train_loader)

    else:
        pylogger.info(f"Evaluating on the {dataset_cfg.name} test set!")
        test_results = trainer.test(model=model, dataloaders=dataset.test_loader)

    results[dataset_cfg.name] = test_results

avg = compute_avg_accuracy(results)
results["avg"] = [
    avg
]  # as a list for consistency due to lightning logging stuff this way

logger.experiment.log(avg)

pylogger.info(results)