# Federated Learning CIFAR10 Swarm Deployment

This notebook demonstrates how to deploy a federated learning swarm for CIFAR10 image classification on the Manta platform.

## What You'll Learn

1. **Setting up the connection** to the Manta platform
2. **Defining modules and federated learning swarm** with workers, aggregator, and scheduler
3. **Deploying the swarm** to a cluster
4. **Monitoring results** from training and evaluation

## Prerequisites

Before running this notebook, ensure you have completed the following steps:

### 1. Create an Account
- Visit [dashboard.manta-tech.io](https://dashboard.manta-tech.io) and create an account
- Create a new cluster and start it

### 2. Install Manta SDK
Install the Manta SDK in your Python environment:
```bash
pip install manta-sdk
```

### 3. Partition the Dataset
Run the data preparation script to partition the CIFAR10 dataset for multiple nodes:
```bash
python prepare_data.py -n <number_of_nodes>
```
This will create partitioned CIFAR10 data in `temp/partitioned/node_0/`, `temp/partitioned/node_1/`, etc.

### 4. Install and Configure Manta Nodes
Install manta-node on each device that will participate in training:
```bash
pip install manta-node
```

Download node configuration from the Manta dashboard (see "Configure New Node" button on your cluster page) and save it in `~/.manta/nodes/<node_name>.toml`

**Important**: Configure each node's dataset path to point to its partition:
- Node 0: dataset path = `/path/to/temp/partitioned/node_0/cifar10.npz`
- Node 1: dataset path = `/path/to/temp/partitioned/node_1/cifar10.npz`
- etc.

### 5. Start Manta Nodes
Launch each Manta node with its configuration:
```bash
manta node start <node_name>
```

Verify nodes are connected by checking your cluster dashboard.

### 6. Docker Image
Ensure the Docker image `manta_light:pytorch` is available on your nodes, or modify the `image` variable in this notebook.

## Step 1: Import Libraries and Configure Authentication

First, import the necessary libraries and configure your authentication credentials.

**Replace the credentials below with your actual account credentials from dashboard.manta-tech.io**

In [None]:
from manta.apis.async_user_api import AsyncUserAPI
from pathlib import Path

from manta import Module, Swarm, Task
from modules.worker import CNN

# Replace with your credentials from dashboard.manta-tech.io
USERNAME = "your-email@example.com"
PASSWORD = "your-password"

api = await AsyncUserAPI.sign_in(
    USERNAME,
    PASSWORD,
    host="api.manta-tech.io",
    port=443,
)

## Step 2: Connect to Manta Platform and Find Active Cluster

This section establishes a connection to the Manta platform and locates an active cluster for deployment:

1. **Initialize UserAPI**: Creates a connection to the Manta manager service
2. **Check availability**: Verifies the connection is working
3. **Find active cluster**: Searches for a running cluster to deploy the swarm

The cluster API will be used for all subsequent operations including swarm deployment, monitoring, and log streaming.


In [None]:
availability_message = await api.is_available()
print(f"UserAPI availability: {availability_message}")

# Find an active (RUNNING) cluster
print("\nSearching for active cluster...")
async for cluster in api.stream_clusters():
    # Status 1 = RUNNING, 0 = CREATED, 2 = INACTIVE
    if cluster["status"] == 1:
        print("===================Active Cluster Found===================")
        print(f"Cluster ID: {cluster['cluster_id']}")
        print(f"Cluster Name: {cluster['name']}")
        print("Status: RUNNING")
        active_cluster_id = cluster["cluster_id"]
        break
else:
    print("No running cluster found. Please start a cluster from the dashboard.")
    raise RuntimeError("No active cluster available")

In [None]:
# Define the modules that will be used in the swarm
root_path = Path().resolve()
image = "manta_light:pytorch"
gpu = False  # Set to True to use GPU for better performance with CNN

# Aggregator Module
aggregator_module = Module(
    root_path / "modules" / "aggregator.py",
    image,
    datasets=[],
)

# Worker Module (CNN training on CIFAR10)
worker_module = Module(
    root_path / "modules" / "worker.py",
    image,
    datasets=["cifar10"],
)

# Scheduler Module
scheduler_module = Module(
    root_path / "modules" / "scheduler.py",
    image,
    datasets=[],
)

## Step 4: Define the Federated Learning Swarm

Now we'll create the FLSwarm class that uses the modules defined above. This approach separates the module definitions from the swarm logic, making the code more modular and easier to understand.

### CIFAR10 FL Architecture

The swarm implements a simpler workflow compared to MNIST, focusing on the core FL loop:
```
Worker (CNN Training) → Aggregator (FedAvg) → Scheduler (Convergence Check) → Loop/End
```


In [None]:
class FLSwarm(Swarm):
    def __init__(
        self,
        aggregator_module: Module,
        worker_module: Module,
        scheduler_module: Module,
        gpu: bool = False,
    ):
        super().__init__()

        # Create tasks from the provided modules
        self.aggregator = Task(
            aggregator_module,
            method="any",
            fixed=False,
            maximum=1,
            gpu=False,
        )

        self.worker = Task(
            worker_module,
            method="all",
            fixed=False,
            maximum=-1,
            gpu=gpu,
        )

        self.scheduler = Task(
            scheduler_module,
            method="any",
            fixed=False,
            maximum=1,
            gpu=False,
        )

        # Set hyperparameters optimized for CIFAR10 CNN training
        self.set_global(
            "hyperparameters",
            {
                "epochs": 1,
                "batch_size": 64,
                "loss": "CrossEntropyLoss",
                "loss_params": {},
                "optimizer": "SGD",
                "optimizer_params": {
                    "lr": 0.001,
                    "momentum": 0.9,
                    "weight_decay": 5e-4,
                },
                "val_acc_threshold": 0.80,  # Lower threshold for CIFAR10
            },
        )
        # Set global model parameters (CNN weights)
        self.set_global("global_model_params", CNN().get_weights())

    def execute(self):
        """
        Generation of the task graph

        +--------+     +------------+     +-----------+ if has_converged
        | Worker | --> | Aggregator | --> | Scheduler | ----------------> END PROGRAM
        +--------+     +------------+     +-----------+
            |                                   | else
            +--<<<----------<<<----------<<<----+
        """
        m = self.worker()
        m = self.aggregator(m)
        return self.scheduler(m)

In [None]:
# Create the swarm instance using the pre-defined modules
swarm = FLSwarm(
    aggregator_module=aggregator_module,
    worker_module=worker_module,
    scheduler_module=scheduler_module,
    gpu=gpu,
)

print("CIFAR10 Federated Learning Swarm created successfully!")
print(f"Using GPU: {gpu}")
print(f"Image: {image}")
print("Ready for deployment to cluster!")

## Step 5: Deploy the Swarm

Now we can deploy the swarm to the active cluster and monitor its execution.


In [None]:
# Deploy the swarm to the cluster
swarm_overview = await cluster_api.send_swarm(swarm)
print("Swarm deployed successfully!")
print(f"Swarm ID: {swarm_overview['swarm_id']}")
print(f"Status: {swarm_overview['status']}")
print(f"Created at: {swarm_overview['datetime']}")

# Start the swarm execution
start_response = await cluster_api.start_swarm(swarm_overview["swarm_id"])
print("\nSwarm execution started!")
print(f"Start response: {start_response}")