In [1]:
import torch
import torch.optim as optim
import os
import yaml
import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/runai-home/.netrc


True

In [41]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

In [42]:
initialize_wandb(
  project=config["wandb"]["project"], 
  group=f"{config['training_params']['network']}", 
  job_type="task_specific", 
  mode=config["wandb"]["mode"], 
  config={
    "task": config['training_params']['task'],
    "network": config['training_params']['network'],
    "dataset": config['training_params']['dataset'],
    "epochs": config['training_params']['total_epochs'],
    "lr": config['training_params']['lr'],
    "batch_size": config['training_params']['batch_size'],
    "seed": config['training_params']['seed'],
  }
)

<module 'wandb' from '/opt/conda/lib/python3.10/site-packages/wandb/__init__.py'>

In [43]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [44]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [45]:
from models.dinov2.mtl.multitasker import MTLDinoV2

# Initialize model
train_tasks = create_task_flags(config["training_params"]["task"], config["training_params"]["dataset"])
print(f"Training Task: {config['training_params']['dataset'].title()} - {config['training_params']['task'].title()} in Single Task Learning Mode with {config['training_params']['network'].upper()}")

model = MTLDinoV2(
  arch_name="vit_base",
  head_tasks=train_tasks,
)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {config['training_params']['network'].title()} | Number of Trainable Parameters: {num_params/1e6:.2f}M")

Training Task: Nyuv2 - Seg in Single Task Learning Mode with SPLIT
Model: Split | Number of Trainable Parameters: 86.63M




In [46]:
model.freeze_shared_layers()
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr"], steps_per_epoch=len(train_loader), epochs=config["training_params"]["total_epochs"],  pct_start=0.1)

In [47]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [48]:
#  Training loop
model.to(device)
for epoch in range(config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    
    wandb.log({
        **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    },) # step=epoch
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    torch_save(model, "checkpoints/dinov2/linear_probing/seg_head_model.pt")

Epoch 0000 | TRAIN: Seg 1.5399 0.2619 || TEST: Seg 0.7968 0.5455 | Best: Seg 0.5455
Epoch 0001 | TRAIN: Seg 0.6280 0.6124 || TEST: Seg 0.6252 0.5999 | Best: Seg 0.5999
Epoch 0002 | TRAIN: Seg 0.5146 0.6440 || TEST: Seg 0.5694 0.6046 | Best: Seg 0.6046
Epoch 0003 | TRAIN: Seg 0.4723 0.6625 || TEST: Seg 0.5465 0.6179 | Best: Seg 0.6179
Epoch 0004 | TRAIN: Seg 0.4556 0.6649 || TEST: Seg 0.5310 0.6200 | Best: Seg 0.6200
Epoch 0005 | TRAIN: Seg 0.4267 0.6741 || TEST: Seg 0.5316 0.6269 | Best: Seg 0.6269
Epoch 0006 | TRAIN: Seg 0.4192 0.6884 || TEST: Seg 0.5183 0.6252 | Best: Seg 0.6269
Epoch 0007 | TRAIN: Seg 0.4107 0.6901 || TEST: Seg 0.5139 0.6276 | Best: Seg 0.6276
Epoch 0008 | TRAIN: Seg 0.4082 0.6989 || TEST: Seg 0.5141 0.6304 | Best: Seg 0.6304
Epoch 0009 | TRAIN: Seg 0.4103 0.7114 || TEST: Seg 0.5117 0.6307 | Best: Seg 0.6307


In [49]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

In [50]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [51]:
train_batch = len(train_loader)
test_batch = len(test_loader)
train_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 3*config["training_params"]["total_epochs"], config["training_params"]["dataset"])
test_metric = TaskMetric(train_tasks, train_tasks, config["training_params"]["batch_size"], 3*config["training_params"]["total_epochs"], config["training_params"]["dataset"])

