# MAML enhanced model

In [None]:
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torchvision.models.resnet import resnet50
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
from l5kit.geometry import transform_points
from l5kit.visualization import draw_trajectory, PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR
from pathlib import Path

import time
import os

Prepare Data path and load cfg

In [None]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = "lyft-motion-prediction-autonomous-vehicles"
dm = LocalDataManager(None)
# get config
cfg = load_config_data("./agent_motion_config.yaml")
print(cfg)

{'format_version': 4, 'model_params': {'model_architecture': 'resnet50', 'history_num_frames': 0, 'future_num_frames': 20, 'step_time': 0.1, 'render_ego_history': True}, 'raster_params': {'raster_size': [224, 224], 'pixel_size': [0.5, 0.5], 'ego_center': [0.25, 0.5], 'map_type': 'py_semantic', 'satellite_map_key': 'aerial_map/aerial_map.png', 'semantic_map_key': 'semantic_map/semantic_map.pb', 'dataset_meta_key': 'meta.json', 'filter_agents_threshold': 0.5, 'disable_traffic_light_faces': False, 'set_origin_to_bottom': True}, 'train_data_loader': {'key': 'scenes/train.zarr', 'batch_size': 12, 'shuffle': True, 'num_workers': 0}, 'val_data_loader': {'key': 'scenes/validate.zarr', 'batch_size': 12, 'shuffle': False, 'num_workers': 0}, 'test_data_loader': {'key': 'scenes/test.zarr', 'batch_size': 12, 'shuffle': False, 'num_workers': 0}, 'sample_data_loader': {'key': 'scenes/sample.zarr', 'batch_size': 12, 'shuffle': False, 'num_workers': 0}, 'train_params': {'checkpoint_every_n_steps': 1000

### Model


In [None]:
def build_model(cfg: Dict) -> torch.nn.Module:
    model = resnet50()

    # change input channels number to match the rasterizer's output
    num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2
    num_in_channels = 3 + num_history_channels
    model.conv1 = nn.Conv2d(
        num_in_channels,
        model.conv1.out_channels,
        kernel_size=model.conv1.kernel_size,
        stride=model.conv1.stride,
        padding=model.conv1.padding,
        bias=False,
    )
    # change output size to (X, Y) * number of future states
    num_targets = 2 * cfg["model_params"]["future_num_frames"]
    model.fc = nn.Linear(in_features=2048, out_features=num_targets)

    return model

In [None]:
def forward(data, model, device, criterion):
    # inputs = data["image"].to(device)
    inputs = data["image"].to(device)

    # Min-max scaling
    inputs_min = torch.min(inputs)
    inputs_max = torch.max(inputs)
    inputs = (inputs - inputs_min) / (inputs_max - inputs_min)

    target_availabilities = data["target_availabilities"].unsqueeze(-1).to(device)
    targets = data["target_positions"].to(device)
    # Forward pass
    outputs = model(inputs).reshape(targets.shape)
    loss = criterion(outputs, targets)
    # not all the output steps are valid, but we can filter them out from the loss using availabilities
    loss = loss * target_availabilities
    loss = loss.mean()
    return loss, outputs

### Load the Train Data
- loading the `zarr` into a `ChunkedDataset` object. This object has a reference to the different arrays into the zarr (e.g. agents and traffic lights);
- wrapping the `ChunkedDataset` into an `AgentDataset`, which inherits from torch `Dataset` class;
- passing the `AgentDataset` into a torch `DataLoader`

In [None]:
# ===== INIT DATASET
from torch.utils.data import DataLoader

train_cfg = cfg["train_data_loader"]
rasterizer = build_rasterizer(cfg, dm)
train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()

train_dataset = AgentDataset(cfg, train_zarr, rasterizer)
train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], 
                             num_workers=train_cfg["num_workers"])
print(train_dataset)

In [None]:
train_dataset.__getitem__(1)

In [None]:
# ==== INIT MODEL
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_model(cfg).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.003)
criterion = nn.MSELoss(reduction="none")

### Training


In [None]:
import torch
import learn2learn as l2l

model = build_model(cfg)  # Replace YourModel with your own model
maml = l2l.algorithms.MAML(model, lr=0.0008, first_order=False)  # Create MAML instance
optimizer = torch.optim.SGD(maml.parameters(), lr=0.001)  # Replace with your optimizer

tr_it = iter(train_dataloader)
progress_bar = tqdm(range(20))
losses_train_m_avg = []
losses_train_m = []

adaptation_steps = 2

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Learning rate scheduler

training_start_time = time.time()
for _ in progress_bar:
    try:
        data = next(tr_it)
    except StopIteration:
        tr_it = iter(train_dataloader)
        data = next(tr_it)

    model.train()
    torch.set_grad_enabled(True)

    # Perform fast adaptation using MAML
    learner = maml.clone()  # Create a clone of the model for adaptation

    inputs = data["image"].to(device)

    # Min-max scaling
    inputs_min = torch.min(inputs)
    inputs_max = torch.max(inputs)
    adaptation_data = (inputs - inputs_min) / (inputs_max - inputs_min)
    # adaptation_data = data["image"].to(device)  # Modify the data as needed for adaptation
    adaptation_labels = data["target_positions"].to(device)  # Modify the labels as needed for adaptation

    for step in range(adaptation_steps):
        loss = criterion(learner(adaptation_data).reshape(adaptation_labels.shape), adaptation_labels)
        learner.adapt(loss.mean())

    # Compute loss and update model
    loss, _ = forward(data, learner, device, criterion)  # Use the adapted learner for forward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses_train_m_avg.append(np.mean(losses_train_m))
    losses_train_m.append(loss.item())
    progress_bar.set_description(f"loss: {loss.item()} loss(avg): {np.mean(losses_train_m)}")

    scheduler.step()  # Update the learning rate

