In [None]:
import torch.optim as optim
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, compute_loss, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
wandb.login()

In [2]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3
}

In [None]:
initialize_wandb(
  project=config["wandb"]["project"], 
  group=f"{config['training_params']['network']}", 
  job_type="task_specific", 
  mode=config["wandb"]["mode"], 
  config={
    "task": config['training_params']['task'],
    "network": config['training_params']['network'],
    "dataset": config['training_params']['dataset'],
    "epochs": config['training_params']['total_epochs'],
    "lr": config['training_params']['lr'],
    "batch_size": config['training_params']['batch_size'],
    "seed": config['training_params']['seed'],
  }
)

In [3]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [None]:
train_tasks = create_task_flags(config["training_params"]["task"], config["training_params"]["dataset"])
print(f"Training Task: {config['training_params']['dataset'].title()} - {config['training_params']['task'].title()} in Single Task Learning Mode with {config['training_params']['network'].upper()}")

# Initialize model
model = model_classes[config["training_params"]["network"]](train_tasks).to(device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")

In [6]:
model.freeze_shared_layers()
optimizer = optim.SGD(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, config["training_params"]["total_epochs"])

In [None]:
# Train and evaluate multi-task network
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])

#  Training loop
for epoch in range(config["training_params"]["total_epochs"]):

    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in train_tasks.keys()}

        train_pred = model(train_data)
        optimizer.zero_grad()

        train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(train_tasks)]
        train_loss[0].backward()
        optimizer.step()

        train_metric.update_metric(train_pred, train_target, train_loss)

    train_str = train_metric.compute_metric()
    wandb.log({
        **{f"train/loss/{task_id}": train_loss[i] for i, task_id in enumerate(train_tasks)},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in train_tasks}
    }, step=epoch)
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    scheduler.step()

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)

In [7]:
# model.freeze_shared_layers(requires_grad=True)
# optimizer = optim.SGD(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4, momentum=0.9)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, config["training_params"]["total_epochs"])

# TODO: better linear probing + full fine-tuning

In [5]:
torch_save(model, "models/seg_model.pt")

In [None]:
wandb.finish(quiet=True)