In [None]:
import sys
import os
from collections import defaultdict
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from IPython.display import display, HTML
sys.path.append(os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))
from datasets.Waymo import WaymoDataset, waymo_collate_fn
from model import OccupancyFlowNetwork
from visualize import render_flow_at_spacetime

In [2]:
tfrecord_path = '../../data1/waymo_dataset/v1.1/waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/validation'
idx_path = '../../data1/waymo_dataset/v1.1/idx/validation'
dataset = WaymoDataset(tfrecord_path, idx_path)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=waymo_collate_fn)
road_map, agent_trajectories, \
flow_field_agent_ids, flow_field_positions, flow_field_times, flow_field_velocities, \
agent_mask, flow_field_mask = next(iter(dataloader))

In [None]:
print(f'road map: {road_map.shape}')
print(f'agent trajectories: {agent_trajectories.shape}')
print(f'flow field agent ids: {flow_field_agent_ids.shape}')
print(f'flow field positions: {flow_field_positions.shape}')
print(f'flow field times: {flow_field_times.shape}')
print(f'flow field velocities: {flow_field_velocities.shape}')
print(f'agent mask: {agent_mask.shape}')
print(f'flow field: {flow_field_mask.shape}')

anim = render_flow_at_spacetime(road_map[0], flow_field_times[0], flow_field_positions[0], flow_field_velocities[0])
display(HTML(anim.to_jshtml()))

road map: torch.Size([1, 256, 256, 3])
agent trajectories: torch.Size([1, 21, 11, 10])
flow field agent ids: torch.Size([1, 27100, 1])
flow field positions: torch.Size([1, 27100, 2])
flow field times: torch.Size([1, 27100, 1])
flow field velocities: torch.Size([1, 27100, 2])
agent mask: torch.Size([1, 21])
flow field: torch.Size([1, 27100])


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [8]:
# can we optimize this with sliding windows?
def occupancy_warping(flow_field, scene_context,
                      agent_ids, positions, times):
    agent_groups = defaultdict(list)
    [agent_groups[round(val.item(), 1)].append(idx) for idx, val in enumerate(agent_ids)]
    for indices in agent_groups.values():
        agent_poistions = positions[indices]
        agent_times = times[indices]
        
        time_groups = defaultdict(list)
        [time_groups[round(val.item(), 1)].append(idx) for idx, val in enumerate(agent_times)]

        occupancy = []
        integration_times = []

        for time, indices in time_groups.items():
            integration_times.append(time)
            occupancy.append(agent_poistions[indices])

        initial_value = occupancy[0].unsqueeze(0)
        integration_times = torch.FloatTensor(integration_times).to(times.device)
        print('-----')
        print(initial_value.shape)
        print(integration_times.shape)
        print(scene_context.shape)
        estimated_occupancy = flow_field.warp_occupancy(initial_value, integration_times, scene_context)
        #print(len(estimated_occupancy))
    return 0

In [9]:
flow_field = OccupancyFlowNetwork(road_map_image_size=256, road_map_window_size=8, 
                                  trajectory_feature_dim=10, 
                                  embedding_dim=256, 
                                  flow_field_hidden_dim=256, flow_field_fourier_features=128).to(device)
flow_field.train()


optim = torch.optim.Adam(flow_field.parameters(), lr=1e-3, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)

road_map = road_map.to(device)
agent_trajectories = agent_trajectories.to(device)
p = flow_field_positions.to(device)
t = flow_field_times.to(device)
v = flow_field_velocities.to(device)
    
flow = flow_field(t, p, road_map, agent_trajectories)
scene_context = flow_field.scene_encoder(road_map, agent_trajectories)

flow_loss = F.mse_loss(flow, v)
occupancy_loss = occupancy_warping(flow_field, scene_context,
                                   flow_field_agent_ids[0], p[0], t[0])
loss = flow_loss + occupancy_loss

optim.zero_grad()
loss.backward()
optim.step()

scheduler.step()

print(loss)

-----
torch.Size([1, 18, 2])
torch.Size([81])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([57])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([64])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([51])
torch.Size([1, 256])
-----
torch.Size([1, 18, 2])
torch.Size([15])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([29])
torch.Size([1, 256])
-----
torch.Size([1, 18, 2])
torch.Size([50])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([45])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([5])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([25])
torch.Size([1, 256])
-----
torch.Size([1, 18, 2])
torch.Size([78])
torch.Size([1, 256])
-----
torch.Size([1, 18, 2])
torch.Size([28])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([81])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([81])
torch.Size([1, 256])
-----
torch.Size([1, 15, 2])
torch.Size([17])
torch.Size([1, 25

In [None]:
agent_groups = defaultdict(list)
[agent_groups[round(val.item(), 1)].append(idx) for idx, val in enumerate(flow_field_agent_ids[0])]
for indices in agent_groups.values():
    agent_times = flow_field_times[0][indices]
    agent_poistions = flow_field_positions[0][indices]

    time_groups = defaultdict(list)
    [time_groups[round(val.item(), 1)].append(idx) for idx, val in enumerate(agent_times)]

    for key, indices in time_groups.items():
        print(f'{key}: {len(indices)}, {agent_poistions[indices].shape}')