In [39]:
import pandas as pd
import numpy as np
from os.path import join
import os
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import MACCSkeys
from ete3 import NCBITaxa
import random
random.seed(10)
import torch
import esm
from bioservices import *
from functions_and_dicts_data_preprocessing_GNN import *
from build_GNN import *
from data_preprocessing import *
import warnings
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
warnings.filterwarnings('ignore')
datasets_dir = "../../data"

CURRENT_DIR = os.getcwd()

## 1. Loading in Sabio data

#### Loading Sabio data

In [2]:
organism = "Seed plants"

df_Sabio = pd.read_table(join(datasets_dir, "kcat_model_" + organism + ".tsv"))

df_Sabio["kcat"] = df_Sabio["kcat"].astype('float')
df_Sabio["PMID"] = df_Sabio["PMID"].astype('Int64')

df_Sabio["substrate_IDs"] = df_Sabio["substrate_IDs"].str.split('#').apply(set)
df_Sabio["product_IDs"] = df_Sabio["product_IDs"].str.split('#').apply(set)

df_Sabio["Type"][df_Sabio['Type'].str.contains("wildtype")] = "wildtype"
df_Sabio["Type"][df_Sabio['Type'].str.contains("mutant")] = "mutant"

print("Number of data points: %s" % len(df_Sabio))
print("Number of UniProt IDs: %s" % len(set(df_Sabio["Uniprot IDs"])))

df_kcat = df_Sabio

Number of data points: 1344
Number of UniProt IDs: 370


#### Removing duplicates

In [3]:
droplist = []

for ind in df_kcat.index:
    UID, kcat = df_kcat["Uniprot IDs"][ind], df_kcat["kcat"][ind]
    help_df = df_kcat.loc[df_kcat["Uniprot IDs"] == UID].loc[df_kcat["kcat"] == kcat]
    
    if len(help_df) > 1:
        droplist = droplist + list(help_df.index)[1:]
        

In [4]:
df_kcat.drop(list(set(droplist)), inplace = True)
print("Dropping %s data points, because they are duplicated." % len(set(droplist)))
df_kcat.reset_index(inplace = True, drop = True)
df_kcat

Dropping 104 data points, because they are duplicated.


