In [1]:
#Predict multi-mode with confidence: As written in evaluation metric page, we can predict 3 modes of motion trajectory.
#Training loss with competition evaluation metric
#Use Training abstraction library pytorch-ignite and pytorch-pfn-extras.

In [None]:
#https://github.com/lyft/l5kit/blob/20ab033c01610d711c3d36e1963ecec86e8b85b6/l5kit/l5kit/evaluation/csv_utils.py#L140

In [None]:
#https://github.com/lyft/l5kit/blob/master/competition.md

In [1]:
import gc
import os
from pathlib import Path
import random
import sys

from tqdm import tqdm
import numpy as np
import pandas as pd
import scipy as sp


import matplotlib.pyplot as plt
import seaborn as sns

from IPython.core.display import display, HTML

# --- plotly ---
from plotly import tools, subplots
import plotly.offline as py
py.init_notebook_mode(connected=True)
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff
import plotly.io as pio
pio.templates.default = "plotly_dark"

# --- models ---
from sklearn import preprocessing
from sklearn.model_selection import KFold
import lightgbm as lgb
import xgboost as xgb
import catboost as cb

# --- setup ---
pd.set_option('max_columns', 50)


Starting from version 2.2.1, the library file in distribution wheels for macOS is built by the Apple Clang (Xcode_9.4.1) compiler.
This means that in case of installing LightGBM from PyPI via the ``pip install lightgbm`` command, you don't need to install the gcc compiler anymore.
Instead of that, you need to install the OpenMP library, which is required for running LightGBM on the system with the Apple Clang compiler.
You can install the OpenMP library by the following command: ``brew install libomp``.



In [2]:
import torch
from pathlib import Path

import pytorch_pfn_extras as ppe
from math import ceil
from pytorch_pfn_extras.training import IgniteExtensionsManager
from pytorch_pfn_extras.training.triggers import MinValueTrigger

from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
import pytorch_pfn_extras.training.extensions as E

In [3]:
# --- Dataset utils ---
from typing import Callable

from torch.utils.data.dataset import Dataset


class TransformDataset(Dataset):
    def __init__(self, dataset: Dataset, transform: Callable):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        batch = self.dataset[index]
        return self.transform(batch)

    def __len__(self):
        return len(self.dataset)

In [4]:
import zarr

import l5kit
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import EgoDataset, AgentDataset

from l5kit.rasterization import build_rasterizer
from l5kit.configs import load_config_data
from l5kit.visualization import draw_trajectory, TARGET_POINTS_COLOR
from l5kit.geometry import transform_points
from tqdm import tqdm
from collections import Counter
from l5kit.data import PERCEPTION_LABELS
from prettytable import PrettyTable
from l5kit.evaluation import write_pred_csv

from matplotlib import animation, rc
from IPython.display import HTML

rc('animation', html='jshtml')
print("l5kit version:", l5kit.__version__)

l5kit version: 1.1.0


In [5]:
#Function
#To define loss function to calculate competition evaluation metric in batch.

In [6]:
# --- Function utils ---
# Original code from https://github.com/lyft/l5kit/blob/20ab033c01610d711c3d36e1963ecec86e8b85b6/l5kit/l5kit/evaluation/metrics.py
import numpy as np

import torch
from torch import Tensor


def pytorch_neg_multi_log_likelihood_batch(
    gt: Tensor, pred: Tensor, confidences: Tensor, avails: Tensor
) -> Tensor:
    """
    Compute a negative log-likelihood for the multi-modal scenario.
    log-sum-exp trick is used here to avoid underflow and overflow, For more information about it see:
    https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations
    https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
    https://leimao.github.io/blog/LogSumExp/
    Args:
        gt (Tensor): array of shape (bs)x(time)x(2D coords)
        pred (Tensor): array of shape (bs)x(modes)x(time)x(2D coords)
        confidences (Tensor): array of shape (bs)x(modes) with a confidence for each mode in each sample
        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep
    Returns:
        Tensor: negative log-likelihood for this example, a single float number
    """
    assert len(pred.shape) == 4, f"expected 3D (MxTxC) array for pred, got {pred.shape}"
    batch_size, num_modes, future_len, num_coords = pred.shape

    assert gt.shape == (batch_size, future_len, num_coords), f"expected 2D (Time x Coords) array for gt, got {gt.shape}"
    assert confidences.shape == (batch_size, num_modes), f"expected 1D (Modes) array for gt, got {confidences.shape}"
    assert torch.allclose(torch.sum(confidences, dim=1), confidences.new_ones((batch_size,))), "confidences should sum to 1"
    assert avails.shape == (batch_size, future_len), f"expected 1D (Time) array for gt, got {avails.shape}"
    # assert all data are valid
    assert torch.isfinite(pred).all(), "invalid value found in pred"
    assert torch.isfinite(gt).all(), "invalid value found in gt"
    assert torch.isfinite(confidences).all(), "invalid value found in confidences"
    assert torch.isfinite(avails).all(), "invalid value found in avails"

    # convert to (batch_size, num_modes, future_len, num_coords)
    gt = torch.unsqueeze(gt, 1)  # add modes
    avails = avails[:, None, :, None]  # add modes and cords

    # error (batch_size, num_modes, future_len)
    error = torch.sum(((gt - pred) * avails) ** 2, dim=-1)  # reduce coords and use availability

    with np.errstate(divide="ignore"):  # when confidence is 0 log goes to -inf, but we're fine with it
        # error (batch_size, num_modes)
        error = torch.log(confidences) - 0.5 * torch.sum(error, dim=-1)  # reduce time

    # use max aggregator on modes for numerical stability
    # error (batch_size, num_modes)
    max_value, _ = error.max(dim=1, keepdim=True)  # error are negative at this point, so max() gives the minimum one
    error = -torch.log(torch.sum(torch.exp(error - max_value), dim=-1, keepdim=True)) - max_value  # reduce modes
    # print("error", error)
    return torch.mean(error)


def pytorch_neg_multi_log_likelihood_single(
    gt: Tensor, pred: Tensor, avails: Tensor
) -> Tensor:
    """

    Args:
        gt (Tensor): array of shape (bs)x(time)x(2D coords)
        pred (Tensor): array of shape (bs)x(time)x(2D coords)
        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep
    Returns:
        Tensor: negative log-likelihood for this example, a single float number
    """
    # pred (bs)x(time)x(2D coords) --> (bs)x(mode=1)x(time)x(2D coords)
    # create confidence (bs)x(mode=1)
    batch_size, future_len, num_coords = pred.shape
    confidences = pred.new_ones((batch_size, 1))
    return pytorch_neg_multi_log_likelihood_batch(gt, pred.unsqueeze(1), confidences, avails)

In [7]:
#Model
#pytorch model definition. Here model outputs both multi-mode trajectory prediction & confidence of each trajectory.

In [8]:
# --- Model utils ---
import torch
from torchvision.models import resnet34
from torch import nn
from typing import Dict


