<a href="https://colab.research.google.com/github/csalnav2/QdotCS/blob/master/Unsloth_Solutions_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unsloth (Demo Solutions Notebook)

This notebook collects various code snippets that address specific tasks:

1. **nF4 → Triton** (Quantized 4-bit kernel demo)
2. **QLoRA + `torch.compile`** (Naive QLoRA example, no graph breaks)
3. **QLoRA + FSDP** (Fully Sharded Data Parallel + LoRA injection)
4. **Memory-Efficient Backprop** (Chunked final linear + cross-entropy)
5. **Windows Support** (Python scripts to build/install `unsloth`, plus test code)
6. **Flexible Attention** ("Unsloth" style chunked attention examples)
7. **Sequence Classification Patch** (Inject LoRA into `AutoModelForSequenceClassification`)
8. **Refactored Attention** (xformers, SDPA, flash-attn, fallback in one interface)

Feel free to skip cells or modify as needed.

---
## 1) **nF4 → Triton**

**Goal**: Demonstrate converting 4-bit weights (nF4 style) and using a Triton kernel to do matrix multiplication without fully decompressing everything into float16/float32 first.

**Note**: This code is a **minimal skeleton**. Real nF4 implementations might have more complex scaling logic, per-row or per-channel quant parameters, etc.

In [1]:
import torch
import triton
import triton.language as tl

# ------------------------------
# 1) Fake nF4 encode + decode
# ------------------------------
def nf4_encode(weights_fp16: torch.Tensor):
    """
    Naive example: we assume the values are roughly in [-8, +7.75].
    1) clamp
    2) shift => [0, ~15.75]
    3) scale => [0..15]
    4) round => integer
    5) pack two 4-bit values per byte.
    """
    # clamp + shift
    clamped = torch.clamp(weights_fp16, -8.0, 7.75)
    shifted = clamped + 8.0

    # scale to [0..15], then round
    scaled = shifted * (15.0 / 15.75)
    quant_4bit = torch.round(scaled).to(torch.uint8)

    # pack two 4-bit into one byte: [value0 (lower 4bits), value1 (upper 4bits)]
    # shape stays the same in terms of # of elements, but physically we combine pairs.

    # We'll flatten first for simplicity
    flat = quant_4bit.view(-1)
    if len(flat) % 2 != 0:
        # pad if odd
        flat = torch.cat([flat, flat.new_zeros(1)])

    # even indices -> lower 4 bits, odd indices -> upper 4 bits
    low_4 = flat[0::2] & 0xF
    high_4 = flat[1::2] & 0xF
    packed = (high_4 << 4) | low_4
    return packed

def nf4_unpad_and_reshape(packed: torch.Tensor, shape):
    """
    Helper to handle if we padded an odd element.
    """
    # total n elements for shape must be shape.numel()
    # each output element is 4 bits, so we need shape.numel() // 2 bytes if it's even.
    n_el = shape.numel()
    n_bytes_needed = (n_el + 1)//2  # if odd, we used +1 in the code
    # so we slice the 'packed' and ignore extra.
    packed = packed[:n_bytes_needed]
    return packed

def nf4_decode(packed: torch.Tensor, out_shape) -> torch.Tensor:
    """
    Unpacks the 4-bit values from a [N/2] byte buffer,
    re-scales them back to float16 in [-8..7.75], shape=out_shape.
    """
    # each byte has 2 nibbles => 2 values
    # lower nibble: x & 0xF
    # upper nibble: (x >> 4) & 0xF
    flat = packed.view(-1)
    n_el = out_shape.numel()

    vals_0 = (flat & 0xF).to(torch.float16)
    vals_1 = ((flat >> 4) & 0xF).to(torch.float16)

    # interleave
    decoded = torch.zeros(n_el, dtype=torch.float16, device=flat.device)
    decoded[0::2] = vals_0
    decoded[1::2] = vals_1

    # scale back => float16 in [-8..7.75]
    # reverse step: val in [0..15], => shift to [0..15.75], => shift down by 8.
    # recall we used: shifted * (15 / 15.75)
    # so decode => val*(15.75/15) - 8
    rescaled = decoded * (15.75 / 15.0) - 8.0

    return rescaled.view(out_shape)

# ------------------------------
# 2) Minimal Triton Kernel for nF4 MatMul
# ------------------------------
@triton.jit
def matmul_nf4_kernel(
    A_ptr,  # int32 ptr
    B_ptr,  # int32 ptr
    C_ptr,  # float16 ptr
    M, N, K,
    stride_am, stride_an,
    stride_bm, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr
):
    """
    A is [M x K] in nF4-packed format, B is [K x N] in nF4-packed,
    C is [M x N] in float16.
    This kernel decodes partial blocks, accumulates them.

    For a more robust solution, you'd handle partial tiles,
    dynamic shapes, etc. This is a minimal example.
    """
    # row idx and col idx in the output tile
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    # We’ll accumulate in float32 for partial sums
    accum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Loop over K in chunks of BLOCK_K
    # For each chunk, we decode a tile from A and B, then multiply.
    # NOTE: This is simplistic. Real kernels do more advanced blocking.

    for k_block_start in range(0, K, BLOCK_K):
        # load A block
        # we have indices [off_m, k_block_start : k_block_start+BLOCK_K]
        # decode nF4 on the fly

        # Pseudocode: We can do a direct load from A_ptr.
        # But in real code, we must decode from nF4. We'll do it in Python first for simplicity.
        # This is where a custom decoding kernel might be used. We'll just do naive loads.
        # We'll show the concept, not a fully correct kernel.

        # In a real kernel, you'd load the 4-bit data from A_ptr + offsets, decode into float16,
        # then cast to float32.
        # We’ll do something like:
        a_block = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
        b_block = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)

        # multiply
        accum += tl.dot(a_block, b_block)

    # write accum to C
    # offset in C => c_ptr + row*stride_cm + col*stride_cn
    # We'll store only if in-bounds.
    # This is incomplete but a minimal shape.
    out_f16 = accum.to(tl.float16)
    # store the tile
    for i in range(BLOCK_M):
        for j in range(BLOCK_N):
            c_row = off_m[i]
            c_col = off_n[j]
            if (c_row < M) and (c_col < N):
                offset_c = c_row * stride_cm + c_col * stride_cn
                tl.store(C_ptr + offset_c, out_f16[i, j])

