In [4]:
import wandb
import os

WANDB_USER = 'nikebless-thesis'
WANDB_PROJECT = 'ibc'
SAVEDIR = '../_models'

os.makedirs(SAVEDIR, exist_ok=True)

models = { 
    'ebm-unregularized-v1': '2fszoqji',
    'ebm-unregularized-v2': '2rnlr80k',
    'ebm-unregularized-v3': '2h2223nm',
    'mae-v1': '13vssg6f',
    'mae-v2': '1369ywfp',
    'mae-v3': '17q16jkl',
    'ebm-regularized-v1': '2x8qclz8',
    'ebm-regularized-v2': '31ua9ivf',
    'ebm-regularized-v3': '9bwouhzr',
    # 'ebm-12-5%-v1': '1qs4417o',
    # 'ebm-12-5%-v2': '',
    # 'ebm-12-5%-v3': '',
    # 'ebm-1-25%-v1': '2vxae1ae',
    # 'ebm-1-25%-v2': '',
    # 'ebm-1-25%-v3': '',
    # 'mae-12-5%-v1': '2ing0vso',
    # 'mae-12-5%-v2': '',
    # 'mae-12-5%-v3': '',
    # 'mae-1-25%-v1': '2ubymmju',
    # 'mae-1-25%-v2': '',
    # 'mae-1-25%-v3': '',
    'ebm-50%-v1': '3dbratb5', # TODO: download
    'ebm-10%-v1': '3dabg6cd',
    'mae-50%-v1': '2keinev2',
    'mae-10%-v1': '1i1dadsh',
}

### Download models from Wandb

In [5]:
api = wandb.Api()

for model_name, run_hash in models.items():

    run = api.run(f"{WANDB_USER}/{WANDB_PROJECT}/{run_hash}")

    output_path = os.path.join(SAVEDIR, f'{model_name}.pt')
    if not os.path.exists(output_path):
        last_pt_model = [file for file in run.files() if file.name.endswith('last.pt')][0]
        last_pt_model.download()
        os.rename(last_pt_model.name, output_path)
        print(f'Saved {model_name} to {output_path}.')
    else:
        print(f'Skipping {model_name} because it\'s already downloaded.')

    last_onnx_models = [file for file in run.files() if file.name.endswith('last.onnx')]
    output_path = os.path.join(SAVEDIR, f'{model_name}.onnx')
    if len(last_onnx_models) and not os.path.exists(output_path):
        last_onnx_model = last_onnx_models[0]
        last_onnx_model.download()
        os.rename(last_onnx_model.name, output_path)
        print(f'Saved {model_name} to {output_path}.')
    else:
        print(f'Skipping {model_name} because it\'s already downloaded.')

Skipping ebm-unregularized-v1 because it's already downloaded.
Skipping ebm-unregularized-v1 because it's already downloaded.
Skipping ebm-unregularized-v2 because it's already downloaded.
Skipping ebm-unregularized-v2 because it's already downloaded.
Skipping ebm-unregularized-v3 because it's already downloaded.
Skipping ebm-unregularized-v3 because it's already downloaded.
Skipping mae-v1 because it's already downloaded.
Skipping mae-v1 because it's already downloaded.
Skipping mae-v2 because it's already downloaded.
Skipping mae-v2 because it's already downloaded.
Skipping mae-v3 because it's already downloaded.
Skipping mae-v3 because it's already downloaded.
Skipping ebm-regularized-v1 because it's already downloaded.
Skipping ebm-regularized-v1 because it's already downloaded.
Skipping ebm-regularized-v2 because it's already downloaded.
Skipping ebm-regularized-v2 because it's already downloaded.
Skipping ebm-regularized-v3 because it's already downloaded.
Skipping ebm-regularize

### Convert models to ONNX

In [6]:
import sys
import os
sys.path.append('/home/nikita/e2e-driving/')

from scripts.pt_to_onnx import convert_pt_to_onnx

conf = {
    'steering_bound': 4.5,
    'use_constant_samples': True,
}

for model_name in models.keys():
    if 'ebm' in model_name:
        pt_model_path = os.path.join(SAVEDIR, f'{model_name}.pt')
        onnx_steering_model_path = pt_model_path.replace('.pt', '.onnx')
        onnx_energy_model_path = pt_model_path.replace('.pt', '-energy.onnx')

        if not os.path.exists(onnx_steering_model_path):
            steering_model_path = convert_pt_to_onnx(pt_model_path, 1, onnx_steering_model_path, with_choice=True, n_samples=1024, iters=0, args=conf)
            print(f'Converted steering model {model_name} to {onnx_steering_model_path}.')
        else:
            print(f'Skipping steering model {model_name} because it\'s already downloaded.')

        if not os.path.exists(onnx_energy_model_path):
            energy_model_path = convert_pt_to_onnx(pt_model_path, 256, onnx_energy_model_path, with_choice=False, n_samples=1024, iters=0, args=conf)
            print(f'Converted energy model {model_name} to {onnx_energy_model_path}.')
        else:
            print(f'Skipping energy model {model_name} because it\'s already downloaded.')


Skipping steering model ebm-unregularized-v1 because it's already downloaded.
Skipping energy model ebm-unregularized-v1 because it's already downloaded.
Skipping steering model ebm-unregularized-v2 because it's already downloaded.
Skipping energy model ebm-unregularized-v2 because it's already downloaded.
Skipping steering model ebm-unregularized-v3 because it's already downloaded.
Skipping energy model ebm-unregularized-v3 because it's already downloaded.
Skipping steering model ebm-regularized-v1 because it's already downloaded.
Skipping energy model ebm-regularized-v1 because it's already downloaded.
Skipping steering model ebm-regularized-v2 because it's already downloaded.
Skipping energy model ebm-regularized-v2 because it's already downloaded.
Skipping steering model ebm-regularized-v3 because it's already downloaded.
Skipping energy model ebm-regularized-v3 because it's already downloaded.
/data/Bolt/dataset-new-small/summer2021/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10



/data/Bolt/dataset-new-small/summer2021/2021-10-14-13-08-51_e2e_rec_vahi_backwards: length=13442, filtered=0


  batch_size = num_samples // self.inference_samples


Converted steering model ebm-10%-v1 to ../_models/ebm-10%-v1.onnx.
/data/Bolt/dataset-new-small/summer2021/2021-05-28-15-19-48_e2e_sulaoja_20_30: length=10708, filtered=0
/data/Bolt/dataset-new-small/summer2021/2021-06-07-14-20-07_e2e_rec_ss6: length=25836, filtered=1
/data/Bolt/dataset-new-small/summer2021/2021-06-07-14-06-31_e2e_rec_ss6: length=3003, filtered=0
/data/Bolt/dataset-new-small/summer2021/2021-06-07-14-09-18_e2e_rec_ss6: length=4551, filtered=1
/data/Bolt/dataset-new-small/summer2021/2021-06-07-14-36-16_e2e_rec_ss6: length=25368, filtered=1
/data/Bolt/dataset-new-small/summer2021/2021-09-24-14-03-45_e2e_rec_ss11_backwards: length=25172, filtered=0
/data/Bolt/dataset-new-small/summer2021/2021-10-26-10-49-06_e2e_rec_ss20_elva: length=33045, filtered=0
/data/Bolt/dataset-new-small/summer2021/2021-10-26-11-08-59_e2e_rec_ss20_elva_back: length=33281, filtered=0
/data/Bolt/dataset-new-small/summer2021/2021-10-20-15-11-29_e2e_rec_vastse_ss13_17_back: length=26763, filtered=0
/da

### Now convert the MAE models to dynamic batch size on HPC