print('Training finished, took {:.2f}s'.format(time.time() - training_start_time))


 ### Plot Loss Curve
We can plot the train loss against the iterations (batch-wise)

In [None]:
print(losses_train_m)

In [None]:
plt.plot(np.arange(len(losses_train_m)), losses_train_m, label="train loss")
plt.ylim(0, 45)  # Set the y-axis limits
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(len(losses_train_m_avg)), losses_train_m_avg, label="train loss")
plt.ylim(0, 45)  # Set the y-axis limits
plt.legend()
plt.show()

### Evaluation


In [None]:
# # ===== GENERATE AND LOAD CHOPPED DATASET
num_frames_to_chop = 50
eval_cfg = cfg["sample_data_loader"]
eval_base_path = create_chopped_dataset(dm.require(eval_cfg["key"]), cfg["raster_params"]["filter_agents_threshold"],
                              num_frames_to_chop, cfg["model_params"]["future_num_frames"], MIN_FUTURE_STEPS)
# eval_base_path = "lyft-motion-prediction-autonomous-vehicles/scenes/sample_chopped_50"

In [None]:
rasterizer = build_rasterizer(cfg, dm)
eval_zarr_path = str(Path(eval_base_path) / "sample.zarr")
eval_mask_path = str(Path(eval_base_path) / "mask.npz")
eval_gt_path = str(Path(eval_base_path) / "gt.csv")

eval_zarr = ChunkedDataset(eval_zarr_path).open()
eval_mask = np.load(eval_mask_path)["arr_0"]
# ===== INIT DATASET AND LOAD MASK
eval_dataset = AgentDataset(cfg, eval_zarr, rasterizer, agents_mask=eval_mask)
eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"],
                             num_workers=eval_cfg["num_workers"])
print(eval_dataset)

### Storing Predictions
There is a small catch to be aware of when saving the model predictions. The output of the models are coordinates in `agent` space and we need to convert them into displacements in `world` space.

To do so, we first convert them back into the `world` space and we then subtract the centroid coordinates.

In [None]:
# ==== EVAL LOOP
model.eval()
torch.set_grad_enabled(False)

# store information for evaluation
future_coords_offsets_pd = []
timestamps = []
agent_ids = []

progress_bar = tqdm(eval_dataloader)
for data in progress_bar:
    _, ouputs = forward(data, model, device, criterion)

    # convert agent coordinates into world offsets
    agents_coords = ouputs.cpu().numpy()
    world_from_agents = data["world_from_agent"].numpy()
    centroids = data["centroid"].numpy()
    coords_offset = transform_points(agents_coords, world_from_agents) - centroids[:, None, :2]

    future_coords_offsets_pd.append(np.stack(coords_offset))
    timestamps.append(data["timestamp"].numpy().copy())
    agent_ids.append(data["track_id"].numpy().copy())


### Save results

In [None]:
pred_path = "lyft-motion-prediction-autonomous-vehicles/pred.csv"

write_pred_csv(pred_path,
               timestamps=np.concatenate(timestamps),
               track_ids=np.concatenate(agent_ids),
               coords=np.concatenate(future_coords_offsets_pd),
              )

### Perform Evaluation

In [12]:
from l5kit.evaluation.metrics import average_displacement_error_mean

metrics = compute_metrics_csv(eval_gt_path, pred_path, [neg_multi_log_likelihood, time_displace, average_displacement_error_mean])
for metric_name, metric_mean in metrics.items():
    print(metric_name, metric_mean)


NameError: name 'eval_gt_path' is not defined

### Visualise Results
We can also visualise some results from the ego (AV) point of view for those frames of interest (the 100th of each scene).


In [None]:
model.eval()
torch.set_grad_enabled(False)

# build a dict to retrieve future trajectories from GT
gt_rows = {}
for row in read_gt_csv(eval_gt_path):
    gt_rows[row["track_id"] + row["timestamp"]] = row["coord"]

eval_ego_dataset = EgoDataset(cfg, eval_dataset.dataset, rasterizer)


for frame_number in range(99, len(eval_zarr.frames), 100):  # start from last frame of scene_0 and increase by 100
    agent_indices = eval_dataset.get_frame_indices(frame_number)
    if not len(agent_indices):
        continue

    # get AV point-of-view frame
    data_ego = eval_ego_dataset[frame_number]
    im_ego = rasterizer.to_rgb(data_ego["image"].transpose(1, 2, 0))
    center = np.asarray(cfg["raster_params"]["ego_center"]) * cfg["raster_params"]["raster_size"]

    predicted_positions = []
    target_positions = []

    for v_index in agent_indices:
        data_agent = eval_dataset[v_index]

        out_net = model(torch.from_numpy(data_agent["image"]).unsqueeze(0).to(device))
        out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()
        # store absolute world coordinates
        predicted_positions.append(transform_points(out_pos, data_agent["world_from_agent"]))
        # retrieve target positions from the GT and store as absolute coordinates
        track_id, timestamp = data_agent["track_id"], data_agent["timestamp"]
        target_positions.append(gt_rows[str(track_id) + str(timestamp)] + data_agent["centroid"][:2])


    # convert coordinates to AV point-of-view so we can draw them
    predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego["raster_from_world"])
    target_positions = transform_points(np.concatenate(target_positions), data_ego["raster_from_world"])

    draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)
    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)

    plt.imshow(im_ego)
    plt.show()