# Evaluation Code

Comprises of two main functions
- get_groupings()
- get_results()

The first is used in the get_results function, it extracts the relational groupings from an LF sequence which is then used in get_results to check that the predicted groupings match the target groupings. The get_results function evaluates the results on both the generalisation and test set at the same time. It first extracts the list of elements from each predicted LF and compares it to the target, it then extracts the relational groupings (using get_groupings) and compared it to the target. If both the element list and groupings match the target, the prediction is counted as correct. All correct predictions are counted, as well as the correct predictions per generalisation case. The percentage of correct cases on the generalisation set, test set, and per generalisation case is then calculated and then written to a file. 



In [4]:

def get_groupings(lf):
    """
    extract the relational groupings from an LF structure 
    - the reasoning behind this is to check correct indexing and relational mapping of indices

    lf: can be either predicted or target LF string
    """
    
    import re
    #lf = re.split(';|AND', lf_string) 

    rel_list = []
    for item in lf:
        if "," in item:
            rel_list.append(item)    

    for item in rel_list:
        if item in lf:
            lf.remove(item)

    rel_groups = []
    
    for relation in rel_list:
        group = []
        role = relation.split(" (")[0]
        if '.' in role:
            role = role.split('.')
            group.append(role[0])
            group.append(role[1])
        else:
            group.append(role)
        
        indxs = re.findall(r'\d+', relation)
        if " ( ( " in relation:
            continue 
        if " (" not in relation:
            continue 
        if len(indxs) < 1:
            continue
        if len(indxs) == 1:
            ind1 = indxs[0]
            name = relation.split(" (")[1].split(',')[1]
            name = re.findall("[a-zA-Z]+", name)
            if len(name) < 1:
                name = relation.split(" (")[1].split(',')[0]
                group.append(name)
            else:
                group.append(name[0])
            for item in lf:
                if f' {ind1} ' in item:
                    group.append(item.split(" (")[0])

        else:
            ind1 = indxs[0]
            ind2 = indxs[1]
            for item in lf:
                if f' {ind1} ' in item:
                    group.append(item.split(" (")[0])
                elif f' {ind2} ' in item:
                    group.append(item.split(" (")[0])
                
        rel_groups.append(group)
            
    return rel_groups

In [5]:

