In [1]:
from bastionlab import Connection, Identity
import polars as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_owner = Identity.create("data_owner")
client = Connection("localhost").client

In [8]:
target_names = ["setosa", "versicolor", "virginica"]
df = pl.read_csv(
    "iris.csv",
    has_header=False,
    new_columns=[
        "sepal length (cm)",
        "sepal width (cm)",
        "petal length (cm)",
        "petal width (cm)",
        "target",
    ],
)

df = df.with_column(
    pl.col("target")
    .map(lambda a: a.apply(lambda b: True if target_names[b] == "virginica" else False))
    .alias("target_class")
)

target_class_df = pl.DataFrame({"classes": target_names})
target_class_df = target_class_df.with_row_count("id")

shape: (150, 5)
┌───────────────────┬──────────────────┬───────────────────┬──────────────────┬────────┐
│ sepal length (cm) ┆ sepal width (cm) ┆ petal length (cm) ┆ petal width (cm) ┆ target │
│ ---               ┆ ---              ┆ ---               ┆ ---              ┆ ---    │
│ f64               ┆ f64              ┆ f64               ┆ f64              ┆ i64    │
╞═══════════════════╪══════════════════╪═══════════════════╪══════════════════╪════════╡
│ 5.1               ┆ 3.5              ┆ 1.4               ┆ 0.2              ┆ 0      │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4.9               ┆ 3.0              ┆ 1.4               ┆ 0.2              ┆ 0      │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4.7               ┆ 3.2              ┆ 1.3               ┆ 0.2              ┆ 0      │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4.6

In [9]:
rdf = client.polars.send_df(df)
classes_rdf = client.polars.send_df(target_class_df)

train_rdf = rdf.select(pl.col(["petal length (cm)", "petal width (cm)"])).collect()
test_rdf = rdf.select(pl.col("target_class")).collect()
print(train_rdf, test_rdf)

FetchableLazyFrame(identifier=2e965adf-68ca-496c-a416-73959fa92427) FetchableLazyFrame(identifier=6f3b146a-05fa-4f72-8527-9f621a62a3d1)


In [10]:
from bastionlab.linfa.trainers import LogisticRegression

trainer = client.linfa.train(train_rdf, test_rdf, LogisticRegression())

In [11]:
res = client.linfa.predict(trainer, [5.5, 4.2])
print(res)

shape: (1, 1)
┌────────────┐
│ prediction │
│ ---        │
│ u64        │
╞════════════╡
│ 1          │
└────────────┘


In [12]:
res = client.linfa.predict_proba(trainer, [5.5, 4.2])
print(res)

shape: (1, 1)
┌────────────┐
│ prediction │
│ ---        │
│ f64        │
╞════════════╡
│ 0.999569   │
└────────────┘


In [13]:
res = client.linfa.cross_validate(trainer)