In [1]:
!nvidia-smi

Mon Dec 16 16:35:29 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 6000     Off  | 00000000:AF:00.0 Off |                    0 |
| N/A   35C    P0    62W / 250W |    837MiB / 22698MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import os
from dotenv import load_dotenv

# Load HF token from .env
load_dotenv()
hf_token = os.getenv('HUGGINGFACE_TOKEN')

# Login to HuggingFace
from huggingface_hub import login
login(token=hf_token)

In [3]:
from pathlib import Path
import torch
import psutil
import numpy as np
import time
import logging
from dataclasses import dataclass
from typing import Dict, Optional, List, Union, Tuple, Any

from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    PretrainedConfig,
    Cache,
    DynamicCache, 
    OffloadedCache,
    QuantizedCache,
    QuantizedCacheConfig
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EnhancedDocumentLoader:
    def __init__(self, file_path: str = "data/crimeandpunishment.txt"):
        self.file_path = Path(file_path)
        self.text_cache = None
        
    def load_chunk(self, start_pos: int = 1000, chunk_size: int = 1000) -> str:
        """Load a chunk from the document with caching"""
        # Cache the full text on first load
        if self.text_cache is None:
            with open(self.file_path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            # Skip Project Gutenberg header
            start_marker = "CRIME AND PUNISHMENT"
            narrative_start = text.find(start_marker)
            if narrative_start == -1:
                raise ValueError(f"Start marker '{start_marker}' not found in the document.")
            self.text_cache = text[narrative_start:]
        
        # Get clean chunk
        chunk_start = min(start_pos, len(self.text_cache) - chunk_size)
        chunk = self.text_cache[chunk_start:chunk_start + chunk_size]
        
        # Adjust to sentence boundary
        first_period = chunk.find(". ") + 2
        if first_period > 1:
            chunk = chunk[first_period:]
            
        return chunk

    def get_total_length(self) -> int:
        """Get total length of usable text"""
        if self.text_cache is None:
            _ = self.load_chunk()  # This will load and cache the text
        return len(self.text_cache)


class SimpleQuantizedCache(QuantizedCache):
    """Basic implementation of quantized cache using simple min-max quantization"""
    
    def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor:
        # Simple min-max quantization
        with torch.no_grad():
            # Compute min and max along the specified axis
            if axis == 0:
                min_val = tensor.min(dim=0)[0]
                max_val = tensor.max(dim=0)[0]
            else:  # axis = -1
                min_val = tensor.min(dim=-1, keepdim=True)[0]
                max_val = tensor.max(dim=-1, keepdim=True)[0]
            
            # Scale to [0, 2^nbits - 1]
            scale = (max_val - min_val) / (2**self.nbits - 1)
            scale = torch.clamp(scale, min=1e-6)  # Prevent division by zero
            
            # Quantize
            qtensor = ((tensor - min_val) / scale).round().clamp(0, 2**self.nbits - 1)
            
            # Store scaling factors as attributes of the tensor
            qtensor.scale = scale
            qtensor.zero_point = min_val
            
            return qtensor
    
    def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor:
        # Dequantize using stored scale and zero point
        with torch.no_grad():
            return qtensor * qtensor.scale + qtensor.zero_point

In [4]:
class CPUOffloadedCache(DynamicCache):
    """Cache implementation that offloads to CPU while maintaining one layer on GPU"""
    
    def __init__(self):
        super().__init__()
        self.current_layer_idx = 0
        self.device_tracking = []  # Track current device for each layer
        self.original_devices = []  # Store original device for each layer
        self._prefetch_stream = (
            torch.cuda.Stream() if torch.cuda.is_available() else None
        )
        
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache with new states, managing device placement"""
        # Track the original device when first seeing a layer
        if layer_idx >= len(self.original_devices):
            self.original_devices.append(key_states.device)
            self.device_tracking.append("cpu")  # Start on CPU except current layer
            
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]
            
        # Update cache
        if len(self.key_cache) <= layer_idx:
            # New layer - store on original device initially
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            # Existing layer - concatenate on current device
            device = self.original_devices[layer_idx]
            key_states = key_states.to(device)
            value_states = value_states.to(device)
            
            self.key_cache[layer_idx] = torch.cat(
                [self.key_cache[layer_idx], key_states], dim=-2
            )
            self.value_cache[layer_idx] = torch.cat(
                [self.value_cache[layer_idx], value_states], dim=-2
            )
        
        # Handle device management
        self._manage_devices(layer_idx)
        
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def _manage_devices(self, current_layer_idx: int):
        """Manage device placement of cache layers"""
        if not torch.cuda.is_available():
            return
            
        # Move current layer to GPU if needed
        if self.device_tracking[current_layer_idx] == "cpu":
            self._move_to_device(
                current_layer_idx, 
                self.original_devices[current_layer_idx]
            )
            
        # Start prefetching next layer
        next_layer_idx = (current_layer_idx + 1) % len(self.key_cache)
        self._prefetch_layer(next_layer_idx)
        
        # Move previous layer to CPU
        prev_layer_idx = (current_layer_idx - 1) % len(self.key_cache)
        if prev_layer_idx != next_layer_idx:
            self._offload_layer(prev_layer_idx)
    
    def _move_to_device(self, layer_idx: int, device: torch.device):
        """Move a layer to specified device"""
        if self.device_tracking[layer_idx] == str(device):
            return
            
        self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device)
        self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device)
        self.device_tracking[layer_idx] = str(device)
    
    def _prefetch_layer(self, layer_idx: int):
        """Prefetch next layer to GPU"""
        if not self._prefetch_stream:
            return
            
        with torch.cuda.stream(self._prefetch_stream):
            self._move_to_device(
                layer_idx,
                self.original_devices[layer_idx]
            )
    
    def _offload_layer(self, layer_idx: int):
        """Offload layer to CPU"""
        if self.device_tracking[layer_idx] == "cpu":
            return
            
        self.key_cache[layer_idx] = self.key_cache[layer_idx].to("cpu")
        self.value_cache[layer_idx] = self.value_cache[layer_idx].to("cpu")
        self.device_tracking[layer_idx] = "cpu"
    
    def get_current_device(self, layer_idx: int) -> str:
        """Get current device for layer"""
        if layer_idx >= len(self.device_tracking):
            return "undefined"
        return self.device_tracking[layer_idx]
    
    def get_device_metrics(self) -> Dict[str, int]:
        """Get metrics about cache device placement"""
        return {
            'layers_on_gpu': sum(1 for d in self.device_tracking if d != "cpu"),
            'layers_on_cpu': sum(1 for d in self.device_tracking if d == "cpu"),
            'total_layers': len(self.device_tracking)
        }

In [5]:
class QuantizedOffloadedCache(SimpleQuantizedCache):
    """Cache implementation combining quantization with CPU offloading"""
    
    def __init__(self, cache_config: QuantizedCacheConfig):
        super().__init__(cache_config)
        self.current_layer_idx = 0
        self.device_tracking = []  # Track current device for each layer
        self.original_devices = []  # Store original device for each layer
        self._prefetch_stream = (
            torch.cuda.Stream() if torch.cuda.is_available() else None
        )
        
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache with new states, managing both quantization and device placement"""
        # Track the original device when first seeing a layer
        if layer_idx >= len(self.original_devices):
            self.original_devices.append(key_states.device)
            self.device_tracking.append("cpu")  # Start on CPU except current layer
            
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]
            
        # Update quantized cache
        if len(self._quantized_key_cache) <= layer_idx:
            # New layer - quantize and store
            q_key = self._quantize(key_states.contiguous(), axis=self.axis_key)
            q_value = self._quantize(value_states.contiguous(), axis=self.axis_value)
            
            # Store quantized data on CPU initially
            self._quantized_key_cache.append(q_key.to("cpu"))
            self._quantized_value_cache.append(q_value.to("cpu"))
            
            # Initialize residual cache
            self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
            self.value_cache.append(torch.zeros(0, dtype=value_states.dtype, device=value_states.device))
            
            keys_to_return, values_to_return = key_states, value_states
        else:
            # Handle existing layer
            device = self.original_devices[layer_idx]
            
            # Move quantized data to correct device if needed
            if self.device_tracking[layer_idx] != str(device):
                self._quantized_key_cache[layer_idx] = self._quantized_key_cache[layer_idx].to(device)
                self._quantized_value_cache[layer_idx] = self._quantized_value_cache[layer_idx].to(device)
                self.device_tracking[layer_idx] = str(device)
            
            # Dequantize and combine with residual
            dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
            dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
            
            keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
            values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
            
            keys_to_return = torch.cat(keys_to_return, dim=-2)
            values_to_return = torch.cat(values_to_return, dim=-2)
            
            # Check if we need to requantize and update residual
            if (
                self.key_cache[layer_idx].dim() == 4
                and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
            ):
                # Requantize full sequence
                self._quantized_key_cache[layer_idx] = self._quantize(
                    keys_to_return.contiguous(), axis=self.axis_key
                )
                self._quantized_value_cache[layer_idx] = self._quantize(
                    values_to_return.contiguous(), axis=self.axis_value
                )
                
                # Move to CPU if not current layer
                if layer_idx != self.current_layer_idx:
                    self._quantized_key_cache[layer_idx] = self._quantized_key_cache[layer_idx].to("cpu")
                    self._quantized_value_cache[layer_idx] = self._quantized_value_cache[layer_idx].to("cpu")
                    self.device_tracking[layer_idx] = "cpu"
                
                # Reset residual
                self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
                self.value_cache[layer_idx] = torch.zeros(0, dtype=value_states.dtype, device=value_states.device)
            else:
                # Update residual
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
        
        # Handle device management
        self._manage_devices(layer_idx)
        
        return keys_to_return, values_to_return
    
    def _manage_devices(self, current_layer_idx: int):
        """Manage device placement of cache layers"""
        if not torch.cuda.is_available():
            return
            
        self.current_layer_idx = current_layer_idx
        
        # Move current layer to GPU if needed
        if self.device_tracking[current_layer_idx] == "cpu":
            self._move_to_device(
                current_layer_idx, 
                self.original_devices[current_layer_idx]
            )
            
        # Start prefetching next layer
        next_layer_idx = (current_layer_idx + 1) % len(self._quantized_key_cache)
        self._prefetch_layer(next_layer_idx)
        
        # Move previous layer to CPU
        prev_layer_idx = (current_layer_idx - 1) % len(self._quantized_key_cache)
        if prev_layer_idx != next_layer_idx:
            self._offload_layer(prev_layer_idx)
    
    def _move_to_device(self, layer_idx: int, device: torch.device):
        """Move a layer to specified device"""
        if self.device_tracking[layer_idx] == str(device):
            return
            
        self._quantized_key_cache[layer_idx] = self._quantized_key_cache[layer_idx].to(device)
        self._quantized_value_cache[layer_idx] = self._quantized_value_cache[layer_idx].to(device)
        self.device_tracking[layer_idx] = str(device)
    
    def _prefetch_layer(self, layer_idx: int):
        """Prefetch next layer to GPU"""
        if not self._prefetch_stream:
            return
            
        with torch.cuda.stream(self._prefetch_stream):
            self._move_to_device(
                layer_idx,
                self.original_devices[layer_idx]
            )
    
    def _offload_layer(self, layer_idx: int):
        """Offload layer to CPU"""
        if self.device_tracking[layer_idx] == "cpu":
            return
            
        self._quantized_key_cache[layer_idx] = self._quantized_key_cache[layer_idx].to("cpu")
        self._quantized_value_cache[layer_idx] = self._quantized_value_cache[layer_idx].to("cpu")
        self.device_tracking[layer_idx] = "cpu"
    
    def get_device_metrics(self) -> Dict[str, Any]:
        """Get metrics about cache device placement and quantization"""
        base_metrics = {
            'layers_on_gpu': sum(1 for d in self.device_tracking if d != "cpu"),
            'layers_on_cpu': sum(1 for d in self.device_tracking if d == "cpu"),
            'total_layers': len(self.device_tracking)
        }
        
        # Add quantization metrics
        base_metrics.update({
            'quantization_bits': self.nbits,
            'residual_length': self.residual_length,
            'compression_ratio': 32.0 / self.nbits  # Compared to FP32
        })
        
        return base_metrics

