In [1]:
from generative.dataset import ScenarioDataset, collate_scenario_batch, split_dataset
from generative.components import CondionalEncoder, HeatmapEncoder, HeatmapDecoder
from generative.losses import FocalLoss, MaskedL1Loss, MaskedCrossEntropyLoss
from dataset_utils.scenario_utils import visualize_sceanrio_2d,visualize_all_cameras, scenario_to_heatmap, visualize_heatmap

import matplotlib.pyplot as plt
import numpy as np
from dataset_utils.constants import id_to_category,MAP_DIR
from torch.utils.data import DataLoader

# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
scenarios_path = '/home/stud/komo/data/scenarios/nuscenes/scenarios.json'
train_set, val_set = split_dataset(scenarios_path, 0.8)
dataset = ScenarioDataset(train_set)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_scenario_batch)

dataset.num_classes

33

In [3]:
def get_number_of_learnable_paramters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

In [4]:
sample_batch = next(iter(data_loader))

condional_encoder = CondionalEncoder()
print("Number of learnable paramters in millions: ", get_number_of_learnable_paramters(condional_encoder))
latent = condional_encoder( sample_batch[1]['map_information'])
print(latent.shape)

Number of learnable paramters in millions:  0.012288
torch.Size([4, 4, 32, 32])


In [5]:
IN_CHANNELS = 40
heatmap_encoder = HeatmapEncoder(IN_CHANNELS, first_layer_channels=32)
print("Number of learnable paramters in millions: ", get_number_of_learnable_paramters(heatmap_encoder))
print("Input shape: ", sample_batch[0].shape)
mu,sigmas = heatmap_encoder(sample_batch[0], latent)
print(mu.shape)
print(sigmas.shape)

Number of learnable paramters in millions:  2.792854
Input shape:  torch.Size([4, 40, 256, 256])
torch.Size([4, 32, 32, 32])
torch.Size([4, 32, 32, 32])


In [6]:
heatmap_decoder = HeatmapDecoder(IN_CHANNELS, first_layer_channels=32)
print("Number of learnable paramters in millions: ", get_number_of_learnable_paramters(heatmap_decoder))
heatmap = heatmap_decoder(mu, latent)
print(heatmap.shape)

Number of learnable paramters in millions:  1.602184
torch.Size([4, 40, 256, 256])


In [None]:
import torch

focal_loss = FocalLoss()
masked_l1_loss = MaskedL1Loss()
BCE = torch.nn.BCELoss()
masked_cross_entropy = MaskedCrossEntropyLoss()
regular_cross_entropy = torch.nn.CrossEntropyLoss()
road_mask = sample_batch[0][:,0,:,:].eq(1)
heat_map_mask = sample_batch[0][:,1,:,:].eq(1)
print(heat_map_mask.shape)
road_loss = focal_loss(heatmap[:,0,:,:], sample_batch[0][:,0,:,:])
print("Road loss: ", road_loss)
heatmap_loss = focal_loss(heatmap[:,1,:,:], sample_batch[0][:,1,:,:])
print("Heatmap loss: ", heatmap_loss)
offset_l1 = masked_l1_loss(heatmap[:,2:4,:,:], sample_batch[0][:,2:4,:,:], heat_map_mask) 
print("Offset loss: ", offset_l1)
size_l1 = masked_l1_loss(heatmap[:,4:7,:,:], sample_batch[0][:,4:7,:,:], heat_map_mask)
print("Size loss: ", size_l1)
target_classes = torch.argmax(sample_batch[0][:,7:,:,:], dim=1)
print("Target classes shape: ", target_classes.shape)
print("predicted classes shape: ", sample_batch[0][:,7:,:,:].shape)
classes_ce = masked_cross_entropy(heatmap[:,7:,:,:], target_classes, heat_map_mask)

print("Classes loss: ", classes_ce)
class_ce = regular_cross_entropy(heatmap[:,7:,:,:], target_classes)
print("Classes loss: ", class_ce)



torch.Size([4, 256, 256])
Road loss:  tensor(1.1369, grad_fn=<RsubBackward1>)
Heatmap loss:  tensor(8789.2842, grad_fn=<RsubBackward1>)
Offset loss:  tensor(0.)
Size loss:  tensor(13.4034, grad_fn=<DivBackward0>)
Target classes shape:  torch.Size([4, 256, 256])
predicted classes shape:  torch.Size([4, 33, 256, 256])
Classes loss:  tensor(3.4692, grad_fn=<DivBackward0>)
Classes loss:  tensor(3.4987, grad_fn=<NllLoss2DBackward0>)


In [29]:
# test out the regular cross entropy loss

target_classes = torch.Tensor([0, 1, 2]).long()
# add more class to have the shape of (2,3,3)
target_classes = target_classes.unsqueeze(0).repeat(3,1)
target_classes = target_classes.unsqueeze(0).repeat(3,1,1)

logits = torch.Tensor([[10, -10, -10], [-10, 10, -10], [-10, -10, 10]]).float()
logits = logits.unsqueeze(0).repeat(3,1,1)
logits = logits.unsqueeze(0).repeat(3,1,1,1)
full_mask = torch.Tensor([1, 1, 1]).bool()
full_mask = full_mask.unsqueeze(0).repeat(3,1)
full_mask = full_mask.unsqueeze(0).repeat(3,1,1)
print("Logits shape: ", logits.shape)
print("Target classes shape: ", target_classes.shape)
print("Full mask shape: ", full_mask.shape)
print("Loss: ", masked_cross_entropy(logits, target_classes, full_mask))


Logits shape:  torch.Size([3, 3, 3, 3])
Target classes shape:  torch.Size([3, 3, 3])
Full mask shape:  torch.Size([3, 3, 3])


AttributeError: 'MaskedCrossEntropyLoss' object has no attribute 'loss'