In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import gym
import dacmdp
import dacmdp.envs as ce
# import wandb as wandb_logger

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from munch import munchify 

config = munchify({
"envArgs":{'env_name': 'CartPole-cont-v1', 'seed': 0},
"logArgs":{'wandb_id': "cartpole_online_test_1", "wandb_entity":"dacmdp",
           "wandb_project":"dacmdp_online_test_v0", "no_wandb_logging":True},
"dataArgs": {'buffer_name': 'random', 'buffer_size': 50000, 
             'load_buffer': False, 'buffer_device': 'gpu', "data_dir":""},
"reprModelArgs": {'repr_model_name': 'OracleDynamicsRepr', 's_multiplyer': 1, 'a_multiplyer': 10, 'repr_dim': 4},
"actionModelArgs": {'action_model_name': 'NNActionModelCuda', 'nn_engine': "torch_pykeops"},
"mdpBuildArgs": {'n_tran_types': 10, 'n_tran_targets': 5, 'penalty_beta': 1.0, 'penalty_type': 'linear', 'rebuild_mdpfcache': False,
                 'save_mdp2cache': False, 'save_folder': '/nfs/hpc/share/shrestaa/storage/dac_storage_22_Q4/mdp_dumps/random_hash'},
"mdpSolveArgs": {'device': 'cuda', 'max_n_backups': 5000, "gamma": 0.99, 'epsilon': 0.0001, 'penalty_beta': 1, "operator": "simple_backup"},
"evalArgs": {'eval_episode_count': 50, "skip_eval":True, "skip_dist_log":True},
})

flat_args = lambda config : {f"{K}::{k}":v for K in config for k,v in config[K].items() if K != "flat_args"}

In [4]:
if not config.logArgs.no_wandb_logging:
    wandb_logger.init( id = config.logArgs.wandb_id ,
        entity=config.logArgs.wandb_entity,
        project=config.logArgs.wandb_project,
        config = flat_args(config),
        resume = "allow")

In [5]:
env = gym.make(config.envArgs.env_name)

In [7]:
from dacmdp.core.models_action import NNActionModel, GlobalClusterActionModel, EnsembleActionModel
from dacmdp.core.models_sa_repr import OracleDynamicsRepr, DeltaPredictonRepr

######### Get Action and Repr Models ####################################
seed_buffer = dacmdp.utils_buffer.generate_or_load_buffer(config, env)
cluster_action_count = 10
cluster_action_model = GlobalClusterActionModel(action_space=env.action_space,
                                   n_actions= cluster_action_count,
                                   data_buffer=seed_buffer)
nn_action_model = NNActionModel(action_space = env.action_space,
                               n_actions = 5,
                               data_buffer = seed_buffer,
                               nn_engine= config.actionModelArgs.nn_engine,
                               projection_fxn=lambda s: s, 
                               )
action_model = EnsembleActionModel(env.action_space,[nn_action_model, cluster_action_model])    
    

# sa_repr_model = OracleDynamicsRepr(env_name=config.envArgs.env_name)
    
sa_repr_model = DeltaPredictonRepr(s_multiplyer=2, 
                               a_multiplyer=1,
                               buffer=seed_buffer,
                               nn_engine="torch_pykeops")
######################################################################################################

Collecting buffer!
Average Reward of collected trajectories:16.733
Collected buffer!
K-means for the Euclidean metric with 50,000 points in dimension 1, K = 10:
Timing for 50 iterations: 0.02949s = 50 x 0.00059s



Caculating State Representations: : 196it [00:00, 99249.50it/s]


In [8]:
cluster_action_model.cand_actions_for_states(torch.FloatTensor(seed_buffer.state[0:100]).cuda()).shape,\
sa_repr_model.encode_state_action_pairs(torch.FloatTensor(seed_buffer.state[0:100]),
                                       torch.FloatTensor(seed_buffer.action[0:100])).shape

(torch.Size([100, 10, 1]), torch.Size([100, 4]))

