# Basic Disaggregation (Without RDMA)

Split prefill and decode across two nodes using standard TCP/IP networking. This establishes the disaggregation architecture before optimizing with RDMA.

## Architecture

```
┌─────────────┐                          ┌─────────────┐
│  Node 1     │                          │  Node 2     │
│  (Prefill)  │                          │  (Decode)   │
│             │                          │             │
│  1. Receive │                          │  3. Receive │
│     prompt  │                          │     KV cache│
│             │                          │             │
│  2. Process │   ── KV Cache ──>       │  4. Generate│
│     prompt  │      (TCP/IP)            │     tokens  │
│             │                          │             │
│  GPU #0     │                          │  GPU #1     │
└─────────────┘                          └─────────────┘
```

## Why Split This Way?

- **Prefill**: Compute-intensive, processes entire prompt at once
- **Decode**: Memory-intensive, generates one token at a time
- **Benefit**: Each node specializes, can optimize differently

## What We're Measuring

- End-to-end latency (prefill + transfer + decode)
- Transfer overhead as % of total time
- Throughput compared to baseline single-node

## Step 1: Setup - Load Configuration

In [None]:
import json
import socket
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load environment config
config_file = Path("environment_config.json")
if config_file.exists():
    with open(config_file) as f:
        env_config = json.load(f)
    MODEL_NAME = env_config['model']['name']
    NODE1_IP = env_config['network']['node1_ip']
    NODE2_IP = env_config['network']['node2_ip']
else:
    MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    NODE1_IP = "192.168.100.10"
    NODE2_IP = "192.168.100.11"

# Determine current node
hostname = socket.gethostname()
is_node1 = "01" in hostname or "dgx01" in hostname
is_node2 = "02" in hostname or "dgx02" in hostname

print(f"Hostname: {hostname}")
print(f"Model: {MODEL_NAME}")
print(f"Node 1 (Prefill): {NODE1_IP}")
print(f"Node 2 (Decode): {NODE2_IP}")
print(f"\nThis node is: {'Prefill' if is_node1 else 'Decode' if is_node2 else 'Unknown'}")

## Step 2: Implement KV Cache Serialization

Before sending KV cache over network, we need to serialize it. This converts PyTorch tensors to bytes.

In [None]:
import pickle
import time

def serialize_kv_cache(past_key_values):
    """
    Serialize KV cache for network transfer.
    
    Args:
        past_key_values: Tuple of (key, value) tensors per layer
    
    Returns:
        bytes: Serialized cache data
    """
    # Move to CPU for serialization
    cpu_cache = []
    for key, value in past_key_values:
        cpu_cache.append((key.cpu(), value.cpu()))
    
    # Serialize with pickle
    serialized = pickle.dumps(cpu_cache)
    return serialized

def deserialize_kv_cache(serialized_data, device='cuda'):
    """
    Deserialize KV cache and move to GPU.
    
    Args:
        serialized_data: Bytes from serialize_kv_cache
        device: Target device for tensors
    
    Returns:
        Tuple of (key, value) tensors per layer
    """
    # Deserialize
    cpu_cache = pickle.loads(serialized_data)
    
    # Move to target device
    device_cache = []
    for key, value in cpu_cache:
        device_cache.append((key.to(device), value.to(device)))
    
    return tuple(device_cache)

# Test serialization
print("Testing KV cache serialization...\n")

# Create dummy KV cache
num_layers = 4
batch_size = 1
num_heads = 8
seq_len = 100
head_dim = 64

dummy_cache = []
for _ in range(num_layers):
    key = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda' if torch.cuda.is_available() else 'cpu')
    value = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda' if torch.cuda.is_available() else 'cpu')
    dummy_cache.append((key, value))

# Serialize
start = time.time()
serialized = serialize_kv_cache(dummy_cache)
serialize_time = time.time() - start

# Deserialize
start = time.time()
deserialized = deserialize_kv_cache(serialized)
deserialize_time = time.time() - start

print(f"Original cache: {len(dummy_cache)} layers")
print(f"Serialized size: {len(serialized) / 1e6:.2f} MB")
print(f"Serialize time: {serialize_time * 1000:.2f} ms")
print(f"Deserialize time: {deserialize_time * 1000:.2f} ms")
print(f"Total overhead: {(serialize_time + deserialize_time) * 1000:.2f} ms")

## Step 3: Implement Prefill Server (Node 1)

This runs on the prefill node. It:
1. Receives prompts
2. Runs prefill phase
3. Sends KV cache to decode node

