In [None]:
import torch
import torch.optim as optim
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from models.dinov2.mtl.multitasker import MTLDinoV2
from training.utils import create_task_flags, TaskMetric, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
wandb.login()

In [2]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

In [None]:
# initialize_wandb(
#   project=config["wandb"]["project"], 
#   group=f"{config['training_params']['network']}", 
#   job_type="task_specific", 
#   mode=config["wandb"]["mode"], 
#   config={
#     "task": config['training_params']['task'],
#     "network": config['training_params']['network'],
#     "dataset": config['training_params']['dataset'],
#     "epochs": config['training_params']['total_epochs'],
#     "lr_head": config['training_params']['lr_head'],
#     "batch_size": config['training_params']['batch_size'],
#     "seed": config['training_params']['seed'],
#   }
# )

In [3]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [None]:
# Initialize model
config["training_params"]["task"] = "depth"
train_tasks = create_task_flags(config["training_params"]["task"], config["training_params"]["dataset"])
print(f"Training Task: {config['training_params']['dataset'].title()} - {config['training_params']['task'].title()} in Single Task Learning Mode with {config['training_params']['network'].upper()}")

model = MTLDinoV2(
  arch_name="vit_base",
  head_tasks=train_tasks,
  head_archs="dpt",
  out_index=[5, 7, 9, 11],
  cls_token=True,
)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")

In [6]:
model.freeze_shared_layers()
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr_head"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr_head"], steps_per_epoch=len(train_loader), epochs=config["training_params"]["total_epochs"],  pct_start=0.0)

In [7]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [None]:
#  Training loop
model.to(device)
for epoch in range(config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=35, norm_type=2)
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    # print(f"Epoch {epoch:04d} | TRAIN: {train_res['total_loss'].item()}")
    
    # wandb.log({
    #     **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
    #     **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    # },) # step=epoch
    train_metric.reset()
    print(f"Epoch {epoch:04d} | TRAIN:{train_str}")

    # evaluating
    # test_str = eval(epoch, model, test_loader, test_metric)

    # print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    # torch_save(model, f"{config['training_params']['out_dir']}/{config['training_params']['task']}_head_model.pt")

In [None]:
for name, param in model.named_parameters():
  if param.grad is not None:
      print(f"Parameter: {name} | Gradient: {param.grad.norm()}")
  # else:
  #     print(f"Parameter: {name} | No gradient")

In [10]:
# # Options for training
# env = Environment(loader=FileSystemLoader('.'))
# template = env.get_template('config/mtl.yaml.j2')
# rendered_yaml = template.render()
# config = yaml.safe_load(rendered_yaml)
config["training_params"]["lr"] = 0.0001

In [11]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [12]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 3*config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 3*config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [13]:
model.freeze_shared_layers(requires_grad=True)
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr"], steps_per_epoch=len(train_loader), epochs=3*config["training_params"]["total_epochs"],  pct_start=0.05)

In [None]:
#  Training loop
model.to(device)
for epoch in range(config["training_params"]["total_epochs"], 4*config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    
    wandb.log({
        **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    },) # step=epoch
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    torch_save(model, f"{config['training_params']['out_dir']}/{config['training_params']['task']}_model.pt")

In [None]:
wandb.finish(quiet=True)

In [11]:
# img_metas (list[dict]): List of image info dict where each dict
#                 has: 'img_shape', 'scale_factor', 'flip', and may also contain
#                 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.