In [1]:
# ! pip install bastionlab polars

In [2]:
# !wget -O diabetes_train.csv.gz https://github.com/scikit-learn/scikit-learn/raw/98cf537f5c538fdbc9d27b851cf03ce7611b8a48/sklearn/datasets/data/diabetes_data_raw.csv.gz
# !wget -O diabetes_target.csv.gz https://github.com/scikit-learn/scikit-learn/raw/98cf537f5c538fdbc9d27b851cf03ce7611b8a48/sklearn/datasets/data/diabetes_target.csv.gz
# !gzip -d diabetes_train.csv.gz
# !gzip -d diabetes_target.csv.gz

In [3]:
from bastionlab import Connection
import polars as pl

connection = Connection("localhost")

In [4]:
# load dataset
train_df = pl.read_csv(
    "diabetes_train.csv",
    sep=" ",
    has_header=False,
    new_columns=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"],
)
target_df = pl.read_csv(
    "diabetes_target.csv",
    has_header=False,
    new_columns=["target"],
)

In [5]:
from bastionlab.polars.policy import Policy, TrueRule, Log

policy = Policy(
    safe_zone=TrueRule(), unsafe_handling=Log(), savable=False
)
train_rdf = connection.client.polars.send_df(train_df, policy)
target_rdf = connection.client.polars.send_df(target_df, policy)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Perform linear regression on `diabetes` dataset.
from bastionlab.linfa.trainers import LinearRegression
from bastionlab.polars import train_test_split

train_rdf = train_rdf.select(pl.all().cast(pl.Float64)).collect()

print(train_rdf.collect().fetch())

train_array = train_rdf.to_array()
target_array = target_rdf.to_array()

train_X, test_X, train_Y, test_Y = train_test_split(
    train_array, target_array, test_size=0.2, shuffle=True
)

lr = LinearRegression()

lr.fit(train_X, train_Y)

# y_pred = lr.predict(test_X)

# result = pl.DataFrame(
#     {"Actual": test_Y.collect().fetch().to_series(), "Predict": y_pred.to_series()}
# )
# result

shape: (442, 10)
┌──────┬─────┬──────┬───────┬─────┬──────┬──────┬────────┬───────┐
│ age  ┆ sex ┆ bmi  ┆ bp    ┆ ... ┆ s3   ┆ s4   ┆ s5     ┆ s6    │
│ ---  ┆ --- ┆ ---  ┆ ---   ┆     ┆ ---  ┆ ---  ┆ ---    ┆ ---   │
│ f64  ┆ f64 ┆ f64  ┆ f64   ┆     ┆ f64  ┆ f64  ┆ f64    ┆ f64   │
╞══════╪═════╪══════╪═══════╪═════╪══════╪══════╪════════╪═══════╡
│ 59.0 ┆ 2.0 ┆ 32.1 ┆ 101.0 ┆ ... ┆ 38.0 ┆ 4.0  ┆ 4.8598 ┆ 87.0  │
├╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 48.0 ┆ 1.0 ┆ 21.6 ┆ 87.0  ┆ ... ┆ 70.0 ┆ 3.0  ┆ 3.8918 ┆ 69.0  │
├╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 72.0 ┆ 2.0 ┆ 30.5 ┆ 93.0  ┆ ... ┆ 41.0 ┆ 4.0  ┆ 4.6728 ┆ 85.0  │
├╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 24.0 ┆ 1.0 ┆ 25.3 ┆ 84.0  ┆ ... ┆ 40.0 ┆ 5.0  ┆ 4.8903 ┆ 89.0  │
├╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ ...  ┆ ... ┆ ...  ┆ ...   ┆ ... ┆ ...  ┆ ...  ┆ ...    ┆ ...   │
├╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌

_InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.NOT_FOUND
	details = "Could not find dataframe: identifier=e0928d54-6f0e-4bf6-aa35-02ad419b5927"
	debug_error_string = "{"created":"@1676546211.789169287","description":"Error received from peer ipv4:127.0.0.1:50056","file":"src/core/lib/surface/call.cc","file_line":966,"grpc_message":"Could not find dataframe: identifier=e0928d54-6f0e-4bf6-aa35-02ad419b5927","grpc_status":5}"
>