In [None]:
cd ..

In [None]:
import pathlib
import json
import numpy as np
import pandas as pd
import torch
import gym
import matplotlib.pyplot as plt

import seqnn
from seqnn import SeqNN, SeqNNConfig
from seqnn.gymutils.logger import Logger
from seqnn.utils import get_data_sample
from seqnn.analysis import plot_prediction

%load_ext autoreload
%autoreload 2


# Load some data

In [None]:
#envname = 'CartPole-v1'
envname = 'LunarLander-v2'
#envname = 'Acrobot-v1'
env = gym.make(envname)

dfs = []
for path in Logger.find_all_files(f'data/gym/{envname}/', '.json'):
    df = Logger.load_episode_as_df(path)
    dfs.append(df)

np.random.seed(3429)
valid_idx = np.random.choice(len(dfs), 5, replace=False)

dfs_train = [df for i, df in enumerate(dfs) if i not in valid_idx]
dfs_valid = [df for i, df in enumerate(dfs) if i in valid_idx]

In [None]:
len(dfs_train), len(dfs_valid)

# Train a model

In [None]:
config = SeqNNConfig(
    targets={"obs": ["obs0", "obs1", "obs2", "obs3"]},
    controls={"act": ["act0"]},
    horizon_past=3,
    horizon_future=5,
    optimizer="SGD",
    optimizer_args={"lr": 0.001, "momentum": 0.9},
    lr_scheduler_args={"gamma": 0.5, "step_size": 2000},
    max_grad_norm=30,
)


model = SeqNN(config)



In [None]:
model.train(dfs_train, dfs_valid, max_epochs=300, progressbar=False)

In [None]:
#model.save('models/cartpole/model1')

# Assess fitted model 

## Load model

In [None]:
# load model

#path = f'models/gym/{envname}/model_old1'
path = f'models/gym/{envname}/model4'
model = seqnn.load(path)


# load data
validset = model.get_dataset(dfs_valid)

## Plot prediction

In [None]:
lims = {
    f'obs{i}': (1.05*env.observation_space.low[i], 1.05*env.observation_space.high[i]) 
    for i in range(env.observation_space.shape[0])
}

In [None]:

i = np.random.choice(len(validset))
#i = 9



past, future = get_data_sample(validset, indices=i)
plot_prediction(
    model,
    past,
    future, 
    tags_pred=model.config.task.grouping[model.config.task.targets[0]],
    tags_control=['act0'],
    #lims=lims,
    ncols=3,
)

# Plan optimization

In [None]:
from seqnn.control.loss import Setpoint
from seqnn.control.planner import CategoricalCEMPlanner

In [None]:

#i = np.random.choice(len(validset))
i = 10
past, future = get_data_sample(validset, indices=i)


task = model.config.task
tags_to_optimize = ['act0']

plan = {
    group: torch.zeros(1, 5, len(task.grouping[group]))
    for group in task.controls_cat
}
print(plan)
num_categ = {tag: task.num_categories[tag] for tag in tags_to_optimize}
plan_loss = Setpoint(
            reference={"obs0": 0.0, "obs1": 0.0, "obs2": 0.0, "obs3": 0.0},
            weights={"obs0": 0.1, "obs1": 1.0},
            end_only=True,
        )
planner = CategoricalCEMPlanner(
    model, plan_loss, past, plan, num_categ
)

for _ in range(15):
    planner.step()
print(plan)


future['act0'][:] = plan['act0']
plot_prediction(
    model,
    past,
    future, 
    tags_pred=model.config.task.grouping[model.config.task.targets[0]],
    tags_control=['act0'],
    #lims=lims,
    ncols=3,
)

# Test scaler 

In [None]:
past, future = get_data_sample(validset, indices=0)

past_scaled, future_scaled = model.model.to_scaled(past, future)
past_recons, future_recons = model.model.to_native(past_scaled, future_scaled)

In [None]:
past

In [None]:
past_recons

In [None]:
future

In [None]:
future_recons

In [None]:
model.predict(past, future)

In [None]:
past_scaled

In [None]:
future

In [None]:
future_scaled

# Experiment with scaler

In [None]:
from seqnn.data.scalers import PastFutureScaler

In [None]:
# load model

path = f'models/gym/{envname}/model4'
model = seqnn.load(path)


# load data
trainset = model.get_dataset(dfs_train)
validset = model.get_dataset(dfs_valid)

In [None]:
trainset[0][0]['obs'].shape

In [None]:
scaler = PastFutureScaler(4)

In [None]:
loader = model.data_to_loader(trainset, train=True)
for past, future in loader:
    break


In [None]:
obs_past, obs_future = past['obs'], future['obs']

# train
scaler.update_stats(obs_past, obs_future)

# scale
obs_past_scaled, obs_future_scaled = scaler.to_scaled(obs_past, obs_future)

In [None]:
obs_past2, obs_future2 = scaler.to_native(obs_past_scaled, obs_future_scaled)

In [None]:
(obs_past - obs_past2).abs().max()

In [None]:
(obs_future - obs_future2).abs().max()

In [None]:
j = 3
#plt.plot(obs_past[:5,:,j].detach().transpose(0,1), '.-');
plt.plot(obs_past_scaled[:3,:,j].detach().transpose(0,1), '.-');

In [None]:
j = 1
#plt.plot(obs_future[:3,:,j].detach().transpose(0,1), '.-');
#plt.plot(obs_future2[:3,:,j].detach().transpose(0,1), '--');
plt.plot(obs_future_scaled[:3,:,j].detach().transpose(0,1), '.-');