<a href="https://colab.research.google.com/github/peremartra/optipfair/blob/main/examples/pruning_compatibility_check.ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#OptiPFair Notebook Series –  Pruning Compatibility Checker

![optiPfair Logo](https://github.com/peremartra/optipfair/blob/main/images/optiPfair.png?raw=true)


Verify if your model is compatible with [OptiPFair](https://github.com/peremartra/optipfair) MLP pruning capabilities for structured pruning of transformer models with GLU-based MLP layers.  

This notebook quickly verifies if your transformer model is compatible with OptipFair's **structured pruning** capabilities.

**In 30 seconds, you'll know:**
- Can I prune this model with OptipFair?
- What's the model architecture?
- What are the MLP expansion ratios?
- Any specific recommendations?

**Supported architectures:** Llama, Mistral, Gemma, Qwen, Phi, and other GLU-based models.

##Recommended Environment

- **Platform**: [Google Colab](https://colab.research.google.com)  
- **Hardware**: GPU runtime (recommended: T4 or better for 1B–3B models)  
- **Dependencies**: Installed automatically in the first cell (optipfair, transformers, torch)

##by Pere Martra.

- [LinkedIn](https://www.linkedin.com/in/pere-martra)  
- [GitHub](https://github.com/peremartra)  
- [X / Twitter](https://x.com/peremartra)


## Setup

In [1]:
# Install OptipFair if not already installed
!pip install transformers torch -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
from transformers import AutoModel, AutoConfig
import warnings
warnings.filterwarnings('ignore')

print("✅ Setup complete!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🤗 Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print("📦 OptipFair will be used for actual pruning operations")

✅ Setup complete!
🔥 PyTorch version: 2.6.0+cu124
🤗 Device: GPU
📦 OptipFair will be used for actual pruning operations


## Model Input
**Enter your model name below:**  
You can use any Hugging Face model ID (e.g., `microsoft/Phi-3-mini-4k-instruct`, `google/gemma-2-2b`)


In [3]:
# 👇 EDIT THIS: Enter your model name
MODEL_NAME = "google/gemma-3-1b-pt"  # Change this to test your model
#MODEL_NAME = "meta-llama/Llama-3.2-1B"  # Change this to test your model
#MODEL_NAME = "Qwen/Qwen3-0.6B"
print(f"🔍 Checking compatibility for: {MODEL_NAME}")

🔍 Checking compatibility for: google/gemma-3-1b-pt


## Compatibility Analysis

In [4]:
def check_model_compatibility(model_name):
    """
    Comprehensive compatibility check for OptipFair pruning
    """
    try:
        print("🔄 Loading model configuration...")
        config = AutoConfig.from_pretrained(model_name)

        # Initialize results
        results = {
            "model_name": model_name,
            "compatible": False,
            "architecture": "Unknown",
            "issues": [],
            "recommendations": [],
            "details": {}
        }

        # Extract basic info with proper handling of missing fields
        results["details"]["model_type"] = getattr(config, 'model_type', 'Unknown')
        results["details"]["num_layers"] = getattr(config, 'num_hidden_layers', 'N/A')
        results["details"]["hidden_size"] = getattr(config, 'hidden_size', 'N/A')
        results["details"]["intermediate_size"] = getattr(config, 'intermediate_size', 'N/A')

        # Calculate expansion ratio from config if possible
        hidden_size = getattr(config, 'hidden_size', None)
        intermediate_size = getattr(config, 'intermediate_size', None)

        if hidden_size and intermediate_size and hidden_size > 0:
            config_expansion_ratio = (intermediate_size / hidden_size) * 100
            results["details"]["config_expansion_ratio"] = f"{config_expansion_ratio:.0f}%"
        else:
            results["details"]["config_expansion_ratio"] = "N/A"

        print(f"📊 Model type: {results['details']['model_type']}")
        print(f"📊 Layers: {results['details']['num_layers']}")
        print(f"📊 Hidden size: {results['details']['hidden_size']}")
        print(f"📊 Intermediate size: {results['details']['intermediate_size']}")
        print(f"📊 Config expansion ratio: {results['details']['config_expansion_ratio']}")

        return results

    except Exception as e:
        print(f"❌ Error loading model: {str(e)}")
        return None

In [5]:
# Run the check
compatibility_results = check_model_compatibility(MODEL_NAME)

🔄 Loading model configuration...


config.json:   0%|          | 0.00/880 [00:00<?, ?B/s]

📊 Model type: gemma3_text
📊 Layers: 26
📊 Hidden size: 1152
📊 Intermediate size: 6912
📊 Config expansion ratio: 600%


In [6]:
def analyze_mlp_structure(model_name, config):
    """
    Analyze MLP structure for pruning compatibility
    """
    try:
        print("\n🔍 Analyzing MLP structure...")

        # Load a small portion of the model to inspect structure
        model = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto" if torch.cuda.is_available() else "cpu",
            trust_remote_code=True
        )

        # Try to find the first transformer layer (different models have different structures)
        first_layer = None
        if hasattr(model, 'layers') and len(model.layers) > 0:
            first_layer = model.layers[0]
        elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h') and len(model.transformer.h) > 0:
            first_layer = model.transformer.h[0]
        elif hasattr(model, 'model') and hasattr(model.model, 'layers') and len(model.model.layers) > 0:
            first_layer = model.model.layers[0]

        if first_layer is None:
            print("⚠️  Could not find transformer layers")
            return None

        # Try to find MLP/feed_forward layer
        mlp = None
        if hasattr(first_layer, 'mlp'):
            mlp = first_layer.mlp
        elif hasattr(first_layer, 'feed_forward'):
            mlp = first_layer.feed_forward
        elif hasattr(first_layer, 'ffn'):
            mlp = first_layer.ffn

        if mlp is None:
            print("⚠️  Could not find MLP layer")
            return None

        # Check for GLU structure (gate_proj + up_proj + down_proj)
        has_gate_proj = hasattr(mlp, 'gate_proj')
        has_up_proj = hasattr(mlp, 'up_proj')
        has_down_proj = hasattr(mlp, 'down_proj')

        # Alternative names for some models
        if not has_gate_proj:
            has_gate_proj = hasattr(mlp, 'w1') or hasattr(mlp, 'gate_linear')
        if not has_up_proj:
            has_up_proj = hasattr(mlp, 'w3') or hasattr(mlp, 'up_linear')
        if not has_down_proj:
            has_down_proj = hasattr(mlp, 'w2') or hasattr(mlp, 'down_linear')

        # Calculate expansion ratio
        expansion_ratio = 0
        input_dim = 0
        output_dim = 0

        try:
            if has_gate_proj and has_up_proj:
                gate_layer = getattr(mlp, 'gate_proj', getattr(mlp, 'w1', None))
                if gate_layer and hasattr(gate_layer, 'in_features') and hasattr(gate_layer, 'out_features'):
                    input_dim = gate_layer.in_features
                    output_dim = gate_layer.out_features
                    expansion_ratio = (output_dim / input_dim) * 100
        except Exception as e:
            print(f"⚠️  Could not calculate expansion ratio: {str(e)}")

        # Clean up memory
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return {
            "has_glu": has_gate_proj and has_up_proj and has_down_proj,
            "expansion_ratio": expansion_ratio,
            "input_dim": input_dim,
            "output_dim": output_dim,
            "found_mlp": True
        }

    except Exception as e:
        print(f"⚠️  Could not analyze MLP structure: {str(e)}")
        return {
            "has_glu": False,
            "expansion_ratio": 0,
            "input_dim": 0,
            "output_dim": 0,
            "found_mlp": False,
            "error": str(e)
        }

# Analyze MLP structure
if compatibility_results:
    mlp_analysis = analyze_mlp_structure(MODEL_NAME, AutoConfig.from_pretrained(MODEL_NAME))


🔍 Analyzing MLP structure...


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

Some weights of Gemma3TextModel were not initialized from the model checkpoint at google/gemma-3-1b-pt and are newly initialized: ['embed_tokens.weight', 'layers.0.input_layernorm.weight', 'layers.0.mlp.down_proj.weight', 'layers.0.mlp.gate_proj.weight', 'layers.0.mlp.up_proj.weight', 'layers.0.post_attention_layernorm.weight', 'layers.0.post_feedforward_layernorm.weight', 'layers.0.pre_feedforward_layernorm.weight', 'layers.0.self_attn.k_norm.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.self_attn.q_norm.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.1.input_layernorm.weight', 'layers.1.mlp.down_proj.weight', 'layers.1.mlp.gate_proj.weight', 'layers.1.mlp.up_proj.weight', 'layers.1.post_attention_layernorm.weight', 'layers.1.post_feedforward_layernorm.weight', 'layers.1.pre_feedforward_layernorm.weight', 'layers.1.self_attn.k_norm.weight', 'layers.1.self_attn.k_proj.weight', 'layers.1.self_attn.o_pr

In [7]:
def generate_final_assessment(results, mlp_analysis):
    """
    Generate final compatibility assessment
    """
    if not results:
        print("❌ INCOMPATIBLE: Could not load model configuration")
        return False

    if not mlp_analysis or not mlp_analysis.get("found_mlp", False):
        print("❌ INCOMPATIBLE: Could not analyze MLP structure")
        results["issues"].append("❌ Could not access MLP layers")
        results["compatible"] = False
        return results

    # Update results with MLP analysis
    results["details"].update(mlp_analysis)

    # Determine compatibility
    model_type = results["details"]["model_type"].lower()
    has_glu = mlp_analysis["has_glu"]
    expansion_ratio = mlp_analysis["expansion_ratio"]

    # Known compatible architectures
    compatible_types = ["llama", "mistral", "gemma", "qwen", "phi"]

    if any(comp_type in model_type for comp_type in compatible_types):
        results["architecture"] = "Supported"
        if has_glu and expansion_ratio > 100:
            results["compatible"] = True
            results["recommendations"].append(f"✅ Perfect! GLU structure detected with {expansion_ratio:.0f}% expansion")
        else:
            results["issues"].append("❌ No GLU structure found or insufficient expansion")
            if expansion_ratio > 0:
                results["issues"].append(f"⚠️  Expansion ratio: {expansion_ratio:.0f}% (minimum 100% recommended)")
    else:
        results["architecture"] = "Unknown/Unsupported"
        results["issues"].append(f"❌ Architecture '{model_type}' not yet supported")
        results["recommendations"].append("📧 Request support via GitHub issues")

    return results

# Generate final assessment
final_results = generate_final_assessment(compatibility_results, mlp_analysis)

## Final Results

In [8]:
def display_results(results):
    """
    Display the final compatibility results in a clean format
    """
    print("=" * 60)
    print("🎯 OPTIPFAIR PRUNING COMPATIBILITY REPORT")
    print("=" * 60)

    # Header
    status_emoji = "✅" if results["compatible"] else "❌"
    status_text = "COMPATIBLE" if results["compatible"] else "NOT COMPATIBLE"

    print(f"\n{status_emoji} STATUS: {status_text}")
    print(f"🏗️  ARCHITECTURE: {results['architecture']}")
    print(f"📦 MODEL: {results['model_name']}")

    # Details
    print(f"\n📊 TECHNICAL DETAILS:")
    details = results["details"]
    print(f"   • Model Type: {details.get('model_type', 'unknown')}")
    print(f"   • Layers: {details.get('num_layers', 'unknown')}")
    print(f"   • Hidden Size: {details.get('hidden_size', 'unknown')}")

    if details.get("expansion_ratio", 0) > 0:
        print(f"   • MLP Expansion: {details['expansion_ratio']:.0f}%")
        print(f"   • GLU Structure: {'✅ Yes' if details.get('has_glu') else '❌ No'}")

    # Issues
    if results["issues"]:
        print(f"\n⚠️  ISSUES FOUND:")
        for issue in results["issues"]:
            print(f"   {issue}")

    # Recommendations
    if results["recommendations"]:
        print(f"\n💡 RECOMMENDATIONS:")
        for rec in results["recommendations"]:
            print(f"   {rec}")

    # Next steps
    print(f"\n🚀 NEXT STEPS:")
    if results["compatible"]:
        print("   📦 Install OptipFair: pip install optipfair")
        print("   📝 Check the examples/ folder in OptipFair repository")
        print("   🔗 https://github.com/peremartra/optipfair")
    else:
        print("   📧 Open an issue: https://github.com/peremartra/optipfair/issues")
        print("   📚 Check supported models: https://github.com/peremartra/optipfair#supported-models")

    print("=" * 60)

# Display final results
if final_results:
    display_results(final_results)
else:
    print("❌ Could not complete compatibility check")

🎯 OPTIPFAIR PRUNING COMPATIBILITY REPORT

✅ STATUS: COMPATIBLE
🏗️  ARCHITECTURE: Supported
📦 MODEL: google/gemma-3-1b-pt

📊 TECHNICAL DETAILS:
   • Model Type: gemma3_text
   • Layers: 26
   • Hidden Size: 1152
   • MLP Expansion: 600%
   • GLU Structure: ✅ Yes

💡 RECOMMENDATIONS:
   ✅ Perfect! GLU structure detected with 600% expansion

🚀 NEXT STEPS:
   📦 Install OptipFair: pip install optipfair
   📝 Check the examples/ folder in OptipFair repository
   🔗 https://github.com/peremartra/optipfair


## 🔗 What's Next?

### ✅ If your model is compatible:
- **Try pruning:** Run the `basic_pruning_mlp.ipynb` notebook  
- **Optimize further:** Experiment with different pruning percentages
- **Visualize:** Check out `visualization_compatibility_check.ipynb`

### ❌ If your model is not compatible:
- **Request support:** Open an issue on [GitHub](https://github.com/peremartra/optipfair/issues)
- **Check updates:** New architectures are added regularly
- **Contribute:** Help us add support for your model!

---

## 📚 Learn More

- **📖 Documentation:** [OptipFair GitHub](https://github.com/peremartra/optipfair)  
- **📝 Tutorials:** [Large Language Models Course](https://github.com/peremartra/Large-Language-Model-Notebooks-Course)
- **🎯 Research:** [GLU Expansion Ratios Paper](https://osf.io/preprints/osf/qgxea)

---

If you found this notebook useful, the best way to support the OptiPFair project is by **starring it on GitHub**. Your support is a huge help in boosting the project's visibility and reaching more developers and researchers.

### ➡️ [**Star OptiPFair on GitHub**](https://github.com/peremartra/optipfair)

---
You can also follow my work and new projects on:

* **[LinkedIn](https://www.linkedin.com/in/pere-martra/)**
* **[X / Twitter](https://twitter.com/PereMartra)**