# integrate.ai API Sample Notebook to run training on Batch/Fargate

## Set environment variables (or replace inline) with your IAI credentials
### Generate and manage this token in the UI, in the Tokens page
### Generate AWS session credentials or use the default profile

In [None]:
import os

IAI_TOKEN = os.environ.get("IAI_TOKEN")

## Authenticate to the integrate.ai api client

In [None]:
from integrate_ai_sdk.api import connect

client = connect(token=IAI_TOKEN)

## Get an existing session

In [None]:
training_session = client.session("03e3e38795")
training_session.id

## Sample model config and data schema
You can find the model config and data schema in the [integrate.ai end user tutorial](https://integrate-ai.gitbook.io/integrate.ai-user-documentation/tutorials/end-user-tutorials/model-training-with-a-sample-local-dataset)

In [None]:
model_config = {
    "experiment_name": "test_synthetic_tabular",
    "experiment_description": "test_synthetic_tabular",
    "strategy": {"name": "FedAvg", "params": {}},
    "model": {"params": {"input_size": 15, "hidden_layer_sizes": [6, 6, 6], "output_size": 2}},
    "balance_train_datasets": False,
    "ml_task": {
        "type": "classification",
        "params": {
            "loss_weights": None,
        },
    },
    "optimizer": {"name": "SGD", "params": {"learning_rate": 0.2, "momentum": 0.0}},
    "differential_privacy_params": {"epsilon": 4, "max_grad_norm": 7},
    "save_best_model": {
        "metric": "loss",  # to disable this and save model from the last round, set to None
        "mode": "min",
    },
    "seed": 23,  # for reproducibility
}

data_schema = {
    "predictors": ["x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14"],
    "target": "y",
}

## Create a Training Session

The documentation for [creating a session](https://integrate-ai.gitbook.io/integrate.ai-user-documentation/tutorials/end-user-tutorials/model-training-with-a-sample-local-dataset#create-and-start-the-session) gives a bit more context into the parameters that are used during training session creation.<br />
For this session we are going to be using two training clients and two rounds. 

In [None]:
training_session = client.create_fl_session(
    name="Testing notebook",
    description="I am testing session creation through a notebook",
    min_num_clients=1,
    num_rounds=2,
    package_name="iai_ffnet",
    model_config=model_config,
    data_config=data_schema,
    startup_mode="external",
).start()

training_session.id

### Specifying optional AWS Credentials, Cluster, Task Definition Name and Network Parameters

In [None]:
# Specify the name of your cluster, task definition, network parameters, and batch job definitions
cluster = "iai-fl-server-ecs-cluster"
task_definition = "iai-fl-server-fargate-job"

# private subnet (routed via NAT)
# subnet_id = "subnet-078a952ae6b700fdb"

# Public subnet (routed via IGW)
subnet_id = "subnet-0fa55725fdb875232"

security_group = "sg-099cff22904011b13"
model_storage = "s3://sandbox.integrate.ai"

# tasks for batch jobs
train_path1 = "s3://sandbox.integrate.ai/data/synthetic/train_silo0.parquet"
train_path2 = "s3://sandbox.integrate.ai/data/synthetic/train_silo1.parquet"
test_path = "s3://sandbox.integrate.ai/data/synthetic/test.parquet"
job_queue = "iai-fl-client-batch-job-queue"
job_def = "iai-fl-client-batch-job"


## Run Fargate server and Batch clients

### Create fargate and batch task builders

In [None]:
from integrate_ai_sdk.taskgroup.taskbuilder import aws as taskbuilder_aws

tb = taskbuilder_aws.fargate(cluster=cluster, task_definition=task_definition)

tb_batch = taskbuilder_aws.batch(job_queue=job_queue, cpu_job_definition=job_def)

### Create and start HFL tasks manually

In [None]:
fls = tb.fls(subnet_id, security_group, storage_path=model_storage, client=client)
fls.set_session(training_session)
fls_server = fls.start()

In [None]:
fls_server.status()

In [None]:
hfl = tb_batch.hfl(train_path=train_path1, test_path=test_path, vcpus="2", memory="16384", client=client)
hfl.set_session(training_session)

In [None]:
hfl_context = hfl.start()

In [None]:
hfl_context.status()

### ... or use taskgroup

In [None]:
from integrate_ai_sdk.taskgroup.base import SessionTaskGroup

task_group_context = (
    SessionTaskGroup(training_session)
    .add_task(tb.fls(subnet_id, security_group, storage_path=model_storage, client=client))
    .add_task(tb_batch.hfl(train_path=train_path1, test_path=test_path, vcpus="2", memory="16384", client=client))
    .start()
)

In [None]:
task_group_context.wait(300, polling_interval=5)