In [2]:
base_path = "/lhome/petbau/master_thesis/runs/"
kp_model = base_path + "run_2022-10-13T18-58-08-634192_SimpleLiftingModel_2D_labels_supervised/"
pc_model = base_path + "run_2022-10-13T19-32-59-200129_PointNet/"
fusion_model = base_path + "run_2022-10-13T20-15-23-337959_Lidar2dKeypointFusion_waymo_2d_labels_supervised/"

models = {"SimpleLiftingModel": kp_model, "PointNet": pc_model, "Lidar2dKeypointFusion":fusion_model}



In [3]:
# some random import to make gin files work 
from input_pipeline.dataset import load_main_3D_data
from train_supervised import SupervisedTrainer
from train_unsupervised import SelfSupervisedTrainer

from configs.constants import JOINT_KEYS, VIS_COMPLETE_IMG_HTML_STATIC, IMAGES_TO_VIS, JOINT_NAMES, JOINt_COLORS_DICT
REVERSE_JOINT_KEYS = {value: key for key, value in JOINT_KEYS.items()}


In [4]:
from models.supervised.lifting_networks.simple_lifting_model import SimpleLiftingModel
from models.supervised.point_networks.pointnet import PointNet
from models.supervised.fusion.lidar_2dkeypoint import Lidar2dKeypointFusionmodel

from models.weakly_supervised.discriminator import Discriminator
from models.weakly_supervised.generator import Generator

import torch.nn as nn


def get_model(model_name, supervised=False):
    if supervised:
        if model_name.lower() == "SimpleLiftingModel".lower():
            model = SimpleLiftingModel(latent_dim=512)
        elif model_name.lower() == "PointNet".lower():
            model = PointNet(dropout=0.2)
        elif model_name.lower() == "Lidar2dKeypointFusion".lower():
            model = Lidar2dKeypointFusionmodel()
        return model
    else:
        discriminator = Discriminator(activation=nn.LeakyReLU)
        discriminator.apply(weights_init)
        generator = Generator(activation=nn.LeakyReLU)
        generator.apply(weights_init)

        return generator, discriminator
    


In [8]:
import torch
import gin 

# load all models 
for model_name, model_path in models.items():
    gin.parse_config_file(model_path + "config_operative.gin")
    model = get_model(model_name, supervised=True)
    print(f"Loading {model_name} model... ")
    model.load_state_dict(torch.load(model_path + "/ckpts/best_model"))
    models[model_name] = model
    

Loading SimpleLiftingModel model... 
Loading PointNet model... 
Loading Lidar2dKeypointFusion model... 


In [9]:
# get data 
from input_pipeline import dataset 
_, ds_test, _ = dataset.load(name="waymo_2d_labels_supervised")

In [10]:
data = next(iter(ds_test))
sample = 11

In [11]:
complete_sample, image =ds_test.dataset.dataset.get_complete_sample(int(data['idx'][sample]))
keypoints_2D = data['keypoints_2D']
occlusions_2D = data['occlusions_2D']
# mask_2D = data['mask_2D']

keypoints_3D = data['keypoints_3D']
occlusions_3D = data['occlusions_3D']
mask_3D = data['mask_3D']
origin = data['root']

pc = data['pc'].transpose(2, 1)

In [12]:
# plot 3d keypoints
from waymo_open_dataset.protos import keypoint_pb2
from waymo_open_dataset.utils import keypoint_draw, frame_utils

keypoints_3D_gt = []
counter = 0
for keypoint in keypoints_3D[sample]:
    if mask_3D[sample][counter][0]:
        laser_keypoint = keypoint_pb2.LaserKeypoint()
        laser_keypoint.type = REVERSE_JOINT_KEYS[counter]
        laser_keypoint.keypoint_3d.location_m.x = keypoint[0]  # complete_sample['keypoints_3d'][keypoint]['x'] - complete_sample['bb_3d']['center_x']
        laser_keypoint.keypoint_3d.location_m.y = keypoint[1]  # complete_sample['keypoints_3d'][keypoint]['y'] - complete_sample['bb_3d']['center_y']
        laser_keypoint.keypoint_3d.location_m.z = keypoint[2]  # complete_sample['keypoints_3d'][keypoint]['z'] - complete_sample['bb_3d']['center_z']
        laser_keypoint.keypoint_3d.visibility.is_occluded = False if int(occlusions_3D[sample][counter]) == 1 else True
        keypoints_3D_gt.append(laser_keypoint)
    counter += 1