In [6]:
@dataclass
class CacheConfig:
    """Configuration for cache strategies"""
    strategy: str  # "dynamic", "quantized", "cpu_offload", or "quantized_offload"
    decode_on_cpu: bool = False
    quantization: Optional[Dict] = None
    prefetch_size: int = 2  # Number of layers to prefetch for CPU offload
    
    @staticmethod
    def get_cache(config: "CacheConfig", model_config) -> Optional[Cache]:
        """Initialize appropriate cache based on configuration"""
        if config.strategy == "dynamic":
            return DynamicCache()
        elif config.strategy == "quantized":
            if config.quantization is None:
                config.quantization = {
                    'nbits': 4,
                    'residual_length': 128,
                    'compute_dtype': torch.float16
                }
            quant_config = QuantizedCacheConfig(**config.quantization)
            return SimpleQuantizedCache(cache_config=quant_config)
        elif config.strategy == "cpu_offload":
            return CPUOffloadedCache()
        elif config.strategy == "quantized_offload":
            if config.quantization is None:
                config.quantization = {
                    'nbits': 4,
                    'residual_length': 128,
                    'compute_dtype': torch.float16
                }
            quant_config = QuantizedCacheConfig(**config.quantization)
            return QuantizedOffloadedCache(cache_config=quant_config)
        return None

