In [1]:
import torch.optim as optim
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, compute_loss, get_weight_str, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
wandb.login()

KeyboardInterrupt: 

In [2]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
mtl_config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3
}

In [8]:
initialize_wandb(
  project=mtl_config["wandb"]["project"], 
  group=f"{mtl_config['training_params']['network']}", 
  job_type="mtl", 
  mode=mtl_config["wandb"]["mode"], 
  config={
    "task": mtl_config['training_params']['task'],
    "network": mtl_config['training_params']['network'],
    "dataset": mtl_config['training_params']['dataset'],
    "weight": mtl_config['training_params']['weight'],
    "epochs": mtl_config['training_params']['total_epochs'],
    "lr": mtl_config['training_params']['lr'],
    "batch_size": mtl_config['training_params']['batch_size'],
    "seed": mtl_config['training_params']['seed'],
  }
)

In [3]:
torch.manual_seed(mtl_config["training_params"]["seed"])
np.random.seed(mtl_config["training_params"]["seed"])
random.seed(mtl_config["training_params"]["seed"])

# device = torch.device(f"cuda:{mtl_config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_tasks = create_task_flags('all', mtl_config["training_params"]["dataset"], with_noise=mtl_config["training_params"]["with_noise"])
pri_tasks = create_task_flags(mtl_config["training_params"]["task"], mtl_config["training_params"]["dataset"], with_noise=False)

train_tasks_str = ' + '.join(task.title() for task in train_tasks.keys())
pri_tasks_str = ' + '.join(task.title() for task in pri_tasks.keys())
print(f"Dataset: {mtl_config['training_params']['dataset'].title()} | Training Task: {train_tasks_str} | Primary Task: {pri_tasks_str} in Multi-task / Auxiliary Learning Mode with {mtl_config['training_params']['network'].upper()}")
print(f"Applying Multi-task Methods | Weighting-based: {mtl_config['training_params']['weight'].title()} + Gradient-based: {mtl_config['training_params']['grad_method'].upper()}")

# Initialize model
model = model_classes[mtl_config["training_params"]["network"]](train_tasks).to(device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {mtl_config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")

# Choose task weighting
params = model.parameters()
if mtl_config["training_params"]["weight"] == "uncert":
    logsigma = torch.tensor([-0.7] * len(train_tasks), requires_grad=True, device=device)
    params = list(params) + [logsigma]
    logsigma_ls = np.zeros((mtl_config["training_params"]["total_epochs"], len(train_tasks)), dtype=np.float32)

elif mtl_config["training_params"]["weight"] in ["dwa", "equal"]:
    T = 2.0  # Temperature used in DWA
    lambda_weight = np.ones((mtl_config["training_params"]["total_epochs"], len(train_tasks)))

# Initialize optimizer and scheduler
optimizer = optim.SGD(params, lr=mtl_config["training_params"]["lr"], weight_decay=1e-4, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, mtl_config["training_params"]["total_epochs"])

In [5]:
train_loader, val_loader, test_loader = get_data_loaders(mtl_config)

In [None]:
# apply gradient methods
if mtl_config["training_params"]["grad_method"] != 'none':
    rng = np.random.default_rng()
    grad_dims = []
    for mm in model.shared_modules():
        for param in mm.parameters():
            grad_dims.append(param.data.numel())
    grads = torch.Tensor(sum(grad_dims), len(train_tasks)).to(device)


# Train and evaluate multi-task network
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, pri_tasks, mtl_config["training_params"]["batch_size"], mtl_config["training_params"]["total_epochs"], mtl_config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, pri_tasks, mtl_config["training_params"]["batch_size"], mtl_config["training_params"]["total_epochs"], mtl_config["training_params"]["dataset"], include_mtl=True)

# Training loop
for index in range(mtl_config["training_params"]["total_epochs"]):
    
    # iteration for all batches
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in train_tasks.keys()}

        # update multi-task network parameters with task weights
        optimizer.zero_grad()
        train_pred = model(train_data)
        train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(train_tasks)]

        train_loss_tmp = [0] * len(train_tasks)

        if mtl_config["training_params"]["weight"] in ["equal", "dwa"]:
            train_loss_tmp = [w * train_loss[i] for i, w in enumerate(lambda_weight[index])]

        if mtl_config["training_params"]["weight"] == "uncert":
            train_loss_tmp = [1 / (2 * torch.exp(w)) * train_loss[i] + w / 2 for i, w in enumerate(logsigma)]

        loss = sum(train_loss_tmp)
        loss.backward()
        optimizer.step()

        train_metric.update_metric(train_pred, train_target, train_loss)

    train_str = train_metric.compute_metric()
    wandb.log({
        **{f"train/loss/{task_id}": train_loss[i] for i, task_id in enumerate(train_tasks)},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in train_tasks}
    }, step=index)
    train_metric.reset()

    # evaluating
    test_str = eval(index, model, test_loader, test_metric)

    scheduler.step()

    print(f"Epoch {index:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {mtl_config['training_params']['task'].title()} {test_metric.get_best_performance(mtl_config['training_params']['task']):.4f}")

    if mtl_config["training_params"]["weight"] in ["dwa", "equal"]:
        dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric, "weight": lambda_weight}

        print(get_weight_str(lambda_weight[index], train_tasks))

    if mtl_config["training_params"]["weight"] == "uncert":
        logsigma_ls[index] = logsigma.detach().cpu()
        dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric, "weight": logsigma_ls}

        print(get_weight_str(1 / (2 * np.exp(logsigma_ls[index])), train_tasks))

    np.save('logging/mtl_dense_{}_{}_{}_{}_{}_{}_.npy'.format(mtl_config["training_params"]["network"], mtl_config["training_params"]["dataset"], mtl_config["training_params"]["task"], mtl_config["training_params"]["weight"], mtl_config["training_params"]["grad_method"], mtl_config["training_params"]["seed"]), dict)


In [7]:
torch_save(model, "models/mtl_model.pt")

In [11]:
wandb.finish(quiet=True)