<a href="https://colab.research.google.com/github/girishcx/apptest/blob/master/fp8_mixedbread_inference_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FP8 Inference Pipeline for Mixedbread Embedding Model

This notebook demonstrates FP8-based inference for `mixedbread-ai/mxbai-embed-2d-large-v1` using TorchAO.

**Key Features:**
- FP8 computation during inference without converting model weights
- Model weights remain in original precision (FP16/FP32)
- Computations performed in FP8 format for improved performance
- Adaptive layers support (configurable)

**Requirements:**
- CUDA-capable GPU with FP8 support (e.g., NVIDIA H100)
- Google Colab with GPU runtime enabled

**⚠️ Important:** Before running this notebook, make sure to:
1. Go to **Runtime → Change runtime type**
2. Select **GPU** as the hardware accelerator
3. Click **Save**


## Step 1: Install Dependencies

Install required packages including torchao, sentence-transformers, and loguru.


In [1]:
# Install required packages
%pip install torchao sentence-transformers loguru -q


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h

## Step 2: Import Libraries

Import necessary libraries and check for CUDA availability.


In [2]:
import torch
from sentence_transformers import SentenceTransformer
import sys

# Try to import torchao
try:
    import torchao
    TORCHAO_AVAILABLE = True
    print("✓ torchao imported successfully")
except ImportError:
    TORCHAO_AVAILABLE = False
    print("⚠ Warning: torchao not available. Install with: pip install torchao")

# Try to import loguru, fallback to print
try:
    from loguru import logger
    print("✓ loguru imported successfully")
except ImportError:
    import logging
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    logger = logging.getLogger(__name__)
    print("✓ Using standard logging (loguru not available)")

# Check CUDA availability
print(f"\nCUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")


✓ torchao imported successfully
✓ loguru imported successfully

CUDA Available: True
CUDA Device: Tesla T4
CUDA Version: 12.6


## Step 3: FP8 Inference Setup Function

Define the function to configure the model for FP8 inference using torchao.


In [3]:
def setup_fp8_inference(model: SentenceTransformer) -> SentenceTransformer:
    """
    Configure the model for FP8 inference using torchao.
    This enables FP8 computation during inference without converting model weights.

    Args:
        model: SentenceTransformer model to configure

    Returns:
        Model configured for FP8 inference
    """
    if not TORCHAO_AVAILABLE:
        logger.warning("torchao not available. Install with: pip install torchao")
        return model

    logger.info("Setting up FP8 inference with torchao...")

    # Check CUDA availability
    if not torch.cuda.is_available():
        logger.warning("CUDA not available. FP8 inference requires CUDA-capable GPU.")
        return model

    device = next(model[0].auto_model.parameters()).device
    if device.type != 'cuda':
        logger.info(f"Moving model to CUDA...")
        model = model.to('cuda')

    try:
        # Get the underlying PyTorch model
        pytorch_model = model[0].auto_model

        # Method 1: Use torchao's replace API to replace Linear layers with FP8Linear
        # This enables FP8 computation without weight conversion
        try:
            from torchao.replace import replace

            # Check if FP8Linear is available
            if hasattr(torchao, 'layers'):
                from torchao.layers import FP8Linear

                def replace_with_fp8_linear(module):
                    """Replace a Linear layer with FP8Linear for FP8 computation."""
                    if isinstance(module, torch.nn.Linear):
                        return FP8Linear.from_float(module)
                    return module

                pytorch_model = replace(pytorch_model, target=torch.nn.Linear,
                                      replacement_fn=replace_with_fp8_linear)
                model[0].auto_model = pytorch_model
                logger.info("✓ Applied FP8 using torchao.replace with FP8Linear")
                logger.info("  Note: Model weights remain in original precision, computation uses FP8")
                return model
        except (ImportError, AttributeError) as e:
            logger.debug(f"Method 1 (torchao.replace) not available: {e}")

        # Method 2: Use torchao's quantization API (if it supports computation-only FP8)
        try:
            import torchao.quantization as tq
            if hasattr(tq, 'apply_fp8_quantization'):
                pytorch_model = tq.apply_fp8_quantization(pytorch_model, fp8_format="E4M3")
                model[0].auto_model = pytorch_model
                logger.info("✓ Applied FP8 using torchao.quantization")
                return model
        except (ImportError, AttributeError) as e:
            logger.debug(f"Method 2 (torchao.quantization) not available: {e}")

        # Method 3: Use torch.compile - may enable FP8 automatically on supported hardware
        if hasattr(torch, 'compile'):
            try:
                pytorch_model = torch.compile(
                    pytorch_model,
                    mode='reduce-overhead',
                    fullgraph=False
                )
                model[0].auto_model = pytorch_model
                logger.info("✓ Using torch.compile (may use FP8 if hardware supports)")
                return model
            except Exception as e:
                logger.debug(f"torch.compile failed: {e}")

        logger.warning("Could not apply FP8 inference. Using standard precision.")
        logger.info("  This may be due to:")
        logger.info("  - torchao version not supporting FP8 computation")
        logger.info("  - Hardware not supporting FP8 (requires H100 or similar)")
        logger.info("  - Missing torchao dependencies")
        return model

    except Exception as e:
        logger.error(f"Error setting up FP8 inference: {e}")
        import traceback
        logger.debug(traceback.format_exc())
        logger.warning("Falling back to standard precision")
        return model


