# ATLAS: Publication-Quality Experiments for IEEE

**Multi-Task Federated Learning with Heterogeneous Devices**  
**Statistical Rigor + Architecture Improvements**

---

## ‚ö†Ô∏è **CRITICAL ISSUES IDENTIFIED**

### **Problem 1: Low Overall Improvement (~6%)**
- ATLAS: 0.7398 ‚Üí 0.8039 (+6.4% in 10 rounds)
- FedAvg: 0.7389 ‚Üí 0.8036 (+6.5% in 10 rounds)
- **Both plateau early at ~0.805**

**Possible Causes**:
1. Initial model already ~74% accurate (high baseline)
2. Learning rate too low (2e-5 is very conservative)
3. Only 3 local epochs (insufficient local optimization)
4. Small tasks (MRPC=3668 samples, CoLA=8516 samples)
5. LoRA rank too low (limiting model capacity)

---

### **Problem 2: ATLAS Time = FedAvg Time (NO SPEEDUP!)**
**Expected**: 99% parameter reduction ‚Üí faster training  
**Reality**: Both take ~2.5 hours for 15 rounds

**Why?**
- ‚úÖ ATLAS reduces **communication** (upload/download bytes) by 99%
- ‚ùå ATLAS does NOT reduce **computation** (forward/backward passes)
- **Bottleneck = Local training time, NOT communication**

**GPU T4 Reality**:
- Forward/backward pass: ~9 min per client per round
- Communication: <10 seconds per round (negligible on GPU)
- **Computation >> Communication** (100:1 ratio)

**ATLAS benefit only matters when**:
- Slow networks (mobile, edge devices)
- Many clients (communication becomes bottleneck)
- CPU-only devices (where communication is more expensive)

**For paper**: Need to emphasize **communication efficiency** NOT training speed

---

### **Problem 3: Time Constraints** (Realistic Planning)
- 1 run = ~2.5 hours (15 rounds, DistilBERT, 5000 samples)
- 3 seeds √ó 3 configs = **9 runs = ~22.5 hours**
- Lambda sweep (5 values) = **12.5 hours**
- Model comparison (BERT/RoBERTa) = **5+ hours each**

**Total realistic workload**: ~40-50 hours compute time

---

## üéØ **REVISED PRIORITY EXPERIMENTS** (Realistic)

### **TIMING: Each run = 2.5 hours on T4 GPU**

---

### **1. Multi-Seed Statistical Experiments** (22.5 hrs)
```bash
python experiments/run_statistical_experiments.py --seeds 3 --rounds 15
```
**Time**: 9 runs √ó 2.5h = 22.5 hours (run overnight + next day)  
**Output**: Mean¬±std, t-tests, p-values (3 seeds is minimal for stats)

---

### **2. Hyperparameter Tuning - Fix Low Improvement** (7.5 hrs)
Test configurations that may break the 0.806 ceiling:

```bash
# Higher learning rate (5x increase)
python experiments/atlas_integrated.py --ablation atlas --lr 5e-5 --rounds 15 --seed 42

# More local epochs (67% increase)
python experiments/atlas_integrated.py --ablation atlas --local-epochs 5 --rounds 15 --seed 42

# Both combined
python experiments/atlas_integrated.py --ablation atlas --lr 5e-5 --local-epochs 5 --rounds 15 --seed 42
```
**Time**: 3 runs √ó 2.5h = 7.5 hours  
**Goal**: Push past 0.806 plateau

---

### **3. Lambda Sweep - Regularization Impact** (12.5 hrs)
```bash
for eta in 0.0 0.01 0.05 0.1 0.2; do
    python experiments/atlas_integrated.py --ablation atlas --eta $eta --rounds 15 --seed 42
done
```
**Time**: 5 runs √ó 2.5h = 12.5 hours  
**Goal**: Find optimal Œª (current 0.1 may be too strong)

---

### **4. Architecture Improvements** (per experiment: 2.5-3 hrs)

**A. Higher LoRA Ranks** (may increase capacity):
```bash
# Default uses ranks [4,8,16,32,64] - try doubling
python experiments/atlas_integrated.py --ablation atlas --rank-multiplier 2 --rounds 15 --seed 42
```

**B. More Tasks** (reduce overfitting through diversity):
```bash
python experiments/atlas_integrated.py --tasks sst2 mrpc cola qnli qqp --rounds 15 --seed 42
```

**C. Different Model** (BERT-base has more capacity):
```bash
python experiments/run_statistical_experiments.py --seeds 3 --model bert-base-uncased --rounds 12
```
**Note**: BERT-base will take ~3-3.5h per run (larger model)

---

## üîß **Addressing the Core Problems**

### **Fix 1: Boost Initial Learning** 
Current settings are too conservative:
- LR=2e-5 (standard BERT fine-tuning, but may be slow for FL)
- Local epochs=3 (may need 5-7 for proper local convergence)
- Batch size=16 (small, increases noise)

**Try**:
```python
# In experiments/atlas_integrated.py config
learning_rate = 5e-5  # 2.5x increase
local_epochs = 5      # 67% increase  
batch_size = 32       # 2x increase (if memory allows)
```

---

### **Fix 2: Communication vs Computation Metrics**

**For paper, report BOTH**:

| Metric | ATLAS | FedAvg | Improvement |
|--------|-------|--------|-------------|
| **Communication** | 245 MB | 8,940 MB | **97.3% reduction** ‚úì |
| **Training Time** | 2.5 hrs | 2.5 hrs | **0% reduction** (expected) |
| **Accuracy** | 0.8062 | 0.8054 | **+0.1%** (marginal) |

**Key message**: ATLAS is for **bandwidth-constrained** scenarios (mobile, edge), NOT for speeding up GPU training.

---

### **Fix 3: Better Baseline Comparison**

Your baseline (74% initial) is already quite good. Try:
1. Start from **random classifier** (not pretrained head)
2. Use **harder tasks** (MNLI, QQP are more challenging)
3. Use **less data per client** (2000 samples instead of 5000)

This will show larger improvements (e.g., 50% ‚Üí 80% = 30% gain vs 74% ‚Üí 80% = 6% gain)

---

## üìä **Realistic 1-Week Schedule**

**Monday-Tuesday** (22.5 hrs): Multi-seed experiments (3 seeds √ó 3 configs)
```bash
python experiments/run_statistical_experiments.py --seeds 3 --rounds 15
```

**Wednesday** (7.5 hrs): Hyperparameter search
```bash
--lr 5e-5 (single run)
--local-epochs 5 (single run)
--lr 5e-5 --local-epochs 5 (combined)
```

