In [1]:
from varyingsim.osi import OSI
from varyingsim.box import BoxEnv
from varyingsim.util.buffers import TrajBuffer

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

## First we will train the OSI and prediction models disjointly

In [2]:
n_traj = 20
n_val = 5
T = 2000
h = 16
shapes = [(2,)] * 5 + [()] # qpos, qvel, prevqpos, prevqvel, action, friction
device = 'cuda'


In [3]:
def obs_to_datum(obs, prev_obs, a):
    """
    obs is qpos, qvel, box fric, floor fric, mass, gear
    """
    xy = obs[0:2]
    xy_vel = obs[3:5]
    xy_prev = prev_obs[0:2]
    xy_vel_prev = prev_obs[3:5]
    friction = prev_obs[7] # TODO: should this be previous?
    act = np.array(a)

    return xy, xy_vel, xy_prev, xy_vel_prev, act, friction 

def get_data(env, buffer, T):
    obs = env.reset()
    t = 0
    scale = 100
    
    buffer.set_new_traj()
    while t < T:
        a = [np.sin(t / scale), np.cos(t / scale)] # we should do someting more sophisticated
        prev_obs = obs
        obs, rew, done, info = env.step(a)
        datum = obs_to_datum(obs, prev_obs, a)
        buffer.add_datum(datum)
        t += 1

def buffer_to_osi_torch(batch):
    """
        takes in a batch from a TrajBuffer and returns a pytorch batch
        of size N x h x d_in
    """
    N, h, _ = batch[0].shape

    x = np.concatenate([batch[0], batch[1]], axis=-1) # history of xy, xy_vel
    y = batch[-1][:, -1] # latest friciton

    x_torch = torch.from_numpy(x).float()
    y_torch = torch.from_numpy(y).float().unsqueeze(-1)

    return x_torch, y_torch

def set_friction_sin(env, t, scale=213):
    env.set_floor_friction(np.sin(t / scale) * 0.15 + 1.15)

def set_friction_step(env, t, scale=50):
    if t // scale % 2 == 0:
        env.set_floor_friction(1.0)
    else:
        env.set_floor_friction(1.3)

def visual_eval(buffer, model, h, device='cuda'):
    val_batch = val_buffer.get_traj_batch(4, h)
    x, y = buffer_to_osi_torch(val_batch)
    x = x.to(device)
    y = y.to(device)
    y_hat = model(x) 
    print(y_hat)
    print(y)

def eval(buffer, model, h, device='cuda'):
    # TODO: should have a method that just returns all trajectories concated
    val_batch = val_buffer.get_traj_batch(len(buffer), h)
    x, y = buffer_to_osi_torch(val_batch)
    x = x.to(device)
    y = y.to(device)
    y_hat = model(x) 
    loss = F.mse_loss(y_hat, y)
    return loss.item()

def train_osi(env, train_buffer, val_buffer, h):
    d_in = env.model.nq - 1 + env.model.nu # xvel, yvel, and action
    d_param = 1 # just friciton for now
    d_hidden_shared = 64
    d_hidden_osi = 256

    model = OSI(h, d_in, d_param, d_hidden_shared, d_hidden_osi)
    
    lr = 1e-3
    batch_size = 64
    n_iters = 6000
    print_iter = 50

    model = model.to(device)

    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for i in range(n_iters):
        batch = train_buffer.get_traj_batch(batch_size, h)
        x, y = buffer_to_osi_torch(batch)
        x, y = x.to(device), y.to(device)
        optim.zero_grad()
        y_hat = model(x)

        loss = F.mse_loss(y_hat, y)
        loss.backward()
        optim.step()

        if i % print_iter == 0:
            val_batch = val_buffer.get_traj_batch(128, h)
            x_val, y_val = buffer_to_osi_torch(val_batch)
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_hat_val = model(x_val)
            val_loss = F.mse_loss(y_hat_val, y_val)
            print(i, loss.item(), val_loss.item())
    return model

In [4]:
env = BoxEnv(set_param_fn=set_friction_sin, rand_reset=True)

