In [1]:
import time

import argparse
from bigmdp.data.env_gym import SimpleNormalizeEnv
from bigmdp.data.dataset import SimpleReplayBuffer, PrioritizedReplayBuffer
from bigmdp.utils.utils_log import *
from bigmdp.utils.utils_video import *
from bigmdp.utils.tmp_vi_helper import *
from bigmdp.utils.image_wrappers import *
from bigmdp.hyper_params import HYPERPARAMS
from bigmdp.utils.utils_directory import *
# from async_vi.MDP import *
import numpy as np
from copy import deepcopy as cpy

from bigmdp.mdp.MDP_GPU import FullMDP
from model import FC_MDP_Predictor_discrete
import pycuda.autoinit

# Get all Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--bottle_neck_size", help="size", type=int, default=32)
parser.add_argument("--discrete_bn", help="size", type=int, default=0)
parser.add_argument("--multiplyer", help="multiplyer of feature space", type=int, default=10)
parser.add_argument("--env", help="environment name", type=str, default="CartPole")
parser.add_argument("--name", help="Experiment name", type=str, default="CartPoleR1")
parser.add_argument("--load", help="Load the previous MDP ?", type=int, default=0)
parser.add_argument("--symbolic", help="Use Symbolic env if 1 else use image based env", type=int, default=1)
parser.add_argument("--steps_to_train", help="Number of steps to train the whole pipeline", type=int, default=0)
parser.add_argument("--rmax", help="Use rmax exploration?", type=int, default=0)
parser.add_argument("--strict_rmax", help="Use rmax exploration and rmax exploration alone?", type=int, default=0)
parser.add_argument("--video_every", help="get a rollout video every", type=int, default=999999999)
parser.add_argument("--backup_every", help="do a bellman backup every k frames", type=int, default=10)
parser.add_argument("--device", help="do backups on ?", type=str, default="GPU")
parser.add_argument("--save_transitions", help="do backups on ?", type=int, default=0)
parser.add_argument("--shaped_reward", help="Use shaped rewards?", type=int, default=0)
parser.add_argument("--load_time_string", help="timestep string", type=str, default="default_time")
parser.add_argument("--use_priority_buffer", help="Use priority buffer ?", type=int, default=0)
parser.add_argument("--render_every", help="render a episode every kth episode", type=int, default=0)
parser.add_argument("--internal_count_override", help="override internal count", type=int, default=0)

# args = parser.parse_args()
args = parser.parse_args("--video_every 50 --env CartPole --multiplyer 5 --rmax 0 --backup_every 100 --load_time_string decimal_test".split(" "))
# args = parser.parse_args("--video_every 50 --env Acrobot --multiplyer 5 --rmax 0 --backup_every 5 --load_time_string without_prior".split(" "))


run_params = "_bn-" + str(args.bottle_neck_size) + \
             "_sym-" + str(args.symbolic) + \
             "_rmax-" + str(bool(args.rmax)) + \
             "_strict_rmax-" + str(bool(args.strict_rmax)) + \
             "_mult-" + str(args.multiplyer) + \
             "_bkp_f-" + str(args.backup_every) + \
             "_device-" + str(args.device) + \
             "_priority-" + str(args.use_priority_buffer)

base_file_path = "./result_dump/{}/{}".format(args.env, args.name + run_params)

create_hierarchy(base_file_path)

train_epochs = 20

log_dirs_dict, loggers_dict = get_advanced_log_dir_and_logger(ROOT_FOLDER="Symbolic" if args.symbolic else "Image",
                                                              EXP_ID=args.env + "-Dec", #todo remove for general
                                                              EXP_PARAMS=run_params,
                                                              load_time_string=args.load_time_string,
                                                              tb_log_keys=["tb_train_logger", "tb_valid_logger"])

if args.symbolic:
    params = HYPERPARAMS[args.env + "-sym"]
    if os.path.exists("./" + str(args.env) + "_env.pk"):
        print(" Environment  Loaded")
        env = torch.load("./" + str(args.env) + "_env.pk")
    else:
        env = SimpleNormalizeEnv(params["env_name"], max_episode_length=params["max_episode_length"])
        torch.save(env, "./" + str(args.env) + "_env.pk")
