In [117]:
import pandas as pd
import networkx as nx
import pickle

In [118]:
G = pickle.load(open('yeast_G.pickle', 'rb'))

print(f'# Nodes: {G.number_of_nodes()} \n# Edges: {G.number_of_edges()}')

# Nodes: 373 
# Edges: 6157


## Add node features

In [119]:
import cobra

cobra_model = cobra.io.load_json_model('redYeast_ST8943_fdp1.json')

In [120]:
# Add number of metabolites in every reaction as a node feature
for node, data in G.nodes(data=True):
    if "rev?" in node: rxn_name = node.split("?")[1]
    else: rxn_name = node
    
    num_metabolites = len(cobra_model.reactions.get_by_id(rxn_name).metabolites)
    data['x'] = num_metabolites

## Read ORACLE's data

In [121]:
sigma = pd.read_csv('saturations.csv', index_col=0)
gamma = pd.read_csv('gamma.csv', index_col=0)
vmax = pd.read_csv('Vmax_matrix.csv', index_col=0)

gamma = gamma.head(1)
sigma = sigma.head(1)
vmax = vmax.head(1)

In [122]:
# get the reactions that are the reversible 
rev_rxn = []
for node in list(G.nodes()):
    if node.split("?")[0] == 'rev': rev_rxn.append(node.split("?")[1])

# rename the reactions of gamma; if it's the reversible one add rev? to the column name
for col in gamma.columns:
    if col in rev_rxn: gamma.rename(columns={col:'rev?'+col}, inplace=True)

In [123]:
listA = list(G.nodes())
listB = gamma.columns.values

print('In Graph but not in gamma:', [item for item in listA if item not in listB])
print()
print('In gamma but not in Graph:', [item for item in listB if item not in listA])

In Graph but not in gamma: ['EX_lac__D_e', 'EX_mal__L_e', 'EX_akg_e', 'EX_2phetoh_e', 'EX_acald_e', 'EX_ac_e', 'EX_gam6p_e', 'EX_co2_e', 'EX_cit_e', 'EX_etoh_e', 'EX_fum_e', 'EX_gly_e', 'EX_gcald_e', 'EX_glx_e', 'EX_id3acald_e', 'EX_ala__L_e', 'EX_asn__L_e', 'EX_asp__L_e', 'EX_cys__L_e', 'EX_glu__L_e', 'EX_gln__L_e', 'EX_phe__L_e', 'EX_ser__L_e', 'EX_trp__L_e', 'EX_tyr__L_e', 'EX_oaa_e', 'EX_pacald_e', 'EX_pyr_e', 'EX_succ_e', 'EX_ind3eth_e', 'EX_h2o_e', 'EX_g6p_e', 'EX_g1p_e', 'EX_2pg_e', 'EX_pser__L_e', 'EX_ppi_e', 'EX_pep_e', 'EX_cbp_e', 'EX_6pgc_e', 'EX_3pg_e', 'EX_cmp_e', 'GROWTH', 'EX_ccm_e', 'EX_pca_e', 'rev?EX_nh4_e', 'rev?EX_glc__D_e', 'rev?EX_h_e', 'rev?EX_fe2_e', 'rev?EX_o2_e', 'rev?EX_pi_e', 'rev?EX_k_e', 'rev?EX_na1_e', 'rev?EX_so4_e', 'rev?EX_cl_e', 'rev?EX_cu2_e', 'rev?EX_mn2_e', 'rev?EX_zn2_e', 'rev?EX_mg2_e', 'rev?EX_ca2_e']

In gamma but not in Graph: []


In [124]:
# Drop nodes without gamma
G.remove_nodes_from([item for item in listA if item not in listB])
print(f'# Nodes: {G.number_of_nodes()} \n# Edges: {G.number_of_edges()}')

# Nodes: 314 
# Edges: 6090


In [125]:
# reactions with gamma > 1
rxn_bad_gamma = gamma.columns[(gamma > 1).any()].tolist()
gamma.drop(columns=rxn_bad_gamma, inplace=True)
print(gamma.shape)

# Drop nodes with gamma > 0
G.remove_nodes_from(rxn_bad_gamma)
print(f'# Nodes: {G.number_of_nodes()} \n# Edges: {G.number_of_edges()}')

(1, 243)
# Nodes: 243 
# Edges: 3793


#### Add `gamma` values as Graph node features

In [126]:
for node in gamma.columns:
    try:
        G.nodes[node]['y'] =  gamma[node].values[0]
    except KeyError:
        pass

no_gamma_nodes = [node for node, data in G.nodes(data=True) if not data]

for node in no_gamma_nodes: G.nodes[node]['y'] = np.nan

#### Maybe, the Graph is ready afterall...

## Networkx to Torch Geometric

In [156]:
import torch
import torch.nn as nn
from torch_geometric.utils.convert import from_networkx

device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = from_networkx(G, group_edge_attrs=all)
data.x = data.x.view(-1,1).float()
data.y = data.y.float()

