In [1]:
import secretflow as sf

In [2]:
edge_parties_number = 2
edge_party_name = 'edge_party_{i}'
edge_parties = [edge_party_name.format(i=i) for i in range(edge_parties_number)]
server_party_name = 'server_party'
server_party = [server_party_name]
all_parties = edge_parties + server_party

In [3]:
edge_parties

['edge_party_0', 'edge_party_1']

In [4]:
sf.init(parties=all_parties, address='local')

  from .autonotebook import tqdm as notebook_tqdm
2024-11-27 15:26:25,302	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-27 15:26:25,471	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
  self.pid = _posixsubprocess.fork_exec(
2024-11-27 15:26:27,376	INFO worker.py:1724 -- Started a local Ray instance.


In [5]:
edge_devices = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]
# use pyu to simulate teeu
edge_tees = [sf.PYU(edge_party_name.format(i=i)) for i in range(edge_parties_number)]

server_device = sf.PYU(server_party_name)
server_tee = sf.PYU(server_party_name)

In [6]:
# custom parameters
i = 0
j = 1
m = 100
kappa = 32
u_low = 0.0
u_high = 2.0
# k is the ring size 2^k. usually take k = 32, 64. like size of int
k = 64  # implies use uint64
fxp = 26  # fixed point precision

In [7]:
import jax
import jax.numpy as jnp
import numpy as np
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from typing import Tuple
import hashlib

jax.config.update("jax_enable_x64", True)


def bytes_to_jax_random_key(byte_key):
    # Hash the bytes key using SHA-256
    hash_digest = hashlib.sha256(byte_key).digest()

    # Convert the first 4 bytes of the hash digest to a 64-bit integer
    seed = int.from_bytes(hash_digest[:4], 'big')

    # Create a JAX random key with this seed
    jax_key = jax.random.PRNGKey(seed)

    return jax_key


# Function to encrypt a jnp array using AES-GCM
def encrypt_jnp_array_gcm(jnp_array, key) -> Tuple[bytes, bytes, bytes]:

    # Convert numpy array to bytes
    array_bytes = jnp_array.tobytes()

    # Create AES cipher in GCM mode
    cipher = AES.new(key, AES.MODE_GCM)

    # Encrypt data
    ciphertext, tag = cipher.encrypt_and_digest(array_bytes)

    # Return the ciphertext, tag, and nonce
    return ciphertext, tag, cipher.nonce


# Function to decrypt a jnp array using AES-GCM
def decrypt_to_jnp_array_gcm(ciphertext, tag, nonce, key, dtype, shape):
    # Create AES cipher in GCM mode with the same parameters
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)

    # Decrypt data
    decrypted_data = cipher.decrypt_and_verify(ciphertext, tag)

    # Convert bytes back to numpy array
    decrypted_jnp_array = jnp.frombuffer(decrypted_data, dtype=dtype).reshape(shape)

    return decrypted_jnp_array


# Example usage
if __name__ == "__main__":
    # Generate a random key for AES-256 (32 bytes)
    key = get_random_bytes(32)

    # Create a jnp array
    original_jnp_array = jnp.uint64(
        jnp.array(np.random.uniform(u_low, u_high, (m,)), dtype=jnp.float64)
    )

    # Encrypt the jnp array using AES-GCM
    ciphertext, tag, nonce = encrypt_jnp_array_gcm(original_jnp_array, key)

    # Decrypt to a jnp array
    decrypted_jnp_array = decrypt_to_jnp_array_gcm(
        ciphertext, tag, nonce, key, dtype=jnp.uint64, shape=original_jnp_array.shape
    )

    # Check if the original and decrypted arrays are the same
    print("Original JAX Array:")
    print(original_jnp_array)
    print(original_jnp_array.dtype)
    print("\nDecrypted JAX Array:")
    print(decrypted_jnp_array)
    print(decrypted_jnp_array.dtype)
    print(
        "\nArrays are equal:", jnp.array_equal(original_jnp_array, decrypted_jnp_array)
    )



Original JAX Array:
[0 1 1 0 1 0 0 1 0 0 1 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 1 1 0 1 1 0 1 1 1 0
 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 1 1 0 0 1 1 0 0 0 0
 0 1 1 0 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1 0 1 0 0 0 1 1]
uint64

Decrypted JAX Array:
[0 1 1 0 1 0 0 1 0 0 1 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 1 1 0 1 1 0 1 1 1 0
 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 1 1 0 0 1 1 0 0 0 0
 0 1 1 0 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1 0 1 0 0 0 1 1]