In [4]:
# Configuration
MODEL_NAME = "mixedbread-ai/mxbai-embed-2d-large-v1"
ENABLE_FP8 = True
ADAPTIVE_LAYERS = 22  # Recommended: 20-24 layers

print("=" * 60)
print("FP8 Inference Pipeline for Mixedbread Embedding Model")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"FP8 Enabled: {ENABLE_FP8}")
print(f"Adaptive Layers: {ADAPTIVE_LAYERS}")
print("")


FP8 Inference Pipeline for Mixedbread Embedding Model
Model: mixedbread-ai/mxbai-embed-2d-large-v1
FP8 Enabled: True
Adaptive Layers: 22



## Step 5: Load Model and Apply FP8

Load the model, apply adaptive layers, and configure for FP8 inference.


In [5]:
# Load the model
logger.info("Loading model...")
model = SentenceTransformer(MODEL_NAME)

# Apply adaptive layers
if hasattr(model[0].auto_model, 'encoder') and hasattr(model[0].auto_model.encoder, 'layer'):
    original_layers = len(model[0].auto_model.encoder.layer)
    model[0].auto_model.encoder.layer = model[0].auto_model.encoder.layer[:ADAPTIVE_LAYERS]
    logger.info(f"✓ Adaptive layers: {original_layers} -> {ADAPTIVE_LAYERS}")

# Move to CUDA if available
if torch.cuda.is_available():
    model = model.to('cuda')
    logger.info(f"✓ Model moved to CUDA: {torch.cuda.get_device_name(0)}")

# Setup FP8 inference
if ENABLE_FP8:
    model = setup_fp8_inference(model)