In [9]:
import time
import matplotlib.pyplot as plt 
import numpy as np 
from dacmdp.core.utils_misc import plot_distributions_as_rgb_array
from dacmdp.eval.utils_eval import evaluate_on_env
from dacmdp.data.utils_buffer import StandardBuffer
from dacmdp.core.dac_core import DACTransitionBatch
from dacmdp.core.dac_build import DACBuildWithActionNames
from dacmdp.core.utils_knn import THelper

data_buffer = seed_buffer 

# Instantiate Elastic Agent
config.mdpBuildArgs.n_tran_types = action_model.n_actions
config.mdpBuildArgs.repr_dim = 4

elasticAgent = DACBuildWithActionNames( config = config, 
                                    action_space = env.action_space, 
                                    action_model = action_model, # Update this later.
                                    repr_model = sa_repr_model, 
                                    effective_batch_size= 1000, 
                                    batch_calc_knn_ret_flat_engine = THelper.batch_calc_knn_ret_flat_pykeops
                                    )

dacmdp_core_defined
Using pre-initialized Action Model BaseActionModel
Using pre-initialized Action Model <dacmdp.core.models_sa_repr.DeltaPredictonRepr object at 0x2afbfd552af0>


In [21]:
######### TT 3: DACMDP Elastic Build   ###########################################################################
transitions = DACTransitionBatch(torch.FloatTensor(data_buffer.state).clone().detach(),
                                torch.FloatTensor(data_buffer.action).clone().detach(),
                                torch.FloatTensor(data_buffer.next_state).clone().detach(),
                                torch.FloatTensor(data_buffer.reward.reshape(-1)).clone().detach(), 
                                torch.LongTensor((1- data_buffer.not_done).reshape(-1)).clone().detach())

st = time.time()
elasticAgent.consume_transitions(transitions, verbose = True, batch_size = 1000)
elasticAgent.dacmdp_core.solve(max_n_backups = config.mdpSolveArgs.max_n_backups, 
                               penalty_beta = config.mdpSolveArgs.penalty_beta, 
                               epsilon = config.mdpSolveArgs.epsilon, 
                               gamma = config.mdpSolveArgs.gamma, 
                               operator="simple_backup", 
                               bellman_backup_batch_size=500, 
                               reset_values=True)

print(f"Graph built and solved in {time.time()-st:.2f} Seconds")
######################################################################################################

batch_next_states.shape torch.Size([50000, 4])
replace indices :  False torch.Size([50000, 4])
Instantiated DACMDP for transition Batch
(150000, 15, 5)
nn after consumption,  150000


Calculate Candidate Actions: : 50it [00:00, 235.67it/s]
Calculate/Update Datsaet SA Representation: : 50it [00:00, 286.49it/s]
Calculate/Update Candidate Transition SA Representation: : 50it [00:00, 211.94it/s]
Update Transition model of core dacmdp: : 50it [00:00, 80.45it/s]


0 tensor(1.)
500 tensor(0.0066)
1000 tensor(0.0001)
1500 tensor(2.2888e-05)
Solved MDP in 1500 Backups
Graph built and solved in 2.16 Seconds


In [24]:
elasticAgent.dacmdp_core.solve(max_n_backups = 10000, 
                               penalty_beta = 1, 
                               epsilon = config.mdpSolveArgs.epsilon, 
                               gamma = 0.999, 
                               operator="simple_backup", 
                               bellman_backup_batch_size=500)

0 tensor(0.7966)
500 tensor(0.4833)
1000 tensor(0.2940)
1500 tensor(0.1791)
2000 tensor(0.1089)
2500 tensor(0.0666)
3000 tensor(0.0405)
3500 tensor(0.0255)
4000 tensor(0.0163)
4500 tensor(0.0099)
5000 tensor(0.0063)
5500 tensor(0.0039)
6000 tensor(0.0026)
6500 tensor(0.0018)
7000 tensor(0.0013)
7500 tensor(0.0008)
8000 tensor(0.0007)
8500 tensor(0.0005)
9000 tensor(0.0004)
9500 tensor(0.0003)
Solved MDP in 9999 Backups


In [25]:
config.evalArgs.eval_episode_count = 50
avg_rewards, info = evaluate_on_env(env,elasticAgent.dummy_lifted_policy, eps_count=config.evalArgs.eval_episode_count)
print(avg_rewards)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:05<00:00,  8.57it/s]

435.38



