<a href="https://colab.research.google.com/github/kbrezinski/GAT-Malware/blob/main/intr_to_lax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 136 kB 8.3 MB/s 
[K     |████████████████████████████████| 65 kB 2.3 MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [26]:
import jax
import flax
import optax
import numpy as np
import jax.numpy as jnp

from flax import linen as nn
from flax.core import freeze, unfreeze
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [36]:
def init_model(layer_widths, parent_key):

  def random_layer_params(n_in, n_out, key, scale=1e-1):
    k1, k2 = jax.random.split(key)
    return (scale * jax.random.normal(k1, (n_out, n_in)),
            scale * jax.random.normal(k2, (n_out, )))
    
  # split keys for each layer
  keys = jax.random.split(parent_key, num=len(layer_widths) - 1)

  return [random_layer_params(n_in, n_out, key)
        for n_in, n_out, key in zip(layer_widths[:-1], layer_widths[1:], keys)]

def predict(params, x):
  activations = x

  for w, b in params[:-1]:
    activations = jax.nn.relu(jnp.dot(w, activations) + b)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return jax.nn.sigmoid(logits)

def accuracy(params, x, y):
  preds = batched_predict(params, x)
  predicted_class = jnp.round(preds)
  return jnp.mean(predicted_class == y)

def loss(params, x, y, eps=1e-14):
  preds = batched_predict(params, x)
  preds = jnp.clip(preds, eps, 1 - eps)  # bound the probabilities to avoid log(0)
  return -jnp.mean(y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds))

@jit
def update(params, x, y, lr=1e-5):
  curr_loss, grads = value_and_grad(loss)(params, x, y)
  return curr_loss, jax.tree_multimap(lambda p, g: p - lr*g, params, grads)

# make bce_loss for 
def make_bce_loss(xs, ys):

  def bce_loss(params):
    def cross_entropy(x, y):
      preds = model.apply(params, x)
      return y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
    return jnp.mean(jax.vmap(cross_entropy)(xs, ys), axis=0)

  return jax.jit(bce_loss)

batched_predict = vmap(predict, in_axes=(None, 0))

In [37]:
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

# constants
seed = 2021
feature_dim = 3
output_dim = 1
parent_key = jax.jax.random.PRNGKey(1001)

# Generate fake data based on blobs with 2 features
centers = [[1]*feature_dim, [2]*feature_dim]
X, y = make_blobs(n_samples=[500, 500], centers=centers, cluster_std=1.)

In [41]:
model = nn.Dense(features=output_dim)
params = model.init(parent_key, X)
bce_loss_fn = make_bce_loss(X, y)
curr_loss, grads = jax.value_and_grad(bce_loss_fn)(params)

TypeError: ignored

In [None]:
# init testing/training regime
use_lax = False
num_epochs = 1
test_split = 0.2
shuffle_idx = jax.random.permutation(parent_key, x=jnp.arange(len(X)))
test_index = int(len(X) * (1 - test_split))

# init model paramters depending on if using jax or lax
if not use_lax:
  params = init_model([feature_dim, 1, 1], parent_key)
  print(jax.tree_map(lambda x: x.shape, params))
else:
  model = nn.Dense(features=output_dim)
  params = model.init(parent_key, X)
  optimizer = optax.adam(learning_rate=1e-2)
  opt_state = optimizer.init(params)  # state handled externally

## training loop
for epoch in range(num_epochs):
  
  if not use_lax:
    curr_loss, params = update(params, X[shuffle_idx[:test_index]], y[shuffle_idx[:test_index]])
  else:
    curr_loss, grads = value_and_grad_fn(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

  if epoch % (num_epochs // 5) == 0:
    print(f"\nEpoch: {epoch + 1}")
    print(f"Loss: {curr_loss:.4f}")
    print(f"Train Accuracy: {accuracy(params, X[shuffle_idx[:test_index]], y[shuffle_idx[:test_index]]):.4f}")
    print(f"Test Accuracy: {accuracy(params, X[shuffle_idx[test_index:]], y[shuffle_idx[test_index:]]):.4f}")

[((1, 3), (1,)), ((1, 1), (1,))]

Epoch: 1
Loss: 0.6942
Train Accuracy: 0.4980
Test Accuracy: 0.5091

Epoch: 11
Loss: 0.6942
Train Accuracy: 0.4980
Test Accuracy: 0.5091

Epoch: 21
Loss: 0.6942
Train Accuracy: 0.4980
Test Accuracy: 0.5091

Epoch: 31
Loss: 0.6942
Train Accuracy: 0.4980
Test Accuracy: 0.5091

Epoch: 41
Loss: 0.6942
Train Accuracy: 0.4980
Test Accuracy: 0.5091


In [None]:
init = 505
tmp = [i for i in range(init, init + 5)]
parent_key = jax.jax.random.PRNGKey(509)
params = init_model([feature_dim, 5, 1], parent_key)

print(y[tmp])
print(batched_predict(params, X[tmp]))
print(loss(params, X[tmp], y[tmp]))
print(accuracy(params, X[tmp], y[tmp]))
params

[0 0 1 1 1]
[[0.5321062 ]
 [0.5346689 ]
 [0.53111476]
 [0.5281173 ]
 [0.5325882 ]]
0.6824679
0.59999996


[(DeviceArray([[-0.17474912, -0.07796253, -0.03069317],
               [-0.12151062,  0.03844729, -0.01764875],
               [ 0.01983018, -0.0688489 ,  0.07098856],
               [-0.17429875,  0.04671511,  0.01443747],
               [ 0.04852775,  0.10632309, -0.02984036]], dtype=float32),
  DeviceArray([-0.03619826, -0.00648518, -0.24387108, -0.14063235,
               -0.00437354], dtype=float32)),
 (DeviceArray([[-0.0825086 , -0.01761603, -0.14344466, -0.03237919,
                -0.06099894]], dtype=float32),
  DeviceArray([0.14020397], dtype=float32))]