In [1]:
import sys
import onnxruntime as ort
import numpy as np
from IPython.display import Image
from pathlib import Path
import matplotlib.pyplot as plt
import wandb
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

sys.path.append('../')
from dataloading.nvidia import NvidiaValidationDataset
from scripts.pt_to_onnx import convert_pt_to_onnx

In [2]:
def download_model(run_id, entity='nikebless', project='ebm-driving', config={}, return_pt_path=False):
    path_to_model = None
    api = wandb.Api()
    run = api.run(f'{entity}/{project}/{run_id}')
    for file in run.files():
        if file.name.endswith('best.pt'):
            file.download(replace=True)
            path_to_model_pt = file.name
            # convert model from pt to onnx

            model_config = {'n_samples': run.config.get('ebm_train_samples'), 'steering_bound': run.config.get('steering_bound'), 'n_gaussians': run.config.get('mdn_n_components'), **config}
            print('model config:', model_config)
            path_to_model_onnx = convert_pt_to_onnx(file.name, run.config.get('model_type'), config=model_config)
            break

    if return_pt_path:
        return path_to_model_onnx, path_to_model_pt
    else:
        return path_to_model_onnx

### Experiment 1: Number of bin sizes (EBM vs Classification)

In [3]:
models = [
    ('3ftnqxcb', 'ebm-512'),
    ('pg0eweml', 'ebm-256'),
    ('iddqahiv', 'ebm-128'),
    ('3g3wwx73', 'classifier-512'),
    ('1t7d1afv', 'classifier-256'),
    ('2faagb62', 'classifier-128'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model = download_model(run_id=run_id, config=config)
    os.rename(path_to_model, f'../_models/{model_name}.onnx')



/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

### Experiment 2: Variability due to initialization (EBM vs Classification)

In [3]:
models = [
    ('3ftnqxcb', 'ebm-512-s1'),
    ('3dke382l', 'ebm-512-s2'),
    ('kozacohh', 'ebm-512-s3'),
    ('13nkhdy1', 'mae-s1'),
    ('bxd5wtqk', 'mae-s2'),
    ('gwgmsh0e', 'mae-s3'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model = download_model(run_id=run_id, config=config)
    os.rename(path_to_model, f'../_models/{model_name}.onnx')



/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

### Experiment 3: MDN — number of components

In [7]:
models = [
    ('1eq2jiva', 'mdn-1-s1'),
    ('32tkml84', 'mdn-1-s2'),
    ('17obb61u', 'mdn-1-s3'),
    ('3s5tapgz', 'mdn-3-s1'),
    ('1bdlpf8d', 'mdn-3-s2'),
    ('33smkc56', 'mdn-3-s3'),
    ('1hbbr6dm', 'mdn-5-s1'),
    ('27dfw40i', 'mdn-5-s2'),
    ('2zz5m35m', 'mdn-5-s3'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model = download_model(run_id=run_id, config=config)
    os.rename(path_to_model, f'../_models/{model_name}.onnx')

model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 1, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 1, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 1, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 5, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 5, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


model config: {'n_samples': 128, 'steering_bound': 4.5, 'n_gaussians': 5, 'normalize_inputs': False}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10

  if return_variances:


### Experiment 4: Varying data amount (EBM vs Classification)

In [4]:
models = [
    # ebm-512-s1 for 100%
    ('https://wandb.ai/nikebless/ebm-driving/runs/228inmtq', 'ebm-50percent-s1'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/2yofzwf0', 'ebm-50percent-s2'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/1owxpr1x', 'ebm-20percent-s1'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/zie8x2vl', 'ebm-20percent-s2'),
    # classifier-512 for 100%
    ('https://wandb.ai/nikebless/ebm-driving/runs/1t56cx3v', 'classifier-50percent-s1'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/d53rmg8e', 'classifier-50percent-s2'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/2e3y0oce', 'classifier-20percent-s1'),
    ('https://wandb.ai/nikebless/ebm-driving/runs/2ldsn0i0', 'classifier-20percent-s2'),
]

for run_id, model_name in models:
    path_to_model = download_model(run_id=run_id)
    os.rename(path_to_model, f'../_models/{model_name}.onnx')

model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3}
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse

### Final experiment: All models

In [21]:
from collections import Counter
api = wandb.Api()

# only drives since September 19th
runs = api.runs('nikebless/ebm-driving', filters={
        "$and": [{
        'created_at': {
            "$gt": '2022-09-181##'
            }
        }]
    }, 
    order='+created_at',
)

model_names_count = Counter()
models = []

for i, run in enumerate(runs):
    if run.summary.get('epoch', None) is None: continue

    model_type = run.config.get('model_type')
    loss_variant = run.config.get('loss_variant')
    temp_reg = run.config.get('temporal_regularization')

    model_type_str = model_type.replace('pilotnet-', '')
    loss_variant_str = loss_variant.replace('ce-proximity-aware', 'spatial').replace('default', 'normal')

    model_name = f'{model_type_str}-{loss_variant_str}-{temp_reg}'
    model_names_count.update([model_name])
    seed = model_names_count.get(model_name)
    model_name = f'{model_name}-s{seed}'
    models.append((run.id, model_name))

models


[('1p73jc4m', 'ebm-spatial-0-s1'),
 ('3dp4sen8', 'ebm-normal-0.3-s1'),
 ('3jk7cnqa', 'ebm-normal-1-s1'),
 ('27xlvoan', 'ebm-spatial-0.3-s1'),
 ('1e9xgnmv', 'ebm-spatial-1-s1'),
 ('f5xw3kj9', 'classifier-normal-0.3-s1'),
 ('2v42l625', 'ebm-normal-0.3-s2'),
 ('2jvl4yhn', 'ebm-spatial-0-s2'),
 ('4001dzmm', 'classifier-normal-1-s1'),
 ('aw7baltg', 'classifier-spatial-0-s1'),
 ('p99caxnk', 'classifier-spatial-0.3-s1'),
 ('1qo45ro2', 'classifier-spatial-1-s1'),
 ('374wicdd', 'ebm-normal-0.3-s3'),
 ('242kpjyb', 'ebm-normal-1-s2'),
 ('1n3e8g3l', 'ebm-spatial-0-s3'),
 ('gy6gp4pa', 'ebm-spatial-0.3-s2'),
 ('29aaride', 'ebm-spatial-1-s2'),
 ('39s7poda', 'classifier-normal-0.3-s2'),
 ('ci8wz89c', 'classifier-normal-1-s2'),
 ('3avfhaf3', 'classifier-spatial-0-s2'),
 ('2dlkfgnz', 'classifier-spatial-0.3-s2'),
 ('12y11e6t', 'classifier-spatial-1-s2'),
 ('1kut0356', 'ebm-normal-0.3-s4'),
 ('3couqhw2', 'ebm-normal-1-s3'),
 ('56vrm532', 'ebm-spatial-0-s4'),
 ('2sn1bnrf', 'ebm-spatial-0.3-s3'),
 ('w7zzop

In [24]:
for run_id, model_name in models:
    if 'classifier-spatial' in model_name and '-s3' in model_name:
        config = {'normalize_inputs': False}
        path_to_model = download_model(run_id=run_id, config=config)
        if path_to_model is not None:
            os.rename(path_to_model, f'../_models/{model_name}.onnx')
        else:
            print(f'Could not download {run_id} {model_name}. Check the run has the model file.')

model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f2a325294f0>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

### Final experiments: Selected models

In [5]:
models = [
    ('3ftnqxcb', 'ebm-512-s1'),
    ('3jk7cnqa', 'ebm-normal-1-s1'),
    ('2jvl4yhn', 'ebm-spatial-0-s2'),
    ('bxd5wtqk', 'mae-s2'),
    ('3g3wwx73', 'classifier-512'),
    ('1hbbr6dm', 'mdn-5-s1'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model_onnx, path_to_model_pt = download_model(run_id=run_id, config=config, return_pt_path=True)
    if path_to_model_onnx is not None and path_to_model_pt is not None:
        os.rename(path_to_model_onnx, f'../_models/{model_name}.onnx')
        os.rename(path_to_model_pt, f'../_models/{model_name}.pt')
    else:
        print(f'Could not download {run_id} {model_name}. Check the run has the model file.')

model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f777b30b8b0>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f7782457970>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f7741580340>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

  batch_size = num_samples // self.inference_samples


model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f775452cfa0>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

  if return_variances:


In [3]:
models = [
    ('274bkloa', 'mdn-3-s1-90-0-90'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model_onnx, path_to_model_pt = download_model(run_id=run_id, config=config, return_pt_path=True)
    if path_to_model_onnx is not None and path_to_model_pt is not None:
        os.rename(path_to_model_onnx, f'../_models/{model_name}.onnx')
        os.rename(path_to_model_pt, f'../_models/{model_name}.pt')
    else:
        print(f'Could not download {run_id} {model_name}. Check the run has the model file.')

model config: {'n_samples': 512, 'steering_bound': 4.5, 'n_gaussians': 3, 'normalize_inputs': False}
[NvidiaDataset] Using default transform: Compose(
    <dataloading.nvidia.Normalize object at 0x7f78f5c8ea30>
)
/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-

  if return_variances:
