In [None]:
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from model_merging.aggregator import aggregate_task_vectors
from model_merging.eval_utils import perform_eval_with_merged_vector
from model_merging.task_vectors import MTLTaskVector
from training.create_network import *
from utils import initialize_wandb

# Login to wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [2]:
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
mm_config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3
}

In [3]:
initialize_wandb(
  project=mm_config["wandb"]["project"], 
  group=f"{mm_config['training_params']['network']}", 
  job_type="model_merging", 
  mode=mm_config["wandb"]["mode"], 
  config={
    "network": mm_config['model_merging']['network'],
    "dataset": mm_config['model_merging']['dataset'],
    "batch_size": mm_config['training_params']['batch_size'],
    "ft_model_files": mm_config['model_merging']['ft_model_files'],
    "method": mm_config['model_merging']['method'],
    "seed": mm_config['training_params']['seed'],
  }
)

<module 'wandb' from '/opt/anaconda3/lib/python3.11/site-packages/wandb/__init__.py'>

In [4]:
pt_model = model_classes[mm_config["model_merging"]["network"]]({'seg': 13, 'depth': 1})
# pt_model = torch_load(mm_config["model_merging"]["pt_model_file"])
task_vectors = [MTLTaskVector(pt_model, ft_file) for ft_file in mm_config["model_merging"]["ft_model_files"]]

In [5]:
mtl_task_vector = aggregate_task_vectors(task_vectors, mm_config)

Norm of shared task vector:  tensor(627.8261)


In [6]:
train_tasks_str = ' + '.join(task.title() for task in mtl_task_vector.tasks.keys())
print(f"Dataset: {mm_config['model_merging']['dataset'].title()} | Training Task: {train_tasks_str}")

Dataset: Nyuv2 | Training Task: Seg + Depth


In [7]:
perform_eval_with_merged_vector(pt_model, mtl_task_vector, mm_config)






Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..0.8901961].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..0.8901961].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..0.8901961].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..0.8901961].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].
Clipping input data to the valid range for im

Total evaluation time: 27.23s
seg metric: nan | depth metric: 0.8344
Delta MTL: nan





KeyboardInterrupt: 

In [8]:
wandb.finish(quiet=True)

VBox(children=(Label(value='1.132 MB of 1.132 MB uploaded (0.055 MB deduped)\r'), FloatProgress(value=1.0, max…

#### Task Arithmetic

Seg + Depth

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 101.41s
seg metric: 0.2166 | depth metric: 0.8741
Delta MTL: -0.39

=========================================== alpha = 0.10 ===========================================
Total evaluation time: 119.07s
seg metric: 0.2545 | depth metric: 0.8934
Delta MTL: -0.37

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 111.14s
seg metric: 0.2004 | depth metric: 1.052
Delta MTL: -0.52

...
Just gets worst for larger alpha



Seg Only

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 80.16s
seg metric: 0.2166
Delta MTL: -0.17

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 75.67s
seg metric: 0.4218
Delta MTL: -0.01

=========================================== alpha = 1.00 ===========================================
Total evaluation time: 80.37s
seg metric: 0.4601
Delta MTL: 0.02


Depth Only

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 79.66s
depth metric: 0.8741
Delta MTL: -0.22

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 80.56s
depth metric: 1.0947
Delta MTL: -0.37

=========================================== alpha = 1.00 ===========================================
Total evaluation time: 94.97s
depth metric: 0.6197
Delta MTL: -0.06

In [8]:
# torch_save(model, 'logging/pt_model.pt')
# torch_load('logging/model_test.pt')