# Neural Network with SPU

>  Please read lab [Logistic Regression On SPU](./lr_with_spu.ipynb) first if you have not。

In lab [Logistic Regression On SPU](./lr_with_spu.ipynb), we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secure MPC training program.

In this lab, the idea is quite similar but this time we will work with a Neural Network model.

We are going to use the same dataset and all the settings as lab [Logistic Regression On SPU](./lr_with_spu.ipynb).

And first, let's work out the plaintext model.

>The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production.

> This tutorial needs more resources than 8c16g, which is the minimum requirement of SecretFlow.

## Train a model with JAX/FLAX

### Load the Dataset

The below is just copied from lab [Logistic Regression On SPU](./lr_with_spu.ipynb). I'm not going to explain again.

In [15]:
import sys

!{sys.executable} -m pip install --upgrade flax

[0mCollecting flax
  Using cached flax-0.8.5-py3-none-any.whl.metadata (10 kB)
Collecting orbax-checkpoint (from flax)
  Using cached orbax_checkpoint-0.5.23-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorstore (from flax)
  Using cached tensorstore-0.1.63-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting importlib_resources (from etils[epath,epy]->orbax-checkpoint->flax)
  Using cached importlib_resources-6.4.0-py3-none-any.whl.metadata (3.9 kB)
Collecting zipp (from etils[epath,epy]->orbax-checkpoint->flax)
  Using cached zipp-3.19.2-py3-none-any.whl.metadata (3.6 kB)
Downloading flax-0.8.5-py3-none-any.whl (731 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.3/731.3 kB[0m [31m18.4 kB/s[0m eta [36m0:00:00[0ma [36m0:00:02[0m
[?25hDownloading orbax_checkpoint-0.5.23-py3-none-any.whl (232 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.5/232.5 kB[0m [31m419.7 kB/s[0m eta [36m0:00:00[0ma 

In [2]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], _
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

In [4]:
breast_cancer(1,True)

(array([[2.12247297e-03, 4.07381288e-03, 1.38199342e-02, ...,
         4.43112365e-04, 4.15373766e-03, 2.24471086e-06],
        [4.95768688e-03, 6.24588622e-03, 3.35448989e-02, ...,
         1.03761166e-03, 1.91490362e-02, 9.99764927e-07],
        [2.15632346e-03, 3.25811001e-03, 1.39163141e-02, ...,
         6.13070052e-04, 5.52891396e-03, 2.05406676e-06],
        ...,
        [3.35919135e-03, 3.95392572e-03, 2.12270804e-02, ...,
         1.98377997e-04, 2.53173484e-03, 8.20874471e-07],
        [3.28631876e-03, 4.61212976e-03, 2.14198402e-02, ...,
         3.76586742e-04, 4.43112365e-03, 1.24917724e-06],
        [2.86318759e-03, 4.82369535e-03, 1.81523272e-02, ...,
         2.78091208e-04, 3.45086977e-03, 1.19417019e-06]]),
 '')

### Define the Model


We are going to use a 4-layer [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) model with a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function here.

In [17]:
from typing import Sequence
import flax.linen as nn


FEATURES = [30, 15, 8, 1]


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

Then we define the training method here.

In [18]:
import jax.numpy as jnp
import jax


def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)


def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)


def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for x, y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
            )
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params


def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

### Validate the Model

We use AUC as the validation metric.

In [19]:
from sklearn.metrics import roc_auc_score


def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

### BUILD Together

Let's put everything together and train a plaintext NN model!

In [20]:
import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)


# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01


# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

auc=0.9927939731411726


Must keep the number of AUC in mind, we are going to repeat the training with SPU. Let's do that magic!


## Train a Model with SPU

In [21]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)


device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)

params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

The version of SecretFlow: 1.8.0b0


  self.pid = _posixsubprocess.fork_exec(
2024-07-28 18:42:00,793	INFO worker.py:1724 -- Started a local Ray instance.
[36m(_run pid=3552416)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': 
[36m(_run pid=3552416)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(_run pid=3552416)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


[36m(_run pid=3552420)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': 
[36m(_run pid=3552420)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(_run pid=3552420)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Let's check params from SPU program.

In [22]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

{'params': {'Dense_0': {'bias': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, -8.4448457e-03,
        4.7277704e-02,  3.7601590e-04,  0.0000000e+00,  4.5651197e-03,
        0.0000000e+00, -3.4031868e-02, -8.4131658e-03,  0.0000000e+00,
        0.0000000e+00,  5.6682825e-02, -4.8434734e-03,  0.0000000e+00,
        3.5732180e-02,  6.3550323e-03,  2.9711276e-03,  3.2665536e-02,
        0.0000000e+00, -2.1323681e-02, -7.8184158e-03,  0.0000000e+00,
        2.8501868e-02,  0.0000000e+00, -3.0903667e-03,  3.8698316e-05,
        1.4437318e-02,  2.0847723e-02], dtype=float32), 'kernel': array([[-0.14871399, -0.23531966, -0.1493772 , -0.01558939, -0.13323145,
         0.1917589 , -0.03680335, -0.03745073, -0.14176767,  0.0323102 ,
         0.12652728, -0.40251398, -0.16895528,  0.21399277, -0.13845326,
         0.10585146, -0.11602809,  0.38624364,  0.059659  ,  0.06317471,
         0.07793002, -0.01319632, -0.28805006, -0.09602834,  0.11111304,
        -0.08544238,  0.07546453, -0.041

Lastly, let's validate the model.

In [23]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

auc=0.9927939731411726


This is the end of the lab.