# integrate.ai HFL Linear Inference Sample Notebook

In this notebook, we will train an HFL session with the built-in package `iai_linear_inference` which trains a bundle of linear models for the target of interest against a specified list of predictors, obtains the coefficients and variance estimates, and also calculates the p-values from the corresponding hypothesis tests.

## Set environment variables (or replace inline) with your IAI credentials
### Generate and manage this token in the UI, in the Tokens page

In [None]:
import os

IAI_TOKEN = os.environ.get("IAI_TOKEN")

## Authenticate to the integrate.ai api client

In [None]:
from integrate_ai_sdk.api import connect

client = connect(token=IAI_TOKEN)

## Sample model config and data config
To be compatible with the `iai_linear_inference` package, we use the strategy `LogitRegInference` in the `model_config`, if the target of interest is binary, and use `LinearRegInference` if it is continuous.

The `data_config` dictionary should include the following 3 fields (note that the columns in all the fields can be specified as either names/strings or indices/integers):
- `target`: the target column of interest;
- `shared_predictors`: predictor columns that should be included in all linear models (e.g., the confounding factors like age, gender in GWAS);
- `chunked_predictors`: predictor columns that should be included in the linear model one at a time (e.g., the gene expressions in GWAS)

With the example data config below, the session will train 4 logistic regression models with `y` as the target, and `x1, x2` plus any one of `x0, x3, x10, x11` as predictors.

In [None]:
model_config_logit = {
    "strategy": {"name": "LogitRegInference", "params": {}},
    "seed": 23,  # for reproducibility
}

data_config_logit = {
    "target": "y",
    "shared_predictors": ["x1", "x2"],
    "chunked_predictors": ["x0", "x3", "x10", "x11"]
}

## Create a Training Session

The documentation for [creating a session](https://documentation.integrateai.net/#create-and-start-the-training-session) gives a bit more context into the parameters that are used during training session creation.<br />
For this session we are going to be using 2 training clients and 5 rounds.

In [None]:
training_session_logit = client.create_fl_session(
    name="Testing linear inference session",
    description="I am testing linear inference session creation through a notebook",
    min_num_clients=2,
    num_rounds=5,
    package_name="iai_linear_inference",
    model_config=model_config_logit,
    data_config=data_config_logit,
).start()

training_session_logit.id

## Start a training session using iai client
Make sure that the sample data you [downloaded](https://documentation.integrateai.net/#review-the-sample-model-configuration) is saved to your `~/Downloads` directory, otherwise update the `data_path` below to point to the sample data.

In [None]:
data_path = "~/Downloads/synthetic"

In [None]:
import subprocess


client_1 = subprocess.Popen(
    f"iai client train --token {IAI_TOKEN} --session {training_session_logit.id} --train-path {data_path}/train_silo0.parquet --test-path {data_path}/test.parquet --batch-size 1024 --client-name client-1-inference --remove-after-complete",
    shell=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    universal_newlines=True
)

client_2 = subprocess.Popen(
    f"iai client train --token {IAI_TOKEN} --session {training_session_logit.id} --train-path {data_path}/train_silo1.parquet --test-path {data_path}/test.parquet --batch-size 1024 --client-name client-2-inference --remove-after-complete",
    shell=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    universal_newlines=True
)

## Poll for session status


In [None]:
import time

current_round = None
current_status = None
while client_1.poll() is None or client_2.poll() is None:
    output1 = client_1.stdout.readline().strip()
    output2 = client_2.stdout.readline().strip()
    if output1:
        print("silo1: ", output1)
    if output2:
        print("silo2: ", output2)

    # poll for status and round
    if current_status != training_session_logit.status:
        print("Session status: ", training_session_logit.status)
        current_status = training_session_logit.status
    if current_round != training_session_logit.round and training_session_logit.round > 0:
        print("Session round: ", training_session_logit.round)
        current_round = training_session_logit.round
    time.sleep(1)

output1, error1 = client_1.communicate()
output2, error2 = client_2.communicate()

print(
    "client_1 finished with return code: %d\noutput: %s\n  %s"
    % (client_1.returncode, output1, error1)
)
print(
    "client_2 finished with return code: %d\noutput: %s\n  %s"
    % (client_2.returncode, output2, error2)
)

## Session Complete!
Now we can view the training metrics and model details such as the model coefficients and p-values. Note that since there are a bundle of models being trained, the metrics below are the average values of all the models.

In [None]:
training_session_logit.metrics().as_dict()

In [None]:
training_session_logit.metrics().plot()

### Trained models are accessible from the completed session

The `LinearInferenceModel` object can be retrieved using the model's `as_pytorch` method. And the relevant information such as p-values can be accessed directly from the model object.

In [None]:
model_logit = training_session_logit.model().as_pytorch()

In [None]:
pv = model_logit.p_values()
pv

The `.summary` method fetches the coefficient, standard error and p-value of the model corresponding to the specified predictor.

In [None]:
summary_x0 = model_logit.summary("x0")
summary_x0

It is also possible to make predictions with the resulting bundle of models, when the data is loaded by the `ChunkedTabularDataset` from the `iai_linear_inference` package. Note that the predictions will be of shape `(n_samples, n_chunked_predictors)` where each column corresponds to one model from the bundle.

In [None]:
import torch
from torch.utils.data import DataLoader
from integrate_ai_sdk.packages.LinearInference.dataset import ChunkedTabularDataset


ds = ChunkedTabularDataset(path=f"{data_path}/test.parquet", **data_config_logit)
dl = DataLoader(ds, batch_size=len(ds), shuffle=False)
x = torch.tensor(ds.X)
y_pred = model_logit(x)
y_pred