# Goal Recognition as a Deep Learning Task: the GRNet Approach

## Imports 

In [None]:
import numpy as np
from tensorflow.keras.models import load_model
from os.path import join
import pickle
import time
import os
from sklearn.metrics import classification_report, accuracy_score
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import tensorflow.keras.backend as K
from typing import Union
from keras.engine.topology import Layer
from keras import initializers, regularizers, constraints
from keras.initializers import Constant
from keras.losses import BinaryCrossentropy

## Custom Classes

### Network classes

Code from 
*Yang, Z.; Yang, D.; Dyer, C.; He, X.; Smola, A. J.; and Hovy, E. H.* 2016. **Hierarchical Attention Networks for Document Classification**
https://github.com/philipperemy/keras-attention-mechanism

In [None]:
class AttentionWeights(Layer):
    def __init__(self, step_dim,
                 W_regularizer=None, b_regularizer=None,
                 W_constraint=None, b_constraint=None,
                 bias=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')
        # self.init = initializers.get(Constant(value=1))

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        self.step_dim = step_dim
        self.features_dim = 0
        super(AttentionWeights, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]

        if self.bias:
            self.b = self.add_weight(shape=(input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        self.built = True

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
                        K.reshape(self.W, (features_dim, 1))), (-1, step_dim))

        if self.bias:
            eij += self.b

        eij = K.tanh(eij)

        a = K.exp(eij)

        if mask is not None:
            a *= K.cast(mask, K.floatx())

        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        return a

    def compute_output_shape(self, input_shape):
        return input_shape[0],  self.features_dim

    def get_config(self):
        config={'step_dim':self.step_dim}
        base_config = super(AttentionWeights, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    




class ContextVector(Layer):
    def __init__(self, **kwargs):
        super(ContextVector, self).__init__(**kwargs)
        self.features_dim = 0

    def build(self, input_shape):
        assert len(input_shape) == 2
        self.features_dim = input_shape[0][-1]
        self.built = True

    def call(self, x, **kwargs):
        assert len(x) == 2
        h = x[0]
        a = x[1]
        a = K.expand_dims(a)
        weighted_input = h * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.features_dim

    def get_config(self):
        base_config = super(ContextVector, self).get_config()
        return dict(list(base_config.items()))

### Constants class

In [None]:
class C:
    '''
    Constants class.
    '''
    OBSERVATIONS = 0
    CORRECT_GOAL = 1
    POSSIBLE_GOALS = 2 
    
    SATELLITE = 0
    LOGISTICS = 1
    ZENOTRAVEL = 2
    BLOCKSWORLD = 3
    DRIVERLOG = 4
    DEPOTS = 5
    
    MAX_PLAN_LENGTH = 0
    MODEL_FILE = 1
    DICTIONARIES_DICT = 2
    
    SMALL = 0
    COMPLETE = 1
    PERCENTAGE = 2
    
    MODELS_DIR = './models/'
    DICTIONARIES_DIR = './dictionaries/'
    #MODELS_DIR = './incremental_models/'
    
    MODEL_LOGISTICS = None
    MODEL_SATELLITE = None
    MODEL_ZENOTRAVEL = None
    MODEL_BLOCKSWORLS = None
    MODEL_DRIVERLOG = None
    MODEL_DEPOTS = None

    MAX_PLAN_PERCENTAGE = 0.7

    TABLE_HEADERS = ['', 'Pereira', 'Our', 'Support']
    
    CUSTOM_OBJECTS = {'AttentionWeights': AttentionWeights,
                   'ContextVector' : ContextVector,
                   'custom_multilabel_loss_v3' : BinaryCrossentropy}


### Exceptions

In [None]:
class PlanLengthError(Exception):
    pass

class FileFormatError(Exception):
    pass

class UnknownIndexError(Exception):
    pass

## Custom Methods

### Unpack files methods

In [None]:
def unzip_file(file_path: str, target_dir: str) -> None:
    '''
    Unzip a file in an empty directory. The directory is 
    emptied before the execution.
    
    Args:
        file_path:
            A string that contains the path
            to the .zip file.
        
        target_dir:
            A string that contains the path 
            to the target directory. This 
            directory is created if it doesn't
            exist and it is emptied if it exists.
        
    '''
    if os.path.exists(target_dir):
        for f in os.listdir(target_dir):
            os.remove(join(target_dir, f))
        os.rmdir(target_dir)
    os.mkdir(target_dir)
    os.system(f'unzip -qq {file_path} -d {target_dir}')
    
def unpack_bz2(file_path: str, target_dir: str) -> None:
    '''
    Unpack a .bz2 file in an empty directory. The directory 
    is emptied before the execution.
    
    Args:
        file_path:
            A string that contains the path
            to the .bz2 file.
        
        target_dir:
            A string that contains the path 
            to the target directory. This 
            directory is created if it doesn't
            exist and it is emptied if it exists.
        
    '''
    if os.path.exists(target_dir):
        for f in os.listdir(target_dir):
            os.remove(join(target_dir, f))
        os.rmdir(target_dir)
    os.mkdir(target_dir)
    os.system(f'tar -xf {file_path} -C {target_dir}')

### Input parse methods

In [None]:
def load_file(file: str, binary: bool = False, use_pickle: bool = False):
    '''
    Get file content from path.
    
    Args:
        file:
            A string that contains the path
            to the file.
        binary:
            Optional. True if the file is a 
            binary file.
        use_pickle:
            Optional. True if the file was 
            saved using pickle.
            
    Returns:
        The content of the file.
    
    Raises:
        FileNotFoundError:
            An error accessing the file
    '''
    operation = 'r'
    if binary:
        operation += 'b'
    with open(file, operation) as rf:
        if use_pickle:
            output = pickle.load(rf)
        else:
            output = rf.readlines()
        rf.close()
    return output
        
        

In [None]:
def parse_file(read_file: str, content_type: int, dictionary: dict = None):
    '''
    Parse different input files.
    
    Args:
        read_file: 
            String containing the path to the file.
        content_type: 
            Integer representing the kind of parse to apply.
                0: observations file,
                1: correct goal file, 
                2: possible goals file
        
    Returns:
        A list of strings that contains the parsed elements.
        
    Raises:
        FileFormatError: 
            An error regarding the action format in 
            the file   
    '''
    
    msg_empty = f'File {read_file} is empty.'
    msg_index = f'Content type {content_type} is unknown.' 
    
    elements = list()
    
    lines = load_file(read_file)
    if len(lines) == 0:
        raise FileFormatError(msg_empty)
    if content_type == C.OBSERVATIONS:
        elements = parse_observations(lines, dictionary)
    elif content_type == C.POSSIBLE_GOALS:
        elements = parse_possible_goals(lines, dictionary)
    elif content_type == C.CORRECT_GOAL:
        elements = parse_correct_goal(lines[0], dictionary)
    else:
        raise UnknownIndexError(msg_index)
    
    if len(elements) > 0:    
        return elements
    else:
        raise FileFormatError(msg_empty)
        

def remove_parentheses(line: str) -> str:
    '''
    Remove parentheses from a string.
    
    Args:
        line: a string that is enclosed in parentheses.
        For example:
        
        "(string example)"
        
    Returns:
        The string without the parenteses.
        None if the string is empty.
        
    Raises:
        FileFormatError: error handling the string
    '''
    
    msg = (f'Error while parsing a line. Expected "(custom '
    +f'text)" but found "{line}"')
    
    line = line.strip()
    if line.startswith('(') and line.endswith(')'):
        element = line[1:-1]
        element = element.strip()
        if len(element) == 0:
            return None
        else:
            return element
    elif len(line) == 0:
        return None
    else:
        raise FileFormatError(msg)
        
def retrieve_from_dict(key: str, dictionary: dict):
    '''
    Return the dictionary value given the key.
    
    Args:
        key:
            A string that is the key.
        dictionary:
            A dict.
            
    Returns:
        The value corresponding to the key.
    
    Raises:
        KeyError:
            An error accessing the dictionary.
    '''
    
    msg_error = f'Key {key.upper()} is not in the dictionary'
    
    try:
        return dictionary[key.upper()]
    except KeyError:
        print(msg_error)
        np.random.seed(47)
        return np.random.randint(0,len(dictionary))

def parse_correct_goal(line: str, goals_dict: dict = None) -> list:
    '''
    Parse the fluents that compose a goal.
    
    Args:
        line: 
            A string that contains one or more 
            fluents in the goal. Fluents are 
            enclosed in parentheses and separated
            by commas. For example:
            
            "(fluent1), (fluent2),  (fluent3)"
        
        goals_dict:
            Optional. A dictionary that maps each 
            fluent to its unique identifier.
    
    Returns:
        A list of strings containing each fluent 
        without parentheses.
        
    Raises:
        FileFormatError:
            An error accessing the file.
    '''
    msg_empty = 'Parsed goal is empty.'
    
    goal = list()
    line = line.strip()
    fluents = line.split(',')
    for f in fluents:
        fluent = remove_parentheses(f)
        if fluent is not None:
            if goals_dict is not None:
                fluent = retrieve_from_dict(fluent, goals_dict)
            goal.append(fluent)
    if len(goal) > 0:
        return goal
    else:
        raise FileFormatError(msg_empty)
    

        
def parse_observations(lines: list, obs_dict: dict = None) -> list:
    '''
    Removes parentheses and empty strings from 
    the observations list.
    
    Args:
        lines: 
            List of strings that contains the 
            observations. Each observation is
            enclosed in parentheses. For 
            example:
            
            ['(observation1)', '', '(observation2)']
        
        obs_dict:
            Optional. A dictionary that maps each 
            observation to its unique identifier.
            
    Returns:
        The input list without parentheses and
        empty strings.
        
    Raises:
        FileFormatError:
            An error accessing the file.
    '''
    msg_empty='Observations list is empty.'
    
    observations = list()
    
    for line in lines:
        observation = remove_parentheses(line)
        if observation is not None:
            if obs_dict is not None:
                observation = retrieve_from_dict(observation, obs_dict)
            observations.append(observation)
    if len(observations)>0:
        return observations
    else:
        raise FileFormatError(msg_empty)

def parse_possible_goals(lines: list, goals_dict: dict = None) -> list:
    '''
    Parse a list of goals.
    
    Args:
        lines:
            A list of strings that contains each
            possible goal.
            
        goals_dict:
            Optional. A dictionary that maps each 
            fluent to its unique identifier.
    
    Returns:
        A list of lists. Each list contains the fluents
        that compose the goal represented as a string.
        
    Raises:
        FileFormatError:
            An error accessing the file.
    '''
    msg_empty='Possible goals list is empty.'
    
    goals=list()
    for line in lines:
        line = line.strip()
        if len(line)>0:
            goals.append(parse_correct_goal(line, goals_dict))
    if len(goals) > 0:
        return goals
    else:
        raise FileFormatError(msg_empty)
            
            

### Model related methods

In [None]:
def parse_domain(domain: Union[str, int]) -> int:
    '''
    Converts domain name into integer
    
    Args:
        domain: 
            A string or an int that represents
            a domain.
    
    Returns:
        An integer associated to a specific domain.
        
    Raises:
        KeyError:
            An error parsing the domain arg.
    '''
    msg = (f'Provided domain {domain} is not supported. '+
           f'Supported domains are: {C.SATELLITE} : satellite, ' +
           f'{C.LOGISTICS} : logistics, {C.BLOCKSWORLD} : blocksworld, ' +
           f'{C.ZENOTRAVEL} : zenotravel, {C.DRIVERLOG}: driverlog,' + 
           f'{C.DEPOTS}: depots.')
           
    if (str(domain).isdigit() and int(domain) == C.SATELLITE) or str(domain).lower().strip() == 'satellite':
        return C.SATELLITE
    elif (str(domain).isdigit() and int(domain) == C.LOGISTICS) or str(domain).lower().strip() == 'logistics':
        return C.LOGISTICS
    elif (str(domain).isdigit() and int(domain) == C.BLOCKSWORLD) or str(domain).lower().strip() == 'blocksworld':
        return C.BLOCKSWORLD
    elif (str(domain).isdigit() and int(domain) == C.ZENOTRAVEL) or str(domain).lower().strip() == 'zenotravel':
        return C.ZENOTRAVEL
    elif (str(domain).isdigit() and int(domain) == C.DRIVERLOG) or str(domain).lower().strip() == 'driverlog':
        return C.DRIVERLOG
    elif (str(domain).isdigit() and int(domain) == C.DEPOTS) or str(domain).lower().strip() == 'depots':
        return C.DEPOTS
    else:
        raise KeyError(msg)

In [None]:
def get_model(domain: int):
    '''
    Loads the model for a specific domain.
    
    Args:
        domain: 
            an integer associated to a specific 
            domain.
            
    Returns:
        The Model loaded for the domain or None
        if there is no model in memory.
        
    Raises:
        KeyError:
            An error parsing the domain arg.
    '''

    msg = (f'Provided domain {domain} is not supported. '+
       f'Supported domains are: {C.SATELLITE} : satellite, ' +
       f'{C.LOGISTICS} : logistics, {C.BLOCKSWORLD} : blocksworld, ' +
       f'{C.ZENOTRAVEL} : zenotravel, {C.DRIVERLOG}: driverlog,' + 
       f'{C.DEPOTS}: depots.')
    
    if domain == C.LOGISTICS:
        return C.MODEL_LOGISTICS
    elif domain == C.SATELLITE:
        return C.MODEL_SATELLITE
    elif domain == C.DEPOTS:
        return C.MODEL_DEPOTS
    elif domain == C.BLOCKSWORLD:
        return C.MODEL_BLOCKSWORLS
    elif domain == C.DRIVERLOG:
        return C.MODEL_DRIVERLOG
    elif domain == C.ZENOTRAVEL:
        return C.MODEL_ZENOTRAVEL
    else:
        raise KeyError(msg)

In [None]:
def get_domain_related(domain: int, element: int, model_type: int = C.SMALL, 
                       percentage: float = 0) -> Union[int, str]:
    '''
    Returns domain related information
    
    Args:
        domain: 
            an integer associated to a specific 
            domain.
        
        element:
            an integer associated to a specific
            piece of information to retrieve.
        
        model_type:
            an integer associated to the type
            of RNN model in use.
        
        percentage:
            a float that represents the model
            percentage to use. Use only with
            model_type = C.PERCENTAGE.
    
    Returns: 
        Max plan size if element=C.MAX_PLAN_LENGTH,
        Model file if element=C.MODEL_FILE
        Dictionaries directory if element=C.DICTIONARIES_DICT
        
    '''
    
    msg = (f'Provided domain {domain} is not supported. '+
           f'Supported domains are: {C.SATELLITE} : satellite, ' +
           f'{C.LOGISTICS} : logistics, {C.BLOCKSWORLD} : blocksworld, ' +
           f'{C.ZENOTRAVEL} : zenotravel.')
    if domain == C.LOGISTICS:
        v = {
            'max_plan_len' : 50,
            'name' : 'logistics',
        }
    elif domain == C.SATELLITE:
        v = {
            'max_plan_len' : 40,
            'name' : 'satellite',
        }
    elif domain == C.ZENOTRAVEL:
        v = {
            'max_plan_len' : 40,
            'name' : 'zenotravel',
        }
    elif domain == C.BLOCKSWORLD:
        v = {
            'max_plan_len' : 75,
            'name' : 'blocksworld',
        }
    elif domain == C.DRIVERLOG:
        v = {
            'max_plan_len' : 70,
            'name' : 'driverlog',
        }
    elif domain == C.DEPOTS:
        v = {
            'max_plan_len' : 64,
            'name' : 'depots'
        }
    else:
        raise KeyError(msg)
        
    if element == C.MAX_PLAN_LENGTH:
        return int(v['max_plan_len']*C.MAX_PLAN_PERCENTAGE)
    
    elif element == C.MODEL_FILE:
        if model_type == C.COMPLETE:
            return f'{v["name"]}.h5'
        elif model_type == C.SMALL:
            return f'{v["name"]}_small.h5'
        elif model_type == C.PERCENTAGE:
            return f'{v["name"]}_{int(percentage*100)}perc.h5'
        
    elif element == C.DICTIONARIES_DICT:
        return join(C.DICTIONARIES_DIR, f'{v["name"]}')


### Domain component methods

In [None]:
def get_observations_array(observations: list, max_plan_length: int) -> np.ndarray:
    '''
    Create an array of observations index.
    
    Args:
        observations: 
            A list of action names
            
        max_plan_length:
            An integer that contains the maximum size of
            the list that will be considered.
    
    Returns:
        An array that contains the observations' indexes
    '''
    
    WARNING_MSG = (f'The action trace is too long. Only the first {max_plan_length}'+
                 f'actions will be considered.')
    
    observations_array = np.zeros((1, max_plan_length))
    if len(observations) > max_plan_length:
        print(WARNING_MSG)
    for index, observation in enumerate(observations):
        if index < max_plan_length:
            observations_array[0][index] = int(observation)
    return observations_array
        

def get_predictions(observations: list, 
                    max_plan_length: int, 
                    domain: int) -> np.ndarray:
    '''
    Return the model predictions.
    
    Args:
        observations:
            A list of action names.
        
        max_plan_length:
            An integer that contains the maximum size of
            the list that will be considered.
        
        domain:
            An integer associated to a specific domain.
    
    Returns:
        The model predictions.
    '''

    model = get_model(domain)
    
    model_input = tf.convert_to_tensor(get_observations_array(observations, max_plan_length))
    y_pred = model.predict(model_input)
    return y_pred



### GR Instance component methods

In [None]:
def get_score(prediction: np.ndarray, possible_goal: list) -> float:
    '''
    Returns the score for a possible goal.
    
    Args:
        prediction:
            An array that contains the model prediction.
        
        possible_goal:
            A list that contains the possible goal indexes.
        
    Returns:
        An float that represents the score of the possible goal.
    '''
    
    score=0
    
    for index in possible_goal:
        score += prediction[0][int(index)]
    return score

def get_scores(prediction: np.ndarray, possible_goals: list) -> np.ndarray:
    '''
    Returns the scores for all possible goals.
    
    Args:
        prediction:
            An array that contains the model prediction.
        
        possible_goals:
            A list of possible goals; each possible goal is represented as a
            list
        
    Returns:
        An array that contains the score of each of the possible goals.
    '''
    scores = np.zeros((len(possible_goals, )), dtype=float)
    for index, possible_goal in enumerate(possible_goals):
        scores[index] = get_score(prediction, possible_goal)
    return scores
        

def get_max(scores: np.ndarray) -> list:
    '''
    Returns a list with the index (or indexes) of the highest scores.
    
    Args:
        scores:
            An array that contains the scores as floats.
    
    Returns:
        A list thet contains the indexes of the highest score.
    '''
    max_element = -1
    index_max = list()
    for i in range(len(scores)):
        if scores[i] > max_element:
            max_element = scores[i]
            index_max = [i]
        elif scores[i] == max_element:
            index_max.append(i)

    return index_max
    
def get_result(scores: np.ndarray, correct_goal: int) -> bool:
    '''
    Computes if the goal recognition task is successfull.
    
    Args:
        scores:
            An array of floats that contains a score for 
            each possible goal
        correct_goal: 
            An integer that represents the index of the 
            correct goal
            
    Returns:
        True if the maximum score index corresponds to the 
        correct goal index, False otherwise.
    '''
    idx_max_list = get_max(scores)
    if len(idx_max_list) == 1:
        idx_max = idx_max_list[0]
    else:
        print(f'Algorithm chose randomly one of {len(idx_max_list)} equals candidates.')
        idx_max = idx_max_list[np.random.randint(0, len(idx_max_list))]
    if idx_max == correct_goal:
        return True
    else:
        return False
    
def get_correct_goal_idx(correct_goal: list, possible_goals: list) -> int:
    '''
    Conputes the correct goal index.
    
    Args:
        correct_goal:
            A list of strings that contains the correct goal
            fluents.
        possible_goals:
            A list of possible goals; each possible goal is represented as a
            list.
    
    Returns:
        The index of the correct goal in the possible goals list.
        None if the possible goal list does not contain the correct goal.
    '''
    
    for index, possible_goal in enumerate(possible_goals):
        possible_goal = np.sort(possible_goal)
        correct_goal = np.sort(correct_goal)
        if np.all(possible_goal == correct_goal):
            return index
    return None

### GRNet execution methods

In [None]:
def init_models(model_type: int, percentage: float)-> None:
    '''
    Loads in memory all the models.
    
    Args:
        model_type:
            an integer associated to the type
            of RNN model in use.
        
        percentage:
            a float that represents the model
            percentage to use. Use only with
            model_type = C.PERCENTAGE.
    
    Returns:
        None   
    '''
    
    model_file = get_domain_related(C.LOGISTICS, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_LOGISTICS =  load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)
    
    model_file = get_domain_related(C.SATELLITE, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_SATELLITE = load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)
    
    model_file = get_domain_related(C.ZENOTRAVEL, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_ZENOTRAVEL = load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)
    
    model_file = get_domain_related(C.DEPOTS, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_DEPOTS = load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)
    
    model_file = get_domain_related(C.DRIVERLOG, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_DRIVERLOG =  load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)
    
    model_file = get_domain_related(C.BLOCKSWORLD, C.MODEL_FILE, model_type=model_type, percentage=percentage)
    C.MODEL_BLOCKSWORLS =  load_model(join(C.MODELS_DIR, model_file), custom_objects=C.CUSTOM_OBJECTS)

In [None]:
def run_experiment(obs_file: str, 
            goals_dict_file: Union[str, None],
            actions_dict_file: Union[str, None],
            possible_goals_file: str, 
            correct_goal_file: str, 
            domain: Union[str, int], 
            verbose: int = 0) -> list:
    '''
    Run the goal recognition experiment

    Args:
        obs_file:
            Path of the file that contains the
            observations (plan)

        goals_dict_file:
            Path of the file that contains the
            goals dictionaries. If None it is
            retrieved from its default location.

        actions_dict_file:
            Path of the file that contains the
            actions dictionaries. If None it is
            retrieved from its default location.

        possible_goals_file:
            Path of the file that contains the
            possible goals.

        correct_goal_file:
            Path of the file that contains the
            correct goal.

        domain:
            String that contains the name of the
            domain or integer that corresponds to
            a domain.

        verbose:
            Integer that corresponds to how much
            information is printed. 0 = no info,
            2 = max info

    Returns:
         A list that contains the result, the correct
         goal index and the predicted goal index.
    '''

    domain = parse_domain(domain)
    if goals_dict_file is None:
        goals_dict_file = join(get_domain_related(domain, C.DICTIONARIES_DICT), 'dizionario_goal')
    goals_dict = load_file(goals_dict_file, binary=True, use_pickle=True)
    if actions_dict_file is None:
        actions_dict_file = join(get_domain_related(domain, C.DICTIONARIES_DICT), 'dizionario')
    actions_dict = load_file(actions_dict_file, binary=True, use_pickle=True)
    observations = parse_file(obs_file, C.OBSERVATIONS, actions_dict)
    
    if verbose > 1:
        print('Observed actions:\n')
        for o in observations:
            print(o)
    possible_goals = parse_file(possible_goals_file, C.POSSIBLE_GOALS, goals_dict)
    
    max_plan_length = get_domain_related(domain, C.MAX_PLAN_LENGTH)
    predictions = get_predictions(observations, max_plan_length, domain)
    scores = get_scores(predictions, possible_goals)
    if verbose > 0:
        for index, goal in enumerate(possible_goals):
            print(f'{index} - {goal} : {scores[index]}')
    
    correct_goal = parse_file(correct_goal_file, C.CORRECT_GOAL, goals_dict) 
    correct_goal_idx = get_correct_goal_idx(correct_goal, possible_goals)
    result = get_result(scores, correct_goal_idx)
    if verbose > 0:
        print(f'Predicted goal is {get_max(scores)[0]}')
        print(f'Correct goal is {correct_goal_idx} - {correct_goal}')
    return [result, correct_goal_idx, get_max(scores)[0]]



## GRNet execution

Do not change these values

In [None]:
model_type=C.SMALL 
percentage=0

Change these values to fit your execution

In [None]:
domain = C.BLOCKSWORLD
domain_dir = f'../goal_recognition/goal-plan-recognition-dataset/satellite/'
source_dir = f'./files_temp_dir'
verbose = 1

In [None]:
init_models(model_type=model_type, percentage=percentage)

perc_list = [0.1, 0.3, 0.5, 0.7, 1] 
results = list()
times = list()
for perc in perc_list:
    plans_dir = f'{join(domain_dir, str(int(perc*100)))}'
    files = os.listdir(plans_dir)
    total=0
    correct=0
    results_file = [list(), list()]
    for j, f in enumerate(files):     
        print(f)
        if f.endswith('.zip'):
            unzip_file(join(plans_dir,f), source_dir)
        elif f.endswith('.bz2'):
            unpack_bz2(join(plans_dir,f), source_dir)
        start_time = time.time()
        result = run_experiment(obs_file=join(source_dir, 'obs.dat'),
                                goals_dict_file=None,
                                actions_dict_file=None,
                                possible_goals_file=join(source_dir, 'hyps.dat'),
                                correct_goal_file=join(source_dir, 'real_hyp.dat'),
                                domain=domain, 
                                verbose=1)
        exec_time = time.time()-start_time
        if result[0]:
            correct+=1
        total +=1
        times.append(exec_time)
        print(exec_time)
        results_file[0].append(result[1])
        results_file[1].append(result[2])
    results.append(results_file)
