In [1]:
import torch.optim as optim
import torch.utils.data.sampler as sampler
import yaml

from auto_lambda import AutoLambda
from create_network import *
from create_dataset import *
from utils import *

## Multi-Task Learning

In [92]:
# Options for training
with open('config/mtl.yaml', 'r') as file:
    mtl_config = yaml.safe_load(file)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3
}

In [93]:
torch.manual_seed(mtl_config["training_params"]["seed"])
np.random.seed(mtl_config["training_params"]["seed"])
random.seed(mtl_config["training_params"]["seed"])

# device = torch.device(f"cuda:{mtl_config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [94]:
# Create logging folder to store training weights and losses
os.makedirs('logging', exist_ok=True)

train_tasks = create_task_flags('all', mtl_config["training_params"]["dataset"], with_noise=mtl_config["training_params"]["with_noise"])
pri_tasks = create_task_flags(mtl_config["training_params"]["task"], mtl_config["training_params"]["dataset"], with_noise=False)

train_tasks_str = ' + '.join(task.title() for task in train_tasks.keys())
pri_tasks_str = ' + '.join(task.title() for task in pri_tasks.keys())
print(f"Dataset: {mtl_config['training_params']['dataset'].title()} | Training Task: {train_tasks_str} | Primary Task: {pri_tasks_str} in Multi-task / Auxiliary Learning Mode with {mtl_config['training_params']['network'].upper()}")
print(f"Applying Multi-task Methods | Weighting-based: {mtl_config['training_params']['weight'].title()} + Gradient-based: {mtl_config['training_params']['grad_method'].upper()}")

# Initialize model
model = model_classes[mtl_config["training_params"]["network"]](train_tasks).to(device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {mtl_config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")

# Choose task weighting
params = model.parameters()
if mtl_config["training_params"]["weight"] == "uncert":
    logsigma = torch.tensor([-0.7] * len(train_tasks), requires_grad=True, device=device)
    params = list(params) + [logsigma]
    logsigma_ls = np.zeros((mtl_config["training_params"]["total_epochs"], len(train_tasks)), dtype=np.float32)

elif mtl_config["training_params"]["weight"] in ["dwa", "equal"]:
    T = 2.0  # Temperature used in DWA
    lambda_weight = np.ones((mtl_config["training_params"]["total_epochs"], len(train_tasks)))

elif mtl_config["training_params"]["weight"] == 'autol':
    autol = AutoLambda(model, device, train_tasks, pri_tasks, mtl_config["training_params"]["autol_init"])
    meta_weight_ls = np.zeros((mtl_config["training_params"]["total_epochs"], len(train_tasks)), dtype=np.float32)
    meta_optimizer = optim.Adam([autol.meta_weights], lr=mtl_config["training_params"]["autol_lr"])

# Initialize optimizer and scheduler
optimizer = optim.SGD(params, lr=0.1, weight_decay=1e-4, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, mtl_config["training_params"]["total_epochs"])

Dataset: Nyuv2 | Training Task: Seg + Depth + Normal | Primary Task: Seg + Depth + Normal in Multi-task / Auxiliary Learning Mode with SPLIT
Applying Multi-task Methods | Weighting-based: Equal + Gradient-based: NONE
Model: Split | Number of Trainable Parameters: 71.89M


In [95]:
if mtl_config["training_params"]["dataset"] in mtl_config["dataset_paths"]:
    dataset_path = mtl_config["dataset_paths"][mtl_config["training_params"]["dataset"]]
    
    # Initialize datasets
    if mtl_config["training_params"]["dataset"] == 'nyuv2':
        train_set = NYUv2(root=dataset_path, train=True, augmentation=True)
        test_set = NYUv2(root=dataset_path, train=False)

    elif mtl_config["training_params"]["dataset"] == 'cityscapes':
        train_set = CityScapes(root=dataset_path, train=True, augmentation=True)
        test_set = CityScapes(root=dataset_path, train=False)
    
    # Initialize data loaders
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=mtl_config["training_params"]["batch_size"],
        shuffle=True,
        num_workers=4
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=mtl_config["training_params"]["batch_size"],
        shuffle=False
    )

    # A copy of train_loader with different data order, used for Auto-Lambda meta-update
    if mtl_config["training_params"]["weight"] == "autol":
        val_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=mtl_config["training_params"]["batch_size"],
            shuffle=True,
            num_workers=4
        )
