In [4]:
import os
import json
import pickle

os.environ["OPENAI_API_KEY"] = "ADD_API_KEY"

from paper_comparison.types import Table
from paper_comparison.metrics_utils import JaccardAlignmentScorer, ExactMatchScorer, EditDistanceScorer, SentenceTransformerAlignmentScorer
from paper_comparison.metrics_utils import BaseFeaturizer, ValueFeaturizer, DecontextFeaturizer
from paper_comparison.metrics import SchemaRecallMetric

# Change subset name to eval on different splits
subset = '../../medium/'
model_names = ['gpt3.5', 'mixtral']
variant_names = ['ours']
threshold_values = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]

emscorer = ExactMatchScorer()
edscorer = EditDistanceScorer()
jscorer = JaccardAlignmentScorer(remove_stopwords=False)
jscorer_nostop = JaccardAlignmentScorer(remove_stopwords=True)
stscorer = SentenceTransformerAlignmentScorer()
name_feat = BaseFeaturizer("name")
value_feat = ValueFeaturizer("values")
decontext_feat = DecontextFeaturizer("decontext")

featurizers = [name_feat, value_feat, decontext_feat]
# scorers = [jscorer_nostop, emscorer, edscorer]
scorers = [stscorer]
# stscorer remaining
# scorers = [jscorer]

total_runs = len(featurizers) * len(scorers) * len(threshold_values) * len(model_names) * len(variant_names)
i = 0
for featurizer in featurizers:
    for scorer in scorers:
        metric_store = {}
        for threshold in threshold_values:
            for model in model_names:
                for variant in variant_names:
                    metric = SchemaRecallMetric(featurizer=featurizer, alignment_scorer=scorer, sim_threshold=threshold)
                    for file in os.listdir(subset):
                        if file == "errors":
                            continue
                        if not os.path.isdir(os.path.join(subset, file)):
                            continue
                        gold_table_input = json.loads(open(os.path.join("../../metric_validation_0", f'{file}_gold.json')).read())
                        gold_table_instance = Table(gold_table_input["tabid"], list(gold_table_input["table"].keys()), gold_table_input["table"])
                        gold_table_instance.decontext_schema = gold_table_input["decontext_schema"]
                        if not os.path.exists(os.path.join(subset, file, model, f'{variant}_outputs', 'try_0_decontext.json')):
                            continue
                        pred_table_input = json.loads(open(os.path.join(subset, file, model, f'{variant}_outputs', 'try_0_decontext.json')).read())
                        pred_table_instances = []
                        for table in pred_table_input:
                            cur_table = Table(table["tabid"], table["schema"], table["table"])
                            cur_table.decontext_schema = table["decontext_schema"]
                            pred_table_instances.append(table)
                            metric.add(cur_table, gold_table_instance)
                    metric_store[f'{featurizer}_{scorer}_{threshold}_{model}_{variant}'] = metric.process_scores()
                    i += 1
                    print(f'Completed {i}/{total_runs} metric computation runs')
        pickle.dump(metric_store, open(f'medium_{featurizer.name}_{scorer.name}_exhaustive.pkl', 'wb'))


Completed 1/126 metric computation runs
Completed 2/126 metric computation runs
Completed 3/126 metric computation runs
Completed 4/126 metric computation runs
Completed 5/126 metric computation runs
Completed 6/126 metric computation runs
Completed 7/126 metric computation runs
Completed 8/126 metric computation runs
Completed 9/126 metric computation runs
Completed 10/126 metric computation runs
Completed 11/126 metric computation runs
Completed 12/126 metric computation runs
Completed 13/126 metric computation runs
Completed 14/126 metric computation runs
Completed 15/126 metric computation runs
Completed 16/126 metric computation runs
Completed 17/126 metric computation runs
Completed 18/126 metric computation runs
Completed 19/126 metric computation runs
Completed 20/126 metric computation runs
Completed 21/126 metric computation runs
Completed 22/126 metric computation runs
Completed 23/126 metric computation runs
Completed 24/126 metric computation runs
Completed 25/126 metric c

In [15]:
import os
import shutil

folder1 = '/Users/aakankshan/Downloads/metric_validation_0/'
folder2 = '/Users/aakankshan/Documents/TableGeneration/metric_validation_0/'

for file in os.listdir(folder2):
    if file.endswith('.json') or file == '.DS_Store' or file.endswith('.txt'):
        continue
    if not os.path.exists(os.path.join(folder2, file, 'mixtral', 'ours_outputs', 'try_0.json')):
        if not os.path.exists(os.path.join(folder2, file, 'mixtral', 'ours_outputs')):
            os.mkdir(os.path.join(folder2, file, 'mixtral', 'ours_outputs'))
        if os.path.exists(os.path.join(folder1, file, 'mixtral', 'ours_outputs', 'try_0.json')):
            print(f"Copying {file}")
            shutil.copyfile(os.path.join(folder1, file, 'mixtral', 'ours_outputs', 'try_0.json'), os.path.join(folder2, file, 'mixtral', 'ours_outputs', 'try_0.json'))

Copying c99cb160-3de0-4bf3-ae95-a068cb816ec4
Copying d9f833ee-070d-4384-868b-dcd94efbcfa6
Copying ea1d67b0-07f4-4ae4-bc4d-39cd1d9000c2
Copying 75c961c9-0c71-4d9a-bfcd-0967578b16ab
Copying 8aacb0ee-f57e-4495-8687-7a0475a6af87
Copying 9ec07665-b4b2-4e9c-a954-cb3dfebce45c
Copying 89cc20fa-5e44-430d-a060-6c265adfd8bd
Copying 6f849f5f-5bc6-4e90-a441-a433a6a57f0d
Copying c6f98048-93c6-448d-ada8-12e06201c06a
Copying 06b7fb23-8434-42ff-9d10-b75761298bdf
Copying a45087b0-4383-4722-aa75-7ec8ff299cbf
Copying 77392b3a-cdb3-426f-bfdc-1ac206935487
Copying 817d7079-9921-4dcb-ac7d-f8d90a83fb2d
Copying 322916bc-27c1-4888-be5e-73a72f95f1e9
Copying 14fa2ac8-fcab-4259-a6c2-b0cd97a05b62
Copying 32e48ffb-7014-4ca7-a7b1-004411de6f1a
Copying 903a7c8a-67b4-4e13-9961-d39a3899adb4
Copying 0c85fa9d-2577-4cb8-91ff-ec533aa4df4f
Copying a69858d8-2fc3-4529-9755-1ca48e6ea9b4
Copying 06d2a888-9aad-45a3-acb0-2deafc16d049
Copying 146202a2-5683-4a8b-931a-b76ee594c21b
Copying fdbd559a-0711-4038-9ecf-a48371d0655f
Copying 5c