else:
    print("Not Implemented Yet")  # Todo
    params = HYPERPARAMS[args.env + "-img"]
    assert False

test_env = cpy(env)


K_s, K_i = params["replay_initial"], args.internal_count_override if args.internal_count_override else params[
    'internal_step_count_per_policy']
print(K_i, K_s)

if args.load:
    mdp = torch.load(base_file_path + "mdp_class.pth")
else:
    mdp = FullMDP(A=env.get_list_of_actions(),
                  ur=params["unknown_transition_reward"],
                  vi_params={"gamma": params["gamma"],
                             "slip_prob": params["slip_probability"],
                             "rmax_reward": params["rmax_reward"],
                             "rmax_thres": 2,
                             "balanced_explr": True,
                             "rmin": params["rmin"]},
                  policy_params={"unhash_array_len": env._env.observation_space.shape[0]},
                  MAX_S_COUNT=int(1e6),
                  weight_transitions=False,
                  default_mode = args.device)

all_rewards = []
eval_rewards, safe_eval_rewards = [], []
policy_fetch_time = []
bellman_backup_time = [9999]
train_buffer = PrioritizedReplayBuffer(int(5e5)) if args.use_priority_buffer else SimpleReplayBuffer(int(5e5))
valid_buffer = PrioritizedReplayBuffer(int(1e5)) if args.use_priority_buffer else SimpleReplayBuffer(int(5e5))

eps_tracker = EpsilonTracker(params)

frame_count = 0
warmup_eps = 10
eval_reward = 0

omit_list = ["end_state", "unknown_state"]

from bigmdp.data.dataset import gather_data_in_buffer
from learnt_mdp_helper import *

# Collect random Dataset # till replay initial\
random_policy = lambda s: env.sample_random_action()
train_buffer, info = gather_data_in_buffer(train_buffer, env, episodes=9999, render=False, policy=random_policy,
                                           frame_count=K_s, pad_attribute_fxn={"qval": lambda s: 0})
valid_buffer, info = gather_data_in_buffer(valid_buffer, env, episodes=9999, render=False, policy=random_policy,
                                           frame_count=int(K_s / 5), pad_attribute_fxn={"qval": lambda s: 0})

# recon_loss_wts = {h:0 for h in training_net.head_ids}
# recon_loss_wts["recon"] = 1
# training_net.default_loss_wts  = recon_loss_wts
# training_net.fit(train_buffer,valid_buffer, epochs=50)
frame_count = 0
print(frame_count)

from tqdm import tqdm





args.model_update_every = 500

outer_loop_frame_count = len(train_buffer)

    # torch.save((train_buffer.buffer, mdp, net.state_dict()), log_dirs_dict["py_log_dir"] + "/checkpoint" + ".pth")

 Environment  Loaded
20000 20000


2020-02-28 15:41:26,250:mylogger:Average Reward of collected trajectories:21.907
I0228 15:41:26.250660 140416558835520 dataset.py:292] Average Reward of collected trajectories:21.907
2020-02-28 15:41:26,353:mylogger:Average Reward of collected trajectories:23.585
I0228 15:41:26.353203 140416558835520 dataset.py:292] Average Reward of collected trajectories:23.585


Average Reward of collected trajectories:21.907
Average Reward of collected trajectories:23.585
0


In [2]:
params["evaluate_every"] = 5000
params["outer_iteration_count"] = 1
params["bellman_backup_every"] = 100
params["n_backups"] = 10
K_i = int(2e6)

In [3]:
def img_2_disc_fxn(s):
    if len(s)==1:
        return s[0] * args.multiplyer
    else:
        return [s_ * args.multiplyer for s_ in s]

In [4]:

outer_count = 0 
# Update the buffer with new values
# For Disc: Skipped

# Train your Network
# For Disc: Skipped

# Populate your MDP
mdp_T = make_new_mdp(train_buffer, img_2_disc_fxn, params, env)
print("Size of MDP,", len(mdp_T.tD))