else:
    raise ValueError(f"Unsupported dataset: {mtl_config['training_params']['dataset']}")

In [None]:
# apply gradient methods
if mtl_config["training_params"]["grad_method"] != 'none':
    rng = np.random.default_rng()
    grad_dims = []
    for mm in model.shared_modules():
        for param in mm.parameters():
            grad_dims.append(param.data.numel())
    grads = torch.Tensor(sum(grad_dims), len(train_tasks)).to(device)


# Train and evaluate multi-task network
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, pri_tasks, mtl_config["training_params"]["batch_size"], mtl_config["training_params"]["total_epochs"], mtl_config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, pri_tasks, mtl_config["training_params"]["batch_size"], mtl_config["training_params"]["total_epochs"], mtl_config["training_params"]["dataset"], include_mtl=True)

# Training loop
for index in range(mtl_config["training_params"]["total_epochs"]):
      # apply Dynamic Weight Average
    if mtl_config["training_params"]["weight"] == "dwa":
        if index == 0 or index == 1:
            lambda_weight[index, :] = 1.0
        else:
            w = []
            for i, t in enumerate(train_tasks):
                w += [train_metric.metric[t][index - 1, 0] / train_metric.metric[t][index - 2, 0]]
            w = torch.softmax(torch.tensor(w) / T, dim=0)
            lambda_weight[index] = len(train_tasks) * w.numpy()

    # iteration for all batches
    model.train()
    train_dataset = iter(train_loader)
    if mtl_config["training_params"]["weight"] == "autol":
        val_dataset = iter(val_loader)

    for k in range(train_batch):
        # train_data, train_target = train_dataset.next()
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in train_tasks.keys()}

        # update meta-weights with Auto-Lambda
        if mtl_config["training_params"]["weight"] == "autol":
            # val_data, val_target = val_dataset.next()
            val_data, val_target = next(val_dataset)
            val_data = val_data.to(device)
            val_target = {task_id: val_target[task_id].to(device) for task_id in train_tasks.keys()}

            meta_optimizer.zero_grad()
            autol.unrolled_backward(train_data, train_target, val_data, val_target,
                                    scheduler.get_last_lr()[0], optimizer)
            meta_optimizer.step()

        # update multi-task network parameters with task weights
        optimizer.zero_grad()
        train_pred = model(train_data)
        train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(train_tasks)]

        train_loss_tmp = [0] * len(train_tasks)

        if mtl_config["training_params"]["weight"] in ["equal", "dwa"]:
            train_loss_tmp = [w * train_loss[i] for i, w in enumerate(lambda_weight[index])]

        if mtl_config["training_params"]["weight"] == "uncert":
            train_loss_tmp = [1 / (2 * torch.exp(w)) * train_loss[i] + w / 2 for i, w in enumerate(logsigma)]

        if mtl_config["training_params"]["weight"] == "autol":
            train_loss_tmp = [w * train_loss[i] for i, w in enumerate(autol.meta_weights)]

        loss = sum(train_loss_tmp)

        if mtl_config["training_params"]["grad_method"] == "none":
            loss.backward()
            optimizer.step()

        # gradient-based methods applied here:
        elif mtl_config["training_params"]["grad_method"] == "graddrop":
            for i in range(len(train_tasks)):
                train_loss_tmp[i].backward(retain_graph=True)
                grad2vec(model, grads, grad_dims, i)
                model.zero_grad_shared_modules()
            g = graddrop(grads)
            overwrite_grad(model, g, grad_dims, len(train_tasks))
            optimizer.step()

        elif mtl_config["training_params"]["grad_method"] == "pcgrad":
            for i in range(len(train_tasks)):
                train_loss_tmp[i].backward(retain_graph=True)
                grad2vec(model, grads, grad_dims, i)
                model.zero_grad_shared_modules()
            g = pcgrad(grads, rng, len(train_tasks))
            overwrite_grad(model, g, grad_dims, len(train_tasks))
            optimizer.step()

        elif mtl_config["training_params"]["grad_method"] == "cagrad":
            for i in range(len(train_tasks)):
                train_loss_tmp[i].backward(retain_graph=True)
                grad2vec(model, grads, grad_dims, i)
                model.zero_grad_shared_modules()
            g = cagrad(grads, len(train_tasks), 0.4, rescale=1)
            overwrite_grad(model, g, grad_dims, len(train_tasks))
            optimizer.step()

        train_metric.update_metric(train_pred, train_target, train_loss)

    train_str = train_metric.compute_metric()
    train_metric.reset()

    # evaluating test data
    model.eval()
    with torch.no_grad():
        test_dataset = iter(test_loader)
        for k in range(test_batch):
            # test_data, test_target = test_dataset.next()
            test_data, test_target = next(test_dataset)
            test_data = test_data.to(device)
            test_target = {task_id: test_target[task_id].to(device) for task_id in train_tasks.keys()}

            test_pred = model(test_data)
            test_loss = [compute_loss(test_pred[i], test_target[task_id], task_id) for i, task_id in enumerate(train_tasks)]

            test_metric.update_metric(test_pred, test_target, test_loss)

    test_str = test_metric.compute_metric()
    test_metric.reset()

    scheduler.step()

    print(f"Epoch {index:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {mtl_config['training_params']['task'].title()} {test_metric.get_best_performance(mtl_config['training_params']['task']):.4f}")

    if mtl_config["training_params"]["weight"] == "autol":
        meta_weight_ls[index] = autol.meta_weights.detach().cpu()
        dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric, "weight": meta_weight_ls}

        print(get_weight_str(meta_weight_ls[index], train_tasks))

    if mtl_config["training_params"]["weight"] in ["dwa", "equal"]:
        dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric, "weight": lambda_weight}

        print(get_weight_str(lambda_weight[index], train_tasks))

    if mtl_config["training_params"]["weight"] == "uncert":
        logsigma_ls[index] = logsigma.detach().cpu()
        dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric, "weight": logsigma_ls}

        print(get_weight_str(1 / (2 * np.exp(logsigma_ls[index])), train_tasks))

    np.save('logging/mtl_dense_{}_{}_{}_{}_{}_{}_.npy'.format(mtl_config["training_params"]["network"], mtl_config["training_params"]["dataset"], mtl_config["training_params"]["task"], mtl_config["training_params"]["weight"], mtl_config["training_params"]["grad_method"], mtl_config["training_params"]["seed"]), dict)


