In [None]:
import sys
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from IPython.display import display, HTML
sys.path.append(os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))
from datasets.Waymo import WaymoDataset, waymo_collate_fn
from model import OccupancyFlowNetwork
from visualize import render_observed_scene_state, render_ground_truth_occupancy

In [None]:
NUM_SCENES = 25#25
MAX_SCENES_TO_RENDER = 1

tfrecord_path = '../../data1/waymo_dataset/v1.1/waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/validation'
idx_path = '../../data1/waymo_dataset/v1.1/idx/validation'
dataset = WaymoDataset(tfrecord_path, idx_path)
dataloader = DataLoader(dataset, batch_size=NUM_SCENES, collate_fn=waymo_collate_fn)

scenes = []
#for _ in range(NUM_SCENES):
scenes.append(next(iter(dataloader)))

In [None]:
count = 0
for scene in scenes:
    road_map = scene.observed_state.road_map
    agent_trajectories = scene.observed_state.agent_trajectories

    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

render_observed_scene_state(road_map[0], agent_trajectories[0])

In [None]:
count = 0
for scene in scenes:
    road_map = scene.observed_state.road_map
    occupancy_grid_occupancies = scene.occupancy_grid.unoccluded_occupancies

    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    anim = render_ground_truth_occupancy(road_map[0], occupancy_grid_occupancies[0])
    display(HTML(anim.to_jshtml()))

In [None]:
count = 0
for scene in scenes:
    road_map = scene.observed_state.road_map
    occupancy_grid_occluded_occupancies = scene.occupancy_grid.occluded_occupancies

    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    anim = render_ground_truth_occupancy(road_map[0], occupancy_grid_occluded_occupancies[0], occluded=True)
    display(HTML(anim.to_jshtml()))

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
count = 0
for scene in scenes:
    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    road_map = scene.observed_state.road_map

    anim = render_ground_truth_occupancy(road_map[0], estimated_occupancies[0].cpu(), occluded=False)
    display(HTML(anim.to_jshtml()))

In [None]:
occupancy_flow_network = OccupancyFlowNetwork(road_map_image_size=256, road_map_window_size=8, 
                                              trajectory_feature_dim=10, 
                                              embedding_dim=256, 
                                              flow_field_hidden_dim=256, flow_field_fourier_features=0).to(device)
occupancy_flow_network.train()

optim = torch.optim.Adam(occupancy_flow_network.parameters(), lr=1e-4, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)
criterion = nn.BCELoss()

EPOCHS = 10000
for epoch in range(EPOCHS):
    epoch_loss = 0
    for scene in scenes:
        road_map = scene.observed_state.road_map.to(device)
        agent_trajectories = scene.observed_state.agent_trajectories.to(device)
        occupancy_grid_positions = scene.occupancy_grid.positions.to(device)
        occupancy_grid_times = scene.occupancy_grid.times.to(device)
        occupancy_grid_unoccluded_occupancies = scene.occupancy_grid.unoccluded_occupancies.to(device)
        occupancy_grid_occluded_occupancies = scene.occupancy_grid.occluded_occupancies.to(device)
        agent_mask = scene.observed_state.agent_mask.to(device)

        batch, length, width, time, _ = occupancy_grid_occupancies.shape
        occupancy_grid_positions = occupancy_grid_positions.reshape(batch, length * width, 2)

        #for tt in range(10):
        for t in range(time):
            #t = tt * 10
            times = occupancy_grid_times[:, t].view(batch, 1, 1).expand(batch, length * width, 1)
            positions = occupancy_grid_positions
            estimated_occupancies, _ = occupancy_flow_network.estimate_occupancy(times, positions, road_map, agent_trajectories, agent_mask)
            estimated_unoccluded_occupancies = estimated_occupancies[:, :, 0]
            estimated_occluded_occupancues = estimated_occupancies[:, :, 1]

            estimated_unoccluded_occupancies = estimated_unoccluded_occupancies.reshape(batch, length, width, 1)
            ground_truth_unoccluded_occupancies = occupancy_grid_unoccluded_occupancies[:, :, :, t, :]

            estimated_occluded_occupancues = estimated_occluded_occupancues.reshape(batch, length, width, 1)
            ground_truth_occluded_occupancies = occupancy_grid_occluded_occupancies[:, :, :, t, :]

            unoccluded_loss = criterion(estimated_unoccluded_occupancies, ground_truth_unoccluded_occupancies)
            occluded_loss = criterion(estimated_unoccluded_occupancies, ground_truth_unoccluded_occupancies)
            loss = 0.5 * unoccluded_loss + 0.5 * occluded_loss

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(occupancy_flow_network.parameters(), max_norm=1.0)
            optim.step()

        epoch_loss += loss

    scheduler.step()

    #epoch_loss /= NUM_SCENES
    
    #if epoch == 0 or (epoch + 1) % 100 == 0:
    print(f'epoch {epoch+1} loss: {epoch_loss.item()}')

