In [1]:
import pymysql
import numpy as np
import pandas as pd


def connect_to_mysql_and_read_data(
    database_name, table_name, host="localhost", user="beng003", password="12341234"
):
    connection = None
    try:
        connection = pymysql.connect(
            host=host,
            user=user,
            password=password,
            database=database_name,
        )

        with connection.cursor() as cursor:

            cursor.execute(f"SELECT * FROM {table_name};")
            rows = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

        return pd.DataFrame(rows, columns=columns)

    except pymysql.MySQLError as e:
        print("错误：", e)
    finally:
        if connection:
            connection.close()


import secretflow as sf

# Check the version of your SecretFlow
print("The version of SecretFlow: {}".format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(["alice", "bob"], address="local")
spu_config = sf.utils.testing.cluster_def(parties=["alice", "bob"])
spu_device = sf.SPU(spu_config)
alice, bob = sf.PYU("alice"), sf.PYU("bob")

v_alice = alice(connect_to_mysql_and_read_data)(
    database_name="alice_database", table_name="alice_iris"
)
v_bob = bob(connect_to_mysql_and_read_data)(
    database_name="bob_database", table_name="bob_iris"
)

# print(v_alice)
# print(sf.reveal(v_alice))

ab_psi = spu_device.psi_df(
    key="uid",
    dfs=[v_alice, v_bob],
    receiver="alice",
)
ab_psi

The version of SecretFlow: 1.8.0b0


  self.pid = _posixsubprocess.fork_exec(
2024-08-19 19:59:25,478	INFO worker.py:1724 -- Started a local Ray instance.


[<secretflow.device.device.pyu.PYUObject at 0x7f98dbfe1420>,
 <secretflow.device.device.pyu.PYUObject at 0x7f98dbfe14b0>]

[36m(SPURuntime(device_id=None, party=alice) pid=349336)[0m [2024-08-19 19:59:28.888] [info] [launch.cc:164] LEGACY PSI config: {"psi_type":"KKRT_PSI_2PC","broadcast_result":true,"input_params":{"path":"/tmp/tmpmzu663qk/psi-input.csv","select_fields":["uid"],"precheck":true},"output_params":{"path":"/tmp/tmpmzu663qk/psi-output.csv","need_sort":true},"curve_type":"CURVE_25519","bucket_size":1048576}
[36m(SPURuntime(device_id=None, party=alice) pid=349336)[0m [2024-08-19 19:59:28.888] [info] [bucket_psi.cc:400] bucket size set to 1048576
[36m(SPURuntime(device_id=None, party=alice) pid=349336)[0m [2024-08-19 19:59:28.888] [info] [bucket_psi.cc:252] Begin sanity check for input file: /tmp/tmpmzu663qk/psi-input.csv, precheck_switch:true
[36m(SPURuntime(device_id=None, party=alice) pid=349336)[0m [2024-08-19 19:59:28.890] [info] [csv_checker.cc:135] Executing duplicated scripts: LC_ALL=C sort --parallel=8 --buffer-size=1G --stable selected-keys.f0d08efd-e850-439b-bb52-6f5e4ed29ed3 |

In [5]:
spu_device.world_size

2

In [39]:
ab_psi[0].device

PYURuntime(alice)

In [40]:
sf.reveal(ab_psi[0])

Unnamed: 0,sepal_length_cm,sepal_width_cm,petal_length_cm,uid
0,5.4,3.7,1.5,10
1,4.8,3.4,1.6,11
2,4.8,3.0,1.4,12
3,4.3,3.0,1.1,13
4,5.8,4.0,1.2,14
5,5.7,4.4,1.5,15
6,5.4,3.9,1.3,16
7,5.1,3.5,1.4,17
8,5.7,3.8,1.7,18
9,5.1,3.8,1.5,19


In [20]:
sf.reveal(ab_psi[1])

Unnamed: 0,petal_width_cm,target,uid
0,0.2,0,10
1,0.2,0,11
2,0.1,0,12
3,0.1,0,13
4,0.2,0,14
5,0.4,0,15
6,0.4,0,16
7,0.3,0,17
8,0.3,0,18
9,0.3,0,19


In [34]:
from secretflow.data.core import partition
from secretflow.data.core.io import read_csv_wrapper
from typing import Callable, Dict, List, Union
from secretflow.device import PYU, SPU, Device
from secretflow.utils.errors import InvalidArgumentError
from secretflow.utils.random import global_random
from secretflow.data.vertical.dataframe import VDataFrame


def get_keys(
    device: Device, x: Union[str, List[str], Dict[Device, List[str]]] = None
) -> List[str]:
    if x:
        if isinstance(x, str):
            return [x]
        elif isinstance(x, List):
            return x
        elif isinstance(x, Dict):
            if device in x:
                if isinstance(x[device], str):
                    return [x[device]]
                else:
                    return x[device]
        else:
            raise InvalidArgumentError(f"Illegal type for keys,got {type(x)}")
    else:
        return []


# filepath_actual = output_path

# note:默认可以不设置的参数都
converters = None
dtypes = None
usecols = None
no_header = False
backend = "pandas"
delimiter = ","
nrows: int = None
skip_rows_after_header: int = None
keys = "uid"
drop_keys = "uid"

partitions = {}
for device_pyu in ab_psi:
    converter = converters[device_pyu.device] if converters is not None else None
    dtype = dtypes[device_pyu.device] if dtypes is not None else None
    usecol = usecols[device_pyu.device] if usecols is not None else None

    if usecol is None and dtype is not None:
        usecol = dtype.keys()

    if no_header:
        assert usecol is None, "can not use usecol when no_header is True"

    partitions[device_pyu.device] = partition(
        data=device_pyu,
        device=device_pyu.device,
        backend=backend,
        filepath="",
        auto_gen_header_prefix=str(device_pyu.device) if no_header else "",
        delimiter=delimiter,
        usecols=usecol,
        dtype=dtype,
        converters=converter,
        read_backend=backend,
        nrows=nrows,
        skip_rows_after_header=skip_rows_after_header,
    )
if drop_keys:
    for device, part in partitions.items():
        device_drop_key = get_keys(device, drop_keys)
        device_psi_key = get_keys(device, keys)

        if device_drop_key is not None:
            columns_set = set(part.columns)
            device_drop_key_set = set(device_drop_key)
            assert columns_set.issuperset(device_drop_key_set), (
                f"drop_keys = {device_drop_key_set.difference(columns_set)}"
                " can not find on device {device}"
            )

            device_psi_key_set = set(device_psi_key)
            assert device_psi_key_set.issuperset(device_drop_key_set), (
                f"drop_keys = {device_drop_key_set.difference(device_psi_key_set)} "
                f"can not find on device_psi_key_set of device {device},"
                f" which are {device_psi_key_set}"
            )

            partitions[device] = part.drop(columns=device_drop_key)

unique_cols = set()

# data columns must be unique across all devices
if len(partitions):
    parties_length = {}
    for device, part in partitions.items():
        parties_length[device.party] = len(part)
    if len(set(parties_length.values())) > 1:
        raise AssertionError(
            f"number of samples must be equal across all devices, got {parties_length}, "
            # f"input uri {filepath_actual}"
        )

for device, part in partitions.items():
    for col in part.columns:
        assert col not in unique_cols, f"col {col} duplicate in multiple devices"
        unique_cols.add(col)

vdf = VDataFrame(partitions)



INFO:root:Create proxy actor <class 'secretflow.device.proxy.ActorPartitionAgent'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.device.proxy.ActorPartitionAgent'> with party bob.


In [35]:
vdf.columns

['sepal_length_cm',
 'sepal_width_cm',
 'petal_length_cm',
 'petal_width_cm',
 'target']

In [36]:
vdf.mean()

sepal_length_cm    5.21
sepal_width_cm     3.65
petal_length_cm    1.42
petal_width_cm     0.25
target             0.00
dtype: float64

In [37]:
print("**********Alice PSI")
print(vdf)
print(vdf.shape)
print(sf.reveal(vdf))
print(sf.reveal(vdf))

**********Alice PSI
VDataFrame(partitions={PYURuntime(alice): <secretflow.data.core.partition.Partition object at 0x7f276812e620>, PYURuntime(bob): <secretflow.data.core.partition.Partition object at 0x7f275421e170>}, aligned=True)
(10, 5)
VDataFrame(partitions={PYURuntime(alice): <secretflow.data.core.partition.Partition object at 0x7f276812e620>, PYURuntime(bob): <secretflow.data.core.partition.Partition object at 0x7f275421e170>}, aligned=True)
VDataFrame(partitions={PYURuntime(alice): <secretflow.data.core.partition.Partition object at 0x7f276812e620>, PYURuntime(bob): <secretflow.data.core.partition.Partition object at 0x7f275421e170>}, aligned=True)


In [25]:
sf.shutdown()