In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

from lib import dpf, panda_models, panda_datasets, panda_training
from lib.utils import file_utils, torch_utils, misc_utils

print(torch.__version__, np.__version__)

In [None]:
# Experiment configuration
experiment_name = "dpf_all_sensors"
dataset_args = {
    'use_proprioception': True,
    'use_vision': True,
    'vision_interval': 10,
}

In [None]:
# Create models & training buddy

dynamics_model = panda_models.PandaSimpleDynamicsModel(state_noise=(0.05))
measurement_model = panda_models.PandaMeasurementModel(units=32)

pf_model = dpf.ParticleFilterNetwork(dynamics_model, measurement_model)

buddy = torch_utils.TrainingBuddy(
    experiment_name,
    pf_model,
    optimizer_names=["e2e", "dynamics", "measurement"],
    log_dir="logs/pf",
    checkpoint_dir="checkpoints/pf"
)

# Dynamics Model Pre-Training

In [None]:
# dynamics_trainset = panda_datasets.PandaDynamicsDataset(
# #     "data/pull-test-tiny.hdf5",
#     "data/pull-test.hdf5",
#     "data/push-test.hdf5",
#     **dataset_args
# )

In [None]:
# dataloader = torch.utils.data.DataLoader(dynamics_trainset, batch_size=256, shuffle=True, num_workers=2)

# for i in range(0):
#     print("Training epoch", i)
#     panda_training.train_dynamics(buddy, pf_model, dataloader, log_interval=10)
#     print()


# Measurement Model Pre-Training

In [None]:
measurement_trainset = panda_datasets.PandaMeasurementDataset(
    ("data/pull-test.hdf5", 100000),
    ("data/push-test.hdf5", 100000),
    **dataset_args
)

In [None]:
measurement_trainset_loader = torch.utils.data.DataLoader(measurement_trainset, batch_size=32, shuffle=True, num_workers=2)
for i in range(5):
    print("Training epoch", i)
    panda_training.train_measurement(buddy, pf_model, measurement_trainset_loader, log_interval=100)
    print()

measurement_trainset_loader = torch.utils.data.DataLoader(measurement_trainset, batch_size=64, shuffle=True, num_workers=2)
for i in range(5):
    print("Training epoch", i)
    panda_training.train_measurement(buddy, pf_model, measurement_trainset_loader, log_interval=100)
    print()

measurement_trainset_loader = torch.utils.data.DataLoader(measurement_trainset, batch_size=256, shuffle=True, num_workers=2)
for i in range(5):
    print("Training epoch", i)
    panda_training.train_measurement(buddy, pf_model, measurement_trainset_loader, log_interval=100)
    print()



In [None]:
buddy.save_checkpoint(label="before_e2e_training")

# End-to-end Training

In [None]:
# Create end-to-end dataset

e2e_trainset = panda_datasets.PandaParticleFilterDataset(
    ("data/pull-test.hdf5", 10000),
    ("data/push-test.hdf5", 10000),
    subsequence_length=4,
    particle_count=50,
    particle_variances=(.2,),
    **dataset_args
)

In [None]:
##### Train end-to-end

pf_model.freeze_measurement_model = False
pf_model.freeze_dynamics_model = True

e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=32, shuffle=True, num_workers=2)
for i in range(5):
    print("Training epoch", i)
    panda_training.train_e2e(buddy, pf_model, e2e_trainset_loader, loss_type="mse", log_interval=100)

e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=64, shuffle=True, num_workers=2)
for i in range(5):
    print("Training epoch", i)
    panda_training.train_e2e(buddy, pf_model, e2e_trainset_loader, loss_type="mse", log_interval=100)

In [None]:
buddy.save_checkpoint()

# Model eval

In [None]:
eval_trajectories_list = []

def load_trajectories(label, validation, include_pull, include_push, max_count=10):
    if validation:
        # Validation set
        files = [
            ("data/pull-test-small.hdf5", max_count),
            ("data/push-test-small.hdf5", max_count)
        ]
    else:
        # Training set
        files = [
            ("data/pull-test.hdf5", max_count),
            ("data/push-test.hdf5", max_count)
        ]
    
    filtered_files = []
    if include_pull:
        filtered_files.append(files[0])
    if include_push:
        filtered_files.append(files[1])

    trajectories = panda_datasets.load_trajectories(
        *filtered_files,
        **dataset_args
    )
    eval_trajectories_list.append((label, trajectories))

load_trajectories("Validation all", validation=True, include_pull=True, include_push=True)
load_trajectories("Validation pull", validation=True, include_pull=True, include_push=False)
load_trajectories("Validation push", validation=True, include_pull=False, include_push=True)

load_trajectories("Training all", validation=False, include_pull=True, include_push=True)
load_trajectories("Training pull", validation=False, include_pull=True, include_push=False)
load_trajectories("Training push", validation=False, include_pull=False, include_push=True)

### Final Model

In [None]:
# Load the latest version of the model & evaluate
buddy.load_checkpoint()
for label, trajectories in eval_trajectories_list:
    print("###############################")
    print("###############################")
    print("##", label)
    print("###############################")
    print("###############################")
    traj = trajectories
    pred, actual = panda_training.rollout(pf_model, traj, start_time=0, max_timesteps=1000, particle_count=200, noisy_dynamics=True)
    panda_training.vis_rollout(pred, actual)

### Model without end-to-end training

In [None]:
# Back up model
buddy.save_checkpoint()

# Load the pre-end-to-end-training version of the model & evaluate
buddy.load_checkpoint(label="before_e2e_training")
for label, trajectories in eval_trajectories_list:
    print("###############################")
    print("###############################")
    print("##", label)
    print("###############################")
    print("###############################")
    traj = eval_trajectories
    pred, actual = panda_training.rollout(pf_model, traj, start_time=0, max_timesteps=1000, particle_count=200, noisy_dynamics=True)
    panda_training.vis_rollout(pred, actual)

# Restore model
buddy.load_checkpoint()