In [65]:
from utils import get_trans_tuple, convert_frame
import gym
from gym_minigrid.register import env_list
from gym_minigrid.minigrid import Grid
import numpy as np
from collections import namedtuple
import copy

In [91]:
def get_trans_tuple():
        tuple_fields = ['xs','actions', 'rs']
        

        tuple_fields.extend(['x_coords', 'y_coords',"directions"])

        Transition = namedtuple("Transition",tuple(tuple_fields))
        return Transition

In [92]:
class DataCollector(object):
    def __init__(self, policy=lambda x0: np.random.choice(3),
                        env=gym.make("MiniGrid-Empty-6x6-v0"),
                        convert_fxn=convert_frame,
                        datapoints_per_trans=1):
        self.convert_fxn = convert_fxn
        self.policy = policy
        self.env = env
        self.datapoints_per_trans = datapoints_per_trans
        
    
    def _collect_datapoint(self):
        x = self.convert_fxn(self.env.render("rgb_array"))
        x_coord, y_coord = int(self.env.agent_pos[0]), int(self.env.agent_pos[1])
        direction = self.env.agent_dir
        return x, x_coord, y_coord, direction

    def append_to_trans_ar(self,trans,action,reward):
        trans.actions.append(copy.deepcopy(action))
        trans.rs.append(copy.deepcopy(reward))
        
    def append_to_trans_state(self,trans,x, x_coord, y_coord, direction):
        trans.xs.append(copy.deepcopy(x))
        trans.x_coords.append(copy.deepcopy(x_coord))
        trans.y_coords.append(copy.deepcopy(y_coord))
        trans.directions.append(copy.deepcopy(direction))
        
        
    def collect_transition_per_the_policy(self):
        Transition = get_trans_tuple()
        num_fields = len(Transition.__dict__["_fields"])

        trans_list = [[] for _ in range(num_fields)]
        trans = Transition(*trans_list)
        for _ in range(self.datapoints_per_trans):
            x, x_coord, y_coord, direction = self._collect_datapoint()
            self.append_to_trans_state(trans,x, x_coord, y_coord, direction)
    
            # to_tensor true just in case policy is exactly a neural network
            action = self.policy(self.convert_fxn(x,to_tensor=True))
            _, reward, done, _ = self.env.step(action)
            
            self.append_to_trans_ar(trans,action,reward)
            
        
        x, x_coord, y_coord, direction = self._collect_datapoint()
        self.append_to_trans_state(trans,x, x_coord, y_coord, direction)
        return trans
    
#     def collect_specific_datapoint(self,coords, direction, action):
#         self.env.agent_pos = np.asarray(coords)
#         self._get_desired_direction(direction)
#         x0 = self.env.render("rgb_array")
#         trans_obj  = self._collect_datapoint(x0, action)
#         return trans_obj
#     def _get_desired_direction(self,desired_direction):
#         true_direction = self.env.agent_dir
#         while not np.allclose(true_direction,desired_direction):
#             _ = self.env.step(0)
#             true_direction = self.env.agent_dir

In [108]:
if __name__ == "__main__":
    dc = DataCollector(datapoints_per_trans=5)

    trans = dc.collect_transition_per_the_policy()