In [1]:
import os
from dotenv import load_dotenv

# Load HF token from .env
load_dotenv()
hf_token = os.getenv('HUGGINGFACE_TOKEN')

# Login to HuggingFace
from huggingface_hub import login
login(token=hf_token)

In [31]:
import torch
import numpy as np
import time
import logging
import psutil
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Optional, List, Union, Tuple, Any
from collections import defaultdict

from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    AutoConfig,
    PretrainedConfig,
    Cache,
    DynamicCache, 
    OffloadedCache,
    QuantizedCache,
    QuantizedCacheConfig,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Document Loader

In [3]:
class EnhancedDocumentLoader:
    def __init__(self, file_path: str = "data/crimeandpunishment.txt"):
        self.file_path = Path(file_path)
        self.text_cache = None
        
    def load_chunk(self, start_pos: int = 1000, chunk_size: int = 1000) -> str:
        """Load a chunk from the document with caching"""
        # Cache the full text on first load
        if self.text_cache is None:
            with open(self.file_path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            # Skip Project Gutenberg header
            start_marker = "CRIME AND PUNISHMENT"
            narrative_start = text.find(start_marker)
            if narrative_start == -1:
                raise ValueError(f"Start marker '{start_marker}' not found in the document.")
            self.text_cache = text[narrative_start:]
        
        # Get clean chunk
        chunk_start = min(start_pos, len(self.text_cache) - chunk_size)
        chunk = self.text_cache[chunk_start:chunk_start + chunk_size]
        
        # Adjust to sentence boundary
        first_period = chunk.find(". ") + 2
        if first_period > 1:
            chunk = chunk[first_period:]
            
        return chunk

    def get_total_length(self) -> int:
        """Get total length of usable text"""
        if self.text_cache is None:
            _ = self.load_chunk()  # This will load and cache the text
        return len(self.text_cache)

# Quant

In [4]:
class SimpleQuantizedCache(QuantizedCache):
    """Basic implementation of quantized cache using simple min-max quantization"""
    
    def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor:
        # Simple min-max quantization
        with torch.no_grad():
            # Compute min and max along the specified axis
            if axis == 0:
                min_val = tensor.min(dim=0)[0]
                max_val = tensor.max(dim=0)[0]
            else:  # axis = -1
                min_val = tensor.min(dim=-1, keepdim=True)[0]
                max_val = tensor.max(dim=-1, keepdim=True)[0]
            
            # Scale to [0, 2^nbits - 1]
            scale = (max_val - min_val) / (2**self.nbits - 1)
            scale = torch.clamp(scale, min=1e-6)  # Prevent division by zero
            
            # Quantize
            qtensor = ((tensor - min_val) / scale).round().clamp(0, 2**self.nbits - 1)
            
            # Store scaling factors as attributes of the tensor
            qtensor.scale = scale
            qtensor.zero_point = min_val
            
            return qtensor
    
    def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor:
        # Dequantize using stored scale and zero point
        with torch.no_grad():
            return qtensor * qtensor.scale + qtensor.zero_point

# Cache

In [18]:
@dataclass
class CPUDecodeConfig:
    """Configuration for CPU decode settings"""
    enabled: bool = False
    # Which layers to place on CPU (None means auto-select)
    cpu_layers: Optional[List[int]] = None
    # Whether to use mixed precision on CPU
    use_fp32_cpu: bool = True

@dataclass
class CacheConfig:
    """Configuration for cache strategies with CPU decode support"""
    strategy: str
    decode_on_cpu: bool = False
    quantization: Optional[Dict] = None
    cpu_decode: Optional[CPUDecodeConfig] = None
    
    @staticmethod
    def get_cache(config: "CacheConfig", model_config) -> Optional[Cache]:
        """Initialize appropriate cache based on configuration"""
        if config.strategy == "dynamic":
            return DynamicCache()
        elif config.strategy == "quantized":
            if config.quantization is None:
                config.quantization = {
                    'nbits': 4,
                    'residual_length': 128,
                    'compute_dtype': torch.float16
                }
            quant_config = QuantizedCacheConfig(**config.quantization)
            return SimpleQuantizedCache(cache_config=quant_config)
        elif config.strategy == "offloaded":
            return OffloadedCache()
        return None

    def get_layer_mapping(self, num_layers: int) -> Dict[str, str]:
        """Get device mapping for model layers based on config"""
        if not self.cpu_decode or not self.cpu_decode.enabled:
            return None
            
        mapping = {
            # Keep embedding and other components on GPU
            "model.embed_tokens": "cuda:0",
            "model.norm": "cuda:0",
            "lm_head": "cuda:0",
            "model.layers": "cuda:0"  # Default for any unspecified layers
        }
        
        if self.cpu_decode.cpu_layers is None:
            # Default strategy: Put latter half of layers on CPU
            split_point = num_layers // 2
            cpu_layers = list(range(split_point, num_layers))
        else:
            cpu_layers = self.cpu_decode.cpu_layers
            
        for i in range(num_layers):
            if i in cpu_layers:
                mapping[f"model.layers.{i}"] = "cpu"
            else:
                mapping[f"model.layers.{i}"] = "cuda:0"
                
        return mapping

# Metrics

In [6]:
@dataclass
class InferencePhaseMetrics:
    """Detailed metrics for each inference phase (prefill/decode)"""
    phase: str  # 'prefill' or 'decode'
    start_time: float
    end_time: float
    tokens_processed: int
    gpu_memory: Dict[str, float]  # Detailed memory stats
    cpu_memory: Dict[str, float]
    gpu_util: float
    memory_bandwidth: float
    cache_metrics: Dict[str, float]  # Cache-specific metrics

@dataclass
class TokenGenerationMetrics:
    """Enhanced per-token metrics"""
    timestamp: float
    token_index: int
    phase: str
    latency: float
    gpu_memory: Dict[str, float]
    cpu_memory: Dict[str, float]
    gpu_util: float
    cache_size: int
    cache_metrics: Optional[Dict[str, float]] = None

class EnhancedMetricsTracker:
    """Enhanced metrics tracking with phase awareness and detailed memory stats"""
    def __init__(self, sampling_rate: float = 0.1):
        self.sampling_rate = sampling_rate
        self.start_time = time.perf_counter()
        self.current_phase = None
        self.phase_metrics: List[InferencePhaseMetrics] = []
        self.token_metrics: List[TokenGenerationMetrics] = []
        
        # Initialize CUDA events for memory bandwidth tracking
        if torch.cuda.is_available():
            self.start_event = torch.cuda.Event(enable_timing=True)
            self.end_event = torch.cuda.Event(enable_timing=True)
    
    def start_phase(self, phase: str):
        """Start tracking a new inference phase"""
        self.current_phase = phase
        self.start_event.record()
        
        self.phase_metrics.append(InferencePhaseMetrics(
            phase=phase,
            start_time=time.perf_counter(),
            end_time=0,
            tokens_processed=0,
            gpu_memory=self._get_gpu_memory_stats(),
            cpu_memory=self._get_cpu_memory_stats(),
            gpu_util=self._get_gpu_utilization(),
            memory_bandwidth=0,
            cache_metrics={}
        ))
    
    def end_phase(self, tokens_processed: int, cache_metrics: Optional[Dict[str, float]] = None):
        """End current phase and record final metrics"""
        self.end_event.record()
        self.end_event.synchronize()
        
        current_metrics = self.phase_metrics[-1]
        current_metrics.end_time = time.perf_counter()
        current_metrics.tokens_processed = tokens_processed
        current_metrics.memory_bandwidth = self._calculate_memory_bandwidth()
        if cache_metrics:
            current_metrics.cache_metrics.update(cache_metrics)
    
    def sample_token(self, token_index: int, phase: str, latency: float, cache_size: int,
                    cache_metrics: Optional[Dict[str, float]] = None):
        """Record metrics for a single token generation"""
        self.token_metrics.append(TokenGenerationMetrics(
            timestamp=time.perf_counter() - self.start_time,
            token_index=token_index,
            phase=phase,
            latency=latency,
            gpu_memory=self._get_gpu_memory_stats(),
            cpu_memory=self._get_cpu_memory_stats(),
            gpu_util=self._get_gpu_utilization(),
            cache_size=cache_size,
            cache_metrics=cache_metrics
        ))
    
    def _get_gpu_memory_stats(self) -> Dict[str, float]:
        """Get detailed GPU memory statistics"""
        if not torch.cuda.is_available():
            return {}
            
        return {
            'allocated': torch.cuda.memory_allocated() / 1024**2,
            'reserved': torch.cuda.memory_reserved() / 1024**2,
            'max_allocated': torch.cuda.max_memory_allocated() / 1024**2,
            'max_reserved': torch.cuda.max_memory_reserved() / 1024**2,
            'fragmentation': self._calculate_memory_fragmentation()
        }
    
    def _get_cpu_memory_stats(self) -> Dict[str, float]:
        """Get detailed CPU memory statistics"""
        process = psutil.Process()
        return {
            'rss': process.memory_info().rss / 1024**2,
            'vms': process.memory_info().vms / 1024**2,
            'shared': process.memory_info().shared / 1024**2,
            'percent': process.memory_percent()
        }
    
    def _calculate_memory_fragmentation(self) -> float:
        """Calculate memory fragmentation with improved accuracy"""
        if not torch.cuda.is_available():
            return 0.0
        
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        
        if reserved == 0:
            return 0.0
            
        # Calculate fragmentation in MB
        fragmentation = (reserved - allocated) / 1024**2
        
        return fragmentation
    
    def _get_gpu_utilization(self) -> float:
        """Get GPU utilization percentage with improved accuracy"""
        if not torch.cuda.is_available():
            return 0.0
        try:
            # Try using nvidia-smi through subprocess if available
            import subprocess
            result = subprocess.check_output(
                ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'],
                encoding='utf-8'
            )
            return float(result.strip()) / 100.0
        except:
            # Fallback to torch CUDA metrics
            return torch.cuda.utilization() / 100.0
    
    def _calculate_memory_bandwidth(self) -> float:
        """Calculate memory bandwidth usage between events"""
        if not torch.cuda.is_available():
            return 0.0
        return self.start_event.elapsed_time(self.end_event) * 1e-3  # Convert to seconds
    
    def get_summary(self) -> Dict[str, Dict[str, float]]:
        """Get enhanced statistical summary of collected metrics"""
        # Debug print
        logger.debug(f"Number of token metrics collected: {len(self.token_metrics)}")
        logger.debug(f"Phases recorded: {set(m.phase for m in self.token_metrics)}")
        
        summary = {
            'phases': {},
            'memory': {
                'peak_gpu_allocated': max(m.gpu_memory.get('allocated', 0) for m in self.token_metrics),
                'peak_gpu_reserved': max(m.gpu_memory.get('reserved', 0) for m in self.token_metrics),
                'peak_cpu': max(m.cpu_memory.get('rss', 0) for m in self.token_metrics),
                'mean_fragmentation': np.mean([m.gpu_memory.get('fragmentation', 0) for m in self.token_metrics]),
            },
            'performance': {
                'mean_latency': np.mean([m.latency for m in self.token_metrics]),
                'p90_latency': np.percentile([m.latency for m in self.token_metrics], 90),
                'mean_gpu_util': np.mean([m.gpu_util for m in self.token_metrics]),
                'tokens_per_second': len(self.token_metrics) / (self.token_metrics[-1].timestamp if self.token_metrics else 1)
            }
        }
        
        # Add phase-specific metrics
        for phase in ['prefill', 'decode']:
            phase_tokens = [m for m in self.token_metrics if m.phase == phase]
            if phase_tokens:
                summary['phases'][phase] = {
                    'token_count': len(phase_tokens),
                    'mean_latency': np.mean([m.latency for m in phase_tokens]),
                    'peak_memory': max(m.gpu_memory.get('allocated', 0) for m in phase_tokens),
                    'mean_cache_size': np.mean([m.cache_size for m in phase_tokens])
                }
        
        return summary

# Inference

In [27]:
class LlamaDeviceManager:
    """Manages device placement for Llama model without modifying its internals"""
    
    def __init__(self, model, cpu_layers):
        self.model = model
        self.cpu_layers = set(cpu_layers)
        self.hooks = []
        self.setup_hooks()
        
    def setup_hooks(self):
        """Setup forward pre-hooks for device management"""
        for name, module in self.model.named_modules():
            if 'layers' in name:
                layer_idx = int(name.split('.')[-1]) if name.split('.')[-1].isdigit() else -1
                if layer_idx in self.cpu_layers:
                    hook = module.register_forward_pre_hook(
                        lambda mod, inp, layer_idx=layer_idx: self._layer_pre_hook(mod, inp, layer_idx)
                    )
                    self.hooks.append(hook)
    
    def _layer_pre_hook(self, module, inputs, layer_idx):
        """Handle device transitions before layer execution"""
        if not isinstance(inputs, tuple):
            inputs = (inputs,)
            
        # Move inputs to appropriate device
        device = torch.device('cpu') if layer_idx in self.cpu_layers else torch.device('cuda:0')
        processed_inputs = []
        
        for x in inputs:
            if isinstance(x, torch.Tensor):
                processed_inputs.append(x.to(device))
            elif isinstance(x, tuple):
                processed_inputs.append(tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in x))
            else:
                processed_inputs.append(x)
                
        return tuple(processed_inputs)

    def prepare_inputs(self, input_ids, attention_mask=None, position_ids=None):
        """Prepare inputs with proper device placement"""
        device = torch.device('cuda:0')  # Start on GPU
        prepared_inputs = {
            'input_ids': input_ids.to(device),
            'attention_mask': attention_mask.to(device) if attention_mask is not None else None,
            'position_ids': position_ids.to(device) if position_ids is not None else None
        }
        return prepared_inputs
    
    def cleanup(self):
        """Remove hooks when done"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()


class EnhancedLlamaInference:
    """Enhanced inference wrapper with fixed device management"""
    
    def __init__(self, model_name: str, cache_config: Optional[CacheConfig] = None):
        self.cache_config = cache_config or CacheConfig(strategy="dynamic")
        
        # Create device map
        device_map = {"model.embed_tokens": "cpu"}  # Keep embeddings on CPU
        if self.cache_config.cpu_decode and self.cache_config.cpu_decode.enabled:
            for i in range(16):  # Assuming 16 layers
                if i in self.cache_config.cpu_decode.cpu_layers:
                    device_map[f"model.layers.{i}"] = "cpu"
                else:
                    device_map[f"model.layers.{i}"] = "cuda:0"
        else:
            device_map = "auto"
            
        # Load model with device map
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map=device_map,
            low_cpu_mem_usage=True
        )
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Initialize cache
        self.cache = CacheConfig.get_cache(self.cache_config, self.model.config)
        
    def run_inference(
        self,
        input_text: str,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
    ) -> Dict:
        metrics_tracker = EnhancedMetricsTracker()
        
        try:
            # Tokenize on CPU first
            inputs = self.tokenizer(
                input_text,
                return_tensors="pt",
                truncation=True
            )
            
            # Keep inputs on CPU initially since embed_tokens is on CPU
            input_ids = inputs["input_ids"]
            attention_mask = inputs.get("attention_mask")
            
            metrics_tracker.start_phase('prefill')
            
            with torch.inference_mode():
                outputs = self.model.generate(
                    input_ids=input_ids,  # Keep on CPU
                    attention_mask=attention_mask,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    past_key_values=self.cache,
                    use_cache=True,
                    # Ensure proper device handling for cache
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                
            metrics_tracker.end_phase(
                tokens_processed=outputs.shape[1],
                cache_metrics=self._get_cache_metrics()
            )
            
            return {
                "text": self.tokenizer.decode(outputs[0], skip_special_tokens=True),
                "metrics": metrics_tracker.get_summary()
            }
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise
            
    def _get_cache_metrics(self) -> Dict[str, float]:
        """Get cache memory metrics"""
        if not hasattr(self.cache, 'get_seq_length'):
            return {}
            
        return {
            'cache_size': self.cache.get_seq_length(),
            'gpu_memory': torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0,
        }

    def cleanup(self):
        """Cleanup resources"""
        del self.model
        del self.cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# Testing

In [11]:
def find_max_context_length(
    strategy: Dict,
    start_length: int = 1024,
    max_length: int = 32768,
    tolerance: int = 512,
    model_name: str = "meta-llama/Llama-3.2-1B"
) -> Tuple[int, Dict]:
    """
    Binary search to find maximum context length for a given strategy.
    Returns (max_length, metrics)
    """
    left, right = start_length, max_length
    max_successful_length = 0
    max_successful_metrics = None
    
    while left <= right:
        mid = (left + right) // 2
        # Round to nearest multiple of tolerance
        test_length = (mid // tolerance) * tolerance
        
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
            
            # Initialize model with strategy
            llm = LlamaInference(
                model_name=model_name,
                cache_config=strategy['config']
            )
            
            # Load test data
            document_loader = EnhancedDocumentLoader()
            input_text = document_loader.load_chunk(chunk_size=test_length)
            
            # Run inference with small output to test context handling
            result = llm.run_inference(
                input_text=input_text,
                max_new_tokens=20  # Small output for testing
            )
            
            # Get memory metrics
            metrics = {
                'context_length': test_length,
                'peak_memory': torch.cuda.max_memory_allocated()/1024**2,
                'reserved_memory': torch.cuda.max_memory_reserved()/1024**2,
                'throughput': result['metrics']['performance']['tokens_per_second'],
                'strategy': strategy['name']
            }
            
            # Update maximum successful length
            max_successful_length = test_length
            max_successful_metrics = metrics
            
            # Try larger context
            left = mid + 1
            
        except torch.cuda.OutOfMemoryError:
            # Try smaller context
            right = mid - 1
        
        finally:
            if 'llm' in locals():
                llm.cleanup()
    
    return max_successful_length, max_successful_metrics

def run_enhanced_benchmark(
    strategies: List[Dict] = None,
    context_checkpoints: List[int] = [4096, 8192, 16384, 32768],
) -> Dict:
    """
    Enhanced benchmark that:
    1. Finds max context length for each strategy
    2. Tests performance at various context lengths up to max
    """
    if strategies is None:
        strategies = [
            {
                "name": "dynamic",
                "config": CacheConfig(strategy="dynamic", decode_on_cpu=False)
            },
            {
                "name": "quantized_4bit",
                "config": CacheConfig(
                    strategy="quantized",
                    decode_on_cpu=False,
                    quantization={'nbits': 4, 'residual_length': 128}
                )
            },
            {
                "name": "quantized_2bit",
                "config": CacheConfig(
                    strategy="quantized",
                    decode_on_cpu=False,
                    quantization={'nbits': 2, 'residual_length': 128}
                )
            },
            {
                "name": "offloaded",
                "config": CacheConfig(strategy="offloaded", decode_on_cpu=False)
            }
        ]
    
    results = {
        'max_context': {},
        'performance_curve': {},
    }
    
    # Find max context length for each strategy
    print("\nTesting maximum context lengths:")
    for strategy in strategies:
        print(f"\nStrategy: {strategy['name']}")
        max_length, metrics = find_max_context_length(strategy)
        results['max_context'][strategy['name']] = {
            'max_length': max_length,
            'metrics': metrics
        }
        print(f"Max context length: {max_length}")
        print(f"Peak memory: {metrics['peak_memory']:.0f}MB")
    
    # Test performance at different context lengths
    print("\nTesting performance curve:")
    for strategy in strategies:
        max_length = results['max_context'][strategy['name']]['max_length']
        results['performance_curve'][strategy['name']] = []
        
        # Test at checkpoints up to max length
        valid_checkpoints = [l for l in context_checkpoints if l <= max_length]
        
        for length in valid_checkpoints:
            try:
                llm = LlamaInference(
                    model_name="meta-llama/Llama-3.2-1B", # change llama
                    cache_config=strategy['config']
                )
                
                document_loader = EnhancedDocumentLoader()
                input_text = document_loader.load_chunk(chunk_size=length)
                
                result = llm.run_inference(
                    input_text=input_text,
                    max_new_tokens=20
                )
                
                metrics = {
                    'context_length': length,
                    'peak_memory': torch.cuda.max_memory_allocated()/1024**2,
                    'reserved_memory': torch.cuda.max_memory_reserved()/1024**2,
                    'throughput': result['metrics']['performance']['tokens_per_second'],
                    'latency': result['metrics']['performance']['mean_latency']
                }
                
                results['performance_curve'][strategy['name']].append(metrics)
                
            except Exception as e:
                print(f"Error testing {strategy['name']} at length {length}: {str(e)}")
                continue
            
            finally:
                if 'llm' in locals():
                    llm.cleanup()
    
    return results

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    print("Starting enhanced benchmark...")
    
    results = run_enhanced_benchmark()
    
    # Print results in a structured format
    print("\nBenchmark Results:")
    print("=" * 80)
    
    print("\nMaximum Context Lengths:")
    for strategy, data in results['max_context'].items():
        print(f"\nStrategy: {strategy}")
        print(f"Max Length: {data['max_length']}")
        print(f"Peak Memory: {data['metrics']['peak_memory']:.0f}MB")
        print(f"Throughput: {data['metrics']['throughput']:.2f} tokens/s")
    
    print("\nPerformance Curves:")
    for strategy, curve in results['performance_curve'].items():
        print(f"\nStrategy: {strategy}")
        for point in curve:
            print(f"Context {point['context_length']}: "
                  f"{point['throughput']:.2f} tokens/s, "
                  f"{point['peak_memory']:.0f}MB peak")

INFO:__main__:Loading model 'meta-llama/Llama-3.2-1B' with cache strategy 'dynamic'
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Starting enhanced benchmark...

Testing maximum context lengths:

Strategy: dynamic


INFO:__main__:Loading model 'meta-llama/Llama-3.2-1B' with cache strategy 'dynamic'
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


KeyboardInterrupt: 

# CPU Decode

In [1]:
from typing import Dict, Optional, List
import torch
import logging

class CPUDecodeManager:
    """Manages CPU decode offloading for Llama model layers"""
    
    def __init__(
        self,
        model_config: PretrainedConfig,
        cpu_layers: Optional[List[int]] = None,
        use_fp32: bool = True,
    ):
        self.num_layers = model_config.num_hidden_layers
        self.cpu_layers = cpu_layers or list(range(self.num_layers))
        self.use_fp32 = use_fp32
        self.logger = logging.getLogger(__name__)
        
        # Track device mapping
        self.layer_devices: Dict[int, str] = {}
        self._initialize_device_mapping()
        
    def _initialize_device_mapping(self):
        """Initialize the device mapping for all layers"""
        for layer_idx in range(self.num_layers):
            if layer_idx in self.cpu_layers:
                self.layer_devices[layer_idx] = "cpu"
            else:
                self.layer_devices[layer_idx] = "cuda:0"
                
    def prepare_inputs_for_layer(
        self,
        layer_idx: int,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> tuple:
        """Move inputs to appropriate device for layer"""
        target_device = self.layer_devices[layer_idx]
        
        # Convert to fp32 if on CPU and configured
        if target_device == "cpu" and self.use_fp32:
            hidden_states = hidden_states.to(dtype=torch.float32)
            
        # Move tensors to target device
        hidden_states = hidden_states.to(target_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(target_device)
            
        return hidden_states, attention_mask
        
    def prepare_layer(self, layer: nn.Module, layer_idx: int):
        """Move layer to appropriate device and dtype"""
        target_device = self.layer_devices[layer_idx]
        target_dtype = torch.float32 if (target_device == "cpu" and self.use_fp32) else torch.float16
        
        # Move layer parameters
        layer.to(device=target_device, dtype=target_dtype)
        
    def get_layer_device(self, layer_idx: int) -> str:
        """Get target device for given layer"""
        return self.layer_devices.get(layer_idx, "cuda:0")
        
    def optimize_memory_transfers(self, hidden_states: torch.Tensor, from_layer: int, to_layer: int) -> torch.Tensor:
        """Optimize memory transfers between layers"""
        from_device = self.get_layer_device(from_layer)
        to_device = self.get_layer_device(to_layer)
        
        if from_device != to_device:
            # Handle dtype conversion if needed
            if to_device == "cpu" and self.use_fp32:
                hidden_states = hidden_states.to(dtype=torch.float32)
            elif from_device == "cpu" and self.use_fp32:
                hidden_states = hidden_states.to(dtype=torch.float16)
                
            hidden_states = hidden_states.to(to_device)
            
        return hidden_states

NameError: name 'PretrainedConfig' is not defined