uint64

Arrays are equal: True


In [8]:

def fxp_mul(a, b, fxp, fxp_type):
    return fxp_type(a * (b / 2.**fxp))

fxp_type = jnp.uint64
a = 11923.1
b = 34712.9
fxp = 26
print(a * b)
a_fxp = fxp_type(a * 2. ** fxp)
b_fxp = fxp_type(b * 2. ** fxp)

print(fxp_mul(a_fxp, b_fxp, fxp, fxp_type) / 2. ** fxp)



413885377.99
413885377.9896865


In [29]:
a = jnp.uint64(0)
b = a - 1
# check natural overflow
print(b)
print(-b)
print(b - b)

18446744073709551615
1
0


In [None]:
import numpy as np
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


set_ups = [
    edge_devices[i](lambda: jax.config.update("jax_enable_x64", True))(),
    edge_devices[j](lambda: jax.config.update("jax_enable_x64", True))(),
    server_device(lambda: jax.config.update("jax_enable_x64", True))(),
    edge_tees[i](lambda: jax.config.update("jax_enable_x64", True))(),
    edge_tees[j](lambda: jax.config.update("jax_enable_x64", True))(),
    server_tee(lambda: jax.config.update("jax_enable_x64", True))(),
]
sf.wait(set_ups)

# Simulate handles

handle_i_j = edge_tees[i](lambda x: get_random_bytes(x))(kappa)
handle_i_s = edge_tees[i](lambda x: get_random_bytes(x))(kappa)

# note that the establishment is not simplified.
handle_j_i = handle_i_j.to(edge_tees[j])
handle_j_s = edge_tees[j](lambda x: get_random_bytes(x))(kappa)

server_handle_i_s = handle_i_s.to(server_tee)
server_handle_j_s = handle_j_s.to(server_tee)


# P_i holds u_i
u_i = edge_devices[i](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))

# P_j holds u_j
u_j = edge_devices[j](lambda x: x)(jnp.array(np.random.uniform(u_low, u_high, (m,))))