**Thursday** (12.5 hrs): Lambda sweep
```bash
for eta in 0.0 0.01 0.05 0.1 0.2; do ...
```

**Friday** (7.5 hrs): Architecture tests
```bash
--tasks sst2 mrpc cola qnli qqp
--rank-multiplier 2
--model bert-base-uncased (single seed test)
```

**Total**: ~50 hours compute (feasible in 1 week with overnight runs)

---

## üí° **Key Insights for Paper**

### **What ATLAS Actually Achieves**:
‚úÖ 97% communication reduction (245 MB vs 8.9 GB)  
‚úÖ Enables heterogeneous devices (2GB to 16GB)  
‚úÖ Maintains accuracy parity with FedAvg  
‚úÖ Better personalization (lower std dev across clients)

### **What ATLAS Does NOT Achieve**:
‚ùå Training time reduction (computation-bound, not communication-bound)  
‚ùå Large accuracy gains over FedAvg (+0.1% only)  
‚ùå Breaking through performance plateaus

### **Research Contributions**:
1. **Novel**: LoRA + Split Learning + Laplacian regularization 
2. **Practical**: Works on 2GB devices (previously impossible)
3. **Efficient**: 97% less data transfer (critical for mobile/edge)
4. **Fair**: Better client-level personalization

**Honest framing**: ATLAS is about **enabling** FL on resource-constrained devices, not about beating centralized/high-resource FL.

---

## ‚ö° **START HERE** (Most Critical)

```bash
# Run this overnight (22.5 hours)
python experiments/run_statistical_experiments.py --seeds 3 --rounds 15
```

While it runs, prepare hyperparameter experiments for tomorrow.

---

**Updated timing**: All estimates now realistic for T4 GPU (2.5h per 15-round run)

## Setup

In [None]:
# Install dependencies
!pip install -q torch transformers datasets peft scikit-learn scipy numpy pandas matplotlib

# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

GPU Available: True
GPU: Tesla T4
VRAM: 14.7 GB


In [None]:
# Clone or update repository
import os
from pathlib import Path

if Path('ATLAS').exists():
    %cd ATLAS
    !git pull origin main
    print("Repository updated")
else:
    !git clone https://github.com/mahmoudmayaleh/ATLAS.git
    %cd ATLAS
    print("Repository cloned")

