In [1]:
import torch
import torch.optim as optim
import os
import yaml
# import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, compute_loss, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
# wandb.login()

In [2]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3,
  # "dinov2": MTLDinoVisionTransformer,
}

In [3]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [5]:
images = next(iter(train_loader))

In [6]:
from models.dinov2.mtl.multitasker import MTLDinoV2

model = MTLDinoV2(
  arch_name="vit_small",
  head_tasks={
    "seg": {
      "num_classes": 13,
    },
    "depth": {
      "num_classes": 1,
      "min_depth": 0.001,
      "max_depth": 10.0,
    },
  },
)



In [7]:
model.state_dict().keys()

odict_keys(['backbone.cls_token', 'backbone.pos_embed', 'backbone.mask_token', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.ls1.gamma', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.0.ls2.gamma', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.qkv.bias', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.ls1.gamma', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1

In [8]:
model.to(device)
model.train()
image = images[0].detach().to(device)
seg = images[1]["seg"].detach().to(device)
depth = images[1]["depth"].detach().to(device)

train_target = {
  "seg": seg,
  "depth": depth,
}

In [9]:
output = model.forward(image, None, img_gt=train_target, return_loss=True)

In [14]:
output

{'seg': {'loss_seg': tensor(2.4415, grad_fn=<MulBackward0>),
  'pred': tensor([[[[ 0.0494,  0.6911,  0.6620,  ...,  0.0753, -0.0788,  0.1555],
            [-0.1478,  0.0450,  0.8975,  ..., -0.3039, -0.4546, -0.5659],
            [ 0.1246,  0.7141,  0.6104,  ...,  0.0733, -0.9006, -0.4133],
            ...,
            [-0.0076, -0.0243, -0.0885,  ..., -0.3157, -0.4633, -0.1530],
            [-0.3220,  0.2048,  0.0620,  ..., -0.7974, -0.5872, -0.1879],
            [-0.8817, -0.1928, -0.5731,  ..., -0.1059, -0.3573, -0.1107]],
  
           [[ 0.5629,  0.5142,  0.4831,  ...,  0.4469,  0.2807,  0.4072],
            [ 0.7329,  0.3521,  0.0437,  ...,  0.3831,  0.8722,  1.0950],
            [ 0.3497,  0.4082,  0.4747,  ..., -0.0124, -0.1584,  0.4988],
            ...,
            [ 0.8423, -0.6791,  0.7651,  ...,  0.9749,  0.2351, -0.1885],
            [-0.3156, -0.9350,  0.3687,  ..., -0.1062, -0.4299,  0.2010],
            [ 0.4749, -0.9616,  0.2249,  ..., -0.7806, -1.0308,  0.3255]],
  
 

In [11]:
# img_metas (list[dict]): List of image info dict where each dict
#                 has: 'img_shape', 'scale_factor', 'flip', and may also contain
#                 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.