In [1]:
import secretflow as sf

In [2]:
edge_parties_number = 2
edge_party_name = 'edge_party_{i}'
edge_parties = [edge_party_name.format(i=i) for i in range(edge_parties_number)]
server_party_name = 'server_party'
server_party = [server_party_name]
all_parties = edge_parties + server_party

In [3]:
edge_parties

['edge_party_0', 'edge_party_1']

In [4]:
sf.init(parties=all_parties, address='local')

  from .autonotebook import tqdm as notebook_tqdm
2024-11-25 16:53:42,454	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-25 16:53:42,644	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
  self.pid = _posixsubprocess.fork_exec(
2024-11-25 16:53:44,547	INFO worker.py:1724 -- Started a local Ray instance.


In [5]:
edge_devices = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]
# use pyu to simulate teeu
edge_tees = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]

server_device = sf.PYU(server_party_name)
server_tee = sf.PYU(server_party_name)

In [6]:
# custom parameters
i = 0
j = 1
m = 100
kappa = 32
u_low = 0.0
u_high = 2.0
# k is the ring size 2^k. usually take k = 32, 64. like size of int
k = 32  # implies use uint32

In [7]:
import jax
import jax.numpy as jnp
import numpy as np
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from typing import Tuple
import hashlib


def bytes_to_jax_random_key(byte_key):
    # Hash the bytes key using SHA-256
    hash_digest = hashlib.sha256(byte_key).digest()

    # Convert the first 4 bytes of the hash digest to an integer
    # JAX random keys expect a 32-bit integer, hence we use the first 4 bytes
    seed = int.from_bytes(hash_digest[:4], 'big')

    # Create a JAX random key with this seed
    jax_key = jax.random.PRNGKey(seed)

    return jax_key


# Function to encrypt a jnp array using AES-GCM
def encrypt_jnp_array_gcm(jnp_array, key) -> Tuple[bytes, bytes, bytes]:
    # Convert jnp array to numpy array
    np_array = np.array(jnp_array)

    # Convert numpy array to bytes
    array_bytes = np_array.tobytes()

    # Create AES cipher in GCM mode
    cipher = AES.new(key, AES.MODE_GCM)

    # Encrypt data
    ciphertext, tag = cipher.encrypt_and_digest(array_bytes)

    # Return the ciphertext, tag, and nonce
    return ciphertext, tag, cipher.nonce


# Function to decrypt a jnp array using AES-GCM
def decrypt_to_jnp_array_gcm(ciphertext, tag, nonce, key, dtype, shape):
    # Create AES cipher in GCM mode with the same parameters
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)

    # Decrypt data
    decrypted_data = cipher.decrypt_and_verify(ciphertext, tag)

    # Convert bytes back to numpy array
    decrypted_np_array = np.frombuffer(decrypted_data, dtype=dtype).reshape(shape)

    # Convert numpy array to jnp array
    decrypted_jnp_array = jnp.array(decrypted_np_array)

    return decrypted_jnp_array


# Example usage
if __name__ == "__main__":
    # Generate a random key for AES-256 (32 bytes)
    key = get_random_bytes(32)

    # Create a jnp array
    original_jnp_array = jnp.array(np.random.uniform(u_low, u_high, (m,)))

    # Encrypt the jnp array using AES-GCM
    ciphertext, tag, nonce = encrypt_jnp_array_gcm(original_jnp_array, key)

    # Decrypt to a jnp array
    decrypted_jnp_array = decrypt_to_jnp_array_gcm(
        ciphertext, tag, nonce, key, dtype=jnp.float32, shape=original_jnp_array.shape
    )

    # Check if the original and decrypted arrays are the same
    print("Original JAX Array:")
    print(original_jnp_array)
    print("\nDecrypted JAX Array:")
    print(decrypted_jnp_array)
    print(
        "\nArrays are equal:", jnp.array_equal(original_jnp_array, decrypted_jnp_array)
    )



