In [44]:
import secretflow as sf

In [45]:
edge_parties_number = 2
edge_party_name = 'edge_party_{i}'
edge_parties = [edge_party_name.format(i=i) for i in range(edge_parties_number)]
server_party_name = 'server_party'
server_party = [server_party_name]
all_parties = edge_parties + server_party

In [46]:
edge_parties

['edge_party_0', 'edge_party_1']

In [47]:
sf.init(parties=all_parties, address='local')

RuntimeError: Maybe you called ray.init twice by accident? This error can be suppressed by passing in 'ignore_reinit_error=True' or by calling 'ray.shutdown()' prior to 'ray.init()'.

In [None]:
edge_devices = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]
# use pyu to simulate teeu
edge_tees = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]

server_device = sf.PYU(server_party_name)
server_tee = sf.PYU(server_party_name)

In [None]:
# custom parameters
i = 0
j = 1
m = 100
kappa = 32
u_low = 0.0
u_high = 2.0
# k is the ring size 2^k. usually take k = 32, 64. like size of int
k = 64  # implies use uint64
fxp = 26  # fixed point precision

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from typing import Tuple
import hashlib

jax.config.update("jax_enable_x64", True)


def bytes_to_jax_random_key(byte_key):
    # Hash the bytes key using SHA-256
    hash_digest = hashlib.sha256(byte_key).digest()

    # Convert the first 4 bytes of the hash digest to a 64-bit integer
    seed = int.from_bytes(hash_digest[:4], 'big')

    # Create a JAX random key with this seed
    jax_key = jax.random.PRNGKey(seed)

    return jax_key


# Function to encrypt a jnp array using AES-GCM
def encrypt_jnp_array_gcm(jnp_array, key) -> Tuple[bytes, bytes, bytes]:

    # Convert numpy array to bytes
    array_bytes = jnp_array.tobytes()

    # Create AES cipher in GCM mode
    cipher = AES.new(key, AES.MODE_GCM)

    # Encrypt data
    ciphertext, tag = cipher.encrypt_and_digest(array_bytes)

    # Return the ciphertext, tag, and nonce
    return ciphertext, tag, cipher.nonce


# Function to decrypt a jnp array using AES-GCM
def decrypt_to_jnp_array_gcm(ciphertext, tag, nonce, key, dtype, shape):
    # Create AES cipher in GCM mode with the same parameters
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)

    # Decrypt data
    decrypted_data = cipher.decrypt_and_verify(ciphertext, tag)

    # Convert bytes back to numpy array
    decrypted_jnp_array = jnp.frombuffer(decrypted_data, dtype=dtype).reshape(shape)

    return decrypted_jnp_array


# Example usage
if __name__ == "__main__":
    # Generate a random key for AES-256 (32 bytes)
    key = get_random_bytes(32)

    # Create a jnp array
    original_jnp_array = jnp.uint64(
        jnp.array(np.random.uniform(u_low, u_high, (m,)), dtype=jnp.float64)
    )

    # Encrypt the jnp array using AES-GCM
    ciphertext, tag, nonce = encrypt_jnp_array_gcm(original_jnp_array, key)

    # Decrypt to a jnp array
    decrypted_jnp_array = decrypt_to_jnp_array_gcm(
        ciphertext, tag, nonce, key, dtype=jnp.uint64, shape=original_jnp_array.shape
    )

    # Check if the original and decrypted arrays are the same
    print("Original JAX Array:")
    print(original_jnp_array)
    print(original_jnp_array.dtype)
    print("\nDecrypted JAX Array:")
    print(decrypted_jnp_array)
    print(decrypted_jnp_array.dtype)
    print(
        "\nArrays are equal:", jnp.array_equal(original_jnp_array, decrypted_jnp_array)
    )