# ------------------------------
# 3) Python driver to do nF4 matmul
# ------------------------------
def matmul_nf4(A_fp16, B_fp16):
    """
    A_fp16: [M, K]
    B_fp16: [K, N]
    Return: [M, N] float16

    *Just a naive demonstration.* We'll do an actual decode on CPU, normal mm.
     Then we run a fake Triton kernel that doesn't do the real decode.
    """
    device = A_fp16.device
    M, K = A_fp16.shape
    K2, N = B_fp16.shape
    assert K == K2

    # 1) encode A, B to nF4
    A_packed = nf4_encode(A_fp16)
    B_packed = nf4_encode(B_fp16)

    # create output buffer
    C = torch.empty((M, N), dtype=torch.float16, device=device)

    # for the sake of demonstration, we'll do a single-grid launch (no tiling)
    # real code: you tile across M and N.
    block_m = 128
    block_n = 128
    block_k = 32

    grid = ( (M + block_m - 1)//block_m, (N + block_n - 1)//block_n )

    matmul_nf4_kernel[grid](
        A_packed,
        B_packed,
        C,
        M, N, K,
        N, 1,  # strides for A ? (placeholder)
        N, 1,  # strides for B ? (placeholder)
        N, 1,  # strides for C
        BLOCK_M=block_m,
        BLOCK_N=block_n,
        BLOCK_K=block_k
    )

    return C

# 4) Demo usage
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    A_fp16 = torch.randn((64, 64), dtype=torch.float16, device=device)
    B_fp16 = torch.randn((64, 32), dtype=torch.float16, device=device)
    C_out = matmul_nf4(A_fp16, B_fp16)
    print("C_out shape:", C_out.shape)
    # In reality, the kernel is incomplete, so results are garbage.
    # This is just a skeleton.
    print("Example done (skeleton).")


CompilationError: at 61:20:
        # multiply
        accum += tl.dot(a_block, b_block)

    # write accum to C
    # offset in C => c_ptr + row*stride_cm + col*stride_cn
    # We'll store only if in-bounds.
    # This is incomplete but a minimal shape.
    out_f16 = accum.to(tl.float16)
    # store the tile
    for i in range(BLOCK_M):
        for j in range(BLOCK_N):
            c_row = off_m[i]
                    ^
ValueError('Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)')

---
## 2) **QLoRA + `torch.compile`** (Naive Example)

This snippet demonstrates a simple QLoRA-like module (4-bit quant + LoRA adapters), then wraps the model in `torch.compile` to ensure we avoid graph breaks.

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

############################################
# 1) Define a naive QLoRA quant module
############################################
class QLoRAQuant(nn.Module):
    """
    A toy QLoRA-like module that:
      - Maintains a base weight.
      - Has low-rank "LoRA" factors A & B (rank adaptation).
      - Performs a simple 4-bit quant -> dequant cycle on the final weight.
    """
    def __init__(self, in_features, out_features, bit_width=4, rank=4):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bit_width = bit_width
        self.rank = rank

        # Base weight
        self.base_weight = nn.Parameter(
            torch.randn(out_features, in_features) * 0.02
        )

        # Low-rank adaptation factors
        self.lora_A = nn.Parameter(
            torch.randn(self.rank, in_features) * 0.01
        )
        self.lora_B = nn.Parameter(
            torch.randn(out_features, self.rank) * 0.01
        )

        # Scale param (could also be a buffer or computed offline)
        self.scales = nn.Parameter(torch.ones(out_features))

    def forward(self, x):
        # 1) Combine base weight with low-rank adaptation
        adapted_weight = self.base_weight + (self.lora_B @ self.lora_A)

        # 2) Scale the weight before quant
        scales_2d = self.scales.unsqueeze(1)
        scaled_weight = adapted_weight * scales_2d

        # 3) Clamp for 4-bit range (naive example)
        clamped_weight = torch.clamp(scaled_weight, -8.0, 7.75)

        # 4) Map to integer range [0..15]
        shifted = clamped_weight + 8.0
        scaled = shifted * (15.0 / 15.75)
        quantized = torch.round(scaled)

        # 5) Dequant
        dequant_shifted = quantized * (15.75 / 15.0)
        dequant_clamped = dequant_shifted - 8.0
        final_weight = dequant_clamped / scales_2d

        # 6) Apply final weight to the input
        return x @ final_weight.T

# ----------------------------------
# 2) Simple model using QLoRAQuant
# ----------------------------------
class SimpleQLoRAModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.quant_linear = QLoRAQuant(
            in_features=input_dim,
            out_features=hidden_dim,
            bit_width=4,
            rank=4
        )
        self.activation = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.quant_linear(x)
        x = self.activation(x)
        return self.fc(x)

# ----------------------------------
# 3) Example Training Loop w/ torch.compile
# ----------------------------------
def example_training_q_lora_compile():
    torch.manual_seed(42)

    # Synthetic dataset: 1000 samples, each w/ 32 features.
    X = torch.randn(1000, 32)
    y = (X.sum(dim=1) > 0).long()

    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Build QLoRA-based model
    model = SimpleQLoRAModel(input_dim=32, hidden_dim=16, num_classes=2)

    # Compile the model (PyTorch 2.0+)
    compiled_model = torch.compile(model)

    optimizer = torch.optim.Adam(compiled_model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    compiled_model.train()
    for epoch in range(5):
        total_loss = 0.0
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            logits = compiled_model(batch_x)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} - Loss = {total_loss / len(dataloader):.4f}")

# Demo run
if __name__ == "__main__":
    example_training_q_lora_compile()


Epoch 1 - Loss = 0.4699
Epoch 2 - Loss = 0.4220
Epoch 3 - Loss = 0.3963
Epoch 4 - Loss = 0.3737
Epoch 5 - Loss = 0.3581


---
## 3) **QLoRA + FSDP**

A single-cell script that:
- Loads BERT in half precision
- Injects LoRA modules
- Wraps the model in FSDP (Fully Sharded Data Parallel)
- Trains only the LoRA parameters

In [6]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    AutoConfig,
)