# corr preprocessing
def corr(k, m, dev1, key1, dev2, key2, return_zero_sharing=False):
    """Correlation function

    Args:
        k (int): ring size will be 2^k. Support k = 64 or 128 for now
        m (int): size of array to be correlated
        dev1 (Device): device 1
        key1 (Key): key for device 1
        dev2 (Device): device 2
        key2 (Key): key for device 2, key2 must be the same as key1 yet hold by different device
    """
    assert k == 64, "Only support k = 64 for now"
    dtype = jnp.uint64
    corr_dev1 = dev1(
        lambda key, shape, dtype: jax.random.bits(
            bytes_to_jax_random_key(key), shape, dtype
        )
    )(key1, (m,), dtype)
    if not return_zero_sharing:
        corr_dev2 = dev2(
            lambda key, shape, dtype: jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    else:
        corr_dev2 = dev2(
            lambda key, shape, dtype: -jax.random.bits(
                bytes_to_jax_random_key(key), shape, dtype
            )
        )(key2, (m,), dtype)
    return corr_dev1, corr_dev2


def cos_sim(u_i, u_j):
    # proramming details not related to protocol
    fixed_point_type = jnp.uint64
    shape_ref_i = edge_tees[i](lambda x: x.shape)(u_i.to(edge_tees[i]))
    shape_ref_j = edge_tees[j](lambda x: x.shape)(u_j.to(edge_tees[j]))
    # let's suppose that the reference type and shape are the same for u_i and u_j and it is ok to share to server
    shape_ref_server = shape_ref_i.to(server_tee)

    # preprocessing
    server_a, edge_tee_i_a = corr(
        k, m, server_tee, server_handle_i_s, edge_tees[i], handle_i_s
    )
    server_b, edge_tee_j_b = corr(
        k, m, server_tee, server_handle_j_s, edge_tees[j], handle_j_s
    )
    c = server_tee(lambda a, b: a * b)(server_a, server_b)

    # normalize u_i and u_j
    u_i_normalized = edge_tees[i](
        lambda x: jnp.array(x / jnp.linalg.norm(x) * (2.0**fxp), dtype=fixed_point_type)
    )(u_i.to(edge_tees[i]))
    u_j_normalized = edge_tees[j](
        lambda x: fixed_point_type(x / jnp.linalg.norm(x) * (2.0**fxp))
    )(u_j.to(edge_tees[j]))
    

    # E_i encrypts e = u_i_normalized - a, sends to P_j via P_i
    e = edge_tees[i](lambda x, y: x - y)(u_i_normalized, edge_tee_i_a)
    c_e, c_e_tag, c_e_nouce = edge_tees[i](encrypt_jnp_array_gcm, num_returns=3)(
        e, handle_i_j
    )
    c_e_j = c_e.to(edge_devices[i]).to(edge_devices[j])
    c_e_tag_j = c_e_tag.to(edge_devices[i]).to(edge_devices[j])
    c_e_nouce_j = c_e_nouce.to(edge_devices[i]).to(edge_devices[j])

    # E_j encrypts f = u_j_normalized - b, sends to P_i via P_j
    f = edge_tees[j](lambda x, y: x - y)(u_j_normalized, edge_tee_j_b)
    c_f, c_f_tag, c_f_nouce = edge_tees[j](encrypt_jnp_array_gcm, num_returns=3)(
        f, handle_j_i
    )

    c_f_i = c_f.to(edge_devices[j]).to(edge_devices[i])
    c_f_tag_i = c_f_tag.to(edge_devices[j]).to(edge_devices[i])
    c_f_nouce_i = c_f_nouce.to(edge_devices[j]).to(edge_devices[i])

    # P_i decrypts f in E_i
    try:
        f_dec = edge_tees[i](decrypt_to_jnp_array_gcm)(
            c_f_i.to(edge_tees[i]),
            c_f_tag_i.to(edge_tees[i]),
            c_f_nouce_i.to(edge_tees[i]),
            handle_i_j,
            fixed_point_type,
            shape_ref_i,
        )
        sf.wait(f_dec)
    except:
        raise RuntimeError("Error in decrypting f, abort")

    
    # P_i decrypts c in E_j
    try:
        e_dec = edge_tees[j](decrypt_to_jnp_array_gcm)(
            c_e_j.to(edge_tees[j]),
            c_e_tag_j.to(edge_tees[j]),
            c_e_nouce_j.to(edge_tees[j]),
            handle_j_i,
            fixed_point_type,
            shape_ref_j,
        )
        sf.wait(e_dec)
    except:
        raise RuntimeError("Error in decrypting c, abort")

    edge_tee_i_u_i_bracket_1, edge_tee_j_u_i_bracket_1 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_u_j_bracket_0, edge_tee_j_u_j_bracket_0 = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i
    )
    edge_tee_i_d, edge_tee_j_d = corr(
        k, m, edge_tees[i], handle_i_j, edge_tees[j], handle_j_i, True
    )
    
    # lots of question here
    u_i_bracket_0 = edge_tees[i](lambda x, y: x - y)(u_i_normalized, edge_tee_i_u_i_bracket_1)
    u_j_bracket_1 = edge_tees[j](lambda x, y: x - y)(u_j_normalized, edge_tee_j_u_j_bracket_0)
    
    print("[u_i]_0", sf.reveal(u_i_bracket_0))
    
    # E_i computes:
    z_bracket_0 = edge_tees[i](
        lambda x1, x2, x3, x4, x5: fixed_point_type(
            fxp_mul(x1, x2, fxp, fixed_point_type)
            + fxp_mul(x3, x4, fxp, fixed_point_type)
            + fxp_mul(x1, x3, fxp, fixed_point_type)
            + x5
        )
    )(e, edge_tee_i_u_j_bracket_0, f_dec, u_i_bracket_0, edge_tee_i_d)

    # E_i encrypts z_bracket_0, sends to server tee via server
    c_z_bracket_0, c_z_bracket_0_tag, c_z_bracket_0_nouce = edge_tees[i](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_0, handle_i_s)
    c_z_bracket_0_server = c_z_bracket_0.to(server_device).to(server_tee)
    c_z_bracket_0_tag_server = c_z_bracket_0_tag.to(server_device).to(server_tee)
    c_z_bracket_0_nouce_server = c_z_bracket_0_nouce.to(server_device).to(server_tee)

    # E_j computes:
    z_bracket_1 = edge_tees[j](
        lambda x1, x2, x3, x4, x5: fixed_point_type(
            fxp_mul(x1, x2, fxp, fixed_point_type)
            + fxp_mul(x3, x4, fxp, fixed_point_type)
            + x5
        )
    )(e_dec, u_j_bracket_1, f, edge_tee_j_u_i_bracket_1, edge_tee_j_d)

    # E_j encrypts z_bracket_1, sends to server tee via server
    c_z_bracket_1, c_z_bracket_1_tag, c_z_bracket_1_nouce = edge_tees[j](
        encrypt_jnp_array_gcm, num_returns=3
    )(z_bracket_1, handle_j_s)
    c_z_bracket_1_server = c_z_bracket_1.to(server_device).to(server_tee)
    c_z_bracket_1_tag_server = c_z_bracket_1_tag.to(server_device).to(server_tee)
    c_z_bracket_1_nouce_server = c_z_bracket_1_nouce.to(server_device).to(server_tee)

    # server tries to decrypt
    try:
        z_bracket_0_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_0_server,
            c_z_bracket_0_tag_server,
            c_z_bracket_0_nouce_server,
            server_handle_i_s,
            fixed_point_type,
            shape_ref_server,
        )
        z_bracket_1_dec = server_tee(decrypt_to_jnp_array_gcm)(
            c_z_bracket_1_server,
            c_z_bracket_1_tag_server,
            c_z_bracket_1_nouce_server,
            server_handle_j_s,
            fixed_point_type,
            shape_ref_server,
        )
        sf.wait([z_bracket_0_dec, z_bracket_1_dec])
    except:
        raise Exception("Decryption z values failed")

    z = server_tee(lambda x, y: x + y)(z_bracket_0_dec, z_bracket_1_dec)
    cos_sim_val = server_tee(
        lambda x, y: jnp.sum(fixed_point_type(x + y) / 2. ** fxp)
    )(z, c)
    return cos_sim_val