Original JAX Array:
[0 1 1 1 1 0 1 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 1 0 1 1 0 0 0 1 1 1 1 1 0
 0 1 1 0 0 1 1 1 1 0 1 0 0 0 1 1 1 1 0 0 1 1 1 1 0 1 0 1 1 1 1 0 1 0 1 0 0
 0 1 0 0 1 0 1 0 1 0 1 1 0 1 0 1 1 1 1 0 1 0 0 0 1 1]
uint64

Decrypted JAX Array:
[0 1 1 1 1 0 1 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 1 0 1 1 0 0 0 1 1 1 1 1 0
 0 1 1 0 0 1 1 1 1 0 1 0 0 0 1 1 1 1 0 0 1 1 1 1 0 1 0 1 1 1 1 0 1 0 1 0 0
 0 1 0 0 1 0 1 0 1 0 1 1 0 1 0 1 1 1 1 0 1 0 0 0 1 1]
uint64

Arrays are equal: True


In [None]:
a = jnp.uint64(2)
b = a**64 - 1
# check natural overflow
print(b * b)

1


In [None]:
import numpy as np
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


edge_devices[i](lambda: jax.config.update("jax_enable_x64", True))()
edge_devices[j](lambda: jax.config.update("jax_enable_x64", True))()
server_device(lambda: jax.config.update("jax_enable_x64", True))()

edge_tees[i](lambda: jax.config.update("jax_enable_x64", True))()
edge_tees[j](lambda: jax.config.update("jax_enable_x64", True))()
server_tee(lambda: jax.config.update("jax_enable_x64", True))()

# Simulate handles

handle_i_j = edge_tees[i](lambda x: get_random_bytes(x))(kappa)
handle_i_s = edge_tees[i](lambda x: get_random_bytes(x))(kappa)

# note that the establishment is not simplified.
handle_j_i = handle_i_j.to(edge_tees[j])
handle_j_s = edge_tees[j](lambda x: get_random_bytes(x))(kappa)

server_handle_i_s = handle_i_s.to(server_tee)
server_handle_j_s = handle_j_s.to(server_tee)


# P_i holds u_i
u_i = edge_devices[i](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))

# P_j holds u_j
u_j = edge_devices[j](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))


