In [1]:
import secretflow as sf

In [2]:
edge_parties_number = 2
edge_party_name = 'edge_party_{i}'
edge_parties = [edge_party_name.format(i=i) for i in range(edge_parties_number)]
server_party_name = 'server_party'
server_party = [server_party_name]
all_parties = edge_parties + server_party

In [3]:
edge_parties

['edge_party_0', 'edge_party_1']

In [4]:
sf.init(parties=all_parties, address='local')

  from .autonotebook import tqdm as notebook_tqdm
2024-12-13 17:10:31,087	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-12-13 17:10:31,257	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
  self.pid = _posixsubprocess.fork_exec(
2024-12-13 17:10:33,184	INFO worker.py:1724 -- Started a local Ray instance.


In [5]:
edge_devices = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]
# use pyu to simulate teeu
edge_tees = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]

server_device = sf.PYU(server_party_name)
server_tee = sf.PYU(server_party_name)

In [6]:
# custom parameters
i = 0
j = 1
m = 100
kappa = 32
u_low = 0.0
u_high = 2.0
# k is the ring size 2^k. usually take k = 32, 64. like size of int
k = 64  # implies use uint64
fxp = 26  # fixed point precision

In [7]:
import jax
import jax.numpy as jnp
import numpy as np
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from typing import Tuple
import hashlib

jax.config.update("jax_enable_x64", True)


def bytes_to_jax_random_key(byte_key):
    seed = int.from_bytes(byte_key[:4], 'big')

    # Create a JAX random key with this seed
    jax_key = jax.random.PRNGKey(seed)

    return jax_key


# Function to encrypt a jnp array using AES-GCM
def encrypt_jnp_array_gcm(jnp_array, key) -> Tuple[bytes, bytes, bytes]:

    # Convert numpy array to bytes
    array_bytes = jnp_array.tobytes()

    # Create AES cipher in GCM mode
    cipher = AES.new(key, AES.MODE_GCM)

    # Encrypt data
    ciphertext, tag = cipher.encrypt_and_digest(array_bytes)

    # Return the ciphertext, tag, and nonce
    return ciphertext, tag, cipher.nonce


# Function to decrypt a jnp array using AES-GCM
def decrypt_to_jnp_array_gcm(ciphertext, tag, nonce, key, dtype, shape):
    # Create AES cipher in GCM mode with the same parameters
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)

    # Decrypt data
    decrypted_data = cipher.decrypt_and_verify(ciphertext, tag)

    # Convert bytes back to numpy array
    decrypted_jnp_array = jnp.frombuffer(decrypted_data, dtype=dtype).reshape(shape)

    return decrypted_jnp_array


# Example usage
if __name__ == "__main__":
    # Generate a random key for AES-256 (32 bytes)
    key = get_random_bytes(32)

    # Create a jnp array
    original_jnp_array = jnp.uint64(
        jnp.array(np.random.uniform(u_low, u_high, (m,)), dtype=jnp.float64)
    )

    # Encrypt the jnp array using AES-GCM
    ciphertext, tag, nonce = encrypt_jnp_array_gcm(original_jnp_array, key)

    # Decrypt to a jnp array
    decrypted_jnp_array = decrypt_to_jnp_array_gcm(
        ciphertext, tag, nonce, key, dtype=jnp.uint64, shape=original_jnp_array.shape
    )

    # Check if the original and decrypted arrays are the same
    print("Original JAX Array:")
    print(original_jnp_array)
    print(original_jnp_array.dtype)
    print("\nDecrypted JAX Array:")
    print(decrypted_jnp_array)
    print(decrypted_jnp_array.dtype)
    print(
        "\nArrays are equal:", jnp.array_equal(original_jnp_array, decrypted_jnp_array)
    )



Original JAX Array:
[0 1 0 0 0 0 1 0 1 1 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 0 1 0
 1 0 1 0 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 0 1 0 1 0 0 1 0 1 1 1 1 1 1 0 1 0 1
 1 0 0 1 0 0 1 0 0 1 0 1 1 1 1 1 0 1 0 1 0 1 0 1 0 1]
uint64

Decrypted JAX Array:
[0 1 0 0 0 0 1 0 1 1 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 0 1 0
 1 0 1 0 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 0 1 0 1 0 0 1 0 1 1 1 1 1 1 0 1 0 1
 1 0 0 1 0 0 1 0 0 1 0 1 1 1 1 1 0 1 0 1 0 1 0 1 0 1]
uint64

Arrays are equal: True


In [8]:
def fxp_mul(a, b):
    return a * b