# Solve Your MDP
mdp_T.solve(eps=1, mode=args.device)

# for mdp_name, M in {"Tabular": mdp_T, "Tabular_Insertion": mdp_TI}.items():

opt_policy = lambda s: mdp_T.get_opt_action(hAsh(img_2_disc_fxn([s]).tolist()), mode = args.device)
safe_policy = lambda s: mdp_T.get_safe_action(hAsh(img_2_disc_fxn([s]).tolist()), mode = args.device)
explr_policy = lambda s: mdp_T.get_explr_action(hAsh(img_2_disc_fxn([s]).tolist()), mode = args.device)
random_policy = lambda s: env.sample_random_action()
eps_opt_policy = get_eps_policy(opt_policy, random_policy, epsilon=0.2)

if args.strict_rmax:
    explore_policies = {"explr_policy": explr_policy}
elif args.rmax:
    explore_policies = {"explr_policy": explr_policy, "opt_policy": opt_policy}
else:
    explore_policies = {"opt_policy": opt_policy, "eps_opt_policy":eps_opt_policy}

s = env.reset()
running_reward = 0
inner_loop_frame_count = 0
eps_count = 0

100%|██████████| 79/79 [00:05<00:00, 15.76it/s]


Size of MDP, 496
Time takedn to solve 2.43491792678833


In [None]:
args.render_every = 10000

In [None]:

for i in tqdm(range(int(1e8))):
    inner_loop_frame_count += 1
    outer_loop_frame_count += 1

    if args.render_every and eps_count % args.render_every == 0:
        env.render()

    if outer_loop_frame_count % params['bellman_backup_every'] == 0:
        st = time.time()
        mdp_T.do_optimal_backup(mode=args.device, n_backups=params["n_backups"])
        if args.rmax:
            mdp_T.do_explr_backup(mode=args.device, n_backups=params["n_backups"])
        bellman_backup_time.append(time.time() - st)

    policy_name, policy = "a", eps_opt_policy # list(explore_policies.items())[eps_count%len(explore_policies)]
    a = policy(s)

    ns, r, d, i = env.step(a)
    _d = False if d and i["max_episode_length_exceeded"] == True else d
    running_reward += r

    # add to buffer
    # Not necessary for Decimal Discretization bur sure
    exp = [s.tolist(), [a], ns.tolist(), [r], [_d]]
    train_buffer.add(exp, padded_info={"qval": 0})
    if np.random.randint(20) < 2:
        valid_buffer.add(exp, padded_info={"qval": 0})

    # Update MDP
    s_d, ns_d = img_2_disc_fxn([s]), img_2_disc_fxn([ns])
    hs_d, hns_d = hAsh(s_d.tolist()), hAsh(ns_d.tolist())
    mdp_T.consume_transition([hs_d, int(a), hns_d, float(r), int(_d)])