# corr preprocessing
def corr(k, m, dev1, key1, dev2, key2, return_zero_sharing=False):
    """Correlation function

    Args:
        k (int): ring size will be 2^k. Support k = 64 or 128 for now
        m (int): size of array to be correlated
        dev1 (Device): device 1
        key1 (Key): key for device 1
        dev2 (Device): device 2
        key2 (Key): key for device 2, key2 must be the same as key1 yet hold by different device
    """
    assert k == 64, "Only support k = 64 for now"
    dtype = jnp.uint64
    corr_dev1 = dev1(
        lambda key, shape, dtype: jax.random.bits(
            bytes_to_jax_random_key(key), shape, dtype
        )
    )(key1, (m,), dtype)
    if not return_zero_sharing:
        corr_dev2 = dev2(
            lambda key, shape, dtype: jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    else:
        corr_dev2 = dev2(
            lambda key, shape, dtype: -jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    return corr_dev1, corr_dev2


def cos_sim(u_i, u_j):
    # proramming details not related to protocol
    fixed_point_type = jnp.uint64
    shape_ref_i = edge_tees[i](lambda x: x.shape)(u_i.to(edge_tees[i]))
    shape_ref_j = edge_tees[j](lambda x: x.shape)(u_j.to(edge_tees[j]))
    # let's suppose that the reference type and shape are the same for u_i and u_j and it is ok to share to server
    shape_ref_server = shape_ref_i.to(server_tee)

    # preprocessing
    server_a, edge_tee_i_a = corr(
        k, m, server_tee, server_handle_i_s, edge_tees[i], handle_i_s
    )
    server_b, edge_tee_j_b = corr(
        k, m, server_tee, server_handle_j_s, edge_tees[j], handle_j_s
    )
    c = server_tee(lambda a, b: a * b)(server_a, server_b)

    # normalize u_i and u_j
    u_i_normalized = edge_tees[i](
        lambda x: jnp.array(x / jnp.linalg.norm(x) * (2.0**fxp), dtype=fixed_point_type)
    )(u_i.to(edge_tees[i]))
    u_j_normalized = edge_tees[j](
        lambda x: fixed_point_type(x / jnp.linalg.norm(x) * (2.0**fxp))
    )(u_j.to(edge_tees[j]))

    # E_i encrypts e = u_i_normalized - a, sends to P_j via P_i
    e = edge_tees[i](lambda x, y: x - y)(u_i_normalized, edge_tee_i_a)
    c_e, c_e_tag, c_e_nouce = edge_tees[i](encrypt_jnp_array_gcm, num_returns=3)(
        e, handle_i_j
    )
    c_e_j = c_e.to(edge_devices[i]).to(edge_devices[j])
    c_e_tag_j = c_e_tag.to(edge_devices[i]).to(edge_devices[j])
    c_e_nouce_j = c_e_nouce.to(edge_devices[i]).to(edge_devices[j])

    # E_j encrypts f = u_j_normalized - b, sends to P_i via P_j
    f = edge_tees[j](lambda x, y: x - y)(u_j_normalized, edge_tee_j_b)
    c_f, c_f_tag, c_f_nouce = edge_tees[j](encrypt_jnp_array_gcm, num_returns=3)(
        f, handle_j_i
    )

    c_f_i = c_f.to(edge_devices[j]).to(edge_devices[i])
    c_f_tag_i = c_f_tag.to(edge_devices[j]).to(edge_devices[i])
    c_f_nouce_i = c_f_nouce.to(edge_devices[j]).to(edge_devices[i])

    # P_i decrypts f in E_i
    try:
        f_dec = edge_tees[i](decrypt_to_jnp_array_gcm)(
            c_f_i.to(edge_tees[i]),
            c_f_tag_i.to(edge_tees[i]),
            c_f_nouce_i.to(edge_tees[i]),
            handle_i_j,
            fixed_point_type,
            shape_ref_i,
        )
        sf.wait(f_dec)
    except:
        raise RuntimeError("Error in decrypting f, abort")

    # P_i decrypts c in E_j
    try:
        e_dec = edge_tees[j](decrypt_to_jnp_array_gcm)(
            c_e_j.to(edge_tees[j]),
            c_e_tag_j.to(edge_tees[j]),
            c_e_nouce_j.to(edge_tees[j]),
            handle_j_i,
            fixed_point_type,
            shape_ref_j,
        )
        sf.wait(e_dec)
    except:
        raise RuntimeError("Error in decrypting c, abort")

    edge_tee_i_u_i_bracket_1, edge_tee_j_u_i_bracket_1 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_u_j_bracket_0, edge_tee_j_u_j_bracket_0 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_d, edge_tee_j_d = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i, True
    )

    u_i_bracket_0 = edge_tees[i](lambda x, y: x - y)(u_i, edge_tee_i_u_i_bracket_1)
    u_j_bracket_1 = edge_tees[j](lambda x, y: x - y)(u_j, edge_tee_j_u_j_bracket_0)

    # E_i computes:
    z_bracket_0 = edge_tees[i](
        lambda x1, x2, x3, x4, x5: fixed_point_type(
            fixed_point_type(x1 * (x2 / (2.0**fxp)))
            + fixed_point_type(x3 * (x4 / (2.0**fxp)))
            + fixed_point_type(x1 * (x3 / (2.0**fxp)))
            + x5
        )
    )(e, edge_tee_i_u_j_bracket_0, f_dec, u_i_bracket_0, edge_tee_i_d)

    # E_i encrypts z_bracket_0, sends to server tee via server
    c_z_bracket_0, c_z_bracket_0_tag, c_z_bracket_0_nouce = edge_tees[i](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_0, handle_i_s)
    c_z_bracket_0_server = c_z_bracket_0.to(server_device).to(server_tee)
    c_z_bracket_0_tag_server = c_z_bracket_0_tag.to(server_device).to(server_tee)
    c_z_bracket_0_nouce_server = c_z_bracket_0_nouce.to(server_device).to(server_tee)

    # E_j computes:
    z_bracket_1 = edge_tees[j](
        lambda x1, x2, x3, x4, x5: fixed_point_type(
            fixed_point_type(x1 * (x2 / (2.0**fxp)))
            + fixed_point_type(x3 * (x4 / (2.0**fxp)))
            + x5
        )
    )(e_dec, u_j_bracket_1, f, edge_tee_j_u_i_bracket_1, edge_tee_j_d)

    # E_j encrypts z_bracket_1, sends to server tee via server
    c_z_bracket_1, c_z_bracket_1_tag, c_z_bracket_1_nouce = edge_tees[j](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_1, handle_j_s)
    c_z_bracket_1_server = c_z_bracket_1.to(server_device).to(server_tee)
    c_z_bracket_1_tag_server = c_z_bracket_1_tag.to(server_device).to(server_tee)
    c_z_bracket_1_nouce_server = c_z_bracket_1_nouce.to(server_device).to(server_tee)

    # server tries to decrypt
    try:
        z_bracket_0_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_0_server,
            c_z_bracket_0_tag_server,
            c_z_bracket_0_nouce_server,
            server_handle_i_s,
            fixed_point_type,
            shape_ref_server,
        )
        z_bracket_1_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_1_server,
            c_z_bracket_1_tag_server,
            c_z_bracket_1_nouce_server,
            server_handle_j_s,
            fixed_point_type,
            shape_ref_server,
        )
        sf.wait([z_bracket_0_dec, z_bracket_1_dec])
    except:
        raise Exception("Decryption z values failed")

    z = server_tee(lambda x, y: x + y)(z_bracket_0_dec, z_bracket_1_dec)
    cos_sim_val = server_tee(
        lambda x, y: jnp.sum(fixed_point_type(x + y) / (2.0**fxp))
    )(z, c)
    return cos_sim_val

