In [23]:
import sys, os, time
%load_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F
from torch.autograd import Variable, grad
import torch.distributions as distrib
import torch.nn as nn
import torch.utils.data as td
from torch.utils.data import Dataset, DataLoader
import gym
import numpy as np
%matplotlib notebook
#%matplotlib tk
from torchdiffeq import odeint

import matplotlib.pyplot as plt
#plt.switch_backend('Qt5Agg') #('Qt5Agg')
from scipy.sparse import coo_matrix
import foundation as fd
from foundation import nets
from foundation import util
from foundation import train
import torch.multiprocessing as mp
np.set_printoptions(linewidth=120)

from nb_backend import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
args = util.NS()
args.name = 'test-node-policy'

args.txtlog = False
args.tblog = True
args.save_root = '../trained_nets/'
args.num_workers = 2

args.def_type = 'torch.FloatTensor'

#args.env = 'cartpole'
args.env = 'Pendulum-v0'
args.base_seed = 12

args.optim_type = 'rmsprop'
args.lr = 1e-3
args.weight_decay = 1e-4
args.momentum = 0.9

args.nonlin = 'prelu'
args.pi_hidden = [8]
args.dyn_hidden = [8, 8, 8]

args.delta = 0.01
args.gamma = 0.99

args.num_iter = 100
args.num_traj = 20
args.num_eval = 5

print('Name: {}'.format(args.name))

Name: test-node-policy


In [25]:
now = time.strftime("%y-%m-%d-%H%M%S")
args.save_dir = os.path.join(args.save_root, args.name, now)
#args.save_dir = os.path.join(args.save_root, args.name)
print('Save dir: {}'.format(args.save_dir))
if args.tblog or args.txtlog:
    util.create_dir(args.save_dir)
    print('Logging in {}'.format(args.save_dir))
logger = util.Logger(args.save_dir, tensorboard=args.tblog, txt=args.txtlog)
env = util.Numpy_Env_Wrapper(gym.make(args.env))
args.state_dim = env.observation_space.shape[0] #env.spec.obs_space.size
args.ctrl_dim = env.action_space.shape[0] #env.spec.action_space.choices if args.env == 'cartpole' else env.spec.action_space.size
print('Env: {} (obs={},act={})'.format(args.env, args.state_dim, args.ctrl_dim))

Save dir: ../trained_nets/test-node-policy/19-01-20-023310
Logging in ../trained_nets/test-node-policy/19-01-20-023310
Env: Pendulum-v0 (obs=3,act=1)


In [26]:
policy = Policy(state_dim=args.state_dim, ctrl_dim=args.ctrl_dim, hidden_dims=args.pi_hidden, nonlin=args.nonlin)
dynamics = Dynamics(state_dim=args.state_dim, ctrl_dim=args.ctrl_dim, hidden_dims=args.dyn_hidden, nonlin=args.nonlin)
func = Integrator(policy, dynamics)
print(func)

Integrator(
  (policy): Policy(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=8, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=8, out_features=1, bias=True)
      (3): Tanh()
    )
  )
  (dynamics): Dynamics(
    (net): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=8, out_features=8, bias=True)
      (3): PReLU(num_parameters=1)
      (4): Linear(in_features=8, out_features=8, bias=True)
      (5): PReLU(num_parameters=1)
      (6): Linear(in_features=8, out_features=3, bias=True)
    )
  )
)


In [27]:
states, actions, rewards = generate_rollouts(policy, env=env)
states.shape, actions.shape, rewards.shape

(torch.Size([201, 3]), torch.Size([200, 1]), torch.Size([200]))

In [28]:
goal = torch.tensor([1,0,0]).float()
Q = torch.eye(args.state_dim)
R = torch.eye(args.ctrl_dim)
def cost(x, u=None):
    q = x-goal.unsqueeze(0)
    cost = q.unsqueeze(-2)@Q@q.unsqueeze(-1)
    if u is not None:
        cost += u.unsqueeze(-2)@R@u.unsqueeze(-1)
    return cost
print(goal)
print(Q)
print(R)
times = torch.arange(env.spec.timestep_limit).float()+1.
times /= 100

