In [1]:
import torch
import torch.optim as optim
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjuan-garciagiraldo[0m ([33mjuagarci[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3,
  # "dinov2": MTLDinoVisionTransformer,
}

In [3]:
initialize_wandb(
  project=config["wandb"]["project"], 
  group=f"{config['training_params']['network']}", 
  job_type="task_specific", 
  mode=config["wandb"]["mode"], 
  config={
    "task": config['training_params']['task'],
    "network": config['training_params']['network'],
    "dataset": config['training_params']['dataset'],
    "epochs": config['training_params']['total_epochs'],
    "lr": config['training_params']['lr'],
    "batch_size": config['training_params']['batch_size'],
    "seed": config['training_params']['seed'],
  }
)

<module 'wandb' from '/opt/conda/lib/python3.10/site-packages/wandb/__init__.py'>

In [7]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [9]:
from models.dinov2.mtl.multitasker import MTLDinoV2

# Initialize model
train_tasks = create_task_flags(config["training_params"]["task"], config["training_params"]["dataset"])
print(f"Training Task: {config['training_params']['dataset'].title()} - {config['training_params']['task'].title()} in Single Task Learning Mode with {config['training_params']['network'].upper()}")

model = MTLDinoV2(
  arch_name="vit_base",
  head_tasks=train_tasks,
)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")



Training Task: Nyuv2 - Seg in Single Task Learning Mode with SPLIT




Model: Split | Number of Trainable Parameters: 86.63M


In [10]:
model.freeze_shared_layers()
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr"], steps_per_epoch=len(train_loader), epochs=config["training_params"]["total_epochs"],  pct_start=0.1)

In [11]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [12]:
#  Training loop
model.to(device)
for epoch in range(config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    
    wandb.log({
        **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    },) # step=epoch
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    torch_save(model, "checkpoints/dinov2/linear_probing/depth_head_model.pt")

  return F.conv2d(input, weight, bias, self.stride,


Epoch 0000 | TRAIN: Seg 1.3317 0.3201 || TEST: Seg 0.7364 0.5711 | Best: Seg 0.5711
Epoch 0001 | TRAIN: Seg 0.5917 0.6254 || TEST: Seg 0.6129 0.6041 | Best: Seg 0.6041
Epoch 0002 | TRAIN: Seg 0.5039 0.6507 || TEST: Seg 0.5720 0.6051 | Best: Seg 0.6051
Epoch 0003 | TRAIN: Seg 0.4711 0.6662 || TEST: Seg 0.5534 0.6185 | Best: Seg 0.6185
Epoch 0004 | TRAIN: Seg 0.4668 0.6641 || TEST: Seg 0.5553 0.6185 | Best: Seg 0.6185


In [13]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [14]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 5*config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 5*config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [15]:
model.freeze_shared_layers(requires_grad=True)
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr"], steps_per_epoch=len(train_loader), epochs=5*config["training_params"]["total_epochs"],  pct_start=0.05)

In [16]:
#  Training loop
model.to(device)
for epoch in range(5*config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    
    wandb.log({
        **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    },) # step=epoch
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    torch_save(model, "checkpoints/dinov2/linear_probing/depth_model.pt")

Epoch 0000 | TRAIN: Seg 1.4587 0.2263 || TEST: Seg 4.1883 0.0324 | Best: Seg 0.0324
Epoch 0001 | TRAIN: Seg 1.9054 0.0679 || TEST: Seg 1.7768 0.0664 | Best: Seg 0.0664
Epoch 0002 | TRAIN: Seg 1.7668 0.0832 || TEST: Seg 1.8019 0.0513 | Best: Seg 0.0664
Epoch 0003 | TRAIN: Seg 1.7591 0.0852 || TEST: Seg 1.8789 0.0808 | Best: Seg 0.0808
Epoch 0004 | TRAIN: Seg 1.7298 0.0885 || TEST: Seg 1.9900 0.0658 | Best: Seg 0.0808
Epoch 0005 | TRAIN: Seg 1.6877 0.0917 || TEST: Seg 1.8294 0.0719 | Best: Seg 0.0808
Epoch 0006 | TRAIN: Seg 1.6428 0.0978 || TEST: Seg 1.7732 0.0854 | Best: Seg 0.0854
Epoch 0007 | TRAIN: Seg 1.6471 0.0977 || TEST: Seg 2.0463 0.0515 | Best: Seg 0.0854
Epoch 0008 | TRAIN: Seg 1.5888 0.1008 || TEST: Seg 1.7278 0.0782 | Best: Seg 0.0854
Epoch 0009 | TRAIN: Seg 1.5584 0.1055 || TEST: Seg 1.9291 0.0693 | Best: Seg 0.0854
Epoch 0010 | TRAIN: Seg 1.5267 0.1158 || TEST: Seg 1.7627 0.0712 | Best: Seg 0.0854
Epoch 0011 | TRAIN: Seg 1.5166 0.1207 || TEST: Seg 1.9027 0.0598 | Best: Seg

KeyboardInterrupt: 

In [10]:
wandb.finish(quiet=True)

In [11]:
# img_metas (list[dict]): List of image info dict where each dict
#                 has: 'img_shape', 'scale_factor', 'flip', and may also contain
#                 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.