# MTL Task Relationships Analysis

This notebook analyzes the relationships between different tasks in a multi-task DINOv2 model. It evaluates how task vectors perform when scaled and merged, providing insights into task interference and synergies.

In [None]:
import sys
import os

sys.path.append(os.path.abspath(".."))

In [None]:
import itertools
from itertools import chain

# Third-party imports
from jinja2 import Environment, FileSystemLoader
import matplotlib.pyplot as plt
import numpy as np
import wandb
import yaml

# Local imports
from model_merging.aggregator import aggregate_task_vectors
from model_merging.eval_utils import (
    evaluate_task_vector_at_coef,
)
from model_merging.task_vectors import MTLTaskVector
from model_merging.utils import compute_cosine_similarity_matrix
from models.dinov2.mtl.multitasker import MTLDinoV2
from training.create_network import *
from utils import initialize_wandb

# Login to wandb
wandb.login()

## Configuration

Load the configuration template and set up model classes.

In [None]:
# Load configuration from template
env = Environment(loader=FileSystemLoader("."))
template = env.get_template("config/mtl.yaml.j2")
rendered_yaml = template.render()
mm_config = yaml.safe_load(rendered_yaml)

# Define model classes
model_classes = {
    "dinov2": MTLDinoV2,
}

## Wandb Initialization

Initialize Weights & Biases for logging experiments.

In [None]:
# Initialize Wandb for experiment tracking
initialize_wandb(
    project=mm_config["wandb"]["project"],
    group=f"{mm_config['training_params']['network']}",
    job_type="task_relationships",
    mode="offline",
    config={
        "network": mm_config["model_merging"]["network"],
        "dataset": mm_config["model_merging"]["dataset"],
        "batch_size": mm_config["training_params"]["batch_size"],
        "ft_model_files": mm_config["model_merging"]["ft_model_files"],
        "method": mm_config["model_merging"]["method"],
        "seed": mm_config["training_params"]["seed"],
    },
)

## Model Setup

Define the tasks and initialize the pre-trained model.

In [None]:
# Define the tasks for multi-task learning
train_tasks = {
    "seg": {
        "num_classes": 13,
    },
    "depth": {
        "num_classes": 1,
        "min_depth": 0.00,
        "max_depth": 10.0,
    },
    "normal": {
        "num_classes": 3,
    },
}

pt_model = MTLDinoV2(
    arch_name="vit_base",
    head_tasks=train_tasks,  # Task configurations
    head_archs="dpt-add_small",  # Decoder architecture for heads
    out_index=[2, 5, 8, 11],  # Output indices from transformer layers
    cls_token=True,  # Include class token
)

## Task Vectors

Load fine-tuned models and compute task vectors for each task combination.

In [None]:
# Create task vectors from fine-tuned model files
task_vectors = {
    " + ".join(task_vector.head_tasks.keys()): task_vector
    for ft_file in mm_config["model_merging"]["ft_model_files"]
    for task_vector in [MTLTaskVector(pt_model, ft_file)]
}

In [None]:
# Compute and display task vector norms
[task.title() + f": {task_vectors[task].norm().item():.2f} " for task in task_vectors]

In [None]:
# Compute cosine similarity matrix between task vectors
compute_cosine_similarity_matrix(task_vectors)

## Evaluation of Task Relationships

Evaluate each task vector at different scaling coefficients.

In [None]:
# Evaluate task relationships by scaling coefficients
task_relationships = {}
for pri_task in itertools.combinations(task_vectors.keys(), 1):
    print(f"\nMain task: {pri_task}")
    mtl_tv, _ = aggregate_task_vectors(
        {task: task_vectors[task] for task in pri_task}, mm_config
    )
    mtl_tv.tau = dict(chain(*(task_vectors[task].tau.items() for task in task_vectors)))
    mtl_tv.head_tasks = dict(
        chain(*(task_vectors[task].head_tasks.items() for task in task_vectors))
    )

    scaling_coef_range = np.linspace(0.0, 1.5, 16)

    info = {}
    for scaling_coef in scaling_coef_range:
        print(f"Evaluating model with task coefficient: {scaling_coef}")
        info[scaling_coef] = evaluate_task_vector_at_coef(
            pt_model,
            mtl_tv,
            mm_config,
            scaling_coef,
            use_val_dataset=False,
            eval_masks=None,
        )
    task_relationships[pri_task] = info

## Baseline Metrics

Define baseline metrics for normalization.

In [None]:
# Define baseline task-specific metrics for normalization
ts_metrics = {"seg": 0.7079, "depth": 0.2314, "normal": 18.02}

## Plotting Results

Visualize the normalized metrics for each primary task across scaling coefficients.

In [None]:
# Define colors for each task
task_colors = {"seg": "blue", "depth": "green", "normal": "red"}

# Define task labels for legend
task_labels = {"seg": "Segmentation", "depth": "Depth", "normal": "Surface Normal"}

num_tasks = len(task_vectors)
fig, axs = plt.subplots(1, num_tasks, figsize=(6 * num_tasks, 6))

# Ensure axs is always iterable
if num_tasks == 1:
    axs = [axs]

for i, task in enumerate(task_vectors.keys()):
    pri_task = (task,)
    info = task_relationships[pri_task]
    x_values = sorted(info.keys())  # Ensure x_values are sorted

    # Prepare y-values for each task's metrics
    metrics = {}
    for t in ["seg", "depth", "normal"]:
        if t == "seg":
            # Higher is better for seg
            metrics[t] = [info[key][t]["metric"][0] / ts_metrics[t] for key in x_values]
        else:
            # Lower is better for depth and normal
            metrics[t] = [ts_metrics[t] / info[key][t]["metric"][0] for key in x_values]

    # Plot metrics on the i-th subplot
    for t in ["seg", "depth", "normal"]:
        axs[i].plot(
            x_values,
            metrics[t],
            label=task_labels[t],
            marker="o",
            color=task_colors[t],
            linewidth=2,
        )

    # Add baseline line
    axs[i].axhline(y=1.0, color="black", linestyle="--", linewidth=1, label="Baseline")

    # Add labels and title
    axs[i].set_title(f"Primary Task: {task}", fontsize=18)
    axs[i].set_xlabel("Task Arithmetic Coefficient", fontsize=16)
    axs[i].set_ylabel("Normalized Metric", fontsize=16)
    axs[i].set_ylim(0.0, 1.5)  # Adjusted to allow for values above 1.2
    axs[i].tick_params(axis="x", labelsize=14)
    axs[i].tick_params(axis="y", labelsize=14)
    axs[i].legend(fontsize=12, title_fontsize=14)
    axs[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Cleanup

Finish logging and clean up resources.

In [None]:
# Finish Wandb logging
wandb.finish(quiet=True)