In [19]:
import networkx as nx
from gym_kidney import _solver
from gym.utils import seeding
from gym_kidney import models
import numpy as np

In [20]:
%matplotlib inline
import matplotlib.pyplot as plt

First: generate graphs somehow, according to the models.

One dumb idea, give some a long number of ticks and others a short number of ticks.

In [21]:
def _default_model():
    M = 128
    K = 1024
    K = 580
    P = 0.05
    P_A = 0.05
    LEN = 3*K
    MODEL = models.HomogeneousModel(M, K, P, P_A, LEN)
    return MODEL

DEFAULT_MODEL = _default_model()

rng, seed = seeding.np_random(1)


In [36]:
def gen_random_graph(model, rng, n_steps=300):
    G = nx.DiGraph()
    for i in range(n_steps):
        G = model.arrive(G,rng)
    return G

In [23]:
def relabel(G):
    n_dd, n_ndd = 0, 0
    d_dd, d_ndd = {}, {}

    for u in G.nodes():
        if G.node[u]["ndd"]:
            d_ndd[u] = n_ndd
            n_ndd += 1
        else:
            d_dd[u] = n_dd
            n_dd += 1

    return n_dd, n_ndd, d_dd, d_ndd

def nx_to_ks(G):
    n_dd, n_ndd, d_dd, d_ndd = relabel(G)

    dd = _solver.Digraph(n_dd)
    for u, v, d in G.edges(data = True):
        if not G.node[u]["ndd"]:
            dd.add_edge(
                d["weight"] if ("weight" in d) else 1.0,
                dd.vs[d_dd[u]],
                dd.vs[d_dd[v]])

    ndds = [_solver.kidney_ndds.Ndd() for _ in range(n_ndd)]
    for u, v, d in G.edges(data = True):
        if G.node[u]["ndd"]:
            edge = _solver.kidney_ndds.NddEdge(
                dd.vs[d_dd[v]],
                d["weight"] if ("weight" in d) else 1.0)
            ndds[d_ndd[u]].add_edge(edge)

    return dd, ndds


In [24]:
def solve_graph(G, cycle_cap=3, chain_cap=3):
    dd, ndd = nx_to_ks(G)
    cfg = _solver.kidney_ip.OptConfig(
            dd,
            ndd,
            cycle_cap,
            chain_cap)
    soln  = _solver.solve_kep(cfg, "picef")
    rew_cycles = sum(map(lambda x: len(x), soln.cycles))
    rew_chains = sum(map(lambda x: len(x.vtx_indices), soln.chains))
    reward = rew_cycles + rew_chains
    
    return reward

In [25]:
def make_graph_score_pair(rng):
    gr = gen_random_graph(DEFAULT_MODEL, rng)
    score = solve_graph(gr)
    return (gr, score)

In [66]:
from tqdm import tqdm

In [71]:
dataset = [make_graph_score_pair(rng) for _ in tqdm(range(10000))]
validation_set = [make_graph_score_pair(rng) for _ in tqdm(range(1000))]

100%|██████████| 10000/10000 [01:43<00:00, 96.31it/s]
100%|██████████| 1000/1000 [00:10<00:00, 92.01it/s]


In [27]:
def adjmat(gr):
    return nx.adjacency_matrix(gr).toarray().astype('float32')

In [44]:
def zero_padded_adjmat(graph, size):
    unpadded = adjmat(graph)
    padded = np.zeros((size, size))
    padded[0:unpadded.shape[0], 0:unpadded.shape[1]] = unpadded
    padded = np.reshape(padded, (padded.shape[0], padded.shape[1], 1))
    return padded
    

next step: tf really dumb mlp model

In [29]:
import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, MaxPooling2D, Conv2D
from tensorflow.keras.models import Model

In [46]:
def mlp_model(input_size=100):
    input_im = Input(shape=(input_size, input_size, 1)) # may as well be compatible with cnn
    flat = Flatten()(input_im)
    l1 = Dense(100, activation='relu')(flat)
    l2 = Dense(20, activation='relu')(l1)
    output = Dense(1, activation='relu')(l2)
    mlp_model = Model(input_im, output)
    return mlp_model

In [63]:
def cnn_model(input_size=100):
    input_im = Input(shape=(input_size, input_size,1))
    layer = Conv2D(32, (3, 3), activation='relu', padding='same')(input_im)
    layer = MaxPooling2D((2, 2), padding='same')(layer)
    layer = Conv2D(16, (3, 3), activation='relu', padding='same')(layer)
    layer = MaxPooling2D((2, 2), padding='same')(layer)
    layer = Conv2D(16, (3, 3), activation='relu', padding='same')(layer)
    layer = MaxPooling2D((2, 2), padding='same')(layer)
    layer = Flatten()(layer)
    layer = Dense(32, activation='relu')(layer)
    output = Dense(1, activation='relu')(layer)
    cnn_model = Model(input_im, output)
    return cnn_model

In [64]:
cnn_model().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         (None, 100, 100, 1)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 100, 100, 32)      320       
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 50, 50, 32)        0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 50, 50, 16)        4624      
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 25, 25, 16)        0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 25, 25, 16)        2320      
_________________________________________________________________
max_pooling2d_11 (MaxPooling (None, 13, 13, 16)        0         
__________

In [47]:
mlp_model().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 100, 100, 1)       0         
_________________________________________________________________
flatten_5 (Flatten)          (None, 10000)             0         
_________________________________________________________________
dense_13 (Dense)             (None, 100)               1000100   
_________________________________________________________________
dense_14 (Dense)             (None, 20)                2020      
_________________________________________________________________
dense_15 (Dense)             (None, 1)                 21        
Total params: 1,002,141
Trainable params: 1,002,141
Non-trainable params: 0
_________________________________________________________________


In [72]:
graph_mats = np.stack([zero_padded_adjmat(g, 100) for g, _ in dataset])
graph_scores = np.expand_dims(np.stack([x for _, x in dataset]), axis=1).astype('float32')

val_mats = np.stack([zero_padded_adjmat(g, 100) for g, _ in validation_set])
val_scores = np.expand_dims(np.stack([x for _, x in validation_set]), axis=1).astype('float32')

In [74]:
mlp = mlp_model()
mlp.compile(optimizer='adadelta', loss='mse')
mlp.fit(graph_mats, graph_scores, epochs=50, batch_size=100, shuffle=True, validation_data=(val_mats, val_scores))

Train on 10000 samples, validate on 1000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x7f9b7b08f940>

In [76]:
cnn = cnn_model()
cnn.compile(optimizer='adadelta', loss='mse')
cnn.fit(graph_mats, graph_scores, epochs=50, batch_size=100, shuffle=True, validation_data=(val_mats, val_scores))

Train on 10000 samples, validate on 1000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x7f9b7b200f28>

In [62]:
cnn.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         (None, 100, 100, 1)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 100, 100, 32)      320       
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 50, 50, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 50, 50, 16)        4624      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 25, 25, 16)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 25, 25, 16)        2320      
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 13, 13, 16)        0         
__________