In [None]:
import sys
import os
import copy

import torch
import torch.nn as nn
from torchdiffeq import odeint

import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch.optim as optim
import torch.utils.data as data_utils

import svg
from svg.dx import SeqDx

sns.set(style='whitegrid', font_scale=1.75)

In [None]:
# sys.path.append('/home/fakoor/code/remote/shift_uncertainty_rl_lab/code')
# from misc import buffer
# import pickle as pkl

# root_dir = '/home/fakoor/code/remote/shift_uncertainty_rl_lab/code/ck'
# task = 'Hopper-v3'
# agent = 'sac'
# exp_name = 'test-tcap-500k'

# sac_update_res = int(5)
# sac_update_total = int(500)
# subsample_ratio = 1.

# ref_locs = [25, 300, 475]
# num_seeds = 1

# start = sac_update_res
# stop = sac_update_total + sac_update_res
# update_range = list(range(start, stop, sac_update_res))

# time_capsules = []
# for n_update in update_range:
#     tcap_file = '_'.join([
#         task.lower(),
#         agent,
#         exp_name
#     ])
#     tcap_path = os.path.join(root_dir, tcap_file, f'agent-tcap-{n_update}k.pkl')
#     try:
#         with open(tcap_path, 'rb') as f:
#             time_capsules.append(pkl.load(f))
#     except FileNotFoundError:
#         pass

    
# train_subseq_len = 4
# test_subseq_len = 4    

# (train_x, train_u), (test_x, test_u) = buffer.buffers_to_dataset(
#     [tcap['visit_buffer'] for tcap in time_capsules],
#     train_subseq_len,
#     test_subseq_len,
#     subsample_ratio=1.
# )

## Prepare data

In [None]:
task = 'Hopper-v3'

train_x = np.load("../data/HopperFull-v0_cl20_xdata.npy")
train_u = np.load("../data/HopperFull-v0_cl20_udata.npy")

test_x = np.load("../data/HopperFull-v0_episodes_xdata.npy")
test_u = np.load("../data/HopperFull-v0_episodes_udata.npy")

In [None]:
num_train, train_subseq_len, x_size = train_x.shape
_, _, u_size = train_u.shape

num_test, test_subseq_len, _ = test_x.shape

In [None]:
# normalize
x_mean = train_x.reshape(-1, x_size).mean(0)
x_std = np.clip(train_x.reshape(-1, x_size).std(0), a_min=1e-6, a_max=None)
u_mean = train_u.reshape(-1, u_size).mean(0)
u_std = np.clip(train_u.reshape(-1, u_size).std(0), a_min=1e-6, a_max=None)

train_x = (train_x - x_mean) / x_std
train_u = (train_u - u_mean) / u_std

test_x = (test_x - x_mean) / x_std
test_u = (test_u - u_mean) / u_std


train_dataset = data_utils.TensorDataset(
    torch.tensor(train_x).float(),
    torch.tensor(train_u).float()
)

test_dataset = data_utils.TensorDataset(
    torch.tensor(test_x).float(),
    torch.tensor(test_u).float()
)

In [None]:
class RecurrentNetwork(nn.Module):
    def __init__(self, input_size, output_size, enc_hidden_size, rec_hidden_size, dec_hidden_size,
                 enc_depth, rec_depth, dec_depth):
        super().__init__()
        self.encoder = svg.utils.mlp(input_size, enc_hidden_size, rec_hidden_size, enc_depth)
        self.recurrent = nn.GRU(rec_hidden_size, rec_hidden_size, num_layers=rec_depth)
        self.decoder = svg.utils.mlp(rec_hidden_size, dec_hidden_size, output_size, dec_depth)
        self.rec_hidden_size = rec_hidden_size
        self.rec_depth = rec_depth
        
    def init_hidden_state(self, inputs):
        assert inputs.dim() == 2
        n_batch = inputs.size(0)
        h = torch.zeros(self.rec_depth, n_batch, self.rec_hidden_size).to(inputs)
        return h
        
    def forward(self, inputs, hidden_state):
        assert inputs.dim() == 2
        inputs_emb = self.encoder(inputs).unsqueeze(0)
        outputs_emb, hidden_state = self.recurrent(inputs_emb, hidden_state)
        outputs = self.decoder(outputs_emb.squeeze(0))
        return outputs
    

class NN(nn.Module):
    def __init__(
        self,
        s_dim = 12,
        a_dim = 3,
        hidden_size = 256,
        num_layers = 2,
        **kwargs
    ):
        super().__init__(**kwargs)

        chs = [s_dim + a_dim] + num_layers * [hidden_size]
        linears = [nn.Linear(chs[i], chs[i + 1]) for i in range(num_layers)]
        activations = [nn.Tanh() for i in range(num_layers)]
        self.net = nn.Sequential(
            *[val for pair in zip(linears, activations) for val in pair],
            nn.Linear(chs[-1], s_dim),
        )
        
    def dx(self, t, z):
        diff = (self.ts - t).pow(2)[t >= self.ts]
        neigh = self.ts[t >= self.ts][diff.argmin()]
        u_t = self.u[neigh.item()]
        dz_dt = self.net(torch.cat([z, u_t], axis=-1))
        return dz_dt

    def integrate(self, z0, u, ts, tol=1e-4, method="euler"):
        self.u = {t.item(): u_t for t, u_t in zip(ts, u.permute(1, 0, 2))}
        self.ts = ts
        
        bs = z0.shape[0]