fxp_type = jnp.uint64
a = 1000
b = 1000
fxp = 26
print("float mul a * b", a * b)
a_fxp = fxp_type(a * 2.0**fxp)
b_fxp = fxp_type(b * 2.0**fxp)

print(a_fxp)
print(b_fxp)
print("fxp mul without scaling: ", a_fxp * b_fxp)

print("fxp mul with scaling but float: ", a_fxp* 1.0 * (b_fxp*1.0))

print(fxp_mul(a_fxp, b_fxp) / 2.0 ** (fxp))

float mul a * b 1000000
67108864000
67108864000
fxp mul without scaling:  2594073385365405696
fxp mul with scaling but float:  4.503599627370496e+21
38654705664.0


In [9]:
import numpy as np
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)


set_ups = [
    edge_devices[i](lambda: jax.config.update("jax_enable_x64", True))(),
    edge_devices[j](lambda: jax.config.update("jax_enable_x64", True))(),
    server_device(lambda: jax.config.update("jax_enable_x64", True))(),
    edge_tees[i](lambda: jax.config.update("jax_enable_x64", True))(),
    edge_tees[j](lambda: jax.config.update("jax_enable_x64", True))(),
    server_tee(lambda: jax.config.update("jax_enable_x64", True))(),
]
sf.wait(set_ups)

# Simulate handles

handle_i_j = edge_tees[i](lambda x: get_random_bytes(x))(kappa)
handle_i_s = edge_tees[i](lambda x: get_random_bytes(x))(kappa)

# note that the establishment is not simplified.
handle_j_i = handle_i_j.to(edge_tees[j])
handle_j_s = edge_tees[j](lambda x: get_random_bytes(x))(kappa)

server_handle_i_s = handle_i_s.to(server_tee)
server_handle_j_s = handle_j_s.to(server_tee)

In [10]:
# P_i holds u_i
u_i = edge_devices[i](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))

# P_j holds u_j
u_j = edge_devices[j](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))


# corr preprocessing
def corr(k, m, dev1, key1, dev2, key2, return_zero_sharing=False):
    """Correlation function

    Args:
        k (int): ring size will be 2^k. Support k = 64 or 128 for now
        m (int): size of array to be correlated
        dev1 (Device): device 1
        key1 (Key): key for device 1
        dev2 (Device): device 2
        key2 (Key): key for device 2, key2 must be the same as key1 yet hold by different device
    """
    assert k == 64, "Only support k = 64 for now"
    dtype = jnp.uint64
    corr_dev1 = dev1(
        lambda key, shape, dtype: dtype(jax.random.bits(
            bytes_to_jax_random_key(key), shape)
        )
    )(key1, (m,), dtype)
    if not return_zero_sharing:
        corr_dev2 = dev2(
            lambda key, shape, dtype: dtype(jax.random.bits(
                bytes_to_jax_random_key(key), shape
            ))
        )(key2, (m,), dtype)
    else:
        corr_dev2 = dev2(
            lambda key, shape, dtype: dtype(-jax.random.bits(
                bytes_to_jax_random_key(key), shape
            ))
        )(key2, (m,), dtype)
    return corr_dev1, corr_dev2