[32m2025-11-26 18:19:21.132[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m2[0m - [1mLoading model...[0m
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/758 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

[32m2025-11-26 18:19:34.875[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m9[0m - [1m✓ Adaptive layers: 24 -> 22[0m
[32m2025-11-26 18:19:34.880[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m14[0m - [1m✓ Model moved to CUDA: Tesla T4[0m
[32m2025-11-26 18:19:34.881[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup_fp8_inference[0m:[36m16[0m - [1mSetting up FP8 inference with torchao...[0m
[32m2025-11-26 18:19:34.882[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36msetup_fp8_inference[0m:[36m54[0m - [34m[1mMethod 1 (torchao.replace) not available: No module named 'torchao.replace'[0m
[32m2025-11-26 18:19:38.034[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup_fp8_inference[0m:[36m76[0m - [1m✓ Using torch.compile (may use FP8 if hardware supports)[0m


## Step 6: Run Inference

Test the FP8 inference pipeline with sample sentences.


In [6]:
# Test inference
logger.info("")
logger.info("Running inference test...")
test_sentences = [
    "This is a sample sentence for embedding.",
    "FP8 inference provides faster computation.",
    "Mixedbread models are efficient for embeddings.",
    "The model uses FP8 computation without weight conversion.",
    "This enables better performance on supported hardware."
]

logger.info(f"Encoding {len(test_sentences)} sentences...")
embeddings = model.encode(test_sentences, show_progress_bar=True)

logger.info("")
logger.info("✓ Inference completed successfully!")
logger.info(f"  Embedding shape: {embeddings.shape}")
logger.info(f"  Embedding dtype: {embeddings.dtype}")
logger.info(f"  Number of sentences: {len(test_sentences)}")
logger.info(f"  Embedding dimension: {embeddings.shape[1]}")

# Show sample embedding values
logger.info("")
logger.info("Sample embedding (first 10 values):")
logger.info(f"  {embeddings[0][:10]}")

print("\n" + "=" * 60)
print("FP8 Inference Pipeline - SUCCESS")
print("=" * 60)


[32m2025-11-26 18:20:17.363[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m2[0m - [1m[0m
[32m2025-11-26 18:20:17.364[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m3[0m - [1mRunning inference test...[0m
[32m2025-11-26 18:20:17.365[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m12[0m - [1mEncoding 5 sentences...[0m


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  return torch._C._get_cublas_allow_tf32()
W1126 18:20:39.260000 506 torch/_inductor/utils.py:1558] [0/0] Not enough SMs to use max_autotune_gemm mode
[32m2025-11-26 18:20:46.749[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m15[0m - [1m[0m
[32m2025-11-26 18:20:46.751[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m16[0m - [1m✓ Inference completed successfully![0m
[32m2025-11-26 18:20:46.753[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m17[0m - [1m  Embedding shape: (5, 1024)[0m
[32m2025-11-26 18:20:46.754[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m18[0m - [1m  Embedding dtype: float32[0m
[32m2025-11-26 18:20:46.755[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m19[0m - [1m  Number of sentences: 5[0m
[32m2025-11-26 18:20:46.759[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m20[0m - [1m  Embedding dimension: 1024


FP8 Inference Pipeline - SUCCESS


## Step 7: Test with Custom Sentences (Optional)

You can test the model with your own sentences here.


In [8]:
# Test with your own sentences
custom_sentences = [
    "Bengaluru is capital of karnataka",
    "Karnataka is seventh largest state in india"
]

# Uncomment to run:
custom_embeddings = model.encode(custom_sentences, show_progress_bar=True)
print(f"Custom embeddings shape: {custom_embeddings.shape}")
print(f"First embedding sample: {custom_embeddings[0][:5]}")


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

W1126 18:25:24.037000 506 torch/fx/experimental/symbolic_shapes.py:6833] [0/1] _maybe_guard_rel() was called on non-relation expression Eq(s18, s43) | Eq(s43, 1)
W1126 18:25:24.044000 506 torch/fx/experimental/symbolic_shapes.py:6833] [0/1] _maybe_guard_rel() was called on non-relation expression Eq(s41, s53) | Eq(s53, 1)


Custom embeddings shape: (2, 1024)
First embedding sample: [-0.21950784  0.13991517 -0.43079486 -0.4858035  -1.291174  ]


## Notes

- **FP8 Support**: FP8 inference requires compatible hardware (e.g., NVIDIA H100 GPUs with FP8 tensor cores)
- **Model Weights**: The model weights remain in their original precision (FP16/FP32)
- **Computation**: Only the computation during inference uses FP8 format
- **Fallback**: If FP8 is not supported, the script automatically falls back to standard precision
- **Adaptive Layers**: You can adjust `ADAPTIVE_LAYERS` (recommended: 20-24) to balance speed and quality
