In [1]:
from autograd import grad

import autograd.numpy as np
import pickle as pkl
import json
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
# Open the data file that contains the 
df = pd.read_csv('../data/hiv_data/hiv-protease-data-expanded.csv', index_col=0)
df = df.dropna(subset=['FPV'])
df

Unnamed: 0,ATV,DRV,FPV,IDV,LPV,NFV,SQV,SeqID,TPV,seqid,sequence,sequence_object,weight
0,,,2.5,16.3,,38.6,16.1,2996,,2996-0,PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEDVNLPGRWKPKM...,ID: 2996-0\nName: <unknown name>\nDescription:...,0.50000
1,,,2.5,16.3,,38.6,16.1,2996,,2996-1,PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEDVNLPGRWKPKM...,ID: 2996-1\nName: <unknown name>\nDescription:...,0.50000
2,,,0.7,0.8,,0.8,1.1,4387,,4387-0,PQITLWQRPLVTIKVGGQLKEALLDTGADDTVLEDMELPGRWKPKM...,ID: 4387-0\nName: <unknown name>\nDescription:...,0.25000
3,,,0.7,0.8,,0.8,1.1,4387,,4387-1,PQITLWQRPLVTIKVGGQLKEALLDTGADDTVLEDMELPGRWKPKM...,ID: 4387-1\nName: <unknown name>\nDescription:...,0.25000
4,,,0.7,0.8,,0.8,1.1,4387,,4387-2,PQITLWQRPLVTIKVGGQLKEALLDTGADDTVLEDMELPGRWKPKM...,ID: 4387-2\nName: <unknown name>\nDescription:...,0.25000
5,,,0.7,0.8,,0.8,1.1,4387,,4387-3,PQITLWQRPLVTIKVGGQLKEALLDTGADDTVLEDMELPGRWKPKM...,ID: 4387-3\nName: <unknown name>\nDescription:...,0.25000
6,32.0,,3.0,35.0,32.0,29.0,164.0,4426,,4426-0,PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEEMNLPGKWKPKM...,ID: 4426-0\nName: <unknown name>\nDescription:...,1.00000
7,,,1.5,1.0,,2.2,1.1,4432,,4432-0,PQITLWQRPLVTVKIGGQLKEALLDTGADDTVLEEMNLPGRWKPKM...,ID: 4432-0\nName: <unknown name>\nDescription:...,1.00000
8,,,3.9,20.2,,21.6,9.2,4482,,4482-0,PQITLWQRPVVTIKIGGQLKEALLDTGADDTVLEEINLPGRWKPKL...,ID: 4482-0\nName: <unknown name>\nDescription:...,0.50000
9,,,3.9,20.2,,21.6,9.2,4482,,4482-1,PQITLWQRPVVTIKIGGQLKEALLDTGADDTVLEDINLPGRWKPKL...,ID: 4482-1\nName: <unknown name>\nDescription:...,0.50000


In [75]:
# Open the numpy array of all graphs' data.
graph_arr = np.load('../data/feat_array.npy')

In [76]:
# Open the pickles that contain the graph information and node-nbr information.
def unpickle_data(path):
    with open(path, 'rb') as f:
        data = pkl.load(f)
    return data

graph_idxs = unpickle_data('../data/graph_idxs.pkl')
graph_nodes = unpickle_data('../data/graph_nodes.pkl')
nodes_nbrs = unpickle_data('../data/nodes_nbrs.pkl')

In [77]:
list(graph_idxs.keys())[0:5]
# len(graph_idxs.keys())

['81010-0', '122867-0', '112451-3', '235691-7', '81905-3']

In [78]:
list(graph_nodes.items())[0]

