## Evaluate NCES

In [1]:
import os, random
from utils.syntax_checker import SyntaxChecker
from utils.evaluator import Evaluator
from ontolearn.knowledge_base import KnowledgeBase
from nces import BaseConceptSynthesis
from nces.synthesizer import ConceptSynthesizer
from utils.data import Data
from owlapy.parser import DLSyntaxParser
from dataloader import CSDataLoader
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

In [2]:
from argparse import Namespace
import json
import torch, pandas as pd
with open("settings.json") as setting:
    args = json.load(setting)
args = Namespace(**args)

In [3]:
import numpy as np, time
from collections import defaultdict

In [4]:
def before_pad(arg):
    arg_temp = []
    for atm in arg:
        if atm == 'PAD':
            break
        arg_temp.append(atm)
    return arg_temp

In [5]:
def compute_accuracy(prediction, target):
    def soft(arg1, arg2):
        arg1_ = arg1
        arg2_ = arg2
        if isinstance(arg1_, str):
            arg1_ = set(before_pad(BaseConceptSynthesis.decompose(arg1_)))
        else:
            arg1_ = set(before_pad(arg1_))
        if isinstance(arg2_, str):
            arg2_ = set(before_pad(BaseConceptSynthesis.decompose(arg2_)))
        else:
            arg2_ = set(before_pad(arg2_))
        return 100*float(len(arg1_.intersection(arg2_)))/len(arg1_.union(arg2_))

    def hard(arg1, arg2):
        arg1_ = arg1
        arg2_ = arg2
        if isinstance(arg1_, str):
            arg1_ = before_pad(BaseConceptSynthesis.decompose(arg1_))
        else:
            arg1_ = before_pad(arg1_)
        if isinstance(arg2_, str):
            arg2_ = before_pad(BaseConceptSynthesis.decompose(arg2_))
        else:
            arg2_ = before_pad(arg2_)
        return 100*float(sum(map(lambda x,y: x==y, arg1_, arg2_)))/max(len(arg1_), len(arg2_))
    soft_acc = sum(map(soft, prediction, target))/len(target)
    hard_acc = sum(map(hard, prediction, target))/len(target)
    return soft_acc, hard_acc

In [6]:
def sample_examples(pos, neg, num_examples):
    if min(len(neg),len(pos)) >= num_examples//2:
        if len(pos) > len(neg):
            num_neg_ex = num_examples//2
            num_pos_ex = num_examples-num_neg_ex
        else:
            num_pos_ex = num_examples//2
            num_neg_ex = num_examples-num_pos_ex
    elif len(pos) > len(neg):
        num_neg_ex = len(neg)
        num_pos_ex = num_examples-num_neg_ex
    elif len(pos) < len(neg):
        num_pos_ex = len(pos)
        num_neg_ex = num_examples-num_pos_ex
    positive = random.sample(pos, num_pos_ex)
    negative = random.sample(neg, num_neg_ex)
    return positive, negative

In [7]:
def map_to_token(model, idx_array):
    return model.inv_vocab[idx_array]

In [8]:
def collate_batch(batch):
    pos_emb_list = []
    neg_emb_list = []
    target_tokens_list = []
    target_labels = []
    for pos_emb, neg_emb, label in batch:
        pos_emb_list.append(pos_emb)
        neg_emb_list.append(neg_emb)
        target_labels.append(label)
    pos_emb_list = pad_sequence(pos_emb_list, batch_first=True, padding_value=0)
    neg_emb_list = pad_sequence(neg_emb_list, batch_first=True, padding_value=0)
    target_labels = pad_sequence(target_labels, batch_first=True, padding_value=-100)
    return pos_emb_list, neg_emb_list, target_labels

In [9]:
def get_data(kb, embeddings, kwargs):
    data_test_path = f"datasets/{kb}/Test_data/Data.json"
    with open(data_test_path, "r") as file:
        data_test = json.load(file)
    data_test = list(data_test.items())
    test_dataset = CSDataLoader(data_test, embeddings, kwargs)
    print("Number of learning problems: ", len(test_dataset))
    test_dataloader = DataLoader(test_dataset, batch_size=kwargs.batch_size, num_workers=kwargs.num_workers, collate_fn=collate_batch, shuffle=False)
    return test_dataloader