Original JAX Array:
[1.543261   1.8819652  1.3520194  0.12602167 0.00819245 0.8335233
 0.24762815 0.5973602  0.84210426 1.6452619  0.60468984 1.270024
 0.6331784  0.52105224 0.75727856 1.0133526  1.3102268  0.1769932
 0.49942148 0.8335648  0.14698231 0.8164377  0.24726874 0.48200187
 1.8702365  1.8223263  0.9344147  1.223453   0.00666214 0.9191807
 0.9327775  1.5772036  0.3084602  1.7717539  1.491721   0.18190163
 1.3147568  0.57207686 0.5401994  1.8854702  0.870894   0.54716337
 1.2602699  0.691345   0.72894716 1.5469918  1.9396929  0.6442396
 1.54365    0.8659952  0.6514194  1.2232236  1.5714749  0.29944515
 0.33028755 0.21620859 0.5635623  1.2797244  0.30582815 0.06229384
 0.06313824 0.5007081  0.81141025 0.5239505  0.37149203 1.1353881
 0.36785588 0.7070881  0.40961593 1.7803401  1.5857289  0.34309584
 1.8364183  1.3298665  0.6180564  0.38195848 0.46578258 1.912777
 0.9642606  0.36924857 0.27891305 1.3318433  1.211587   1.987189
 1.1710542  1.679124   1.1500549  1.031098   1.801167

In [8]:
import numpy as np
import jax.numpy as jnp
import jax

# Simulate handles

handle_i_j = edge_tees[i](lambda x: get_random_bytes(x))(kappa)
handle_i_s = edge_tees[i](lambda x: get_random_bytes(x))(kappa)

# note that the establishment is not simplified.
handle_j_i = handle_i_j.to(edge_tees[j])
handle_j_s = edge_tees[j](lambda x: get_random_bytes(x))(kappa)

server_handle_i_s = handle_i_s.to(server_tee)
server_handle_j_s = handle_j_s.to(server_tee)


# P_i holds u_i
u_i = edge_devices[i](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))

# P_j holds u_j
u_j = edge_devices[j](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))


