In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Tuple

def print_activation_dimensions(model: nn.Module, model_name: str = "DeepSeek R1"):
    """
    Print activation dimensions across all layers of a loaded model.
    Useful for GCG implementation and refusal direction analysis.
    """
    print(f"=== {model_name} Layer Activation Dimensions ===\n")
    
    # Get model configuration
    config = model.config
    print(f"Model Configuration:")
    print(f"  Hidden size: {config.hidden_size}")
    print(f"  Number of layers: {config.num_hidden_layers}")
    print(f"  Number of attention heads: {config.num_attention_heads}")
    print(f"  Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}")
    print(f"  Vocab size: {config.vocab_size}")
    print(f"  Max position embeddings: {getattr(config, 'max_position_embeddings', 'N/A')}")
    print()
    
    layer_info = []
    
    # Iterate through all named modules
    for name, module in model.named_modules():
        if hasattr(module, 'weight') and module.weight is not None:
            weight_shape = tuple(module.weight.shape)
            
            # Categorize layer types for GCG targeting
            layer_type = get_layer_type(name, module)
            
            layer_info.append({
                'name': name,
                'type': layer_type,
                'weight_shape': weight_shape,
                'module_type': type(module).__name__
            })
    
    # Print layer information grouped by type
    print_layers_by_type(layer_info)
    
    # Print summary for GCG implementation
    print_gcg_summary(layer_info, config)

def get_layer_type(name: str, module: nn.Module) -> str:
    """Categorize layer types for GCG analysis."""
    name_lower = name.lower()
    
    if 'embed' in name_lower:
        return 'embedding'
    elif 'attention' in name_lower or 'attn' in name_lower:
        if 'q_proj' in name_lower or 'query' in name_lower:
            return 'attention_query'
        elif 'k_proj' in name_lower or 'key' in name_lower:
            return 'attention_key'
        elif 'v_proj' in name_lower or 'value' in name_lower:
            return 'attention_value'
        elif 'o_proj' in name_lower or 'out' in name_lower:
            return 'attention_output'
        else:
            return 'attention_other'
    elif 'mlp' in name_lower or 'feed_forward' in name_lower or 'ffn' in name_lower:
        if 'gate' in name_lower:
            return 'mlp_gate'
        elif 'up' in name_lower:
            return 'mlp_up'
        elif 'down' in name_lower:
            return 'mlp_down'
        else:
            return 'mlp_other'
    elif 'norm' in name_lower or 'layer_norm' in name_lower:
        return 'normalization'
    elif 'lm_head' in name_lower or 'output' in name_lower:
        return 'output_head'
    else:
        return 'other'

def print_layers_by_type(layer_info: List[Dict]):
    """Print layers grouped by type for better organization."""
    from collections import defaultdict
    
    layers_by_type = defaultdict(list)
    for layer in layer_info:
        layers_by_type[layer['type']].append(layer)
    
    for layer_type in sorted(layers_by_type.keys()):
        print(f"\n--- {layer_type.upper().replace('_', ' ')} LAYERS ---")
        for layer in layers_by_type[layer_type]:
            print(f"  {layer['name']:<50} | {str(layer['weight_shape']):<20} | {layer['module_type']}")

def print_gcg_summary(layer_info: List[Dict], config):
    """Print summary information useful for GCG implementation."""
    print(f"\n=== GCG Implementation Summary ===")
    
    # Find transformer layers
    transformer_layers = [l for l in layer_info if 'layers.' in l['name'] and l['type'] in ['attention_output', 'mlp_down']]
    
    if transformer_layers:
        print(f"Transformer layers detected: {len(set(l['name'].split('.layers.')[1].split('.')[0] for l in transformer_layers))}")
        
        # Get representative dimensions
        hidden_dim = config.hidden_size
        print(f"Hidden dimension: {hidden_dim}")
        
        # Common target layers for refusal direction analysis
        print(f"\nRecommended layers for refusal direction extraction:")
        print(f"  - Residual stream dimension: {hidden_dim}")
        print(f"  - MLP output layers: Look for 'mlp.down_proj' or similar")
        print(f"  - Attention output layers: Look for 'self_attn.o_proj' or similar")
        
        # Print some example layer names for targeting
        example_layers = [l['name'] for l in transformer_layers[:5]]
        print(f"\nExample targetable layer names:")
        for layer_name in example_layers:
            print(f"  - {layer_name}")

def analyze_specific_layers(model: nn.Module, layer_patterns: List[str] = None):
    """
    Analyze specific layers matching given patterns.
    Useful for targeting specific components in GCG.
    """
    if layer_patterns is None:
        # Default patterns for common refusal direction extraction points
        layer_patterns = [
            'layers.*.self_attn.o_proj',
            'layers.*.mlp.down_proj',
            'layers.*.mlp.gate_proj',
            'model.layers.*.self_attn.o_proj',
            'model.layers.*.mlp.down_proj'
        ]
    
    print(f"\n=== Analyzing Specific Layer Patterns ===")
    
    import re
    for pattern in layer_patterns:
        print(f"\nPattern: {pattern}")
        regex_pattern = pattern.replace('*', r'\d+')
        
        matching_layers = []
        for name, module in model.named_modules():
            if re.search(regex_pattern, name) and hasattr(module, 'weight'):
                matching_layers.append((name, module.weight.shape))
        
        if matching_layers:
            print(f"  Found {len(matching_layers)} matching layers:")
            for name, shape in matching_layers[:3]:  # Show first 3
                print(f"    {name}: {shape}")
            if len(matching_layers) > 3:
                print(f"    ... and {len(matching_layers) - 3} more")
        else:
            print("  No matching layers found")

# Example usage
def main():
    """
    Example usage for analyzing DeepSeek R1 model dimensions.
    Adapt the model loading code for your specific setup.
    """
    # Replace with your actual model loading code
    print("Loading DeepSeek R1 model...")
    
    # Example model loading (adapt to your setup)
    model = AutoModelForCausalLM.from_pretrained(
        "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        torch_dtype=torch.bfloat16,
        device_map="cuda:3"
    )
    
    # For demonstration, assuming you have a loaded model
    # print_activation_dimensions(model, "DeepSeek R1")
    # analyze_specific_layers(model)
    
    print("Replace the model loading section with your actual loaded DeepSeek R1 model.")
    print("Then call: print_activation_dimensions(your_model)")

if __name__ == "__main__":
    main()

Loading DeepSeek R1 model...
Replace the model loading section with your actual loaded DeepSeek R1 model.
Then call: print_activation_dimensions(your_model)