In [10]:
def predict_class_expressions(model_name, kb, args):
    args.knowledge_base_path = "datasets/"+f"{kb}/{kb}.owl"
    embeddings = pd.read_csv(f"embeddings/{kb}/ConEx_entity_embeddings.csv").set_index('Unnamed: 0')
    dataloader = get_data(kb, embeddings, args)
    model = torch.load(f"datasets/{kb}/Model_weights/{model_name}.pt", map_location=torch.device('cpu'))
    model.eval()
    soft_acc, hard_acc = 0.0, 0.0
    preds = []
    targets = []
    for x1, x2, labels in tqdm(dataloader):
        target_sequence = map_to_token(model, labels)
        pred_sequence, _ = model(x1, x2)
        preds.append(pred_sequence)
        targets.append(target_sequence)
        s_acc, h_acc = compute_accuracy(pred_sequence, target_sequence)
        soft_acc += s_acc
        hard_acc += h_acc
    print(f"Average syntactic accuracy, Soft: {soft_acc/len(dataloader)}%, Hard: {hard_acc/len(dataloader)}%")
    return np.concatenate(preds, 0), np.concatenate(targets, 0)

In [11]:
#t0 = time.time()
#Preds, Targets = predict_class_expressions("SetTransformer", "carcinogenesis", args)
#t1 = time.time()
#print("Duration: ", t1-t0)

In [32]:
def evaluate_nces(kb_name, models, args):
    print('#'*50)
    print('NCES evaluation on {} KB:'.format(kb_name))
    print('#'*50)
    All_metrics = {m: defaultdict(lambda: defaultdict(list)) for m in models}
    print()
    kb = KnowledgeBase(path=f"datasets/{kb_name}/{kb_name}.owl")
    namespace = kb.ontology()._onto.base_iri
    print("KB namespace: ", namespace)
    print()
    syntax_checker = SyntaxChecker(kb)
    evaluator = Evaluator(kb)
    dl_parser = DLSyntaxParser(namespace = namespace)
    All_individuals = set(kb.individuals())
    with open(f"datasets/{kb_name}/Test_data/Data.json", "r") as file:
        data_test = json.load(file)
    for model_name in models:
        t0 = time.time()
        predictions, targets = predict_class_expressions(model_name, kb_name, args)
        t1 = time.time()
        duration = (t1-t0)/len(predictions)
        print()
        print(f"##{model_name}##")
        print()
        for i, pb_str in enumerate(targets):
            pb_str = "".join(before_pad(pb_str))
            #examples = data_test[pb_str]
            #pos_examples = set(examples['positive examples'])
            #neg_examples = set(examples['negative examples'])
            try:
                end_idx = np.where(predictions[i] == 'PAD')[0][0] # remove padding token
            except IndexError:
                end_idx = 1
            pred = predictions[i][:end_idx]
            #print("Before parsing: ", pred.sum())
            succeed = False
            if (pred=='(').sum() > (pred==')').sum():
                for i in range(len(pred))[::-1]:
                    try:
                        prediction = dl_parser.parse_expression("".join(pred.tolist().insert(i,')')))
                        succeed = True
                        break
                    except Exception:
                        pass
                if not succeed:
                    try:
                        pred = syntax_checker.correct(pred.sum())
                        pred = list(syntax_checker.get_suggestions(pred))[-1]
                        prediction = syntax_checker.get_concept(pred)
                    except Exception:
                        print(f"Could not understand expression {pred}")
                        continue
            elif (pred==')').sum() > (pred=='(').sum():
                for i in range(len(pred)):
                    try:
                        prediction = dl_parser.parse_expression("".join(pred.tolist().insert(i,'(')))
                        succeed = True
                        break
                    except Exception:
                        pass
                if not succeed:
                    try:
                        pred = syntax_checker.correct(pred.sum())
                        pred = list(syntax_checker.get_suggestions(pred))[-1]
                        prediction = syntax_checker.get_concept(pred)
                    except Exception:
                        print(f"Could not understand expression {pred}")
                        continue
            else:
                try:
                    prediction = dl_parser.parse_expression("".join(pred.tolist()))
                except Exception:
                    try:
                        pred = syntax_checker.correct(pred.sum())
                        pred = list(syntax_checker.get_suggestions(pred))[-1]
                        prediction = syntax_checker.get_concept(pred)
                    except Exception:
                        print(f"Could not understand expression {pred}")
                        continue
            target_expression = dl_parser.parse_expression(pb_str) # The target class expression
            positive_examples = {ind.get_iri().as_str().split("/")[-1] for ind in kb.individuals(target_expression)}
            negative_examples = All_individuals-positive_examples
            acc, f1 = evaluator.evaluate(prediction, positive_examples, negative_examples)
            print(f'Problem {i}, Target: {pb_str}, Prediction: {syntax_checker.renderer.render(prediction)}, Acc: {acc}, F1: {f1}')
            print()
            All_metrics[model_name]['acc']['values'].append(acc)
            All_metrics[model_name]['prediction']['values'].append(syntax_checker.renderer.render(prediction))
            All_metrics[model_name]['f1']['values'].append(f1)
            All_metrics[model_name]['time']['values'].append(duration)
            
        for metric in All_metrics[model_name]:
            if metric != 'prediction':
                All_metrics[model_name][metric]['mean'] = [np.mean(All_metrics[model_name][metric]['values'])]
                All_metrics[model_name][metric]['std'] = [np.std(All_metrics[model_name][metric]['values'])]
        
        print(model_name+' Speed: {}s +- {} / lp'.format(round(All_metrics[model_name]['time']['mean'][0], 2),\
                                                               round(All_metrics[model_name]['time']['std'][0], 2)))
        print(model_name+' Avg Acc: {}% +- {} / lp'.format(round(All_metrics[model_name]['acc']['mean'][0], 2),\
                                                               round(All_metrics[model_name]['acc']['std'][0], 2)))
        print(model_name+' Avg F1: {}% +- {} / lp'.format(round(All_metrics[model_name]['f1']['mean'][0], 2),\
                                                               round(All_metrics[model_name]['f1']['std'][0], 2)))