class LyftMultiModel(nn.Module):

    def __init__(self, cfg: Dict, num_modes=3):
        super().__init__()

        # TODO: support other than resnet18?
        backbone = resnet34(pretrained=True, progress=True)
        self.backbone = backbone

        num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2
        num_in_channels = 3 + num_history_channels

        self.backbone.conv1 = nn.Conv2d(
            num_in_channels,
            self.backbone.conv1.out_channels,
            kernel_size=self.backbone.conv1.kernel_size,
            stride=self.backbone.conv1.stride,
            padding=self.backbone.conv1.padding,
            bias=False,
        )

        # This is 512 for resnet18 and resnet34;
        # And it is 2048 for the other resnets
        backbone_out_features = 512

        # X, Y coords for the future positions (output shape: Bx50x2)
        self.future_len = cfg["model_params"]["future_num_frames"]
        num_targets = 2 * self.future_len

        # You can add more layers here.
        self.head = nn.Sequential(
            # nn.Dropout(0.2),
            nn.Linear(in_features=backbone_out_features, out_features=4096),
        )

        self.num_preds = num_targets * num_modes
        self.num_modes = num_modes

        self.logit = nn.Linear(4096, out_features=self.num_preds + num_modes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.head(x)
        x = self.logit(x)

        # pred (bs)x(modes)x(time)x(2D coords)
        # confidences (bs)x(modes)
        bs, _ = x.shape
        pred, confidences = torch.split(x, self.num_preds, dim=1)
        pred = pred.view(bs, self.num_modes, self.future_len, 2)
        assert confidences.shape == (bs, self.num_modes)
        confidences = torch.softmax(confidences, dim=1)
        return pred, confidences


In [None]:
print(LyftMultiModel)

In [9]:
class LyftMultiRegressor(nn.Module):
    """Single mode prediction"""

    def __init__(self, predictor, lossfun=pytorch_neg_multi_log_likelihood_batch):
        super().__init__()
        self.predictor = predictor
        self.lossfun = lossfun

    def forward(self, image, targets, target_availabilities):
        pred, confidences = self.predictor(image)
        loss = self.lossfun(targets, pred, confidences, target_availabilities)
        metrics = {
            "loss": loss.item(),
            "nll": pytorch_neg_multi_log_likelihood_batch(targets, pred, confidences, target_availabilities).item()
        }
        ppe.reporting.report(metrics, self)
        return loss, metrics

In [10]:
#Training with Ignite
#I use pytorch-ignite for training abstraction.
#Engine defines the 1 iteration of training update.

In [11]:
# --- Training utils ---
from ignite.engine import Engine


def create_trainer(model, optimizer, device) -> Engine:
    model.to(device)

    def update_fn(engine, batch):
        model.train()
        optimizer.zero_grad()
        loss, metrics = model(*[elem.to(device) for elem in batch])
        loss.backward()
        optimizer.step()
        return metrics
    trainer = Engine(update_fn)
    return trainer

In [12]:
# Code referenced from https://github.com/grafi-tt/chaineripy/blob/master/chaineripy/extensions/print_report.py by @grafi-tt
# Modified to work with pytorch_pfn_extras

import os
import sys
from copy import deepcopy

from IPython.core.display import display
from ipywidgets import HTML

from pytorch_pfn_extras.training.extensions.print_report import PrintReport

from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training.extensions import log_report \
    as log_report_module
from pytorch_pfn_extras.training.extensions import util


class PrintReportNotebook(PrintReport):

    """An extension to print the accumulated results.

    This extension uses the log accumulated by a :class:`LogReport` extension
    to print specified entries of the log in a human-readable format.

    Args:
        entries (list of str ot None): List of keys of observations to print.
            If `None` is passed, automatically infer keys from reported dict.
        log_report (str or LogReport): Log report to accumulate the
            observations. This is either the name of a LogReport extensions
            registered to the manager, or a LogReport instance to use
            internally.
        out: Stream to print the bar. Standard output is used by default.

    """

    def __init__(self, entries=None, log_report='LogReport', out=sys.stdout):
        super(PrintReportNotebook, self).__init__(entries=entries, log_report=log_report, out=out)
        self._widget = HTML()

    def initialize(self, trainer):
        display(self._widget)

    @property
    def widget(self):
        return self._widget

    def __call__(self, manager):
        log_report = self.get_log_report(manager)
        df = log_report.to_dataframe()
        if self._infer_entries:
            # --- update entries ---
            self._update_entries(log_report)
        self._widget.value = df[self._entries].to_html(index=False, na_rep='')

In [13]:
# Code referenced from https://github.com/grafi-tt/chaineripy/blob/master/chaineripy/extensions/progress_bar.py by @grafi-tt
# Modified to work with pytorch_pfn_extras

from pytorch_pfn_extras.training import extension, trigger
import datetime
import time

from IPython.core.display import display
from ipywidgets import FloatProgress, HBox, HTML, VBox


class ProgressBarNotebook(extension.Extension):

    """Trainer extension to print a progress bar and recent training status.
    This extension prints a progress bar at every call. It watches the current
    iteration and epoch to print the bar.
    Args:
        training_length (tuple): Length of whole training. It consists of an
            integer and either ``'epoch'`` or ``'iteration'``. If this value is
            omitted and the stop trigger of the trainer is
            :class:`IntervalTrigger`, this extension uses its attributes to
            determine the length of the training.
        update_interval (int): Number of iterations to skip printing the
            progress bar.
        bar_length (int): Length of the progress bar in characters.
        out: Stream to print the bar. Standard output is used by default.
    """

    def __init__(self, training_length=None, update_interval=100,
                 bar_length=50):
        self._training_length = training_length
        if training_length is not None:
            self._init_status_template()
        self._update_interval = update_interval
        self._recent_timing = []

        self._total_bar = FloatProgress(description='total',
                                        min=0, max=1, value=0,
                                        bar_style='info')
        self._total_html = HTML()
        self._epoch_bar = FloatProgress(description='this epoch',
                                        min=0, max=1, value=0,
                                        bar_style='info')
        self._epoch_html = HTML()
        self._status_html = HTML()

        self._widget = VBox([HBox([self._total_bar, self._total_html]),
                             HBox([self._epoch_bar, self._epoch_html]),
                             self._status_html])

    def initialize(self, manager):
        if self._training_length is None:
            t = manager._stop_trigger
            if not isinstance(t, trigger.IntervalTrigger):
                raise TypeError(
                    'cannot retrieve the training length from %s' % type(t))
            self._training_length = t.period, t.unit
            self._init_status_template()

        updater = manager.updater
        self.update(updater.iteration, updater.epoch_detail)
        display(self._widget)

    def __call__(self, manager):
        length, unit = self._training_length

        updater = manager.updater
        iteration, epoch_detail = updater.iteration, updater.epoch_detail

        if unit == 'iteration':
            is_finished = iteration == length
        else:
            is_finished = epoch_detail == length

        if iteration % self._update_interval == 0 or is_finished:
            self.update(iteration, epoch_detail)

    def finalize(self):
        if self._total_bar.value != 1:
            self._total_bar.bar_style = 'warning'
            self._epoch_bar.bar_style = 'warning'

    @property
    def widget(self):
        return self._widget

    def update(self, iteration, epoch_detail):
        length, unit = self._training_length

        recent_timing = self._recent_timing
        now = time.time()

        recent_timing.append((iteration, epoch_detail, now))

        if unit == 'iteration':
            rate = iteration / length
        else:
            rate = epoch_detail / length
        self._total_bar.value = rate
        self._total_html.value = "{:6.2%}".format(rate)

        epoch_rate = epoch_detail - int(epoch_detail)
        self._epoch_bar.value = epoch_rate
        self._epoch_html.value = "{:6.2%}".format(epoch_rate)

        status = self._status_template.format(iteration=iteration,
                                              epoch=int(epoch_detail))

        if rate == 1:
            self._total_bar.bar_style = 'success'
            self._epoch_bar.bar_style = 'success'

        old_t, old_e, old_sec = recent_timing[0]
        span = now - old_sec
        if span != 0:
            speed_t = (iteration - old_t) / span
            speed_e = (epoch_detail - old_e) / span
        else:
            speed_t = float('inf')
            speed_e = float('inf')

        if unit == 'iteration':
            estimated_time = (length - iteration) / speed_t
        else:
            estimated_time = (length - epoch_detail) / speed_e
        estimate = ('{:10.5g} iters/sec. Estimated time to finish: {}.'
                    .format(speed_t,
                            datetime.timedelta(seconds=estimated_time)))

        self._status_html.value = status + estimate

        if len(recent_timing) > 100:
            del recent_timing[0]

    def _init_status_template(self):
        self._status_template = (
            '{iteration:10} iter, {epoch} epoch / %s %ss<br />' %
            self._training_length)

In [14]:
# --- Utils ---
import yaml


def save_yaml(filepath, content, width=120):
    with open(filepath, 'w') as f:
        yaml.dump(content, f, width=width)


def load_yaml(filepath):
    with open(filepath, 'r') as f:
        content = yaml.safe_load(f)
    return content


class DotDict(dict):
    """dot.notation access to dictionary attributes

    Refer: https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary/23689767#23689767
    """  # NOQA

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [15]:
#Configs

In [16]:
# --- Lyft configs ---
cfg = {
    'format_version': 4,
    'model_params': {
        'model_architecture': 'resnet34',
        'history_num_frames': 10,
        'history_step_size': 1,
        'history_delta_time': 0.1,
        'future_num_frames': 50,
        'future_step_size': 1,
        'future_delta_time': 0.1
    },

    'raster_params': {
        'raster_size': [534,534], # 224*224, 300*300, 350*350, 448*448, you can give a try to set 224*224
        'pixel_size': [0.25, 0.25],
        'ego_center': [0.25, 0.5],
        'map_type': 'py_semantic',
        'satellite_map_key': 'aerial_map/aerial_map.png',
        'semantic_map_key': 'semantic_map/semantic_map.pb',
        'dataset_meta_key': 'meta.json',
        'filter_agents_threshold': 0.5
    },

    'train_data_loader': {
        'key': 'scenes/train.zarr',
        'batch_size': 64, #  16, 32, you can give a try to set 16
        'shuffle': True,
        'num_workers': 4
    },

    'valid_data_loader': {
        'key': 'scenes/validate.zarr',
        'batch_size': 64, #  16, 32, you can give a try to set 16
        'shuffle': False,
        'num_workers': 4
    },

    'train_params': {
        'max_num_steps': 10000, # 10000, 20000, 30000, 600k, 700k, you can give a try to set 1000
        'checkpoint_every_n_steps': 5000,

    }
}

In [17]:
flags_dict = {
    "debug": False,
    # --- Data configs ---
    "l5kit_data_folder": "/Users/shuozhang/Downloads/lyft-motion-prediction-autonomous-vehicles/", # change this
    # --- Model configs ---
    "pred_mode": "multi",
    # --- Training configs ---
    "device": "cpu", # change this to 'cuba:0' if put on server
    "out_dir": "results/multi_train",
    "epoch": 2, 
    "snapshot_freq": 50,
}

In [18]:
flags = DotDict(flags_dict)
out_dir = Path(flags.out_dir)
os.makedirs(str(out_dir), exist_ok=True)
print(f"flags: {flags_dict}")
save_yaml(out_dir / 'flags.yaml', flags_dict)
save_yaml(out_dir / 'cfg.yaml', cfg)
debug = flags.debu

flags: {'debug': False, 'l5kit_data_folder': '/Users/shuozhang/Downloads/lyft-motion-prediction-autonomous-vehicles/', 'pred_mode': 'multi', 'device': 'cpu', 'out_dir': 'results/multi_train', 'epoch': 2, 'snapshot_freq': 50}


In [19]:
#Loading data

In [20]:
# # set env variable for data
# os.environ["L5KIT_DATA_FOLDER"] = flags.l5kit_data_folder
# dm = LocalDataManager(None)

# print("Load dataset...")
# train_cfg = cfg["train_data_loader"]

# # Rasterizer
# rasterizer = build_rasterizer(cfg, dm)

# # Train dataset/dataloader
# def transform(batch):
#     return batch["image"], batch["target_positions"], batch["target_availabilities"]


# train_path = "scenes/sample.zarr" if debug else train_cfg["key"]
# train_zarr = ChunkedDataset(dm.require(train_path)).open()
# print("train_zarr", type(train_zarr))
# train_agent_dataset = AgentDataset(cfg, train_zarr, rasterizer)
# train_dataset = TransformDataset(train_agent_dataset, transform)
# if debug:
#     # Only use 1000 dataset for fast check...
#     train_dataset = Subset(train_dataset, np.arange(1000))
# train_loader = DataLoader(train_dataset,
#                           shuffle=train_cfg["shuffle"],
#                           batch_size=train_cfg["batch_size"],
#                           num_workers=train_cfg["num_workers"])
# print(train_agent_dataset)
# print("# AgentDataset train:", len(train_agent_dataset))
# print("# ActualDataset train:", len(train_dataset))

Load dataset...
train_zarr <class 'l5kit.data.zarr_dataset.ChunkedDataset'>
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Num TR lights | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   16265    |  4039527   | 320124624  |    38735988   |      112.19     |        248.36        |        79.25         |        24.83         |        10.00        |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
# AgentDataset train: 22496709
# ActualDataset train: 22496709


In [21]:
# from typing import Dict

# from tempfile import gettempdir
# import matplotlib.pyplot as plt
# import numpy as np
# import torch
# from torch import nn, optim
# from torch.utils.data import DataLoader
# from torchvision.models.resnet import resnet50
# from tqdm import tqdm

# from l5kit.configs import load_config_data
# from l5kit.data import LocalDataManager, ChunkedDataset
# from l5kit.dataset import AgentDataset, EgoDataset
# from l5kit.rasterization import build_rasterizer
# from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
# from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
# from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
# from l5kit.geometry import transform_points
# from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
# from prettytable import PrettyTable
# from pathlib import Path

# import os

In [22]:
# # ===== GENERATE AND LOAD CHOPPED DATASET
# valid_cfg = cfg["valid_data_loader"]
# valid_path = "scenes/sample.zarr" if debug else valid_cfg["key"]
# num_frames_to_chop = 100
# valid_base_path = create_chopped_dataset(dm.require(valid_cfg["key"]), cfg["raster_params"]["filter_agents_threshold"], 
#                               num_frames_to_chop, cfg["model_params"]["future_num_frames"], MIN_FUTURE_STEPS)

copying: 100%|██████████| 16220/16220 [04:00<00:00, 67.30it/s]

you're running with a custom agents_mask

extracting GT: 100%|██████████| 94694/94694 [06:46<00:00, 233.01it/s]


In [23]:
# valid_zarr_path = str(Path(valid_base_path) / Path(dm.require(valid_cfg["key"])).name)
# valid_mask_path = str(Path(valid_base_path) / "mask.npz")
# valid_gt_path = str(Path(valid_base_path) / "gt.csv")

# valid_zarr = ChunkedDataset(valid_zarr_path).open()
# valid_mask = np.load(valid_mask_path)["arr_0"]
# # ===== INIT DATASET AND LOAD MASK
# valid_agent_dataset = AgentDataset(cfg, valid_zarr, rasterizer, agents_mask=valid_mask)
# valid_dataset = TransformDataset(valid_agent_dataset, transform)
# valid_loader = DataLoader(valid_dataset,
#                           shuffle=valid_cfg["shuffle"],
#                           batch_size=valid_cfg["batch_size"],
#                           num_workers=valid_cfg["num_workers"])

# print(valid_agent_dataset)
# print("# AgentDataset valid:", len(valid_agent_dataset))


you're running with a custom agents_mask



+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Num TR lights | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   16220    |  1622000   | 125423254  |    11733321   |      45.06      |        100.00        |        77.33         |        10.00         |        10.00        |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
# AgentDataset valid: 94694


In [30]:
# print("# ActualDataset valid:", len(valid_dataset))

# ActualDataset valid: 94694


In [None]:
#Prepare model & optimizer

In [19]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
    model = LyftMultiRegressor(predictor)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
#Write training code
#pytorch-ignite & pytorch-pfn-extras are used here.

#pytorch/ignite: It provides abstraction for writing training loop.
#pfnet/pytorch-pfn-extras: It provides several "extensions" useful for training.
#Useful for logging, printing, evaluating, saving the model, scheduling the learning rate during training.

In [None]:
# Train setup
# trainer = create_trainer(model, optimizer, device)


# def eval_func(*batch):
#     loss, metrics = model(*[elem.to(device) for elem in batch])


# valid_evaluator = E.Evaluator(
#     valid_loader,
#     model,
#     progress_bar=True,
#     eval_func=eval_func,
# )

# log_trigger = (10 if debug else 1000, "iteration")
# log_report = E.LogReport(trigger=log_trigger)


# extensions = [
#     log_report,  # Save `log` to file
#     valid_evaluator,  # Run evaluation for valid dataset in each epoch.
#     # E.FailOnNonNumber()  # Stop training when nan is detected.
# ]

# is_notebook = False  # Make it False when you run code in local machine using console.
# if is_notebook:
#     extensions.extend([
#         ProgressBarNotebook(update_interval=10 if debug else 100),  # Show progress bar during training
#         PrintReportNotebook(),  # Show "log" on jupyter notebook  
#     ])
# else:
#     extensions.extend([
#         E.ProgressBar(update_interval=10 if debug else 100),  # Show progress bar during training
#         E.PrintReport(),  # Print "log" to terminal
#     ])


# epoch = flags.epoch

# models = {"main": model}
# optimizers = {"main": optimizer}
# manager = IgniteExtensionsManager(
#     trainer,
#     models,
#     optimizers,
#     epoch,
#     extensions=extensions,
#     out_dir=str(out_dir),
# )
# # Save predictor.pt every epoch
# manager.extend(E.snapshot_object(predictor, "predictor.pt"),
#                trigger=(flags.snapshot_freq, "iteration"))
# # Check & Save best validation predictor.pt every epoch
# # manager.extend(E.snapshot_object(predictor, "best_predictor.pt"),
# #                trigger=MinValueTrigger("validation/main/nll", trigger=(flags.snapshot_freq, "iteration")))
# # --- lr scheduler ---
# # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
# #     optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-10)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(
#     optimizer, gamma=0.99999)
# manager.extend(lambda manager: scheduler.step(), trigger=(1, "iteration"))
# # Show "lr" column in log
# manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)

# trainer.run(train_loader, max_epochs=epoch)

In [None]:
# df = log_report.to_dataframe()
# df.to_csv("log.csv", index=False)
# df[["epoch", "iteration", "main/loss", "main/nll", "validation/main/loss", "validation/main/nll", "lr", "elapsed_time"]]

In [None]:
# Extensions - Each role:

# ProgressBar (ProgressBarNotebook): Shows training progress in formatted style.
# LogReport: Logging metrics reported by ppe.reporter.report 
#(see LyftMultiRegressor for reporting point) method and save to log file. 
#It automatically collects reported value in each iteration and saves the "mean" of reported value 
#for regular frequency (for example every 1 epoch).
# PrintReport (PrintReportNotebook): Prints the value which LogReport collected in formatted style.
# Evaluator: Evaluate on validation dataset.
# snapshot_object: Saves the object. Here the model is saved in regular interval flags.snapshot_freq.
#Even you quit training using Ctrl+C without finishing all the epoch, 
#the intermediate trained model is saved and you can use it for inference.
# lambda function with scheduler.step(): You can invoke any function in regular interval specified by trigger.
#Here exponential decay of learning rate is applied by calling scheduler.step() every iteration.
# observe_lr: LogReport will check optimizer's learning rate using this extension.
#So you can follow how the learning rate changed through the training.

In [None]:
# debug result: 
# iteration	main/loss	main/nll	lr	    validation/main/loss	validation/main/nll	elapsed_time
# 10	   2566.690948	2566.690948	0.001000			                               10.377875
# 20	   1540.009851	1540.009851	0.001000			                               17.131299
# 30	   970.311111	970.311111	0.001000			                               25.948778
# 40	   1260.554537	1260.554537	0.001000			                               32.972381
# 50	   1403.084926	1403.084926	0.001000			                               41.940716
# 60	   1522.389905	1522.389905	0.000999			                               49.204889
# 70	   1106.671555	1106.671555	0.000999			                               57.684869
# 80	   856.877580	856.877580	0.000999			                               64.952417
# 90	   957.383397	957.383397	0.000999	4912.747314	      4912.747314	       79.925156
# 100	   911.428769	911.428769	0.000999			                               86.664308
# 110	   441.032510	441.032510	0.000999			                               95.298566
# 120	   332.250836	332.250836	0.000999			                              102.593534
# 130	   814.817209	814.817209	0.000999			                              111.937015
# 140	   1075.804512	1075.804512	0.000999			                              118.627421
# 150	   1094.629809	1094.629809	0.000999			                              127.598126
# 160	  627.256924	627.256924	0.000998			                              134.250497

In [21]:
22496709/64*2

703022.15625

In [21]:
def run_prediction(predictor, data_loader):
    predictor.eval()

    pred_coords_list = []
    confidences_list = []
    timestamps_list = []
    track_id_list = []

    with torch.no_grad():
        dataiter = tqdm(data_loader)
        for data in dataiter:
            image = data["image"].to(device)
            # target_availabilities = data["target_availabilities"].to(device)
            # targets = data["target_positions"].to(device)
            preds, confidences = predictor(image)
            # convert agent coordinates into world offsets
            preds = preds.cpu().numpy().copy()
            world_from_agents = data["world_from_agent"].numpy()
            centroids = data["centroid"].numpy()
            coords_offset = []
        
        # convert into world coordinates and compute offsets
            for idx in range(len(preds)):
                for mode in range(3):
                    preds[idx, mode, :, :] = transform_points(preds[idx, mode, :, :], world_from_agents[idx]) - centroids[idx][:2]
            
            confidences_list.append(confidences.cpu().numpy().copy())
            timestamps_list.append(data["timestamp"].numpy().copy())
            track_id_list.append(data["track_id"].numpy().copy())
            pred_coords_list.append(preds.copy())
    timestamps = np.concatenate(timestamps_list)
    track_ids = np.concatenate(track_id_list)
    coords = np.concatenate(pred_coords_list)
    confs = np.concatenate(confidences_list)
    return timestamps, track_ids, coords, confs

In [22]:
# set env variable for data
l5kit_data_folder = "/Users/shuozhang/Downloads/lyft-motion-prediction-autonomous-vehicles"
os.environ["L5KIT_DATA_FOLDER"] = l5kit_data_folder
dm = LocalDataManager(None)

print("Load dataset...")
default_test_cfg = {
    'key': 'scenes/test.zarr',
    'batch_size': 64,
    'shuffle': False,
    'num_workers': 4
}
test_cfg = cfg.get("test_data_loader", default_test_cfg)

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

test_path = test_cfg["key"]
print(f"Loading from {test_path}")
test_zarr = ChunkedDataset(dm.require(test_path)).open()
print("test_zarr", type(test_zarr))
test_mask = np.load(f"{l5kit_data_folder}/scenes/mask.npz")["arr_0"]
test_agent_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask)
test_dataset = test_agent_dataset
if debug:
    # Only use 100 dataset for fast check...
    test_dataset = Subset(test_dataset, np.arange(100))
