In [1]:
import json
import chemparse

In [2]:
def format_material_string(string, splitchar):
    string = string.replace('MATERIAL:', '')
    string_list = "".join(string.split()).split(splitchar)
    return string_list
    
def format_temperature_string(temp_list):
    temp_list = "".join(temp_list.replace('K', '').split()).split(',')
    temp_list = [t for t in temp_list if len(t) > 0]
    temp_list = [temp_string[:-1] if temp_string[-1] == '.' else temp_string for temp_string in temp_list]
    temp_list_float = []
    for temp in temp_list:
        try:
            temp_list_float.append(float(temp))
        except ValueError:
            temp_list_float.append(-100.0)
    return temp_list_float

def standard_chem_formula(dictionary):
    return ''.join(char for char in ''.join(key for key in sorted(dictionary.keys())) if char.isalpha() or char.isnumeric())


def process_model_output(model_output):
    processed_model_output = {}
    for key in model_output.keys():
        if key == 'questions': continue
        materials = format_material_string(model_output[key][0], '&')
        temperatures = format_temperature_string(model_output[key][1])
        if len(materials) - len(temperatures) != 0: continue
        materials = list(set(materials))
        indices = [materials.index(s) for s in materials]
        if len(materials) > len(temperatures):
            temperatures = temperatures + [-100.0] * (len(materials) - len(temperatures))
        temperatures = [temperatures[i] for i in indices]
        processed_model_output[key] = (materials, temperatures)
    return processed_model_output

In [3]:
with open('/home/louis/research/pdf_processor/extraction/output/database.json') as f:
    database = json.load(f)

processed_database_output = {}
for key in database.keys():
    materials = format_material_string(database[key][0], ',')
    temperatures = format_temperature_string(database[key][1].replace('CRITICAL TEMPERATURE:', ''))
    if len(materials) > len(temperatures):
        temperatures = temperatures + [-100.0] * (len(materials) - len(temperatures))
    processed_database_output[key] = (materials, temperatures)

In [4]:
with open('/home/louis/research/pdf_processor/extraction/output/run_70B_15000chars.json') as f:
    model_output_1 = json.load(f)
model_output_1 = {key:model_output_1[key][:2] for key in model_output_1.keys()}
pmodel_output_1 = process_model_output(model_output_1)

with open('/home/louis/research/pdf_processor/extraction/output/run15000chars.json') as f:
    model_output_2 = json.load(f)
model_output_2 = {key:model_output_2[key][:2] for key in model_output_2.keys()}
pmodel_output_2 = process_model_output(model_output_2)

shared_keys = list(sorted(pmodel_output_2.keys() & pmodel_output_1.keys() & processed_database_output.keys()))
pmodel_output_1 = {key:pmodel_output_1[key] for key in shared_keys}
pmodel_output_2 = {key:pmodel_output_2[key] for key in shared_keys}
processed_database_output = {key:processed_database_output[key] for key in shared_keys}

correct_num_materials_1 = 0
correct_num_materials_2 = 0

for key in processed_database_output.keys():
    dbase_output = processed_database_output[key][1][0]
    model_prediction_1 = pmodel_output_1[key][1][0]
    model_prediction_2 = pmodel_output_2[key][1][0]
   
    if abs(dbase_output - model_prediction_1) < 0.01:
        correct_num_materials_1 += 1
        print("{: >20} {: >30} {: >10} {: >20}".format(key, processed_database_output[key][0][0], dbase_output, model_output_1[key][1]))
    else:
        print("{: >20} {: >30} {: >10} {: >20} {}".format(key, processed_database_output[key][0][0], dbase_output, model_output_1[key][1], "*"))


    if abs(dbase_output - model_prediction_2) < 0.01:
        correct_num_materials_2 += 1
        print("{: >20} {: >30} {: >10} {: >20}".format(key, processed_database_output[key][0][0], dbase_output, model_output_2[key][1]))
    else:
        print("{: >20} {: >30} {: >10} {: >20} {}".format(key, processed_database_output[key][0][0], dbase_output, model_output_2[key][1], "*"))

    print("------------------------------------------")

print("model_1 score:", correct_num_materials_1 / len(processed_database_output))
print("model_2 score:", correct_num_materials_2 / len(processed_database_output))


    physrevb.10.4572                            Ga1       6.07               6.07 K
    physrevb.10.4572                            Ga1       6.07               6.07 K
------------------------------------------
 physrevb.100.014503                      Eu1Fe2As2       27.0              29(2) K *
 physrevb.100.014503                      Eu1Fe2As2       27.0                 27 K
------------------------------------------
 physrevb.100.014507                          Pd1S2        2.0                8.0 K *
 physrevb.100.014507                          Pd1S2        2.0                8.0 K *
------------------------------------------
 physrevb.100.041109                            Rb1        2.1                  2 K *
 physrevb.100.041109                            Rb1        2.1                  2 K *
------------------------------------------
 physrevb.100.060103                         Au2Pb1       3.61                1.2 K *
 physrevb.100.060103                         Au2Pb1       3.