# Debug: PyTorch Quantile Dtype Error in Wanda Selectivity

## Problem Description

When running the Wanda IDF pruning method, we encounter this error:
```
RuntimeError: quantile() input tensor must be either float or double dtype
```

This error occurs in `lib/wanda_selectivity.py` at line 68 in the `finalize()` method when calling `torch.quantile()`.

## Error Location
- **File:** `lib/wanda_selectivity.py`
- **Method:** `SelectivityStatsLight.finalize()`
- **Line:** `q90_per_channel = torch.quantile(all_samples, 0.9, dim=0)`

The issue is that `torch.quantile()` requires tensors to be of float32 or float64 dtype, but the `all_samples` tensor might be in a different dtype (like int32, int64, float16, etc.).

In [None]:
# Setup Environment and Imports
import torch
import numpy as np
import warnings
warnings.filterwarnings('ignore')

print(f"✅ PyTorch version: {torch.__version__}")
print(f"✅ Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Reproduce the Error

Let's reproduce the exact error that occurs in the `SelectivityStatsLight.finalize()` method:

In [None]:
# Reproduce the problematic scenario
def reproduce_quantile_error():
    """Reproduce the exact error from wanda_selectivity.py"""
    
    # Create sample data that might have incompatible dtype
    # This simulates what happens in the SelectivityStatsLight class
    samples = []
    
    # Simulate different potential dtypes that could cause issues
    problematic_dtypes = [torch.int32, torch.int64, torch.float16]
    
    for dtype in problematic_dtypes:
        print(f"\n🧪 Testing with dtype: {dtype}")
        
        # Create sample tensors (simulating the samples list in SelectivityStatsLight)
        if dtype in [torch.int32, torch.int64]:
            # Integer data (could come from tokenizer outputs or similar)
            sample_data = torch.randint(0, 100, (50, 512), dtype=dtype)
        else:
            # Float16 data (common in mixed precision training)
            sample_data = torch.randn(50, 512, dtype=dtype)
        
        samples.append(sample_data)
        
        # Try to concatenate and run quantile (this will fail)
        try:
            all_samples = torch.cat(samples, dim=0)
            print(f"   Combined tensor dtype: {all_samples.dtype}")
            print(f"   Combined tensor shape: {all_samples.shape}")
            
            # This line will cause the error for non-float dtypes
            q90_per_channel = torch.quantile(all_samples, 0.9, dim=0)
            print(f"   ✅ Success! Quantile computed successfully")
            
        except RuntimeError as e:
            print(f"   ❌ Error: {e}")
        
        # Clear samples for next iteration
        samples.clear()

reproduce_quantile_error()

## 2. Investigate Data Types

Let's examine which tensor dtypes are compatible with `torch.quantile()`:

In [None]:
# Test which dtypes work with torch.quantile()
def test_quantile_dtypes():
    """Test various tensor dtypes with torch.quantile()"""
    
    test_dtypes = [
        torch.float32,   # Should work
        torch.float64,   # Should work  
        torch.float16,   # Might not work
        torch.int32,     # Won't work
        torch.int64,     # Won't work
        torch.bool,      # Won't work
    ]
    
    print("Testing torch.quantile() compatibility with different dtypes:\n")
    
    for dtype in test_dtypes:
        try:
            # Create test tensor
            if dtype == torch.bool:
                test_tensor = torch.randint(0, 2, (100, 10), dtype=torch.int32).bool()
            elif dtype in [torch.int32, torch.int64]:
                test_tensor = torch.randint(0, 100, (100, 10), dtype=dtype)
            else:
                test_tensor = torch.randn(100, 10, dtype=dtype)
            
            # Try quantile operation
            result = torch.quantile(test_tensor, 0.9, dim=0)
            print(f"✅ {dtype}: SUCCESS - Result dtype: {result.dtype}")
            
        except Exception as e:
            print(f"❌ {dtype}: FAILED - {str(e)[:50]}...")

test_quantile_dtypes()

## 3. Check Tensor Properties

Let's examine the properties that could lead to dtype issues in the original code:

In [None]:
# Simulate the problematic SelectivityStatsLight behavior
class SelectivityStatsLightOriginal:
    """Original (buggy) version that causes the dtype error"""
    
    def __init__(self, num_channels: int):
        self.num_channels = num_channels
        self.samples = []
        self.max_samples = 200

    def update(self, x: torch.Tensor):
        if x.numel() == 0:
            return
        
        n_rows = x.shape[0]
        if len(self.samples) < self.max_samples:
            n_to_sample = min(50, n_rows, self.max_samples - len(self.samples))
            if n_to_sample > 0:
                indices = torch.randperm(n_rows, device=x.device)[:n_to_sample]
                # BUG: Not ensuring dtype is float - just moving to CPU
                sampled = x[indices].cpu()  # This preserves original dtype!
                self.samples.append(sampled)
    
    def finalize(self):
        if len(self.samples) == 0:
            return torch.ones(self.num_channels), torch.ones(self.num_channels)
        
        # BUG: torch.cat preserves the dtype of input tensors
        all_samples = torch.cat(self.samples, dim=0)
        print(f"   All samples dtype before quantile: {all_samples.dtype}")
        
        # This will fail if all_samples is not float32/float64
        q90_per_channel = torch.quantile(all_samples, 0.9, dim=0)
        return torch.ones(self.num_channels), torch.ones(self.num_channels)

# Test the buggy version
print("Testing original (buggy) SelectivityStatsLight:\n")

# Test with float16 input (common in mixed precision)
stats_buggy = SelectivityStatsLightOriginal(num_channels=128)
test_input = torch.randn(100, 128, dtype=torch.float16)

print(f"Input tensor dtype: {test_input.dtype}")
stats_buggy.update(test_input)

try:
    stats_buggy.finalize()
    print("✅ No error occurred")
except RuntimeError as e:
    print(f"❌ Error: {e}")

## 4. Fix the Dtype Issue

Now let's implement the fix by ensuring tensors are converted to float before quantile operations:

In [None]:
# Fixed version of SelectivityStatsLight
class SelectivityStatsLightFixed:
    """Fixed version that handles dtype conversion properly"""
    
    def __init__(self, num_channels: int):
        self.num_channels = num_channels
        self.samples = []
        self.max_samples = 200

    def update(self, x: torch.Tensor):
        if x.numel() == 0:
            return
        
        n_rows = x.shape[0]
        if len(self.samples) < self.max_samples:
            n_to_sample = min(50, n_rows, self.max_samples - len(self.samples))
            if n_to_sample > 0:
                indices = torch.randperm(n_rows, device=x.device)[:n_to_sample]
                # FIX 1: Ensure float dtype when storing samples
                sampled = x[indices].cpu().float()  # Convert to float32
                self.samples.append(sampled)
    
    def finalize(self):
        if len(self.samples) == 0:
            return torch.ones(self.num_channels), torch.ones(self.num_channels)
        
        all_samples = torch.cat(self.samples, dim=0)
        
        # FIX 2: Double-check dtype before quantile operation
        if all_samples.dtype not in [torch.float32, torch.float64]:
            all_samples = all_samples.float()
        
        print(f"   All samples dtype after fix: {all_samples.dtype}")
        
        # Compute per-channel stats
        mean_per_channel = all_samples.mean(dim=0)
        median_per_channel = torch.median(all_samples, dim=0).values
        
        # IDF calculation
        above_median = (all_samples > median_per_channel.unsqueeze(0)).float().mean(dim=0)
        p_j = torch.clamp(above_median, 1e-6, 0.999)
        idf_scores = torch.log(1.0 / (p_j + 1e-9))
        idf_scores = torch.clamp(idf_scores, 0.1, 10.0)
        
        # Spikiness calculation - this will now work!
        q90_per_channel = torch.quantile(all_samples, 0.9, dim=0)
        spikiness_scores = torch.ones(self.num_channels)
        
        for j in range(self.num_channels):
            top_vals = all_samples[all_samples[:, j] >= q90_per_channel[j], j]
            if len(top_vals) > 0:
                mu_top = top_vals.mean().item()
                mu = mean_per_channel[j].item()
                spikiness_scores[j] = mu_top / (mu + 1e-9)
        
        spikiness_scores = torch.clamp(spikiness_scores, 1.0, 20.0)
        
        return idf_scores, spikiness_scores

# Test the fixed version
print("Testing fixed SelectivityStatsLight:\n")

stats_fixed = SelectivityStatsLightFixed(num_channels=128)

# Test with various problematic dtypes
test_dtypes = [torch.float16, torch.int32, torch.float32]

for dtype in test_dtypes:
    print(f"\n🧪 Testing with {dtype}:")
    
    if dtype in [torch.int32]:
        test_input = torch.randint(0, 100, (100, 128), dtype=dtype)
    else:
        test_input = torch.randn(100, 128, dtype=dtype)
    
    print(f"   Input dtype: {test_input.dtype}")
    
    # Reset stats for each test
    stats_fixed.samples = []
    stats_fixed.update(test_input)
    
    try:
        idf_scores, spikiness_scores = stats_fixed.finalize()
        print(f"   ✅ Success! IDF shape: {idf_scores.shape}, Spikiness shape: {spikiness_scores.shape}")
    except Exception as e:
        print(f"   ❌ Error: {e}")

## 5. Verify the Solution

Let's demonstrate that our fix resolves the original error:

In [None]:
# Comprehensive verification of the fix
def verify_fix_comprehensive():
    """Test the fix with edge cases and various scenarios"""
    
    print("🔍 Comprehensive Fix Verification\n")
    
    # Test scenarios that could occur in real usage
    test_scenarios = [
        ("Mixed precision (float16)", torch.float16, lambda: torch.randn(150, 256, dtype=torch.float16)),
        ("Integer weights", torch.int32, lambda: torch.randint(-50, 50, (150, 256), dtype=torch.int32)),
        ("Large integer values", torch.int64, lambda: torch.randint(-1000, 1000, (150, 256), dtype=torch.int64)),
        ("Normal float32", torch.float32, lambda: torch.randn(150, 256, dtype=torch.float32)),
        ("Double precision", torch.float64, lambda: torch.randn(150, 256, dtype=torch.float64)),
    ]
    
    for scenario_name, dtype, data_generator in test_scenarios:
        print(f"📋 Testing: {scenario_name}")
        
        # Test original (should fail for some dtypes)
        stats_original = SelectivityStatsLightOriginal(num_channels=256)
        test_data = data_generator()
        
        print(f"   Input: {test_data.dtype}, shape: {test_data.shape}")
        stats_original.update(test_data)
        
        try:
            stats_original.finalize()
            print("   Original: ✅ Success")
        except RuntimeError as e:
            print(f"   Original: ❌ Failed - {str(e)[:40]}...")
        
        # Test fixed version (should always work)
        stats_fixed = SelectivityStatsLightFixed(num_channels=256)
        stats_fixed.update(test_data)
        
        try:
            idf, spikiness = stats_fixed.finalize()
            print(f"   Fixed: ✅ Success - IDF: {idf.dtype}, Spikiness: {spikiness.dtype}")
        except Exception as e:
            print(f"   Fixed: ❌ Failed - {e}")
        
        print()

verify_fix_comprehensive()

## 6. Test with Different Data Types

Final validation with extreme edge cases:

In [None]:
# Edge case testing
def test_edge_cases():
    """Test with edge cases that might occur in practice"""
    
    print("🚀 Edge Case Testing\n")
    
    edge_cases = [
        ("Empty tensor", lambda: torch.empty(0, 128)),
        ("Single sample", lambda: torch.randn(1, 128, dtype=torch.float16)),
        ("Very small values", lambda: torch.randn(100, 128) * 1e-8),
        ("Very large values", lambda: torch.randn(100, 128) * 1e8),
        ("All zeros", lambda: torch.zeros(100, 128)),
        ("All ones", lambda: torch.ones(100, 128)),
        ("Boolean as int", lambda: torch.randint(0, 2, (100, 128), dtype=torch.int32)),
    ]
    
    for case_name, data_generator in edge_cases:
        print(f"🧪 Testing: {case_name}")
        
        try:
            test_data = data_generator()
            print(f"   Data: {test_data.dtype}, shape: {test_data.shape}")
            
            if test_data.numel() == 0:
                print("   Skipping empty tensor")
                continue
            
            stats = SelectivityStatsLightFixed(num_channels=test_data.shape[-1])
            stats.update(test_data)
            
            idf, spikiness = stats.finalize()
            print(f"   ✅ Success - IDF range: [{idf.min():.3f}, {idf.max():.3f}]")
            print(f"                Spikiness range: [{spikiness.min():.3f}, {spikiness.max():.3f}]")
            
        except Exception as e:
            print(f"   ❌ Failed: {e}")
        
        print()

test_edge_cases()

## Summary

### ✅ **Problem Fixed!**

The error `RuntimeError: quantile() input tensor must be either float or double dtype` was caused by:

1. **Root Cause**: `torch.quantile()` only accepts float32 or float64 tensors, but the code was passing tensors with other dtypes (like float16, int32, int64).

2. **Location**: In `lib/wanda_selectivity.py`, line 68 in the `SelectivityStatsLight.finalize()` method.

3. **Solution Applied**: Two fixes were implemented:
   - **Fix 1**: In `update()` method - Convert tensors to float32 when storing: `.cpu().float()`
   - **Fix 2**: In `finalize()` method - Double-check dtype before quantile operation

### 🔧 **Code Changes Made**

In `/Users/vasyl/Projects/wandar/lib/wanda_selectivity.py`:

```python
# Line ~35: In update() method
sampled = x[indices].cpu().float()  # Added .float()

# Line ~55: In finalize() method  
if all_samples.dtype not in [torch.float32, torch.float64]:
    all_samples = all_samples.float()
```

### 🚀 **Ready to Run**

Your original command should now work without errors:

```bash
python main.py \
  --model baffo32/decapoda-research-llama-7B-hf \
  --prune_method wanda_idf \
  --sparsity_ratio 0.5 \
  --sparsity_type unstructured \
  --save out/wanda_idf/
```