## Evaluate NCES

In [1]:
#from sklearn.model_selection import train_test_split
#import os
#def split_train_test(kb):
#    with open(f'Method/Datasets/{kb}/Data/Data.json') as file:
#        data = json.load(file)
#    data = list(data.items())
#    data_train, data_test = train_test_split(data, test_size=0.01, random_state=42)
#    os.makedirs(f'Method/Datasets/{kb}/Test_data/', exist_ok=True)
#    with open(f'Method/Datasets/{kb}/Test_data/Data.json', 'w') as file_test:
#        json.dump(dict(data_test), file_test, indent=3, ensure_ascii=False)
#        
#    with open(f'Method/Datasets/{kb}/Train_data/Data.json', 'w') as file_train:
#        json.dump(dict(data_train), file_train, indent=3, ensure_ascii=False)
#        
#    print('Done !')

In [53]:
#import os
#import shutil
#for kb in ['family-benchmark', 'semantic_bible']:
#    for rep in os.listdir(f"Method/Datasets/{kb}"):
#        if not rep in ["Train_data", "Test_data", f"{kb}.owl"]:
#            if rep == ".gitattributes":
#                os.remove(f"Method/Datasets/{kb}/{rep}")
#            else:
#                shutil.rmtree(f"Method/Datasets/{kb}/{rep}")
#    #!rm -r f"Method/Datasets/{kb}/Embeddings/" 

In [2]:
import sys
sys.path.append('./Method/')
from collections import defaultdict
import os, json
from tqdm import tqdm
from datasets import Dataset

def get_data(data_path):
    if not os.path.isfile(data_path):
        with open(f"{data_path}/Data.json") as file:
            data = json.load(file)
    else:
        with open(f"{data_path}") as file:
            data = json.load(file)
    new_data = defaultdict(lambda: [])
    for i,concept in tqdm(enumerate(data)):
        positives = data[concept]['positive examples']
        negatives = data[concept]['negative examples']
        new_data["id"].append(i)
        new_data["translation"].append({
                         "lang1": "StartPositive " + " ".join(positives)+" EndPositive "
                         + "StartNegative " + " ".join(negatives)+" EndNegative",
                        "lang2": concept})
    return Dataset.from_dict(new_data)

In [3]:
def removeInvalidParentheses(expression):
    left = 0
    right = 0

    # First, we find out the number of misplaced left and right parentheses.
    for char in expression:

        # Simply record the left one.
        if char == '(':
            left += 1
        elif char == ')':
            # If we don't have a matching left, then this is a misplaced right, record it.
            right = right + 1 if left == 0 else right

            # Decrement count of left parentheses because we have found a right
            # which CAN be a matching one for a left.
            left = left - 1 if left > 0 else left

    result = {}
    def recurse(expression, index, left_count, right_count, left_rem, right_rem, expr):
        # If we reached the end of the string, just check if the resulting expression is
        # valid or not and also if we have removed the total number of left and right
        # parentheses that we should have removed.
        if index == len(expression):
            if left_rem == 0 and right_rem == 0:
                ans = "".join(expr)
                result[ans] = 1
        else:

            # The discard case. Note that here we have our pruning condition.
            # We don't recurse if the remaining count for that parenthesis is == 0.
            if (expression[index] == '(' and left_rem > 0) or (expression[index] == ')' and right_rem > 0):
                recurse(expression, index + 1,
                        left_count,
                        right_count,
                        left_rem - int(expression[index] == '('),
                        right_rem - int(expression[index] == ')'), expr)

            expr.append(expression[index])    

            # Simply recurse one step further if the current character is not a parenthesis.
            if not expression[index] in {'(', ')'}:
                recurse(expression, index + 1,
                        left_count,
                        right_count,
                        left_rem,
                        right_rem, expr)
            elif expression[index] == '(':
                # Consider an opening bracket.
                recurse(expression, index + 1,
                        left_count + 1,
                        right_count,
                        left_rem,
                        right_rem, expr)
            elif expression[index] == ')' and left_count > right_count:
                # Consider a closing bracket.
                recurse(expression, index + 1,
                        left_count,
                        right_count + 1,
                        left_rem,
                        right_rem, expr)

            # Pop for backtracking.
            expr.pop()

    # Now, the left and right variables tell us the number of misplaced left and
    # right parentheses and that greatly helps pruning the recursion.
    recurse(expression, 0, 0, 0, left, right, [])     
    return list(result.keys())[0]

