In [1]:
import torch
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import tqdm

import os
import sys
module_paths =  [
    os.path.abspath(os.path.join('ronin/source'))  # RoNIN
]
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)

import data_glob_speed
import cnn_vae_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def plot_np_multi(xy_dict, xlabel='time (s)'):
    fig = go.Figure()
    for name, (x, y) in xy_dict.items():
        fig.add_trace(go.Scatter(x=x, y=y, name=name))
    fig.update_layout(xaxis_title=xlabel)
    fig.show()

def save_model_by_name(model, epoch):
    save_dir = os.path.join('checkpoints', model.name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(epoch))
    state = model.state_dict()
    torch.save(state, file_path)
    print('Saved to {}'.format(file_path))

# Data loading

In [3]:
DATA_ROOT_DIR = 'datasets/Ronin/train_dataset_1'
with open('ronin/train_list.txt') as f:
    data_list = [s.strip().split(',' or ' ')[0] for s in f.readlines() if len(s) > 0 and s[0] != '#']

# Each item in the dataset is a (feature, target, seq_id, frame_id) tuple.
# Each feature is a 6x200 array. Rows 0-2 are gyro, and rows 3-5 are accel (non gravity subtracted).
# Both gyro and accels are in a gravity-aligned world frame (arbitrary yaw, but consistent throughout
# the 200 frames)
dataset = data_glob_speed.StridedSequenceDataset(data_glob_speed.GlobSpeedSequence,
                                                 DATA_ROOT_DIR,
                                                 data_list,
                                                 cache_path='ronin/train_data_cache')
train_loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [4]:
model = cnn_vae_model.CnnVae(feature_dim=6, latent_dim=64,
                                    first_channel_size=64, last_channel_size=512,
                                    fc_dim=256).to(device)

# Train

In [None]:
EPOCHS=100
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
with tqdm.tqdm(total=EPOCHS) as pbar:
    for epoch in range(EPOCHS):
        for batch_id, (feat, _, _, _) in enumerate(train_loader):
            feat = feat.to(device)
            optimizer.zero_grad()
            nelbo, kl, rec = model.negative_elbo_bound(feat)
            nelbo.backward()
            optimizer.step()
            pbar.set_postfix(
                nelbo='{:.2e}'.format(nelbo),
                kl='{:.2e}'.format(kl),
                rec='{:.2e}'.format(rec)
            )
        pbar.update(1)
    
        save_model_by_name(model, epoch)

  1%|▌                                                   | 1/100 [02:20<3:52:10, 140.71s/it, kl=4.92e+01, nelbo=2.46e+03, rec=2.41e+03]

Saved to checkpoints/CnnVae/model-00000.pt


  2%|█                                                   | 2/100 [04:44<3:52:25, 142.30s/it, kl=6.37e+01, nelbo=1.80e+03, rec=1.74e+03]

Saved to checkpoints/CnnVae/model-00001.pt


  2%|█                                                   | 2/100 [05:43<3:52:25, 142.30s/it, kl=6.58e+01, nelbo=1.66e+03, rec=1.59e+03]