print(data)
print()
print(data.num_nodes ,data.num_edges)

Data(x=[243, 1], edge_index=[2, 3793], y=[243], edge_attr=[3793, 1])

243 3793


In [159]:
train_size = int(len(data.x) * 0.8)
val_size = int(len(data.x) * 0.1)
test_size = int(len(data.x) * 0.1)

# Create train, validation, and test masks
train_mask = torch.zeros(len(data.x), dtype=torch.bool)
train_mask[:train_size] = 1

val_mask = torch.zeros(len(data.x), dtype=torch.bool)
val_mask[train_size:train_size + val_size] = 1

test_mask = torch.zeros(len(data.x), dtype=torch.bool)
test_mask[train_size + val_size:] = 1

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

data

Data(x=[243, 1], edge_index=[2, 3793], y=[243], edge_attr=[3793, 1], train_mask=[243], val_mask=[243], test_mask=[243])

## Create a GNN

In [170]:
import torch.nn.functional as F
from torch.optim import Optimizer
from torch_geometric.nn import GCNConv
from torch import Tensor
import torch.nn as nn
import torch

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.conv = GCNConv(input_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.linear = nn.Linear(hidden_dim, 1)

    def forward(self, x: Tensor, edge_index: Tensor) -> torch.Tensor:

        x = self.conv(x, edge_index)
        x = self.relu(x)
        x = self.linear(x)
        
        return x

In [175]:
SEED = 42
MAX_EPOCHS = 1000
LEARNING_RATE = .01

INPUT_DIM = data.num_features
HIDDEN_DIM = 128

model = GCN(INPUT_DIM, HIDDEN_DIM)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.MSELoss()

" ************ TRAIN FUNCTION **************"
def train():

    model.train() # Tells the model that we are in training mode
    optimizer.zero_grad() # Resets the gradient
    
    y_pred = model(data.x, data.edge_index) # predicted y
    y_true = data.y # True labels
    
    loss = criterion(y_pred[data.train_mask], data.y[data.train_mask])
    loss.backward()
    
    optimizer.step()
    return loss

" ************ TEST FUNCTION **************"
def test(mask):

    model.eval() # Tells the model that we are in testing mode

    y_pred = model(data.x, data.edge_index) # Preds for all data
    y_true = data.y
    
    mse = criterion(y_pred[mask], y_true[mask])

    return mse        

In [176]:
VAL_ACCURACY = []
TEST_ACCURACY = []

for epoch in range(MAX_EPOCHS):

    loss = train()

    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)

    VAL_ACCURACY.append(val_acc)
    TEST_ACCURACY.append(test_acc)

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, \
        Test: {test_acc:.4f}')


Epoch: 000, Loss: 0.3274, Val: 0.6975,         Test: 1.0574
Epoch: 001, Loss: 0.6142, Val: 0.4700,         Test: 0.6652
Epoch: 002, Loss: 0.3222, Val: 0.3577,         Test: 0.4285
Epoch: 003, Loss: 0.3066, Val: 0.3601,         Test: 0.4006
Epoch: 004, Loss: 0.4226, Val: 0.3426,         Test: 0.3902
Epoch: 005, Loss: 0.3680, Val: 0.3352,         Test: 0.4191
Epoch: 006, Loss: 0.2709, Val: 0.3732,         Test: 0.5198
Epoch: 007, Loss: 0.2623, Val: 0.4219,         Test: 0.6237
Epoch: 008, Loss: 0.3131, Val: 0.4264,         Test: 0.6403
Epoch: 009, Loss: 0.3264, Val: 0.3829,         Test: 0.5656
Epoch: 010, Loss: 0.2855, Val: 0.3265,         Test: 0.4597
Epoch: 011, Loss: 0.2435, Val: 0.2876,         Test: 0.3770
Epoch: 012, Loss: 0.2406, Val: 0.2704,         Test: 0.3336
Epoch: 013, Loss: 0.2645, Val: 0.2622,         Test: 0.3169
Epoch: 014, Loss: 0.2752, Val: 0.2556,         Test: 0.3166
Epoch: 015, Loss: 0.2576, Val: 0.2552,         Test: 0.3358
Epoch: 016, Loss: 0.2317, Val: 0.2656,  

In [177]:
y_pred = model(data.x, data.edge_index) # Preds for all data
y_pred = y_pred[data.test_mask]
y_pred

tensor([[0.5263],
        [0.5264],
        [0.5263],
        [0.5255],
        [0.5263],
        [0.5261],
        [0.5263],
        [0.5261],
        [0.5263],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5262],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5261],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5264],
        [0.5264]], grad_fn=<IndexBackward0>)

In [178]:
y_true = data.y[data.test_mask]
y_true

tensor([9.9900e-01, 9.9877e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01,
        9.9900e-01, 7.4300e-08, 6.9395e-02, 9.9900e-01, 9.9900e-01, 7.6738e-01,
        9.9900e-01, 9.9900e-01, 5.1849e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01,
        9.9900e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01, 9.9900e-01,
        9.9900e-01])