# corr preprocessing
def corr(k, m, dev1, key1, dev2, key2, return_zero_sharing=False):
    """Correlation function

    Args:
        k (int): ring size will be 2^k. Support k = 64 or 128 for now
        m (int): size of array to be correlated
        dev1 (Device): device 1
        key1 (Key): key for device 1
        dev2 (Device): device 2
        key2 (Key): key for device 2, key2 must be the same as key1 yet hold by different device
    """
    assert k == 32, "Only support k = 32 for now"
    dtype = jnp.uint32
    corr_dev1 = dev1(
        lambda key, shape, dtype: jax.random.bits(
            bytes_to_jax_random_key(key), shape, dtype
        )
    )(key1, (m,), dtype)
    if not return_zero_sharing:
        corr_dev2 = dev2(
            lambda key, shape, dtype: jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    else:
        corr_dev2 = dev2(
            lambda key, shape, dtype: -jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    return corr_dev1, corr_dev2


def cos_sim(u_i, u_j):
    # proramming details not related to protocol
    dtype_ref_i = edge_tees[i](lambda x: x.dtype)(u_i.to(edge_tees[i]))
    shape_ref_i = edge_tees[i](lambda x: x.shape)(u_i.to(edge_tees[i]))
    dtype_ref_j = edge_tees[j](lambda x: x.dtype)(u_j.to(edge_tees[j]))
    shape_ref_j = edge_tees[j](lambda x: x.shape)(u_j.to(edge_tees[j]))
    # let's suppose that the reference type and shape are the same for u_i and u_j and it is ok to share to server
    dtype_ref_server = dtype_ref_i.to(server_tee)
    shape_ref_server = shape_ref_i.to(server_tee)

    # preprocessing
    server_a, edge_tee_i_a = corr(
        k, m, server_tee, server_handle_i_s, edge_tees[i], handle_i_s
    )
    server_b, edge_tee_j_b = corr(
        k, m, server_tee, server_handle_j_s, edge_tees[j], handle_j_s
    )
    c = server_tee(lambda a, b: a * b)(server_a, server_b)

    # normalize u_i and u_j
    u_i_normalized = edge_tees[i](lambda x: x / jnp.linalg.norm(x))(
        u_i.to(edge_tees[i])
    )
    u_j_normalized = edge_tees[j](lambda x: x / jnp.linalg.norm(x))(
        u_j.to(edge_tees[j])
    )

    # E_i encrypts e = u_i_normalized - a, sends to P_j via P_i
    e = edge_tees[i](lambda x, y: x - y)(u_i_normalized, edge_tee_i_a)
    c_e, c_e_tag, c_e_nouce = edge_tees[i](encrypt_jnp_array_gcm, num_returns=3)(
        e, handle_i_j
    )
    c_e_j = c_e.to(edge_devices[i]).to(edge_devices[j])
    c_e_tag_j = c_e_tag.to(edge_devices[i]).to(edge_devices[j])
    c_e_nouce_j = c_e_nouce.to(edge_devices[i]).to(edge_devices[j])

    # E_j encrypts f = u_j_normalized - b, sends to P_i via P_j
    f = edge_tees[j](lambda x, y: x - y)(u_j_normalized, edge_tee_j_b)
    c_f, c_f_tag, c_f_nouce = edge_tees[j](encrypt_jnp_array_gcm, num_returns=3)(
        f, handle_j_i
    )

    c_f_i = c_f.to(edge_devices[j]).to(edge_devices[i])
    c_f_tag_i = c_f_tag.to(edge_devices[j]).to(edge_devices[i])
    c_f_nouce_i = c_f_nouce.to(edge_devices[j]).to(edge_devices[i])

    # P_i decrypts f in E_i
    try:
        f_dec = edge_tees[i](decrypt_to_jnp_array_gcm)(
            c_f_i.to(edge_tees[i]),
            c_f_tag_i.to(edge_tees[i]),
            c_f_nouce_i.to(edge_tees[i]),
            handle_i_j,
            dtype_ref_i,
            shape_ref_i,
        )
        sf.wait(f_dec)
    except:
        raise RuntimeError("Error in decrypting f, abort")

    # P_i decrypts c in E_j
    try:
        e_dec = edge_tees[j](decrypt_to_jnp_array_gcm)(
            c_e_j.to(edge_tees[j]),
            c_e_tag_j.to(edge_tees[j]),
            c_e_nouce_j.to(edge_tees[j]),
            handle_j_i,
            dtype_ref_j,
            shape_ref_j,
        )
        sf.wait(e_dec)
    except:
        raise RuntimeError("Error in decrypting c, abort")

    edge_tee_i_u_i_bracket_1, edge_tee_j_u_i_bracket_1 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_u_j_bracket_0, edge_tee_j_u_j_bracket_0 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_d, edge_tee_j_d = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i, True
    )

    u_i_bracket_0 = edge_tees[i](lambda x, y: x - y)(u_i, edge_tee_i_u_i_bracket_1)
    u_j_bracket_1 = edge_tees[j](lambda x, y: x - y)(u_j, edge_tee_j_u_j_bracket_0)

    # E_i computes:
    z_bracket_0 = edge_tees[i](
        lambda x1, x2, x3, x4, x5: x1 * x2 + x3 * x4 + x1 * x3 + x5
    )(e, edge_tee_i_u_j_bracket_0, f_dec, u_i_bracket_0, edge_tee_i_d)

    # E_i encrypts z_bracket_0, sends to server tee via server
    c_z_bracket_0, c_z_bracket_0_tag, c_z_bracket_0_nouce = edge_tees[i](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_0, handle_i_s)
    c_z_bracket_0_server = c_z_bracket_0.to(server_device).to(server_tee)
    c_z_bracket_0_tag_server = c_z_bracket_0_tag.to(server_device).to(server_tee)
    c_z_bracket_0_nouce_server = c_z_bracket_0_nouce.to(server_device).to(server_tee)

    # E_j computes:
    z_bracket_1 = edge_tees[j](lambda x1, x2, x3, x4, x5: x1 * x2 + x3 * x4 + x5)(
        e_dec, u_j_bracket_1, f, edge_tee_j_u_i_bracket_1, edge_tee_j_d
    )

    # E_j encrypts z_bracket_1, sends to server tee via server
    c_z_bracket_1, c_z_bracket_1_tag, c_z_bracket_1_nouce = edge_tees[j](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_1, handle_j_s)
    c_z_bracket_1_server = c_z_bracket_1.to(server_device).to(server_tee)
    c_z_bracket_1_tag_server = c_z_bracket_1_tag.to(server_device).to(server_tee)
    c_z_bracket_1_nouce_server = c_z_bracket_1_nouce.to(server_device).to(server_tee)

    # server tries to decrypt
    try:
        z_bracket_0_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_0_server,
            c_z_bracket_0_tag_server,
            c_z_bracket_0_nouce_server,
            server_handle_i_s,
            dtype_ref_server,
            shape_ref_server,
        )
        z_bracket_1_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_1_server,
            c_z_bracket_1_tag_server,
            c_z_bracket_1_nouce_server,
            server_handle_j_s,
            dtype_ref_server,
            shape_ref_server,
        )
        sf.wait([z_bracket_0_dec, z_bracket_1_dec])
    except:
        raise Exception("Decryption z values failed")

    z = server_tee(lambda x, y: x + y)(z_bracket_0_dec, z_bracket_1_dec)
    cos_sim_val = server_tee(lambda x, y: jnp.sum(x + y))(z, c)
    return cos_sim_val

In [9]:
sf.reveal(cos_sim(u_i, u_j))

[36m(pyu_fn pid=758380)[0m An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[36m(pyu_fn pid=758376)[0m 2024-11-25 16:53:46,843,843 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'cuda': 
[36m(pyu_fn pid=758376)[0m 2024-11-25 16:53:46,843,843 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(pyu_fn pid=758376)[0m 2024-11-25 16:53:46,844,844 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Array(5.5957816e+20, dtype=float32)