('81010-0',
 {370637: 'A70LYS',
  370638: 'A19LEU',
  370639: 'B73GLY',
  370640: 'B49GLY',
  370641: 'A12THR',
  370642: 'B13VAL',
  370643: 'A24LEU',
  370644: 'B68GLY',
  370645: 'B4THR',
  370646: 'A88ASN',
  370647: 'A37ASN',
  370648: 'A63LEU',
  370649: 'B2GLN',
  370650: 'A29ASP',
  370651: 'B90LEU',
  370652: 'A6TRP',
  370653: 'A23LEU',
  370654: 'A3ILE',
  370655: 'B16GLY',
  370656: 'A36ILE',
  370657: 'B75VAL',
  370658: 'A42TRP',
  370659: 'A34GLU',
  370660: 'A65GLU',
  370661: 'A53PHE',
  370662: 'A85ILE',
  370663: 'B6TRP',
  370664: 'A94GLY',
  370665: 'A93ILE',
  370666: 'A52GLY',
  370667: 'B66ILE',
  370668: 'B63LEU',
  370669: 'A7GLN',
  370670: 'B52GLY',
  370671: 'B32VAL',
  370672: 'A2GLN',
  370673: 'B79PRO',
  370674: 'A22ALA',
  370675: 'B28ALA',
  370676: 'B94GLY',
  370677: 'B22ALA',
  370678: 'B20ILE',
  370679: 'B40GLY',
  370680: 'B76LEU',
  370681: 'B15ILE',
  370682: 'A28ALA',
  370683: 'B57LYS',
  370684: 'B67GLU',
  370685: 'A89MET',
  370686: 'A59T

In [79]:
list(nodes_nbrs.items())[0]

(0, [0, 23, 46])

In [209]:
# Keep track of only those that are in both the graph_idxs and in the df['seqid']
intersect = set(df['seqid'].values).intersection(graph_idxs.keys())
len(intersect)

3200

In [210]:
# Get a reduced list of graph_idxs.
graph_idxs_red = dict()
graph_nodes_red = dict()
for g in intersect:
    graph_idxs_red[g] = graph_idxs[g]
    graph_nodes_red[g] = graph_nodes[g]

In [211]:
graph_idxs_red['46213-0']

[535066,
 535067,
 535068,
 535069,
 535070,
 535071,
 535072,
 535073,
 535074,
 535075,
 535076,
 535077,
 535078,
 535079,
 535080,
 535081,
 535082,
 535083,
 535084,
 535085,
 535086,
 535087,
 535088,
 535089,
 535090,
 535091,
 535092,
 535093,
 535094,
 535095,
 535096,
 535097,
 535098,
 535099,
 535100,
 535101,
 535102,
 535103,
 535104,
 535105,
 535106,
 535107,
 535108,
 535109,
 535110,
 535111,
 535112,
 535113,
 535114,
 535115,
 535116,
 535117,
 535118,
 535119,
 535120,
 535121,
 535122,
 535123,
 535124,
 535125,
 535126,
 535127,
 535128,
 535129,
 535130,
 535131,
 535132,
 535133,
 535134,
 535135,
 535136,
 535137,
 535138,
 535139,
 535140,
 535141,
 535142,
 535143,
 535144,
 535145,
 535146,
 535147,
 535148,
 535149,
 535150,
 535151,
 535152,
 535153,
 535154,
 535155,
 535156,
 535157,
 535158,
 535159,
 535160,
 535161,
 535162,
 535163,
 535164,
 535165,
 535166,
 535167,
 535168,
 535169,
 535170,
 535171,
 535172,
 535173,
 535174,
 535175,
 535176,
 

In [212]:
idxs = np.concatenate([v for k, v in graph_idxs_red.items()])
idxs

array([370637, 370638, 370639, ..., 576544, 576545, 576546])

In [213]:
graph_arr.shape

(659895, 36)

In [214]:
# Create a reduced graph_array that is ordered correctly.
graph_arr_fin = np.zeros(shape=graph_arr[idxs].shape)
graph_arr_fin.shape

(622671, 36)

In [215]:
# Make one pass over the data to get the old/new index mapping, and
# make the final graph_array that gets passed in as an input.

# Here I try the Python version. Cython version a few cells below.
def reindex_data(graph_idxs_red, graph_arr_fin, graph_arr):

    # Create maps of graph indices from the old to the new.
    nodes_oldnew = dict()  # {old_idx: new_idx}.
    nodes_newold = dict()  # {new_idx: old_idx}

    curr_idx = 0
    for seqid, idxs in graph_idxs_red.items():
        for idx in idxs:
            nodes_oldnew[idx] = curr_idx
            nodes_newold[curr_idx] = idx
            graph_arr_fin[curr_idx] = graph_arr[idx]
            curr_idx += 1
    return graph_arr_fin, nodes_oldnew, nodes_newold

# %timeit reindex_data(graph_idxs_red, graph_arr_fin, graph_arr)
graph_arr_fin, nodes_oldnew, nodes_newold = reindex_data(graph_idxs_red, graph_arr_fin, graph_arr)

In [216]:
%timeit reindex_data(graph_idxs_red, graph_arr_fin, graph_arr)

1 loop, best of 3: 709 ms per loop


In [217]:
import cython

%load_ext cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [218]:
%%cython

cimport numpy as np

def cython_reindex_data(dict graph_idxs_red, np.ndarray graph_arr_fin, np.ndarray graph_arr):
    return reindex_data_c(graph_idxs_red, graph_arr_fin, graph_arr)

cdef reindex_data_c(dict graph_idxs_red, np.ndarray graph_arr_fin, np.ndarray graph_arr):
    cdef nodes_newold = dict()
    cdef nodes_oldnew = dict()

    cdef int curr_idx = 0
    cdef str seqid
    cdef list idxs
    cdef int idx
    for seqid, idxs in graph_idxs_red.items():
        for idx in idxs:
            nodes_oldnew[idx] = curr_idx
            nodes_newold[curr_idx] = idx
            graph_arr_fin[curr_idx] = graph_arr[idx]
            curr_idx += 1
    return graph_arr_fin, nodes_oldnew, nodes_newold

In [219]:
%timeit cython_reindex_data(graph_idxs_red, graph_arr_fin, graph_arr)

1 loop, best of 3: 606 ms per loop


In [220]:
# Check a random sample of the indices to make sure that they are sampled correctly.
from random import sample

n_samples = 10000
rnd_idxs = sample([i for i in range(graph_arr_fin.shape[0])], n_samples)
for new_idx in tqdm(rnd_idxs):
    assert np.all(np.equal(graph_arr_fin[new_idx], graph_arr[nodes_newold[new_idx]]))

100%|██████████| 10000/10000 [00:00<00:00, 48780.51it/s]


In [221]:
graph_arr_fin.shape

(622671, 36)

In [222]:
# Finally, rework the nodes_nbrs, graph_idxs, and graph_nodes dictionaries with the corrected idxs.
# THIS IS THE KEY STEP! MUST ENCAPSULATE IN A FUNCTION!
from collections import defaultdict

def reindex_nodes_and_neighbors(nodes_nbrs, nodes_oldnew):
    """
    - nodes_nbrs: a dictionary of nodes and their neighbors.
    - nodes_oldnew: a dictionary mapping old node indices to their new node indices.
    """
    nodes_nbrs_fin = defaultdict(list)
    for node, nbrs in tqdm(nodes_nbrs.items()):
        if node in nodes_oldnew.keys():
            for nbr in nbrs:
                nodes_nbrs_fin[nodes_oldnew[node]].append(nodes_oldnew[nbr])
    return nodes_nbrs_fin

nodes_nbrs_fin = reindex_nodes_and_neighbors(nodes_nbrs, nodes_oldnew)

def reindex_graph_idxs(graph_idxs, nodes_oldnew):
    """
    - graph_idxs: a dictionary of graphs and their original indices.
    - nodes_oldnew: a dictionary mapping old node indices to their new node indices.
    """
    graph_idxs_fin = defaultdict(list)
    for seqid, nodes in tqdm(graph_idxs.items()):
        for node in nodes:
            if node in nodes_oldnew.keys():
                graph_idxs_fin[seqid].append(nodes_oldnew[node])
    return graph_idxs_fin

graph_idxs_fin = reindex_graph_idxs(graph_idxs, nodes_oldnew)


def reindex_graph_nodes(graph_nodes, nodes_oldnew):
    """
    - graph_nodes: a dictionary mapping graphs to their dictionary mapping indices to node names.
    - nodes_oldnew: a dictionary mapping old node indices to their new node indices.
    """    
    graph_nodes_fin = defaultdict(dict)
    for seqid, idx_node in tqdm(graph_nodes.items()):
        for old_idx, node_name in idx_node.items():
            if old_idx in nodes_oldnew.keys():
                graph_nodes_fin[seqid][nodes_oldnew[old_idx]] = node_name
    return graph_nodes_fin

graph_nodes_fin = reindex_graph_nodes(graph_nodes, nodes_oldnew)

100%|██████████| 659895/659895 [00:02<00:00, 248879.11it/s]
100%|██████████| 3392/3392 [00:00<00:00, 12648.98it/s]
100%|██████████| 3392/3392 [00:00<00:00, 11779.13it/s]


In [223]:
from graphfp.layers import FingerprintLayer, LinearRegressionLayer, GraphConvLayer
from graphfp.utils import initialize_network

layers = [# GraphConvLayer((36, 36)),
          GraphConvLayer((36, 36)),
          FingerprintLayer(shape=36),
          LinearRegressionLayer(shape=(36,1))]

wb = initialize_network(input_shape=graph_arr_fin.shape, layers_spec=layers)
# print(wb)


def predict(wb_struct, inputs, nodes_nbrs, graph_idxs, layers):
    curr_inputs = inputs
    for i, layer in enumerate(layers):
        wb = wb_struct['layer{0}_{1}'.format(i, layer)]
        curr_inputs = layer.forward_pass(wb, inputs, nodes_nbrs, graph_idxs)
    return curr_inputs

In [164]:
from time import time
start = time()
predict(wb, graph_arr_fin, nodes_nbrs_fin, graph_idxs_fin, layers)
end = time()
print(end - start)

14.908802032470703


In [None]:
def train_loss(wb_vect, unflattener):