Unnamed: 0,ECs,Organism,Uniprot IDs,PMID,Type,kcat,Temperature,pH,Substrates,Products,substrate_IDs,product_IDs,Main Substrate,Sequence
0,1,Petunia hybrida,Q15GI3,16782809,wildtype,0.300000,28.0,6.5,Coniferyl acetate;NADPH,Acetate;NADP+;Isoeugenol,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,{InChI=1S/C10H12O2/c1-3-4-8-5-6-9(11)10(7-8)12...,InChI=1S/C12H14O4/c1-9(13)16-7-3-4-10-5-6-11(1...,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...
1,1,Ocimum basilicum,Q15GI4,16782809,wildtype,0.700000,28.0,6.5,NADPH;Coniferyl acetate,Eugenol;NADP+;Acetate,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C2H4O2/c1-2(3)4/h1H3,(H,3,4)/p-1, In...",InChI=1S/C12H14O4/c1-9(13)16-7-3-4-10-5-6-11(1...,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...
2,1.1.1,Cochlearia officinalis,A7DY56,24583623,wildtype,1.010000,30.0,5.0,NADPH;3-Methylcyclohexanone;H+,NADP+;3-Methylcyclohexanol,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-17...,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...
3,1.1.1,Cochlearia officinalis,A7DY56,24583623,wildtype,11.800000,30.0,5.0,3-Methylcyclohexanone;H+;NADH,NAD+;3-Methylcyclohexanol,"{InChI=1S/p+1, InChI=1S/C7H12O/c1-6-3-2-4-7(8)...","{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",InChI=1S/C21H29N7O14P2/c22-17-12-19(25-7-24-17...,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...
4,1.1.1,Cochlearia officinalis,A7DY56,24583623,wildtype,0.160000,30.0,9.5,3-Methylcyclohexanol;NADP+,3-Methylcyclohexanone;H+;NADPH,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,InChI=1S/C21H28N7O17P3/c22-17-12-19(25-7-24-17...,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1235,6.3.2.2,Arabidopsis thaliana,P46309,15180996,wildtype,0.101667,25.0,7.0,L-Glutamate;L-Cysteine;ATP,Phosphate;ADP;gamma-L-Glutamyl-L-cysteine,"{InChI=1S/C3H7NO2S/c4-2(1-7)3(5)6/h2,7H,1,4H2,...",{InChI=1S/C10H15N5O10P2/c11-8-5-9(13-2-12-8)15...,"InChI=1S/C5H9NO4/c6-3(5(9)10)1-2-4(7)8/h3H,1-2...",MALLSQAGGSYTVVPSGVCSKAGTKAVVSGGVRNLDVLRMKEAFGS...
1236,6.3.2.2,Arabidopsis thaliana,P46309,15180996,wildtype,0.113333,25.0,7.0,ATP;L-Glutamate;L-Cysteine,gamma-L-Glutamyl-L-cysteine;ADP;Phosphate,"{InChI=1S/C3H7NO2S/c4-2(1-7)3(5)6/h2,7H,1,4H2,...",{InChI=1S/C10H15N5O10P2/c11-8-5-9(13-2-12-8)15...,InChI=1S/C10H16N5O13P3/c11-8-5-9(13-2-12-8)15(...,MALLSQAGGSYTVVPSGVCSKAGTKAVVSGGVRNLDVLRMKEAFGS...
1237,6.3.2.52,Arabidopsis thaliana,Q8GZ29,29462792,wildtype,0.073333,-,-,(-)-Jasmonic acid;Glutamine;ATP,Diphosphate;Jasmonoyl-glutamine;AMP,"{InChI=1S/C5H10N2O3/c6-3(5(9)10)1-2-4(7)8/h3H,...","{InChI=1S/H4O7P2/c1-8(2,3)7-9(4,5)6/h(H2,1,2,3...",InChI=1S/C12H18O3/c1-2-3-4-5-10-9(8-12(14)15)6...,MLPKFDPTNQKACLSLLEDLTTNVKQIQDSVLEAILSRNAQTEYLR...
1238,6.3.2.52,Arabidopsis thaliana,Q8GZ29,29462792,wildtype,0.066667,-,-,Glutamine;ATP;(-)-Jasmonic acid,Diphosphate;AMP;Jasmonoyl-glutamine,"{InChI=1S/C5H10N2O3/c6-3(5(9)10)1-2-4(7)8/h3H,...","{InChI=1S/H4O7P2/c1-8(2,3)7-9(4,5)6/h(H2,1,2,3...","InChI=1S/C5H10N2O3/c6-3(5(9)10)1-2-4(7)8/h3H,1...",MLPKFDPTNQKACLSLLEDLTTNVKQIQDSVLEAILSRNAQTEYLR...


#### Removing top and bottom 3% of kcat values

In [5]:
def find_outliers_IQR(df):

   q1=df.quantile(0.25)

   q3=df.quantile(0.75)

   IQR=q3-q1

   outliers = df[((df<(q1-1.5*IQR)) | (df>(q3+1.5*IQR)))]

   return outliers

find_outliers_IQR(df_kcat["kcat"])

print(df_kcat['kcat'].quantile(0.03),  df_kcat['kcat'].quantile(0.97))

0.001 1649.7999999999956


In [6]:
print(len(df_kcat))
df_kcat = df_kcat[(df_kcat['kcat'] > df_kcat['kcat'].quantile(0.03)) & (df_kcat['kcat'] < df_kcat['kcat'].quantile(0.97))]
df_kcat.reset_index(inplace = True, drop = True)
print(len(df_kcat))

1240
1161


In [7]:
todrop= []

for ind in df_kcat.index:
    UID = df_kcat["Uniprot IDs"][ind]
    if len(UID.split(';')) > 1:
        todrop.append(ind)
        print(df_kcat["Uniprot IDs"][ind])
        print(todrop)
        
df_kcat.drop(todrop, inplace=True)
df_kcat.reset_index(inplace = True, drop = True)

Q41736;P00221
[281]
Q41736;P00221
[281, 282]
Q41736;P00221
[281, 282, 283]
P19866;P12860
[281, 282, 283, 297]
P19866;P12860
[281, 282, 283, 297, 298]
O04385;O23760
[281, 282, 283, 297, 298, 408]
O04385;O23760
[281, 282, 283, 297, 298, 408, 409]
P09342;P09114
[281, 282, 283, 297, 298, 408, 409, 436]
P09342;P09114
[281, 282, 283, 297, 298, 408, 409, 436, 437]
Q42588;P32260
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478]
Q42588;P16703
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479]
A0A2U7XUE3;Q9FEY5
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479, 499]
A0A2U7XUE3;Q9FEY5
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479, 499, 500]
Q9SC13;P60038
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479, 499, 500, 551]
Q42588;P32260
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479, 499, 500, 551, 581]
P55241;Q947C0
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 479, 499, 500, 551, 581, 630]
P23509;P55241;Q947C0
[281, 282, 283, 297, 298, 408, 409, 436, 437, 478, 

In [8]:
df_kcat["substrate_IDs"] = df_kcat["substrate_IDs"].apply(lambda x: (set(x)))
df_kcat["product_IDs"] = df_kcat["product_IDs"].apply(lambda x: (set(x)))

In [9]:
df_kcat.to_pickle(join(datasets_dir, "kcat_data_merged.pkl"))

## 2. Assigning IDs to every unique sequence and to every unique reaction in the dataset

#### Creating DataFrames for all sequences and for all reactions

In [10]:
df_reactions = pd.DataFrame({"substrates": df_kcat["substrate_IDs"],
                            "products" : df_kcat["product_IDs"]})

df_reactions = df_reactions.loc[df_reactions["substrates"] != set([])]
df_reactions = df_reactions.loc[df_reactions["products"] != set([])]


droplist = []
for ind in df_reactions.index:
    sub_IDs, pro_IDs = df_reactions["substrates"][ind], df_reactions["products"][ind]
    help_df = df_reactions.loc[df_reactions["substrates"] == sub_IDs].loc[df_reactions["products"] == pro_IDs]
    if len(help_df):
        for ind in list(help_df.index)[1:]:
            droplist.append(ind)
            
df_reactions.drop(list(set(droplist)), inplace = True)
df_reactions.reset_index(inplace = True, drop =True)

df_reactions["Reaction ID"] = ["Reaction_" + str(ind) for ind in df_reactions.index]

In [11]:
df_sequences = pd.DataFrame(data = {"Sequence" : df_kcat["Sequence"].unique()})
df_sequences = df_sequences.loc[~pd.isnull(df_sequences["Sequence"])]
df_sequences.reset_index(inplace = True, drop = True)
df_sequences["Sequence ID"] = ["Sequence_" + str(ind) for ind in df_sequences.index]

df_sequences

Unnamed: 0,Sequence,Sequence ID
0,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...,Sequence_0
1,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...,Sequence_1
2,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...,Sequence_2
3,MAKEGGLGENSRWSLGGMTALVTGGSKGIGEAVVEELAMLGAKVHT...,Sequence_3
4,MAKAGENSRDKSRWSLEGMTALVTGGSKGLGEAVVEELAMLGARVH...,Sequence_4
...,...,...
455,MSSLADLINLDLSDSTDQIIAEYIWIGGSGLDMRSKARTLPGPVTD...,Sequence_455
456,MSSLADLINLDLSDSTDQIIAEYIWIGGSGLDMRSKARTLPGPVTD...,Sequence_456
457,MALLSQAGGSYTVVPSGVCSKAGTKAVVSGGVRNLDVLRMKEAFGS...,Sequence_457
458,MLPKFDPTNQKACLSLLEDLTTNVKQIQDSVLEAILSRNAQTEYLR...,Sequence_458


#### Calculating maximal kcat value for each reaction and sequence

In [12]:
df_reactions["max_kcat_for_RID"] = np.nan
for ind in df_reactions.index:
    df_reactions["max_kcat_for_RID"][ind] = max(df_kcat.loc[df_kcat["substrate_IDs"] == df_reactions["substrates"][ind]].loc[df_kcat["product_IDs"] == df_reactions["products"][ind]]["kcat"])

In [13]:
df_sequences["max_kcat_for_UID"] = np.nan
for ind in df_sequences.index:
    df_sequences["max_kcat_for_UID"][ind] = max(df_kcat.loc[df_kcat["Sequence"] == df_sequences['Sequence'][ind]]["kcat"])

#### Calculating the sum of the molecular weights of all substrates and of all products

In [14]:
df_reactions["MW_frac"] = np.nan

for ind in df_reactions.index:
    substrates = list(df_reactions["substrates"][ind])
    products = list(df_reactions["products"][ind])
    
    mw_subs = mw_mets(metabolites = substrates)
    mw_pros = mw_mets(metabolites = products)
    
    if mw_subs == np.nan or mw_pros == np.nan:
        df_reactions["MW_frac"][ind] = np.inf
    if mw_pros != 0:
        df_reactions["MW_frac"][ind] = mw_subs/mw_pros
    else:
        df_reactions["MW_frac"][ind] = np.inf
        
df_reactions

Unnamed: 0,substrates,products,Reaction ID,max_kcat_for_RID,MW_frac
0,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,{InChI=1S/C10H12O2/c1-3-4-8-5-6-9(11)10(7-8)12...,Reaction_0,0.300000,1.001043
1,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C2H4O2/c1-2(3)4/h1H3,(H,3,4)/p-1, In...",Reaction_1,0.700000,1.001043
2,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",Reaction_2,2.340000,1.001175
3,"{InChI=1S/p+1, InChI=1S/C7H12O/c1-6-3-2-4-7(8)...","{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",Reaction_3,11.800000,1.000000
4,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,Reaction_4,0.170000,0.998826
...,...,...,...,...,...
462,{InChI=1S/C10H16N5O13P3/c11-8-5-9(13-2-12-8)15...,"{InChI=1S/C32H46N7O20P3S/c1-32(2,27(44)30(45)3...",Reaction_462,0.120000,0.996638
463,"{InChI=1S/H3N/h1H3, InChI=1S/C10H16N5O13P3/c11...","{InChI=1S/C5H10N2O3/c6-3(5(9)10)1-2-4(7)8/h3H,...",Reaction_463,8.080000,0.992493
464,"{InChI=1S/C3H7NO2S/c4-2(1-7)3(5)6/h2,7H,1,4H2,...",{InChI=1S/C10H15N5O10P2/c11-8-5-9(13-2-12-8)15...,Reaction_464,0.113333,0.993501
465,"{InChI=1S/C5H10N2O3/c6-3(5(9)10)1-2-4(7)8/h3H,...","{InChI=1S/H4O7P2/c1-8(2,3)7-9(4,5)6/h(H2,1,2,3...",Reaction_465,0.073333,0.996494


#### Calculating enzyme, reaction and substrate features

In [15]:
# model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D")

In [16]:
#creating model input:
df_sequences["model_input"] = [seq[:1022] for seq in df_sequences["Sequence"]]
model_input = [(df_sequences["Sequence ID"][ind], df_sequences["model_input"][ind]) for ind in df_sequences.index]
seqs = [model_input[i][1] for i in range(len(model_input))]
#loading ESM-2 model:
print(".....2(a) Loading ESM-2 model.")
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
#convert input into batches:

#Calculate ESM-2 representations
print(".....2(b) Calculating enzyme representations.")
df_sequences["Enzyme rep"] = ""

for ind in df_sequences.index:
    print(ind,"/",len(df_sequences))    
    batch_labels, batch_strs, batch_tokens = batch_converter([(df_sequences["Sequence ID"][ind], df_sequences["model_input"][ind])])
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33])
    df_sequences["Enzyme rep"][ind] = results["representations"][33][0, 1 : len(df_sequences["model_input"][ind]) + 1].mean(0).numpy()
    
df_sequences.head(5)

.....2(a) Loading ESM-2 model.


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to C:\Users\jearle/.cache\torch\hub\checkpoints\esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to C:\Users\jearle/.cache\torch\hub\checkpoints\esm2_t33_650M_UR50D-contact-regression.pt


.....2(b) Calculating enzyme representations.
0 / 460
1 / 460
2 / 460
3 / 460
4 / 460
5 / 460
6 / 460
7 / 460
8 / 460
9 / 460
10 / 460
11 / 460
12 / 460
13 / 460
14 / 460
15 / 460
16 / 460
17 / 460
18 / 460
19 / 460
20 / 460
21 / 460
22 / 460
23 / 460
24 / 460
25 / 460
26 / 460
27 / 460
28 / 460
29 / 460
30 / 460
31 / 460
32 / 460
33 / 460
34 / 460
35 / 460
36 / 460
37 / 460
38 / 460
39 / 460
40 / 460
41 / 460
42 / 460
43 / 460
44 / 460
45 / 460
46 / 460
47 / 460
48 / 460
49 / 460
50 / 460
51 / 460
52 / 460
53 / 460
54 / 460
55 / 460
56 / 460
57 / 460
58 / 460
59 / 460
60 / 460
61 / 460
62 / 460
63 / 460
64 / 460
65 / 460
66 / 460
67 / 460
68 / 460
69 / 460
70 / 460
71 / 460
72 / 460
73 / 460
74 / 460
75 / 460
76 / 460
77 / 460
78 / 460
79 / 460
80 / 460
81 / 460
82 / 460
83 / 460
84 / 460
85 / 460
86 / 460
87 / 460
88 / 460
89 / 460
90 / 460
91 / 460
92 / 460
93 / 460
94 / 460
95 / 460
96 / 460
97 / 460
98 / 460
99 / 460
100 / 460
101 / 460
102 / 460
103 / 460
104 / 460
105 / 460
106 

Unnamed: 0,Sequence,Sequence ID,max_kcat_for_UID,model_input,Enzyme rep
0,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...,Sequence_0,0.3,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...,"[-0.032205872, -0.031796478, -0.051493246, 0.0..."
1,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...,Sequence_1,0.7,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...,"[-0.016749375, -0.04821479, -0.049711417, 0.00..."
2,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...,Sequence_2,11.8,MANLRESSRDKSRWSLEGMTALVTGGSKGIGEAVVEELAMLGARVH...,"[-0.0007728002, -0.06124316, 0.041369695, 0.05..."
3,MAKEGGLGENSRWSLGGMTALVTGGSKGIGEAVVEELAMLGAKVHT...,Sequence_3,1.61,MAKEGGLGENSRWSLGGMTALVTGGSKGIGEAVVEELAMLGAKVHT...,"[-0.011831594, -0.06318853, 0.038726423, 0.021..."
4,MAKAGENSRDKSRWSLEGMTALVTGGSKGLGEAVVEELAMLGARVH...,Sequence_4,0.56,MAKAGENSRDKSRWSLEGMTALVTGGSKGLGEAVVEELAMLGARVH...,"[-0.0063330643, -0.07236032, 0.040615078, 0.02..."


In [17]:
def get_metabolite_type(met):
    if is_KEGG_ID(met):
        return("KEGG")
    elif is_InChI(met):
        return("InChI")
    else:
        return("invalid")

def get_reaction_site_smarts(metabolites):
    reaction_site = ""
    for met in metabolites:
        met_type = get_metabolite_type(met)
        if met_type == "KEGG":
            try:
                Smarts = Chem.MolToSmarts(Chem.MolFromMolFile(join("", "", "data", "mol-files",  met + ".mol")))
            except OSError:
                return(np.nan)
        elif met_type == "InChI":
            Smarts = Chem.MolToSmarts(Chem.inchi.MolFromInchi(met))
        else:
            Smarts = "invalid"
        reaction_site = reaction_site + "." + Smarts
    return(reaction_site[1:])


def is_KEGG_ID(met):
    #a valid KEGG ID starts with a "C" or "D" followed by a 5 digit number:
    if len(met) == 6 and met[0] in ["C", "D"]:
        try:
            int(met[1:])
            return(True)
        except: 
            pass
    return(False)

def is_InChI(met):
    m = Chem.inchi.MolFromInchi(met,sanitize=False)
    if m is None:
      return(False)
    else:
      try:
        Chem.SanitizeMol(m)
      except:
        print('.......Metabolite string "%s" is in InChI format but has invalid chemistry' % met)
        return(False)
    return(True)

def convert_fp_to_array(difference_fp_dict):
    fp = np.zeros(2048)
    for key in difference_fp_dict.keys():
        fp[key] = difference_fp_dict[key]
    return(fp)

In [18]:
df_reactions["difference_fp"], df_reactions["structural_fp"],  = "", ""
for ind in df_reactions.index:
    left_site = get_reaction_site_smarts(df_reactions["substrates"][ind])
    right_site = get_reaction_site_smarts(df_reactions["products"][ind])
    if not pd.isnull(left_site) and not pd.isnull(right_site):
        rxn_forward = AllChem.ReactionFromSmarts(left_site + ">>" + right_site)
        difference_fp = Chem.rdChemReactions.CreateDifferenceFingerprintForReaction(rxn_forward)
        difference_fp = convert_fp_to_array(difference_fp.GetNonzeroElements())
        df_reactions["difference_fp"][ind] = difference_fp
        df_reactions["structural_fp"][ind] = Chem.rdChemReactions.CreateStructuralFingerprintForReaction(rxn_forward).ToBitString()

df_reactions.head(5)

Unnamed: 0,substrates,products,Reaction ID,max_kcat_for_RID,MW_frac,difference_fp,structural_fp
0,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,{InChI=1S/C10H12O2/c1-3-4-8-5-6-9(11)10(7-8)12...,Reaction_0,0.3,1.001043,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...
1,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C2H4O2/c1-2(3)4/h1H3,(H,3,4)/p-1, In...",Reaction_1,0.7,1.001043,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...
2,{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",Reaction_2,2.34,1.001175,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...
3,"{InChI=1S/p+1, InChI=1S/C7H12O/c1-6-3-2-4-7(8)...","{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",Reaction_3,11.8,1.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...
4,"{InChI=1S/C7H14O/c1-6-3-2-4-7(8)5-6/h6-8H,2-5H...",{InChI=1S/C21H30N7O17P3/c22-17-12-19(25-7-24-1...,Reaction_4,0.17,0.998826,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001101111100...


In [19]:
df_sequences.to_pickle(join(datasets_dir, "all_sequences_with_IDs.pkl"))
df_reactions.to_pickle(join(datasets_dir, "all_reactions_with_IDs.pkl"))

In [20]:
df_sequences["max_kcat_for_UID"] = np.nan
for ind in df_sequences.index:
    df_sequences["max_kcat_for_UID"][ind] = max(df_kcat.loc[df_kcat["Sequence"] == df_sequences['Sequence'][ind]]["kcat"])

#### Mapping Sequence and Reaction IDs to kcat_df

In [21]:
df_kcat = df_kcat.merge(df_sequences, on = "Sequence", how = "left")

In [22]:
df_reactions.rename(columns = {"substrates" : "substrate_IDs",
                              "products" : "product_IDs"}, inplace = True)

df_kcat["Reaction ID"] = np.nan
df_kcat["MW_frac"] = np.nan
df_kcat["max_kcat_for_RID"] = np.nan
df_kcat["difference_fp"] = ""
df_kcat["structural_fp"] = ""

for ind in df_kcat.index:
    sub_set, pro_set = df_kcat["substrate_IDs"][ind], df_kcat["product_IDs"][ind]
    
    help_df = df_reactions.loc[df_reactions["substrate_IDs"] == sub_set].loc[df_reactions["product_IDs"] == pro_set]
    if len(help_df) == 1:
        df_kcat["Reaction ID"][ind] = list(help_df["Reaction ID"])[0]
        df_kcat["max_kcat_for_RID"][ind] = list(help_df["max_kcat_for_RID"])[0]
        df_kcat["MW_frac"][ind] = list(help_df["MW_frac"])[0]
        df_kcat["difference_fp"][ind] = list(help_df["difference_fp"])[0]
        df_kcat["structural_fp"][ind] = list(help_df["structural_fp"])[0]
df_kcat.head(2)

Unnamed: 0,ECs,Organism,Uniprot IDs,PMID,Type,kcat,Temperature,pH,Substrates,Products,...,Sequence,Sequence ID,max_kcat_for_UID,model_input,Enzyme rep,Reaction ID,MW_frac,max_kcat_for_RID,difference_fp,structural_fp
0,1,Petunia hybrida,Q15GI3,16782809,wildtype,0.3,28.0,6.5,Coniferyl acetate;NADPH,Acetate;NADP+;Isoeugenol,...,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...,Sequence_0,0.3,MTTGKGKILILGATGYLGKYMVKASISLGHPTYAYVMPLKKNSDDS...,"[-0.032205872, -0.031796478, -0.051493246, 0.0...",Reaction_0,1.001043,0.3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...
1,1,Ocimum basilicum,Q15GI4,16782809,wildtype,0.7,28.0,6.5,NADPH;Coniferyl acetate,Eugenol;NADP+;Acetate,...,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...,Sequence_1,0.7,MEENGMKSKILIFGGTGYIGNHMVKGSLKLGHPTYVFTRPNSSKTT...,"[-0.016749375, -0.04821479, -0.049711417, 0.00...",Reaction_1,1.001043,0.7,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1100111100000001001000110110010001001111111100...


In [23]:
df_kcat["MACCS FP"] = ""

for ind in df_kcat.index:
    id = df_kcat["Main Substrate"][ind]
    if id[0] == "C":
        try:
            mol = Chem.MolFromMolFile(join(datasets_dir,"mol-files", id + '.mol'))
        except OSError:
            None
    else:
        try:
            mol = Chem.inchi.MolFromInchi(id,sanitize=False)
        except OSError:
            None
    if mol is not None:
        maccs_fp = MACCSkeys.GenMACCSKeys(mol).ToBitString()
        df_kcat["MACCS FP"][ind] = maccs_fp

#### Calculating the maximal kcat value for every EC number in the dataset

In [24]:
df_EC_kcat = pd.read_csv(join(datasets_dir, "max_EC_" + organism + ".tsv"), sep = "\t", header=0)

df_EC_kcat.head(5)
df_kcat["max_kcat_for_EC"] = np.nan

for ind in df_kcat.index:
    EC = df_kcat["ECs"][ind]
    max_kcat = 0
    try:
        print(EC)
        max_kcat = df_EC_kcat.loc[df_EC_kcat["EC"] == EC, "max_kcat"].iloc[0]
        print(max_kcat)
    except:
        pass
    if max_kcat != 0:
        df_kcat["max_kcat_for_EC"][ind] = max_kcat
df_kcat.to_pickle(join(datasets_dir, "merged_and_grouped_kcat_dataset2.pkl"))     

1
0.7
1
0.7
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.0
1.1.1
44.

## 3. Removing outliers

#### Removing non-optimally measured values

To ignore $kcat$ values that were obtained under non-optimal conditions, we exclude values lower than 0.1\% than the maximal $kcat$ value for the same enzyme, reaction or EC number.

In [25]:
df_kcat["frac_of_max_UID"] = np.nan
df_kcat["frac_of_max_RID"] = np.nan
df_kcat["frac_of_max_EC"] = np.nan

for ind in df_kcat.index:
    df_kcat["frac_of_max_UID"][ind] =  df_kcat["kcat"][ind]/df_kcat["max_kcat_for_UID"][ind]
    df_kcat["frac_of_max_RID"][ind] =  df_kcat["kcat"][ind]/df_kcat["max_kcat_for_RID"][ind]
    df_kcat["frac_of_max_EC"][ind] = df_kcat["kcat"][ind]/df_kcat["max_kcat_for_EC"][ind]

len(df_kcat)

1132

In [26]:
n = len(df_kcat)

df_kcat = df_kcat.loc[df_kcat["frac_of_max_UID"] >= 0.01]
print(len(df_kcat))
df_kcat = df_kcat.loc[df_kcat["frac_of_max_RID"] >= 0.01]
print(len(df_kcat))
df_kcat = df_kcat.loc[df_kcat["frac_of_max_EC"] <= 10]
print(len(df_kcat))
df_kcat = df_kcat.loc[df_kcat["frac_of_max_EC"] >= 0.01]
print(len(df_kcat))

1086
1004
1004
893


In [27]:
print("We remove %s data points, because we suspect that these kcat values were not measure for the natural reaction " \
    "of an enzyme or under non-optimal conditions." % (n-len(df_kcat)))

We remove 239 data points, because we suspect that these kcat values were not measure for the natural reaction of an enzyme or under non-optimal conditions.


#### Removing data points with reaction queations with uneven fraction of molecular weights

In [28]:
n = len(df_kcat)

df_kcat = df_kcat.loc[df_kcat["MW_frac"] < 3]
df_kcat = df_kcat.loc[df_kcat["MW_frac"] > 1/3]

print("We remove %s data points because the sum of molecular weights of substrates does not match the sum of molecular" \
      "weights of the products." % (n-len(df_kcat)))

We remove 30 data points because the sum of molecular weights of substrates does not match the sum of molecularweights of the products.


In [29]:
print("Size of final kcat dataset: %s" % len(df_kcat))
df_kcat.to_pickle(join(datasets_dir, "final_kcat_dataset_" + organism + ".pkl"))

Size of final kcat dataset: 863


## 4. Preparing dataset and splitting into train-test

In [30]:
df_kcat = pd.read_pickle(join(datasets_dir, "final_kcat_dataset_" + organism + ".pkl"))
df_kcat["log10_kcat"] = [np.log10(x) for x in df_kcat["kcat"]]

#### Making input for GNN

In [31]:
inchi_ids = {}
for i, element in enumerate(df_kcat["Main Substrate"]):
    if element[0] != 'C' and element not in inchi_ids.keys():
        inchi_ids[element] = str(i)
        mol = Chem.inchi.MolFromInchi(element)
        if not mol is None:
            calculate_atom_and_bond_feature_vectors(mol, str(i))
        Chem.rdmolfiles.MolToMolFile(Chem.inchi.MolFromInchi(element), join(datasets_dir,"mol-files", str(i) + ".mol")  )  

#### Splitting glucosinolates into validation dataset

Search UniProt for GO term related to glucosionalte metabolic process, download file as .tsv and filter dataset

In [32]:
glucosinolates = pd.read_table(join(datasets_dir,"glucosinolates.tsv"))["Entry"].tolist()
df_validation = df_kcat[df_kcat["Uniprot IDs"].isin(glucosinolates)]
df_validation.reset_index(inplace=True, drop = True)
df_kcat = df_kcat[~df_kcat["Uniprot IDs"].isin(glucosinolates)]
df_kcat.reset_index(inplace=True, drop = True)
split = "full"

If training-testing with only Arabidopsis data:

In [33]:
# df_kcat = df_kcat[df_kcat["Organism"] == 'Arabidopsis thaliana']
# df_kcat.reset_index(inplace=True, drop = True)
# split = "Arabidopsis"

If training-testing with only Brassicaceae data:

In [34]:
# ncbi = NCBITaxa()

# organisms = {}

# def is_brassicaceae(org):
#     try:
#         tax_id = ncbi.get_name_translator([org])[org][0]
#         lineage = ncbi.get_lineage(tax_id)
#         if 3700 not in lineage:
#             return(False)
#         else:
#             return(True)
#     except KeyError:
#         return(False)
    
# for org in df_kcat["Organism"].tolist():
#     if org not in organisms.keys():
#         organisms[org] = is_brassicaceae(org)

# df_kcat = df_kcat[df_kcat["Organism"].isin([key for key, value in organisms.items() if value is True])]
# df_kcat.reset_index(inplace=True, drop = True)
# split = "Brassicaceae"

If training-testing only with wildtype data:

In [35]:
# df_kcat = df_kcat[df_kcat["Type"].str.contains("wildtype")]
# df_kcat.reset_index(inplace=True, drop = True)
# split = "wildtype"

If training-testing only with secondary metabolite data:

In [36]:
# secondary = pd.read_table(join(datasets_dir,"secondary_metabolites.tsv"))["Entry"].tolist()
# df_kcat = df_kcat[df_kcat["Uniprot IDs"].isin(secondary)]
# df_kcat.reset_index(inplace=True, drop = True)
# split = "secondary"

In [37]:
# os.mkdir(join(datasets_dir, "splits", split))

#### Splitting into train-test

In [40]:
df = df_kcat.copy()
df = df.sample(frac = 1, random_state=123)
df.reset_index(drop= True, inplace = True)

train_df, test_df = split_dataframe_enzyme(frac = 5, df = df.copy())
print("Test set size: %s" % len(test_df))
print("Training set size: %s" % len(train_df))
print("Size of test set in percent: %s" % np.round(100*len(test_df)/ (len(test_df) + len(train_df))))

train_df.reset_index(inplace = True, drop = True)
test_df.reset_index(inplace = True, drop = True)

train_df.to_pickle(join(datasets_dir, "splits", split, "train_df_kcat_%s.pkl" %organism))
test_df.to_pickle(join(datasets_dir, "splits", split, "test_df_kcat_%s.pkl" %organism))

Test set size: 221
Training set size: 635
Size of test set in percent: 26.0


#### Splitting CV folds

In [41]:
data_train2 = train_df.copy()
data_train2["index"] = list(data_train2.index)

data_train2, df_fold = split_dataframe_enzyme(df = data_train2, frac=5)
indices_fold1 = list(df_fold["index"])
print(len(data_train2), len(indices_fold1))#

data_train2, df_fold = split_dataframe_enzyme(df = data_train2, frac=4)
indices_fold2 = list(df_fold["index"])
print(len(data_train2), len(indices_fold2))

data_train2, df_fold = split_dataframe_enzyme(df = data_train2, frac=3)
indices_fold3 = list(df_fold["index"])
print(len(data_train2), len(indices_fold3))

data_train2, df_fold = split_dataframe_enzyme(df = data_train2, frac=2)
indices_fold4 = list(df_fold["index"])
indices_fold5 = list(data_train2["index"])
print(len(data_train2), len(indices_fold4))


fold_indices = [indices_fold1, indices_fold2, indices_fold3, indices_fold4, indices_fold5]

CV_train_indices = [[], [], [], [], []]
CV_test_indices = [[], [], [], [], []]

for i in range(5):
    for j in range(5):
        if i != j:
            CV_train_indices[i] = CV_train_indices[i] + fold_indices[j]
    CV_test_indices[i] = fold_indices[i]
    
    
np.save(join(datasets_dir, "splits", split, "CV_train_indices_%s" %organism), CV_train_indices)
np.save(join(datasets_dir, "splits", split, "CV_test_indices_%s" %organism), CV_test_indices)

510 125
406 104
287 119
147 140


## 5. Building GNN for substrate representation

In [42]:
# os.mkdir(join(datasets_dir, "GNN_input_data", split))

for ind in train_df.index:
    calculate_and_save_input_matrixes(inchi_ids, sample_ID = "train_" + str(ind), df = train_df,
                                      save_folder = join(datasets_dir, "GNN_input_data", split))
    
for ind in test_df.index:
    calculate_and_save_input_matrixes(inchi_ids, sample_ID = "test_" + str(ind), df = test_df,
                                      save_folder = join(datasets_dir, "GNN_input_data", split))
    
for ind in df_validation.index:
    calculate_and_save_input_matrixes(inchi_ids, sample_ID = "val_" + str(ind), df = df_validation,
                                    save_folder = join(datasets_dir, "GNN_input_data", split))

Could not create input for substrate ID 320
Could not create input for substrate ID 583
Could not create input for substrate ID 321


In [43]:
train_indices = os.listdir(join(datasets_dir, "GNN_input_data", split))
train_indices = [index[:index.rfind("_")] for index in train_indices]
train_indices = list(set([index for index in train_indices if "train" in index]))

test_indices = os.listdir(join(datasets_dir, "GNN_input_data", split))
test_indices = [index[:index.rfind("_")] for index in test_indices]
test_indices = list(set([index for index in test_indices if "test" in index]))

#### Hyper-parameter optimization with CV

In [44]:
param_grid = {'batch_size': [32,64,96],
                'D': [50,100],
                'learning_rate': [0.01, 0.1],
                'epochs': [30,50,80],
                'l2_reg_fc' : [0.01, 0.1, 1],
                'l2_reg_conv': [0.01, 0.1, 1],
                'rho': [0.9, 0.95, 0.99]}

params_list = [(batch_size, D, learning_rate, epochs, l2_reg_fc, l2_reg_conv, rho) for batch_size in param_grid['batch_size'] for D in param_grid["D"] for learning_rate in param_grid['learning_rate']
                for epochs in param_grid['epochs'] for l2_reg_fc in param_grid['l2_reg_fc'] for l2_reg_conv in param_grid['l2_reg_conv'] for rho in param_grid["rho"]]

params_list = random.sample(params_list, 10)

In [None]:
count = 0
results=[]

for params in params_list:

    batch_size, D, learning_rate, epochs, l2_reg_fc, l2_reg_conv, rho = params
    count +=1
    MAE = []

    for i in range(5):
        train_index, test_index  = CV_train_indices[i], CV_test_indices[i]
        train_index = [ind for ind in train_indices if int(ind.split("_")[1]) in train_index]
        test_index = [ind for ind in train_indices if int(ind.split("_")[1]) in test_index]

        train_params = {'batch_size': batch_size,
                    'folder' :join(datasets_dir, "GNN_input_data/full"),
                    'list_IDs' : np.array(train_index),
                    'shuffle': True}

        test_params = {'batch_size': len(test_index),
                    'folder' : join(datasets_dir, "GNN_input_data/full"),
                    'list_IDs' : np.array(test_index),
                    'shuffle': False}

        training_generator = DataGenerator(**train_params)
        test_generator = DataGenerator(**test_params)


        model = DMPNN_without_extra_features(l2_reg_conv = l2_reg_conv, l2_reg_fc = l2_reg_fc, learning_rate = learning_rate,
                        D = D, N = N, F1 = F1, F2 = F2, F= F, drop_rate = 0.0, ada_rho = rho)
        model.fit(training_generator, epochs= epochs, shuffle = True, verbose = 1)

        #get test_y:
        test_indices_y = [int(ind.split("_")[1]) for ind in train_indices if ind in test_index]
        test_y = np.array([train_df["kcat"][ind] for ind in test_indices_y])

        pred_test = model.predict(test_generator)
        mae = np.median(abs(np.array([10**x for x in pred_test]) - np.reshape(test_y[:len(pred_test)], (-1,1))))
        print(mae)
        MAE.append(mae)

    results.append({"batch_size" : batch_size, "D" : D , "learning_rate" : learning_rate, "epochs" : epochs,
                    "l2_reg_fc" : l2_reg_fc, "l2_reg_conv" : l2_reg_conv, "rho" : rho, "cv_mae" : np.mean(MAE)})

params = min(results, key=lambda d: d['cv_mae'])
print(params)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
3.291889833807945
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
5.920713093280792
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30


{'batch_size': 32, 'D': 50, 'learning_rate': 0.01, 'epochs': 30, 'l2_reg_fc': 0.1, 'l2_reg_conv': 1, 'rho': 0.9, 'cv_mae': 2.4853503725624084}

#### Training the model with the best set of hyperparmeters on the whole training set and validate it on the test set

In [55]:
batch_size = 32
D = 50
learning_rate = 0.01
epochs = 30
l2_reg_fc = 0.1
l2_reg_conv = 1
rho = 0.9

In [56]:
train_indices = os.listdir(join(datasets_dir, "GNN_input_data/full"))
train_indices = [index[:index.rfind("_")] for index in train_indices]
train_indices = list(set([index for index in train_indices if "train" in index]))

test_indices = os.listdir(join(datasets_dir, "GNN_input_data/full"))
test_indices = [index[:index.rfind("_")] for index in test_indices]
test_indices = list(set([index for index in test_indices if "test" in index]))

train_params = {'batch_size': batch_size,
              'folder' :join(datasets_dir, "GNN_input_data/full"),
              'list_IDs' : train_indices,
              'shuffle': True}

test_params = {'batch_size': batch_size,
              'folder' :join(datasets_dir, "GNN_input_data/full"),
              'list_IDs' : test_indices,
              'shuffle': False}

training_generator = DataGenerator(**train_params)
test_generator = DataGenerator(**test_params)

model = DMPNN_without_extra_features(l2_reg_conv = l2_reg_conv, l2_reg_fc = l2_reg_fc, learning_rate = learning_rate,
                  D = D, N = N, F1 = F1, F2 = F2, F= F, drop_rate = 0.0, ada_rho = rho)

model.fit(training_generator, epochs= epochs, shuffle = True, verbose = 1)
model.save_weights(join(datasets_dir, "model_weights", "saved_model_GNN_best_hyperparameters"))

pred_test = model.predict(test_generator)
test_indices_y = [int(ind.split("_")[1]) for ind in np.array(test_indices)]
test_y = np.array([test_df["kcat"][ind] for ind in test_indices_y])

#### Calculating substrate representation for every data point in training and test set

In [57]:
model = DMPNN_without_extra_features(l2_reg_conv = l2_reg_conv, l2_reg_fc = l2_reg_fc, learning_rate = learning_rate,
                  D = D, N = N, F1 = F1, F2 = F2, F= F, drop_rate = 0.0, ada_rho = rho)
model.load_weights(join(datasets_dir, "model_weights", "saved_model_GNN_best_hyperparameters"))

get_fingerprint_fct = K.function([model.layers[0].input, model.layers[26].input,
                                  model.layers[3].input],
                                  [model.layers[-10].output])

In [58]:
input_data_folder = join(datasets_dir, "GNN_input_data", split)   

def get_representation_input(cid_list):
    XE = ();
    X = ();
    A = ();
    # Generate data
    for cid in cid_list:
        try:
            X = X + (np.load(join(input_data_folder, cid + '_X.npy')), );
            XE = XE + (np.load(join(input_data_folder, cid + '_XE.npy')), );
            A = A + (np.load(join(input_data_folder, cid + '_A.npy')), );
        except FileNotFoundError: #return zero arrays:
            X = X + (np.zeros((N,32)), );
            XE = XE + (np.zeros((N,N,F)), );
            A = A + (np.zeros((N,N,1)), );
    return(XE, X, A)

input_data_folder = join(datasets_dir, "GNN_input_data", split)   
def get_substrate_representations(df, training_set, testing_set, get_fingerprint_fct):
    df["GNN FP"] = ""
    i = 0
    n = len(df)
    
    cid_all = list(df.index)
    if training_set == True:
        prefix = "train_"
    elif testing_set == True:
        prefix = "test_"
    else:
        prefix = "val_"
    cid_all = [prefix + str(cid) for cid in cid_all]
    
    while i*32 <= n:
        if (i+1)*32  <= n:
            XE, X, A = get_representation_input(cid_all[i*32:(i+1)*32])
            representations = get_fingerprint_fct([np.array(XE), np.array(X),np.array(A)])[0]
            df["GNN FP"][i*32:(i+1)*32] = list(representations[:, :52])
        else:
            print(i)
            XE, X, A = get_representation_input(cid_all[-min(32,n):])
            representations = get_fingerprint_fct([np.array(XE), np.array(X),np.array(A)])[0]
            df["GNN FP"][-min(32,n):] = list(representations[:, :52])
        i += 1
        
    ### set all GNN FP-entries with no input matrices to np.nan:
    all_X_matrices = os.listdir(input_data_folder)
    for ind in df.index:
        if prefix +str(ind) +"_X.npy" not in all_X_matrices:
            df["GNN FP"][ind] = np.nan
    return(df)

In [59]:
#Calculating the GNN representations
train_with_rep = get_substrate_representations(df = train_df, training_set = True, testing_set = False,
                                                      get_fingerprint_fct = get_fingerprint_fct)
test_with_rep = get_substrate_representations(df = test_df, training_set = False, testing_set = True,
                                                     get_fingerprint_fct = get_fingerprint_fct)
val_with_rep = get_substrate_representations(df = df_validation, training_set = False, testing_set = False,
                                                     get_fingerprint_fct = get_fingerprint_fct)

#Saving the DataFrames:
train_with_rep.to_pickle(join(datasets_dir, "splits", split, "training_data.pkl"))
test_with_rep.to_pickle(join(datasets_dir, "splits", split, "test_data.pkl"))
val_with_rep.to_pickle(join(datasets_dir, "splits", split, "val_data.pkl"))

0