laser_wireframe_gt = keypoint_draw.build_laser_wireframe(keypoints_3D_gt)



In [None]:
from waymo_open_dataset.utils.keypoint_draw import Wireframe
import plotly.graph_objects as go
import numpy as np

def build_laser_wireframe( predictions, occlusions_3D):
    
        keypoints_3D_pred = []
        counter = 0
        for keypoint, occlusion in zip(predictions, occlusions_3D):
            laser_keypoint = keypoint_pb2.LaserKeypoint()
            laser_keypoint.type = REVERSE_JOINT_KEYS[counter]
            laser_keypoint.keypoint_3d.location_m.x = keypoint[0]
            laser_keypoint.keypoint_3d.location_m.y = keypoint[1]
            laser_keypoint.keypoint_3d.location_m.z = keypoint[2]
            laser_keypoint.keypoint_3d.visibility.is_occluded = False if int(occlusion) == 1 else True
            keypoints_3D_pred.append(laser_keypoint)
            counter += 1

        return keypoint_draw.build_laser_wireframe(keypoints_3D_pred)
    
def create_plotly_figure(title="", width=500, height=750) -> go.Figure:
    """Creates a plotly figure for 3D visualization."""
    fig = go.Figure()
    axis_settings = dict(
        showgrid=False,
        zeroline=False,
        showline=False,
        showbackground=True,
        showaxeslabels=True,
        showticklabels=True)
    fig.update_layout(
        title_text=f"{title}",
        width=width,
        height=height,
        showlegend=False,
        scene=dict(
            aspectmode='data',  # force xyz has same scale,
            xaxis=axis_settings,
            yaxis=axis_settings,
            zaxis=axis_settings,
        ),
    )
    return fig

def draw_laser_wireframe(fig: go.Figure, wireframe: Wireframe, grey=False) -> None:
    OCCLUDED_BORDER_WIDTH = 3
    """Draws a laser wireframe onto the plotly Figure."""
    for line in wireframe.lines:
        points = np.stack([line.start, line.end], axis=0)
        fig.add_trace(
            go.Scatter3d(
                mode='lines',
                x=points[:, 0],
                y=points[:, 1],
                z=points[:, 2],
                line=dict(color='#666b6b' if grey else line.color, width=line.width)))
    dot_coords = np.stack([d.location for d in wireframe.dots], axis=0)
    fig.add_trace(
        go.Scatter3d(
            text=[d.name for d in wireframe.dots],
            mode='markers',
            x=dot_coords[:, 0],
            y=dot_coords[:, 1],
            z=dot_coords[:, 2],
            marker=dict(
                color=['#666b6b' if grey else d.color for d in wireframe.dots],
                size=[d.size*3 if d.name=='NOSE' else d.size for d in wireframe.dots],
                line=dict(
                    width=OCCLUDED_BORDER_WIDTH,
                    color=[d.actual_border_color for d in wireframe.dots]))))


In [40]:
from waymo_open_dataset.utils.keypoint_draw import Wireframe
import plotly.graph_objects as go
import numpy as np

def build_laser_wireframe( predictions, occlusions_3D):
    
        keypoints_3D_pred = []
        counter = 0
        for keypoint, occlusion in zip(predictions, occlusions_3D):
            laser_keypoint = keypoint_pb2.LaserKeypoint()
            laser_keypoint.type = REVERSE_JOINT_KEYS[counter]
            laser_keypoint.keypoint_3d.location_m.x = keypoint[0]
            laser_keypoint.keypoint_3d.location_m.y = keypoint[1]
            laser_keypoint.keypoint_3d.location_m.z = keypoint[2]
            laser_keypoint.keypoint_3d.visibility.is_occluded = False if int(occlusion) == 1 else True
            keypoints_3D_pred.append(laser_keypoint)
            counter += 1

        return keypoint_draw.build_laser_wireframe(keypoints_3D_pred)
    