## Model Merging

### Task Vectors

In [2]:
import yaml
from jinja2 import Environment, FileSystemLoader

env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
mtl_config = yaml.safe_load(rendered_yaml)


In [7]:
# torch_save(model, 'logging/pt_model.pt')
# torch_load('logging/model_test.pt')

In [3]:
from model_merging.task_vectors import MTLTaskVector

pt_model = torch_load('logging/mtl_model.pt').state_dict()
seg_task_vector = MTLTaskVector('logging/mtl_model.pt', 'logging/seg_model.pt')
depth_task_vector = MTLTaskVector('logging/mtl_model.pt', 'logging/depth_model.pt')
task_vectors = [seg_task_vector, depth_task_vector]

In [8]:
import copy
from collections import OrderedDict

def state_dict_to_vector(state_dict, remove_keys=[]):
    shared_state_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in shared_state_dict:
            del shared_state_dict[key]

    sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
    return torch.nn.utils.parameters_to_vector([value.reshape(-1) for key, value in sorted_shared_state_dict.items()])


def vector_to_state_dict(vector, state_dict, remove_keys=[]):
    # create a reference dict to define the order of the vector
    reference_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in reference_dict:
            del reference_dict[key]
    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

    # create a shared state dict using the refence dict
    torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

    # add back the encoder and decoder embedding weights.
    if "transformer.shared.weight" in sorted_reference_dict:
        for key in remove_keys:
            sorted_reference_dict[key] = sorted_reference_dict["transformer.shared.weight"]
    return sorted_reference_dict

def check_state_dicts_equal(state_dict1, state_dict2):
    if set(state_dict1.keys()) != set(state_dict2.keys()):
        return False

    for key in state_dict1.keys():
        if not torch.equal(state_dict1[key], state_dict2[key]):
            return False

    return True