#        print(model_name+' Avg Str_Acc: {}% +- {} / lp'.format(round(All_metrics[model_name]['str_acc']['mean'][0], 2),\
#                                                               round(All_metrics[model_name]['str_acc']['std'][0], 2)))
#        print("\n")
        print()
        
        with open("datasets/"+kb_name+"/Results/NCES.json", "w") as file:
            json.dump(All_metrics, file, indent=3, ensure_ascii=False)
    return All_metrics

In [33]:
#args

In [34]:
#F1_semb = evaluate_nces(kb_path_semb, kwargs, kb_name='semantic_bible')

In [35]:
#F1_semb['SetTransformer']['f1']['values']

In [36]:
F1_car = evaluate_nces("carcinogenesis", ["SetTransformer", "GRU", "LSTM"], args)

##################################################
NCES evaluation on carcinogenesis KB:
##################################################

KB namespace:  http://dl-learner.org/carcinogenesis#

Number of learning problems:  98


100%|██████████| 1/1 [00:02<00:00,  2.55s/it]


Average syntactic accuracy, Soft: 85.77870542156258%, Hard: 90.06590514993873%

##SetTransformer##

Problem 0, Target: Bond-2 ⊔ Di8, Prediction: Bond-2 ⊔ Di8, Acc: 100.0, F1: 100.0

Problem 1, Target: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Fluorine)), Prediction: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Fluorine)), Acc: 100.0, F1: 100.0

Problem 2, Target: Carbon-15 ⊔ Carbon-26, Prediction: Carbon-15 ⊔ Carbon-26, Acc: 100.0, F1: 100.0

Problem 3, Target: Carbon-193, Prediction: Carbon-193 ⊔ Nitrogen-33, Acc: 99.991, F1: 85.714