def get_results(gen_path, test_path):
    """
    function to get results for each run
    specify the gen predictions file path and test predictions file path

    calculates test accuracy, gen accuracy, and accuracy per category 
    
    """
    import re 
    import pprint
    
    run = gen_path.split('/')[-1].split('_') # isolates the name of the file, which includes details on the run (ex seed, model, data info), example ['PRED', '42', 'ende', 'transformer', 'recogs', 'v1', 'cogs.tsv']
    seed = run[1]
    model = run[3]
    dataset = run[4]
    
    with open(test_path, 'r') as infile:
        content_t = []
        rows = infile.read().split('\n') 
        for r in rows:
            column = r.split('\t')
            content_t.append(column)
    
    total = 0
    test_correct = 0
    
    for line in content_t[1:-1]:
        total+=1
        
        pred = line[1]
        target = line[2]
        
        pred = re.split(';|AND', pred)
        target = re.split(';|AND', target)
    
        pred_list = []
        target_list = []

        for var in pred:
            var = var.split("(")
            pred_list.append(var[0])
        for var in target:
            var = var.split("(")
            target_list.append(var[0])

        ind_group_pred = get_groupings(pred)
        ind_group_target = get_groupings(target)
        ind_group_pred.sort()
        ind_group_target.sort()

        case_misindex = 0
        for group_p, group_t in zip(ind_group_pred, ind_group_target):
            if group_p != group_t:
                case_misindex +=1
        #if case_misindex > 0:
         #   mis_index += 1
            
        pred_list.sort()
        target_list.sort()
    
        if pred_list == target_list:
            if case_misindex == 0:
                test_correct+=1
    
    test_acc = 100*test_correct/total
    
    
    with open(gen_path, 'r') as infile:
        content_g = []
        rows = infile.read().split('\n') 
        for r in rows:
            column = r.split('\t')
            content_g.append(column)
    
    total = 0
    gen_correct = 0
    gen_incorrect = 0
    categories = set()
    results = {}
    
    act_pas=0
    obj_subj_c=0
    unac_subj_ootsubj=0
    pp_r=0
    obj_subj_pp=0

    for line in content_g[1:-1]:
        total+=1
        cat = line[3]
        pred = line[1]
        target = line[2]

        pred = re.split(';|AND', pred)
        target = re.split(';|AND', target)

        pred_list = []
        target_list = []
        categories.add(cat)

        for var in pred:
            var = var.split("(")
            pred_list.append(var[0])
        for var in target:
            var = var.split("(")
            target_list.append(var[0])

        ind_group_pred = get_groupings(pred)
        ind_group_target = get_groupings(target)
        ind_group_pred.sort()
        ind_group_target.sort()

        case_misindex = 0
        for group_p, group_t in zip(ind_group_pred, ind_group_target):
            if group_p != group_t:
                case_misindex +=1
    
        pred_list.sort()
        target_list.sort()
        if pred_list == target_list:
            if case_misindex == 0:
                test_correct+=1
        if pred_list == target_list:
            if case_misindex == 0:
                gen_correct+=1
                if cat == ' active_to_passive':
                    act_pas +=1
                elif cat == ' obj_to_subj_common':
                    obj_subj_c +=1
                elif cat == ' only_seen_as_unacc_subj_as_obj_omitted_transitive_subj':
                    unac_subj_ootsubj +=1
                elif cat == ' pp_recursion':
                    pp_r +=1
                elif cat == ' obj_pp_to_subj_pp':
                    obj_subj_pp +=1
            else:
                gen_incorrect+=1
    
    cat_total = total/5
    results[f"{dataset}_{model}_{seed}"] = {
                    "active_to_passive" : 100 * act_pas/cat_total,
                    "obj_to_subj_common" : 100 * obj_subj_c/cat_total,
                    "only_seen_as_unacc_subj_as_obj_omitted_transitive_subj" : 100 * unac_subj_ootsubj/cat_total,
                    "pp_recursion" : 100 * pp_r/cat_total,
                    "obj_pp_to_subj_pp" : 100 * obj_subj_pp/cat_total,
                    "overall_acc": 100*gen_correct/total,
                    "test_acc": test_acc
            }


    pprint.pprint(results)
    print('\n')
    
    r_path = gen_path.split('/')
    r_path.remove(r_path[-1])
    ovr_path = r_path
    results_path = '/'.join(r_path)
    results_path = f'{results_path}/RESULTS_{dataset}_{model}_{seed}.txt'

    with open(results_path, "w") as outfile:
        pprint.pprint(results, outfile)
    
    ovr_path.remove(r_path[-1])
    ovr_path.remove(ovr_path[-1])
    overall_results = '/'.join(ovr_path)
    overall_results = f'{overall_results}/ALL_RESULTS.txt'
    with open(overall_results, "a") as outfile:
        pprint.pprint(results, outfile)
        outfile.write('\n\n')


# Run Code

The following cells of code were used to automate the running of the get_results() function for all the model output files in the results folder 

In [81]:
import os

path = '/Users/marina/Desktop/RESULTS'
data_model = os.listdir('/Users/marina/Desktop/RESULTS')
data_model.remove('.DS_Store')
data_model.remove('0.initialruns')

gen_test_paths = []

for dm in data_model:
    if dm.endswith('.txt'):
        continue
    #path = f'{path}/{f}'
    seeds = os.listdir(f'{path}/{dm}')
    seeds.remove('.DS_Store')
    for s in seeds:
        files = os.listdir(f'{path}/{dm}/{s}')
        #files.remove('.DS_Store')
        for f in files:
            if 'PRED' in f:
                gen_path = f'{path}/{dm}/{s}/{f}'
            elif 'TEST' in f:
                test_path = f'{path}/{dm}/{s}/{f}'
        pair = [gen_path, test_path]
        gen_test_paths.append(pair)


In [82]:
for pair in gen_test_paths:
    get_results(pair[0],pair[1],pos=False)

{'recogspos_transformer_77': {'active_to_passive': 99.8,
                              'obj_pp_to_subj_pp': 0.0,
                              'obj_to_subj_common': 99.9,
                              'only_seen_as_unacc_subj_as_obj_omitted_transitive_subj': 100.0,
                              'overall_acc': 60.18,
                              'pp_recursion': 1.2,
                              'test_acc': 99.6}}


{'recogspos_transformer_88': {'active_to_passive': 99.9,
                              'obj_pp_to_subj_pp': 0.0,
                              'obj_to_subj_common': 99.5,
                              'only_seen_as_unacc_subj_as_obj_omitted_transitive_subj': 99.8,
                              'overall_acc': 59.96,
                              'pp_recursion': 0.6,
                              'test_acc': 99.7}}


{'recogspos_transformer_42': {'active_to_passive': 99.9,
                              'obj_pp_to_subj_pp': 0.0,
                              'obj_to_subj_commo