# 🔗 Adapter Fusion Basics

This notebook introduces the basics of Adapter Fusion - an advanced parameter-efficient fine-tuning method.

## What is Adapter Fusion?

Adapter Fusion combines knowledge from multiple task-specific adapters:
- **Step 1**: Train individual adapters on different tasks
- **Step 2**: Combine adapters using fusion mechanisms
- **Step 3**: Enable knowledge transfer between tasks

## Architecture

```
Base Model (Frozen)
    ↓
Task A Adapter → \
Task B Adapter → → Fusion Layer → Combined Output
Task C Adapter → /
```

## Benefits

- **Knowledge Transfer**: Tasks learn from each other
- **Parameter Efficiency**: Only fusion layer is added
- **Modularity**: Easy to add/remove tasks
- **No Catastrophic Forgetting**: Previous tasks are preserved

## 1. Setup Environment

In [None]:
# Install required packages (run once)
# !pip install -r ../requirements.txt

import sys
sys.path.append('..')

import torch
import numpy as np
from datasets import Dataset

# Import our fusion modules
from config import ModelConfig, FusionConfig, TrainingConfig, AdapterConfig
from fusion import FusionModel, AdapterManager, AttentionFusion, WeightedFusion
from adapters import BottleneckAdapter
from training import FusionTrainer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Understanding Fusion Mechanisms

In [None]:
# Create dummy adapter outputs to understand fusion
batch_size, seq_len, hidden_size = 2, 10, 768
num_adapters = 3

# Simulate outputs from 3 different adapters
adapter_outputs = [
    torch.randn(batch_size, seq_len, hidden_size),  # Sentiment adapter
    torch.randn(batch_size, seq_len, hidden_size),  # NLI adapter  
    torch.randn(batch_size, seq_len, hidden_size),  # QA adapter
]

print(f"Number of adapters: {len(adapter_outputs)}")
print(f"Each adapter output shape: {adapter_outputs[0].shape}")

### 2.1 Attention-based Fusion

In [None]:
# Create attention fusion layer
attention_fusion = AttentionFusion(
    hidden_size=hidden_size,
    num_adapters=num_adapters,
    num_attention_heads=8,
    dropout=0.1
)

print("Attention Fusion Architecture:")
print(attention_fusion)

# Test forward pass
fused_output = attention_fusion(adapter_outputs)
print(f"\nFused output shape: {fused_output.shape}")

# Count parameters
fusion_params = sum(p.numel() for p in attention_fusion.parameters())
print(f"Fusion layer parameters: {fusion_params:,}")

### 2.2 Weighted Fusion

In [None]:
# Create weighted fusion layer
weighted_fusion = WeightedFusion(
    hidden_size=hidden_size,
    num_adapters=num_adapters,
    learnable_weights=True,
    weight_initialization="uniform"
)

print("Weighted Fusion Architecture:")
print(weighted_fusion)

# Test forward pass
fused_output = weighted_fusion(adapter_outputs)
print(f"\nFused output shape: {fused_output.shape}")

# Show learned weights
weights = torch.softmax(weighted_fusion.fusion_weights, dim=0)
print(f"\nLearned fusion weights: {weights.detach().numpy()}")
print(f"Weight sum: {weights.sum().item():.3f}")

# Count parameters
fusion_params = sum(p.numel() for p in weighted_fusion.parameters())
print(f"Fusion layer parameters: {fusion_params:,}")

## 3. Create Individual Adapters

In [None]:
# Create adapter configurations for different tasks
adapter_configs = {
    "sentiment": AdapterConfig(
        adapter_size=64,
        task_name="sentiment",
        task_type="classification"
    ),
    "nli": AdapterConfig(
        adapter_size=64,
        task_name="nli", 
        task_type="classification"
    ),
    "qa": AdapterConfig(
        adapter_size=128,  # Larger adapter for complex QA task
        task_name="qa",
        task_type="question_answering"
    )
}

# Create adapter manager
adapter_manager = AdapterManager(
    hidden_size=hidden_size,
    adapter_configs=adapter_configs,
    freeze_adapters=True  # Freeze for fusion training
)

print("Adapter Manager Information:")
adapter_manager.print_adapter_info()

## 4. Test Adapter Manager

In [None]:
# Test adapter manager with dummy input
dummy_input = torch.randn(batch_size, seq_len, hidden_size)

# Get outputs from all adapters
all_outputs = adapter_manager(dummy_input)
print(f"Adapter outputs: {list(all_outputs.keys())}")

# Get outputs from specific adapters
selected_outputs = adapter_manager(dummy_input, adapter_names=["sentiment", "nli"])
print(f"Selected adapter outputs: {list(selected_outputs.keys())}")

# Check output shapes
for name, output in all_outputs.items():
    print(f"{name} adapter output shape: {output.shape}")

## 5. Create Complete Fusion Model

In [None]:
# Configuration for fusion model
model_config = ModelConfig(
    model_name_or_path="distilbert-base-uncased",
    num_labels=2,
    max_length=128,
    multi_task=True,
    task_names=list(adapter_configs.keys())
)

fusion_config = FusionConfig(
    fusion_method="attention",
    num_attention_heads=8,
    fusion_dropout=0.1,
    freeze_adapters_during_fusion=True,
    adapter_names=list(adapter_configs.keys())
)

