In [1]:
import torch
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import tqdm
import numpy as np
import wandb
import time

import utils

import os
import sys
module_paths =  [
    os.path.abspath(os.path.join('ronin/source'))  # RoNIN
]
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)

import data_glob_speed
import data_ridi
import cnn_vae_model

# WANDB API Key: eefeec3d5632912a6bb9112f48d2dde3ca6e0658
wandb.login()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

[34m[1mwandb[0m: Currently logged in as: [33mansonw[0m. Use [1m`wandb login --relogin`[0m to force relogin


cuda


# Load RONIN dataset

In [2]:
DATA_ROOT_DIR = 'datasets'
with open('datasets/self_sup_ronin_train_list.txt') as f:
    ronin_data_list = [s.strip().split(',' or ' ')[0] for s in f.readlines() if len(s) > 0 and s[0] != '#']

# Each item in the dataset is a (feature, target, seq_id, frame_id) tuple.
# Each feature is a 6x200 array. Rows 0-2 are gyro, and rows 3-5 are accel (non gravity subtracted).
# Both gyro and accels are in a gravity-aligned world frame (arbitrary yaw, but consistent throughout
# the 200 frames)
ronin_train_dataset = data_glob_speed.StridedSequenceDataset(data_glob_speed.GlobSpeedSequence,
                                                             DATA_ROOT_DIR,
                                                             ronin_data_list,
                                                             cache_path='datasets/cache')

batch_size = 128
train_loader = DataLoader(ronin_train_dataset, batch_size=batch_size, shuffle=True)

# Define model

In [3]:
latent_dim = 64
first_chan_size = 64
last_chan_size = 512
encoder_fc_dim = 256
encoder = cnn_vae_model.ResNetEncoder(feature_dim=6,
                                 latent_dim=latent_dim,
                                 first_channel_size=first_chan_size,
                                 last_channel_size=last_chan_size,
                                 fc_dim=encoder_fc_dim).to(device)

vel_fc_dims = [512]
dropout = 0.5
lr = 1e-4
vel_model = cnn_vae_model.VelocityRegressor(encoder, vel_fc_dims, num_outputs=2, dropout=dropout).to(device)

# Supervised training

In [4]:
def get_model_name():
    name = "baseline_lr_{}_FC".format(lr)
    for fc_dim in vel_fc_dims:
        name += "_{}".format(fc_dim)
    return name

# WANDB setup
wandb_run = wandb.init(
    # Set the project where this run will be logged
    project="Baseline-RONIN-dataset-supervised",
    # Track hyperparameters and run metadata
    config={
        "vel_fc_dims": vel_fc_dims,
        "dropout": dropout,
        "lr": lr,
        "latent_dim": latent_dim,
        "encoder_first_chan_size": first_chan_size,
        "encoder_last_chan_size": last_chan_size,
        "encoder_fc_dim": encoder_fc_dim,
        "batch_size": batch_size
})

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(vel_model.parameters(), lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True, eps=1e-12)

if 'start_epoch' not in locals():
    start_epoch = 0

max_epochs = 5000
best_loss = np.inf
train_losses_all = []
for epoch in range(start_epoch, max_epochs):
    start_t = time.time()
    vel_model.train()
    train_outs, train_targets = [], []
    for batch_id, (feat, targ, _, _) in enumerate(train_loader):
        feat, targ = feat.to(device), targ.to(device)
        optimizer.zero_grad()
        pred = vel_model(feat)
        train_outs.append(pred.cpu().detach().numpy())
        train_targets.append(targ.cpu().detach().numpy())
        loss = criterion(pred, targ)
        loss = torch.mean(loss)
        loss.backward()
        optimizer.step()
    train_outs = np.concatenate(train_outs, axis=0)
    train_targets = np.concatenate(train_targets, axis=0)
    train_losses = np.average((train_outs - train_targets) ** 2, axis=0)

    end_t = time.time()
    print('-------------------------')
    print('Epoch {}, time usage: {:.3f}s, average loss: {}/{:.6f}'.format(
        epoch, end_t - start_t, train_losses, np.average(train_losses)))
    avg_loss = np.average(train_losses)
    train_losses_all.append(avg_loss)

    wandb_run.log({"vel_x_loss": train_losses[0],
                   "vel_y_loss": train_losses[1],
                   "avg_loss": avg_loss})

    if avg_loss < best_loss:
        best_loss = avg_loss
        utils.save_states(get_model_name(), epoch, vel_model, optimizer)

-------------------------
Epoch 0, time usage: 106.973s, average loss: [0.04673219 0.04989984]/0.048316
Model saved to  checkpoints/baseline_lr_0.0001_FC_512/model-00000.pt
-------------------------
Epoch 1, time usage: 107.221s, average loss: [0.02204478 0.02442742]/0.023236
Model saved to  checkpoints/baseline_lr_0.0001_FC_512/model-00001.pt
-------------------------
Epoch 2, time usage: 107.683s, average loss: [0.01658193 0.01861861]/0.017600
Model saved to  checkpoints/baseline_lr_0.0001_FC_512/model-00002.pt
-------------------------
Epoch 3, time usage: 107.495s, average loss: [0.01386032 0.01561496]/0.014738
Model saved to  checkpoints/baseline_lr_0.0001_FC_512/model-00003.pt
-------------------------
Epoch 4, time usage: 107.766s, average loss: [0.01213281 0.0136878 ]/0.012910
Model saved to  checkpoints/baseline_lr_0.0001_FC_512/model-00004.pt
-------------------------
Epoch 5, time usage: 107.827s, average loss: [0.01094766 0.01230246]/0.011625
Model saved to  checkpoints/bas

KeyboardInterrupt: 

wandb: Network error (ReadTimeout), entering retry loop.