def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False):
    if K == 100:
        # print("Not applying mask")
        if return_mask:
            return M, torch.ones_like(M), None
        else:
            return M, torch.ones_like(M)

    if K >= 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    # Find the k-th smallest element by magnitude for each row
    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if reshape_mask:
        final_mask = final_mask.reshape(M.shape)

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask
    else:
        return M * final_mask, final_mask.float().mean(dim=1)

In [27]:
print(f"Flattening out Checkpoints")
remove_keys = []
flat_task_vectors = torch.vstack([state_dict_to_vector(state_dict.vector, remove_keys) for state_dict in task_vectors])
flat_task_vectors.shape

Flattening out Checkpoints


torch.Size([2, 23561152])

In [11]:
# from model_merging.tallmask_utils import construct_consensus_mask, construct_tall_mask, load_tall_mask
# from model_merging.ties_utils import ties_merging

In [12]:
merge_method = mtl_config["model_merging"]["method"]
merge_config = mtl_config[merge_method]

if merge_config["name"] == "ties":
    # TIES Merging
    merge_func = "dis-mean"
    # merged_tv = ties_merging(flat_task_vectors, reset_thresh=merge_config["k"], merge_func=merge_func)
elif merge_config["name"] in ["sum", "zeroshot", "average"]:
    # "sum" corresponds to Task Arithmetic (TA)
    # TA, zeroshot, weight average all construct the task vector with sum, but use different scaling factors.
    flat_task_vectors, _ = topk_values_mask(flat_task_vectors, K=merge_config["k"], return_mask=False)
    merged_tv = flat_task_vectors.sum(dim=0)
# TODO:
# elif merge_config["name"] == "tall_mask":
#     # construct multi-task vector
#     if merge_config["use_ties"]:
#         print(f"Using TIES for constructing multi-task vector")
#         merged_tv = ties_merging(flat_task_vectors, reset_thresh=20, merge_func=f"dis-sum")
#     else:
#         print(f"Using Task Arithmetic for constructing multi-task vector")
#         flat_task_vectors, _ = topk_values_mask(flat_task_vectors, K=merge_config["k"], return_mask=False)
#         merged_tv = flat_task_vectors.sum(dim=0)
#     # get TALL masks
#     if merge_config["load_masks"]:
#         # load tall masks directly from storage
#         eval_masks = load_tall_mask(remove_keys, ptm_check, config)
#     else:
#         print(f"=== Constructing TALL Mask ===")
#         # construct tall masks
#         eval_masks = construct_tall_mask(
#             flat_task_vectors, flat_ft, flat_ptm, merged_tv, ptm_check, remove_keys, config
#         )
# elif merge_config["name"] == "consensus":  # consensus merging
#     # construct consensus mask (assuming the TALL masks have already been constructed)
#     consensus_mask = construct_consensus_mask(ptm_check, merge_config["prun_thre_k"], config, remove_keys)
#     # construct multi-task vector
#     if merge_config["use_ties"]:
#         merged_tv = ties_merging(flat_task_vectors, reset_thresh=20, merge_func="dis-sum")
#     else:
#         flat_task_vectors, _ = topk_values_mask(
#             flat_task_vectors, K=merge_config["k"], return_mask=False
#         )  # top-k mag filtering
#         merged_tv = flat_task_vectors.sum(dim=0)
#     # apply the consensus mask to filter multi-task vector
#     merged_tv = merged_tv * consensus_mask
# elif merge_config["name"] == "mag_masking":
#     # Magnitude masking baseline
#     print(f"=== Using Magnitude Masking ===")
#     merged_tv = flat_task_vectors.sum(dim=0)
#     _, _, eval_masks = topk_values_mask(flat_task_vectors, K=merge_config["k"], return_mask=True)
#     eval_masks = [vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys) for mask in eval_masks]
#     eval_masks = {key: value for key, value in zip(mtl_config.DATASETS, eval_masks)}
else:
    raise ValueError(f"Method {mtl_config['model_merging']['name']} not defined.")

merged_tv_state_dict = vector_to_state_dict(merged_tv, task_vectors[0].vector, remove_keys=remove_keys)





# task_vector = NonLinearTaskVector(model_name=mtl_config.model, vector=merged_tv_state_dict)
# print("Norm of task vector: ", task_vector.norm())

# if merge_config["name"] not in ["tall_mask", "mag_masking"]:
#     eval_masks = None

# return task_vector, eval_masks