test_loader = DataLoader(
    test_dataset,
    shuffle=test_cfg["shuffle"],
    batch_size=test_cfg["batch_size"],
    num_workers=test_cfg["num_workers"],
    pin_memory=True,
)

print(test_agent_dataset)
print("# AgentDataset test:", len(test_agent_dataset))
print("# ActualDataset test:", len(test_dataset))

Load dataset...
Loading from scenes/test.zarr
test_zarr <class 'l5kit.data.zarr_dataset.ChunkedDataset'>
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Num TR lights | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   11314    |  1131400   |  88594921  |    7854144    |      31.43      |        100.00        |        78.31         |        10.00         |        10.00        |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
# AgentDataset test: 71122
# ActualDataset test: 71122



you're running with a custom agents_mask



In [31]:
# 重新跑

In [23]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_31800.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_31800.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [24]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [11:16:05<00:00, 36.48s/it]  


In [25]:
csv_path = "model8_2_64.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_2_64.csv


In [40]:
csv_path = "model8_2_32.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" ) #67.916

Saved to model8_2.csv


In [26]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_66000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_66000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [27]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [11:36:45<00:00, 37.60s/it]  


In [28]:
csv_path = "model8_4.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_4.csv


In [29]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_97000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_97000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [30]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [12:10:03<00:00, 39.39s/it]  


