In [None]:
import os, json
from tqdm import tqdm
from src.intension import Intension

In [None]:
MODELS = [ 
    { "model_name": "google/gemma-7b-it", "batch_size": 50 },
    { "model_name": "gpt-3.5-turbo", "batch_size": 50 },
    { "model_name": "gpt-4-0125-preview", "batch_size": 50 },
    { "model_name": "mistralai/Mistral-7B-Instruct-v0.2", "batch_size": 50 },
    { "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1", "batch_size": 50 },
    { "model_name": "claude-3-opus-20240229", "batch_size": 1 },
]

DATA = json.load(open('experiments/wikidata_statements.json', 'r'))

In [None]:
for model in MODELS:
    filename = f'experiments/{model["model_name"].split("/")[-1]}-wikidata.json'
    if os.path.isfile(filename):
        print(f'Skipping {model["model_name"]}')
    else:
        results = []
        queries = [
            {
                "predicate": datum["predicate"]["label"],
                "arguments": ", ".join([ arg["label"] for arg in datum["arguments"] ]),
                "world": datum["predicate"]["definition"] + " ".join([ arg["description"] for arg in datum["arguments"] ]),
                "actual": datum["in_extension"]
            }
            for datum in DATA
        ]
        batches = [ queries[i:i+model["batch_size"]] for i in range(0, len(queries), model["batch_size"]) ] 
        intension = Intension(model_name=model["model_name"])
        for batch in tqdm(batches, desc=f'{model["model_name"]:30}', total=len(batches)):
            response = intension.chain.batch(batch)
            for i, result in enumerate(response):
                result["rationale"] = result["text"]["rationale"]
                result["predicted"] = result["text"]["answer"]
                result.pop("text")
            results.extend(response)
        json.dump(results, open(filename, "w+"))