# Create fusion model
print("Creating fusion model...")
fusion_model = FusionModel(
    model_config=model_config,
    fusion_config=fusion_config,
    adapter_manager=adapter_manager
)

# Print model information
fusion_model.print_model_info()

## 6. Test Fusion Model Inference

In [None]:
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)

# Test texts
test_texts = [
    "This movie is absolutely fantastic!",
    "The film was terrible and boring.",
    "An okay movie, nothing special."
]

fusion_model.eval()

print("Testing Fusion Model Inference:")
print("=" * 50)

for text in test_texts:
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=model_config.max_length
    )
    
    with torch.no_grad():
        # Test with different adapter combinations
        outputs_sentiment = fusion_model(**inputs, adapter_names=["sentiment"])
        outputs_all = fusion_model(**inputs, adapter_names=list(adapter_configs.keys()))
        
        # Get probabilities
        sentiment_probs = torch.softmax(outputs_sentiment.logits, dim=-1)
        all_probs = torch.softmax(outputs_all.logits, dim=-1)
        
        print(f"\nText: '{text}'")
        print(f"Sentiment only: {sentiment_probs[0].numpy()}")
        print(f"All adapters: {all_probs[0].numpy()}")
        
        # Prediction
        pred_sentiment = "Positive" if sentiment_probs[0][1] > 0.5 else "Negative"
        pred_all = "Positive" if all_probs[0][1] > 0.5 else "Negative"
        
        print(f"Prediction (sentiment): {pred_sentiment}")
        print(f"Prediction (fused): {pred_all}")

## 7. Fusion Efficiency Analysis

In [None]:
# Analyze fusion efficiency
def analyze_model_efficiency(model):
    """Analyze model parameter efficiency"""
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Get component-wise parameters
    base_params = sum(p.numel() for p in model.base_model.parameters())
    adapter_params = sum(p.numel() for p in model.get_adapter_parameters())
    fusion_params = sum(p.numel() for p in model.get_fusion_parameters())
    
    print("Model Efficiency Analysis:")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable percentage: {trainable_params/total_params*100:.2f}%")
    print()
    print("Component Breakdown:")
    print(f"Base model: {base_params:,} ({base_params/total_params*100:.1f}%)")
    print(f"Adapters: {adapter_params:,} ({adapter_params/total_params*100:.1f}%)")
    print(f"Fusion layer: {fusion_params:,} ({fusion_params/total_params*100:.1f}%)")
    
    # Compare with full fine-tuning
    full_finetuning_params = base_params
    fusion_overhead = adapter_params + fusion_params
    reduction_factor = full_finetuning_params / fusion_overhead
    
    print()
    print("Efficiency Comparison:")
    print(f"Full fine-tuning would train: {full_finetuning_params:,} parameters")
    print(f"Fusion approach trains: {fusion_overhead:,} parameters")
    print(f"Parameter reduction: {reduction_factor:.1f}x fewer parameters")
    
    return {
        "total_params": total_params,
        "trainable_params": trainable_params,
        "base_params": base_params,
        "adapter_params": adapter_params,
        "fusion_params": fusion_params,
        "reduction_factor": reduction_factor
    }

# Analyze our fusion model
efficiency_stats = analyze_model_efficiency(fusion_model)

## 8. Fusion Method Comparison

In [None]:
# Compare different fusion methods
fusion_methods = {
    "attention": AttentionFusion(hidden_size, num_adapters, dropout=0.1),
    "weighted": WeightedFusion(hidden_size, num_adapters, dropout=0.1),
}

print("Fusion Method Comparison:")
print("=" * 40)

for method_name, fusion_layer in fusion_methods.items():
    # Count parameters
    params = sum(p.numel() for p in fusion_layer.parameters())
    
    # Test forward pass
    with torch.no_grad():
        output = fusion_layer(adapter_outputs)
    
    print(f"\n{method_name.capitalize()} Fusion:")
    print(f"  Parameters: {params:,}")
    print(f"  Output shape: {output.shape}")
    print(f"  Output mean: {output.mean().item():.4f}")
    print(f"  Output std: {output.std().item():.4f}")

# Memory usage comparison
print("\nMemory Usage (approximate):")
bytes_per_param = 4  # float32

for method_name, fusion_layer in fusion_methods.items():
    params = sum(p.numel() for p in fusion_layer.parameters())
    memory_mb = (params * bytes_per_param) / (1024 * 1024)
    print(f"{method_name.capitalize()}: {memory_mb:.2f} MB")

## 9. Key Takeaways

From this notebook, you learned:

1. **Fusion Architecture**: How to combine multiple adapters using different mechanisms
2. **Attention Fusion**: Uses attention to learn optimal adapter combinations
3. **Weighted Fusion**: Simple learnable weights for adapter combination
4. **Parameter Efficiency**: Massive reduction in trainable parameters
5. **Modularity**: Easy to add/remove adapters for different tasks

## Next Steps

- Try training individual adapters on real tasks
- Experiment with different fusion methods
- Test multi-task learning scenarios
- Compare with traditional multi-task learning
- Explore hierarchical fusion strategies

## Resources

- [AdapterFusion Paper](https://arxiv.org/abs/2005.00247)
- [Adapter-Hub](https://adapterhub.ml/)
- [Parameter-Efficient Transfer Learning Survey](https://arxiv.org/abs/2106.04647)