In [None]:
sf.reveal(cos_sim(u_i, u_j))

Array(1.39007595e+13, dtype=float64)

In [None]:
# import required libraries
import numpy as np
from numpy.linalg import norm

# define two lists or array
A = sf.reveal(u_i)
B = sf.reveal(u_j)

print("A:", A)
print("B:", B)

# compute cosine similarity
cosine = np.dot(A, B) / (norm(A) * norm(B))
print("Cosine Similarity:", cosine)

A: [0.8623193  0.17956612 1.79077977 1.7438163  1.1873991  1.96022877
 1.51734483 0.22431422 1.82826904 0.51600967 1.6775534  0.29838605
 0.93400558 1.47112485 1.86854324 0.60968533 0.53165953 1.52955898
 1.25030583 1.83944766 1.25028483 0.03067724 0.26707032 1.72205597
 0.23796422 1.50362566 1.54174311 0.83910885 0.26748663 0.14106073
 1.7788103  1.13274733 1.9984689  0.44985267 1.19070499 1.37773989
 1.53284534 0.17171045 0.78381154 0.88692749 0.89875841 1.46863582
 1.47288639 0.58524679 0.02368717 1.88614066 1.89577746 1.90956089
 1.16216615 1.08919253 0.2464347  1.72376545 0.72241741 0.65852687
 1.72958954 0.24340316 1.87089398 0.28033962 0.80961154 1.15908833
 0.45361611 0.70909795 1.90936868 0.98429659 0.61035651 0.18600628
 0.48222722 0.51167043 1.43623533 0.97993358 0.16841544 1.46028435
 1.72604138 1.51471112 0.77945433 1.3880641  0.44306514 0.37067447
 1.90009089 1.00754212 0.17629495 1.23012184 1.63887347 1.18048356
 1.43498631 1.57030576 1.42996397 1.58063728 1.27593223 1.0