def get_predictions(kb_name='carcinogenesis', model_name='t5_small', batch_size=8):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
    from tqdm import tqdm
    import torch
    #sys.path.append('./Method/')
    from ontolearn import KnowledgeBase
    from owlapy.render import DLSyntaxObjectRenderer
    
    data = get_data(f"Method/Datasets/{kb_name}/Test_data/")
    model_path = "Method/transformers/results_"+kb_name + "_" + model_name.replace("-", "_").split("/")[-1]
    tokenizer = AutoTokenizer.from_pretrained(f"{model_path}/model/")
    model = AutoModelForSeq2SeqLM.from_pretrained(f"{model_path}/model/")
    
    kb = KnowledgeBase(path=f"Method/Datasets/{kb_name}/{kb_name}.owl")
    dl_syntax_renderer = DLSyntaxObjectRenderer()
    atomic_concepts = list(kb.ontology().classes_in_signature())
    atomic_concepts = [dl_syntax_renderer.render(a) for a in atomic_concepts]
    properties = [rel.get_iri().get_remainder() for rel in kb.ontology().object_properties_in_signature()]
    bad_tokens = list(set(tokenizer.vocab.keys())-set(['⊔', '⊓', '∃', '∀', '¬', '⊤', '⊥', ')', '(', '.'] + atomic_concepts + properties))
    bad_tokens_ids = tokenizer(bad_tokens, add_special_tokens=False).input_ids
                       
    source_lang = "lang1"
    target_lang = "lang2"
    prefix = ""#"translate instance to class expression: "
    def preprocess_function(examples):
        inputs = [prefix + example[source_lang] for example in examples["translation"]]
        targets = [example[target_lang] for example in examples["translation"]]
        model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets, max_length=64, truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    tokenized_data = data.map(preprocess_function, batched=True)
    print(f'\nTest size: {len(tokenized_data)}\n')
    class_expressions = []
    for i in tqdm(range(0, len(tokenized_data), batch_size)):
        data_batch = tokenized_data[i:i+batch_size]
        output_sequences = model.generate(
        input_ids=torch.tensor(data_batch['input_ids']),
        attention_mask=torch.tensor(data_batch['attention_mask']),
        #max_length = 12,
        no_repeat_ngram_size = 2,
        bad_words_ids=bad_tokens_ids,
        do_sample=False)
        predictions = list(map(removeInvalidParentheses, tokenizer.batch_decode(output_sequences, skip_special_tokens=True)))
        class_expressions.extend(predictions)
        
    return tokenized_data, tokenizer, kb, class_expressions

In [4]:
tokenized_data, tokenizer, kb, class_expressions = get_predictions()

98it [00:00, 23968.85it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]


Test size: 98



100%|██████████| 13/13 [00:42<00:00,  3.25s/it]


In [5]:
removeInvalidParentheses(class_expressions[14])

'Methyl ⊔ Krypton-83 ⊔ Six_ring ⊔ Sulfo   ⊔ Oxygen-41  ( ∃ hasAtom. Copper ).'

In [6]:
from helper_classes.syntax_checker import SyntaxChecker, Evaluator
#from Method.ontolearn import KnowledgeBase

In [7]:
kb_path = 'Method/Datasets/carcinogenesis/carcinogenesis.owl'

In [9]:
syntax_checker = SyntaxChecker(kb, tokenizer)
evaluator = Evaluator(kb)

