In [11]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
XLA_PYTHON_CLIENT_PREALLOCATE=False
# XLA_PYTHON_CLIENT_MEM_FRACTION=.50

from typing import NamedTuple
import h5py
import numpy as np
import jax
import jax.numpy as jnp
import optax
import haiku as hk
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

np.random.seed(8)
print("JAX version {}".format(jax.__version__))
print("Haiku version {}".format(hk.__version__))

JAX version 0.4.8
Haiku version 0.0.9


In [15]:
jax.local_devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [13]:
x = jax.device_put(jnp.array([1,2,3,4,5,6,7,8]))

In [14]:
key = jax.random.PRNGKey(8)

## Read data

In [5]:
# read data
f = h5py.File('/clusterfs/ml4hep/mfong/transfer_learning/delphes_train.h5', 'r')
f2 = h5py.File('/clusterfs/ml4hep/mfong/transfer_learning/delphes_test.h5', 'r')

In [6]:
feature_keys = ['fjet_clus_eta', 'fjet_clus_phi', 'fjet_clus_pt']
for k in f.keys():
    print(k, f[k].shape)

fjet_clus_E (4000068, 200)
fjet_clus_eta (4000068, 200)
fjet_clus_phi (4000068, 200)
fjet_clus_pt (4000068, 200)
fjet_eta (4000068,)
fjet_m (4000068,)
fjet_phi (4000068,)
fjet_pt (4000068,)
labels (4000068,)


In [7]:
x = np.concatenate([f[k] for k in feature_keys], axis=1)
x.shape

(4000068, 600)

In [8]:
scaler = StandardScaler()
scaler.fit(x)

x = scaler.transform(x)

In [9]:
y = f["labels"][:]
y.shape

(4000068,)

## MLP starter code
https://www.kaggle.com/code/alembcke/titanic-multi-layer-perceptron-using-haiku-jax

In [67]:
class TrainingState(NamedTuple):
    params: hk.Params
    # avg_params: hk.Params
    opt_state: optax.OptState

In [68]:
def net_fn(x: jax.Array) -> jax.Array:
  """Standard MLP network."""
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(64), jax.nn.relu,
      hk.Linear(8), jax.nn.relu,
      hk.Linear(1),
  ])
  return mlp(x)

In [77]:
network = hk.without_apply_rng(hk.transform(net_fn))
optimiser = optax.adam(1e-2)

In [78]:
def loss(params: hk.Params, features: jnp.ndarray, labels: jnp.ndarray):
    """Loss function, using Sigmoid Binary Cross Entropy loss."""
    logits = network.apply(params, features)
    return optax.sigmoid_binary_cross_entropy(logits, labels).sum(axis=-1).mean()

In [79]:
@jax.jit
def evaluate(params: hk.Params, features: jnp.ndarray, labels: jnp.ndarray):
    """Checks the accuracy of predictions compared to labels."""
    logits = network.apply(params, features)
    predictions = jnp.around(logits, 0)
    return jnp.mean(predictions == labels)

@jax.jit
def update(state: TrainingState, features: jnp.ndarray, labels: jnp.ndarray) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(state.params, features, labels)
    updates, opt_state = optimiser.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    # avg_params = optax.incremental_update(params, state.avg_params, step_size=0.001)
    # return TrainingState(params, avg_params, opt_state)
    return TrainingState(params, opt_state)

In [80]:
initial_params = network.init(jax.random.PRNGKey(seed=8), x[0])
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_opt_state)
# state = TrainingState(initial_params, initial_params, initial_opt_state)

In [81]:
print(hk.experimental.tabulate(network)(x[0]))

+----------------------------+--------------------------------------------------------------------------------+-----------------+----------+----------+---------------+---------------+
| Module                     | Config                                                                         | Module params   | Input    | Output   |   Param count |   Param bytes |
| sequential (Sequential)    | Sequential(                                                                    |                 | f32[600] | f32[1]   |        38,993 |     155.97 KB |
|                            |     layers=[Flatten(),                                                         |                 |          |          |               |               |
|                            |             Linear(output_size=64),                                            |                 |          |          |               |               |
|                            |             <jax._src.custom_derivatives.custom_j

In [84]:
for step in range(100):
    if step % 10 == 0:
        accuracy = np.array(evaluate(state.params, x[:10000], y[:10000])).item()
        print({"step": step, "accuracy": f"{accuracy:.3f}"})

    # Do SGD on training examples.
    state = update(state, x[:10000], y[:10000])

{'step': 0, 'accuracy': '0.501'}
{'step': 10, 'accuracy': '0.501'}
{'step': 20, 'accuracy': '0.501'}
{'step': 30, 'accuracy': '0.501'}
{'step': 40, 'accuracy': '0.501'}
{'step': 50, 'accuracy': '0.501'}
{'step': 60, 'accuracy': '0.501'}
{'step': 70, 'accuracy': '0.501'}
{'step': 80, 'accuracy': '0.501'}
{'step': 90, 'accuracy': '0.501'}
