In [1]:
import torch
import torch.optim as optim
import os
import yaml
# import wandb

from jinja2 import Environment, FileSystemLoader

from training.create_dataset import *
from training.create_network import *
from training.utils import create_task_flags, TaskMetric, compute_loss, eval
from utils import torch_save, get_data_loaders, initialize_wandb

# Login to wandb
# wandb.login()

In [2]:
# Options for training
env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('config/mtl.yaml.j2')
rendered_yaml = template.render()
config = yaml.safe_load(rendered_yaml)

# Create logging folder to store training weights and losses
os.makedirs("logs", exist_ok=True)

model_classes = {
  "split": MTLDeepLabv3,
  "mtan": MTANDeepLabv3,
  # "dinov2": MTLDinoVisionTransformer,
}

In [3]:
torch.manual_seed(config["training_params"]["seed"])
np.random.seed(config["training_params"]["seed"])
random.seed(config["training_params"]["seed"])

# device = torch.device(f"cuda:{config["training_params"]['gpu']}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_loader, val_loader, test_loader = get_data_loaders(config)

In [5]:
images = next(iter(train_loader))

In [6]:
from models.dinov2.depth import dinov2_vits14_ld

model = dinov2_vits14_ld()



In [7]:
images[0].shape

torch.Size([2, 3, 224, 224])

In [8]:
images[1]["depth"].shape

torch.Size([2, 1, 224, 224])

In [9]:
# img_metas (list[dict]): List of image info dict where each dict
#                 has: 'img_shape', 'scale_factor', 'flip', and may also contain
#                 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.

In [10]:
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 2)

image = images[0].detach().to(device)
depth = images[1]["depth"].detach().to(device)

In [11]:
model.train()
optimizer.zero_grad()

In [12]:
model.backbone.patch_embed(image)

DDDD
FFF


tensor([[[ 2.6910e-02,  4.7631e-02,  8.8394e-02,  ...,  1.8235e-02,
          -2.3841e-03,  1.5969e-03],
         [ 4.6896e-02,  4.6478e-02,  5.7952e-02,  ..., -5.5455e-03,
           4.5902e-02,  1.5836e-03],
         [-8.8425e-02,  4.8295e-02,  5.0614e-02,  ...,  3.3028e-02,
           1.0579e-04,  1.1183e-02],
         ...,
         [ 7.9645e-03,  4.7552e-02,  5.1247e-02,  ...,  3.3025e-02,
           9.5586e-03,  6.7903e-03],
         [-5.5989e-02,  4.7468e-02,  1.4163e-02,  ...,  3.5305e-02,
          -3.5921e-03,  6.9401e-03],
         [-3.7302e-02,  4.7055e-02,  7.9847e-02,  ...,  7.4549e-02,
           4.0578e-02,  7.0111e-03]],

        [[-7.3589e-02,  4.8505e-02,  7.8780e-02,  ...,  4.0675e-02,
           1.2043e-02,  3.2469e-03],
         [-1.0785e-01,  4.9415e-02,  8.6202e-02,  ...,  3.7837e-02,
           1.0464e-02,  1.2803e-03],
         [-8.9431e-02,  5.0113e-02,  7.6992e-02,  ...,  4.9945e-02,
           6.6126e-03, -2.7373e-04],
         ...,
         [-3.4223e-02,  4

In [13]:
model.backbone(image)

CCC
torch.Size([2, 3, 224, 224])
DDDD
FFF
torch.Size([2, 256, 384])


((tensor([[[[-2.4361e-02,  4.4513e-03, -2.5568e-02,  ..., -1.7040e-02,
             -4.7646e-02, -3.3779e-02],
            [-3.8648e-03, -4.1586e-03,  2.0594e-03,  ..., -3.6277e-03,
             -1.2559e-02, -2.6673e-02],
            [-5.3425e-02, -2.8258e-02, -5.2189e-02,  ..., -2.4270e-02,
             -3.0013e-02, -1.9739e-02],
            ...,
            [-9.5665e-03, -2.0185e-02, -1.7889e-02,  ..., -2.3987e-03,
             -4.1057e-03, -6.2479e-03],
            [ 1.1738e-03, -2.7186e-02, -1.0759e-02,  ...,  1.9578e-02,
             -2.4879e-02, -2.8551e-02],
            [ 1.5511e-02, -4.3900e-02, -1.7616e-02,  ...,  2.3856e-02,
             -2.0908e-02, -8.3892e-03]],
  
           [[-4.3871e-02, -5.9876e-02, -7.2662e-02,  ..., -1.0514e-02,
              8.3143e-03,  1.1776e-02],
            [-3.3399e-02, -5.4731e-02, -6.7950e-02,  ...,  1.3197e-03,
              2.0916e-02,  2.6122e-02],
            [-2.6680e-02, -4.6264e-02, -5.7105e-02,  ...,  9.0713e-03,
              2.2693

In [14]:
output = model.forward(image, None, depth_gt=depth, return_loss=True)

A
CCC
torch.Size([2, 3, 224, 224])
DDDD
FFF
torch.Size([2, 256, 384])
B


TypeError: 'NoneType' object is not subscriptable