In [15]:
from cem_optimizer_v2 import CEM_opt
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from buffers import MultiEnvReplayBuffer


from IPython.display import clear_output
from queue import Queue
import matplotlib.pyplot as plt
import numpy as np
import torch


In [20]:
import json 



def json_load():
    with open('config_parameters.json', 'r') as f:
        config_file = json.load(f)
    return config_file


def save_json_update(json_obj):
    with open('config_parameters.json', 'w') as f:
        json.dump(json_obj, f)

config = json_load()


# nested struct
buff_config = config['buffer']
planner_config = config['planner']
const_config = config['const']
train_config = config['train']

In [21]:
envs = {}

for name in list(ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys()):
    if not buff_config["correspondence_id2env"].get(name, 0):
        buff_config["correspondence_id2env"][name] = buff_config["correspondence_id2env"]["first_idx_free"]
        buff_config["correspondence_id2env"]["first_idx_free"] += 1
    envs[buff_config["correspondence_id2env"][name]] = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[name]()


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


'<PlateSlideSideV2GoalObservable instance>'

In [58]:
buffer = MultiEnvReplayBuffer(1_000)
env_action_space_shape = 4

In [50]:
class InformedPlanner:
    
    def __init__(self, env):
        
        self.env = env
        self.horizon = 40
        self.num_sequence_action = 100
        self.cem = CEM_opt(num_action_seq=self.num_sequence_action,
                           action_seq_len=env_action_space_shape * self.horizon,
                           percent_elite=0.1)
        self.action_seq_planned = Queue(maxsize=self.horizon)
        
        
    def plan(self, force_replan=False):
        
        if self.action_seq_planned.empty() or force_replan:
            
            action_sequences = self.cem.population
            rewards = np.zeros(action_sequences.shape[0])
            for idx, seq in enumerate(action_sequences):
                rewards[idx] = self.eval_act_seq(seq)
            self.cem.update(rewards)
            
            for act in self.cem.solutions().reshape(-1, 4):
                self.action_seq_planned.put(act)
        return self.action_seq_planned.get()
            
            
                    
    def eval_act_seq(self, sequence):
        rew_seq = 0
        self.env.reset()
        act_reshaped = sequence.reshape((-1, 4))
        for act in act_reshaped:
            
            _, r, _, _ = self.env.step(act)
            rew_seq += r
        return rew_seq/len(sequence)


In [59]:
planners = {k: InformedPlanner(envs[k]) for  k in list(envs.keys())}


for env_k in list(planners.keys()): 
    
    min_rew = 100
    max_rew = 0
    
    for episode in range(20): 
        
        
        state = envs[env_k].reset()
        clear_output(wait=True)

        for horizon_step in range(40): 


            action = planners[env_k].plan()
            s_prime, reward, done, _ = envs[env_k].step(action)
            
            buffer.add(state, action, reward, s_prime, done, env_k)
            state = s_prime
            if reward > max_rew:
                max_rew = reward
            
            if reward < min_rew: 
                min_rew = reward

            print(f'{env_k=}, {episode=}, {horizon_step=}, {reward=}, {done=}')

            
    with open('log_cem_fake.txt', 'a+') as f: 
        f.write(f'env {env_k} : envs : {envs[env_k].__str__()} : {min_rew=}, {max_rew=} (horizon=40)\n')
        

env_k=22, episode=19, horizon_step=0, reward=0.9360255775307187, done=False
env_k=22, episode=19, horizon_step=1, reward=0.9377559178867427, done=False
env_k=22, episode=19, horizon_step=2, reward=0.940641457766598, done=False
env_k=22, episode=19, horizon_step=3, reward=0.9444797164955022, done=False
env_k=22, episode=19, horizon_step=4, reward=0.9491589289275113, done=False
env_k=22, episode=19, horizon_step=5, reward=0.9543104360373054, done=False
env_k=22, episode=19, horizon_step=6, reward=0.9597786329688927, done=False
env_k=22, episode=19, horizon_step=7, reward=0.9652745176253463, done=False
env_k=22, episode=19, horizon_step=8, reward=0.9706125893449556, done=False
env_k=22, episode=19, horizon_step=9, reward=0.9756192665127338, done=False
env_k=22, episode=19, horizon_step=10, reward=0.9800919051040808, done=False
env_k=22, episode=19, horizon_step=11, reward=0.9839795160886217, done=False
env_k=22, episode=19, horizon_step=12, reward=0.9872949167115987, done=False
env_k=22, 

In [62]:
buffer.write_buffer('buffer_stock/')