In [52]:
model.freeze_shared_layers(requires_grad=True)
optimizer = optim.AdamW(model.parameters(), lr=config["training_params"]["lr"], weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config["training_params"]["lr"], steps_per_epoch=len(train_loader), epochs=3*config["training_params"]["total_epochs"],  pct_start=0.05)

In [53]:
#  Training loop
model.to(device)
for epoch in range(3*config["training_params"]["total_epochs"]):
    # training
    model.train()
    train_dataset = iter(train_loader)
    for k in range(train_batch):
        train_data, train_target = next(train_dataset)
        train_data = train_data.to(device)
        train_target = {task_id: train_target[task_id].to(device) for task_id in model.head_tasks}
        
        train_res = model(train_data, None, img_gt=train_target, return_loss=True)
        
        optimizer.zero_grad()
        train_res["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        train_metric.update_metric(train_res, train_target)
    
    train_str = train_metric.compute_metric()
    
    wandb.log({
        **{f"train/loss/{task_id}": train_res[task_id]["total_loss"] for task_id in model.head_tasks},
        **{f"train/metric/{task_id}": train_metric.get_metric(task_id) for task_id in model.head_tasks}
    },) # step=epoch
    train_metric.reset()

    # evaluating
    test_str = eval(epoch, model, test_loader, test_metric)

    print(f"Epoch {epoch:04d} | TRAIN:{train_str} || TEST:{test_str} | Best: {config['training_params']['task'].title()} {test_metric.get_best_performance(config['training_params']['task']):.4f}")

    # task_dict = {"train_loss": train_metric.metric, "test_loss": test_metric.metric}
    # np.save("logging/stl_{}_{}_{}_{}.npy".format(config["training_params"]["network"], config["training_params"]["dataset"], config["training_params"]["task"], config["training_params"]["seed"]), task_dict)
    torch_save(model, "checkpoints/dinov2/linear_probing/seg_model.pt")

Epoch 0000 | TRAIN: Seg 0.3939 0.7066 || TEST: Seg 0.5034 0.6356 | Best: Seg 0.6356
Epoch 0001 | TRAIN: Seg 0.3878 0.7121 || TEST: Seg 0.4897 0.6437 | Best: Seg 0.6437
Epoch 0002 | TRAIN: Seg 0.3700 0.7194 || TEST: Seg 0.4799 0.6493 | Best: Seg 0.6493
Epoch 0003 | TRAIN: Seg 0.3524 0.7335 || TEST: Seg 0.4730 0.6529 | Best: Seg 0.6529
Epoch 0004 | TRAIN: Seg 0.3494 0.7272 || TEST: Seg 0.4694 0.6560 | Best: Seg 0.6560
Epoch 0005 | TRAIN: Seg 0.3497 0.7336 || TEST: Seg 0.4635 0.6594 | Best: Seg 0.6594
Epoch 0006 | TRAIN: Seg 0.3381 0.7393 || TEST: Seg 0.4565 0.6621 | Best: Seg 0.6621
Epoch 0007 | TRAIN: Seg 0.3242 0.7505 || TEST: Seg 0.4585 0.6644 | Best: Seg 0.6644
Epoch 0008 | TRAIN: Seg 0.3347 0.7391 || TEST: Seg 0.4566 0.6630 | Best: Seg 0.6644
Epoch 0009 | TRAIN: Seg 0.3208 0.7478 || TEST: Seg 0.4519 0.6649 | Best: Seg 0.6649
Epoch 0010 | TRAIN: Seg 0.3118 0.7455 || TEST: Seg 0.4478 0.6674 | Best: Seg 0.6674
Epoch 0011 | TRAIN: Seg 0.3038 0.7643 || TEST: Seg 0.4462 0.6699 | Best: Seg

In [54]:
wandb.finish(quiet=True)

In [11]:
# img_metas (list[dict]): List of image info dict where each dict
#                 has: 'img_shape', 'scale_factor', 'flip', and may also contain
#                 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.