tensor([1., 0., 0.])
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
tensor([[1.]])


In [29]:
criterion = nn.MSELoss()
dyn_optim = nets.get_optimizer(args.optim_type, dynamics.parameters(), lr=args.lr, weight_decay=args.weight_decay)
pi_optim = nets.get_optimizer(args.optim_type, policy.parameters(), lr=args.lr, weight_decay=args.weight_decay)

In [30]:
itr = 0
num = 1000

In [31]:
stats = util.StatsMeter('pi', 'dyn', 'total', tau=0.05)

In [34]:
print_freq = max(1, num // 100)

for _ in range(num):
    
    #states, actions, rewards = generate_rollouts(policy, env=env)
    
    x0, xs = states[:1], states[1:]
    
    pred = odeint(func, x0, times).squeeze()
    
    dyn_loss = criterion(pred, xs)
    pi_loss = cost(pred).mean()
    
    stats.update('pi', pi_loss.detach())
    stats.update('dyn', dyn_loss.detach())
    
    loss = dyn_loss + pi_loss
    
    stats.update('total', loss.detach())
    
    dyn_optim.zero_grad()
    pi_optim.zero_grad()
    
    loss.backward()
    
    dyn_optim.step()
    pi_optim.step()
    
    if itr % print_freq == 0:
        print('Itr {}: dyn={dyn.val:.4f} ({dyn.smooth:.4f}), pi={pi.val:.4f} ({pi.smooth:.4f}), total={total.val:.4f} ({total.smooth:.4f})'.format(itr+1, 
                        dyn=stats['dyn'], pi=stats['pi'], total=stats['total']))
        logger.update(stats.smooths(), itr)
    
    itr += 1

Itr 1001: dyn=1.3617 (5.9617), pi=2.6219 (0.7789), total=3.9835 (6.7406)
Itr 1011: dyn=1.3646 (4.1166), pi=2.5755 (1.5064), total=3.9401 (5.6231)
Itr 1021: dyn=1.3666 (3.0129), pi=2.5463 (1.9277), total=3.9128 (4.9405)
Itr 1031: dyn=1.3688 (2.3528), pi=2.5150 (2.1684), total=3.8837 (4.5212)
Itr 1041: dyn=1.3711 (1.9585), pi=2.4841 (2.3001), total=3.8552 (4.2586)
Itr 1051: dyn=1.3737 (1.7234), pi=2.4667 (2.3700), total=3.8404 (4.0935)
Itr 1061: dyn=1.3758 (1.5836), pi=2.4393 (2.4022), total=3.8151 (3.9858)
Itr 1071: dyn=1.3781 (1.5008), pi=2.4109 (2.4103), total=3.7890 (3.9110)
Itr 1081: dyn=1.3806 (1.4522), pi=2.3802 (2.4027), total=3.7608 (3.8549)
Itr 1091: dyn=1.3828 (1.4240), pi=2.3540 (2.3875), total=3.7368 (3.8115)
Itr 1101: dyn=1.3855 (1.4081), pi=2.3254 (2.3675), total=3.7109 (3.7755)
Itr 1111: dyn=1.3882 (1.3996), pi=2.2969 (2.3438), total=3.6851 (3.7435)
Itr 1121: dyn=1.3912 (1.3957), pi=2.2671 (2.3180), total=3.6582 (3.7137)
Itr 1131: dyn=1.3945 (1.3947), pi=2.2359 (2.2901), 

In [47]:
#pred = odeint(func, states[:1], times)

In [13]:
pred.shape

torch.Size([200, 3])

In [36]:
generate_rollouts(policy, env=env, render=True)
pass

In [33]:
states, actions, rewards = generate_rollouts(policy, env=env, render=True)
states.shape, actions.shape, rewards.shape

(torch.Size([201, 3]), torch.Size([200, 1]), torch.Size([200]))

In [9]:
q = states[:2]-goal
#q = q.unsqueeze(-2)
q.unsqueeze(-2)@Q@q.unsqueeze(-1)

tensor([[[3.7948]],

        [[3.0070]]])

In [14]:
rewards.mean()

tensor(-9.0670)