In [26]:
sf.reveal(cos_sim(u_i, u_j))

[u_i]_0 [ 5169062224848403315 17978026289647267792 13672065781242901756
  4087741131154505715  1183695357040722497 14027211169129552561
 12284502308696450303  7235411937498170472  7735373542400950854
   581642800369894792  2924137005316900573  3864873446383321525
  4487447608251254885 14703952963403251261 15824495735675028973
 17522413515950285156 11014097426761228981 17299480072492948932
 13103955814677909869 16917562960920794702  4685476342624249018
 18281673430761101676 17173322556582609333 16775029077704773129
  6595274472729707751 15145328883993275446 10426321057361546894
  4103705237712594769  2902796022298996863  7550976014758218747
  7397777925427507967  2560862873114356546  8854448826171144645
 15517768026696712402  1981733484964007743   452921484743560381
  8999198461180260165  8311169620528965681  7294980128296928931
 16550117844529837303  8707923708895175307  1487420294768947610
 17385792141263136939  8244111715935836676 14370723539669920333
  3583586949942830324 1386217182

Array(6.4101405e+10, dtype=float64, weak_type=True)

In [12]:
# import required libraries
import numpy as np
from numpy.linalg import norm

# define two lists or array
A = sf.reveal(u_i)
B = sf.reveal(u_j)

print("A:", A)
print("B:", B)

# compute cosine similarity
cosine = np.dot(A, B) / (norm(A) * norm(B))
print("Cosine Similarity:", cosine)

A: [0.32715419 0.09707942 0.24309819 0.7829849  1.82747603 1.69398645
 0.53867555 1.17196586 0.29603094 1.24683992 0.94050465 1.77320876
 0.44908843 0.69587353 1.22740045 0.40731205 0.52691574 0.26812958
 1.24186123 0.98127649 1.58739671 0.62918131 0.57841722 1.91449038
 1.94691465 0.53999149 1.99981756 1.79291908 1.82243655 1.37788273
 0.63233276 0.57559107 1.86854255 1.49110901 1.3441809  0.90413442
 0.78054719 0.72820348 1.28452934 0.12526138 0.22261799 0.93243254
 1.4390098  1.5238524  0.22964797 1.12073907 0.22043995 0.77728024
 0.15637854 0.53509414 0.46638797 0.6048205  1.24583923 0.9890621
 1.25349018 0.82017171 0.18832625 0.64362864 0.57595841 0.7377862
 1.21810718 0.0383983  0.10570283 1.09402011 0.95090879 1.47702946
 0.57646542 1.74285391 0.63546611 0.40176661 1.16982702 0.17611763
 0.95425472 0.64177308 1.76760909 0.78191339 0.4040706  1.02459876
 1.35999207 1.96281719 1.61178474 1.69953212 0.31817473 0.03891332
 1.38275013 1.58419925 0.38820253 0.56136914 0.61197704 1.895