In [31]:
csv_path = "model8_6.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_6.csv


In [32]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_112000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_112000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [33]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [11:15:48<00:00, 36.46s/it]  


In [34]:
csv_path = "model8_7.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_7.csv


In [35]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_215000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_215000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [36]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [11:41:59<00:00, 37.88s/it]  


In [37]:
csv_path = "model8_14.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_14.csv


In [38]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_239000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_239000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [39]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [12:19:51<00:00, 39.92s/it]  


In [40]:
csv_path = "model8_16.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_16.csv


In [41]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_266000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_266000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [42]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [11:46:21<00:00, 38.11s/it]  


In [43]:
csv_path = "model8_19.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_19.csv


In [44]:
device = torch.device(flags.device)

if flags.pred_mode == "multi":
    predictor = LyftMultiModel(cfg)
else:
    raise ValueError(f"[ERROR] Unexpected value flags.pred_mode={flags.pred_mode}")

pt_path = "/Users/shuozhang/Downloads/predictor_resent34_534_batch64_281000.pt" 
# change the file for pt
print(f"Loading from {pt_path}")
predictor.load_state_dict(torch.load(pt_path,map_location=torch.device('cpu')))
predictor.to(device)

Loading from /Users/shuozhang/Downloads/predictor_resent34_534_batch64_281000.pt