# Metrics

In [7]:
@dataclass
class InferencePhaseMetrics:
    """Detailed metrics for each inference phase (prefill/decode)"""
    phase: str  # 'prefill' or 'decode'
    start_time: float
    end_time: float
    tokens_processed: int
    gpu_memory: Dict[str, float]  # Detailed memory stats
    cpu_memory: Dict[str, float]
    gpu_util: float
    memory_bandwidth: float
    cache_metrics: Dict[str, float]  # Cache-specific metrics

@dataclass
class TokenGenerationMetrics:
    """Enhanced per-token metrics"""
    timestamp: float
    token_index: int
    phase: str
    latency: float
    gpu_memory: Dict[str, float]
    cpu_memory: Dict[str, float]
    gpu_util: float
    cache_size: int
    cache_metrics: Optional[Dict[str, float]] = None

class EnhancedMetricsTracker:
    """Enhanced metrics tracking with phase awareness and detailed memory stats"""
    def __init__(self, sampling_rate: float = 0.1):
        self.sampling_rate = sampling_rate
        self.start_time = time.perf_counter()
        self.current_phase = None
        self.phase_metrics: List[InferencePhaseMetrics] = []
        self.token_metrics: List[TokenGenerationMetrics] = []
        
        # Initialize CUDA events for memory bandwidth tracking
        if torch.cuda.is_available():
            self.start_event = torch.cuda.Event(enable_timing=True)
            self.end_event = torch.cuda.Event(enable_timing=True)
    
    def start_phase(self, phase: str):
        """Start tracking a new inference phase"""
        self.current_phase = phase
        self.start_event.record()
        
        self.phase_metrics.append(InferencePhaseMetrics(
            phase=phase,
            start_time=time.perf_counter(),
            end_time=0,
            tokens_processed=0,
            gpu_memory=self._get_gpu_memory_stats(),
            cpu_memory=self._get_cpu_memory_stats(),
            gpu_util=self._get_gpu_utilization(),
            memory_bandwidth=0,
            cache_metrics={}
        ))
    
    def end_phase(self, tokens_processed: int, cache_metrics: Optional[Dict[str, float]] = None):
        """End current phase and record final metrics"""
        self.end_event.record()
        self.end_event.synchronize()
        
        current_metrics = self.phase_metrics[-1]
        current_metrics.end_time = time.perf_counter()
        current_metrics.tokens_processed = tokens_processed
        current_metrics.memory_bandwidth = self._calculate_memory_bandwidth()
        if cache_metrics:
            current_metrics.cache_metrics.update(cache_metrics)
    
    def sample_token(self, token_index: int, phase: str, latency: float, cache_size: int,
                    cache_metrics: Optional[Dict[str, float]] = None):
        """Record metrics for a single token generation"""
        self.token_metrics.append(TokenGenerationMetrics(
            timestamp=time.perf_counter() - self.start_time,
            token_index=token_index,
            phase=phase,
            latency=latency,
            gpu_memory=self._get_gpu_memory_stats(),
            cpu_memory=self._get_cpu_memory_stats(),
            gpu_util=self._get_gpu_utilization(),
            cache_size=cache_size,
            cache_metrics=cache_metrics
        ))
    
    def _get_gpu_memory_stats(self) -> Dict[str, float]:
        """Get detailed GPU memory statistics"""
        if not torch.cuda.is_available():
            return {}
            
        return {
            'allocated': torch.cuda.memory_allocated() / 1024**2,
            'reserved': torch.cuda.memory_reserved() / 1024**2,
            'max_allocated': torch.cuda.max_memory_allocated() / 1024**2,
            'max_reserved': torch.cuda.max_memory_reserved() / 1024**2,
            'fragmentation': self._calculate_memory_fragmentation()
        }
    
    def _get_cpu_memory_stats(self) -> Dict[str, float]:
        """Get detailed CPU memory statistics"""
        process = psutil.Process()
        return {
            'rss': process.memory_info().rss / 1024**2,
            'vms': process.memory_info().vms / 1024**2,
            'shared': process.memory_info().shared / 1024**2,
            'percent': process.memory_percent()
        }
    
    def _calculate_memory_fragmentation(self) -> float:
        """Calculate memory fragmentation with improved accuracy"""
        if not torch.cuda.is_available():
            return 0.0
        
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        
        if reserved == 0:
            return 0.0
            
        # Calculate fragmentation in MB
        fragmentation = (reserved - allocated) / 1024**2
        
        return fragmentation
    
    def _get_gpu_utilization(self) -> float:
        """Get GPU utilization percentage with improved accuracy"""
        if not torch.cuda.is_available():
            return 0.0
        try:
            # Try using nvidia-smi through subprocess if available
            import subprocess
            result = subprocess.check_output(
                ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'],
                encoding='utf-8'
            )
            return float(result.strip()) / 100.0
        except:
            # Fallback to torch CUDA metrics
            return torch.cuda.utilization() / 100.0
    
    def _calculate_memory_bandwidth(self) -> float:
        """Calculate memory bandwidth usage between events"""
        if not torch.cuda.is_available():
            return 0.0
        return self.start_event.elapsed_time(self.end_event) * 1e-3  # Convert to seconds
    
    def get_summary(self) -> Dict[str, Dict[str, float]]:
        """Get enhanced statistical summary of collected metrics"""
        # Debug print
        logger.debug(f"Number of token metrics collected: {len(self.token_metrics)}")
        logger.debug(f"Phases recorded: {set(m.phase for m in self.token_metrics)}")
        
        summary = {
            'phases': {},
            'memory': {
                'peak_gpu_allocated': max(m.gpu_memory.get('allocated', 0) for m in self.token_metrics),
                'peak_gpu_reserved': max(m.gpu_memory.get('reserved', 0) for m in self.token_metrics),
                'peak_cpu': max(m.cpu_memory.get('rss', 0) for m in self.token_metrics),
                'mean_fragmentation': np.mean([m.gpu_memory.get('fragmentation', 0) for m in self.token_metrics]),
            },
            'performance': {
                'mean_latency': np.mean([m.latency for m in self.token_metrics]),
                'p90_latency': np.percentile([m.latency for m in self.token_metrics], 90),
                'mean_gpu_util': np.mean([m.gpu_util for m in self.token_metrics]),
                'tokens_per_second': len(self.token_metrics) / (self.token_metrics[-1].timestamp if self.token_metrics else 1)
            }
        }
        
        # Add phase-specific metrics
        for phase in ['prefill', 'decode']:
            phase_tokens = [m for m in self.token_metrics if m.phase == phase]
            if phase_tokens:
                summary['phases'][phase] = {
                    'token_count': len(phase_tokens),
                    'mean_latency': np.mean([m.latency for m in phase_tokens]),
                    'peak_memory': max(m.gpu_memory.get('allocated', 0) for m in phase_tokens),
                    'mean_cache_size': np.mean([m.cache_size for m in phase_tokens])
                }
        
        return summary

