In [1]:
import pandas as pd
import math
import random
import numpy as np

%run ./ast_parser.ipynb

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class EnvManager():
    def __init__(self, device):
        self.device = device
        self.all_episodes_df = pd.DataFrame()
        self.episodes = pd.read_json("./data/data.json")
        for index, episode in self.episodes.iterrows():
            states = code2states(episode["code"], episode["critical_positions"], episode["labels"])
            
            states = pd.DataFrame(states)
            states["episode_num"] = index
            self.all_episodes_df = pd.concat([self.all_episodes_df,states],axis=0, ignore_index=True)

        self.action_space = self.all_episodes_df["label"].unique()
        self.action_space_shape = len(self.action_space)
        self.state_space_shape = len(self.all_episodes_df.columns)
        for cloumn in self.all_episodes_df.columns:
            if cloumn != "level" and cloumn != "label" and cloumn != "episode_num":
                for index,item in enumerate(self.all_episodes_df[cloumn]):
                    if not isinstance(self.all_episodes_df.loc[index,cloumn],str) and not isinstance(self.all_episodes_df.loc[index,cloumn],list) and math.isnan(self.all_episodes_df.loc[index,cloumn]):
                        self.all_episodes_df.loc[index,cloumn] = None
                    else:
                        self.all_episodes_df.loc[index,cloumn] = type(item)
                self.all_episodes_df[cloumn] = pd.Categorical(self.all_episodes_df[cloumn])
                self.all_episodes_df[cloumn] = self.all_episodes_df[cloumn].cat.codes
                
        ## Init the env
        self.done = False # Set to true when the agent traverse the whole code
        self.current_episode = None # Index of the current episode
        self.current_state = None # current state
        self.current_state_index = None ## State index
        self.right_actions = [] # Store the right action for each node
        self.critical_parts = [] # Indeces of the ndoes with problems in the tree
        self.tree = [] # A list containing all the nodes in the current episode
        
    ## Reset the env
    def reset(self):
        self.right_actions = []
        self.current_state_index = 0
        self.current_episode = None
        self.done = False
        
        ## Get a random episode
        allepisodes = self.all_episodes_df["episode_num"].unique() ## Retrieve the ID(num) of the available episodes.
        random_episode_num = random.choice(allepisodes) ## Choose a random episode-id
        self.current_episode = self.all_episodes_df.loc[self.all_episodes_df["episode_num"] == random_episode_num] # Retrieve all states for the selected episodes.
        self.current_episode = self.current_episode.sort_values("level") # Sort the states of the selected episode
        self.current_state_index = 0 # set the current state index to 0
        self.current_state = self.current_episode # Init the current state
        self.current_state = self.current_state.drop(["label"],axis=1).iloc[self.current_state_index,:].to_numpy() # Transform the state to numpy array without the label column
        

    # Map action state to a reward
    def take_action(self, action):
        # Get the string representaiton of the tensor
        selectedAction = self.action_space[action] 
        ## Compare the selected action with the entered action
        ## Reward +1 if it is correct
        ## Reward -1 if it is wrong
        if self.current_episode["label"].iloc[self.current_state_index] == selectedAction: 
            reward = +1
        else:
            reward = -1
        if self.current_state_index == len(self.current_episode.index) - 1:
            self.done = True
        else:
            self.current_state_index = self.current_state_index + 1
        
        return self.get_state(),reward, self.done

 
    def get_state(self):
        if self.done:
            return np.zeros_like(
                np.array(self.current_state)
            )
        else:
            return np.array(self.current_state)

    # Get the number of actions available for the agent
    def num_actions_available(self):
        return len(self.action_space)

    def num_state_features(self):
        return self.state_space_shape -1 # remove the label