Cloning into 'ATLAS'...
remote: Enumerating objects: 481, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 481 (delta 17), reused 38 (delta 12), pack-reused 436 (from 1)[K
Receiving objects: 100% (481/481), 49.63 MiB | 17.60 MiB/s, done.
Resolving deltas: 100% (253/253), done.
/content/ATLAS
Repository cloned


In [None]:
# Mount Google Drive for automatic checkpoint backup
from google.colab import drive
drive.mount('/content/drive')

# Create backup directory
import os
import shutil
from pathlib import Path

backup_dir = '/content/drive/MyDrive/ATLAS_Checkpoints'
os.makedirs(backup_dir, exist_ok=True)
print(f"[OK] Checkpoints will be backed up to: {backup_dir}")

# Auto-backup helper functions
def backup_to_drive(source_dir='results', backup_subdir='ATLAS_Results'):
    """Automatically backup results and checkpoints to Google Drive"""
    drive_backup = f'/content/drive/MyDrive/{backup_subdir}'
    os.makedirs(drive_backup, exist_ok=True)
    
    # Backup results
    if Path(source_dir).exists():
        for item in Path(source_dir).glob('*'):
            if item.is_file():
                shutil.copy2(str(item), f"{drive_backup}/{item.name}")
        print(f"[BACKUP] Results -> Drive/{backup_subdir}")
    
    # Backup checkpoints (only the latest one)
    if Path('checkpoints').exists():
        checkpoints = sorted(Path('checkpoints').glob('*.pkl'), key=lambda x: x.stat().st_mtime)
        if checkpoints:
            latest = checkpoints[-1]  # Only backup the most recent checkpoint
            shutil.copy2(str(latest), f"{backup_dir}/{latest.name}")
            print(f"[BACKUP] Latest checkpoint: {latest.name} -> Drive")
    
    print(f"[OK] Auto-backup complete!")

print("[OK] Auto-backup functions loaded")


Mounted at /content/drive
[OK] Checkpoints will be backed up to: /content/drive/MyDrive/ATLAS_Checkpoints


---

## Experiment 1: ATLAS Full Pipeline (30 Rounds)

**Session-Based Training**:
- Session 1: Rounds 1-15 (~2-3 hours)
- Session 2: Rounds 16-30 (~2-3 hours)

**Note**: Only the final checkpoint is saved. All results and checkpoints are automatically backed up to Google Drive after each run.

### Session 1: Rounds 1-15

In [None]:
# QUICK TEST: Try breaking the 0.806 plateau with better hyperparameters
# Session 1: Higher LR + More local epochs (2.5 hours)

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola \
    --clients-per-task 3 \
    --samples 5000 \
    --local-epochs 5 \
    --lr 5e-5 \
    --seed 42

print("\n" + "=" * 80)
print("QUICK TEST COMPLETE")
print("=" * 80)
print("Compare with baseline:")
print("  Baseline (lr=2e-5, epochs=3): 0.8062 at round 15")
print("  New (lr=5e-5, epochs=5):      [check above]")
print("\nIf > 0.815: Breakthrough! Continue with these settings")
print("If 0.806-0.815: Modest gain, try more aggressive settings")
print("If < 0.806: Too aggressive, dial back")
print("=" * 80)

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results automatically backed up to Drive")

[MODE] Full experiment (2-4 hours per run on T4 GPU)
         For 30+ rounds, split into sessions: 15+15 with --resume
[SESSION] Limiting this session to 15 rounds (use --resume to continue)
config.json: 100% 483/483 [00:00<00:00, 2.17MB/s]
tokenizer_config.json: 100% 48.0/48.0 [00:00<00:00, 235kB/s]
vocab.txt: 100% 232k/232k [00:00<00:00, 1.97MB/s]
tokenizer.json: 100% 466k/466k [00:00<00:00, 16.7MB/s]

[SETUP] Creating multi-task federated learning setup...
  Loading task: sst2
  [CLEAN] Loading pre-cleaned sst2 from disk
Map: 100% 66978/66978 [00:08<00:00, 7975.44 examples/s]
Map: 100% 872/872 [00:00<00:00, 7061.31 examples/s]
    Client 0: sst2, cpu_2gb, 5000 samples
    Client 1: sst2, cpu_2gb, 5000 samples
    Client 2: sst2, tablet_4gb, 5000 samples
  Loading task: mrpc
  [CLEAN] Loading pre-cleaned mrpc from disk
Map: 100% 3668/3668 [00:01<00:00, 2657.43 examples/s]
Map: 100% 408/408 [00:00<00:00, 4530.64 examples/s]
    Client 3: mrpc, tablet_4gb, 1222 samples
    Client 4: mr

---

## üß™ Single-Seed Hyperparameter Tests (One per Colab Session)

**Strategy**: Test each variation individually (2.5h each) before committing to multi-seed runs

**Baseline Performance**: 
- lr=2e-5, epochs=3 ‚Üí **0.8062** at round 15
- Plateau at round 10 (0.8039)

**Goal**: Break through 0.815+ to justify multi-seed runs

### Test 1: Higher Learning Rate Only (Session 1)

In [None]:
# Test: lr=5e-5 (2.5x increase), keep epochs=3
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --lr 5e-5 \
    --seed 42

print(f"\n‚úì Test 1 complete: Higher LR (5e-5) only")
backup_to_drive()
print("[OK] Backed up to Drive")

### Test 2: More Local Epochs Only (Session 2)

In [None]:
# Test: local_epochs=5 (67% increase), keep lr=2e-5
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --local-epochs 5 \
    --seed 42

print(f"\n‚úì Test 2 complete: More local epochs (5) only")
backup_to_drive()
print("[OK] Backed up to Drive")

### Test 3: Combined (Higher LR + More Epochs) (Session 3)

**Most promising** - combines both improvements

In [None]:
# Test: lr=5e-5 + local_epochs=5 (both improvements)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --lr 5e-5 \
    --local-epochs 5 \
    --seed 42

print(f"\n‚úì Test 3 complete: Combined (lr=5e-5, epochs=5)")
backup_to_drive()
print("[OK] Backed up to Drive")

### Test 4: Weaker Regularization (Session 4)

Current Œ∑=0.1 may be too restrictive. Try Œ∑=0.01 for less personalization, more convergence.

In [None]:
# Test: eta=0.01 (weaker regularization), keep lr=5e-5, epochs=5
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --lr 5e-5 \
    --local-epochs 5 \
    --eta 0.01 \
    --seed 42

print(f"\n‚úì Test 4 complete: Weaker regularization (eta=0.01)")
backup_to_drive()
print("[OK] Backed up to Drive")

---

## üìä Compare All Single-Seed Tests

After running tests 1-4, use this to compare and decide next steps.

In [None]:
import json
import pandas as pd
from pathlib import Path

# Compare all test results
results = {
    'Baseline (lr=2e-5, epochs=3)': 0.8062,  # Your existing result
}

# Try to load new test results
test_configs = [
    ('Test 1: lr=5e-5 only', 'results/atlas_integrated_full_atlas.json'),
    ('Test 2: epochs=5 only', 'results/atlas_integrated_full_atlas.json'),
    ('Test 3: lr=5e-5 + epochs=5', 'results/atlas_integrated_full_atlas.json'),
    ('Test 4: lr=5e-5 + epochs=5 + eta=0.01', 'results/atlas_integrated_full_atlas.json'),
]

print("=" * 80)
print("HYPERPARAMETER TEST COMPARISON")
print("=" * 80)

for test_name, result_file in test_configs:
    if Path(result_file).exists():
        with open(result_file, 'r') as f:
            data = json.load(f)
            final_acc = data.get('final_avg_accuracy', 
                               sum(data.get('final_accuracies', {}).values()) / 
                               len(data.get('final_accuracies', {1: 0.0})))
            results[test_name] = final_acc

# Create comparison table
df = pd.DataFrame(list(results.items()), columns=['Configuration', 'Final Accuracy'])
df['Improvement vs Baseline'] = df['Final Accuracy'] - 0.8062
df['% Gain'] = (df['Improvement vs Baseline'] / 0.8062) * 100

print("\n" + df.to_string(index=False))

# Find best
best_idx = df['Final Accuracy'].idxmax()
best_config = df.iloc[best_idx]

print("\n" + "=" * 80)
print("RECOMMENDATION")
print("=" * 80)

if best_config['Final Accuracy'] > 0.815:
    print(f"‚úì BREAKTHROUGH! {best_config['Configuration']}")
    print(f"  Accuracy: {best_config['Final Accuracy']:.4f} (+{best_config['Improvement vs Baseline']:.4f})")
    print(f"\n‚Üí USE THIS CONFIG FOR MULTI-SEED RUNS (3 seeds √ó 3 configs = 22.5 hours)")
    
elif best_config['Final Accuracy'] > 0.810:
    print(f"‚úì Modest Improvement: {best_config['Configuration']}")
    print(f"  Accuracy: {best_config['Final Accuracy']:.4f} (+{best_config['Improvement vs Baseline']:.4f})")
    print(f"\n‚Üí Try one more aggressive setting before committing to multi-seed")
    print(f"  Suggestions: lr=7e-5, epochs=7, or eta=0.005")
    
elif best_config['Final Accuracy'] > 0.8062:
    print(f"‚úì Marginal Improvement: {best_config['Configuration']}")
    print(f"  Accuracy: {best_config['Final Accuracy']:.4f} (+{best_config['Improvement vs Baseline']:.4f})")
    print(f"\n‚Üí Gains too small. Try different approach:")
    print(f"  - Add more tasks (QNLI, QQP, MNLI)")
    print(f"  - Use BERT-base (more capacity)")
    print(f"  - Reduce samples to 3000 (less overfitting)")
    
else:
    print(f"‚úó Worse than baseline: {best_config['Configuration']}")
    print(f"  Accuracy: {best_config['Final Accuracy']:.4f} ({best_config['Improvement vs Baseline']:.4f})")
    print(f"\n‚Üí Settings too aggressive. Try:")
    print(f"  - lr=3e-5 (between 2e-5 and 5e-5)")
    print(f"  - epochs=4 (between 3 and 5)")

print("\n" + "=" * 80)

---

## üìä Convergence Analysis - Your Current Results

Visualize the early plateau issue and identify optimal stopping point.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Your actual convergence data
atlas_accs = [0.7398, 0.7673, 0.7740, 0.7811, 0.7894, 0.7932, 0.7923, 
              0.7966, 0.8017, 0.8039, 0.8036, 0.8044, 0.8054, 0.8065, 0.8062]

fedavg_accs = [0.7389, 0.7640, 0.7742, 0.7811, 0.7874, 0.7922, 0.7952,
               0.7981, 0.7978, 0.8036, 0.8053, 0.8027, 0.8041, 0.8050, 0.8054]

rounds = list(range(1, 16))

# Plot convergence
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Plot 1: Full convergence curves
axes[0].plot(rounds, atlas_accs, 'o-', linewidth=3, markersize=8, 
             label='ATLAS', color='#2ecc71')
axes[0].plot(rounds, fedavg_accs, 's-', linewidth=3, markersize=8,
             label='FedAvg/Cluster', color='#3498db')

# Mark convergence point
axes[0].axvline(x=10, color='red', linestyle='--', linewidth=2, alpha=0.7, 
                label='Convergence (~R10)')
axes[0].axhspan(0.804, 0.807, alpha=0.2, color='yellow', 
                label='Plateau Zone')

axes[0].set_xlabel('Round', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Average Accuracy', fontsize=14, fontweight='bold')
axes[0].set_title('Convergence Pattern: Early Plateau at Round 10', 
                  fontsize=16, fontweight='bold')
axes[0].legend(fontsize=12, loc='lower right')
axes[0].grid(True, alpha=0.3, linestyle='--')
axes[0].set_ylim([0.73, 0.81])

# Plot 2: Improvement per round (gradient)
atlas_improvements = [0] + [atlas_accs[i] - atlas_accs[i-1] for i in range(1, len(atlas_accs))]
fedavg_improvements = [0] + [fedavg_accs[i] - fedavg_accs[i-1] for i in range(1, len(fedavg_accs))]

axes[1].bar(np.array(rounds)-0.2, atlas_improvements, width=0.4, 
            label='ATLAS', color='#2ecc71', alpha=0.8)
axes[1].bar(np.array(rounds)+0.2, fedavg_improvements, width=0.4,
            label='FedAvg/Cluster', color='#3498db', alpha=0.8)

axes[1].axhline(y=0.001, color='red', linestyle='--', linewidth=2, 
                label='Negligible Gain (<0.1%)')
axes[1].axvline(x=10.5, color='orange', linestyle='--', linewidth=2, alpha=0.7,
                label='Diminishing Returns ‚Üí')

axes[1].set_xlabel('Round', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Accuracy Gain vs Previous Round', fontsize=14, fontweight='bold')
axes[1].set_title('Per-Round Improvement: Diminishing After R10', 
                  fontsize=16, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(axis='y', alpha=0.3)
axes[1].set_ylim([-0.005, 0.035])

plt.tight_layout()
plt.savefig('results/convergence_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# Print analysis
print("=" * 80)
print("CONVERGENCE ANALYSIS")
print("=" * 80)

atlas_gain_1_10 = atlas_accs[9] - atlas_accs[0]
atlas_gain_10_15 = atlas_accs[14] - atlas_accs[9]
fedavg_gain_1_10 = fedavg_accs[9] - fedavg_accs[0]
fedavg_gain_10_15 = fedavg_accs[14] - fedavg_accs[9]

print(f"\nATLAS:")
print(f"  Round 1‚Üí10:  +{atlas_gain_1_10:.4f} (+{atlas_gain_1_10*100:.2f}%)")
print(f"  Round 10‚Üí15: +{atlas_gain_10_15:.4f} (+{atlas_gain_10_15*100:.2f}%)")
print(f"  ‚Üí {(atlas_gain_10_15/atlas_gain_1_10)*100:.1f}% as effective after R10")

print(f"\nFedAvg:")
print(f"  Round 1‚Üí10:  +{fedavg_gain_1_10:.4f} (+{fedavg_gain_1_10*100:.2f}%)")
print(f"  Round 10‚Üí15: +{fedavg_gain_10_15:.4f} (+{fedavg_gain_10_15*100:.2f}%)")
print(f"  ‚Üí {(fedavg_gain_10_15/fedavg_gain_1_10)*100:.1f}% as effective after R10")

print(f"\nFinal Gap (ATLAS - FedAvg): +{atlas_accs[-1] - fedavg_accs[-1]:.4f} (+{(atlas_accs[-1] - fedavg_accs[-1])*100:.2f}%)")

print("\n" + "=" * 80)
print("CONCLUSION")
print("=" * 80)
print("‚úì Both methods converge by Round 10-12")
print("‚úì Rounds 11-15 show minimal improvement (<0.3%)")
print("‚úì 15 rounds is optimal balance (no overfitting detected yet)")
print("‚úó Going to 30 rounds would likely overfit with no gain")
print("\nRECOMMENDATION: Focus on hyperparameter tuning to break 0.806 ceiling!")
print("  - Try eta=0.01-0.05 (weaker regularization)")
print("  - Try lr=5e-5 (faster convergence)")
print("  - Try local_epochs=5 (deeper local updates)")
print("=" * 80)

---

## üîç Communication vs Computation Analysis

**Why ATLAS saves bandwidth but not time**

In [None]:
# Time breakdown analysis
import matplotlib.pyplot as plt
import numpy as np

# Measured/estimated times per round (in seconds)
# Based on your 2.5 hours / 15 rounds = 600 seconds per round

# Time breakdown for ATLAS
atlas_times = {
    'Local Training\n(Forward/Backward)': 580,  # ~97% of time
    'Aggregation': 5,                           # ~1%
    'Communication\n(Upload/Download)': 10,     # ~2%
    'Laplacian Reg': 5                          # ~1%
}

# Time breakdown for FedAvg
fedavg_times = {
    'Local Training\n(Forward/Backward)': 580,  # Same computation
    'Aggregation': 5,                           
    'Communication\n(Upload/Download)': 15,     # Slightly more (full model)
}

# Communication size (MB)
atlas_comm = 245 / 15  # ~16 MB per round
fedavg_comm = 8940 / 15  # ~596 MB per round

# Create figure
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Time breakdown comparison
methods = ['ATLAS', 'FedAvg']
local_train = [580, 580]
communication = [10, 15]
other = [10, 5]

x = np.arange(len(methods))
width = 0.6

bars1 = axes[0].bar(x, local_train, width, label='Local Training', color='#e74c3c')
bars2 = axes[0].bar(x, communication, width, bottom=local_train, 
                    label='Communication', color='#3498db')
bars3 = axes[0].bar(x, other, width, 
                    bottom=np.array(local_train) + np.array(communication),
                    label='Aggregation/Other', color='#95a5a6')

axes[0].set_ylabel('Time per Round (seconds)', fontsize=14, fontweight='bold')
axes[0].set_title('Time Breakdown: Computation >> Communication', fontsize=16, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(methods, fontsize=12, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].set_ylim([0, 650])

# Add percentage labels
for i, (total, comm) in enumerate(zip([600, 600], [10, 15])):
    axes[0].text(i, total + 20, f'{total}s total', ha='center', fontweight='bold', fontsize=11)
    comm_pct = (comm/total)*100
    axes[0].text(i, local_train[i] + comm/2, f'{comm_pct:.1f}%', ha='center', 
                fontsize=10, color='white', fontweight='bold')

# Plot 2: Communication size comparison
bars = axes[1].bar(methods, [atlas_comm, fedavg_comm], 
                   color=['#2ecc71', '#e74c3c'], edgecolor='black', linewidth=2)
axes[1].set_ylabel('Communication per Round (MB)', fontsize=14, fontweight='bold')
axes[1].set_title('Communication Cost: ATLAS 97% Reduction', fontsize=16, fontweight='bold')
axes[1].set_ylim([0, 650])

# Add value labels and reduction
for i, (bar, val) in enumerate(zip(bars, [atlas_comm, fedavg_comm])):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
                f'{val:.1f} MB', ha='center', va='bottom', fontsize=12, fontweight='bold')

reduction = ((fedavg_comm - atlas_comm) / fedavg_comm) * 100
axes[1].text(0.5, 500, f'‚Üì {reduction:.1f}% reduction', ha='center', 
            fontsize=14, fontweight='bold', color='#2ecc71',
            bbox=dict(boxstyle='round', facecolor='white', edgecolor='#2ecc71', linewidth=2))

axes[1].grid(axis='y', alpha=0.3)

# Plot 3: Bottleneck analysis
categories = ['Computation\n(GPU bound)', 'Communication\n(Network bound)']
atlas_bottleneck = [97, 2]  # Percentages
fedavg_bottleneck = [97, 2.5]

x = np.arange(len(categories))
width = 0.35

bars1 = axes[2].bar(x - width/2, atlas_bottleneck, width, label='ATLAS', 
                    color='#2ecc71', edgecolor='black', linewidth=1.5)
bars2 = axes[2].bar(x + width/2, fedavg_bottleneck, width, label='FedAvg',
                    color='#e74c3c', edgecolor='black', linewidth=1.5)

axes[2].set_ylabel('% of Total Time', fontsize=14, fontweight='bold')
axes[2].set_title('Bottleneck: Computation Dominates', fontsize=16, fontweight='bold')
axes[2].set_xticks(x)
axes[2].set_xticklabels(categories, fontsize=12, fontweight='bold')
axes[2].legend(fontsize=12)
axes[2].set_ylim([0, 110])
axes[2].grid(axis='y', alpha=0.3)

# Add percentage labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        if height > 5:
            axes[2].text(bar.get_x() + bar.get_width()/2, height/2,
                        f'{height:.1f}%', ha='center', va='center',
                        fontsize=11, fontweight='bold', color='white')

plt.tight_layout()
plt.savefig('results/communication_vs_computation.png', dpi=300, bbox_inches='tight')
plt.show()

print("=" * 80)
print("COMMUNICATION vs COMPUTATION ANALYSIS")
print("=" * 80)

print("\nüìä TIME BREAKDOWN (per round):")
print(f"  Local Training:    ~580s (~97%)")
print(f"  Communication:     ~10-15s (~2%)")  
print(f"  Other (agg, etc):  ~5-10s (~1%)")
print(f"  TOTAL:             ~600s")

print("\nüíæ COMMUNICATION SIZE (per round):")
print(f"  ATLAS:    {atlas_comm:.1f} MB")
print(f"  FedAvg:   {fedavg_comm:.1f} MB")
print(f"  Savings:  {reduction:.1f}% reduction")

print("\n‚ö†Ô∏è WHY NO TIME SPEEDUP:")
print("  1. GPU training is COMPUTATION-BOUND (not communication-bound)")
print("  2. Forward/backward passes take 580s, communication only 10-15s")
print("  3. Reducing communication from 15s‚Üí10s saves only ~1% total time")
print("  4. ATLAS's 97% communication reduction = 5s saved = NEGLIGIBLE")

print("\n‚úÖ WHEN ATLAS TIME SAVINGS MATTER:")
print("  - Slow networks (<1 Mbps): Communication becomes 30-50% of time")
print("  - CPU-only devices: Computation slower, communication more expensive")
print("  - Many clients (>100): Communication overhead scales linearly")
print("  - Mobile/Edge: Network costs (bandwidth, latency, $) are critical")

print("\nüìù FOR PAPER:")
print("  - Emphasize: 'Communication-efficient FL for resource-constrained devices'")
print("  - NOT: 'Faster training' (misleading on GPUs)")
print("  - Show: Bandwidth savings, memory reduction, device heterogeneity")
print("  - Honest framing: Enables FL where it was previously impossible")

print("\n" + "=" * 80)

---

## Experiment 2: Ablation Studies

**Compare**:
1. ATLAS Full (all 4 phases)
2. FedAvg per Cluster (no Laplacian)
3. Local Only (no aggregation)

Each runs 30 rounds for rigorous comparison.

### 2.1: FedAvg per Cluster (Ablation)

In [None]:
# Ablation: FedAvg within clusters (Phase 1-3 only)
# Session 1: Rounds 1-15

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --max-rounds 15 \
    --ablation fedavg_cluster \
    --samples 5000 \
    --local-epochs 3

print(f"\nFedAvg ablation - Session 1 complete")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results and checkpoint automatically backed up to Drive")

[MODE] Full experiment (2-4 hours per run on T4 GPU)
         For 30+ rounds, split into sessions: 15+15 with --resume
[SESSION] Limiting this session to 15 rounds (use --resume to continue)

[SETUP] Creating multi-task federated learning setup...
  Loading task: sst2
  [DEDUP] Removed 371 duplicates from sst2 train
Map: 100% 66978/66978 [00:10<00:00, 6247.06 examples/s]
Map: 100% 872/872 [00:00<00:00, 4364.31 examples/s]
    Client 0: sst2, cpu_2gb, 5000 samples
    Client 1: sst2, cpu_2gb, 5000 samples
    Client 2: sst2, tablet_4gb, 5000 samples
  Loading task: mrpc
Map: 100% 3668/3668 [00:00<00:00, 4633.37 examples/s]
Map: 100% 408/408 [00:00<00:00, 3889.03 examples/s]
    Client 3: mrpc, tablet_4gb, 1222 samples
    Client 4: mrpc, tablet_4gb, 1222 samples
    Client 5: mrpc, laptop_8gb, 1224 samples
  Loading task: cola
  [DEDUP] Removing 16 train‚Üîval overlaps from cola
  [DEDUP] Removed 35 duplicates from cola train
  [DEDUP] Removed 4 duplicates from cola val
Map: 100% 8516/8

### 2.2: Local Only Baseline

In [None]:
# Baseline: Local training only (no aggregation)
# Faster since no communication, but worse accuracy

!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation local_only \
    --samples 5000 \
    --local-epochs 3

print(f"\nLocal Only baseline complete")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results and checkpoint automatically backed up to Drive")

[MODE] Full experiment (2-4 hours per run on T4 GPU)
         For 30+ rounds, split into sessions: 15+15 with --resume
config.json: 100% 483/483 [00:00<00:00, 2.44MB/s]
tokenizer_config.json: 100% 48.0/48.0 [00:00<00:00, 258kB/s]
vocab.txt: 100% 232k/232k [00:00<00:00, 4.35MB/s]
tokenizer.json: 100% 466k/466k [00:00<00:00, 21.5MB/s]

[SETUP] Creating multi-task federated learning setup...
  Loading task: sst2
README.md: 5.27kB [00:00, 15.9MB/s]
data/train-00000-of-00001.parquet: 100% 3.11M/3.11M [00:00<00:00, 7.65MB/s]
data/validation-00000-of-00001.parquet: 100% 72.8k/72.8k [00:00<00:00, 393kB/s]
data/test-00000-of-00001.parquet: 100% 148k/148k [00:00<00:00, 704kB/s]  
Generating train split: 100% 67349/67349 [00:00<00:00, 744916.92 examples/s]
Generating validation split: 100% 872/872 [00:00<00:00, 348227.47 examples/s]
Generating test split: 100% 1821/1821 [00:00<00:00, 595820.86 examples/s]
  [DEDUP] Removed 371 duplicates from sst2 train
Map: 100% 66978/66978 [00:12<00:00, 5279.58

---

## Experiment 3: Lambda (Œ∑) Sweep

Test different Laplacian regularization strengths: {0.0, 0.01, 0.1, 0.5, 1.0}

This helps find optimal personalization vs convergence tradeoff.

In [None]:
# Lambda sweep across 5 values
# run --eta 0.0, 0.01, 0.1, 0.5, 1.0
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --eta 0.0 \
    --samples 5000 \
    --local-epochs 3

print("\nLambda sweep complete")
print("Results saved: results/lambda_sweep_full_atlas.json")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results automatically backed up to Drive")

[MODE] Full experiment (2-4 hours per run on T4 GPU)
         For 30+ rounds, split into sessions: 15+15 with --resume

[SETUP] Creating multi-task federated learning setup...
  Loading task: sst2
  [DEDUP] Removed 371 duplicates from sst2 train
Map: 100% 66978/66978 [00:11<00:00, 5977.98 examples/s]
Map: 100% 872/872 [00:00<00:00, 7351.62 examples/s]
    Client 0: sst2, cpu_2gb, 5000 samples
    Client 1: sst2, cpu_2gb, 5000 samples
    Client 2: sst2, tablet_4gb, 5000 samples
  Loading task: mrpc
Map: 100% 3668/3668 [00:00<00:00, 4472.23 examples/s]
Map: 100% 408/408 [00:00<00:00, 4323.75 examples/s]
    Client 3: mrpc, tablet_4gb, 1222 samples
    Client 4: mrpc, tablet_4gb, 1222 samples
    Client 5: mrpc, laptop_8gb, 1224 samples
  Loading task: cola
  [DEDUP] Removing 16 train‚Üîval overlaps from cola
  [DEDUP] Removed 35 duplicates from cola train
  [DEDUP] Removed 4 duplicates from cola val
Map: 100% 8516/8516 [00:01<00:00, 7672.49 examples/s]
Map: 100% 1039/1039 [00:00<00:00, 

---

## Experiment 3.5: Multi-Seed Statistical Experiments (PUBLICATION QUALITY)

**Run 3 seeds per configuration for statistical rigor:**
- Mean ¬± std dev computation
- Paired t-tests and Wilcoxon tests
- Cohen's d effect sizes
- LaTeX tables for paper

**This is essential for IEEE/NeurIPS/ICML publications!**

In [None]:
# Run multi-seed experiments (3 seeds √ó 3 configs = 9 runs)
# This will take 8-12 hours total
# Results include statistical tests and publication-ready tables

!python experiments/run_statistical_experiments.py \
    --seeds 3 \
    --configs atlas fedavg_cluster local_only \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola \
    --rounds 15 \
    --samples 3000 \
    --local-epochs 3

print("\n[OK] Statistical experiments complete!")
print("Results saved to: results/statistical/")

# Auto-backup to Google Drive
backup_to_drive('results/statistical', 'ATLAS_Statistical_Results')
print("\n[OK] Statistical results automatically backed up to Drive")

In [None]:
# Visualize statistical results
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Load statistical summary
stats_dir = Path('results/statistical')
summary = pd.read_csv(stats_dir / 'statistical_summary.csv')
tests = pd.read_csv(stats_dir / 'statistical_tests.csv')

# Display summary table
print("=" * 80)
print("STATISTICAL SUMMARY (Mean ¬± Std Dev)")
print("=" * 80)
print(summary.to_string(index=False))

# Display significance tests
print("\n" + "=" * 80)
print("STATISTICAL SIGNIFICANCE TESTS")
print("=" * 80)
sig_tests = tests[tests['significant'] == True]
print(f"Found {len(sig_tests)} significant comparisons (p < 0.05)")
print(sig_tests[['comparison', 'p_value_ttest', 'p_value_wilcoxon', 'cohen_d']].to_string(index=False))

# Plot: Final accuracy with error bars
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Final accuracy
configs = summary['config'].values
means = summary['final_acc_mean'].values
stds = summary['final_acc_std'].values

bars = axes[0].bar(range(len(configs)), means, 
                   color=['#2ecc71', '#3498db', '#e74c3c'],
                   edgecolor='black', linewidth=1.5)
axes[0].errorbar(range(len(configs)), means, yerr=stds,
                fmt='none', color='black', capsize=10, capthick=2, linewidth=2)
axes[0].set_xticks(range(len(configs)))
axes[0].set_xticklabels([c.replace('_', ' ').title() for c in configs])
axes[0].set_ylabel('Final Accuracy', fontsize=14, fontweight='bold')
axes[0].set_title('Final Accuracy (Mean ¬± Std, 5 seeds)', fontsize=16, fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)
axes[0].set_ylim([0.75, 0.85])

# Add value labels
for i, (bar, mean, std) in enumerate(zip(bars, means, stds)):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.005,
                f'{mean:.4f}¬±{std:.4f}', ha='center', va='bottom', 
                fontsize=10, fontweight='bold')

# Plot 2: Personalization (std dev across clients)
pers_means = summary['personalization_mean'].values
pers_stds = summary['personalization_std'].values

bars2 = axes[1].bar(range(len(configs)), pers_means,
                    color=['#2ecc71', '#3498db', '#e74c3c'],
                    edgecolor='black', linewidth=1.5)
axes[1].errorbar(range(len(configs)), pers_means, yerr=pers_stds,
                 fmt='none', color='black', capsize=10, capthick=2, linewidth=2)
axes[1].set_xticks(range(len(configs)))
axes[1].set_xticklabels([c.replace('_', ' ').title() for c in configs])
axes[1].set_ylabel('Personalization (Std Dev)', fontsize=14, fontweight='bold')
axes[1].set_title('Personalization Quality', fontsize=16, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)

# Plot 3: Communication cost
comm_means = summary['comm_mean_mb'].values
comm_stds = summary['comm_std_mb'].values

bars3 = axes[2].bar(range(len(configs)), comm_means,
                    color=['#2ecc71', '#3498db', '#e74c3c'],
                    edgecolor='black', linewidth=1.5)
axes[2].errorbar(range(len(configs)), comm_means, yerr=comm_stds,
                 fmt='none', color='black', capsize=10, capthick=2, linewidth=2)
axes[2].set_xticks(range(len(configs)))
axes[2].set_xticklabels([c.replace('_', ' ').title() for c in configs])
axes[2].set_ylabel('Total Communication (MB)', fontsize=14, fontweight='bold')
axes[2].set_title('Communication Overhead', fontsize=16, fontweight='bold')
axes[2].grid(axis='y', alpha=0.3)

plt.tight_layout()
plot_file = stats_dir / 'statistical_comparison.png'
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()

print(f"\n[OK] Plot saved: {plot_file} (300 DPI - publication quality)")

# Print LaTeX table
print("\n" + "=" * 80)
print("LATEX TABLE (Copy-paste to paper)")
print("=" * 80)
print("\\begin{table}[t]")
print("\\centering")
print("\\caption{Statistical Comparison (Mean $\\pm$ Std Dev, 5 seeds)}")
print("\\begin{tabular}{lccc}")
print("\\toprule")
print("Method & Final Acc & Personalization & Comm (MB) \\\\")
print("\\midrule")
for _, row in summary.iterrows():
    config_name = row['config'].replace('_', ' ').title()
    print(f"{config_name} & "
          f"${row['final_acc_mean']:.3f} \\pm {row['final_acc_std']:.3f}$ & "
          f"${row['personalization_mean']:.3f} \\pm {row['personalization_std']:.3f}$ & "
          f"${row['comm_mean_mb']:.1f} \\pm {row['comm_std_mb']:.1f}$ \\\\")
print("\\bottomrule")
print("\\end{tabular}")
print("\\end{table}")

---

## Experiment 3.6: Expanded GLUE Tasks (Multi-Task Heterogeneity)

Test ATLAS on more diverse NLP tasks:
- **QNLI**: Question Natural Language Inference
- **QQP**: Quora Question Pairs (duplicate detection)
- **MNLI**: Multi-Genre NLI (3x harder than current tasks)

This demonstrates generalization across diverse task types.

In [None]:
# Run ATLAS with 5 diverse GLUE tasks
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 15 \
    --ablation atlas \
    --model distilbert-base-uncased \
    --tasks sst2 mrpc cola qnli mnli \
    --clients-per-task 3 \
    --samples 3000 \
    --local-epochs 3 \
    --seed 42

print("\n[OK] Expanded GLUE task experiment complete")

# Auto-backup
backup_to_drive()
print("\n[OK] Results automatically backed up to Drive")

---

## Experiment 4: Different Models

Test ATLAS with different backbone models.

### 4.1: BERT-base (110M params)

In [None]:
# BERT-base (more parameters than DistilBERT)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model bert-base-uncased \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

print("\nBERT-base Session 1 complete")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results and checkpoint automatically backed up to Drive")

### 4.2: RoBERTa-base

In [None]:
# RoBERTa-base (better performance on many tasks)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model roberta-base \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

print("\nRoBERTa-base Session 1 complete")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results and checkpoint automatically backed up to Drive")

### 4.3: GPT-2

In [None]:
# GPT-2 (decoder-only, good for generation tasks)
!python experiments/atlas_integrated.py \
    --mode full \
    --rounds 30 \
    --max-rounds 15 \
    --ablation atlas \
    --model gpt2 \
    --tasks sst2 mrpc cola \
    --samples 5000 \
    --local-epochs 3

print("\nGPT-2 Session 1 complete")

# Auto-backup to Google Drive
backup_to_drive()
print("\n[OK] Results and checkpoint automatically backed up to Drive")

---

## Results Analysis & Visualization

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Load all experiment results
results_dir = Path('results')
experiments = {}

# Define experiment files
result_files = {
    'ATLAS (DistilBERT)': 'atlas_integrated_full_atlas.json',
    'FedAvg Cluster': 'atlas_integrated_full_fedavg_cluster.json',
    'Local Only': 'atlas_integrated_full_local_only.json',
}

print("=" * 80)
print("LOADING EXPERIMENTAL RESULTS")
print("=" * 80)

for name, filename in result_files.items():
    filepath = results_dir / filename
    if filepath.exists():
        with open(filepath, 'r') as f:
            experiments[name] = json.load(f)
        print(f"[OK] Loaded: {name}")
    else:
        print(f"[MISSING] Missing: {name} ({filename})")

if not experiments:
    print("\n[ERROR] No results found! Run experiments first.")
else:
    print(f"\n[OK] Loaded {len(experiments)} experiments")

LOADING EXPERIMENTAL RESULTS
[MISSING] Missing: ATLAS (DistilBERT) (atlas_integrated_full_atlas.json)
[MISSING] Missing: FedAvg Cluster (atlas_integrated_full_fedavg_cluster.json)
[MISSING] Missing: Local Only (atlas_integrated_full_local_only.json)

[ERROR] No results found! Run experiments first.


In [None]:
# Create comprehensive comparison table
comparison_data = []

for exp_name, results in experiments.items():
    final_accs = results.get('final_accuracies', {})
    round_metrics = results.get('round_metrics', [])

    if not final_accs or not round_metrics:
        continue

    # Calculate metrics
    client_accs = list(final_accs.values())
    avg_acc = np.mean(client_accs)
    std_acc = np.std(client_accs)
    min_acc = min(client_accs)
    max_acc = max(client_accs)

    # Communication cost
    total_comm = 0
    for rm in round_metrics:
        up = rm.get('comm_upload_bytes', {})
        down = rm.get('comm_download_bytes', {})
        if isinstance(up, dict):
            total_comm += sum(up.values()) + sum(down.values())
        else:
            total_comm += up + down
    total_comm_mb = total_comm / (1024**2)

    # Time
    total_time_min = sum(rm.get('time_seconds', 0) for rm in round_metrics) / 60

    comparison_data.append({
        'Experiment': exp_name,
        'Rounds': len(round_metrics),
        'Avg Accuracy': f'{avg_acc:.4f}',
        'Std Dev': f'{std_acc:.4f}',
        'Min Acc': f'{min_acc:.4f}',
        'Max Acc': f'{max_acc:.4f}',
        'Comm (MB)': f'{total_comm_mb:.1f}',
        'Time (min)': f'{total_time_min:.1f}'
    })

df = pd.DataFrame(comparison_data)
print("\n" + "=" * 80)
print("COMPREHENSIVE COMPARISON TABLE")
print("=" * 80)
print(df.to_string(index=False))

# Save to CSV
csv_path = results_dir / 'publication_comparison.csv'
df.to_csv(csv_path, index=False)
print(f"\nSaved to: {csv_path}")

In [None]:
# Publication-quality plots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Convergence curves
for exp_name, results in experiments.items():
    rounds = [rm['round'] for rm in results['round_metrics']]
    accs = [rm['avg_accuracy'] for rm in results['round_metrics']]
    axes[0, 0].plot(rounds, accs, 'o-', label=exp_name, linewidth=2.5, markersize=6)

axes[0, 0].set_xlabel('Round', fontsize=14, fontweight='bold')
axes[0, 0].set_ylabel('Average Accuracy', fontsize=14, fontweight='bold')
axes[0, 0].set_title('Convergence Comparison (30 Rounds)', fontsize=16, fontweight='bold')
axes[0, 0].legend(fontsize=11, loc='lower right')
axes[0, 0].grid(True, alpha=0.3, linestyle='--')
axes[0, 0].set_ylim([0.5, 1.0])

# 2. Final accuracy with error bars
exp_names = []
means = []
stds = []

for exp_name, results in experiments.items():
    accs = list(results['final_accuracies'].values())
    exp_names.append(exp_name)
    means.append(np.mean(accs))
    stds.append(np.std(accs))

bars = axes[0, 1].bar(range(len(exp_names)), means,
                       color=['#2ecc71', '#3498db', '#e74c3c'][:len(exp_names)],
                       edgecolor='black', linewidth=1.5)
axes[0, 1].errorbar(range(len(exp_names)), means, yerr=stds,
                    fmt='none', color='black', capsize=8, capthick=2)
axes[0, 1].set_xticks(range(len(exp_names)))
axes[0, 1].set_xticklabels(exp_names, rotation=15, ha='right')
axes[0, 1].set_ylabel('Average Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].set_title('Final Accuracy (Mean ¬± Std)', fontsize=16, fontweight='bold')
axes[0, 1].grid(axis='y', alpha=0.3, linestyle='--')
axes[0, 1].set_ylim([0.5, 1.0])

# Add value labels
for i, (bar, mean) in enumerate(zip(bars, means)):
    axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + stds[i] + 0.02,
                    f'{mean:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# 3. Per-client accuracy (personalization)
for exp_name, results in experiments.items():
    client_ids = sorted(results['final_accuracies'].keys(), key=lambda x: int(x))
    accs = [results['final_accuracies'][cid] for cid in client_ids]
    axes[1, 0].plot(range(len(client_ids)), accs, 'o-', label=exp_name,
                    linewidth=2.5, markersize=7)

axes[1, 0].set_xlabel('Client ID', fontsize=14, fontweight='bold')
axes[1, 0].set_ylabel('Accuracy', fontsize=14, fontweight='bold')
axes[1, 0].set_title('Per-Client Personalization', fontsize=16, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3, linestyle='--')
axes[1, 0].set_ylim([0.5, 1.0])

# 4. Communication cost
comm_costs = []
for exp_name, results in experiments.items():
    total = 0
    for rm in results['round_metrics']:
        up = rm.get('comm_upload_bytes', {})
        down = rm.get('comm_download_bytes', {})
        if isinstance(up, dict):
            total += sum(up.values()) + sum(down.values())
        else:
            total += up + down
    comm_costs.append(total / (1024**2))

axes[1, 1].bar(range(len(exp_names)), comm_costs,
               color=['#2ecc71', '#3498db', '#e74c3c'][:len(exp_names)],
               edgecolor='black', linewidth=1.5)
axes[1, 1].set_xticks(range(len(exp_names)))
axes[1, 1].set_xticklabels(exp_names, rotation=15, ha='right')
axes[1, 1].set_ylabel('Total Communication (MB)', fontsize=14, fontweight='bold')
axes[1, 1].set_title('Communication Overhead', fontsize=16, fontweight='bold')
axes[1, 1].grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plot_path = results_dir / 'publication_results.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"\n[OK] High-resolution plot saved: {plot_path}")
print("  (300 DPI - suitable for IEEE publications)")

---

## Download Results for Publication

Package all results and figures for offline analysis.

In [None]:
# Create publication package
!zip -r atlas_publication_results.zip results/ figures/ checkpoints/ \
    -x "*.pyc" "*__pycache__*"

print("[OK] Results packaged: atlas_publication_results.zip")
print("\nContents:")
print("  - results/*.json (all experimental data)")
print("  - results/publication_comparison.csv")
print("  - results/publication_results.png (300 DPI)")
print("  - checkpoints/*.pkl (for resuming)")

# Download (in Colab)
from google.colab import files
files.download('atlas_publication_results.zip')

---

## Citation & IEEE Formatting

**Suggested IEEE Paper Structure**:

1. **Abstract**: Multi-task FL with heterogeneous devices + LoRA + Laplacian regularization
2. **Introduction**: Challenges of FL for LLMs on edge devices
3. **Related Work**: FedAvg, LoRA, Split Learning, MIRA, HSplitLoRA
4. **Methodology**:
   - Phase 1: Gradient-based clustering
   - Phase 2: Heterogeneous rank allocation
   - Phase 3: Split federated learning
   - Phase 4: Graph-based personalization
5. **Experiments**:
   - Setup: 9 clients, 3 tasks, 30 rounds, 5000 samples
   - Baselines: Local Only, FedAvg per Cluster
   - Results: Convergence, accuracy, communication, ablations
6. **Results & Discussion**: Show plots from above
7. **Conclusion**: ATLAS enables personalized LLM fine-tuning on heterogeneous edge devices

**Key Metrics for IEEE Paper**:
- Convergence rate (rounds to 90% of final accuracy)
- Final accuracy (mean ¬± std across clients)
- Communication cost (MB per round, total MB)
- Personalization quality (variance, per-task accuracy)
- Ablation study results (with/without each phase)