In [1]:
import sys
import onnxruntime as ort
import numpy as np
from IPython.display import Image
from pathlib import Path
import matplotlib.pyplot as plt
import wandb
import os

sys.path.append('../')
from dataloading.nvidia import NvidiaValidationDataset
from scripts.pt_to_onnx import convert_pt_to_onnx

In [2]:
def download_model(run_id, entity='nikebless', project='ebm-driving', config={}):
    path_to_model = None
    api = wandb.Api()
    run = api.run(f'{entity}/{project}/{run_id}')
    for file in run.files():
        if file.name.endswith('best.pt'):
            file.download(replace=True)
            # convert model from pt to onnx
            model_config = {'n_samples': run.config.get('ebm_train_samples'), 'steering_bound': run.config.get('steering_bound'), **config}
            path_to_model = convert_pt_to_onnx(file.name, run.config.get('model_type'), config=model_config)
            break

    return path_to_model

### Experiment 1: Number of bin sizes (EBM vs Classification)

In [3]:
models = [
    ('3ftnqxcb', 'ebm-512'),
    ('pg0eweml', 'ebm-256'),
    ('iddqahiv', 'ebm-128'),
    ('3g3wwx73', 'classifier-512'),
    ('1t7d1afv', 'classifier-256'),
    ('2faagb62', 'classifier-128'),
]

for run_id, model_name in models:
    config = {'normalize_inputs': False}
    path_to_model = download_model(run_id=run_id, config=config)
    os.rename(path_to_model, f'../_models/{model_name}.onnx')



/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni

  batch_size = num_samples // self.inference_samples


/data/Bolt/end-to-end/rally-estonia-cropped/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/end-to-end/rally-estonia-cropped/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/end-to-end/rally-estonia-cropped/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/data/Bolt/end-to-end/rally-estoni