In [3]:
import sys
sys.path.append('D:\PyCharmProjects\VFPUMC02')
sys.path.append(r'C:\Users\Administrator\PycharmProjects\VFPUMC02')
sys.path.append(r'/root/VFPUMC02')
DATA_BASE_PATH = r'/root/VFPUMC02/datasets'
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import os
import pandas as pd
from fate.arch.dataframe import PandasReader
from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest
from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost
from fate.arch import Context
from fate.arch.dataframe import DataFrame
from datetime import datetime
from fate.arch.context import create_context

In [5]:
def create_ctx(party, session_id='test_fate'):
    parties = [("guest", "9999"), ("host", "10000")]
    if party == "guest":
        local_party = ("guest", "9999")
    else:
        local_party = ("host", "10000")
    context = create_context(local_party, parties=parties, federation_session_id=session_id)
    return context

In [6]:
def train(ctx: Context, data: DataFrame, num_trees: int = 3, objective: str = 'binary:bce', max_depth: int = 3, learning_rate: float=0.3):
    
    if ctx.is_on_guest:
        bst = HeteroSecureBoostGuest(num_trees=num_trees, objective=objective, \
            max_depth=max_depth, learning_rate=learning_rate)
    else:
        bst = HeteroSecureBoostHost(num_trees=num_trees, max_depth=max_depth)

    bst.fit(ctx, data)

    return bst

In [7]:
def predict(ctx: Context, data: DataFrame, model_dict: dict):
    ctx = ctx.sub_ctx('predict')
    if ctx.is_on_guest:
        bst = HeteroSecureBoostGuest()
    else:
        bst = HeteroSecureBoostHost()
    bst.from_model(model_dict)
    return bst.predict(ctx, data)

In [8]:
def csv_to_df(ctx, file_path, has_label=True):

    df = pd.read_csv(file_path)
    df["sample_id"] = [i for i in range(len(df))]
    if has_label:
        reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32") 
    else:
        reader = PandasReader(sample_id_name="sample_id", match_id_name="id", dtype="float32")

    fate_df = reader.to_frame(ctx, df)
    return fate_df

In [10]:
A_host_path = os.path.join(DATA_BASE_PATH,'breast_hetero_host.csv')
B_guest_path = os.path.join(DATA_BASE_PATH,'breast_hetero_guest.csv')

In [21]:
ctx = create_ctx('host')
A_data = csv_to_df(ctx, A_host_path, has_label=False)
print('host_data')
A_data.as_pd_df().head()

host_data


Unnamed: 0,sample_id,id,x0,x1,x2,x3,x4,x5,x6,x7,...,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19
0,0,133.0,0.449512,-1.247226,0.413178,0.303781,-0.123848,-0.184227,-0.219076,0.268537,...,-0.33736,-0.728193,-0.442587,-0.272757,-0.608018,-0.577235,-0.501126,0.143371,-0.466431,-0.554102
1,5,274.0,1.080023,1.20783,0.956888,0.978402,-0.555822,-0.645696,-0.399365,-0.038153,...,0.057848,0.392164,-0.050027,0.120414,-0.532348,-0.770613,-0.519694,-0.531097,-0.769127,-0.394858
2,6,420.0,-0.726307,-0.058095,-0.73191,-0.697343,-0.775723,-0.513983,-0.426233,-0.893482,...,-0.428673,0.404865,-0.32675,-0.44085,0.07901,-0.279903,0.416992,-0.486165,-0.225484,-0.172446
3,7,76.0,-0.169639,-1.943019,-0.167192,-0.27215,2.329937,0.006804,-0.251467,0.429234,...,0.017786,-0.368046,-0.105966,-0.169129,2.11976,0.162743,-0.672216,-0.577002,0.626908,0.896114
4,8,315.0,-0.465014,-0.567723,-0.526371,-0.492852,-0.800631,-1.250816,-1.058714,-1.096145,...,-0.843011,-0.910353,-0.90049,-0.608283,-0.704355,-1.255622,-0.970629,-1.363557,-0.800607,-0.927058


In [22]:
ctx = create_ctx('guest')
B_data = csv_to_df(ctx,B_guest_path,has_label=True)
print('guest_data')
B_data.as_pd_df().head()

guest_data


Unnamed: 0,sample_id,id,y,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9
0,0,133.0,1,0.254879,-1.046633,0.209656,0.074214,-0.441366,-0.377645,-0.485934,0.347072,-0.28757,-0.733474
1,5,274.0,0,0.963102,1.467675,0.829202,0.772457,-0.038076,-0.468613,-0.307946,-0.015321,-0.641864,-0.247477
2,6,420.0,1,-0.662496,0.212149,-0.620475,-0.632995,-0.327392,-0.385278,-0.077665,-0.730362,0.217178,-0.06128
3,7,76.0,1,-0.453343,-2.147457,-0.473631,-0.483572,0.558093,-0.740244,-0.89617,-0.617229,-0.308601,-0.666975
4,8,315.0,1,-0.606584,-0.971725,-0.678558,-0.591332,-0.963013,-1.302401,-1.212855,-1.321154,-1.591501,-1.230554


In [24]:
from fate.arch.launchers.multiprocess_launcher import launch

def run(ctx):
    num_tree = 3
    max_depth = 3
    if ctx.is_on_guest:
        data = csv_to_df(ctx, './breast_hetero_guest.csv')
        bst = train(ctx, data, num_trees=num_tree, max_depth=max_depth)
        model_dict = bst.get_model()
        pred = predict(ctx, data, model_dict)
        print(pred.as_pd_df())
    else:
        data = csv_to_df(ctx, './breast_hetero_host.csv', has_label=False)
        bst = train(ctx, data, num_trees=num_tree, max_depth=max_depth)
        model_dict = bst.get_model()
        predict(ctx, data, model_dict)

In [None]:
launch(run)

In [None]:
model_dict = bst.get_model()
# 以 guest 方为例
bst_2 = HeteroSecureBoostGuest()
bst_2.from_model(model_dict)