# ATLAS Integrated Experiments on Colab GPU

This notebook runs the complete ATLAS pipeline:
- **Phase 1**: Gradient-based task clustering
- **Phase 2**: Heterogeneous LoRA rank allocation
- **Phase 3**: Split federated learning
- **Phase 4**: MIRA Laplacian regularization

**Modes:**
- `quick`: 2 tasks, 4 clients, 500 samples, 5 rounds (~15-20 min)
- `full`: 3 tasks, 9 clients, 2000 samples, 10 rounds (~2-3 hours)

**Device Simulation:**
- CPU 2GB
- Tablet 4GB
- Laptop 8GB
- GPU 16GB

## Setup

In [None]:
# Clone repository (if not already cloned)
!git clone https://github.com/YOUR_USERNAME/ATLAS.git
%cd ATLAS

In [None]:
# Install dependencies
!pip install -q torch transformers datasets peft scikit-learn scipy numpy

In [None]:
# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## Quick Mode (15-20 minutes)

Run a fast experiment with 2 tasks, 4 clients, 500 samples per client.

In [None]:
!python experiments/atlas_integrated.py --mode quick

## View Quick Mode Results

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Load results
with open('results/atlas_integrated_quick.json', 'r') as f:
    results = json.load(f)

print("\n" + "="*60)
print("ATLAS QUICK MODE RESULTS")
print("="*60)

# Clustering results
print("\nüìä Phase 1: Task Clustering")
print(f"Num Clusters: {results['clustering']['num_clusters']}")
print(f"Cluster Assignments: {results['clustering']['assignments']}")
print(f"Silhouette Score: {results['clustering']['silhouette_score']:.3f}")

# Device profiles
print("\nüíª Phase 2: Heterogeneous Device Profiles")
for client_id, info in results['device_profiles'].items():
    print(f"{client_id}: {info['device_type']} (rank {info['lora_rank']})")

# Final accuracies
print("\nüéØ Final Results (Round 5)")
final_round = results['rounds'][4]  # 0-indexed
for client_id, acc in final_round['client_accuracies'].items():
    print(f"{client_id}: {acc:.4f}")

print(f"\nAverage Accuracy: {final_round['avg_accuracy']:.4f}")
print(f"Total Runtime: {results['total_time']:.1f} seconds ({results['total_time']/60:.1f} minutes)")

## Plot Training Progress

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract round-by-round metrics
rounds = [r['round'] for r in results['rounds']]
avg_accs = [r['avg_accuracy'] for r in results['rounds']]

# Per-client accuracy over time
client_ids = list(results['device_profiles'].keys())
client_accs = {cid: [] for cid in client_ids}
for r in results['rounds']:
    for cid in client_ids:
        client_accs[cid].append(r['client_accuracies'][cid])

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Average accuracy
axes[0].plot(rounds, avg_accs, 'o-', linewidth=2, markersize=8, color='blue')
axes[0].set_xlabel('Round', fontsize=12)
axes[0].set_ylabel('Average Accuracy', fontsize=12)
axes[0].set_title('ATLAS Convergence (Quick Mode)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim([0.5, 1.0])

# Per-client accuracy
for cid in client_ids:
    axes[1].plot(rounds, client_accs[cid], 'o-', label=cid, linewidth=2, markersize=6)
axes[1].set_xlabel('Round', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Per-Client Personalized Models', fontsize=14, fontweight='bold')
axes[1].legend(loc='lower right', fontsize=9)
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0.5, 1.0])

plt.tight_layout()
plt.savefig('results/atlas_quick_convergence.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Saved plot to results/atlas_quick_convergence.png")

## Clustering Visualization

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Get gradient fingerprints (you'll need to save these in the main script)
# For now, visualize cluster assignments
cluster_assignments = results['clustering']['assignments']
client_ids = list(cluster_assignments.keys())
clusters = [cluster_assignments[cid] for cid in client_ids]

# Create simple visualization
fig, ax = plt.subplots(figsize=(10, 6))

# Group clients by cluster
cluster_to_clients = {}
for cid, cluster_id in cluster_assignments.items():
    if cluster_id not in cluster_to_clients:
        cluster_to_clients[cluster_id] = []
    cluster_to_clients[cluster_id].append(cid)

# Plot
colors = ['red', 'blue', 'green', 'orange', 'purple']
y_pos = 0
for cluster_id, clients in sorted(cluster_to_clients.items()):
    for i, cid in enumerate(clients):
        ax.scatter(i, cluster_id, s=300, c=colors[cluster_id], alpha=0.6, edgecolors='black', linewidth=2)
        ax.text(i, cluster_id, cid.split('_')[0], ha='center', va='center', fontsize=9, fontweight='bold')

ax.set_xlabel('Client Index within Cluster', fontsize=12)
ax.set_ylabel('Cluster ID', fontsize=12)
ax.set_title('ATLAS Phase 1: Task-Based Client Clustering', fontsize=14, fontweight='bold')
ax.set_yticks(range(len(cluster_to_clients)))
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('results/atlas_clustering.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Saved plot to results/atlas_clustering.png")

---

## Full Mode (2-3 hours)

Run complete ATLAS experiment with 3 tasks, 9 clients, 2000 samples per client.

‚ö†Ô∏è **This will take 2-3 hours. Make sure you have time!**

In [None]:
# Run full experiment
!python experiments/atlas_integrated.py --mode full

## Resume from Checkpoint (if interrupted)

In [None]:
# Resume from last checkpoint
!python experiments/atlas_integrated.py --mode full --resume checkpoints/atlas_full_round_*.pt

## View Full Mode Results

In [None]:
# Load and visualize full results (same as quick mode)
with open('results/atlas_integrated_full.json', 'r') as f:
    results_full = json.load(f)

print("\n" + "="*60)
print("ATLAS FULL MODE RESULTS")
print("="*60)

# Clustering
print("\nüìä Phase 1: Task Clustering")
print(f"Num Clusters: {results_full['clustering']['num_clusters']}")
print(f"Cluster Assignments: {results_full['clustering']['assignments']}")
print(f"Silhouette Score: {results_full['clustering']['silhouette_score']:.3f}")

# Device profiles
print("\nüíª Phase 2: Heterogeneous Devices")
for client_id, info in results_full['device_profiles'].items():
    print(f"{client_id}: {info['device_type']} (rank {info['lora_rank']})")

# Final results
print("\nüéØ Final Results (Round 10)")
final_round = results_full['rounds'][9]
for client_id, acc in final_round['client_accuracies'].items():
    task = client_id.split('_')[0]
    print(f"{client_id} ({task}): {acc:.4f}")

print(f"\nüìà Average Accuracy: {final_round['avg_accuracy']:.4f}")
print(f"‚è±Ô∏è Total Runtime: {results_full['total_time']/60:.1f} minutes")

## Compare Baselines

In [None]:
# You can compare with your previous experiments:
# - Standard FL (exp1_standard_fl.json)
# - Homogeneous LoRA FL (exp2_lora_fl.json)
# - ATLAS Integrated (atlas_integrated_full.json)

print("\n" + "="*60)
print("COMPARISON: ATLAS vs Baselines")
print("="*60)
print("\n(Run exp1 and exp2 from real_training.py first for fair comparison)")
print("\nExpected findings:")
print("- ATLAS should have better personalized accuracy (per-client)")
print("- ATLAS should handle heterogeneous devices efficiently")
print("- Task clustering should group similar clients together")
print("- Laplacian regularization should improve convergence")

## Download Results

In [None]:
# Zip results for download
!zip -r atlas_results.zip results/ checkpoints/
from google.colab import files
files.download('atlas_results.zip')