# Single-Input Multi-Output Merging

This notebook explores SIMO merging techniques for MTL with DINOv2 models. It includes model merging, LoRA adaptation, and training for tasks like segmentation, depth estimation, and surface normal prediction.

In [None]:
# Standard library imports
import math
import random

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
import wandb
from jinja2 import Environment, FileSystemLoader

# Local imports
from model_merging.aggregator import aggregate_task_vectors
from model_merging.adamerging_utils import del_attr, set_attr, load_weights
from model_merging.eval_utils import perform_eval_with_merged_vector
from model_merging.task_vectors import MTLTaskVector
from model_merging.utils import compute_cosine_similarity_matrix
from models.dinov2.mtl.multitasker import MTLDinoV2
from training.create_network import *
from training.utils import TaskMetric, eval
from utils import get_data_loaders, initialize_wandb, torch_load, torch_save

# Login to wandb
wandb.login()

## Configuration

Loading the experiment configuration from a Jinja2 template file. This includes hyperparameters for training, model merging settings, and dataset configurations. The configuration is rendered and parsed as YAML for easy access throughout the notebook.

In [None]:
# Load configuration from Jinja2 template
env = Environment(loader=FileSystemLoader("."))
template = env.get_template("config/mtl.yaml.j2")
rendered_yaml = template.render()
mm_config = yaml.safe_load(rendered_yaml)

# Define model classes for different architectures
model_classes = {
    "dinov2": MTLDinoV2,
}

In [None]:
# Initialize Weights & Biases for experiment tracking
initialize_wandb(
    project=mm_config["wandb"]["project"],
    group=f"{mm_config['training_params']['network']}",
    job_type="representation_surgery",
    mode="offline",
    config={
        "network": mm_config["model_merging"]["network"],
        "dataset": mm_config["model_merging"]["dataset"],
        "batch_size": mm_config["training_params"]["batch_size"],
        "ft_model_files": mm_config["model_merging"]["ft_model_files"],
        "method": mm_config["model_merging"]["method"],
        "seed": mm_config["training_params"]["seed"],
    },
)

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(mm_config["training_params"]["seed"])
np.random.seed(mm_config["training_params"]["seed"])
random.seed(mm_config["training_params"]["seed"])

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model Initialization

Defining the multi-task learning tasks for NYUv2 and initializing the pre-trained DINOv2 model with task-specific decoder heads.

In [None]:
train_tasks = {
    "seg": {
        "num_classes": 13,
    },
    "depth": {
        "num_classes": 1,
        "min_depth": 0.00,
        "max_depth": 10.0,
    },
    "normal": {
        "num_classes": 3,
    },
}

pt_model = MTLDinoV2(
    arch_name="vit_base",
    head_tasks=train_tasks,  # Task configurations
    head_archs="dpt-add_small",  # Decoder architecture for heads
    out_index=[2, 5, 8, 11],  # Output indices from transformer layers
    cls_token=True,  # Include class token
)

## Task Vectors

Load fine-tuned models and compute task vectors for each task combination.

In [None]:
# Load fine-tuned models and compute task vectors
ts_models, task_vectors = {}, {}
for ft_file in mm_config["model_merging"]["ft_model_files"]:
    ts_model = torch_load(ft_file)
    task_vector = MTLTaskVector(pt_model, ts_model)

    key = " + ".join(task_vector.head_tasks.keys())
    ts_models[key] = ts_model
    task_vectors[key] = task_vector

In [None]:
# Compute and display task vector norms
[task.title() + f": {task_vectors[task].norm().item():.2f} " for task in task_vectors]

In [None]:
# Compute cosine similarity matrix between task vectors
compute_cosine_similarity_matrix(task_vectors)

## Single-Input Multi-Output Merging

### Merging Preparation

Preparing the model for parameter merging by making backbone parameters functional and organizing parameter lists. This includes separating pre-trained backbone parameters from task-specific adaptations for efficient merging.

In [None]:
# Attach task-specific heads to the pre-trained model
for task in task_vectors:
    pt_model.load_state_dict(task_vectors[task].tau, strict=False)

for head in pt_model.decoders:
    pt_model.decoders[head].eval()

pt_model_dict = pt_model.state_dict()