In [None]:
estimated_occupancies = []

for scene in scenes:
    road_map = scene.observed_state.road_map.to(device)
    agent_trajectories = scene.observed_state.agent_trajectories.to(device)
    occupancy_grid_positions = scene.occupancy_grid.positions.to(device)
    occupancy_grid_times = scene.occupancy_grid.times.to(device)
    occupancy_grid_occupancies = scene.occupancy_grid.unoccluded_occupancies.to(device)
    agent_mask = scene.observed_state.agent_mask.to(device)

    batch, length, width, time, _ = occupancy_grid_occupancies.shape
    occupancy_grid_positions = occupancy_grid_positions.reshape(batch, length * width, 2)

    for t in range(time):
        times = occupancy_grid_times[:, t].view(batch, 1, 1).expand(batch, length * width, 1)
        positions = occupancy_grid_positions
        with torch.no_grad():
            estimated_occupancy, _ = occupancy_flow_network.estimate_occupancy(times, positions, road_map, agent_trajectories, agent_mask)
        #estimated_occluded_occupancies = occupancy_flow_network.estimate_occluded_occupancy(occupancy_grid_times, occupancy_grid_positions, road_map, agent_trajectories, agent_mask)

        estimated_occupancy = estimated_occupancy[:, :, 0].reshape(batch, length, width, 1)
        estimated_occupancies.append(estimated_occupancy)

estimated_occupancies = torch.stack(estimated_occupancies, dim=3)

In [None]:
count = 0
for scene in scenes:
    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    road_map = scene.observed_state.road_map

    anim = render_ground_truth_occupancy(road_map[0], estimated_occupancies[0].cpu(), occluded=False)
    display(HTML(anim.to_jshtml()))

In [None]:
estimated_occupancies = []

for road_map, agent_trajectories, _, _, _, _, occupancy_grid_positions, occupancy_grid_times, occupancy_grid_occupancies, occupancy_grid_occluded_occupancies, agent_mask, _ in scenes:
    road_map = road_map.to(device)
    agent_trajectories = agent_trajectories.to(device)
    occupancy_grid_positions = occupancy_grid_positions.to(device)
    occupancy_grid_times = occupancy_grid_times.to(device)
    occupancy_grid_occupancies = occupancy_grid_occupancies.to(device)
    occupancy_grid_occluded_occupancies = occupancy_grid_occluded_occupancies.to(device)
    agent_mask = agent_mask.to(device)

    batch, length, width, time, _ = occupancy_grid_occupancies.shape
    occupancy_grid_positions = occupancy_grid_positions.reshape(batch, length * width, 2)

    for t in range(time):
        times = occupancy_grid_times[:, t].view(1, 1, 1).expand(1, length * width, 1)
        positions = occupancy_grid_positions
        with torch.no_grad():
            estimated_occupancy, _ = occupancy_flow_network.estimate_occupancy(times, positions, road_map, agent_trajectories, agent_mask)
        #estimated_occluded_occupancies = occupancy_flow_network.estimate_occluded_occupancy(occupancy_grid_times, occupancy_grid_positions, road_map, agent_trajectories, agent_mask)

        estimated_occupancy = estimated_occupancy.reshape(batch, length, width, 1)
        estimated_occupancies.append(estimated_occupancy)

estimated_occupancies = torch.stack(estimated_occupancies, dim=3)

In [None]:
count = 0
for road_map, _, _, _, _, _, _, _, _, _, _, _ in scenes:
    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    anim = render_ground_truth_occupancy(road_map[0], estimated_occupancies[0].cpu(), occluded=False)
    display(HTML(anim.to_jshtml()))

In [None]:
occupancy_flow_network2 = OccupancyFlowNetwork(road_map_image_size=256, road_map_window_size=8, 
                                              trajectory_feature_dim=10, 
                                              embedding_dim=256, 
                                              flow_field_hidden_dim=256, flow_field_fourier_features=128).to(device)
