In [None]:
import numpy as np


PRECISION = 0.001

SYSTEMS = ['popper_acc', 'popper_accminsize', 'popper_lexfn', 'popper_lexfnsize', 'popper_lexfp', 'popper_lexfpsize', 'popper_mdl']

TRIALS = [i for i in range(2)]
STEP_NPROGS = 1000
MAX_NPROGS = 80000


def std_err(lst):
    data = np.array(lst)
    return np.std(data, ddof=1) / np.sqrt(np.size(data))

def read_results(result_file):
    num_literals = 0
    accuracy = None

    # Check if the file is '0.pl'
    if os.path.basename(result_file) == '0.pl':
        with open(result_file, 'r') as f:
            for line in f.readlines():
                if line.startswith('accuracy:'):
                    accuracy = float(line.split(':')[1])
        # Since it's '0.pl', we assume no rules, hence num_literals is 0
        return accuracy, 0

    # For other files, count literals accurately
    with open(result_file, 'r') as f:
        for line in f.readlines():
            # Check if the line contains a rule (ending with a period)
            if line.strip().endswith('.'):
                # Split the line by ':-' to separate the head and body of the rule
                if ':-' in line:
                    head, body = line.split(':-')
                    # Count literals in the head and body
                    num_literals += head.count('(')
                    num_literals += body.count('(')
                else:
                    # If there's no ':-', it's just a head with no body
                    num_literals += line.count('(')
            # Check if the line contains the accuracy
            elif line.startswith('accuracy:'):
                accuracy = float(line.split(':')[1])

    if accuracy is None:
        raise ValueError(f"No accuracy found in file: {result_file}")
    print(num_literals)

    return accuracy, num_literals


def read_data(DOMAIN, TASKS):
    data_accuracy = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(list)))
    data_literals = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(list)))
    stats = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(dict)))

    for sys in SYSTEMS:
        for task in TASKS:
            for trial in TRIALS:
                print(sys, task, trial)
                output_path = os.path.join('/content/drive/MyDrive/popper/plot_timeout/experimental_data', DOMAIN, f'{task}', sys, str(trial), "nprogs")
                print(output_path)
                dirs = os.listdir(output_path)
                n_progs = sorted([int(d.split(".")[0]) for d in dirs if not d.startswith('.')])
                k, current_key = 0, 0
                while k < MAX_NPROGS:
                    if n_progs[current_key] < k and current_key < len(n_progs) - 1:
                        current_key += 1
                    acc, literals = read_results(os.path.join(output_path, f"{n_progs[current_key]}.pl"))
                    data_accuracy[sys][task][k].append(acc)
                    data_literals[sys][task][k].append(literals)
                    k += STEP_NPROGS
                    if k > max(n_progs):
                        print(current_key, len(n_progs), n_progs)
                        break

            acc_av, acc_sem = [], []
            for k in data_accuracy[sys][task]:
                acc_av.append(mean(data_accuracy[sys][task][k]))
                acc_sem.append(std_err(data_accuracy[sys][task][k]))
            stats[sys][task]['timeout'] = sorted(data_accuracy[sys][task].keys())
            stats[sys][task]['acc_av'] = acc_av
            stats[sys][task]['acc_sem'] = acc_sem
            stats[sys][task]['avg_literals'] = mean([literals for k in data_literals[sys][task] for literals in data_literals[sys][task][k]])

    return stats


DOMAIN = 'zendo_noise'
TASKS = [
    "zendo1__0.1",
    "zendo2__0.1",
    "zendo3__0.1",
    "zendo4__0.1",
    ]

data_0_literal = read_data(DOMAIN,TASKS)