In [11]:
test_cases = {'Alcohol ⊔ (Krypton-83 ⊓ Sulfide) ⊔ Sulfo ⊔ ( ∃ hasAtom. Copper . ) ': 
              ['Alcohol','⊔','(','Krypton-83','⊓','Sulfide',')','⊔','Sulfo','⊔','∃','hasAtom','.','Copper'], 
             'Alcohol ⊓ ∀': ['Alcohol'], 
             'Iodine ⊔ ∃inBond.': ['Iodine', '⊔', '∃', 'inBond', '.', '⊤'],
              '⊔': ['⊤'],
              'Alcohol ⊓ ∀∃inBond.': ['Alcohol', '⊓', '∀', 'inBond', '.', '⊤'],
              'Ar_halide ⊔ Oxygen-50 ⊔ Krypton-83 ⊔ Sulfo   ⊔ Non_ar_5c_ring .   ∃ inBond. Copper': 
              ['Ar_halide','⊔','Oxygen-50','⊔','Krypton-83','⊔','Sulfo','⊔','Non_ar_5c_ring','⊔','∃','inBond','.','Copper']
             }

In [12]:
def pass_test(func, test_cases):
    for key, value in test_cases.items():
        assert func(key) == value, f'Test failed for {key}:{value}'
    print('All test cases passed!!!')

In [13]:
pass_test(syntax_checker.correct, test_cases)

All test cases passed!!!


In [14]:
id = 10
ce = syntax_checker.correct(class_expressions[id])

In [43]:
for id in range(len(class_expressions)):
    pred = syntax_checker.correct(class_expressions[id])
    examples = tokenized_data['translation'][id]['lang1'].split()
    pos_end = examples.index('EndPositive')
    exact_solution = tokenized_data['translation'][id]['lang2']
    ce = ''.join(pred).replace('⊔', ' ⊔ ').replace('⊓', ' ⊓ ').replace('∀', '∀ ').replace('∃', '∃ ')
    print(f'id: {id}, prediction: {ce}, exact solution: {exact_solution}')
    positives = examples[1:pos_end]
    negatives = examples[pos_end+2:-1]
    evaluator.evaluate(pred, positives, negatives)

id: 0, prediction: Ar_halide ⊔ Oxygen-50 ⊔ Krypton-83 ⊔ Sulfo ⊔ Non_ar_5c_ring ⊔ ∃ inBond.Copper, exact solution: Bond-2 ⊔ Di8
Accuracy: 0.0%
F1 score: 0.0%
id: 1, prediction: Iodine ⊔ ∃ inBond.(Carbon-15 ⊔ Sulfur-74 ⊔ Sulfur-72), exact solution: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Fluorine))
Accuracy: 5.556%
F1 score: 10.526%
id: 2, prediction: Carbon-19 ⊔ Oxygen-50 ⊔ Krypton-83 ⊔ Sulfo ⊔ Non_ar_hetero_6_ring, exact solution: Carbon-15 ⊔ Carbon-26
Accuracy: 0.0%
F1 score: 0.0%
id: 3, prediction: Carbon-19 ⊔ Krypton-83 ⊔ Sulfo ⊔ Hydrogen-2 ⊔ Oxygen-41, exact solution: Carbon-193
Accuracy: 0.0%
F1 score: 0.0%
id: 4, prediction: Carbon-19 ⊔ Oxygen-50 ⊔ Oxygen-41 ⊔ Krypton-83 ⊔ Sulfo ⊔ ∃ inBond.Copper, exact solution: Di281 ⊔ Hydrogen-1
Accuracy: 0.0%
F1 score: 0.0%
id: 5, prediction: Iodine ⊔ ∃ inBond.(Nitrogen-36 ⊔ Sulfur-74 ⊔ Krypton-83), exact solution: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Oxygen-41))
Accuracy: 1.2269999999999999%
F1 score: 2.424%
id: 6, prediction: Carbon-19 ⊔ Krypton-83 ⊔ S