#         odeint_options = dict(step_size=1e-3, interp='linear')
        odeint_options = {}
        zt = odeint(self.dx, z0.reshape(bs, -1), ts, rtol=tol, method=method, options=odeint_options)
        zt = zt.permute(1, 0, 2)  # T x N x D -> N x T x D
        return zt[:, 1:]
    
    def forward(self, z0, u, ts):
        return self.integrate(z0, u, ts)

In [None]:
class RecNODE(SeqDx):
    def integrate(self, z0, u, ts, tol=1e-4, method="euler"):
        self.h = self.init_hidden_state(z0)
        self.u = {t.item(): u_t for t, u_t in zip(ts, u.permute(1, 0, 2))}
        self.ts = ts
        
        bs = z0.shape[0]
#         odeint_options = dict(step_size=1e-3, interp='linear')
        odeint_options = {}
        zt = odeint(self.dx, z0.reshape(bs, -1), ts, rtol=tol, method=method, options=odeint_options)
        zt = zt.permute(1, 0, 2)  # T x N x D -> N x T x D
        return zt[:, 1:]
    
    def dx(self, t, x):
        diff = (self.ts - t).pow(2)[t >= self.ts]
        neigh = self.ts[t >= self.ts][diff.argmin()]
        u_t = self.u[neigh.item()]
        batch_size = u_t.size(0)
        
        x_u = torch.cat([x, u_t], dim=-1)
        x_u_emb = self.xu_enc(x_u).unsqueeze(0)
        if self.rec_num_layers > 0:
            dx_dt_emb, self.h = self.rec(x_u_emb, self.h)
        else:
            dx_dt_emb = xu_emb
        dx_dt_emb = dx_dt_emb.squeeze(0)
        
        dx_dt = self.x_dec(dx_dt_emb)
        return dx_dt
    
    def forward(self, z0, u, ts):
        return self.integrate(z0, u, ts)
        
    def unroll(self, x, us, detach_xt=False):
        raise NotImplementedError
    
    @property
    def param_groups(self):
        return [{'params': self.parameters(), 'weight_decay': 0.}]

In [None]:
class RPPNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.seq_model = RecurrentNetwork(*args, **kwargs)
        self.dx_model = RecurrentNetwork(*args, **kwargs)
        
    def forward(self, x_0, u):
        seq_hidden_state = self.seq_model.init_hidden_state(x_0)
        dx_hidden_state = self.dx_model.init_hidden_state(x_0)
        x_seq = x_dx = x_0
        pred_x = []
        for u_t in u:
            seq_input = torch.cat((x_seq, u_t), dim=-1)
            x_seq = self.seq_model(seq_input, seq_hidden_state)
            
            dx_input = torch.cat((x_dx, u_t), dim=-1)
            x_dx = x_dx + self.dx_model(dx_input, dx_hidden_state)
            
            pred_x.append(torch.stack((x_seq, x_dx)))
            
        pred_x = torch.stack([x.mean(0) for x in pred_x])
        return pred_x
    
    @property
    def param_groups(self):
        groups = [
            {'params': self.seq_model.parameters(), 'weight_decay': 0.},
            {'params': self.dx_model.parameters(), 'weight_decay': 0.},
        ]
        return groups
    
    
class AutoregNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.seq_model = RecurrentNetwork(*args, **kwargs)
        
    def forward(self, x_0, u):
        hidden_state = self.seq_model.init_hidden_state(x_0)
        x_t = x_0
        pred_x = []
        for u_t in u:
            x_u = torch.cat((x_t, u_t), dim=-1)
            x_t = self.seq_model(x_u, hidden_state)
            pred_x.append(x_t)
        return torch.stack(pred_x)
    
    @property
    def param_groups(self):
        return [{'params': self.parameters(), 'weight_decay': 0.}]

In [None]:
# ts = (1. / train_x.shape[1]) * torch.arange(train_x.shape[1])
time_step = 1.
train_ts = torch.tensor(
    np.linspace(0, time_step*(train_subseq_len), train_subseq_len + 1)
).float().to('cuda:0')
test_ts = torch.tensor(
    np.linspace(0, time_step*(test_subseq_len), test_subseq_len + 1)
).float().to('cuda:0')

# s_dim, a_dim = x.shape[-1], u.shape[-1]