def create_plotly_figure(title="", width=500, height=750) -> go.Figure:
    """Creates a plotly figure for 3D visualization."""
    fig = go.Figure()
    axis_settings = dict(
        showgrid=False,
        zeroline=False,
        showline=False,
        showbackground=True,
        showaxeslabels=True,
        showticklabels=True)
    fig.update_layout(
        title_text=f"{title}",
        width=width,
        height=height,
        showlegend=False,
        scene=dict(
            aspectmode='data',  # force xyz has same scale,
            xaxis=axis_settings,
            yaxis=axis_settings,
            zaxis=axis_settings,
        ),
    )
    return fig

def draw_laser_wireframe(fig: go.Figure, wireframe: Wireframe, color=None) -> None:
    OCCLUDED_BORDER_WIDTH = 3
    """Draws a laser wireframe onto the plotly Figure."""
    for line in wireframe.lines:
        if color is not None:
            line_color=dict(color=color, width=line.width)
            scatter_color = [color for d in wireframe.dots]
            border_color = [color for d in wireframe.dots]
        else:
            line_color = dict(color=line.color, width=line.width)
            scatter_color= [d.color for d in wireframe.dots]
            border_color = [d.actual_border_color for d in wireframe.dots]
        points = np.stack([line.start, line.end], axis=0)
        fig.add_trace(
            go.Scatter3d(
                mode='lines',
                x=points[:, 0],
                y=points[:, 1],
                z=points[:, 2],
                line=line_color,))
    dot_coords = np.stack([d.location for d in wireframe.dots], axis=0)
    fig.add_trace(
        go.Scatter3d(
            text=[d.name for d in wireframe.dots],
            mode='markers',
            x=dot_coords[:, 0],
            y=dot_coords[:, 1],
            z=dot_coords[:, 2],
            marker=dict(
                color=scatter_color,
                size=[d.size*3 if d.name=='NOSE' else d.size for d in wireframe.dots],
                line=dict(
                    width=OCCLUDED_BORDER_WIDTH,
                    color=border_color))))


In [18]:
predictions = {"SimpleLiftingModel": None, "PointNet":None, "Lidar2dKeypointFusion":None}

for name, model in models.items():
    print(f"Predicting unsing {name}...")
    if name.lower() == "PointNet".lower():
        preds, _ = model(pc)
    elif name == "SimpleLiftingModel":
        preds = model(keypoints_2D)
    elif name.lower() == "Lidar2dKeypointFusion".lower():
        preds, _, _ = model(pc, keypoints_2D, gt=(keypoints_3D, data['mask_3D']))
    else:
        print("Error! Wrong model name...")
        break
    predictions[name] = preds
    
    
    

Predicting unsing SimpleLiftingModel...
Predicting unsing PointNet...
Predicting unsing Lidar2dKeypointFusion...


In [55]:
laser_wireframes =  {"SimpleLiftingModel": None, "PointNet":None, "Lidar2dKeypointFusion":None}
laser_wireframe_gt = keypoint_draw.build_laser_wireframe(keypoints_3D_gt)
figures = []

# draw gt
fig_gt = create_plotly_figure(title='GroundTruth')
#draw_laser_wireframe(fig_gt, laser_wireframe_pred, color="#CC6600")
draw_laser_wireframe(fig_gt, laser_wireframe_gt, color="#00ADEF")
fig_gt.layout.scene.xaxis.showaxeslabels = False
fig_gt.layout.scene.xaxis.showbackground = False
fig_gt.layout.scene.xaxis.showticklabels = False
fig_gt.layout.scene.yaxis.showaxeslabels = False
fig_gt.layout.scene.yaxis.showbackground = False
fig_gt.layout.scene.yaxis.showticklabels = False
fig_gt.layout.scene.zaxis.showaxeslabels = False
#fig.layout.scene.zaxis.showbackground = False
fig_gt.layout.scene.zaxis.showticklabels = False
figures.append(fig_gt)


# SimpleLiftingModel
laser_wireframe_pred = build_laser_wireframe(predictions['SimpleLiftingModel'][sample], occlusions_3D[sample])
fig_lift = create_plotly_figure(title='SimpleLiftingModel')
draw_laser_wireframe(fig_lift, laser_wireframe_pred, color="#CC6600")
#draw_laser_wireframe(fig_lift, laser_wireframe_gt, color="#00ADEF")
fig_lift.layout.scene.xaxis.showaxeslabels = False
fig_lift.layout.scene.xaxis.showbackground = False
fig_lift.layout.scene.xaxis.showticklabels = False
fig_lift.layout.scene.yaxis.showaxeslabels = False
fig_lift.layout.scene.yaxis.showbackground = False
fig_lift.layout.scene.yaxis.showticklabels = False
fig_lift.layout.scene.zaxis.showaxeslabels = False
#fig.layout.scene.zaxis.showbackground = False
fig_lift.layout.scene.zaxis.showticklabels = False
figures.append(fig_lift)