In [None]:
# Make backbone parameters functional by deleting attributes
names = list(next(iter(task_vectors.values())).theta.keys())
for name in names:
    del_attr(pt_model, name.split("."))

In [None]:
# Prepare parameter lists for merging: backbone and task vectors
paramslist = []
paramslist += [
    tuple(
        v.detach().requires_grad_().cpu()
        for k, v in pt_model_dict.items()
        if not any(task in k for task in pt_model.head_tasks)
    )
]  # Pretrained backbone
paramslist += [
    tuple(v.detach().requires_grad_().cpu() for v in tv.theta.values())
    for tv in task_vectors.values()
]  # Task vectors theta (backbone)
torch.cuda.empty_cache()

### LoRA Adaptation

Defining and applying Low-Rank Adaptation to the model backbone blocks for (efficient) fine-tuning of shared representation.

In [None]:
# Define LoRA adaptor for efficient fine-tuning
class LoRAAdaptor(nn.Module):
    def __init__(self, input_dim, rank):
        super(LoRAAdaptor, self).__init__()
        self.down_proj = nn.Linear(input_dim, rank, bias=False)
        self.up_proj = nn.Linear(rank, input_dim, bias=False)
        self.non_linear_func = nn.ReLU()

        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
        nn.init.zeros_(self.up_proj.weight)

    def forward(self, x):
        return self.up_proj(self.non_linear_func(self.down_proj(x)))


def modify_block_with_lora(block, rank=16):
    # Modify a transformer block by adding LoRA adaptation.
    block.lora = LoRAAdaptor(block.norm2.normalized_shape[0], rank=rank)
    original_forward = block.forward

    def new_forward(x):
        x = original_forward(x)
        return x - block.lora(x)  # Apply LoRA adaptation

    block.forward = new_forward


def modify_decoder_with_lora(model, task, rank=16):
    # Modify a decoder head by adding LoRA.
    decoder = model.decoders[task]
    decoder.lora = LoRAAdaptor(model.backbone.embed_dim, rank=rank)
    original_forward = decoder.forward

    def new_forward(x):
        x = original_forward(x)
        return x - decoder.lora(x)  # Apply LoRA adaptation

    decoder.forward = new_forward


def add_lora_to_model_blocks(model, rank=16):
    for block in model.backbone.blocks:
        modify_block_with_lora(block, rank=rank)


def add_lora_to_model_decoders(model, rank=16):
    for task in model.decoders:
        modify_decoder_with_lora(model, task, rank=rank)


add_lora_to_model_blocks(pt_model, rank=16)
# add_lora_to_model_decoders(pt_model, rank=16)

### Merging Framework

Defining the LambdaWrapper class that handles parameter merging with learnable merging coefficients (similar to AdaMerging).

In [None]:
# Define LambdaWrapper for merging parameters with learnable lambdas
class LambdaWrapper(torch.nn.Module):
    def __init__(self, mtl_model, paramslist, names, use_learnable_lambdas=0.0):
        super(LambdaWrapper, self).__init__()
        self.paramslist = paramslist
        self.mtl_model = mtl_model
        self.head_tasks = mtl_model.head_tasks
        self.names = names
        self.use_learnable_lambdas = use_learnable_lambdas

        # task-wise lambdas
        if isinstance(use_learnable_lambdas, list):
            self.lambdas_raw = torch.Tensor(use_learnable_lambdas)  # fixed
        elif isinstance(use_learnable_lambdas, float):
            self.lambdas_raw = nn.Parameter(
                torch.ones(1, len(paramslist) - 1) * use_learnable_lambdas,
                requires_grad=True,
            )  # learnable

    def collect_trainable_params(self):
        trainable_params = []

        # Collect LoRA parameters from backbone blocks
        for block in self.mtl_model.backbone.blocks:
            if hasattr(block, "lora") and block.lora is not None:
                trainable_params.append(block.lora.down_proj.weight)
                trainable_params.append(block.lora.up_proj.weight)

        # Collect LoRA parameters from decoders
        for decoder in self.mtl_model.decoders.values():
            if hasattr(decoder, "lora") and decoder.lora is not None:
                trainable_params.append(decoder.lora.down_proj.weight)
                trainable_params.append(decoder.lora.up_proj.weight)

        if isinstance(self.use_learnable_lambdas, float):
            trainable_params.append(self.lambdas_raw)

        return trainable_params

    def lambdas(self):
        pretrain_lambda = torch.ones(
            1, 1, device=self.lambdas_raw.device, dtype=self.lambdas_raw.dtype
        )
        task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
        return torch.cat((pretrain_lambda, task_lambdas), dim=1)

    def _merge_parameters(self, _lambda):
        """Merge parameters based on lambdas."""
        if _lambda.size(0) == 1:  # task-wise merging
            params = tuple(
                sum(pi * _lambda_i for pi, _lambda_i in zip(p, _lambda[0].cpu()))
                for p in zip(*self.paramslist)
            )
        else:
            raise NotImplementedError("Layer-wise merging not implemented with LoRA.")
        return tuple(p.cuda(0) for p in params)

    def get_model(self):
        _lambda = self.lambdas()
        params = self._merge_parameters(_lambda)
        load_weights(self.mtl_model, self.names, params)
        return self.mtl_model

    def forward(self, img, img_metas, return_loss=True, **kwargs):
        _lambda = self.lambdas()
        params = self._merge_parameters(_lambda)
        load_weights(self.mtl_model, self.names, params)
        return self.mtl_model(img, img_metas, return_loss=return_loss, **kwargs)

