In [10]:
import ast
import json
import os
from pathlib import Path
import re
import shutil
from typing import Any, List
import warnings

import numpy as np
import pandas as pd


repo_dir = Path(os.path.abspath("")).parent
dss_dir = repo_dir / "resources"
eval_dir = repo_dir / "eval"

def process_num(value: Any):
    if isinstance(value, list) or isinstance(value, set) or isinstance(value, tuple):
        return [process_num(v) for v in value]
    elif isinstance(value, float):
        return round(round(value, 8), 4) # round by 8 to remove numerical inaccuracies, than by 4
    else:
        raise RuntimeError()

def process(value: str, metadata: str, label: bool):
    try:
        if value != "":
            value = ast.literal_eval(value)
            value = process_num(value)
            if isinstance(value, list) and len(value) > 0:
                value = np.array(value).squeeze()
                value[value == 0.] = 0.
                value = value.tolist()
                if len(value) == 1: 
                    value = value[0]
            if 'sorted' in metadata:
                value = sorted(value)
            value = str(value)
    except:
        value = str(value)

    if label:
        if value.startswith("re"):
            re_value = value.split("re")[1]
        else:
            re_value = re.escape(value)
    else:
        re_value = None

    return value, re_value

eval_metrics = []
for ds_dir in sorted(dss_dir.iterdir()):
    if not ds_dir.is_dir():
        continue

    abc_out_dir = ds_dir 
    abc_eval_dir = eval_dir / ds_dir.name 

    for sample in sorted(abc_out_dir.iterdir()):
        with open(sample / "input.txt") as f:
            text_batch = f.read()
        text_batch = text_batch.split("\n")
        
        solution_file = sample / "solution.txt"
        if not solution_file.is_file():
            warnings.warn("%s - No solution developed yet" % sample.name)
            with open(solution_file, 'w') as f:
                f.write("".join(["\n" for _ in range(len(text_batch) - 1)]))
            # continue

        metadata_file = sample / "metadata.txt"
        metadata_batch = ["" for _ in range(len(text_batch))] 
        if metadata_file.is_file():
            with open(metadata_file) as f:
                metadata_subset = f.read().split("\n")
            for i, val in enumerate(metadata_subset):
                metadata_batch[i] = val

        with open(solution_file) as f:
            solution_batch = f.read()
            solution_batch = solution_batch.split("\n")

        assert len(solution_batch) == len(text_batch), sample
        assert len(solution_batch) == len(metadata_batch), sample

        for i, (text, solution, metadata) in enumerate(zip(text_batch, solution_batch, metadata_batch)):
            response_dir = abc_eval_dir / sample.name / str(i)

            eval_sample = {
                'ds': ds_dir.name,
                'id': sample.name,
                "question-idx": i,
                "question": text,
                'label': solution
            }

            if not response_dir.is_dir():
                warnings.warn("%s: %s - Not generated yet!" % (sample.name, i))

                eval_sample.update({
                    'failed': True,
                    'predicted': '',
                })

            else:
                eval_sample['failed'] = (response_dir / "exception.txt").is_file()

                if not eval_sample['failed']:
                    with open(response_dir / "result.txt") as f:
                        eval_sample['predicted'] = f.read()
                else:
                    eval_sample['predicted'] = ""

                try: 
                    predicted = ast.literal_eval(eval_sample['predicted'])
                    if isinstance(predicted, tuple):
                        eval_sample['predicted'] = str(list(predicted))
                except:
                    pass

                eval_sample['predicted'], _ = process(eval_sample['predicted'], metadata=metadata, label=False)
 
            eval_sample['label'], eval_sample['re_label'] = process(eval_sample['label'], metadata=metadata, label=True)
            # if eval_sample['predicted'] != eval_sample['label']:
            #     shutil.rmtree(response_dir)

            eval_metrics.append(eval_sample)

df = pd.DataFrame(eval_metrics)
# df['correct'] = (df['predicted'] == df['label']) & ~df['failed']
df['correct'] = (df.apply(lambda r: re.match(r['re_label'], r['predicted']) is not None, axis=1)) & ~df['failed']
del df['re_label']

if len(eval_sample['predicted']) > 100:
    eval_sample['predicted'] = eval_sample['predicted'][:100] + "..."

In [11]:
# df = pd.read_csv(resources_dir / "eval/eval.csv")
df_i = df.set_index(['ds', 'id', 'question-idx'])
print('correct:', df_i['correct'].sum())
print('questions:', (df_i['label'] != 'X').sum())
df_i.style.set_properties(subset=['question'], **{'width': '600px'})
# df_i.to_csv(repources_dir / "eval/eval.csv")
pd.options.display.max_rows = 111
df_i

correct: 0
questions: 111


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,question,label,failed,predicted,correct
ds,id,question-idx,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
abc,00000002,0,How wide is the object?,0.0812,False,,False
abc,00000002,1,What is the total height of the object?,0.0301,False,,False
abc,00000002,2,How many holes are present around the object e...,re(19|20),False,,False
abc,00000002,3,what is the depth of the thru-hole visible on ...,0.046,False,,False
abc,00000002,4,what is the radius of the through-hole visible...,0.0032,False,,False
abc,00000002,5,What is the minimal distance between the cente...,0.0088,False,,False
abc,00000013,0,How many holes does the object have?,re(12|13),False,,False
abc,00000013,1,What different diameters do the holes have?,"[0.0043, 0.0063, 0.0127]",False,,False
abc,00000013,2,What is the diameter of the largest hole on th...,0.0063,False,,False
abc,00000013,3,What is the depth of the hole with the smalles...,0.0095,False,,False


In [12]:
df_i.groupby('id').sum()['correct'].to_frame()

Unnamed: 0_level_0,correct
id,Unnamed: 1_level_1
00000002,0
00000013,0
00000024,0
00000057,0
00000115,0
00000124,0
00000179,0
00000180,0
00000181,0
00000182,0