occupancy_flow_network2.train()

optim = torch.optim.Adam(occupancy_flow_network2.parameters(), lr=1e-4, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)
criterion = nn.BCELoss()

EPOCHS = 1000
for epoch in range(EPOCHS):
    epoch_loss = 0
    for road_map, agent_trajectories, _1, _2, _3, _4, occupancy_grid_positions, occupancy_grid_times, occupancy_grid_occupancies, occupancy_grid_occluded_occupancies, agent_mask, _ in scenes:
        road_map = road_map.to(device)
        agent_trajectories = agent_trajectories.to(device)
        occupancy_grid_positions = occupancy_grid_positions.to(device)
        occupancy_grid_times = occupancy_grid_times.to(device)
        occupancy_grid_occupancies = occupancy_grid_occupancies.to(device)
        occupancy_grid_occluded_occupancies = occupancy_grid_occluded_occupancies.to(device)
        agent_mask = agent_mask.to(device)

        batch, length, width, time, _ = occupancy_grid_positions.shape
        occupancy_grid_positions = occupancy_grid_positions.reshape(batch, length * width, time, 2)
        occupancy_grid_times = occupancy_grid_times.reshape(batch, length * width, time, 1)

        for t in range(time):
            times = occupancy_grid_times[:, :, t, :]
            positions = occupancy_grid_positions[:, :, t, :]
            estimated_occupancies, _ = occupancy_flow_network2.estimate_occupancy(times, positions, road_map, agent_trajectories, agent_mask)
            #estimated_occluded_occupancies = occupancy_flow_network.estimate_occluded_occupancy(occupancy_grid_times, occupancy_grid_positions, road_map, agent_trajectories, agent_mask)

            estimated_occupancies = estimated_occupancies.reshape(batch, length, width, 1)
            ground_truth_occupancies = occupancy_grid_occupancies[:, :, :, t, :]

            loss = criterion(estimated_occupancies, ground_truth_occupancies)

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(occupancy_flow_network.parameters(), max_norm=1.0)
            optim.step()

        epoch_loss += loss

    scheduler.step()

    epoch_loss /= NUM_SCENES
    
    #if epoch == 0 or (epoch + 1) % 100 == 0:
    print(f'epoch {epoch+1} loss: {epoch_loss.item()}')

In [None]:
estimated_occupancies2 = []

for road_map, agent_trajectories, _, _, _, _, occupancy_grid_positions, occupancy_grid_times, occupancy_grid_occupancies, occupancy_grid_occluded_occupancies, agent_mask, _ in scenes:
    road_map = road_map.to(device)
    agent_trajectories = agent_trajectories.to(device)
    occupancy_grid_positions = occupancy_grid_positions.to(device)
    occupancy_grid_times = occupancy_grid_times.to(device)
    occupancy_grid_occupancies = occupancy_grid_occupancies.to(device)
    occupancy_grid_occluded_occupancies = occupancy_grid_occluded_occupancies.to(device)
    agent_mask = agent_mask.to(device)

    batch, length, width, time, _ = occupancy_grid_positions.shape
    occupancy_grid_positions = occupancy_grid_positions.reshape(batch, length * width, time, 2)
    occupancy_grid_times = occupancy_grid_times.reshape(batch, length * width, time, 1)

    for t in range(time):
        times = occupancy_grid_times[:, :, t, :]
        positions = occupancy_grid_positions[:, :, t, :]
        with torch.no_grad():
            estimated_occupancy, _ = occupancy_flow_network2.estimate_occupancy(times, positions, road_map, agent_trajectories, agent_mask)
        #estimated_occluded_occupancies = occupancy_flow_network.estimate_occluded_occupancy(occupancy_grid_times, occupancy_grid_positions, road_map, agent_trajectories, agent_mask)

        estimated_occupancy = estimated_occupancy.reshape(batch, length, width, 1)
        estimated_occupancies2.append(estimated_occupancy)

estimated_occupancies2 = torch.stack(estimated_occupancies2, dim=3)

In [None]:
count = 0
for road_map, _, _, _, _, _, _, _, _, _, _, _ in scenes:
    count += 1
    if count > MAX_SCENES_TO_RENDER:
        break

    anim = render_ground_truth_occupancy(road_map[0], estimated_occupancies2[0].cpu(), occluded=False)
    display(HTML(anim.to_jshtml()))