In [1]:
import torch
import pandas as pd
import ast
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def str_node(node, info, level):
    info["level"] = level
    if isinstance(node, ast.AST):
        for name, val in ast.iter_fields(node):
                if "ast" in str(val):
                    info[name] = None
                else:
                    info[name] = val
        fields = [(name, str_node(val,info,level)) for name, val in ast.iter_fields(node) if name not in ('left', 'right')]
        info["type"] = node.__class__.__name__
        rv = '%s(%s' % (node.__class__.__name__, ', '.join('%s=%s' % field for field in fields))
        return rv + ')'
    else:
        return repr(node)

def ast_visit(node, state_types,level=0):
    info = {}
    str_node(node,info, level )    
    state_types.append(info)
    for field, value in ast.iter_fields(node):
        if isinstance(value, list):
            for item in value:
                if isinstance(item, ast.AST):
                    ast_visit(item,state_types, level=level+1)
        elif isinstance(value, ast.AST):
            ast_visit(value,state_types, level=level+1)

In [25]:
# state = { type ,  valueType, }

class EnvManager():
    def __init__(self, device):
        self.device = device
        ## Load all the episodes
        self.episodes = pd.read_json("./data/data.json")
        ## Extract all possible labels from the episodes
        ## They represent the output of the DQN-Agent
        self.action_space = self.episodes["label"].unique().tolist()
        self.action_space.append("Ok")
        
        ## Get all possible states from the episodes
        states = []
        for index, episode in self.episodes.iterrows():
            ast_visit(ast.parse(episode["code"]), states)
        states = pd.DataFrame(states)

        self.state_space_shape = len(states.columns)

        for index,item in enumerate(states["value"]):
            states.loc[index,"value"] = type(item)

 
     

        ## Init the env
        self.done = False # Set to true when the agent traverse the whole code
        self.current_episode = None # Index of the current episode
        self.current_state = None # current state
        self.current_state_index = None ## State index
        self.right_actions = [] # Store the right action for each node
        self.critical_parts = [] # Indeces of the ndoes with problems in the tree
        self.tree = [] # A list containing all the nodes in the current episode
        
    ## Reset the env
    def reset(self):
        self.tree = []
        self.critical_parts = []
        self.right_actions = []
        self.current_state_index = 0
        self.current_episode = None
        self.done = False
        
        self.current_episode = self.episodes.sample() # Get a random episode
        ast_visit(ast.parse(self.current_episode["code"].iloc[0]), self.tree) # Traverse the code and store the nodes in a list 
        self.critical_parts =  self.current_episode["critical"].iloc[0] # Nodes with problems
        self.create_right_actions_table() # Generate a list with the right action for each node
        
    
    ## Create a list with the right action to take for each node
    def create_right_actions_table(self):
        self.right_actions = ["Ok"] * len(self.tree)
        for index, node in enumerate(self.tree):
            if index in self.critical_parts:
                self.right_actions[index] = self.current_episode["label"].iloc[0]

    # Map action state to a reward
    def take_action(self, action):
        # Get the string representaiton of the tensor
        selectedAction = self.action_space[action.item()] 
        ## Compare the selected action with the entered action
        ## Reward +1 if it is correct
        ## Reward -1 if it is wrong
        if self.right_actions[self.current_state_index] == selectedAction: 
            reward = +1
        else:
            reward = -1
        if self.current_state_index == len(self.tree) - 1:
            self.done = True
        else:
            self.current_state_index = self.current_state_index + 1
        
        return torch.tensor([reward], device=self.device)

 
    def get_state(self):
        if self.done:
            return torch.zeros_like(
                torch.tensor(self.current_state), device=self.device
            ).float()
        else:
            return torch.tensor(self.current_state, device=self.device).float()

    # Get the number of actions available for the agent
    def num_actions_available(self):
        return len(self.action_space)

    def num_state_features(self):
        return self.state_space_shape


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = EnvManager(device)
env.reset()
env.num_state_features()


13