LyftMultiModel(
  (backbone): ResNet(
    (conv1): Conv2d(25, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [45]:
timestamps, track_ids, coords, confs = run_prediction(predictor, test_loader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 1112/1112 [15:26:04<00:00, 49.97s/it]   


In [46]:
csv_path = "model8_20.csv"
write_pred_csv(
    csv_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)
print(f"Saved to {csv_path}" )

Saved to model8_20.csv


In [None]:
# iteration   main/loss   main/nll    lr          elapsed_time..]  0.14%
# 1000        246.608     246.608     0.00099006  2565.82       
# 2000        134.495     134.495     0.000980208  5040.86       
# 3000        124.541     124.541     0.000970455  7578.24       
# 4000        113.603     113.603     0.000960799  10112.7       
# 5000        102.806     102.806     0.000951239  12566         
# 6000        80.8249     80.8249     0.000941774  15043.2       
# 7000        69.7267     69.7267     0.000932403  17604.6       
# 8000        65.1087     65.1087     0.000923125  20250.5       
# 9000        56.577      56.577      0.00091394  23009.6       
# 10000       48.5363     48.5363     0.000904846  25476.3       
# 11000       46.8095     46.8095     0.000895843  27943.2       
# 12000       44.5082     44.5082     0.000886929  30435.1       
# 13000       40.8986     40.8986     0.000878104  32903.5       
# 14000       39.1917     39.1917     0.000869366  35379         
# 15000       38.335      38.335      0.000860716  37877.8       
# 16000       35.9245     35.9245     0.000852152  40568.5       
# 17000       34.0565     34.0565     0.000843673  43390.8       
# 18000       32.0918     32.0918     0.000835278  46231.3       
# 19000       30.5138     30.5138     0.000826967  49064.1       
# 20000       28.8933     28.8933     0.000818738  51734.6       
# 21000       29.4042     29.4042     0.000810592  54252.5       
# 22000       28.8084     28.8084     0.000802526  56748.2       
# 23000       26.66       26.66       0.000794541  59256.8       
# 24000       26.7328     26.7328     0.000786635  61746.2       
# 25000       26.9661     26.9661     0.000778808  64237.5       
# 26000       24.2706     24.2706     0.000771058  66725.8       
# 27000       23.7597     23.7597     0.000763386  69233.3       
# 28000       23.5241     23.5241     0.00075579  71717.9       
# 29000       23.5805     23.5805     0.00074827  74203.5       
# 30000       23.0245     23.0245     0.000740825  76675.5       
# 31000       22.6089     22.6089     0.000733453  79151.4       
# 32000       21.9162     21.9162     0.000726155  81638.5       
# 33000       22.5695     22.5695     0.00071893  84129.7       
# 34000       21.5089     21.5089     0.000711776  86610.3       
# 35000       20.805      20.805      0.000704694  89115.4       
# 36000       20.3881     20.3881     0.000697682  91619.5       
# 37000       20.6398     20.6398     0.00069074  94139.5       
# 38000       21.3689     21.3689     0.000683867  96663.5       
# 39000       20.374      20.374      0.000677062  99472.1       
# 40000       19.7404     19.7404     0.000670325  102378        
# 41000       19.2853     19.2853     0.000663656  105261        
# 42000       19.8043     19.8043     0.000657052  108158        
# 43000       19.3634     19.3634     0.000650514  111026        
# 44000       19.3609     19.3609     0.000644041  113891        
# 45000       19.0078     19.0078     0.000637633  116727        
# 46000       19.6359     19.6359     0.000631289  119591        
# 47000       18.2572     18.2572     0.000625007  122515        
# 48000       17.9629     17.9629     0.000618788  125078        
# 49000       18.8436     18.8436     0.000612631  127557        
# 50000       17.9263     17.9263     0.000606535  129969        
# 51000       17.6406     17.6406     0.0006005   132392        
# 52000       17.7651     17.7651     0.000594525  134789        
# 53000       17.915      17.915      0.000588609  137191        
# 54000       17.3214     17.3214     0.000582753  139649        
# 55000       16.7883     16.7883     0.000576954  142078        
# 56000       17.0955     17.0955     0.000571213  144482        
# 57000       16.9098     16.9098     0.000565529  146903        
# 58000       16.9515     16.9515     0.000559902  149316        
# 59000       15.9669     15.9669     0.000554331  151750        
# 60000       16.5633     16.5633     0.000548815  154175        
# 61000       15.828      15.828      0.000543355  156617        
# 62000       16.3236     16.3236     0.000537948  159033        
# 63000       16.8483     16.8483     0.000532595  161431        
# 64000       16.213      16.213      0.000527296  163837 
# 65000       15.9555     15.9555     0.000522049  166304        
# 66000       16.2599     16.2599     0.000516855  168717        
# 67000       15.2751     15.2751     0.000511712  171145        
# 68000       16.1139     16.1139     0.00050662  173633        
# 69000       16.1092     16.1092     0.000501579  176120        
# 70000       15.6411     15.6411     0.000496589  178625        
# 71000       15.8849     15.8849     0.000491647  181105        
# 72000       15.5516     15.5516     0.000486755  183523        
# 73000       15.0486     15.0486     0.000481912  186146        
# 74000       15.3272     15.3272     0.000477117  189065        
# 75000       14.8535     14.8535     0.00047237  191473        
# 76000       15.2367     15.2367     0.000467669  193909        
# 77000       15.3614     15.3614     0.000463016  196310        
# 78000       15.3999     15.3999     0.000458409  198705        
# 79000       15.2488     15.2488     0.000453848  201122        
# 80000       14.3557     14.3557     0.000449332  203519        
# 81000       14.6494     14.6494     0.000444861  206163        
# 82000       15.128      15.128      0.000440434  209075        
# 83000       15.0283     15.0283     0.000436052  211951        
# 84000       14.2413     14.2413     0.000431713  214847        
# 85000       14.6462     14.6462     0.000427417  217497        
# 86000       14.2119     14.2119     0.000423164  219899      
# 87000       14.5498     14.5498     0.000418954  222433        
# 88000       14.0853     14.0853     0.000414785  225279        
# 89000       13.7793     13.7793     0.000410658  228225        
# 90000       14.3565     14.3565     0.000406572  231161        
# 91000       14.5311     14.5311     0.000402526  234113        
# 92000       13.9809     13.9809     0.000398521  237093        
# 93000       14.332      14.332      0.000394556  240080        
# 94000       14.0843     14.0843     0.00039063  243038        
# 95000       13.8302     13.8302     0.000386743  246000        
# 96000       13.8213     13.8213     0.000382895  248955        
# 97000       13.5921     13.5921     0.000379085  251906        
# 98000       14.7321     14.7321     0.000375313  254925        
# 99000       13.7742     13.7742     0.000371579  257925        
# 100000      13.8036     13.8036     0.000367881  260884        
# 101000      13.9826     13.9826     0.000364221  263935        
# 102000      13.2042     13.2042     0.000360597  266938        
# 103000      13.0243     13.0243     0.000357009  269999        
# 104000      13.6645     13.6645     0.000353456  272968        
# 105000      12.9524     12.9524     0.000349939  276035        
# 106000      13.3976     13.3976     0.000346457  279091        
# 107000      12.7807     12.7807     0.00034301  282152        
# 108000      13.8308     13.8308     0.000339597  285194        
# 109000      12.8637     12.8637     0.000336218  288160        
# 110000      12.9654     12.9654     0.000332873  291140        
# 111000      12.86       12.86       0.00032956  294150        
# 112000      12.9481     12.9481     0.000326281  297231  
# 113000      12.5697     12.5697     0.000323035  300266        
# 114000      12.5993     12.5993     0.00031982  303365        
# 115000      12.6543     12.6543     0.000316638  306453        
# 116000      12.9969     12.9969     0.000313487  309586        
# 117000      12.8993     12.8993     0.000310368  312728        
# 118000      12.9749     12.9749     0.00030728  315897        
# 119000      12.8778     12.8778     0.000304222  319075        
# 120000      12.413      12.413      0.000301195  322196        
# 121000      12.7178     12.7178     0.000298198  325318        
# 122000      12.2412     12.2412     0.000295231  328456        
# 123000      12.3016     12.3016     0.000292294  331568        
# 124000      12.7699     12.7699     0.000289385  334672        
# 125000      12.1416     12.1416     0.000286506  337775        
# 126000      12.5922     12.5922     0.000283655  340911       
#127000       12.6997     12.6997     0.000280833  344026        
# 128000      13.0767     13.0767     0.000278038  347111        
# 129000      12.0074     12.0074     0.000275272  350268        
# 130000      12.1284     12.1284     0.000272533  353354        
# 131000      12.4354     12.4354     0.000269821  356466        
# 132000      11.7967     11.7967     0.000267136  359618        
# 133000      12.3922     12.3922     0.000264478  362748        
# 134000      12.0027     12.0027     0.000261847  365844        
# 135000      12.2929     12.2929     0.000259241  369025   
# 136000      12.0449     12.0449     0.000256662  372200        
# 137000      12.106      12.106      0.000254108  375319        
# 138000      12.1313     12.1313     0.000251579  378445        
# 139000      11.7585     11.7585     0.000249076  381564        
# 140000      12.0848     12.0848     0.000246598  384732        
# 141000      12.3589     12.3589     0.000244144  387884        
# 142000      11.5295     11.5295     0.000241715  390992        
# 143000      11.7807     11.7807     0.00023931  394117        
# 144000      11.821      11.821      0.000236928  397230        
# 145000      11.8141     11.8141     0.000234571  400386        
# 146000      11.8843     11.8843     0.000232237  403551        
# 147000      11.7224     11.7224     0.000229926  406694        
# 148000      11.5571     11.5571     0.000227638  409881        
# 149000      12.103      12.103      0.000225373  413025        
# 150000      11.5525     11.5525     0.000223131  416148        
# 151000      11.5473     11.5473     0.000220911  419306        
# 152000      11.2453     11.2453     0.000218712  422418        
# 153000      10.951      10.951      0.000216536  425521        
# 154000      11.3978     11.3978     0.000214382  428639        
# 155000      11.4013     11.4013     0.000212248  431944  
# 156000      11.4369     11.4369     0.000210137  435293        
# 157000      11.6738     11.6738     0.000208046  438421        
# 158000      12.3936     12.3936     0.000205976  441590        
# 159000      11.7866     11.7866     0.000203926  444743        
# 160000      11.6324     11.6324     0.000201897  447928        
# 161000      11.1356     11.1356     0.000199888  451072        
# 162000      11.2384     11.2384     0.000197899  454257        
# 163000      11.1549     11.1549     0.00019593  457411        
# 164000      11.587      11.587      0.00019398  460540        
# 165000      11.021      11.021      0.00019205  463740        
# 166000      10.8986     10.8986     0.000190139  466931        
# 167000      11.691      11.691      0.000188247  470034        
# 168000      10.8743     10.8743     0.000186374  473184        
# 169000      11.1495     11.1495     0.00018452  476352        
# 170000      10.7639     10.7639     0.000182684  479473        
# 171000      11.2971     11.2971     0.000180866  482620        
# 172000      11.175      11.175      0.000179066  485741        
# 173000      10.9317     10.9317     0.000177285  488927        
# 174000      10.9018     10.9018     0.000175521  492094        
# 175000      10.9454     10.9454     0.000173774  495268        
# 176000      11.4809     11.4809     0.000172045  498419        
# 177000      10.9679     10.9679     0.000170333  501542        
# 178000      11.1397     11.1397     0.000168638  504716        
# 179000      10.7722     10.7722     0.00016696  507852        
# 180000      11.1619     11.1619     0.000165299  510970        
# 181000      11.1312     11.1312     0.000163654  514120        
# 182000      11.2299     11.2299     0.000162026  517266        
# 183000      10.8833     10.8833     0.000160414  520421        
# 184000      10.7197     10.7197     0.000158818  523603        
# 185000      11.1702     11.1702     0.000157237  526742        
# 186000      10.7632     10.7632     0.000155673  529886        
# 187000      10.6083     10.6083     0.000154124  533034        
# 188000      10.7333     10.7333     0.00015259  536177        
# 189000      10.9713     10.9713     0.000151072  539367        
# 190000      10.6769     10.6769     0.000149569  542483  
# 191000      10.6101     10.6101     0.00014808  545647        
# 192000      10.8125     10.8125     0.000146607  548788        
# 193000      10.7871     10.7871     0.000145148  551975        
# 194000      10.2772     10.2772     0.000143704  555135        
# 195000      10.5243     10.5243     0.000142274  558274        
# 196000      10.4488     10.4488     0.000140858  561391        
# 197000      10.6396     10.6396     0.000139457  564451        
# 198000      10.4315     10.4315     0.000138069  567077        
# 199000      11.0021     11.0021     0.000136695  569645        
# 200000      10.4966     10.4966     0.000135335  572219        
# 201000      10.7146     10.7146     0.000133989  574817        
# 202000      10.5162     10.5162     0.000132655  577386        
# 203000      10.4969     10.4969     0.000131336  579986        
# 204000      10.6306     10.6306     0.000130029  582575        
# 205000      10.357      10.357      0.000128735  585183        
# 206000      10.3308     10.3308     0.000127454  587766        
# 207000      10.2185     10.2185     0.000126186  590348        
# 208000      10.2892     10.2892     0.00012493  592931        
# 209000      10.3505     10.3505     0.000123687  595534        
# 210000      10.1028     10.1028     0.000122456  598151        
# 211000      10.2549     10.2549     0.000121238  600746        
# 212000      10.3988     10.3988     0.000120032  603322        
# 213000      10.4937     10.4937     0.000118837  605857 
# 214000      10.4532     10.4532     0.000117655  608398        
# 215000      10.5619     10.5619     0.000116484  610969        
# 216000      10.0156     10.0156     0.000115325  613551        
# 217000      10.0464     10.0464     0.000114178  616326        
# 218000      10.6529     10.6529     0.000113041  619344        
# 219000      10.0268     10.0268     0.000111917  622351        
# 220000      9.98399     9.98399     0.000110803  625478        
# 221000      10.2309     10.2309     0.000109701  628513        
# 222000      9.85557     9.85557     0.000108609  631587        
# 223000      10.2144     10.2144     0.000107528  634659        
# 224000      10.3628     10.3628     0.000106458  637672        
# 225000      9.93266     9.93266     0.000105399  640727        
# 226000      9.84652     9.84652     0.00010435  643872        
# 227000      10.3439     10.3439     0.000103312  646982        
# 228000      10.1158     10.1158     0.000102284  650123        
# 229000      9.86405     9.86405     0.000101266  653255        
# 230000      9.98591     9.98591     0.000100259  656425        
# 231000      9.89182     9.89182     9.92611e-05  659581        
# 232000      10.2688     10.2688     9.82734e-05  662713        
# 233000      9.75424     9.75424     9.72956e-05  665828        
# 234000      9.94533     9.94533     9.63275e-05  668946        
# 235000      10.0589     10.0589     9.5369e-05  672110        
# 236000      9.72761     9.72761     9.44201e-05  675280        
# 237000      10.1024     10.1024     9.34806e-05  678429        
# 238000      9.71764     9.71764     9.25504e-05  681556        
# 239000      10.0936     10.0936     9.16295e-05  684722        
# 240000      9.90142     9.90142     9.07178e-05  687896        
# 241000      10.3031     10.3031     8.98151e-05  691080   
# 242000      10.0567     10.0567     8.89214e-05  694276        
# 243000      9.51079     9.51079     8.80366e-05  697443        
# 244000      9.75928     9.75928     8.71607e-05  700571        
# 245000      10.1733     10.1733     8.62934e-05  703706        
# 246000      9.8334      9.8334      8.54348e-05  706838        
# 247000      9.95693     9.95693     8.45847e-05  710074        
# 248000      9.854       9.854       8.3743e-05  713326        
# 249000      9.9559      9.9559      8.29098e-05  716469        
# 250000      9.69821     9.69821     8.20848e-05  719600        
# 251000      10.1917     10.1917     8.1268e-05  722757        
# 252000      9.70314     9.70314     8.04594e-05  725916        
# 253000      9.75325     9.75325     7.96588e-05  729079        
# 254000      9.74271     9.74271     7.88662e-05  732236        
# 255000      9.78939     9.78939     7.80815e-05  735374        
# 256000      9.40064     9.40064     7.73045e-05  738521        
# 257000      9.79821     9.79821     7.65353e-05  741641        
# 258000      9.63232     9.63232     7.57738e-05  744778        
# 259000      9.88412     9.88412     7.50198e-05  747906        
# 260000      9.57372     9.57372     7.42734e-05  751015        
# 261000      9.86628     9.86628     7.35343e-05  754114        
# 262000      9.68375     9.68375     7.28026e-05  757231        
# 263000      9.78728     9.78728     7.20782e-05  760381        
# 264000      9.59483     9.59483     7.1361e-05  763538        
# 265000      9.32489     9.32489     7.0651e-05  766637        
# 266000      9.42444     9.42444     6.9948e-05  769818  
# 267000      9.67659     9.67659     6.9252e-05  772966        
# 268000      9.60055     9.60055     6.85629e-05  776143        
# 269000      9.62905     9.62905     6.78807e-05  779275        
# 270000      9.69203     9.69203     6.72053e-05  782426        
# 271000      9.60596     9.60596     6.65366e-05  785587  
# 272000      9.67285     9.67285     6.58745e-05  788747        
# 273000      9.80387     9.80387     6.52191e-05  791872        
# 274000      9.37456     9.37456     6.45701e-05  794991        
# 275000      9.90353     9.90353     6.39276e-05  798126        
# 276000      9.67749     9.67749     6.32915e-05  801298        
# 277000      9.69953     9.69953     6.26618e-05  804462        
# 278000      9.44851     9.44851     6.20383e-05  807566        
# 279000      9.47831     9.47831     6.1421e-05  810656        
# 280000      9.69661     9.69661     6.08098e-05  813729        
# 281000      9.47622     9.47622     6.02047e-05  816713        
# 282000      9.42972     9.42972     5.96057e-05  819671        
# 283000      9.13307     9.13307     5.90126e-05  822726        
# 284000      9.61275     9.61275     5.84254e-05  825754        
# 285000      9.71332     9.71332     5.78441e-05  828774        
# 286000      9.40811     9.40811     5.72685e-05  831757        
# 287000      9.4811      9.4811      5.66987e-05  834721        
# 288000      9.44672     9.44672     5.61345e-05  837692        
# 289000      9.6535      9.6535      5.5576e-05  840672        
# 290000      9.47068     9.47068     5.5023e-05  843697        
# 291000      9.61711     9.61711     5.44755e-05  846734        
# 292000      9.23319     9.23319     5.39334e-05  849746        
# 293000      9.54719     9.54719     5.33968e-05  852803        
# 294000      9.49306     9.49306     5.28655e-05  855774        
# 295000      9.54949     9.54949     5.23395e-05  858752        
# 296000      9.44547     9.44547     5.18187e-05  861697        
# 297000      9.5571      9.5571      5.13031e-05  864717        
# 298000      9.12158     9.12158     5.07926e-05  867863        
# 299000      9.42967     9.42967     5.02872e-05  870936        
# 300000      9.13116     9.13116     4.97868e-05  874007        
# 301000      9.33819     9.33819     4.92914e-05  877013        
# 302000      9.21345     9.21345     4.8801e-05  879987        
# 303000      9.55794     9.55794     4.83154e-05  882954        
# 304000      9.75606     9.75606     4.78346e-05  886005        
# 305000      9.78506     9.78506     4.73587e-05  888961        
# 306000      9.46734     9.46734     4.68874e-05  891936        
# 307000      9.34678     9.34678     4.64209e-05  894974        
# 308000      9.3525      9.3525      4.5959e-05  898007        
# 309000      9.37629     9.37629     4.55017e-05  900980        
# 310000      9.13258     9.13258     4.5049e-05  903970        
# 311000      9.04357     9.04357     4.46007e-05  907106        
# 312000      9.31958     9.31958     4.41569e-05  910189        
# 313000      9.28294     9.28294     4.37176e-05  913236        
# 314000      9.35659     9.35659     4.32826e-05  916338        
# 315000      8.91545     8.91545     4.28519e-05  919481        
# 316000      9.30117     9.30117     4.24255e-05  922605        
# 317000      9.1749      9.1749      4.20034e-05  925703        
# 318000      9.57114     9.57114     4.15854e-05  928784        
# 319000      9.27524     9.27524     4.11716e-05  931877        
# 320000      9.36833     9.36833     4.0762e-05  934977        
# 321000      9.13914     9.13914     4.03564e-05  938147        
# 322000      8.99555     8.99555     3.99548e-05  941233        
# 323000      9.32532     9.32532     3.95573e-05  944374        
# 324000      9.24886     9.24886     3.91637e-05  947492        
# 325000      9.1293      9.1293      3.8774e-05  950577        
# 326000      8.97721     8.97721     3.83882e-05  953751        
# 327000      9.35058     9.35058     3.80062e-05  956899        
# 328000      9.37362     9.37362     3.7628e-05  960167        
# 329000      9.32745     9.32745     3.72536e-05  963281        
# 330000      9.08412     9.08412     3.68829e-05  966418        
# 331000      9.31468     9.31468     3.65159e-05  969592        
# 332000      9.01207     9.01207     3.61526e-05  972758        
# 333000      9.48032     9.48032     3.57929e-05  975847        
# 334000      9.14195     9.14195     3.54367e-05  978914        
# 335000      9.25655     9.25655     3.50841e-05  982059        
# 336000      9.24903     9.24903     3.4735e-05  985184        
# 337000      9.29893     9.29893     3.43894e-05  988328        
# 338000      9.20102     9.20102     3.40472e-05  991416        

In [None]:
# raster 534, pixel 0.25, batch 64, resnet34
# 16100 87.364
# 31800 67.916
# 44000 52.826
# 66000 46.433
# 85600 37.023
# 97000 39.376
# 112000 30.721
# 126000 29.439
# 142k 25.491
#154600 25.396
#170600 25.682
#181700 23.363
#198k 23.886
#215k 23.295
#227k 23.340
# 239k 23.719
# 246k 23.769
# 258k，21.630
# 266k, 22.841
# 281k, 23.167
# 297k 21.912
# 314k ,21.564，
# 325k，21.623
# 338k, 21.622
# 342k， 22.135

In [None]:
# raster 267, pixel 0.5, batch 64, resnet34
# 52k 32.4
# 81k 24
# 111k 23.9
# 146k 23.243
# 176k 26.621
# 217k 27.006

In [39]:
# evaluation
from typing import Dict

from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet50
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
from l5kit.geometry import transform_points
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from prettytable import PrettyTable
from pathlib import Path

import os

In [40]:
# ===== GENERATE AND LOAD CHOPPED DATASET
num_frames_to_chop = 100
eval_cfg = cfg["valid_data_loader"]
eval_base_path = create_chopped_dataset(dm.require(eval_cfg["key"]), cfg["raster_params"]["filter_agents_threshold"], 
                              num_frames_to_chop, cfg["model_params"]["future_num_frames"], MIN_FUTURE_STEPS)

copying: 100%|██████████| 16220/16220 [04:01<00:00, 67.22it/s]

you're running with a custom agents_mask

extracting GT: 100%|██████████| 94694/94694 [06:48<00:00, 231.71it/s]


In [41]:
eval_zarr_path = str(Path(eval_base_path) / Path(dm.require(eval_cfg["key"])).name)
eval_mask_path = str(Path(eval_base_path) / "mask.npz")
eval_gt_path = str(Path(eval_base_path) / "gt.csv")

eval_zarr = ChunkedDataset(eval_zarr_path).open()
eval_mask = np.load(eval_mask_path)["arr_0"]
# ===== INIT DATASET AND LOAD MASK
eval_dataset = AgentDataset(cfg, eval_zarr, rasterizer, agents_mask=eval_mask)
eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"], 
                             num_workers=eval_cfg["num_workers"])
print(eval_dataset)

+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Num TR lights | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   16220    |  1622000   | 125423254  |    11733321   |      45.06      |        100.00        |        77.33         |        10.00         |        10.00        |
+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+



you're running with a custom agents_mask



In [42]:
timestamps, track_ids, coords, confs = run_prediction(predictor, eval_dataloader)


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future


disable_traffic_light_faces not found in config, this will raise an error in the future

100%|██████████| 2960/2960 [3:47:43<00:00,  4.62s/it]  


In [43]:
eval_path = "model6_eval5.csv"
write_pred_csv(
    eval_path,
    timestamps=timestamps,
    track_ids=track_ids,
    coords=coords,
    confs=confs)

In [44]:
metrics = compute_metrics_csv(eval_gt_path, eval_path, [neg_multi_log_likelihood, time_displace])
for metric_name, metric_mean in metrics.items():
    print(metric_name, metric_mean) #170600, 25.682

neg_multi_log_likelihood 29.370679729605982
time_displace [0.04612678 0.06858268 0.08922398 0.10907556 0.12888243 0.14844974
 0.16786249 0.18659044 0.20341679 0.22138289 0.23912526 0.25730642
 0.2747145  0.29207841 0.30954033 0.32557833 0.34012221 0.35381431
 0.36744373 0.38052849 0.39259571 0.40478787 0.41634809 0.42728495
 0.43907737 0.44852004 0.45847813 0.46798121 0.47859874 0.48711819
 0.49692595 0.50606379 0.51587084 0.52436406 0.5348049  0.54402356
 0.55409431 0.56561527 0.57459945 0.58619397 0.59838096 0.61004991
 0.62327116 0.63740162 0.65243557 0.66863578 0.68587496 0.70347854
 0.72369505 0.74385898]
