In [617]:
! which python3
! python3 -V

/Users/nick/Dev/magisterka/venv/bin/python3
Python 3.9.6


In [618]:
import typing as t
import jraph
import jax.numpy as jnp
import networkx as nx
import metadata
import haiku as hk
import utils
import jax
import optax

In [619]:
WINDOW = int(31*24*60/5)
VIS = False

In [620]:
def create_graph_tuple(senders, receivers, edges) -> jraph.GraphsTuple:
    return jraph.GraphsTuple(
        nodes=None,
        edges=edges,
        senders=senders,
        receivers=receivers,
        n_node=jnp.asarray([metadata.NUM_NODES]),
        n_edge=jnp.asarray([metadata.NUM_EDGES]),
        globals=None)

In [621]:
_senders = []
_receivers = []

for x, row in enumerate(metadata.ADJACENCY_MATRIX):
    for y, _ in enumerate(row):
        if metadata.ADJACENCY_MATRIX[x][y]:
            _senders.append(x)
            _receivers.append(y)

senders = jnp.array(_senders)
receivers = jnp.array(_receivers)
nodes = jnp.array([1] * metadata.NUM_NODES)
edges = jnp.array([1] * metadata.NUM_EDGES)

graph_features = jnp.array([[1]])
graph = create_graph_tuple(senders, receivers, edges)

if VIS:
    utils.draw_jraph_graph_structure(graph)

### Load data

In [622]:
import pandas as pd

In [623]:
df = pd.read_pickle("../data/samples_5m_subset_v1.pkl")

In [624]:
def get_data_for_link(src: str, dst: str, n: int) -> jnp.array:
    ss: pd.Series = df[(df["src_host"] == src) & (df["dst_host"] == dst)]
    return ss.incoming_rate_avg.to_numpy()[:n]

In [625]:
_map = metadata.NODE_IDS_TO_LABELS_MAPPING

_senders = []
_receivers = []
_edges = jnp.array([[]])

for x, row in enumerate(metadata.ADJACENCY_MATRIX):
    for y, _ in enumerate(row):
        if metadata.ADJACENCY_MATRIX[x][y]:
            _edges = jnp.append(_edges, 
                get_data_for_link(_map[x], _map[y], WINDOW))
                
            _senders.append(x)
            _receivers.append(y)

senders = jnp.array(_senders)
receivers = jnp.array(_receivers)
edges = jnp.array(_edges.reshape((metadata.NUM_EDGES, WINDOW)))

In [626]:
graph = create_graph_tuple(senders, receivers, edges)

In [627]:
if VIS:
    utils.draw_jraph_graph_structure(graph)

### Model

In [628]:
HIDDEN_SIZE = 64
LEARNING_RATE = 1e-4
SPLIT = 0.9
BLOCK_SIZE = int(24*60/5)  # 1 day
BATCH_SIZE = 32
SEED = 1237
NUM_TRAINING_STEPS = 500

In [629]:
n = int(SPLIT * edges.shape[1])
d_train = edges[0][:n]
d_val = edges[0][n:]

In [630]:
d_train = jnp.log(d_train)
d_val = jnp.log(d_val)

In [631]:
d_val.shape, d_train.shape, edges.shape

((893,), (8035,), (15, 8928))

In [632]:
class TrainingState(t.NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

Batch = t.Mapping[str, jnp.ndarray]

In [633]:
the_seed = hk.PRNGSequence(SEED)

def get_batch(split: str) -> Batch:
    data = d_train if split == "train" else d_val
    ixs = jax.random.randint(next(the_seed), (BATCH_SIZE, ), 0, len(data) - BLOCK_SIZE)
    x = jnp.stack([data[i:i+BLOCK_SIZE] for i in ixs])
    y = jnp.stack([data[i+1:i+BLOCK_SIZE+1] for i in ixs])
    return {'input': x, 'target': y}

In [634]:
expl_batch = get_batch("train")

In [635]:
def make_network() -> hk.RNNCore:
    """Defines the network architecture."""
    model = hk.DeepRNN([
        hk.Linear(HIDDEN_SIZE, name="linear"),
        jax.nn.relu,
        hk.LSTM(HIDDEN_SIZE, name="lstm1"),
        jax.nn.relu,
        hk.LSTM(HIDDEN_SIZE, name="lstm2"),
        hk.Linear(1),
    ])
    return model

def make_optimizer() -> optax.GradientTransformation:
    """Defines the optimizer."""
    return optax.adam(LEARNING_RATE)

In [636]:
def sequence_loss(batch: Batch) -> jnp.ndarray:
    """Unrolls the network over a sequence of inputs & targets, gets loss."""
    core = make_network()
    batch_size, sequence_length = batch['input'].shape  # (B, T) 

    initial_state = core.initial_state(batch_size)
    _input = jnp.expand_dims(batch['input'], -1) # (B, T, 1) 
    
    logits, _ = hk.dynamic_unroll(core, _input, initial_state, time_major=False)
    # logits -> (B, T, 1)
    loss = jnp.mean(jnp.square(logits[:, :, -1] - batch['target']))
    return loss

In [637]:
@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
    """Does a step of SGD given inputs & targets."""
    _, optimizer = make_optimizer()
    _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    gradients = jax.grad(loss_fn)(state.params, batch)
    updates, new_opt_state = optimizer(gradients, state.opt_state)
    new_params = optax.apply_updates(state.params, updates)
    return TrainingState(params=new_params, opt_state=new_opt_state)

In [638]:
rng = hk.PRNGSequence(SEED)

init_params_fn, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
initial_params = init_params_fn(next(rng), expl_batch)
opt_init, _ = make_optimizer()
initial_opt_state = opt_init(initial_params)

# de facto initial state
state = TrainingState(params=initial_params, opt_state=initial_opt_state)

In [639]:
loss_fn = jax.jit(loss_fn)

In [640]:
state = update(state, expl_batch)

for step in range(NUM_TRAINING_STEPS):
    train_batch = get_batch("train")
    state = update(state, train_batch)

    if step % 50 == 0:
        eval_batch = get_batch("eval")
        loss = loss_fn(state.params, eval_batch)
        print({
            'step': step,
            'loss': float(loss),
        })

{'step': 0, 'loss': 299.6245422363281}
{'step': 50, 'loss': 256.5135192871094}
{'step': 100, 'loss': 204.96331787109375}
{'step': 150, 'loss': 169.24688720703125}
{'step': 200, 'loss': 141.93331909179688}
{'step': 250, 'loss': 122.60469818115234}
{'step': 300, 'loss': 109.05988311767578}
{'step': 350, 'loss': 97.44053649902344}
{'step': 400, 'loss': 88.46305084228516}
{'step': 450, 'loss': 81.1367416381836}