In [None]:
import socket
import struct

class PrefillServer:
    """Prefill server - processes prompts and sends KV cache."""
    
    def __init__(self, model_name, port=5555):
        self.model_name = model_name
        self.port = port
        self.model = None
        self.tokenizer = None
        
    def load_model(self):
        """Load model for prefill."""
        print(f"Loading model: {self.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model.eval()
        print("✓ Model loaded")
    
    def prefill(self, prompt):
        """
        Run prefill phase - process prompt and generate KV cache.
        
        Returns:
            dict with input_ids, past_key_values, and metrics
        """
        start = time.time()
        
        # Tokenize
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        input_ids = inputs['input_ids']
        
        # Run forward pass to generate KV cache
        with torch.no_grad():
            outputs = self.model(**inputs, use_cache=True)
            past_key_values = outputs.past_key_values
        
        prefill_time = time.time() - start
        
        return {
            'input_ids': input_ids,
            'past_key_values': past_key_values,
            'prompt': prompt,
            'prefill_time_ms': prefill_time * 1000
        }
    
    def send_to_decode_node(self, result, decode_host, decode_port=5556):
        """
        Send KV cache to decode node via TCP.
        
        Protocol:
        1. Send prompt length (4 bytes)
        2. Send prompt (variable)
        3. Send input_ids length (4 bytes)
        4. Send input_ids (variable)
        5. Send KV cache length (4 bytes)
        6. Send KV cache (variable)
        """
        start = time.time()
        
        # Serialize KV cache
        serialized_kv = serialize_kv_cache(result['past_key_values'])
        prompt_bytes = result['prompt'].encode('utf-8')
        input_ids_bytes = pickle.dumps(result['input_ids'].cpu())
        
        # Connect to decode node
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((decode_host, decode_port))
        
        try:
            # Send prompt
            sock.sendall(struct.pack('I', len(prompt_bytes)))
            sock.sendall(prompt_bytes)
            
            # Send input_ids
            sock.sendall(struct.pack('I', len(input_ids_bytes)))
            sock.sendall(input_ids_bytes)
            
            # Send KV cache
            sock.sendall(struct.pack('I', len(serialized_kv)))
            sock.sendall(serialized_kv)
            
        finally:
            sock.close()
        
        transfer_time = time.time() - start
        
        return {
            'transfer_time_ms': transfer_time * 1000,
            'kv_cache_mb': len(serialized_kv) / 1e6
        }

# Initialize prefill server (only on Node 1)
if is_node1:
    print("Initializing Prefill Server...")
    prefill_server = PrefillServer(MODEL_NAME)
    prefill_server.load_model()
else:
    print("This is not the prefill node - skip prefill server setup")

## Step 4: Implement Decode Server (Node 2)

This runs on the decode node. It:
1. Receives KV cache from prefill node
2. Runs decode phase
3. Generates output tokens

In [None]:
import threading

class DecodeServer:
    """Decode server - receives KV cache and generates tokens."""
    
    def __init__(self, model_name, port=5556):
        self.model_name = model_name
        self.port = port
        self.model = None
        self.tokenizer = None
        self.server_socket = None
        
    def load_model(self):
        """Load model for decode."""
        print(f"Loading model: {self.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model.eval()
        print("✓ Model loaded")
    
    def decode(self, input_ids, past_key_values, max_new_tokens=50):
        """
        Run decode phase - generate tokens using received KV cache.
        
        Args:
            input_ids: Input token IDs
            past_key_values: KV cache from prefill
            max_new_tokens: Number of tokens to generate
        
        Returns:
            dict with generated text and metrics
        """
        start = time.time()
        
        # Move input_ids to device
        input_ids = input_ids.to(self.model.device)
        
        # Generate with provided cache
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids,
                past_key_values=past_key_values,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode output
        generated_text = self.tokenizer.decode(
            outputs[0][input_ids.shape[1]:],
            skip_special_tokens=True
        )
        
        decode_time = time.time() - start
        tokens_generated = outputs.shape[1] - input_ids.shape[1]
        
        return {
            'text': generated_text,
            'tokens': tokens_generated,
            'decode_time_ms': decode_time * 1000,
            'tokens_per_sec': tokens_generated / decode_time
        }
    
    def receive_from_prefill_node(self, client_socket):
        """
        Receive KV cache from prefill node.
        
        Returns:
            dict with prompt, input_ids, past_key_values
        """
        start = time.time()
        
        # Receive prompt
        prompt_len = struct.unpack('I', self._recv_exact(client_socket, 4))[0]
        prompt = self._recv_exact(client_socket, prompt_len).decode('utf-8')
        
        # Receive input_ids
        input_ids_len = struct.unpack('I', self._recv_exact(client_socket, 4))[0]
        input_ids_bytes = self._recv_exact(client_socket, input_ids_len)
        input_ids = pickle.loads(input_ids_bytes)
        
        # Receive KV cache
        kv_len = struct.unpack('I', self._recv_exact(client_socket, 4))[0]
        kv_bytes = self._recv_exact(client_socket, kv_len)
        past_key_values = deserialize_kv_cache(kv_bytes, device=self.model.device)
        
        receive_time = time.time() - start
        
        return {
            'prompt': prompt,
            'input_ids': input_ids,
            'past_key_values': past_key_values,
            'receive_time_ms': receive_time * 1000,
            'kv_cache_mb': kv_len / 1e6
        }
    
    def _recv_exact(self, sock, n):
        """Receive exactly n bytes from socket."""
        data = b''
        while len(data) < n:
            chunk = sock.recv(n - len(data))
            if not chunk:
                raise ConnectionError("Socket connection broken")
            data += chunk
        return data
    
    def start_server(self):
        """Start listening for prefill requests."""
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.server_socket.bind(('0.0.0.0', self.port))
        self.server_socket.listen(5)
        print(f"Decode server listening on port {self.port}")

# Initialize decode server (only on Node 2)
if is_node2:
    print("Initializing Decode Server...")
    decode_server = DecodeServer(MODEL_NAME)
    decode_server.load_model()
    print("\nTo start server, run:")
    print("  decode_server.start_server()")
else:
    print("This is not the decode node - skip decode server setup")

## Step 5: Test Disaggregated Inference

Run end-to-end test: Node 1 does prefill, sends to Node 2, Node 2 does decode.

**Instructions:**
1. On Node 2: Run the decode server cell above, then start server
2. On Node 1: Run this cell to test complete pipeline

In [None]:
# This cell runs on Node 1 (prefill node)
if is_node1:
    print("Running Disaggregated Inference Test\n")
    print("="*60)
    
    test_prompt = "Explain container orchestration systems like Kubernetes."
    
    print(f"Prompt: '{test_prompt}'\n")
    
    # Step 1: Prefill on Node 1
    print("[Node 1] Running prefill...")
    prefill_result = prefill_server.prefill(test_prompt)
    print(f"  Prefill time: {prefill_result['prefill_time_ms']:.2f} ms")
    print(f"  Input tokens: {prefill_result['input_ids'].shape[1]}")
    
    # Step 2: Transfer to Node 2
    print(f"\n[Network] Sending KV cache to {NODE2_IP}...")
    transfer_result = prefill_server.send_to_decode_node(prefill_result, NODE2_IP)
    print(f"  Transfer time: {transfer_result['transfer_time_ms']:.2f} ms")
    print(f"  KV cache size: {transfer_result['kv_cache_mb']:.2f} MB")
    bandwidth_gbps = (transfer_result['kv_cache_mb'] * 8) / (transfer_result['transfer_time_ms'] / 1000)
    print(f"  Effective bandwidth: {bandwidth_gbps:.2f} Gbps")
    
    # Note: Decode happens on Node 2
    print("\n[Node 2] Decode running on remote node...")
    print("  (Check Node 2's output for decode results)")
    
    # Calculate total pipeline
    total_time = prefill_result['prefill_time_ms'] + transfer_result['transfer_time_ms']
    transfer_overhead_pct = (transfer_result['transfer_time_ms'] / total_time) * 100
    
    print("\n" + "="*60)
    print("Pipeline Summary (Prefill + Transfer):")
    print("="*60)
    print(f"Prefill time: {prefill_result['prefill_time_ms']:.2f} ms")
    print(f"Transfer time: {transfer_result['transfer_time_ms']:.2f} ms")
    print(f"Total so far: {total_time:.2f} ms")
    print(f"Transfer overhead: {transfer_overhead_pct:.1f}%")
    
else:
    print("This cell should run on Node 1 (prefill node)")

## Step 6: Decode Node Processing (Run on Node 2)

This cell runs on Node 2 to handle incoming requests.

In [None]:
# This cell runs on Node 2 (decode node)
if is_node2:
    print("Starting Decode Server (waiting for requests)...\n")
    
    decode_server.start_server()
    
    try:
        # Accept one connection for testing
        client_socket, client_address = decode_server.server_socket.accept()
        print(f"Received connection from {client_address}\n")
        
        # Receive KV cache
        print("[Network] Receiving KV cache...")
        received = decode_server.receive_from_prefill_node(client_socket)
        print(f"  Receive time: {received['receive_time_ms']:.2f} ms")
        print(f"  KV cache size: {received['kv_cache_mb']:.2f} MB")
        
        # Run decode
        print(f"\n[Node 2] Running decode...")
        print(f"  Prompt: '{received['prompt']}'")
        decode_result = decode_server.decode(
            received['input_ids'],
            received['past_key_values'],
            max_new_tokens=50
        )
        print(f"  Decode time: {decode_result['decode_time_ms']:.2f} ms")
        print(f"  Tokens generated: {decode_result['tokens']}")
        print(f"  Throughput: {decode_result['tokens_per_sec']:.1f} tokens/sec")
        print(f"\n  Generated text:\n  '{decode_result['text']}'")
        
        # Total pipeline time
        total_decode_side = received['receive_time_ms'] + decode_result['decode_time_ms']
        
        print("\n" + "="*60)
        print("Decode Node Summary:")
        print("="*60)
        print(f"Receive time: {received['receive_time_ms']:.2f} ms")
        print(f"Decode time: {decode_result['decode_time_ms']:.2f} ms")
        print(f"Total: {total_decode_side:.2f} ms")
        
    finally:
        client_socket.close()
        decode_server.server_socket.close()
        
else:
    print("This cell should run on Node 2 (decode node)")

## Step 7: Compare with Baseline

Load baseline metrics and compare disaggregated performance.

In [None]:
# Load baseline metrics
baseline_file = Path("baseline_metrics.json")
if baseline_file.exists():
    with open(baseline_file) as f:
        baseline = json.load(f)
    
    baseline_latency = baseline['single_request']['latency_ms']
    baseline_throughput = baseline['single_request']['throughput_tokens_per_sec']
    
    # Example disaggregated times (update with actual measurements)
    # These would come from running the cells above
    disagg_prefill = 50  # ms - from Node 1
    disagg_transfer = 15  # ms - network transfer
    disagg_decode = 100  # ms - from Node 2
    disagg_total = disagg_prefill + disagg_transfer + disagg_decode
    
    print("Performance Comparison\n")
    print("="*60)
    print(f"\nBaseline (Single Node):")
    print(f"  Total latency: {baseline_latency:.1f} ms")
    print(f"  Throughput: {baseline_throughput:.1f} tokens/sec")
    
    print(f"\nDisaggregated (Two Nodes):")
    print(f"  Prefill: {disagg_prefill:.1f} ms")
    print(f"  Transfer: {disagg_transfer:.1f} ms")
    print(f"  Decode: {disagg_decode:.1f} ms")
    print(f"  Total: {disagg_total:.1f} ms")
    
    slowdown = (disagg_total / baseline_latency - 1) * 100
    transfer_pct = (disagg_transfer / disagg_total) * 100
    
    print(f"\nAnalysis:")
    print(f"  Slowdown: {slowdown:+.1f}%")
    print(f"  Transfer overhead: {transfer_pct:.1f}% of total time")
    
    print("\n" + "="*60)
    print("Key Insight:")
    print("="*60)
    print(f"Transfer time ({disagg_transfer:.1f} ms) is {transfer_pct:.0f}% overhead")
    print(f"This is why we need RDMA - 10x faster network")
    print(f"With RDMA @ 100 Gbps: ~{disagg_transfer/10:.1f} ms transfer")
    print(f"That would reduce total to ~{disagg_prefill + disagg_transfer/10 + disagg_decode:.1f} ms")
    
else:
    print("Baseline metrics not found. Run 01_Local_Inference_Baseline.ipynb first")

## Key Takeaways

**What We Built:**
- Split inference pipeline: Node 1 (prefill) → Node 2 (decode)
- TCP/IP-based KV cache transfer
- Serialization/deserialization overhead

**What We Measured:**
- Transfer time: 10-30 ms for typical sequences
- Transfer overhead: 15-30% of total latency
- Network bandwidth: ~5-10 Gbps effective with TCP

**Why This Matters:**
- Proved disaggregation works architecturally
- Identified network as bottleneck (not compute)
- Transfer overhead too high for production

**What's Next:**
- [04_NIXL_Integration.ipynb](04_NIXL_Integration.ipynb) - Replace TCP with RDMA/NIXL for 10x faster transfer