# PointNet
laser_wireframe_pred = build_laser_wireframe(predictions['PointNet'][sample], occlusions_3D[sample])
fig_point = create_plotly_figure(title='PointNet')
draw_laser_wireframe(fig_point, laser_wireframe_pred, color="#CC6600")
#draw_laser_wireframe(fig_point, laser_wireframe_gt,color="#00ADEF")
fig_point.layout.scene.xaxis.showaxeslabels = False
fig_point.layout.scene.xaxis.showbackground = False
fig_point.layout.scene.xaxis.showticklabels = False
fig_point.layout.scene.yaxis.showaxeslabels = False
fig_point.layout.scene.yaxis.showbackground = False
fig_point.layout.scene.yaxis.showticklabels = False
fig_point.layout.scene.zaxis.showaxeslabels = False
#fig.layout.scene.zaxis.showbackground = False
fig_point.layout.scene.zaxis.showticklabels = False


figures.append(fig_point)

# Fusion
laser_wireframe_pred = build_laser_wireframe(predictions['Lidar2dKeypointFusion'][sample], occlusions_3D[sample])
fig_fusion = create_plotly_figure(title='Fusion')
draw_laser_wireframe(fig_fusion, laser_wireframe_pred, color="#CC6600")
#draw_laser_wireframe(fig_fusion, laser_wireframe_gt, color="#00ADEF")
fig_fusion.layout.scene.xaxis.showaxeslabels = False
fig_fusion.layout.scene.xaxis.showbackground = False
fig_fusion.layout.scene.xaxis.showticklabels = False
fig_fusion.layout.scene.yaxis.showaxeslabels = False
fig_fusion.layout.scene.yaxis.showbackground = False
fig_fusion.layout.scene.yaxis.showticklabels = False
fig_fusion.layout.scene.zaxis.showaxeslabels = False
#fig.layout.scene.zaxis.showbackground = False
fig_fusion.layout.scene.zaxis.showticklabels = False

figures.append(fig_fusion)
    

In [53]:
fig_fusion.show()

In [56]:
head_and_body = """<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
* {
  box-sizing: border-box;
}

/* Create three unequal columns that floats next to each other */
.column {
  float: left;
  padding: 10px;
  height: 300px; /* Should be removed. Only for demonstration */
}

.left, .right {
  width: 25%;
}

.middle_left {
  width: 25%;
}

/* Clear floats after the columns */
.row:after {
  content: "";
  display: table;
  clear: both;
}
</style>
</head>
<body>
"""






end = """</body>
</html>
"""



with open("presentation.html", 'w') as dashboard:
  
    dashboard.write(head_and_body)
    dashboard.write('<div class="row">')
    dashboard.write('<div class="column left">')
    inner_html = figures[0].to_html().split('<body>')[1].split('</body>')[0]
    dashboard.write('<center>')
    dashboard.write(inner_html)
    dashboard.write('</center>')
    dashboard.write('</div>')
    dashboard.write('<div class="column middle" >')
    inner_html = figures[1].to_html().split('<body>')[1].split('</body>')[0]
    dashboard.write('<center>')
    dashboard.write(inner_html)
    dashboard.write('</center>')
    dashboard.write('</div>')
    dashboard.write('<div class="column middle" >')
    inner_html = figures[2].to_html().split('<body>')[1].split('</body>')[0]
    dashboard.write('<center>')
    dashboard.write(inner_html)
    dashboard.write('</center>')
    dashboard.write('</div>')
    dashboard.write('<div class="column right"')
    inner_html = figures[3].to_html().split('<body>')[1].split('</body>')[0]
    dashboard.write('<center>')
    dashboard.write(inner_html)
    dashboard.write('</center>')
    dashboard.write('</div>')
    dashboard.write('</div>')
    dashboard.write(end)
    

: 