Problem 4, Target: Di281 ⊔ Hydrogen-1, Prediction: Di281 ⊔ Hydrogen-1, Acc: 100.0, F1: 100.0

Problem 5, Target: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Oxygen-41)), Prediction: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Hydrogen-8)), Acc: 99.982, F1: 98.30499999999999

Problem 6, Target: Hydrogen-1 ⊔ Phosphorus, Prediction: Hydrogen-1 ⊔ Phosphorus, Acc: 100.0, F1: 100.0

Problem 7, Target: Iodine ⊔ (∃ inBond.(Copper-96 ⊔ Fluorine-92)), Prediction: Iodine ⊔ (∃ inBond.(Copper-96 ⊔ Nitrogen-34)),

100%|██████████| 1/1 [00:07<00:00,  7.15s/it]


Average syntactic accuracy, Soft: 83.45856524427954%, Hard: 88.32843743558026%

##GRU##

Problem 0, Target: Bond-2 ⊔ Di8, Prediction: Bond-2 ⊔ Oxygen-52, Acc: 99.952, F1: 98.82600000000001

Problem 1, Target: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Fluorine)), Prediction: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Carbon-232)), Acc: 99.98700000000001, F1: 90.90899999999999

Problem 2, Target: Carbon-15 ⊔ Carbon-26, Prediction: Carbon-26 ⊔ Carbon-26, Acc: 99.996, F1: 96.97

Problem 3, Target: Carbon-193, Prediction: Carbon-193, Acc: 100.0, F1: 100.0

Problem 4, Target: Di281 ⊔ Hydrogen-1, Prediction: Di281 ⊔ Hydrogen-1, Acc: 100.0, F1: 100.0

Problem 5, Target: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Oxygen-41)), Prediction: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Tin-113)), Acc: 99.956, F1: 95.868

Problem 6, Target: Hydrogen-1 ⊔ Phosphorus, Prediction: Hydrogen-1 ⊔ Phosphorus, Acc: 100.0, F1: 100.0

Problem 7, Target: Iodine ⊔ (∃ inBond.(Copper-96 ⊔ Fluorine-92)), Prediction: Iodine ⊔ (∃ inBond.(Carbon-232 ⊔ Copper-9

100%|██████████| 1/1 [00:06<00:00,  6.84s/it]


Average syntactic accuracy, Soft: 74.10482374768091%, Hard: 81.70983544933121%

##LSTM##

Problem 0, Target: Bond-2 ⊔ Di8, Prediction: Bond-2 ⊔ Tellurium, Acc: 99.97399999999999, F1: 99.356

Problem 1, Target: Iodine ⊔ (∃ inBond.(Carbon-17 ⊔ Fluorine)), Prediction: Iodine ⊔ (∃ inBond.(Fluorine ⊔ Carbon-232)), Acc: 99.969, F1: 75.862

Problem 2, Target: Carbon-15 ⊔ Carbon-26, Prediction: Calcium-84 ⊔ Carbon-26, Acc: 99.991, F1: 94.118

Problem 3, Target: Carbon-193, Prediction: Carbon-193 ⊔ Cyanate, Acc: 99.991, F1: 85.714

Problem 4, Target: Di281 ⊔ Hydrogen-1, Prediction: Hydrogen-1 ⊔ Hydrogen-1, Acc: 99.991, F1: 99.71900000000001

Problem 5, Target: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Oxygen-41)), Prediction: Iodine ⊔ (∃ inBond.(Hydrogen-8 ⊔ Tin-113)), Acc: 99.956, F1: 95.868

Problem 6, Target: Hydrogen-1 ⊔ Phosphorus, Prediction: Hydrogen-1 ⊔ Phosphorus, Acc: 100.0, F1: 100.0

