<a href="https://colab.research.google.com/github/linhkid/batch_invariant_ops_extra/blob/main/batch_invariant_extra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install -e .

Obtaining file:///content/drive/MyDrive/Workarchive/Experiments/batch_invariant_ops
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: batch-invariant-ops
  Building editable for batch-invariant-ops (pyproject.toml) ... [?25l[?25hdone
  Created wheel for batch-invariant-ops: filename=batch_invariant_ops-0.1.0-0.editable-py3-none-any.whl size=4026 sha256=f2d6c243a009d65e5921cb10d77e6cd0a6004eb53f4061fcca25a228b22865d6
  Stored in directory: /tmp/pip-ephem-wheel-cache-jjmggupv/wheels/12/93/97/6a8f2178a73349082bbf89122de6424efd891a2f124c8cac45
Successfully built batch-invariant-ops
Installing collected packages: batch-invariant-ops
Successfully installed batch-invariant-ops-0.1.0


In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from batch_invariant_ops import set_batch_invariant_mode

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from batch_invariant_ops import set_batch_invariant_mode

# Load a model (e.g., GPT-2)
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.eval()

# Prepare input
text = "The future of AI is"
inputs = tokenizer(text, return_tensors="pt")

# Run inference with batch-invariant mode for deterministic results
with set_batch_invariant_mode():
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=50,
            temperature=0.0,  # Use greedy decoding for determinism
            do_sample=False
        )

result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

For classification task


In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from batch_invariant_ops import set_batch_invariant_mode

# Load a sentiment analysis model
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model.eval()

# Prepare batch with repeated inputs (to test batch invariance)
text = "This movie is fantastic!"
inputs = tokenizer([text] * 4, return_tensors="pt", padding=True)

# Enable batch-invariant mode for deterministic results
with set_batch_invariant_mode():
    with torch.no_grad():
        outputs = model(**inputs)
        # All items in batch should have identical logits
        logits = outputs.logits

# Check that all batch items produce identical results
print(f"All outputs identical: {torch.allclose(logits[0], logits[1])}")

In [None]:
import torch
from batch_invariant_ops import set_batch_invariant_mode
torch.set_default_device('cuda')

# Just to get the logging out of the way haha
with set_batch_invariant_mode(True):
    pass

def test_batch_invariance():
    B, D = 2048, 4096
    a = torch.linspace(-100, 100, B*D).reshape(B, D)
    b = torch.linspace(-100, 100, D*D).reshape(D, D)

    # Method 1: Matrix-vector multiplication (batch size 1)
    out1 = torch.mm(a[:1], b)

    # Method 2: Matrix-matrix multiplication, then slice (full batch)
    out2 = torch.mm(a, b)[:1]

    # Check if results are identical
    diff = (out1 - out2).abs().max()
    print(f"Difference: {diff.item()}")
    return diff.item() == 0

# Test with standard PyTorch (likely to show differences)
print("Standard PyTorch:")
with set_batch_invariant_mode(False):
    is_deterministic = test_batch_invariance()
    print(f"Deterministic: {is_deterministic}")

# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
    is_deterministic = test_batch_invariance()
    print(f"Deterministic: {is_deterministic}")

In [None]:
!python test_attention_experiment.py --implementation both

In [None]:
!python test_mm_and_bmm.py

In [8]:
#!/usr/bin/env python3
"""
A comprehensive analysis suite to demonstrate the effects of torch.mm
batch-dependency across various conditions in a transformer architecture.

This script includes experiments for:
1. Basic forward pass accumulation in a single layer.
2. The impact on mixed-precision casting (FP16, BF16).
3. The interaction with INT8 quantization on a supported submodule (Linear layers).
4. Error accumulation across a deep, multi-layer network.
5. The effect on non-linear activations (GELU) and gradient flow (backward pass).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
# The batch_invariant_ops library is assumed to be in the same directory or installed
from batch_invariant_ops import set_batch_invariant_mode
from torch.ao.quantization import get_default_qconfig

# ======================================================================
# Model Definitions
# ======================================================================

class SimpleTransformerLayer(nn.Module):
    """
    A simple transformer layer that uses multiple torch.mm operations.
    Updated to store intermediate activations and FFN inputs.
    """
    def __init__(self, d_model=768, num_heads=12, d_ff=3072):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.out_linear = nn.Linear(d_model, d_model, bias=False)

        self.ff1 = nn.Linear(d_model, d_ff, bias=False)
        self.ff2 = nn.Linear(d_ff, d_model, bias=False)

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        self.intermediate_activation = None
        self.ffn_input = None # Store input to FFN for quantization experiment

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.out_linear(out)

        x_after_attn = self.ln1(x + out)
        self.ffn_input = x_after_attn

        ff_intermediate = F.gelu(self.ff1(x_after_attn))
        self.intermediate_activation = ff_intermediate
        ff_out = self.ff2(ff_intermediate)
        x_final = self.ln2(x_after_attn + ff_out)

        return x_final

class QuantizableLinearBlock(nn.Module):
    """
    An isolated, minimal block containing only quantizable nn.Linear layers.
    This avoids unsupported ops like LayerNorm and GELU during quantization.
    """
    def __init__(self, d_model=768, d_ff=3072):
        super().__init__()
        self.ff1 = nn.Linear(d_model, d_ff, bias=False)
        self.ff2 = nn.Linear(d_ff, d_model, bias=False)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        # We don't apply GELU here to keep the module simple for quantization
        x = self.ff1(x)
        x = self.ff2(x)
        x = self.dequant(x)
        return x

# ======================================================================
# Experiment Functions
# ======================================================================

def test_accumulation_in_transformer():
    print("\n" + "="*70)
    print("ACCUMULATION EXPERIMENT: Large Transformer Layer")
    print("="*70)

    torch.set_default_device('cuda')
    model = SimpleTransformerLayer(d_model=1536, num_heads=32).eval()

    batch_size, seq_len, d_model = 64, 256, 1536
    x = torch.linspace(-1000, 1000, batch_size * seq_len * d_model).reshape(batch_size, seq_len, d_model)

    print(f"Model: SimpleTransformerLayer (d_model={d_model}, heads=32)")
    print(f"Total parameters: ~{sum(p.numel() for p in model.parameters())/1e6:.1f}M")

    print("\n1. Standard PyTorch Mode:")
    print("-" * 50)
    with set_batch_invariant_mode(False):
        model(x[:1])
        ffn_input_single_std = model.ffn_input.detach().clone()
        model(x)
        ffn_input_batch_std = model.ffn_input.detach().clone()
    # Recalculate diff on the ffn_input which is the source of the error for the next step
    diff = (ffn_input_single_std - ffn_input_batch_std[:1]).abs()
    print(f"  Max difference (pre-FFN): {diff.max().item():.6f}")
    print(f"  Result: {'✗ Batch-dependent' if diff.max().item() > 1e-6 else '✓ Batch-invariant'}")

    print("\n2. Batch-Invariant Mode:")
    print("-" * 50)
    with set_batch_invariant_mode(True):
        model(x[:1])
        ffn_input_single_inv = model.ffn_input.detach().clone()
        model(x)
        ffn_input_batch_inv = model.ffn_input.detach().clone()
    diff_inv = (ffn_input_single_inv - ffn_input_batch_inv[:1]).abs()
    print(f"  Max difference (pre-FFN): {diff_inv.max().item():.6f}")
    print(f"  Result: {'✓ Batch-invariant' if diff_inv.max().item() < 1e-6 else '✗ Batch-dependent'}")

    return ffn_input_batch_std, ffn_input_batch_inv

def test_mixed_precision(ffn_input_std, ffn_input_inv):
    print("\n" + "="*70)
    print("MIXED PRECISION EXPERIMENT")
    print("="*70)
    for dtype, name in [(torch.float32, "FP32"), (torch.float16, "FP16"), (torch.bfloat16, "BF16")]:
        print(f"\n{name} precision:")
        out_std_prec, out_inv_prec = ffn_input_std.to(dtype), ffn_input_inv.to(dtype)
        final_diff = (out_std_prec - out_inv_prec).abs().max().item()
        print(f"  Final difference after casting to {name}: {final_diff:.6f}")

def test_quantization_effects(ffn_input_std, ffn_input_inv):
    print("\n" + "="*70)
    print("QUANTIZATION EXPERIMENT ON LINEAR BLOCK (Final)")
    print("="*70)

    torch.set_default_device('cpu')
    ffn_input_std = ffn_input_std.to('cpu')
    ffn_input_inv = ffn_input_inv.to('cpu')

    d_model, d_ff = ffn_input_std.shape[-1], 3072*2

    print(f"Isolating the Linear Layers (d_model={d_model}, d_ff={d_ff})")

    # Get FP32 baseline difference from the isolated linear block
    fp32_linear_block = QuantizableLinearBlock(d_model=d_model, d_ff=d_ff).eval()
    out_std_fp32 = fp32_linear_block(ffn_input_std)
    out_inv_fp32 = fp32_linear_block(ffn_input_inv)
    fp32_diff = (out_std_fp32 - out_inv_fp32).abs().max().item()
    print(f"\nFP32 baseline difference from Linear block: {fp32_diff:.6f}")

    # Quantize the Linear block and test again
    print("\nQuantizing the Linear block:")
    quant_linear_block = QuantizableLinearBlock(d_model=d_model, d_ff=d_ff).eval()
    quant_linear_block.qconfig = get_default_qconfig("x86")
    torch.ao.quantization.prepare(quant_linear_block, inplace=True)
    quant_linear_block(ffn_input_std) # Calibration
    torch.ao.quantization.convert(quant_linear_block, inplace=True)

    out_std_quant = quant_linear_block(ffn_input_std)
    out_inv_quant = quant_linear_block(ffn_input_inv)
    quant_diff = (out_std_quant - out_inv_quant).abs().max().item()

    print(f"  Difference after INT8 quantization:    {quant_diff:.6f}")
    print("  → Batch differences are affected by the noise from quantization.")

def test_deep_accumulation():
    print("\n" + "="*70)
    print("DEEP ACCUMULATION EXPERIMENT: 8-Layer Transformer")
    print("="*70)

    torch.set_default_device('cuda')
    num_layers, d_model, heads = 8, 768, 24
    layers = nn.ModuleList([SimpleTransformerLayer(d_model=d_model, num_heads=heads) for _ in range(num_layers)]).eval()

    batch_size, seq_len = 32, 128
    x = torch.linspace(-5000, 5000, batch_size * seq_len * d_model).reshape(batch_size, seq_len, d_model)

    print(f"Model: {num_layers} transformer layers (d_model={d_model}, heads={heads})")

    with set_batch_invariant_mode(False):
        x_single, x_batch = x[:1].clone(), x.clone()
        for i, layer in enumerate(layers):
            x_single, x_batch = layer(x_single), layer(x_batch)
            diff = (x_single - x_batch[:1]).abs().max().item()
            print(f"Layer {i+1} difference (standard): {diff:.6f}")

def test_gradient_and_activation_effects():
    print("\n" + "="*70)
    print("GRADIENT & ACTIVATION EXPERIMENT")
    print("="*70)

    torch.set_default_device('cuda')
    model = SimpleTransformerLayer(d_model=768, num_heads=16).train()

    batch_size, seq_len, d_model = 32, 128, 768
    x = torch.linspace(-100, 100, batch_size * seq_len * d_model).reshape(batch_size, seq_len, d_model)
    x.requires_grad = True

    print("\n1. Standard PyTorch Mode:")
    print("-" * 50)
    with set_batch_invariant_mode(False):
        model.zero_grad()
        out_single = model(x[:1]); out_single.sum().backward()
        grad_single = x.grad[:1].detach().clone()

        x.grad.zero_()
        out_batch = model(x); out_batch.sum().backward()
        grad_batch_first = x.grad[:1].detach().clone()

    grad_diff_std = (grad_single - grad_batch_first).abs().max().item()
    print(f"  Max gradient difference: {grad_diff_std:.6f}")
    print(f"  Result: {'✗ Batch-dependent' if grad_diff_std > 1e-6 else '✓ Batch-invariant'}")

    print("\n2. Batch-Invariant Mode:")
    print("-" * 50)
    with set_batch_invariant_mode(True):
        x.grad.zero_(); model.zero_grad()
        out_single_inv = model(x[:1]); out_single_inv.sum().backward()
        grad_single_inv = x.grad[:1].detach().clone()

        x.grad.zero_()
        out_batch_inv = model(x); out_batch_inv.sum().backward()
        grad_batch_first_inv = x.grad[:1].detach().clone()

    grad_diff_inv = (grad_single_inv - grad_batch_first_inv).abs().max().item()
    print(f"  Max gradient difference: {grad_diff_inv:.6f}")
    print(f"  Result: {'✓ Batch-invariant' if grad_diff_inv < 1e-6 else '✗ Batch-dependent'}")

def main():
    use_cuda = torch.cuda.is_available()
    device_name = torch.cuda.get_device_name(0) if use_cuda else 'CPU'
    print(f"Using device: {device_name}")
    print(f"PyTorch version: {torch.__version__}")

    if use_cuda:
        ffn_in_std, ffn_in_inv = test_accumulation_in_transformer()
        test_mixed_precision(ffn_in_std, ffn_in_inv)
        test_quantization_effects(ffn_in_std, ffn_in_inv)
        test_deep_accumulation()
        test_gradient_and_activation_effects()
    else:
        print("\nCUDA not available. Skipping GPU-dependent experiments.")

if __name__ == "__main__":
    main()



Using device: NVIDIA A100-SXM4-80GB
PyTorch version: 2.8.0+cu126

ACCUMULATION EXPERIMENT: Large Transformer Layer
Model: SimpleTransformerLayer (d_model=1536, heads=32)
Total parameters: ~18.9M

1. Standard PyTorch Mode:
--------------------------------------------------
  Max difference (pre-FFN): 0.000010
  Result: ✗ Batch-dependent

2. Batch-Invariant Mode:
--------------------------------------------------
  Max difference (pre-FFN): 0.000000
  Result: ✓ Batch-invariant

MIXED PRECISION EXPERIMENT

FP32 precision:
  Final difference after casting to FP32: 0.005751

FP16 precision:
  Final difference after casting to FP16: 0.005859

BF16 precision:
  Final difference after casting to BF16: 0.015625

QUANTIZATION EXPERIMENT ON LINEAR BLOCK (Final)
Isolating the Linear Layers (d_model=1536, d_ff=6144)

FP32 baseline difference from Linear block: 0.001808

Quantizing the Linear block:


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.ao.quantization.prepare(quant_linear_block, inplace=True)
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e q

  Difference after INT8 quantization:    0.034143
  → Batch differences are affected by the noise from quantization.

DEEP ACCUMULATION EXPERIMENT: 8-Layer Transformer
Model: 8 transformer layers (d_model=768, heads=24)
Layer 1 difference (standard): 0.000002
Layer 2 difference (standard): 0.000003
Layer 3 difference (standard): 0.000003
Layer 4 difference (standard): 0.000003
Layer 5 difference (standard): 0.000003
Layer 6 difference (standard): 0.000003
Layer 7 difference (standard): 0.000003
Layer 8 difference (standard): 0.000003

GRADIENT & ACTIVATION EXPERIMENT

1. Standard PyTorch Mode:
--------------------------------------------------
  Max gradient difference: 0.000000
  Result: ✓ Batch-invariant

2. Batch-Invariant Mode:
--------------------------------------------------
  Max gradient difference: 0.000000
  Result: ✓ Batch-invariant