train_buffer = TrajBuffer(-1, shapes)
val_buffer = TrajBuffer(-1, shapes)

for i in range(n_traj):
    get_data(env, train_buffer, T)
for i in range(n_val):
    get_data(env, val_buffer, T)

In [5]:
osi_model = train_osi(env, train_buffer, val_buffer, h)
eval_loss = eval(val_buffer, osi_model, h)
print(eval_loss)

0 1.3349215984344482 0.6299585103988647
50 0.006848664488643408 0.00786968320608139
100 0.004041353706270456 0.003887791885063052
150 0.0028496631421148777 0.0027191087137907743
200 0.0011799049098044634 0.0013772605452686548
250 0.0008055281941778958 0.0007930191932246089
300 0.0008958268444985151 0.0006273123435676098
350 0.00042417208896949887 0.0005331055726855993
400 0.0005236775032244623 0.0009426895412616432
450 0.0005248639499768615 0.000687066582031548
500 0.00041864762897603214 0.0004596546641550958
550 0.0004536719061434269 0.000613196287304163
600 0.00042437613592483103 0.0006377086974680424
650 0.0003540627076290548 0.0004107936401851475
700 0.00037777257966808975 0.0004533175379037857
750 0.0004513421154115349 0.0003319201641716063
800 0.0003512442926876247 0.0008016062201932073
850 0.0004514720058068633 0.00032077389187179506
900 0.00026805768720805645 0.00030641138437204063
950 0.00039958395063877106 0.0005286566447466612
1000 0.00033376069040969014 0.000557296443730592

In [6]:
def traj_buffer_to_pred_torch(batch):
    # qpos, qvel, prevqpos, prevqvel, action, friction
    qpos = batch[0]
    qvel = batch[1]
    prev_qpos = batch[2]
    prev_qvel = batch[3]
    act = batch[4]
    friction = batch[5]
    x = np.concatenate([prev_qvel[:, -1], act[:, -1]], axis=-1)
    y = np.concatenate([qpos[:, -1], qvel[:, -1]], axis=-1)
    x_prev = np.concatenate([prev_qpos[:, -1], prev_qvel[:, -1]], axis=-1)
    x_torch = torch.from_numpy(x).float()
    y_torch = torch.from_numpy(y).float()
    x_prev_torch = torch.from_numpy(x_prev).float()
    friction_torch = torch.from_numpy(friction[:, -1]).unsqueeze(-1).float()
    return x_torch, y_torch, x_prev_torch, friction_torch

def batch_buffer_to_pred_torch(batch):
    qpos = batch[0]
    qvel = batch[1]
    prev_qpos = batch[2]
    prev_qvel = batch[3]
    act = batch[4]
    friction = batch[5]
    x = np.concatenate([prev_qvel, act], axis=-1)
    y = np.concatenate([qpos, qvel], axis=-1)
    x_prev = np.concatenate([prev_qpos, prev_qvel], axis=-1)
    x_torch = torch.from_numpy(x).float()
    y_torch = torch.from_numpy(y).float()
    x_prev_torch = torch.from_numpy(x_prev).float()
    friction_torch = torch.from_numpy(friction).unsqueeze(-1).float()
    return x_torch, y_torch, x_prev_torch, friction_torch

def train_pred_model(train_buffer, val_buffer, include_friction=True):
    d_in = 4 + include_friction # 2 for qvel, 2 for action, 1 for friction
    d_out = 4 # qpos, qvel
    d_hidden = 256
    
    model = nn.Sequential(
        nn.Linear(d_in, d_hidden),
        nn.ReLU(),
        nn.Linear(d_hidden, d_hidden),
        nn.ReLU(),
        nn.Linear(d_hidden, d_hidden),
        nn.ReLU(),
        nn.Linear(d_hidden, d_out)
    )
    
    lr = 1e-3
    batch_size = 64
    n_iters = 6000
    print_iter = 50
    device = 'cuda'
    
    model = model.to(device)
    
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    
    for i in range(n_iters):
        batch = train_buffer.get_batch(batch_size)
        x, y, x_prev, friction = batch_buffer_to_pred_torch(batch)
        if include_friction:
            x = torch.cat([x, friction], dim=1)
        x, y = x.to(device), y.to(device)
        x_prev = x_prev.to(device)
        optim.zero_grad()
        y_hat = x_prev + model(x)

        loss = F.mse_loss(y_hat, y)
        loss.backward()
        optim.step()

        if i % print_iter == 0:
            val_batch = val_buffer.get_batch(128)
            x_val, y_val, x_prev_val, fric_val = batch_buffer_to_pred_torch(val_batch)
            if include_friction:
                x_val = torch.cat([x_val, fric_val], dim=1)
            x_val, y_val = x_val.to(device), y_val.to(device)
            x_prev_val = x_prev_val.to(device)
            y_hat_val = x_prev_val + model(x_val)
            val_loss = F.mse_loss(y_hat_val, y_val)
            print(i, loss.item(), val_loss.item())
    return model