In [None]:
def fit_net(net, train_loader, test_loader, num_epochs):
    optimizer = optim.Adam(net.param_groups, lr=1e-3)
    lr_sched = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-6, T_max=num_epochs)

    records = []
    for _ in range(num_epochs):
        avg_loss = 0
        for x, u in train_loader:
            x, u = x.to('cuda:0'), u.to('cuda:0')
            
            args = [u, train_ts] if hasattr(net, 'integrate') else [u.permute(1, 0, 2)] 
            z_pred = net(x[:,0], *args)
            z_pred = z_pred if hasattr(net, 'integrate') else z_pred.permute(1, 0, 2)

            optimizer.zero_grad()
            loss = (z_pred[:, :-1] - x[:, 1:]).pow(2).mean()
            loss.backward()
            optimizer.step()

            avg_loss += loss.item() / len(train_loader)
        lr_sched.step()

        test_mse, test_med_se = 0., 0.
        for x, u in test_loader:
            x, u = x.to('cuda:0'), u.to('cuda:0')
            
            with torch.no_grad():
                args = [u, test_ts] if hasattr(net, 'integrate') else [u.permute(1, 0, 2)] 
                z_pred = net(x[:,0], *args)
                z_pred = z_pred if hasattr(net, 'integrate') else z_pred.permute(1, 0, 2)
                
            test_se = (z_pred[:, :-1] - x[:, 1:]).pow(2).mean(0).mean(-1)
            test_mse += test_se.mean().item() / len(test_loader)
            test_med_se += test_se.median().item() / len(test_loader)

        records.append(dict(train_mse=avg_loss, test_mse=test_mse, test_med_se=test_med_se))
    return records

In [None]:
num_epochs = 100
train_loader = data_utils.DataLoader(train_dataset, 200)
test_loader = data_utils.DataLoader(test_dataset, 200)

In [None]:
rpp_net_config = dict(
    input_size=(x_size + u_size),
    output_size=x_size,
    enc_hidden_size=512,
    rec_hidden_size=512,
    dec_hidden_size=512,
    enc_depth=2,
    rec_depth=2,
    dec_depth=0
)
rpp_net = RPPNet(**rpp_net_config).to('cuda')
rpp_records = fit_net(rpp_net, train_loader, test_loader, num_epochs)
rpp_df = pd.DataFrame(rpp_records)

In [None]:
autoreg_net_config = dict(
    input_size=(x_size + u_size),
    output_size=x_size,
    enc_hidden_size=512,
    rec_hidden_size=512,
    dec_hidden_size=512,
    enc_depth=2,
    rec_depth=2,
    dec_depth=0
)
autoreg_net = AutoregNet(**autoreg_net_config).to('cuda')
autoreg_records = fit_net(autoreg_net, train_loader, test_loader, num_epochs)
autoreg_df = pd.DataFrame(autoreg_records)

In [None]:
rec_net_config = dict(
    env_name=task,
    obs_dim=x_size,
    action_dim=u_size,
    action_range=None,
    horizon=4,
    device='cuda',
    detach_xt=False,
    xu_enc_hidden_dim=512,
    xu_enc_hidden_depth=2,
    x_dec_hidden_dim=512,
    x_dec_hidden_depth=0,
    clip_grad_norm=1.0,
    rec_type='GRU',
    rec_latent_dim=512,
    rec_num_layers=2,
    lr=1e-3,
)
# rec_net = SeqDx(**rec_net_config).to('cuda')

# rec_net_records = fit_net(rec_net, train_loader, test_loader, num_epochs)
# rec_net_df = pd.DataFrame(rec_net_records)

In [None]:
rec_node_config = copy.deepcopy(rec_net_config)
rec_node = RecNODE(**rec_node_config).to('cuda')

rec_node_records = fit_net(rec_node, train_loader, test_loader, num_epochs)
rec_node_df = pd.DataFrame(rec_node_records)

In [None]:
# mlp_node = NN(x_size, u_size, hidden_size=512, num_layers=4).to('cuda:0')
# mlp_node_records = fit_net(mlp_node, train_loader, test_loader, num_epochs)
# mlp_node_df = pd.DataFrame(mlp_node_records)

In [None]:
# plt.plot(rec_net_df.train_mse, label='GRU-Delta')
plt.plot(rec_node_df.train_mse, label='NODE (Euler)')
# plt.plot(mlp_node_df.train_mse, label='MLP-NODE (Euler)')
plt.plot(autoreg_df.train_mse, label='Seq.')
plt.plot(rpp_df.train_mse, label='RPP')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('Train MSE')
plt.title(task)

In [None]:
# plt.plot(rec_net_df.test_mse, label='GRU-Delta')
plt.plot(rec_node_df.test_med_se, label='NODE')
# plt.plot(mlp_node_df.test_med_se, label='MLP-NODE')
plt.plot(autoreg_df.test_med_se, label='Seq.')
plt.plot(rpp_df.test_med_se, label='RPP')
# plt.legend()
plt.xlabel('epoch')
plt.ylabel('Test Median SE')
plt.title(task)
# plt.yscale('log')
plt.ylim(0, 2)