In [1]:
import yaml

from jinja2 import Environment, FileSystemLoader

from create_network import *
from model_merging.aggregator import aggregate_task_vectors
from model_merging.task_vectors import MTLTaskVector
from model_merging.eval_utils import perform_eval_with_merged_vector

In [4]:
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
mm_config = yaml.safe_load(rendered_yaml)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3
}

In [5]:
pt_model = model_classes[mm_config["model_merging"]["network"]]({'seg': 13, 'depth': 1}) #  'seg': 13
# pt_model = torch_load(mm_config["model_merging"]["pt_model_file"])
task_vectors = [MTLTaskVector(pt_model, ft_file) for ft_file in mm_config["model_merging"]["ft_model_files"]]

In [6]:
mtl_task_vector = aggregate_task_vectors(task_vectors, mm_config)

Norm of shared task vector:  tensor(1486.8450)


In [5]:
train_tasks_str = ' + '.join(task.title() for task in mtl_task_vector.tasks.keys())
print(f"Dataset: {mm_config['model_merging']['dataset'].title()} | Training Task: {train_tasks_str}")

Dataset: Nyuv2 | Training Task: Seg + Depth


In [6]:
perform_eval_with_merged_vector(pt_model, mtl_task_vector, mm_config)




Total evaluation time: 123.41s
seg metric: 0.3443 | depth metric: 0.7975
Delta MTL: -0.24



Total evaluation time: 150.27s
seg metric: 0.457 | depth metric: 0.5825
Delta MTL: -0.02



Total evaluation time: 179.90s
seg metric: 0.5023 | depth metric: 0.4965
Delta MTL: 0.07



Total evaluation time: 146.48s
seg metric: 0.0022 | depth metric: nan
Delta MTL: nan





KeyboardInterrupt: 

#### Task Arithmetic

Seg + Depth

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 101.41s
seg metric: 0.2166 | depth metric: 0.8741
Delta MTL: -0.39

=========================================== alpha = 0.10 ===========================================
Total evaluation time: 119.07s
seg metric: 0.2545 | depth metric: 0.8934
Delta MTL: -0.37

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 111.14s
seg metric: 0.2004 | depth metric: 1.052
Delta MTL: -0.52

...
Just gets worst for larger alpha



Seg Only

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 80.16s
seg metric: 0.2166
Delta MTL: -0.17

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 75.67s
seg metric: 0.4218
Delta MTL: -0.01

=========================================== alpha = 1.00 ===========================================
Total evaluation time: 80.37s
seg metric: 0.4601
Delta MTL: 0.02


Depth Only

=========================================== alpha = 0.00 ===========================================
Total evaluation time: 79.66s
depth metric: 0.8741
Delta MTL: -0.22

=========================================== alpha = 0.50 ===========================================
Total evaluation time: 80.56s
depth metric: 1.0947
Delta MTL: -0.37

=========================================== alpha = 1.00 ===========================================
Total evaluation time: 94.97s
depth metric: 0.6197
Delta MTL: -0.06

In [8]:
# torch_save(model, 'logging/pt_model.pt')
# torch_load('logging/model_test.pt')