# ATLAS: Publication-Quality Experiments for IEEE

**Multi-Task Federated Learning with Heterogeneous Devices**  
**Session-Based Training for Long Runs (30+ Rounds)**

---

## üìã Experimental Configuration

### Publication Parameters
- **Rounds**: 30 (split into 15+15 sessions if needed)
- **Samples per client**: 3000-5000 (publication quality)
- **Local epochs**: 3
- **Checkpointing**: Every 5 rounds (automatic resuming)

### Multi-Domain Tasks
- **NLP**: SST-2 (sentiment), MRPC (paraphrase), CoLA (grammar), QNLI (QA)
- **Vision** (optional): CIFAR-10, MNIST variants
- **Speech** (optional): Speech commands

### Models Supported
- DistilBERT (default)
- BERT-base
- RoBERTa
- GPT-2
- (Easily extensible)

---

## üéØ Key Features

1. **Session-based training**: Split 30 rounds into multiple Colab sessions
2. **Automatic checkpoint resume**: Continue from last saved checkpoint
3. **Model-agnostic**: Test different architectures
4. **Multi-domain**: True heterogeneous multi-task FL
5. **IEEE-quality results**: Rigorous parameters for publication

## üîß Setup

In [None]:
# Install dependencies
!pip install -q torch transformers datasets peft scikit-learn scipy numpy pandas matplotlib