def setup_distributed():
    if dist.is_initialized():
        return 0
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl")
        return local_rank
    else:
        # Single GPU fallback
        dist.init_process_group(
            backend="nccl",
            init_method='file:///tmp/fsdp_example',
            rank=0,
            world_size=1
        )
        torch.cuda.set_device(0)
        return 0

def load_bert_fp16(model_name="bert-base-uncased"):
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype=torch.float16
    )
    return model

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, lora_rank=8, alpha=1.0):
        super().__init__()
        self.lora_down = nn.Linear(in_features, lora_rank, bias=False)
        self.lora_up   = nn.Linear(lora_rank, out_features, bias=False)
        nn.init.zeros_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)
        self.alpha = alpha

    def forward(self, x):
        return self.alpha * self.lora_up(self.lora_down(x))

def inject_lora_in_bert(model, lora_rank=8, alpha=1.0):
    linear_list = []
    for full_name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            linear_list.append((full_name, module))

    for full_name, module in linear_list:
        print(f"Injecting LoRA into: {full_name} => {module}")
        lora_mod = LoRALinear(
            module.in_features,
            module.out_features,
            lora_rank=lora_rank,
            alpha=alpha
        ).half()  # keep LoRA in half precision

        # Register
        safe_name = full_name.replace(".", "_")
        model.add_module(f"lora_{safe_name}", lora_mod)

        # Patch forward
        orig_forward = module.forward
        def custom_forward(m_self, x, orig_forward=orig_forward, lora_mod=lora_mod):
            base_out = orig_forward(x)
            lora_out = lora_mod(x)
            return base_out + lora_out
        module.forward = custom_forward.__get__(module, module.__class__)

    return model

def main():
    local_rank = setup_distributed()
    model_name = "bert-base-uncased"
    print(f"Loading {model_name} in half precision...")

    model = load_bert_fp16(model_name)

    # For older FSDP, ensure requires_grad=True on all
    for n, p in model.named_parameters():
        p.requires_grad = True

    print("Injecting LoRA (rank=8, alpha=1.0) in float16...")
    model = inject_lora_in_bert(model, lora_rank=8, alpha=1.0)

    # Collect LoRA params only
    lora_params = []
    for name, p in model.named_parameters():
        if "lora_" in name:
            lora_params.append(p)
    print(f"Collected {len(lora_params)} LoRA params for the optimizer.")

    fsdp_model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device(),
    )

    optimizer = torch.optim.AdamW(lora_params, lr=1e-4)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    texts = [
        "Hello world, how are you?",
        "Testing BERT in half precision with LoRA",
        "Combining FSDP for memory efficiency!",
    ] * 5

    encodings = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    input_ids = encodings["input_ids"].cuda(local_rank)
    attention_mask = encodings["attention_mask"].cuda(local_rank)
    labels = input_ids.clone()

    # Create random mask for masked LM
    with torch.no_grad():
        rand_mask = torch.rand_like(labels.float()) < 0.15
        labels[~rand_mask] = -100

    fsdp_model.train()
    epochs = 2
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = fsdp_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        if local_rank == 0:
            print(f"Epoch {epoch+1} / {epochs} done, loss = {loss.item()}")

    dist.barrier()
    if local_rank == 0:
        print("Training complete!")

# Actually call main() in the same cell so we see output
if __name__ == "__main__":
    main()


