<a href="https://colab.research.google.com/github/kbrezinski/GAT-Malware/blob/main/complexity_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [None]:
# constants
seed = 2022
feature_dim = 2
output_dim = 1

def init_model(layer_widths, parent_key):

  def random_layer_params(n_in, n_out, key, scale=1e-2):
    k1, k2 = jax.random.split(key)
    return (scale * jax.random.normal(k1, (n_out, n_in)),
            scale * jax.random.normal(k2, (n_out, )))
    
  # split keys for each layer
  keys = jax.random.split(parent_key, num=len(layer_widths) - 1)

  return [random_layer_params(n_in, n_out, key)
        for n_in, n_out, key in zip(layer_widths[:-1], layer_widths[1:], keys)]

parent_key = jax.jax.random.PRNGKey(seed)
params = init_model([feature_dim, 16, output_dim], parent_key)
jax.tree_map(lambda x: x.shape, params)

In [None]:
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

# Generate fake data based on blobs with 2 features
centers = [[1, 1], [-1, -1]]
X, y = make_blobs(n_samples=[800, 200], centers=centers, cluster_std=1)

In [None]:
def predict(params, x):
  activations = x

  for w, b in params[:-1]:
    activations = jax.nn.relu(jnp.dot(w, activations) + b)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return jax.nn.sigmoid(logits)

def accuracy(params, x, y):
  preds = batched_predict(params, x)
  predicted_class = jnp.round(preds)
  return jnp.mean(predicted_class == y)

def loss(params, x, y, lmbda=1e-1):
  preds = batched_predict(params, x)
  pred = jnp.clip(preds, 1e-14, 1 - 1e-14)  # bound the probabilities to avoid log(0)
  return -jnp.mean(y * jnp.log(pred) + (1 - y) * jnp.log(1 - pred))

@jit
def update(params, x, y, lr=5e-2):
  curr_loss, grads = value_and_grad(loss)(params, x, y)
  return curr_loss, jax.tree_multimap(lambda p, g: p - lr*g, params, grads)

batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
params = init_model([feature_dim, 8, output_dim], parent_key)

num_epochs = 5000

## training loop
for epoch in range(num_epochs):
  
  curr_loss, params = update(params, X[:800], y[:800])

  if epoch % (num_epochs // 5) == 0:
    print(f"\nEpoch: {epoch + 1}")
    print(f"Loss: {curr_loss:.4f}")
    print(f"Train Accuracy: {accuracy(params, X[:800], y[:800]):.4f}")
    print(f"Test Accuracy: {accuracy(params, X[800:], y[800:]):.4f}")

In [None]:
print(y[800:805])
accuracy(params, X[800:805], y[800:805])

In [None]:
jax.random.shuffle(key=parent_key, x=X)