In [1]:
import sys, os
import time
import numpy as np
import gym
import torch
import torch.nn as nn
from torch import Tensor
import matplotlib.pyplot as plt
import configargparse
from tabulate import tabulate
import torch.multiprocessing as mp

import foundation as fd
from foundation import util
from foundation import models
from foundation import rl
from foundation import envs
from foundation import train

from rlhw_backend import *

In [2]:
parser = train.setup_rl_options()
args = parser.parse_args(['--config', '../config/ppo.yaml'])
print(args.__dict__.keys())

dict_keys(['config', 'name', 'save_root', 'log_date', 'log_tb', 'log_txt', 'save_freq', 'agent', 'clip', 'policy', 'model', 'baseline', 'env', 'device', 'seed', 'budget_steps', 'steps_per_itr', 'tau', 'epochs', 'batch_size', 'norm_adv', 'optim_type', 'lr', 'weight_decay', 'momentum', 'step_size', 'discount', 'subsample', 'gae_lambda', 'nonlin', 'hidden', 'min_log_std', 'b_hidden', 'b_scale_max', 'b_epochs', 'b_batch_size', 'b_nonlin', 'b_optim_type', 'b_lr', 'b_weight_decay', 'b_momentum', 'b_nesterov', 'b_time_order', 'b_obs_order'])


In [3]:
# manually changing args

args.name = 'test-ppo-nb'

args.log_tb = False #True


In [4]:
now = time.strftime("%y-%m-%d-%H%M%S")
if args.log_date:
    args.name = os.path.join(args.name, now)
args.save_dir = os.path.join(args.save_root, args.name)
print('Save dir: {}'.format(args.save_dir))
if args.log_tb or args.log_txt or args.save_freq is not None:
    util.create_dir(args.save_dir)
    print('Logging/Saving in {} (tb={},txt={})'.format(args.save_dir, args.log_tb, args.log_txt))
logger = util.Logger(args.save_dir, tensorboard=args.log_tb, txt=args.log_txt)

if args.seed is None:
    args.seed = util.get_random_seed()
    print('Generating random seed: {}'.format(args.seed))

Save dir: results/test-ppo-nb\19-03-09-013118
Logging/Saving in results/test-ppo-nb\19-03-09-013118 (tb=False,txt=False)
Generating random seed: -1793008202


In [5]:
torch.manual_seed(args.seed)
print('Using {}'.format(args.device))

env = envs.Pytorch_Gym_Env(args.env, device=args.device)
env.seed(args.seed)

args.state_dim, args.action_dim = len(env.observation_space.low), len(env.action_space.low)
print('Env name={} (obs={}, act={})'.format(env._env.env.spec.id, args.state_dim, args.action_dim))

n_batch = args.budget_steps / args.steps_per_itr

Using cpu
WARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.
WARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.
Env name=InvertedPendulum-v2 (obs=4, act=1)


In [6]:
if 'mlp' in args.baseline:

    baseline_model = NormalizedMLP(args.state_dim, 1, norm='norm' in args.baseline,
                                   hidden_dims=args.b_hidden, nonlin=args.b_nonlin)

    baseline_model.optim = util.get_optimizer(args.b_optim_type, baseline_model.parameters(), lr=args.b_lr, weight_decay=args.b_weight_decay)
    baseline_model.scheduler = torch.optim.lr_scheduler.LambdaLR(
                baseline_model.optim, lambda x: (n_batch - x) / n_batch, -1)

    # print(baseline_model.optim)
    # quit()

    #assert args.baseline == 'norm-mlp'
    baseline = rl.Deep_Baseline(baseline_model, scale_max=args.b_scale_max,
                        batch_size=args.b_batch_size, epochs_per_step=args.b_epochs, )

elif args.baseline == 'lin':
    baseline = rl.Linear_Baseline(state_dim=args.state_dim, value_dim=1)
else:
    raise Exception('unknown baseline: {}'.format(args.baseline))

In [7]:
assert args.policy == 'normal'
assert args.model == 'norm-mlp'
policy_model = NormalizedMLP(args.state_dim, 2 * args.action_dim, hidden_dims=args.hidden, nonlin=args.nonlin)
policy = rl.NormalPolicy(policy_model, )

assert args.agent == 'ppoclip'

agent = rl.PPOClip(policy=policy, baseline=baseline, clip=args.clip, normalize_adv=args.norm_adv,
            optim_type=args.optim_type, lr=args.lr, scheduler_lin=n_batch, weight_decay=args.b_weight_decay,
            batch_size=args.batch_size, epochs_per_step=args.epochs,
            ).to(args.device)

print(agent)
print('Agent has {} parameters'.format(util.count_parameters(agent)))

gen = fd.data.Generator(env, agent, step_limit=args.budget_steps,
                step_threshold=args.steps_per_itr, drop_last_state=True)



PPOClip(
  (policy): NormalPolicy(
    (model): NormalizedMLP(
      (criterion): MSELoss()
      (norm): RunningNormalization()
      (net): Sequential(
        (0): Linear(in_features=4, out_features=8, bias=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=8, out_features=8, bias=True)
        (3): PReLU(num_parameters=1)
        (4): Linear(in_features=8, out_features=2, bias=True)
      )
    )
  )
  (baseline): Deep_Baseline(
    (model): NormalizedMLP(
      (criterion): MSELoss()
      (norm): RunningNormalization()
      (net): Sequential(
        (0): Linear(in_features=4, out_features=8, bias=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=8, out_features=8, bias=True)
        (3): PReLU(num_parameters=1)
        (4): Linear(in_features=8, out_features=1, bias=True)
      )
    )
  )
)
Agent has 255 parameters


In [8]:
train.run_rl_training(gen, agent, args=args, logger=logger, save_freq=args.save_freq)

  self.val = torch.tensor(val).float()


[ 03-09-19 01:31:29 ] 2048/1000000 (ep=309) : last=4.000 max=27.000 - 6.584 
--- checkpoint saved at: results/test-ppo-nb\19-03-09-013118\checkpoint_0.pth.tar ---
[ 03-09-19 01:31:34 ] 4106/1000000 (ep=553) : last=12.000 max=35.000 - 8.580 
[ 03-09-19 01:31:38 ] 6161/1000000 (ep=753) : last=8.000 max=47.000 - 10.669 
[ 03-09-19 01:31:41 ] 8219/1000000 (ep=897) : last=12.000 max=47.000 - 15.373 
[ 03-09-19 01:31:45 ] 10269/1000000 (ep=996) : last=45.000 max=54.000 - 26.326 
[ 03-09-19 01:31:49 ] 12325/1000000 (ep=1054) : last=24.000 max=79.000 - 36.905 
[ 03-09-19 01:31:53 ] 14380/1000000 (ep=1093) : last=50.000 max=99.000 - 56.198 
[ 03-09-19 01:31:57 ] 16485/1000000 (ep=1123) : last=103.000 max=113.000 - 69.701 
[ 03-09-19 01:32:01 ] 18566/1000000 (ep=1150) : last=144.000 max=230.000 - 86.239 
[ 03-09-19 01:32:04 ] 20719/1000000 (ep=1167) : last=154.000 max=352.000 - 117.228 
[ 03-09-19 01:32:08 ] 23027/1000000 (ep=1174) : last=552.000 max=552.000 - 241.723 
--- checkpoint saved at: r

KeyboardInterrupt: 

In [None]:
path = save_checkpoint({
            'agent_state_dict': agent.state_dict(),
            'stats': stats,
            'args': args,
            'steps': gen.steps_generated(),
            'episodes': gen.episodes_generated(),
        }, args.save_dir, epoch=itr)