# Federated Learning MNIST Swarm Deployment

This notebook demonstrates how to deploy a federated learning swarm for MNIST classification on the Manta platform.

## What You'll Learn

1. **Setting up the connection** to the Manta platform
2. **Defining modules and federated learning swarm** with workers, aggregator, and scheduler
3. **Deploying the swarm** to a cluster
4. **Monitoring results** from training and evaluation
5. **Streaming logs** for debugging and tracking progress

## Prerequisites

Before running this notebook, ensure you have completed the following steps:

### 1. Create an Account
- Visit [dashboard.manta-tech.io](https://dashboard.manta-tech.io) and create an account
- Create a new cluster and start it

### 2. Install Manta SDK
Install the Manta SDK in your Python environment:
```bash
pip install manta-sdk
```

### 3. Partition the Dataset
Run the data preparation script to partition the MNIST dataset for multiple nodes:
```bash
python prepare_data.py -n <number_of_nodes>
```
This will create partitioned MNIST data in `temp/partitioned/node_0/`, `temp/partitioned/node_1/`, etc.

### 4. Install and Configure Manta Nodes
Install manta-node on each device that will participate in training:
```bash
pip install manta-node
```

Download node configuration from the Manta dashboard (see "Configure New Node" button on your cluster page) and save it in `~/.manta/nodes/<node_name>.toml`

**Important**: Configure each node's dataset path to point to its partition:
- Node 0: dataset path = `/path/to/temp/partitioned/node_0/mnist.npz`
- Node 1: dataset path = `/path/to/temp/partitioned/node_1/mnist.npz`
- etc.

### 5. Start Manta Nodes
Launch each Manta node with its configuration:
```bash
manta node start <node_name>
```

Verify nodes are connected by checking your cluster dashboard or running:
```bash
manta node status
```

### 6. Docker Image
Ensure the Docker image `manta_light:pytorch` is available on your nodes, or modify the `image` variable in this notebook to use a different PyTorch-compatible image.

## Step 1: Import Libraries and Configure Authentication

First, import the necessary libraries and configure your authentication credentials.

**Replace the credentials below with your actual account credentials from dashboard.manta-tech.io**

In [None]:
from manta.apis.async_user_api import AsyncUserAPI
from manta.common.conversions import bytes_to_dict
from pathlib import Path

from manta import Module, Task, Swarm
from manta.light.utils import numpy_to_bytes

from modules.worker.model import MLP

# Replace with your credentials from dashboard.manta-tech.io
USERNAME = "your-email@example.com"
PASSWORD = "your-password"

api = await AsyncUserAPI.sign_in(
    USERNAME,
    PASSWORD,
    host="api.manta-tech.io",
    port=443,
)

## Step 2: Connect to Manta Platform and Find Active Cluster

This section establishes a connection to the Manta platform and locates an active cluster for deployment:

1. **Initialize UserAPI**: Creates a connection to the Manta manager service
2. **Check availability**: Verifies the connection is working
3. **Find active cluster**: Searches for a running cluster to deploy the swarm

The cluster API will be used for all subsequent operations including swarm deployment, monitoring, and log streaming.


In [None]:
availability_message = await api.is_available()
print(f"UserAPI availability: {availability_message}")

# Find an active (RUNNING) cluster
print("\nSearching for active cluster...")
async for cluster in api.stream_clusters():
    # Status 1 = RUNNING, 0 = CREATED, 2 = INACTIVE
    if cluster["status"] == 1:
        print("===================Active Cluster Found===================")
        print(f"Cluster ID: {cluster['cluster_id']}")
        print(f"Cluster Name: {cluster['name']}")
        print(f"Status: RUNNING")
        active_cluster_id = cluster["cluster_id"]
        break
else:
    print("No running cluster found. Please start a cluster from the dashboard.")
    raise RuntimeError("No active cluster available")

## Step 3: Define the Federated Learning Swarm

The `FLSwarm` class defines the complete federated learning workflow with four main components:

### Task Components:

1. **Aggregator Task**: 
   - Combines model weights from workers using federated averaging
   - Runs on any available node (`method="any"`)
   - Limited to 1 instance (`maximum=1`)

2. **Worker Train Task**: 
   - Trains local models on distributed MNIST data
   - Runs on all available nodes with data (`method="all"`)
   - Unlimited instances (`maximum=-1`)
   - Requires MNIST dataset

3. **Scheduler Task**: 
   - Coordinates training rounds and checks convergence
   - Decides whether to continue training or stop
   - Runs on any available node

4. **Worker Test Task**: 
   - Evaluates the global model on test data
   - Runs on all nodes with test data
   - Provides validation metrics

### Execution Flow:
The `execute()` method defines the task graph:
```
Worker → Aggregator → Test → Scheduler → (loop back to Worker or END)
```

### Configuration:
- **Hyperparameters**: Learning rate, batch size, optimizer settings
- **Global Model**: Initial model weights shared across all workers
- **Docker Image**: Specifies the container image with manta-light and PyTorch


## Step 4: Define Modules

First, let's define all the modules that will be used in our federated learning swarm. This separation makes it easier to understand and modify each component independently.

### Module Definitions

Each module represents a specific task in the federated learning workflow:
- **Aggregator Module**: Combines model weights from workers using federated averaging
- **Worker Module**: Trains local models on distributed MNIST data
- **Scheduler Module**: Coordinates training rounds and checks convergence  
- **Worker Test Module**: Evaluates the global model on test data


In [4]:
# Define the modules that will be used in the swarm
root_path = Path().resolve()
image = "manta_light:pytorch"
gpu = False  # Set to True to use GPU

# Aggregator Module
aggregator_module = Module(
    root_path / "modules" / "aggregator.py",
    image,
    datasets=[],
)

# Worker Train Module
worker_train_module = Module(
    root_path / "modules" / "worker",
    image,
    datasets=["mnist"],
)

# Scheduler Module
scheduler_module = Module(
    root_path / "modules" / "scheduler.py",
    image,
    datasets=[],
)

# Worker Test Module
worker_test_module = Module(
    root_path / "modules" / "worker_test",
    image,
    datasets=["mnist"],
)

## Step 5: Define the Federated Learning Swarm

Now we'll create the FLSwarm class that uses the modules defined above. This approach separates the module definitions from the swarm logic, making the code more modular and easier to understand.


In [5]:
class FLSwarm(Swarm):
    def __init__(
        self,
        aggregator_module: Module,
        worker_train_module: Module,
        scheduler_module: Module,
        worker_test_module: Module,
        gpu: bool = False,
    ):
        super().__init__()

        # Store modules
        self.aggregator_module = aggregator_module
        self.worker_train_module = worker_train_module
        self.scheduler_module = scheduler_module
        self.worker_test_module = worker_test_module
        self.gpu = gpu

        # Set hyperparameters
        self.set_global(
            "hyperparameters",
            {
                "epochs": 1,
                "batch_size": 32,
                "loss": "CrossEntropyLoss",
                "loss_params": {},
                "optimizer": "SGD",
                "optimizer_params": {"lr": 0.01, "momentum": 0.9},
                "val_acc_threshold": 0.99,
            },
        )

        # Set global model parameters
        self.set_global("global_model_params", numpy_to_bytes(MLP().get_weights()))

    def execute(self):
        """
        Generation of the task graph

        +--------+     +------------+     +------+     +-----------+ if has_converged
        | Worker | --> | Aggregator | --> | Test | --> | Scheduler | ----------------> END PROGRAM
        +--------+     +------------+     +------+     +-----------+
            |                                                       | else
            +--<<<----------<<<-------------<<<------------<<<------+
        """
        m = Task(
            self.worker_train_module,
            gpu=self.gpu,
        )()
        m = Task(
            self.aggregator_module,
            method="any",
            fixed=True,
            maximum=1,
        )(m)
        m = Task(
            self.worker_test_module,
            gpu=self.gpu,
        )(m)
        return Task(
            self.scheduler_module,
            method="any",
            fixed=True,
            maximum=1,
        )(m)

In [6]:
# Create the swarm instance using the pre-defined modules
swarm = FLSwarm(
    aggregator_module=aggregator_module,
    worker_train_module=worker_train_module,
    scheduler_module=scheduler_module,
    worker_test_module=worker_test_module,
    gpu=gpu,
)

print("Federated Learning Swarm created successfully!")
print(f"Using GPU: {gpu}")
print(f"Image: {image}")

Federated Learning Swarm created successfully!
Using GPU: False
Image: manta_light:pytorch


## Step 6: Deploy the Swarm to the Cluster

This section deploys the federated learning swarm to the active cluster:

### Deployment Process:
1. **Create Swarm Instance**: Instantiate the `FLSwarm` class with desired configuration
2. **Deploy to Cluster**: Use `cluster_api.deploy_swarm()` to submit the swarm
3. **Get Deployment Overview**: Receive confirmation and metadata about the deployment

### Important Deployment Information:
- **swarm_id**: Unique identifier for tracking and monitoring the swarm
- **status**: Current state (PENDING → RUNNING → COMPLETED)
- **task_count**: Number of tasks defined in the swarm
- **node_count**: Number of nodes participating in the execution
- **iteration**: Current training round (starts at 0)

The swarm will automatically start executing once suitable nodes are available and the required Docker images are pulled.


In [7]:
print("\nDeploying swarm...")
swarm_overview = await api.deploy_swarm(active_cluster_id, swarm)
print("Swarm Deployment Overview:")
for key, value in swarm_overview.items():
    print(f"  {key}: {value}")

swarm_id = swarm_overview["swarm_id"]


Deploying swarm...
Swarm Deployment Overview:
  swarm_id: a41d4d1b01ff44f58ce8207d3c7273ca
  cluster_id: 90700570a2ce4b1891b32341a192fc1a
  owner_id: a6afeb73446e409aafc8f8af10eadb0f
  name: Swarm
  created_at: 2025-10-16 08:26:18.097000+00:00
  last_start: 1970-01-01 00:00:00+00:00
  last_stop: 1970-01-01 00:00:00+00:00
  status: ACTIVE
  iteration: 0
  circular: 0
  authorization: {'swarm_id': 'a41d4d1b01ff44f58ce8207d3c7273ca', 'quotas': {'max_storage_gb': 10.0, 'current_storage_gb': 3.814697265625e-06, 'max_concurrent_tasks': 10, 'current_concurrent_tasks': 2}, 'permissions': [], 'permissions_updated_at': datetime.datetime(2025, 10, 16, 8, 26, 18, 97000, tzinfo=datetime.timezone.utc)}


## Step 5: Monitor Training Results

This section shows how to monitor the training progress by streaming results from the swarm:

### Result Monitoring:
- **Stream Results**: Use `cluster_api.stream_results()` to receive real-time updates
- **Filter by Tag**: Specify "metrics" to get training/validation metrics
- **Real-time Updates**: Results arrive as tasks complete each training iteration

### Key Metrics Available:
- **val_loss**: Validation loss after each training round
- **val_acc**: Validation accuracy after each training round
- **node_id**: Which node generated the result
- **task_id**: Which task (worker/aggregator) produced the metric
- **iteration**: Current training round number

### Understanding the Output:
Each result contains metadata about the task execution and the actual training metrics. The federated learning process will show results from:
- **Workers**: Local training metrics from each participating node
- **Aggregator**: Global model performance after weight aggregation
- **Test Workers**: Validation results on test data

💡 **Tip**: The streaming continues until the swarm completes or convergence is reached.


In [None]:
async for result in api.stream_results(swarm_id, "metrics"):
    for key, value in result.items():
        if key == "data":
            print(f"  {key}: {bytes_to_dict(value)}")
        else:
            print(f"  {key}: {value}")

    print("-" * 100)

## Step 6: Monitor Execution Logs

This section demonstrates how to stream logs from the swarm execution for debugging and progress tracking:

### Log Streaming:
- **Stream Logs**: Use `cluster_api.stream_logs()` to receive real-time log output
- **All Tasks**: Logs from all task types (workers, aggregator, scheduler)
- **Debugging Information**: Detailed execution traces and error messages

### Log Information Includes:
- **node_id**: Which node is executing the task
- **task_id**: Specific task instance generating the log
- **iteration/circular**: Training round and cycle information
- **timestamp**: When the log entry was generated
- **severity**: Log level (INFO, WARNING, ERROR, COMPLETED)
- **content**: Detailed log messages from the task execution

### Log Content Examples:
- **Configuration**: Task setup and parameter initialization
- **Data Loading**: Dataset access and preparation
- **Training Progress**: Model training and validation steps
- **Communication**: Inter-task messaging and coordination
- **Completion**: Task completion and cleanup

### Using Logs for Debugging:
- Monitor for ERROR severity messages to identify issues
- Track COMPLETED messages to see task progression
- Use timestamps to understand execution timing
- Check node_id to identify which nodes are having problems

🔍 **Debugging Tip**: If training seems stuck, check the logs for connection issues, data loading problems, or resource constraints.


In [None]:
async for log in api.stream_logs("25f532d523864c6884a0e34854b7c707"):
    for key, value in log.items():
        print(f"  {key}: {value}")
    print("-" * 100)

In [None]:
# Optional: Stream logs for debugging (replace swarm_id with your actual swarm_id)
# Uncomment the lines below to stream logs
# async for log in api.stream_logs(swarm_id):
#     print(f"[{log.get('severity', 'INFO')}] Node {log.get('node_id', 'unknown')[:8]}: {log.get('content', '')}")
#     print("-" * 80)