Problem 7, Target: Iodine ⊔ (∃ inBond.(Copper-96 ⊔ Fluorine-92)), Prediction: Iodine ⊔ (∃ inBond.(Carbon-232 

## Shuffle positive examples

## Statistical tests

In [20]:
from Method.helper_functions import wilcoxon_statistical_test
import warnings
warnings.filterwarnings('ignore')
import json

def celoe_vs_nces_stat_tests():
    with open('datasets/semantic_bible/Results/NCES.json') as file:
        nces = json.load(file)
    f1_nces = nces['LSTM']['f1']['values']
    acc_nces = nces['LSTM']['acc']['values']
    time_nces = nces['LSTM']['time']['values']
    with open('datasets/semantic_bible/Results/concept_learning_results_celoe.json') as celoe_file:
        celoe = json.load(celoe_file)
    f1_celoe = celoe['F-measure']
    acc_celoe = celoe['Accuracy']
    time_celoe = celoe['Runtime']
    _, p1 = wilcoxon_statistical_test(acc_nces, acc_celoe)
    _, p2 = wilcoxon_statistical_test(f1_nces, f1_celoe)
    _, p3 = wilcoxon_statistical_test(time_nces, time_celoe)
    for metric, p in zip(['Accuracy', 'F1', 'RunTime'], [p1,p2,p3]):
        print(f'On {metric} of Semantic Bible KB, p_value = ', p)
        if p<=0.05:
            print('Probably different distributions')
        else:
            print('Probably the same distribution')
        print()
    
    with open('datasets/family-benchmark/Results/NCES.json') as file:
        nces = json.load(file)
    f1_nces = nces['LSTM']['f1']['values']
    acc_nces = nces['LSTM']['acc']['values']
    time_nces = nces['LSTM']['time']['values']
    with open('datasets/family-benchmark/Results/concept_learning_results_celoe.json') as celoe_file:
        celoe = json.load(celoe_file)
    f1_celoe = celoe['F-measure']
    acc_celoe = celoe['Accuracy']
    time_celoe = celoe['Runtime']
    _, p1 = wilcoxon_statistical_test(acc_nces, acc_celoe)
    _, p2 = wilcoxon_statistical_test(f1_nces, f1_celoe)
    _, p3 = wilcoxon_statistical_test(time_nces, time_celoe)
    for metric, p in zip(['Accuracy', 'F1', 'RunTime'], [p1,p2,p3]):
        print(f'On {metric} of Family Benchmark KB, p_value = ', p)
        if p<=0.05:
            print('Probably different distributions')
        else:
            print('Probably the same distribution')
        print()
        
    with open('datasets/mutagenesis/Results/NCES.json') as file:
        nces = json.load(file)
    f1_nces = nces['LSTM']['f1']['values']
    acc_nces = nces['LSTM']['acc']['values']
    time_nces = nces['LSTM']['time']['values']
    with open('datasets/mutagenesis/Results/concept_learning_results_celoe.json') as celoe_file:
        celoe = json.load(celoe_file)
    f1_celoe = celoe['F-measure']
    acc_celoe = celoe['Accuracy']
    time_celoe = celoe['Runtime']
    _, p1 = wilcoxon_statistical_test(acc_nces, acc_celoe)
    _, p2 = wilcoxon_statistical_test(f1_nces, f1_celoe)
    _, p3 = wilcoxon_statistical_test(time_nces, time_celoe)
    for metric, p in zip(['Accuracy', 'F1', 'RunTime'], [p1,p2,p3]):
        print(f'On {metric} of Mutagenesis KB, p_value = ', p)
        if p<=0.05:
            print('Probably different distributions')
        else:
            print('Probably the same distribution')
        print()
        
    with open('datasets/carcinogenesis/Results/NCES.json') as file:
        nces = json.load(file)
    f1_nces = nces['LSTM']['f1']['values']
    acc_nces = nces['LSTM']['acc']['values']
    time_nces = nces['LSTM']['time']['values']
    with open('datasets/carcinogenesis/Results/concept_learning_results_celoe.json') as celoe_file:
        celoe = json.load(celoe_file)
    f1_celoe = celoe['F-measure']
    acc_celoe = celoe['Accuracy']
    time_celoe = celoe['Runtime']
    _, p1 = wilcoxon_statistical_test(acc_nces, acc_celoe)
    _, p2 = wilcoxon_statistical_test(f1_nces, f1_celoe)
    _, p3 = wilcoxon_statistical_test(time_nces, time_celoe)
    for metric, p in zip(['Accuracy', 'F1', 'RunTime'], [p1,p2,p3]):
        print(f'On {metric} of Carcinogenesis KB, p_value = ', p)
        if p<=0.05:
            print('Probably different distributions')
        else:
            print('Probably the same distribution')
        print()
        
    with open('datasets/vicodi/Results/NCES.json') as file:
        nces = json.load(file)
    f1_nces = nces['LSTM']['f1']['values']
    acc_nces = nces['LSTM']['acc']['values']
    time_nces = nces['LSTM']['time']['values']
    with open('datasets/vicodi/Results/concept_learning_results_celoe.json') as celoe_file:
        celoe = json.load(celoe_file)
    f1_celoe = celoe['F-measure']
    acc_celoe = celoe['Accuracy']
    time_celoe = celoe['Runtime']
    _, p1 = wilcoxon_statistical_test(acc_nces, acc_celoe)
    _, p2 = wilcoxon_statistical_test(f1_nces, f1_celoe)
    _, p3 = wilcoxon_statistical_test(time_nces, time_celoe)
    for metric, p in zip(['Accuracy', 'F1', 'RunTime'], [p1,p2,p3]):
        print(f'On {metric} of Vicodi KB, p_value = ', p)
        if p<=0.05:
            print('Probably different distributions')
        else:
            print('Probably the same distribution')
        print()

In [10]:
#celoe_vs_nces_stat_tests()

## Training curves

In [None]:
import json
plt_data_path = "datasets/semantic_bible/Plot_data/plot_data.json"
with open(plt_data_path,"r") as plt_file1:
    plt_data1 = json.load(plt_file1)

plt_data_path = "datasets/family-benchmark/Plot_data/plot_data.json"
with open(plt_data_path,"r") as plt_file2:
    plt_data2 = json.load(plt_file2)
    
    
plt_data_path = "datasets/carcinogenesis/Plot_data/plot_data.json"
with open(plt_data_path,"r") as plt_file3:
    plt_data3 = json.load(plt_file3)

plt_data_path = "datasets/mutagenesis/Plot_data/plot_data.json"
with open(plt_data_path,"r") as plt_file4:
    plt_data4 = json.load(plt_file4)
    
plt_data_path = "datasets/vicodi/Plot_data/plot_data.json"
with open(plt_data_path,"r") as plt_file5:
    plt_data5 = json.load(plt_file5)

import matplotlib.pyplot as plt

In [None]:
plt_data1.keys()

In [None]:
def plot_curves(plt_data1, plt_data2, name1, name2):

    Markers = ['--', ':', '-']
    Colors = ['g', 'b', 'm']
    i = 0
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10,7))
    #fig.suptitle('Sharing x per column, y per row')

    for crv in plt_data1['loss']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax1.plot(crv, mk, markersize=5, color=c)
        i += 1
    ax1.legend(('GRU', 'LSTM', 'CNN'))
    ax1.set_title(name1)
    ax1.set(ylabel='Loss')

    for crv in plt_data2['loss']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax2.plot(crv, mk, markersize=5, color=c)
        i += 1   
    ax2.legend(('GRU', 'LSTM', 'CNN'))
    ax2.set_title(name2)

    for crv in plt_data1['hard acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax3.plot(crv, mk, markersize=5, color=c)
        i += 1
    ax3.legend(('GRU', 'LSTM', 'CNN'))
    ax3.set(ylabel='Hard Accuracy', xlabel='Epochs')    

    for crv in plt_data2['hard acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax4.plot(crv, mk, markersize=5, color=c)
        i += 1
    ax4.legend(('GRU', 'LSTM', 'CNN'))
    ax4.set(xlabel='Epochs')    

    for ax in fig.get_axes():
        ax.label_outer()
    fig.savefig('training-curves'+name1+'_and_'+name2+'.pdf')
    fig.show()

In [None]:
plot_curves(plt_data1, plt_data2, name1='Semantic Bible', name2='Family Benchmark')

In [None]:
plot_curves(plt_data3, plt_data4, name1='Carcinogenesis', name2='Mutagenesis')

In [None]:
def plot_acc_curves(plt_data1, plt_data2, plt_data3, plt_data4, plt_data5, name1, name2, name3, name4, name5, mode='hard'):

    Markers = ['--', ':', '-']
    Colors = ['g', 'b', 'm']
    i = 0
    fig, ((ax1, ax2, ax3, ax4, ax5)) = plt.subplots(1, 5, figsize=(30,6), sharey=True)
    #fig, ((ax1, ax2, ax3)) = plt.subplots(1, 3, figsize=(15,5), sharey=True, sharex=True)
    #fig.suptitle('Sharing x per column, y per row')

    for crv in plt_data1[f'{mode} acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax1.plot(crv, mk, markersize=6, color=c)
        i += 1
    leg1 = ax1.legend(('GRU', 'LSTM', 'CNN'), prop={'size': 20})
    for line in leg1.get_lines():
        line.set_linewidth(4.0)
    ax1.set_title(name1, fontsize=30, fontweight="bold")
    ax1.set_xlabel('Epochs', fontsize=25)
    ax1.set_ylabel(mode.capitalize()+' Accuracy', fontsize=25)
    ax1.tick_params(axis='both', which='major', labelsize=20)

    for crv in plt_data2[f'{mode} acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax2.plot(crv, mk, markersize=6, color=c)
        i += 1   
    leg2 = ax2.legend(('GRU', 'LSTM', 'CNN'), prop={'size': 20})
    for line in leg2.get_lines():
        line.set_linewidth(4.0)
    ax2.set_title(name2, fontsize=30, fontweight="bold")
    ax2.set_xlabel('Epochs', fontsize=25)
    ax2.tick_params(axis='both', which='major', labelsize=20)

    for crv in plt_data3[f'{mode} acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax3.plot(crv, mk, markersize=10, color=c)
        i += 1
    leg3 = ax3.legend(('GRU', 'LSTM', 'CNN'), prop={'size': 20})
    for line in leg3.get_lines():
        line.set_linewidth(4.0)
    ax3.set_title(name3, fontsize=30, fontweight="bold")
    ax3.set_xlabel('Epochs', fontsize=25)
    ax3.tick_params(axis='both', which='major', labelsize=20)

    for crv in plt_data4[f'{mode} acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax4.plot(crv, mk, markersize=10, color=c)
        i += 1
    leg4 = ax4.legend(('GRU', 'LSTM', 'CNN'), prop={'size': 20})
    for line in leg4.get_lines():
        line.set_linewidth(4.0)
    ax4.set_xlabel('Epochs', fontsize=25)
    ax4.set_title(name4, fontsize=30, fontweight="bold")
    ax4.tick_params(axis='both', which='major', labelsize=20)
    
    for crv in plt_data5[f'{mode} acc']:
        mk = Markers[i%3]
        c = Colors[i%3]
        ax5.plot(crv, mk, markersize=10, color=c)
        i += 1
    
    leg5 = ax5.legend(('GRU', 'LSTM', 'CNN'), prop={'size': 20})
    for line in leg5.get_lines():
        line.set_linewidth(4.0)
    ax5.set_xlabel('Epochs', fontsize=25)
    ax5.set_title(name5, fontsize=30, fontweight="bold")
    ax5.tick_params(axis='both', which='major', labelsize=20)

    for ax in fig.get_axes():
        ax.label_outer()
    fig.savefig(f'accuracy-curves-all-KBs_{mode}.pdf', bbox_inches='tight')
    fig.show()

In [None]:
name1, name2, name3, name4, name5 = 'Semantic Bible', 'Family Benchmark', 'Carcinogenesis', 'Mutagenesis', 'Vicodi'
plot_acc_curves(plt_data1, plt_data2, plt_data3, plt_data4, plt_data5, name1, name2, name3, name4, name5, mode='soft')