# Check GPU
import torch
print(f"‚úì GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úì VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# Clone or update repository
import os
from pathlib import Path

if Path('ATLAS').exists():
    %cd ATLAS
    !git pull origin main
    print("‚úì Repository updated")
else:
    !git clone https://github.com/mahmoudmayaleh/ATLAS.git
    %cd ATLAS
    print("‚úì Repository cloned")

---

## üöÄ Experiment 1: ATLAS Full Pipeline (30 Rounds)

**Session-Based Training**:
- Session 1: Rounds 1-15 (~2-3 hours)
- Session 2: Rounds 16-30 (~2-3 hours)

Checkpoints saved every 5 rounds automatically.

### Session 1: Rounds 1-15

In [None]:
# Session 1: Train rounds 1-15
# This will save checkpoints at rounds 5, 10, 15

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola \
    --clients-per-task 3 \
    --samples 5000 \
    --local-epochs 3

print("\n‚úì Session 1 complete (Rounds 1-15)")
print("‚úì Checkpoint saved: checkpoints/atlas_round_15.pkl")
print("\n‚è∏Ô∏è  You can now disconnect and resume in a new session")

### Session 2: Rounds 16-30 (Resume from checkpoint)

**Run this in a NEW Colab session** (or continue in same session):

In [None]:
# Session 2: Resume from round 15, continue to round 30
# Will automatically load checkpoint and continue training

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --resume checkpoints/atlas_round_15.pkl \
    --ablation atlas \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola \
    --clients-per-task 3 \
    --samples 5000 \
    --local-epochs 3

print("\n‚úì Session 2 complete (Rounds 16-30)")
print("‚úì Final results saved: results/atlas_integrated_full_atlas.json")

### Alternative: Single Session (if you have 4+ hours)

If your Colab session lasts long enough, run all 30 rounds at once:

In [None]:
# Run full 30 rounds in one go (3-4 hours)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --ablation atlas \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola \
    --clients-per-task 3 \
    --samples 5000 \
    --local-epochs 3

---

## üìä Experiment 2: Ablation Studies

**Compare**:
1. ATLAS Full (all 4 phases)
2. FedAvg per Cluster (no Laplacian)
3. Local Only (no aggregation)

Each runs 30 rounds for rigorous comparison.

### 2.1: FedAvg per Cluster (Ablation)

In [None]:
# Ablation: FedAvg within clusters (Phase 1-3 only)
# Session 1: Rounds 1-15

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation fedavg_cluster \
    --samples 5000 \
    --local-epochs 3

# Save checkpoint location
fedavg_checkpoint = "checkpoints/atlas_round_15.pkl"
print(f"\n‚úì FedAvg ablation - Session 1 complete")
print(f"‚úì Resume with: --resume {fedavg_checkpoint}")

In [None]:
# Session 2: Continue rounds 16-30
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --resume checkpoints/atlas_round_15.pkl \
    --ablation fedavg_cluster \
    --samples 5000 \
    --local-epochs 3

### 2.2: Local Only Baseline

In [None]:
# Baseline: Local training only (no aggregation)
# Faster since no communication, but worse accuracy

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --ablation local_only \
    --samples 5000 \
    --local-epochs 3

---

## üî¨ Experiment 3: Lambda (Œ∑) Sweep

Test different Laplacian regularization strengths: {0.0, 0.01, 0.1, 0.5, 1.0}

This helps find optimal personalization vs convergence tradeoff.

In [None]:
# Lambda sweep across 5 values
# Note: This runs 5 experiments √ó 30 rounds each
# Recommended: Run in multiple sessions or reduce rounds

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --lambda-sweep \
    --samples 5000 \
    --local-epochs 3

print("\n‚úì Lambda sweep complete")
print("‚úì Results saved: results/lambda_sweep_full_atlas.json")

---

## üé® Experiment 4: Different Models

Test ATLAS with different backbone models.

### 4.1: BERT-base (110M params)

In [None]:
# BERT-base (more parameters than DistilBERT)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model bert-base-uncased \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

print("\n‚úì BERT-base Session 1 complete")
print("‚úì Continue with: --resume checkpoints/atlas_round_15.pkl --model bert-base-uncased")

### 4.2: RoBERTa-base

In [None]:
# RoBERTa-base (better performance on many tasks)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model roberta-base \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

### 4.3: GPT-2

In [None]:
# GPT-2 (decoder-only, good for generation tasks)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model gpt2 \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

---

## üìà Results Analysis & Visualization

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Load all experiment results
results_dir = Path('results')
experiments = {}

# Define experiment files
result_files = {
    'ATLAS (DistilBERT)': 'atlas_integrated_full_atlas.json',
    'FedAvg Cluster': 'atlas_integrated_full_fedavg_cluster.json',
    'Local Only': 'atlas_integrated_full_local_only.json',
}

print("=" * 80)
print("üìä LOADING EXPERIMENTAL RESULTS")
print("=" * 80)

for name, filename in result_files.items():
    filepath = results_dir / filename
    if filepath.exists():
        with open(filepath, 'r') as f:
            experiments[name] = json.load(f)
        print(f"‚úì Loaded: {name}")
    else:
        print(f"‚ö†Ô∏è  Missing: {name} ({filename})")

if not experiments:
    print("\n‚ùå No results found! Run experiments first.")
else:
    print(f"\n‚úì Loaded {len(experiments)} experiments")

In [None]:
# Create comprehensive comparison table
comparison_data = []

for exp_name, results in experiments.items():
    final_accs = results.get('final_accuracies', {})
    round_metrics = results.get('round_metrics', [])
    
    if not final_accs or not round_metrics:
        continue
    
    # Calculate metrics
    client_accs = list(final_accs.values())
    avg_acc = np.mean(client_accs)
    std_acc = np.std(client_accs)
    min_acc = min(client_accs)
    max_acc = max(client_accs)
    
    # Communication cost
    total_comm = 0
    for rm in round_metrics:
        up = rm.get('comm_upload_bytes', {})
        down = rm.get('comm_download_bytes', {})
        if isinstance(up, dict):
            total_comm += sum(up.values()) + sum(down.values())
        else:
            total_comm += up + down
    total_comm_mb = total_comm / (1024**2)
    
    # Time
    total_time_min = sum(rm.get('time_seconds', 0) for rm in round_metrics) / 60
    
    comparison_data.append({
        'Experiment': exp_name,
        'Rounds': len(round_metrics),
        'Avg Accuracy': f'{avg_acc:.4f}',
        'Std Dev': f'{std_acc:.4f}',
        'Min Acc': f'{min_acc:.4f}',
        'Max Acc': f'{max_acc:.4f}',
        'Comm (MB)': f'{total_comm_mb:.1f}',
        'Time (min)': f'{total_time_min:.1f}'
    })

df = pd.DataFrame(comparison_data)
print("\n" + "=" * 80)
print("üìä COMPREHENSIVE COMPARISON TABLE")
print("=" * 80)
print(df.to_string(index=False))

# Save to CSV
csv_path = results_dir / 'publication_comparison.csv'
df.to_csv(csv_path, index=False)
print(f"\nüíæ Saved to: {csv_path}")

In [None]:
# Publication-quality plots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Convergence curves
for exp_name, results in experiments.items():
    rounds = [rm['round'] for rm in results['round_metrics']]
    accs = [rm['avg_accuracy'] for rm in results['round_metrics']]
    axes[0, 0].plot(rounds, accs, 'o-', label=exp_name, linewidth=2.5, markersize=6)

axes[0, 0].set_xlabel('Round', fontsize=14, fontweight='bold')
axes[0, 0].set_ylabel('Average Accuracy', fontsize=14, fontweight='bold')
axes[0, 0].set_title('Convergence Comparison (30 Rounds)', fontsize=16, fontweight='bold')
axes[0, 0].legend(fontsize=11, loc='lower right')
axes[0, 0].grid(True, alpha=0.3, linestyle='--')
axes[0, 0].set_ylim([0.5, 1.0])

# 2. Final accuracy with error bars
exp_names = []
means = []
stds = []

for exp_name, results in experiments.items():
    accs = list(results['final_accuracies'].values())
    exp_names.append(exp_name)
    means.append(np.mean(accs))
    stds.append(np.std(accs))

bars = axes[0, 1].bar(range(len(exp_names)), means, 
                       color=['#2ecc71', '#3498db', '#e74c3c'][:len(exp_names)],
                       edgecolor='black', linewidth=1.5)
axes[0, 1].errorbar(range(len(exp_names)), means, yerr=stds, 
                    fmt='none', color='black', capsize=8, capthick=2)
axes[0, 1].set_xticks(range(len(exp_names)))
axes[0, 1].set_xticklabels(exp_names, rotation=15, ha='right')
axes[0, 1].set_ylabel('Average Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].set_title('Final Accuracy (Mean ¬± Std)', fontsize=16, fontweight='bold')
axes[0, 1].grid(axis='y', alpha=0.3, linestyle='--')
axes[0, 1].set_ylim([0.5, 1.0])

# Add value labels
for i, (bar, mean) in enumerate(zip(bars, means)):
    axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + stds[i] + 0.02,
                    f'{mean:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# 3. Per-client accuracy (personalization)
for exp_name, results in experiments.items():
    client_ids = sorted(results['final_accuracies'].keys(), key=lambda x: int(x))
    accs = [results['final_accuracies'][cid] for cid in client_ids]
    axes[1, 0].plot(range(len(client_ids)), accs, 'o-', label=exp_name, 
                    linewidth=2.5, markersize=7)

axes[1, 0].set_xlabel('Client ID', fontsize=14, fontweight='bold')
axes[1, 0].set_ylabel('Accuracy', fontsize=14, fontweight='bold')
axes[1, 0].set_title('Per-Client Personalization', fontsize=16, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3, linestyle='--')
axes[1, 0].set_ylim([0.5, 1.0])

# 4. Communication cost
comm_costs = []
for exp_name, results in experiments.items():
    total = 0
    for rm in results['round_metrics']:
        up = rm.get('comm_upload_bytes', {})
        down = rm.get('comm_download_bytes', {})
        if isinstance(up, dict):
            total += sum(up.values()) + sum(down.values())
        else:
            total += up + down
    comm_costs.append(total / (1024**2))

axes[1, 1].bar(range(len(exp_names)), comm_costs,
               color=['#2ecc71', '#3498db', '#e74c3c'][:len(exp_names)],
               edgecolor='black', linewidth=1.5)
axes[1, 1].set_xticks(range(len(exp_names)))
axes[1, 1].set_xticklabels(exp_names, rotation=15, ha='right')
axes[1, 1].set_ylabel('Total Communication (MB)', fontsize=14, fontweight='bold')
axes[1, 1].set_title('Communication Overhead', fontsize=16, fontweight='bold')
axes[1, 1].grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plot_path = results_dir / 'publication_results.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úì High-resolution plot saved: {plot_path}")
print("  (300 DPI - suitable for IEEE publications)")

---

## üì¶ Download Results for Publication

Package all results and figures for offline analysis.

In [None]:
# Create publication package
!zip -r atlas_publication_results.zip results/ figures/ checkpoints/ \
    -x "*.pyc" "*__pycache__*"

print("‚úì Results packaged: atlas_publication_results.zip")
print("\nContents:")
print("  - results/*.json (all experimental data)")
print("  - results/publication_comparison.csv")
print("  - results/publication_results.png (300 DPI)")
print("  - checkpoints/*.pkl (for resuming)")

# Download (in Colab)
from google.colab import files
files.download('atlas_publication_results.zip')

---

## üéì Citation & IEEE Formatting

**Suggested IEEE Paper Structure**:

1. **Abstract**: Multi-task FL with heterogeneous devices + LoRA + Laplacian regularization
2. **Introduction**: Challenges of FL for LLMs on edge devices
3. **Related Work**: FedAvg, LoRA, Split Learning, MIRA, HSplitLoRA
4. **Methodology**:
   - Phase 1: Gradient-based clustering
   - Phase 2: Heterogeneous rank allocation
   - Phase 3: Split federated learning
   - Phase 4: Graph-based personalization
5. **Experiments**: 
   - Setup: 9 clients, 3 tasks, 30 rounds, 5000 samples
   - Baselines: Local Only, FedAvg per Cluster
   - Results: Convergence, accuracy, communication, ablations
6. **Results & Discussion**: Show plots from above
7. **Conclusion**: ATLAS enables personalized LLM fine-tuning on heterogeneous edge devices

**Key Metrics for IEEE Paper**:
- Convergence rate (rounds to 90% of final accuracy)
- Final accuracy (mean ¬± std across clients)
- Communication cost (MB per round, total MB)
- Personalization quality (variance, per-task accuracy)
- Ablation study results (with/without each phase)

---

## üìù Session Management Cheat Sheet

### Start new experiment:
```bash
python experiments/atlas_integrated.py --mode full --rounds 30 --max-rounds 15
```

### Resume from checkpoint:
```bash
python experiments/atlas_integrated.py --mode full --rounds 30 \\
    --resume checkpoints/atlas_round_15.pkl
```

### Change model:
```bash
python experiments/atlas_integrated.py --model bert-base-uncased
```

### Change tasks:
```bash
python experiments/atlas_integrated.py --tasks sst2 mrpc qnli mnli
```

### Override parameters:
```bash
python experiments/atlas_integrated.py --samples 3000 --local-epochs 2
```

### Check available checkpoints:
```bash
ls -lh checkpoints/
```