#     mdp_T.consume_transition((hAsh(hs.tolist()), int(a), hAsh(hns.tolist()), float(r), int(_d)))

    # prep for next loop
    s = ns

    # Housekeeping to omit while true loop
    if d:
        s = env.reset()
        running_reward = 0
        eps_count += 1

    if inner_loop_frame_count % params['evaluate_every'] == 0:
        mdp_T.do_optimal_backup(mode="GPU", n_backups=500)
        n_steps = outer_count * K_i + inner_loop_frame_count
        eval_reward = evaluate_on_env(test_env, opt_policy, eps_count=50, render=False)[0]
        
        col_header = ["Average Reward","MDP #States", "#MissingTrans/ #Trans","Missing %", "VI Error"] 
        meta_data = [10]*len(col_header) # max len of the data in the table to be printed (per column)
        print(' | '.join([s.ljust(max(meta_data[i], len(s)), '.') for i, s in enumerate(col_header)]))
        col_data = [eval_reward, 
                    len(mdp_T.tD),  
                    str(mdp_T.missing_state_action_count)+"/"+str(len(mdp_T.tD)*len(mdp_T.A)),
                    round(mdp_T.missing_state_action_count/mdp_T.total_state_action_count,4 ),
                    mdp_T.curr_vi_error       ]
        print(' | '.join([str(s).ljust(max(meta_data[i], len(col_header[i])),  " ") for i, s in enumerate(col_data)]))
        
        loggers_dict["tb_train_logger"].add_scalar("Optimal Avg Reward", n_steps)
        loggers_dict["tb_train_logger"].add_scalar("State Count", float(len(mdp_T.tD)), outer_loop_frame_count)
        loggers_dict["tb_train_logger"].add_scalar('Mising Transition Count', mdp.missing_state_action_count, outer_loop_frame_count)
        loggers_dict["tb_train_logger"].add_scalar("Opt VI Error",mdp_T.curr_vi_error, outer_loop_frame_count)
        loggers_dict["tb_train_logger"].add_scalar("Safe VI Error",mdp_T.s_curr_vi_error, outer_loop_frame_count)

    ####################################################

    # Not necessary for Decimal Experiments
    # if outer_loop_frame_count % args.checkpoint_every == 0:
    #     training_net._save_to_cache()
    #     torch.save((train_buffer.buffer, mdp, net.state_dict()),
    #                log_dirs_dict["py_log_dir"] + "/checkpoint" + ".pth")

# Not necessary for Decimal Experiments
# training_net._save_to_cache()

  0%|          | 5007/100000000 [00:27<3085:17:30,  9.00it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
78.84          | 657         | 343/1314              | 0.261      | 0.39984130859375


  0%|          | 10007/100000000 [00:55<2869:20:35,  9.68it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
116.4          | 854         | 495/1708              | 0.2898     | 0.593780517578125


  0%|          | 15006/100000000 [01:22<3054:40:43,  9.09it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
176.52         | 974         | 571/1948              | 0.2931     | 0.554718017578125


  0%|          | 20005/100000000 [01:51<3336:50:19,  8.32it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
122.5          | 1092        | 645/2184              | 0.2953     | 1.31256103515625


  0%|          | 25000/100000000 [02:20<4457:39:47,  6.23it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
131.58         | 1162        | 681/2324              | 0.293      | 1.6876220703125


  0%|          | 30000/100000000 [02:50<4529:51:08,  6.13it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
88.3           | 1277        | 755/2554              | 0.2956     | 32.89703369140625


  0%|          | 35009/100000000 [03:17<2716:55:11, 10.22it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
163.58         | 1325        | 765/2650              | 0.2887     | 12.09625244140625


  0%|          | 40008/100000000 [03:48<3693:59:35,  7.52it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
250.24         | 1377        | 798/2754              | 0.2898     | 4.44677734375


  0%|          | 45008/100000000 [04:17<3323:52:00,  8.35it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
283.62         | 1450        | 847/2900              | 0.2921     | 1.63446044921875


  0%|          | 50006/100000000 [04:45<2958:56:29,  9.38it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
148.42         | 1513        | 886/3026              | 0.2928     | 11.619873046875


  0%|          | 55007/100000000 [05:14<3106:44:40,  8.94it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
232.22         | 1574        | 918/3148              | 0.2916     | 3.362060546875


  0%|          | 60007/100000000 [05:44<3599:48:12,  7.71it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
236.82         | 1642        | 959/3284              | 0.292      | 1.23675537109375


  0%|          | 65006/100000000 [06:12<3115:49:35,  8.91it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
223.2          | 1738        | 1024/3476             | 0.2946     | 0.57220458984375


  0%|          | 70005/100000000 [06:41<3152:14:46,  8.81it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
287.1          | 1825        | 1082/3650             | 0.2964     | 1.3369140625


  0%|          | 75000/100000000 [07:10<4473:34:34,  6.20it/s]

Average Reward | MDP #States | #MissingTrans/ #Trans | Missing %. | VI Error..
181.58         | 1853        | 1091/3706             | 0.2944     | 0.7327880859375


  0%|          | 78177/100000000 [07:22<96:14:21, 288.41it/s]