def cos_sim(u_i, u_j, verbose=False):
    # programming details not related to protocol
    if verbose:
        print("input: ", sf.reveal(u_i), sf.reveal(u_j))
    
    fxp_type = jnp.uint64
    shape_ref_i = edge_tees[i](lambda x: x.shape)(u_i.to(edge_tees[i]))
    shape_ref_j = edge_tees[j](lambda x: x.shape)(u_j.to(edge_tees[j]))
    # let's suppose that the reference type and shape are the same for u_i and u_j and it is ok to share to server
    shape_ref_server = shape_ref_i.to(server_tee)

    # preprocessing
    server_a, edge_tee_i_a = corr(
        k, m, server_tee, server_handle_i_s, edge_tees[i], handle_i_s
    )
    server_b, edge_tee_j_b = corr(
        k, m, server_tee, server_handle_j_s, edge_tees[j], handle_j_s
    )
    c = server_tee(lambda a, b: a * b)(server_a, server_b)
    if verbose:
        print()
        print("preprocessing: ")
        print("server_a: ", sf.reveal(server_a))
        print("server_b: ", sf.reveal(server_b))

        print("edge_tee_i_a: ", sf.reveal(server_a))
        print("edge_tee_j_b: ", sf.reveal(edge_tee_j_b))
        print("c: ", sf.reveal(c))
    
    # normalize u_i and u_j
    u_i_normalized = edge_tees[i](
        lambda x: jnp.array(x / jnp.linalg.norm(x) * (2.0**fxp), dtype=fxp_type)
    )(u_i.to(edge_tees[i]))
    u_j_normalized = edge_tees[j](
        lambda x: jnp.array(x / jnp.linalg.norm(x) * (2.0**fxp), dtype=fxp_type)
    )(u_j.to(edge_tees[j]))
    
    if verbose:
        print()
        print("step 1:")
        print("fxp", fxp)
        print("u_i_normalized", sf.reveal(u_i_normalized))
        print("u_j_normalized", sf.reveal(u_j_normalized))
    
    # E_i encrypts e = u_i_normalized - a, sends to P_j via P_i
    e = edge_tees[i](lambda x, y: x - y)(u_i_normalized, edge_tee_i_a)
    
    if verbose:
        print()
        print("step 2:")
        print("a: ", sf.reveal(edge_tee_i_a))
        print("e = u_i_normalized - a: ", sf.reveal(e))
    
    c_e, c_e_tag, c_e_nouce = edge_tees[i](encrypt_jnp_array_gcm, num_returns=3)(
        e, handle_i_j
    )
    c_e_j = c_e.to(edge_devices[i]).to(edge_devices[j])
    c_e_tag_j = c_e_tag.to(edge_devices[i]).to(edge_devices[j])
    c_e_nouce_j = c_e_nouce.to(edge_devices[i]).to(edge_devices[j])

    # E_j encrypts f = u_j_normalized - b, sends to P_i via P_j
    f = edge_tees[j](lambda x, y: x - y)(u_j_normalized, edge_tee_j_b)
    c_f, c_f_tag, c_f_nouce = edge_tees[j](encrypt_jnp_array_gcm, num_returns=3)(
        f, handle_j_i
    )
    
    if verbose:
        print()
        print("step 3:")
        print("u_j_normalized: ", sf.reveal(u_j_normalized))
        print("b: ", sf.reveal(edge_tee_j_b))
        print("f = u_j_normalized - b: ", sf.reveal(f))

    c_f_i = c_f.to(edge_devices[j]).to(edge_devices[i])
    c_f_tag_i = c_f_tag.to(edge_devices[j]).to(edge_devices[i])
    c_f_nouce_i = c_f_nouce.to(edge_devices[j]).to(edge_devices[i])

    # P_i decrypts f in E_i
    try:
        f_dec = edge_tees[i](decrypt_to_jnp_array_gcm)(
            c_f_i.to(edge_tees[i]),
            c_f_tag_i.to(edge_tees[i]),
            c_f_nouce_i.to(edge_tees[i]),
            handle_i_j,
            fxp_type,
            shape_ref_i,
        )
        sf.wait(f_dec)
    except:
        raise RuntimeError("Error in decrypting f, abort")

    
    # P_i decrypts c in E_j
    try:
        e_dec = edge_tees[j](decrypt_to_jnp_array_gcm)(
            c_e_j.to(edge_tees[j]),
            c_e_tag_j.to(edge_tees[j]),
            c_e_nouce_j.to(edge_tees[j]),
            handle_j_i,
            fxp_type,
            shape_ref_j,
        )
        sf.wait(e_dec)
    except:
        raise RuntimeError("Error in decrypting c, abort")

    edge_tee_i_a_1, edge_tee_j_a_1 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_b_0, edge_tee_j_b_0 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_d, edge_tee_j_d = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i, True
    )
    
    # lots of question here
    edge_tee_i_a_0 = edge_tees[i](lambda x, y: fxp_type(x - y))(edge_tee_i_a, edge_tee_i_a_1)
    edge_tee_j_b_1 = edge_tees[j](lambda x, y: fxp_type(x - y))(edge_tee_j_b, edge_tee_j_b_0)
    
    if verbose:
        print()
        print("step 6:")
        print("a_0: ", sf.reveal(edge_tee_i_a_0))
        print("b_1: ", sf.reveal(edge_tee_j_b_1))
        print("a_1: ", sf.reveal(edge_tee_i_a_1))
        print("b_0: ", sf.reveal(edge_tee_j_b_0))
        print("d_0: ", sf.reveal(edge_tee_i_d))
        print("d_1: ", sf.reveal(edge_tee_j_d))
    
       
    
    # E_i computes:
    z_bracket_0 = edge_tees[i](
        lambda x1, x2, x3, x4, x5: 
            fxp_type(fxp_mul(x1, x2))
            + fxp_type(fxp_mul(x3, x4))
            + fxp_type(fxp_mul(x1, x3))
            + fxp_type(x5)
    )(e, edge_tee_i_b_0, f_dec, edge_tee_i_a_0, edge_tee_i_d)
    
    if verbose:
        print()
        print("step 7:")
        print("e", sf.reveal(e))
        print("f_dec", sf.reveal(f_dec))
        print("edge_tee_i_d", sf.reveal(edge_tee_i_d))
        print("z_bracket_0", sf.reveal(z_bracket_0))
    
    # E_i encrypts z_bracket_0, sends to server tee via server
    c_z_bracket_0, c_z_bracket_0_tag, c_z_bracket_0_nouce = edge_tees[i](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_0, handle_i_s)
    c_z_bracket_0_server = c_z_bracket_0.to(server_device).to(server_tee)
    c_z_bracket_0_tag_server = c_z_bracket_0_tag.to(server_device).to(server_tee)
    c_z_bracket_0_nouce_server = c_z_bracket_0_nouce.to(server_device).to(server_tee)

    # E_j computes:
    z_bracket_1 = edge_tees[j](
        lambda x1, x2, x3, x4, x5: fxp_type(
            fxp_type(fxp_mul(x1, x2))
            + fxp_type(fxp_mul(x3, x4))
            + fxp_type(x5)
        )
    )(e_dec, edge_tee_j_b_1, f, edge_tee_j_a_1, edge_tee_j_d)

    # E_j encrypts z_bracket_1, sends to server tee via server
    c_z_bracket_1, c_z_bracket_1_tag, c_z_bracket_1_nouce = edge_tees[j](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_1, handle_j_s)
    c_z_bracket_1_server = c_z_bracket_1.to(server_device).to(server_tee)
    c_z_bracket_1_tag_server = c_z_bracket_1_tag.to(server_device).to(server_tee)
    c_z_bracket_1_nouce_server = c_z_bracket_1_nouce.to(server_device).to(server_tee)
    
    if verbose:
        print()
        print("step 8:")
        print("e_dec", sf.reveal(e_dec))
        print("f", sf.reveal(f))
        print("edge_tee_j_d", sf.reveal(edge_tee_j_d))
        print("z_bracket_1", sf.reveal(z_bracket_1))
    

    # server tries to decrypt
    try:
        z_bracket_0_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_0_server,
            c_z_bracket_0_tag_server,
            c_z_bracket_0_nouce_server,
            server_handle_i_s,
            fxp_type,
            shape_ref_server,
        )
        z_bracket_1_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_1_server,
            c_z_bracket_1_tag_server,
            c_z_bracket_1_nouce_server,
            server_handle_j_s,
            fxp_type,
            shape_ref_server,
        )
        sf.wait([z_bracket_0_dec, z_bracket_1_dec])
    except:
        raise Exception("Decryption z values failed")

    z = server_tee(lambda x, y: fxp_type(x + y))(z_bracket_0_dec, z_bracket_1_dec)
    cos_sim_val = server_tee(
            lambda x, y: jnp.sum(fxp_type(x + y))
        )(z, c)
    
    if verbose:
        print()
        print("step 10:")
        print("z_bracket_0_dec", sf.reveal(z_bracket_0_dec))
        print("z_bracket_1_dec", sf.reveal(z_bracket_1_dec))
        print("z", sf.reveal(z))
        print("c", sf.reveal(c))
        print("cos_sim_val", sf.reveal(cos_sim_val))
        print("fxp", fxp)
        print("cos_sim_val / 2^(2*fxp)", sf.reveal(cos_sim_val) / (2.**(2*fxp)))
    return cos_sim_val

In [11]:
sf.reveal(cos_sim(u_i, u_j)) / (2.**(2*fxp))

[36m(pyu_fn pid=1844138)[0m 2024-12-13 17:10:35,447,447 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'cuda': 
[36m(pyu_fn pid=1844138)[0m 2024-12-13 17:10:35,448,448 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(pyu_fn pid=1844138)[0m 2024-12-13 17:10:35,448,448 INFO [xla_bridge.py:backends:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Array(0.73825558, dtype=float64, weak_type=True)

In [12]:
# import required libraries
import numpy as np
from numpy.linalg import norm

# define two lists or array
u_i_revealed = sf.reveal(u_i)
u_j_revealed = sf.reveal(u_j)
A = u_i_revealed/ norm(u_i_revealed)
B = u_j_revealed / norm(u_j_revealed)


# compute cosine similarity
cosine = np.dot(A, B) / (norm(A) * norm(B))
cosine2 = jnp.sum(A * B)
print("Cosine Similarity:", cosine2)

Cosine Similarity: 0.7382557166568481
