# Traditional Model Merging Notebook

This notebook allows to perform traditional model merging techniques on the vision encoder of the MTL model.

In [None]:
import os
import yaml
import wandb
import numpy as np
import random
import torch
import torch.optim as optim

# Import Jinja2 for template rendering
from jinja2 import Environment, FileSystemLoader

# Import custom modules for model merging
from model_merging.aggregator import aggregate_task_vectors
from model_merging.eval_utils import perform_eval_with_merged_vector
from model_merging.utils import compute_cosine_similarity_matrix
from model_merging.task_vectors import MTLTaskVector

from models.dinov2.mtl.multitasker import MTLDinoV2
from training.create_network import *
from training.utils import TaskMetric, eval

from utils import get_data_loaders, initialize_wandb, torch_load, torch_save

# Login to Weights & Biases
wandb.login()

## Configuration

Load the configuration template and set up model classes.

In [None]:
# Load configuration from template
env = Environment(loader=FileSystemLoader("."))
template = env.get_template("config/mtl.yaml.j2")
rendered_yaml = template.render()
mm_config = yaml.safe_load(rendered_yaml)

# Define model classes
model_classes = {
    "dinov2": MTLDinoV2,
}

## Wandb Initialization

Initialize Weights & Biases for logging experiments.

In [None]:
# Initialize Weights & Biases for experiment tracking
initialize_wandb(
    project=mm_config["wandb"]["project"],
    group=f"{mm_config['training_params']['network']}",
    job_type="model_merging",
    mode="offline",
    config={
        "network": mm_config["model_merging"]["network"],
        "dataset": mm_config["model_merging"]["dataset"],
        "batch_size": mm_config["training_params"]["batch_size"],
        "ft_model_files": mm_config["model_merging"]["ft_model_files"],
        "method": mm_config["model_merging"]["method"],
        "seed": mm_config["training_params"]["seed"],
    },
)

## Model Setup

Define the tasks and initialize the pre-trained model.

In [None]:
train_tasks = {
    "seg": {
        "num_classes": 13,
    },  # Semantic segmentation with 13 classes
    "depth": {
        "num_classes": 1,
        "min_depth": 0.00,
        "max_depth": 10.0,
    },  # Depth estimation
    "normal": {
        "num_classes": 3,
    },  # Surface normal estimation
}

pt_model = MTLDinoV2(
    arch_name="vit_base",
    head_tasks=train_tasks,  # Task configurations
    head_archs="dpt-add_small",  # Decoder architecture for heads
    out_index=[2, 5, 8, 11],  # Output indices from transformer layers
    cls_token=True,  # Include class token
)

## Task Vectors

Load fine-tuned models and compute task vectors for each task combination.

In [None]:
task_vectors = {
    " + ".join(task_vector.head_tasks.keys()): task_vector
    for ft_file in mm_config["model_merging"]["ft_model_files"]
    for task_vector in [MTLTaskVector(pt_model, ft_file)]
}

In [None]:
# Compute and display the norms of each task vector
[task.title() + f": {task_vectors[task].norm().item():.2f} " for task in task_vectors]

In [None]:
# Compute the cosine similarity matrix between task vectors
compute_cosine_similarity_matrix(task_vectors)

## Aggregation

Aggregate the task vectors using the specified merging method to create a multi-task vector.

In [None]:
# Aggregate task vectors using the configured merging method (e.g. Task Arithmetic, TIES, ...)
mtl_task_vector, masks = aggregate_task_vectors(task_vectors, mm_config)

In [None]:
train_tasks_str = " + ".join(task.title() for task in mtl_task_vector.head_tasks.keys())
print(
    f"Dataset: {mm_config['model_merging']['dataset'].title()} | Training Task: {train_tasks_str}"
)

## Evaluation

Evaluate the performance of the merged model on the test dataset.

In [None]:
# Perform evaluation of the merged task vector
perform_eval_with_merged_vector(pt_model, mtl_task_vector, mm_config, eval_masks=masks)

In [None]:
# Additional evaluation at a specific coefficient
from model_merging.eval_utils import evaluate_task_vector_at_coef

evaluate_task_vector_at_coef(
    pt_model, mtl_task_vector, mm_config, 0.4, use_val_dataset=False, eval_masks=None
)

## Cleanup

Finish the Weights & Biases run.

In [None]:
# Finish the Weights & Biases run quietly
wandb.finish(quiet=True)