# Inference

In [8]:
class Llama2Inference:
    def __init__(
        self, 
        model_name: str = "meta-llama/Llama-2-7b-chat-hf",
        cache_config: Optional[CacheConfig] = None,
    ):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cache_config = cache_config or CacheConfig(strategy="dynamic")
        self.metrics_tracker = None
        
    def setup(self):
        """Initialize model with CPU decode offloading"""
        logger.info(f"Loading model '{self.model_name}'")
        
        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        if self.cache_config.decode_on_cpu:
            # For Llama2, we need to access model.model.layers
            if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
                logger.info("Moving decode layers to CPU...")
                
                # Get number of layers
                num_layers = len(self.model.model.layers)
                
                # Keep first few layers on GPU for better performance
                gpu_layers = min(4, num_layers // 4)  # Keep ~25% on GPU
                
                # Move specific layers to CPU
                for i in range(gpu_layers, num_layers):
                    self.model.model.layers[i] = self.model.model.layers[i].to('cpu')
                
                logger.info(f"Moved {num_layers - gpu_layers} layers to CPU, kept {gpu_layers} on GPU")
            else:
                logger.warning("Could not find Llama2 layers for CPU offloading")
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Initialize appropriate cache
        self.cache = CacheConfig.get_cache(self.cache_config, self.model.config)
        
    def _prepare_inputs_for_cpu_decode(self, input_ids, attention_mask):
        """Prepare inputs for CPU decode phase"""
        # Ensure proper device placement for CPU decode
        if self.cache_config.decode_on_cpu:
            # Move inputs to CPU but keep attention mask on GPU for efficiency
            input_ids = input_ids.cpu()
            # attention_mask stays on GPU as it's needed for attention computation
            
            # If using offloaded cache, ensure it's properly placed
            if isinstance(self.cache, (CPUOffloadedCache, QuantizedOffloadedCache)):
                self.cache.prepare_for_cpu_decode()
        
        return input_ids, attention_mask
    
    def _get_layer_device(self, layer_idx: int) -> str:
        """Get the device for a specific layer during decode"""
        if not self.cache_config.decode_on_cpu:
            return self.device
            
        # For Llama2 CPU offloading strategy:
        # - Keep first few layers on GPU
        # - Rest on CPU
        num_layers = len(self.model.model.layers)
        gpu_layers = min(4, num_layers // 4)
        
        return self.device if layer_idx < gpu_layers else "cpu"
    
    def _run_decode_step(self, input_ids, attention_mask, past_key_values):
        """Run a single decode step with CPU offloading"""
        # Prepare inputs
        input_ids, attention_mask = self._prepare_inputs_for_cpu_decode(
            input_ids, attention_mask
        )
        
        # Run forward pass
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True
        )
        
        # Move logits back to GPU for sampling
        if self.cache_config.decode_on_cpu:
            outputs.logits = outputs.logits.to(self.device)
            
        return outputs

In [11]:
import torch
import logging
import time
import psutil
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional
import gc
from dataclasses import dataclass

logger = logging.getLogger(__name__)

@dataclass
class BenchmarkConfig:
    """Configuration for benchmark runs"""
    context_lengths: List[int] = (512, 2048, 4096)
    output_length: int = 50
    num_runs: int = 3  # Number of runs per configuration
    warmup_runs: int = 1  # Number of warmup runs
    model_name: str = "meta-llama/Llama-2-7b-chat-hf"
    validate_outputs: bool = True
    stress_test: bool = False

class CacheBenchmark:
    """Enhanced benchmark suite for testing cache strategies"""
    
    def __init__(self, config: BenchmarkConfig):
        self.config = config
        self.strategies = self._get_default_strategies()
        self.results = {}
        self.validation_results = {}
        
    def _get_default_strategies(self) -> List[Dict]:
        """Define cache strategies to test"""
        return [
            {
                "name": "dynamic",
                "config": CacheConfig(
                    strategy="dynamic",
                    decode_on_cpu=False
                )
            },
            {
                "name": "quantized",
                "config": CacheConfig(
                    strategy="quantized",
                    decode_on_cpu=False,
                    quantization={
                        'nbits': 4,
                        'residual_length': 128,
                        'compute_dtype': torch.float16
                    }
                )
            },
            {
                "name": "cpu_offload",
                "config": CacheConfig(
                    strategy="cpu_offload",
                    decode_on_cpu=True,
                    prefetch_size=2
                )
            },
            {
                "name": "quantized_offload",
                "config": CacheConfig(
                    strategy="quantized_offload",
                    decode_on_cpu=True,
                    quantization={
                        'nbits': 4,
                        'residual_length': 128,
                        'compute_dtype': torch.float16
                    },
                    prefetch_size=2
                )
            }
        ]
    
    def _validate_cache_behavior(
        self,
        llm: Llama2Inference,
        strategy: Dict,
        context_length: int
    ) -> Dict[str, bool]:
        """Validate cache behavior for a given strategy"""
        validation = {
            'cache_growth_correct': False,
            'device_placement_correct': False,
            'memory_pattern_valid': False,
            'quantization_active': False
        }
        
        # Check cache growth
        initial_length = llm.cache.get_seq_length()
        _ = llm.run_inference("Test input", max_new_tokens=5)
        final_length = llm.cache.get_seq_length()
        validation['cache_growth_correct'] = final_length > initial_length
        
        # Check device placement for offloading strategies
        if strategy['config'].strategy in ['cpu_offload', 'quantized_offload']:
            if hasattr(llm.cache, 'get_device_metrics'):
                metrics = llm.cache.get_device_metrics()
                validation['device_placement_correct'] = (
                    metrics['layers_on_cpu'] > 0 and
                    metrics['layers_on_gpu'] > 0
                )
        else:
            validation['device_placement_correct'] = True
        
        # Check memory pattern
        if torch.cuda.is_available():
            initial_mem = torch.cuda.memory_allocated()
            _ = llm.run_inference("Another test", max_new_tokens=5)
            peak_mem = torch.cuda.max_memory_allocated()
            final_mem = torch.cuda.memory_allocated()
            
            # Memory should peak during generation but release after
            validation['memory_pattern_valid'] = (
                peak_mem > initial_mem and
                final_mem < peak_mem
            )
        else:
            validation['memory_pattern_valid'] = True
        
        # Check quantization
        if strategy['config'].strategy in ['quantized', 'quantized_offload']:
            if hasattr(llm.cache, 'get_device_metrics'):
                metrics = llm.cache.get_device_metrics()
                validation['quantization_active'] = (
                    'quantization_bits' in metrics and
                    'compression_ratio' in metrics
                )
        else:
            validation['quantization_active'] = True
            
        return validation
    
    def _run_stress_test(
        self,
        llm: Llama2Inference,
        context_length: int,
        num_iterations: int = 5
    ) -> Dict[str, float]:
        """Run stress test with continuous generation"""
        stress_metrics = {
            'avg_throughput': 0.0,
            'memory_growth': 0.0,
            'device_transitions': 0
        }
        
        initial_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
        throughputs = []
        device_transitions = 0
        
        for i in range(num_iterations):
            start_time = time.perf_counter()
            result = llm.run_inference(
                f"Test input {i}",
                max_new_tokens=50
            )
            end_time = time.perf_counter()
            
            # Calculate throughput
            tokens_generated = len(result['decode_timings'])
            throughput = tokens_generated / (end_time - start_time)
            throughputs.append(throughput)
            
            # Track device transitions for offloading strategies
            if hasattr(llm.cache, 'get_device_metrics'):
                metrics = llm.cache.get_device_metrics()
                device_transitions += metrics['layers_on_cpu']
        
        final_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
        
        stress_metrics['avg_throughput'] = np.mean(throughputs)
        stress_metrics['memory_growth'] = (final_mem - initial_mem) / 1024**2  # MB
        stress_metrics['device_transitions'] = device_transitions
        
        return stress_metrics
    
    def run_benchmark(self) -> Dict:
        """Run complete benchmark suite"""
        document_loader = EnhancedDocumentLoader()
        
        for context_length in self.config.context_lengths:
            logger.info(f"\nTesting context length: {context_length}")
            self.results[context_length] = {}
            self.validation_results[context_length] = {}
            
            # Load test input
            input_text = document_loader.load_chunk(chunk_size=context_length)
            
            for strategy in self.strategies:
                logger.info(f"\nTesting strategy: {strategy['name']}")
                strategy_results = []
                
                try:
                    # Warmup runs
                    logger.info("Performing warmup runs...")
                    llm = Llama2Inference(
                        model_name=self.config.model_name,
                        cache_config=strategy['config']
                    )
                    
                    for _ in range(self.config.warmup_runs):
                        _ = llm.run_inference(
                            input_text=input_text,
                            max_new_tokens=self.config.output_length
                        )
                    
                    # Validation if requested
                    if self.config.validate_outputs:
                        logger.info("Validating cache behavior...")
                        self.validation_results[context_length][strategy['name']] = (
                            self._validate_cache_behavior(llm, strategy, context_length)
                        )
                    
                    # Benchmark runs
                    logger.info("Starting benchmark runs...")
                    for run in range(self.config.num_runs):
                        # Clear GPU memory
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                            torch.cuda.reset_peak_memory_stats()
                        
                        result = llm.run_inference(
                            input_text=input_text,
                            max_new_tokens=self.config.output_length
                        )
                        
                        strategy_results.append(result)
                        
                        # Print progress metrics
                        avg_token_time = np.mean([
                            t['forward_latency'] 
                            for t in result['decode_timings']
                        ])
                        logger.info(
                            f"Run {run + 1}/{self.config.num_runs}: "
                            f"avg token time = {avg_token_time*1000:.2f}ms"
                        )
                    
                    # Stress test if enabled
                    if self.config.stress_test:
                        logger.info("Running stress test...")
                        stress_metrics = self._run_stress_test(
                            llm,
                            context_length
                        )
                        self.results[context_length][f"{strategy['name']}_stress"] = (
                            stress_metrics
                        )
                    
                    # Aggregate results
                    self.results[context_length][strategy['name']] = {
                        'avg_token_time': np.mean([
                            np.mean([t['forward_latency'] for t in r['decode_timings']])
                            for r in strategy_results
                        ]),
                        'throughput': np.mean([
                            len(r['decode_timings']) / sum(
                                t['forward_latency'] for t in r['decode_timings']
                            )
                            for r in strategy_results
                        ]),
                        'peak_gpu_memory': max([
                            r['metrics']['memory']['peak_gpu_allocated']
                            for r in strategy_results
                        ]) if torch.cuda.is_available() else 0,
                        'peak_cpu_memory': max([
                            r['metrics']['memory']['peak_cpu']
                            for r in strategy_results
                        ]),
                        'cache_metrics': strategy_results[-1]['metrics'].get(
                            'cache_metrics', {}
                        )
                    }
                    
                    # Cleanup
                    llm.cleanup()
                    gc.collect()
                    
                except Exception as e:
                    logger.error(f"Error testing {strategy['name']}: {str(e)}")
                    continue
        
        return self.results
    
    def print_summary(self):
        """Print formatted summary of benchmark results"""
        print("\nBenchmark Summary:")
        print("=" * 80)
        
        for context_length, strategies in self.results.items():
            print(f"\nContext Length: {context_length}")
            print("-" * 40)
            
            # Print main metrics
            headers = ["Strategy", "Throughput", "GPU Mem", "CPU Mem", "Avg Token"]
            row_format = "{:15} {:>10} {:>10} {:>10} {:>10}"
            print(row_format.format(*headers))
            print("-" * 60)
            
            for strategy, metrics in strategies.items():
                if not strategy.endswith('_stress'):
                    print(row_format.format(
                        strategy,
                        f"{metrics['throughput']:.1f}",
                        f"{metrics['peak_gpu_memory']:.0f}MB",
                        f"{metrics['peak_cpu_memory']:.0f}MB",
                        f"{metrics['avg_token_time']*1000:.1f}ms"
                    ))
            
            # Print validation results if available
            if context_length in self.validation_results:
                print("\nValidation Results:")
                for strategy, validation in self.validation_results[context_length].items():
                    print(f"\n{strategy}:")
                    for check, passed in validation.items():
                        print(f"  {check}: {'✓' if passed else '✗'}")
            
            # Print stress test results if available
            stress_results = {
                k: v for k, v in strategies.items() if k.endswith('_stress')
            }
            if stress_results:
                print("\nStress Test Results:")
                for strategy, metrics in stress_results.items():
                    base_strategy = strategy.replace('_stress', '')
                    print(f"\n{base_strategy}:")
                    print(f"  Avg Throughput: {metrics['avg_throughput']:.1f} tokens/s")
                    print(f"  Memory Growth: {metrics['memory_growth']:.0f}MB")
                    print(f"  Device Transitions: {metrics['device_transitions']}")

# if __name__ == "__main__":
#     logging.basicConfig(level=logging.INFO)
    
#     # Configure benchmark
#     config = BenchmarkConfig(
#         context_lengths=[512, 2048, 4096],
#         output_length=50,
#         num_runs=3,
#         warmup_runs=1,
#         validate_outputs=True,
#         stress_test=True
#     )
    
#     # Run benchmark
#     benchmark = CacheBenchmark(config)
#     results = benchmark.run_benchmark()
    
#     # Print results
#     benchmark.print_summary()

In [None]:
# # Test with minimal configuration first
# config = BenchmarkConfig(
#     context_lengths=[512],  # Start small
#     output_length=20,
#     num_runs=1,
#     warmup_runs=1,
#     validate_outputs=True,
#     stress_test=False
# )

# # Initialize and run benchmark
# benchmark = CacheBenchmark(config)
# results = benchmark.run_benchmark()
# benchmark.print_summary()

In [12]:
import torch
import logging
from typing import Dict
import time

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def test_llama2_inference():
    """Simple test of Llama2 inference with CPU offloading"""
    
    # Initialize with CPU offloading
    logger.info("Initializing Llama2 with CPU offloading...")
    llm = Llama2Inference(
        model_name="meta-llama/Llama-2-7b-chat-hf",
        cache_config=CacheConfig(strategy="cpu_offload", decode_on_cpu=True)
    )
    llm.setup()
    
    # Test input
    test_prompt = "Write a short story about a cat in 2-3 sentences."
    
    # Check layer devices
    logger.info("\nChecking layer device placement:")
    for i, layer in enumerate(llm.model.model.layers):
        device = next(layer.parameters()).device
        logger.info(f"Layer {i}: {device}")
    
    # Run inference
    logger.info("\nRunning inference...")
    start_time = time.time()
    
    # Prepare inputs
    inputs = llm.tokenizer(test_prompt, return_tensors="pt")
    inputs = {k: v.to(llm.device) for k, v in inputs.items()}
    
    # Generate with minimal tokens for testing
    outputs = llm._run_decode_step(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        past_key_values=None
    )
    
    # Test generation of a few tokens
    for _ in range(3):  # Generate 3 tokens for testing
        next_token = torch.argmax(outputs.logits[:, -1:], dim=-1)
        device = next_token.device
        logger.info(f"Generated token device: {device}")
        
        outputs = llm._run_decode_step(
            input_ids=next_token,
            attention_mask=torch.ones((1, next_token.shape[1]), device=llm.device),
            past_key_values=outputs.past_key_values
        )
    
    end_time = time.time()
    logger.info(f"Inference time: {end_time - start_time:.2f}s")
    
    # Cleanup
    del llm.model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

if __name__ == "__main__":
    try:
        test_llama2_inference()
    except Exception as e:
        logger.error(f"Test failed: {str(e)}", exc_info=True)

INFO:__main__:Initializing Llama2 with CPU offloading...
INFO:__main__:Loading model 'meta-llama/Llama-2-7b-chat-hf'


config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

INFO:__main__:Moving decode layers to CPU...
INFO:__main__:Moved 28 layers to CPU, kept 4 on GPU


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

INFO:__main__:
Checking layer device placement:
INFO:__main__:Layer 0: cuda:0
INFO:__main__:Layer 1: cuda:0
INFO:__main__:Layer 2: cuda:0
INFO:__main__:Layer 3: cuda:0
INFO:__main__:Layer 4: cpu
INFO:__main__:Layer 5: cpu
INFO:__main__:Layer 6: cpu
INFO:__main__:Layer 7: cpu
INFO:__main__:Layer 8: cpu
INFO:__main__:Layer 9: cpu
INFO:__main__:Layer 10: cpu
INFO:__main__:Layer 11: cpu
INFO:__main__:Layer 12: cpu
INFO:__main__:Layer 13: cpu
INFO:__main__:Layer 14: cpu
INFO:__main__:Layer 15: cpu
INFO:__main__:Layer 16: cpu
INFO:__main__:Layer 17: cpu
INFO:__main__:Layer 18: cpu
INFO:__main__:Layer 19: cpu
INFO:__main__:Layer 20: cpu
INFO:__main__:Layer 21: cpu
INFO:__main__:Layer 22: cpu
INFO:__main__:Layer 23: cpu
INFO:__main__:Layer 24: cpu
INFO:__main__:Layer 25: cpu
INFO:__main__:Layer 26: cpu
INFO:__main__:Layer 27: cpu
INFO:__main__:Layer 28: cpu
INFO:__main__:Layer 29: cpu
INFO:__main__:Layer 30: cpu
INFO:__main__:Layer 31: cpu
INFO:__main__:
Running inference...
ERROR:__main__:Tes