Loading bert-base-uncased in half precision...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Injecting LoRA (rank=8, alpha=1.0) in float16...
Injecting LoRA into: bert.encoder.layer.0.attention.self.query => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.self.key => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.self.value => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.output.dense => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.intermediate.dense => Linear(in_features=768, out_features=3072, bias=True)
Injecting LoRA into: bert.encoder.layer.0.output.dense => Linear(in_features=3072, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.1.attention.self.query => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.1.attention.self.key => Linear(in_features=768, out_features=768, bias=True)
Injecting

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Epoch 1 / 2 done, loss = 4.46875
Epoch 2 / 2 done, loss = 4.09765625
Training complete!


---
## 4) **Memory-Efficient Backprop** (Chunked Final MatMul + Cross-Entropy)

This code chunk demonstrates how to avoid creating a huge `[B*S, vocab]` logits matrix at once, by chunking the matmul into smaller pieces. This reduces memory usage at the cost of multiple partial computations.

In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def forward_chunked_linear_ce(
    X,       # [batch_size * seq_len, hidden_dim]
    W,       # [hidden_dim, vocab_size]
    targets, # [batch_size * seq_len]
    chunk_size=8192
):
    """
    Perform chunked X @ W => partial logits => cross entropy without storing the entire [B*S, vocab].
    """
    batch_tokens = X.shape[0]  # B*S
    vocab_size = W.shape[1]

    logsumexp_buf = None
    correct_logits = torch.zeros(batch_tokens, device=X.device)

    for start_col in range(0, vocab_size, chunk_size):
        end_col = min(start_col + chunk_size, vocab_size)
        W_chunk = W[:, start_col:end_col]
        logits_chunk = X.matmul(W_chunk)

        # incremental log-sum-exp
        if logsumexp_buf is None:
            logsumexp_buf = torch.logsumexp(logits_chunk, dim=1)
        else:
            combined = torch.stack([logsumexp_buf, torch.logsumexp(logits_chunk, dim=1)], dim=0)
            logsumexp_buf = torch.logsumexp(combined, dim=0)

        # gather correct logits if they're in this chunk
        mask = (targets >= start_col) & (targets < end_col)
        if mask.any():
            local_positions = targets[mask] - start_col
            correct_logit_vals = logits_chunk[mask, local_positions]
            correct_logits[mask] = correct_logit_vals

    ce = -(correct_logits - logsumexp_buf)
    loss = ce.mean()
    return loss

class MemoryEfficientLinearCELoss(nn.Module):
    def __init__(self, hidden_dim, vocab_size, chunk_size=8192):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.chunk_size = chunk_size
        self.weight = nn.Parameter(torch.randn(hidden_dim, vocab_size) * 0.02)

    def forward(self, X, targets):
        return forward_chunked_linear_ce(
            X,
            self.weight,
            targets,
            chunk_size=self.chunk_size
        )

# Demo
def example_mem_eff_backprop():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 2
    seq_len = 5
    hidden_dim = 8
    vocab_size = 20
    chunk_size = 5

    X = torch.randn(batch_size * seq_len, hidden_dim, requires_grad=True, device=device)
    targets = torch.randint(0, vocab_size, size=(batch_size * seq_len,), dtype=torch.long, device=device)

    mem_eff_ce = MemoryEfficientLinearCELoss(hidden_dim, vocab_size, chunk_size=chunk_size).to(device)
    loss = mem_eff_ce(X, targets)
    loss.backward()

    print(f"Loss: {loss.item():.4f}")
    print("Grad w.r.t X:", X.grad)
    print("Grad w.r.t W:", mem_eff_ce.weight.grad)

if __name__ == "__main__":
    example_mem_eff_backprop()


Loss: 2.9659
Grad w.r.t X: tensor([[ 1.6163e-03,  1.8064e-03,  1.0825e-03, -2.3256e-03, -2.5538e-03,
          1.5989e-03,  3.3694e-03,  3.7867e-04],
        [ 1.6906e-03,  1.8409e-03,  1.1629e-03, -2.3396e-03, -2.4640e-03,
          1.6535e-03,  3.5120e-03,  4.4430e-04],
        [-3.4139e-03, -3.5918e-03,  3.9076e-04,  3.1014e-03,  2.6923e-03,
         -1.7271e-03,  1.5015e-03, -2.0493e-03],
        [-2.2461e-03, -3.2119e-03,  1.1062e-03,  3.4821e-05, -3.2107e-04,
          2.6581e-03,  1.3169e-04,  5.5187e-04],
        [-7.0122e-04, -3.0019e-04, -1.1601e-03,  2.6959e-03, -1.0794e-03,
          2.0529e-03,  1.7755e-03,  1.8649e-03],
        [-1.1357e-03,  4.3799e-05,  2.1444e-04, -2.3237e-03, -5.2027e-03,
         -1.5427e-03, -8.1709e-04, -3.0248e-03],
        [ 8.1323e-04,  2.5520e-03,  5.1153e-04, -1.5681e-03,  2.4222e-03,
          2.8720e-03,  1.7460e-03,  1.9653e-03],
        [-4.6521e-03,  1.2229e-03, -4.4408e-03,  4.1418e-03,  6.8102e-04,
         -1.2283e-03, -1.9430e-03,  1.

In [5]:
# ============================================
# 1) Confirm GPU type (T4, A100, etc.).
# ============================================
!nvidia-smi

# ============================================
# 2) [Optional] Install system-level CUDA 11.8 libs
#    so bitsandbytes can find libcusparse.so.11, etc.
#    If you get 'libcusparse.so.11 not found' errors,
#    installing these packages often helps.
# ============================================
!apt-get update -y
!apt-get install -y --no-install-recommends \
    cuda-cudart-11-8 \
    cuda-cusparse-11-8 \
    cuda-libraries-11-8

# ============================================
# 3) Wipe older Torch/bitsandbytes/xformers/triton
#    to avoid conflicts.
# ============================================
!pip uninstall -y torch bitsandbytes xformers triton

# ============================================
# 4) Install PyTorch 2.0.1+cu118, matching torchvision/torchaudio.
# ============================================
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 \
    --extra-index-url https://download.pytorch.org/whl/cu118

# ============================================
# 5) (Optional) Re-install pinned bitsandbytes, xformers, triton
#    to confirm environment is consistent.
#    (Though build_unsloth.py may also install them depending on the markers.)
# ============================================
!pip install bitsandbytes==0.41.1 xformers==0.0.22 triton==2.0.0 \
    --extra-index-url https://download.pytorch.org/whl/cu118


Thu Feb 20 16:26:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

---
## 5) **Windows Support**

Below are two scripts:
- **`build_unsloth.py`**: Creates a `pyproject.toml`, builds a wheel, and installs it.
- **`test_deps.py`**: Installs bitsandbytes, xformers, triton, then tests them.

These are primarily relevant for letting `unsloth` (and associated libraries) build on Windows.

In [6]:
%%writefile build_unsloth.py
import os
import sys
import subprocess

# 1) Write pyproject.toml with correct license syntax, allowing Python 3.9+
toml_content = """\
[project]
name = "unsloth"
version = "0.1.0"
description = "unsloth: Windows-friendly package for bitsandbytes, xformers, triton"
readme = "README.md"
requires-python = ">=3.9"

[project.license]
text = "MIT"

authors = [
  { name = "Your Name", email = "you@example.com" }
]

# Dependencies only install if environment markers match (e.g., Windows).
# On Colab Linux + CUDA 11.8, these might not do anything,
# but we still define them to show the "Windows-friendly" idea.
dependencies = [
  "torch==2.0.1+cu118; platform_system=='Windows'",
  "transformers==4.30.2",
  "accelerate==0.20.3",
  "bitsandbytes==0.39.1",
  "xformers==0.0.20",
  "triton==2.0.0",
]

[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
"""

with open("pyproject.toml", "w", encoding="utf-8") as f:
    f.write(toml_content)

# 2) Minimal package structure
os.makedirs("src/unsloth", exist_ok=True)
with open("src/unsloth/__init__.py", "w", encoding="utf-8") as f:
    f.write('# unsloth package init - minimal\n')

# Minimal README
with open("README.md", "w", encoding="utf-8") as f:
    f.write("# unsloth\n\nA Windows-friendly package with bitsandbytes, xformers, triton.\n")

print("=== pyproject.toml created. Attempting to build and install locally... ===")

# 3) Upgrade pip and install build tools
subprocess.run([
    "python", "-m", "pip", "install", "--upgrade",
    "pip", "build", "setuptools>=61", "wheel"
], check=True)

# 4) Build the wheel
build_result = subprocess.run(["python", "-m", "build"], capture_output=True, text=True)
if build_result.returncode != 0:
    print("ERROR: Build failed. Output:\n")
    print(build_result.stdout)
    print(build_result.stderr)
    sys.exit(1)

# 5) Check dist/ directory
if not os.path.isdir("dist"):
    print("ERROR: 'dist/' directory not found, build likely failed.")
    sys.exit(1)

dist_files = os.listdir("dist")
if not dist_files:
    print("ERROR: 'dist/' directory is empty, no wheel found.")
    sys.exit(1)

wheel_files = [f for f in dist_files if f.endswith(".whl")]
if not wheel_files:
    print("ERROR: No .whl file found in dist/. Found:", dist_files)
    sys.exit(1)

wheel_path = os.path.join("dist", wheel_files[0])

# 6) Install the wheel with extra index for cu118
cmd = [
    "python",
    "-m",
    "pip",
    "install",
    wheel_path,
    "--extra-index-url",
    "https://download.pytorch.org/whl/cu118"
]
print("\nInstalling wheel with command:", " ".join(cmd))

install_result = subprocess.run(cmd, capture_output=True, text=True)
if install_result.returncode != 0:
    print("ERROR: Failed to install the wheel. Output:\n")
    print(install_result.stdout)
    print(install_result.stderr)
    sys.exit(1)

print("Successfully installed the unsloth wheel from dist/!\n")
print("Installation log:")
print(install_result.stdout)

# ============== End of build_unsloth.py ==============


Overwriting build_unsloth.py


In [7]:
!python build_unsloth.py

=== pyproject.toml created. Attempting to build and install locally... ===
Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Collecting build
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting setuptools>=61
  Downloading setuptools-75.8.0-py3-none-any.whl.metadata (6.7 kB)
Collecting pyproject_hooks (from build)
  Downloading pyproject_hooks-1.2.0-py3-none-any.whl.metadata (1.3 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading build-1.2.2.post1-py3-none-any.whl (22 kB)
Downloading setuptools-75.8.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyproject_hooks-1.2.0-py3-none-any.whl (10 kB)
Installing collected packages: setuptools, pyproject_hooks, pip, build
  Attempting uninstall: se

In [14]:
################################################################################
# ONE-CELL COLAB SCRIPT: PyTorch Nightly (2.2.0 + cu121),
# bitsandbytes 0.45.2, xformers 0.0.24, tested on A100 CUDA 12.x
################################################################################

print("=== Checking GPU and driver info ===")
!nvidia-smi

print("\n=== 1) Uninstall older Torch, bitsandbytes, xformers, triton ===")
!pip uninstall -y torch bitsandbytes xformers triton

print("\n=== 2) Install PyTorch NIGHTLY 2.2.0+cu121, plus torchvision, torchaudio")
print("         from the official 'nightly/cu121' index. ===")

# We use --pre (pre-release) and a special index URL for nightly cu121 builds.
!pip install --pre torch torchvision torchaudio \
    --index-url https://download.pytorch.org/whl/nightly/cu121

print("\n=== 3) Install bitsandbytes 0.45.2 and xformers 0.0.24 (built for Torch 2.2.0+cu121) ===")
# We'll just use PyPI. bitsandbytes 0.45.2 has CUDA 12.1 support.
# xformers 0.0.24 is built for Torch 2.2.0+cu121, so it won't conflict.
!pip install bitsandbytes==0.45.2 xformers==0.0.24

print("\n=== 4) Write test_deps.py script to verify bitsandbytes, xformers, and triton ===")

test_deps_code = """import os
import sys
import torch

os.environ["BNB_CUDA_VERSION"] = "121"  # bitsandbytes tries libbitsandbytes_cuda121.so

# 1) Test bitsandbytes
try:
    import bitsandbytes as bnb
    print("\\n=== bitsandbytes import OK ===")
    linear_8bit = bnb.nn.Linear8bitLt(128, 64).cuda()
    dummy_in = torch.randn(16, 128, device='cuda', dtype=torch.float16)
    dummy_out = linear_8bit(dummy_in)
    print('bitsandbytes linear8bit forward pass successful. Output shape:', dummy_out.shape)
except Exception as ex:
    print('bitsandbytes usage error:', ex)
    sys.exit(1)

# 2) Test xformers
try:
    import xformers
    print("\\n=== xformers import OK ===")
    from xformers.ops import fmha
    q = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    k = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    v = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    out = fmha.memory_efficient_attention(q, k, v)
    print('xformers fmha output shape:', out.shape)
except Exception as ex:
    print('xformers usage error:', ex)
    sys.exit(1)

# 3) Test triton (bundled in Torch 2.2.0 nightly)
try:
    import triton
    import triton.language as tl
    print("\\n=== triton import OK ===")

    @triton.jit
    def add_kernel(x_ptr, y_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
        pid = tl.program_id(0)
        offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        x = tl.load(x_ptr + offset)
        y = tl.load(y_ptr + offset)
        tl.store(output_ptr + offset, x + y)

    x_t = torch.randn(1024, device='cuda')
    y_t = torch.randn(1024, device='cuda')
    output_t = torch.empty(1024, device='cuda')
    grid = (1024 // 256,)
    add_kernel[grid](x_t, y_t, output_t, BLOCK_SIZE=256)
    print('triton add_kernel test, first 5 results:', output_t[:5].tolist())
except Exception as ex:
    print('triton usage error:', ex)
    sys.exit(1)

print('\\nAll tests passed! bitsandbytes, xformers, and triton are working.')
"""

with open("test_deps.py", "w") as f:
    f.write(test_deps_code)

print("\n=== 5) Run test_deps.py to confirm everything works with Torch 2.2.0+cu121 ===")
!python test_deps.py


=== Checking GPU and driver info ===
Thu Feb 20 17:19:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
           

---
## 6) **Flexible Attention**

Here’s a snippet that builds various attention masks (causal, sliding, etc.) and uses a chunked approach, plus `torch.compile` if you like. This demonstration shows different mask types in one place.

In [15]:
import sys
import math
import torch

def build_attention_mask(seq_len, mask_type="causal", window_size=64, device="cuda"):
    """
    Creates an attention mask:
      - "causal": blocks j > i (standard auto-regressive mask).
      - "sliding": local window = ±window_size around each token.
    """
    mask = torch.zeros(seq_len, seq_len, device=device)
    if mask_type == "causal":
        # Triangular upper matrix => block j>i
        casual_mat = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        mask[casual_mat.bool()] = float("-1e9")
    elif mask_type == "sliding":
        # For each position i, block everything outside [i - window_size, i + window_size]
        for i in range(seq_len):
            left = max(0, i - window_size)
            right = min(seq_len, i + window_size + 1)
            mask[i, :left] = float("-1e9")
            mask[i, right:] = float("-1e9")
    else:
        raise ValueError(f"Unknown mask_type={mask_type}")
    return mask

def flex_attention(q, k, v, attn_mask):
    """
    Simple scaled dot-product attention:
      q, k, v: shape [batch, seq_len, d_model]
      attn_mask: shape [seq_len, seq_len], large negative => blocked
    """
    d_model = q.shape[-1]
    # (batch, seq_len, d_model) @ (batch, d_model, seq_len) => (batch, seq_len, seq_len)
    attn_scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(d_model)

    # Apply the mask (broadcast => (batch, seq_len, seq_len))
    attn_scores = attn_scores + attn_mask.unsqueeze(0)

    # Softmax and multiply by v
    attn_probs = torch.softmax(attn_scores, dim=-1)
    out = torch.bmm(attn_probs, v)
    return out

# Fallback approach for Python 3.11:
# - If Python < 3.11 => we compile
# - If Python >= 3.11 => skip compile to avoid runtime error
if sys.version_info < (3, 11):
    compiled_flex_attention = torch.compile(flex_attention, mode="default")
    print("Using torch.compile on Python < 3.11.")
else:
    compiled_flex_attention = flex_attention
    print("Skipping torch.compile (Python 3.11+ not yet supported).")

def run_flex_attention_demo():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 2
    d_model = 64

    for mask_type in ["causal", "sliding"]:
        print(f"\n===> Testing mask_type = {mask_type}")
        for seq_len in [128, 256, 300, 512]:
            q = torch.randn(batch_size, seq_len, d_model, device=device)
            k = torch.randn(batch_size, seq_len, d_model, device=device)
            v = torch.randn(batch_size, seq_len, d_model, device=device)

            base_mask = build_attention_mask(seq_len, mask_type=mask_type, device=device)
            out = compiled_flex_attention(q, k, v, base_mask)
            print(f"seq_len={seq_len}, out.shape={out.shape}, mask_type={mask_type}")

if __name__ == "__main__":
    run_flex_attention_demo()


Skipping torch.compile (Python 3.11+ not yet supported).

===> Testing mask_type = causal
seq_len=128, out.shape=torch.Size([2, 128, 64]), mask_type=causal
seq_len=256, out.shape=torch.Size([2, 256, 64]), mask_type=causal
seq_len=300, out.shape=torch.Size([2, 300, 64]), mask_type=causal
seq_len=512, out.shape=torch.Size([2, 512, 64]), mask_type=causal

===> Testing mask_type = sliding
seq_len=128, out.shape=torch.Size([2, 128, 64]), mask_type=sliding
seq_len=256, out.shape=torch.Size([2, 256, 64]), mask_type=sliding
seq_len=300, out.shape=torch.Size([2, 300, 64]), mask_type=sliding
seq_len=512, out.shape=torch.Size([2, 512, 64]), mask_type=sliding


---
## 7) **Sequence Classification Patch** (LoRA + `AutoModelForSequenceClassification`)

We patch `AutoModelForSequenceClassification` by injecting LoRA modules into every `nn.Linear` in the model, then fine-tune only the LoRA parameters on a toy dataset.

In [1]:
################################################################################
# SINGLE-CELL COLAB SCRIPT:
# LoRA BERT classification w/ Torch 2.1.0+cu121 & Transformers 4.31.0
# Removing peft & older libraries => fix the 'adapter_kwargs' error.
################################################################################

print("=== Checking GPU / driver info ===")
!nvidia-smi

print("\n=== 1) Uninstall conflicting packages (torch, transformers, peft, xformers, etc.) ===")
!pip uninstall -y torch transformers peft xformers tokenizers bitsandbytes

print("\n=== 2) Install Torch 2.1.0+cu121 & Transformers==4.31.0 ===")
!pip install torch==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.31.0

print("\n=== 3) Running your LoRA BERT classification code ===")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
)

class ToyClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=32):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, lora_rank=4, alpha=1.0):
        super().__init__()
        self.lora_down = nn.Linear(in_features, lora_rank, bias=False)
        self.lora_up   = nn.Linear(lora_rank, out_features, bias=False)
        nn.init.zeros_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)
        self.alpha = alpha

    def forward(self, x):
        return self.alpha * self.lora_up(self.lora_down(x))

def patch_model_for_sequence_classification(model, lora_rank=4, alpha=1.0):
    modules_to_patch = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            modules_to_patch.append((name, module))

    for full_name, module in modules_to_patch:
        safe_name = full_name.replace(".", "_")
        lora_mod = LoRALinear(
            module.in_features,
            module.out_features,
            lora_rank=lora_rank,
            alpha=alpha
        ).to(module.weight.device, module.weight.dtype)

        # Register it
        model.add_module(f"lora_{safe_name}", lora_mod)

        # Patch forward
        orig_forward = module.forward
        def custom_forward(m_self, x, orig_forward=orig_forward, lora_layer=lora_mod):
            base_out = orig_forward(x)
            lora_out = lora_layer(x)
            return base_out + lora_out

        module.forward = custom_forward.__get__(module, module.__class__)

    return model

def finetune_sequence_classification():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_name = "bert-base-uncased"
    num_labels = 2

    config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
    model.to(device)

    # Inject LoRA
    patch_model_for_sequence_classification(model, lora_rank=4, alpha=1.0)

    texts = [
        "I love this product, it is amazing!",
        "This is the worst experience of my life.",
        "The movie was quite entertaining.",
        "Horrible service, will not come back!"
    ]
    labels = [1, 0, 1, 0]
    dataset = ToyClassificationDataset(texts, labels, tokenizer, max_length=16)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    # Only train LoRA params
    lora_params = []
    for param_name, param in model.named_parameters():
        if "lora_" in param_name:
            param.requires_grad = True
            lora_params.append(param)
        else:
            param.requires_grad = False

    optimizer = optim.AdamW(lora_params, lr=1e-4)
    model.train()
    epochs = 3
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, avg_loss={avg_loss:.4f}")

    model.eval()
    sample_text = ["I dislike the taste, not recommended."]
    enc = tokenizer(sample_text, truncation=True, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
    preds = torch.argmax(logits, dim=-1)
    print("\nInference Test:")
    print(f"Input: {sample_text}")
    print(f"Logits: {logits.cpu().numpy()}")
    print(f"Predicted label: {preds.item()} (0=Neg,1=Pos)")

finetune_sequence_classification()


=== Checking GPU / driver info ===
Thu Feb 20 18:43:30 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
             

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3, avg_loss=0.8157
Epoch 2/3, avg_loss=0.7862
Epoch 3/3, avg_loss=0.7097

Inference Test:
Input: ['I dislike the taste, not recommended.']
Logits: [[-0.36591572  0.18960014]]
Predicted label: 1 (0=Neg,1=Pos)


---
## 8) **Refactored Attention**

Merging `xformers`, PyTorch’s SDPA, `flash_attn`, and a fallback “flex” approach in a single function.

In [2]:
import warnings

try:
    import xformers.ops as xops
    XFORMERS_AVAILABLE = True
except ImportError:
    XFORMERS_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
except ImportError:
    FLASH_ATTN_AVAILABLE = False

SDPA_AVAILABLE = hasattr(torch.nn.functional, "scaled_dot_product_attention")

def flex_custom_attention(q, k, v, attn_mask=None):
    d_k = q.shape[-1]
    scores = torch.matmul(q, k.transpose(-1, -2)) / (d_k ** 0.5)
    if attn_mask is not None:
        scores = scores + attn_mask
    weights = torch.softmax(scores, dim=-1)
    weights = weights.to(v.dtype)
    out = torch.matmul(weights, v)
    return out

def xformers_attention(q, k, v, attn_mask=None):
    B, H, L, D = q.shape
    q_ = q.reshape(B*H, L, D)
    k_ = k.reshape(B*H, L, D)
    v_ = v.reshape(B*H, L, D)

    bool_mask = None
    if attn_mask is not None:
        expanded = attn_mask.expand(B, H, L, L).reshape(B*H, L, L)
        bool_mask = (expanded < -1e4)
    out = xops.memory_efficient_attention(
        q_, k_, v_,
        attn_mask=bool_mask,
        p=0.0
    )
    return out.reshape(B, H, L, D)

def flash_attention(q, k, v, attn_mask=None):
    import flash_attn
    B, H, L, D = q.shape
    q_ = q.reshape(B*H, L, D)
    k_ = k.reshape(B*H, L, D)
    v_ = v.reshape(B*H, L, D)
    out = flash_attn.flash_attn_func(
        q_, k_, v_,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False
    )
    return out.reshape(B, H, L, D)

def sdpa_attention(q, k, v, attn_mask=None):
    from torch.nn.functional import scaled_dot_product_attention as sdpa
    B, H, L, D = q.shape
    q_ = q.permute(2, 0, 1, 3).reshape(L, B*H, D)
    k_ = k.permute(2, 0, 1, 3).reshape(L, B*H, D)
    v_ = v.permute(2, 0, 1, 3).reshape(L, B*H, D)

    am = None
    if attn_mask is not None:
        am = attn_mask.expand(B, H, L, L).reshape(B*H, L, L)
    out_ = sdpa(q_, k_, v_, attn_mask=am, dropout_p=0.0, is_causal=False)
    out = out_.reshape(L, B, H, D).permute(1, 2, 0, 3)
    return out

def unified_attention(q, k, v, attn_mask=None, backend="auto"):
    if backend == "auto":
        if XFORMERS_AVAILABLE:
            backend = "xformers"
        elif FLASH_ATTN_AVAILABLE:
            backend = "flash"
        elif SDPA_AVAILABLE:
            backend = "sdpa"
        else:
            backend = "flex"

    if backend == "xformers":
        if not XFORMERS_AVAILABLE:
            raise RuntimeError("xformers not installed!")
        return xformers_attention(q, k, v, attn_mask)
    elif backend == "flash":
        if not FLASH_ATTN_AVAILABLE:
            raise RuntimeError("flash_attn not installed!")
        return flash_attention(q, k, v, attn_mask)
    elif backend == "sdpa":
        if not SDPA_AVAILABLE:
            raise RuntimeError("PyTorch >=2.0 needed for SDPA!")
        return sdpa_attention(q, k, v, attn_mask)
    else:
        return flex_custom_attention(q, k, v, attn_mask)

# Demo usage
def example_unified_attention():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    B, H, L, D = 2, 4, 16, 64
    q = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    k = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    v = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    attn_mask = torch.zeros((B, 1, L, L), device=device, dtype=torch.float32)
    blocked = torch.rand((B, 1, L, L), device=device) < 0.2
    attn_mask[blocked] = float("-inf")
    out_flex = unified_attention(q, k, v, attn_mask, backend="flex")
    print("fallback =>", out_flex.shape)

if __name__ == "__main__":
    example_unified_attention()


fallback => torch.Size([2, 4, 16, 64])


---
## Final Notes

- This notebookincludes separate code snippets for each task.
- Some cells (like the nF4 → Triton example) are skeletons or placeholders to illustrate core ideas.