In [None]:
# Fixed merging lambda values are used for controlling the contribution of each task vector.
rlambda = [[0.4, 0.4, 0.4]]  # Task-specific lambdas for merging
simo_model = LambdaWrapper(
    pt_model, paramslist, names, rlambda, use_learnable_lambdas=rlambda
)

## Training

In [None]:
# Set up data loaders, metrics, optimizer, and scheduler
train_loader, val_loader, test_loader = get_data_loaders(mm_config)
train_batch = len(train_loader)

mm_config["training_params"]["total_epochs"] = 1
train_metric = TaskMetric(
    train_tasks,
    train_tasks,
    mm_config["training_params"]["batch_size"],
    mm_config["training_params"]["total_epochs"],
    mm_config["training_params"]["dataset"],
    include_mtl=True,
)
val_metric = TaskMetric(
    train_tasks,
    train_tasks,
    mm_config["training_params"]["batch_size"],
    mm_config["training_params"]["total_epochs"],
    mm_config["training_params"]["dataset"],
    include_mtl=True,
)

lr = 1e-5
optimizer = optim.AdamW(
    simo_model.collect_trainable_params(),
    lr=lr,
    betas=(0.9, 0.999),
    weight_decay=0.0,
)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=lr,
    steps_per_epoch=len(train_loader),
    epochs=mm_config["training_params"]["total_epochs"],
    pct_start=0.1,
)

In [None]:
# Training loop to optimize merging coefficients and LoRA parameters
for epoch in range(mm_config["training_params"]["total_epochs"]):
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {
            task_id: train_target[task_id].to(device)
            for task_id in simo_model.head_tasks
        }

        # AdaMerging model
        train_res = simo_model(train_data, None, img_gt=train_target, return_loss=True)
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        torch.nn.utils.clip_grad_norm_(
            simo_model.collect_trainable_params(), max_norm=35.0, norm_type=2
        )
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)

    train_str = train_metric.compute_metric()
    wandb.log(
        {
            **{
                f"train/loss/{task_id}": train_res[task_id][f"loss_{task_id}"]
                for task_id in simo_model.head_tasks
            },
            **{
                f"train/metric/{task_id}": train_metric.get_metric(task_id)
                for task_id in simo_model.head_tasks
            },
        },
    )  # step=epoch
    train_metric.reset()

    # evaluating
    eval_model = simo_model.get_model()
    test_str = eval(epoch, eval_model, val_loader, val_metric)

    print(
        f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {mm_config['training_params']['task'].title()} {val_metric.get_best_performance(mm_config['training_params']['task']):.4f}"
    )

    torch.save(
        {
            "state_dict_lora": simo_model.mtl_model.backbone.state_dict(),
        },
        f"representation_surgery/nyuv2/dinov2/dpt-add_small_head/simo_{epoch}.pt",
    )

In [None]:
# Finish the wandb run
wandb.finish(quiet=True)