In [8]:
pred_model_gt = train_pred_model(train_buffer, val_buffer)
pred_model_no_fric = train_pred_model(train_buffer, val_buffer, include_friction=False)

0 0.005612609442323446 0.0019001023611053824
50 0.0003082512994296849 0.00023471939493902028
100 0.0002265569637529552 0.00016066402895376086
150 0.0003321518306620419 0.00022417775471694767
200 0.00018423606525175273 0.00022708275355398655
250 0.00022153067402541637 0.00023813624284230173
300 0.00017115396622102708 0.00018756510689854622
350 0.000176668370841071 0.00020773083087988198
400 0.00022416323190554976 0.00021662807557731867
450 0.00026084529235959053 0.00023394257004838437
500 0.00022215713397599757 0.0001468068512622267
550 0.00019189860904589295 0.00022631704632658511
600 0.0001440146006643772 0.00019876024452969432
650 0.0003493807162158191 0.0002451499458402395
700 0.00018401470151729882 0.00020660311565734446
750 0.00013711294741369784 0.0002692607813514769
800 0.0001636888482607901 0.0001646681921556592
850 0.00024098370340652764 0.0001810419635148719
900 0.00020022479293402284 0.0002488477330189198
950 0.00019497897301334888 0.00020818927441723645
1000 0.0002295061422

In [9]:
def eval_osi_pred(osi_model, pred_model, buffer, h, val_size=5000):
    batch = buffer.get_traj_batch(val_size, h)
    x_hist, gt_fric = buffer_to_osi_torch(batch)
    x, y, x_prev, gt_fric = traj_buffer_to_pred_torch(batch)
    
    x_hist = x_hist.to(device)
    gt_fric = gt_fric.to(device)
    x = x.to(device)
    x_prev = x_prev.to(device)
    y = y.to(device)
    
    estimated_fric = osi_model(x_hist)
    x = torch.cat([x, estimated_fric], dim=1)
    
    y_hat = x_prev + pred_model(x)
    
    pos_mse = F.mse_loss(y_hat, y)
    fric_mse = F.mse_loss(estimated_fric, gt_fric)
    return pos_mse.item(), fric_mse.item()

def eval_pred_gt(pred_model, buffer, val_size=5000, include_friction=True):
    batch = buffer.get_batch(val_size)
    x, y, x_prev, gt_fric = batch_buffer_to_pred_torch(batch)
    
    gt_fric = gt_fric.to(device)
    x = x.to(device)
    x_prev = x_prev.to(device)
    y = y.to(device)
    
    if include_friction:
        x = torch.cat([x, gt_fric], dim=1)
    
    y_hat = x_prev + pred_model(x)
    
    pos_mse = F.mse_loss(y_hat, y)
    return pos_mse.item()


In [11]:
pos_mse_osi, fric_mse = eval_osi_pred(osi_model, pred_model_gt, val_buffer, h)
pos_gt_mse = eval_pred_gt(pred_model_gt, val_buffer)
pos_no_fric_mse = eval_pred_gt(pred_model_no_fric, val_buffer, include_friction=False)

print(pos_mse_osi, pos_gt_mse, pos_no_fric_mse)

0.00019378271827008575 0.0001789461966836825 0.00017983032739721239
