<a href="https://colab.research.google.com/github/csalnav2/QdotCS/blob/master/Unsloth_Solutions_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unsloth (Demo Solutions Notebook)

This notebook collects various code snippets that address specific tasks:

1. **nF4 → Triton** (Quantized 4-bit kernel demo)
2. **QLoRA + `torch.compile`** (Naive QLoRA example, no graph breaks)
3. **QLoRA + FSDP** (Fully Sharded Data Parallel + LoRA injection)
4. **Memory-Efficient Backprop** (Chunked final linear + cross-entropy)
5. **Windows Support** (Python scripts to build/install `unsloth`, plus test code)
6. **Flexible Attention** ("Unsloth" style chunked attention examples)
7. **Sequence Classification Patch** (Inject LoRA into `AutoModelForSequenceClassification`)
8. **Refactored Attention** (xformers, SDPA, flash-attn, fallback in one interface)

Feel free to skip cells or modify as needed.

---
## 1) **nF4 → Triton**

**Goal**: Demonstrate converting 4-bit weights (nF4 style) and using a Triton kernel to do matrix multiplication without fully decompressing everything into float16/float32 first.

**Note**: This code is a **minimal skeleton**. Real nF4 implementations might have more complex scaling logic, per-row or per-channel quant parameters, etc.

In [1]:
# If in Google Colab or a fresh environment, install specific versions:
# (You can comment these out if you already have matching versions.)
!pip uninstall -y torch triton
!pip install --no-cache-dir torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install --no-cache-dir triton==2.1.0

import torch
import triton
import triton.language as tl

print("PyTorch version:", torch.__version__)
print("Triton version:", triton.__version__)

# ---------------------------------------------------------------------
# Hard-coded tile sizes => compile-time constants
# ---------------------------------------------------------------------
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64

@triton.jit
def nf4_tile_matmul(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr
):
    """
    Single-tile NF4 decode + partial matmul.
    - M,N,K => must each be 64 exactly.
    - A => shape [M, K//2] in 4-bit. B => shape [K, N//2].
    - Output => [M, N] float16 in C.

    We'll decode each nibble from A, B => cast to float32 => do dot => store float16.
    """
    # Decode A => shape [BLOCK_M, BLOCK_K]
    rowA = tl.arange(0, BLOCK_M)
    colA = tl.arange(0, BLOCK_K)
    rowA = rowA[:, None]  # shape [BLOCK_M,1]
    colA = colA[None, :]  # shape [1,BLOCK_K]
    linearA = rowA*stride_am + colA*stride_ak
    byteA   = linearA // 2
    nibSelA = linearA & 1
    bytesA  = tl.load(A_ptr + byteA)
    shiftA  = nibSelA * 4
    nibA    = (bytesA >> shiftA) & 0xF
    valA    = tl.cast(nibA, tl.float32)

    # Decode B => shape [BLOCK_K, BLOCK_N]
    rowB = tl.arange(0, BLOCK_K)
    colB = tl.arange(0, BLOCK_N)
    rowB = rowB[:, None]  # shape [BLOCK_K,1]
    colB = colB[None, :]  # shape [1,BLOCK_N]
    linearB = rowB*stride_bk + colB*stride_bn
    byteB   = linearB // 2
    nibSelB = linearB & 1
    bytesB  = tl.load(B_ptr + byteB)
    shiftB  = nibSelB * 4
    nibB    = (bytesB >> shiftB) & 0xF
    valB    = tl.cast(nibB, tl.float32)

    # partial dot => shape [BLOCK_M, BLOCK_N]
    accum = tl.dot(valA, valB)

    # store to C => shape [BLOCK_M, BLOCK_N]
    out_f16 = accum.to(tl.float16)
    rowC = tl.arange(0, BLOCK_M)[:, None]
    colC = tl.arange(0, BLOCK_N)[None, :]
    c_offset = rowC*stride_cm + colC*stride_cn
    tl.store(C_ptr + c_offset, out_f16)

def nf4_tile_matmul_host(A_4bit, B_4bit, M, N, K):
    """
    Host function: A_4bit => [M, K//2], B_4bit => [K, N//2], each nibble = 4 bits
    Output => [M, N] float16.

    M,N,K must be 64 to match the tile kernel.
    """
    device = A_4bit.device
    C = torch.empty((M, N), dtype=torch.float16, device=device)

    # row-major strides
    stride_am = K
    stride_ak = 1
    stride_bk = N
    stride_bn = 1
    stride_cm = N
    stride_cn = 1

    # Single tile => (1,1) grid
    nf4_tile_matmul[(1,1)](
        A_4bit, B_4bit, C,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K
    )
    return C

# ---------------------------------------------------------------------
# Demo usage
# ---------------------------------------------------------------------
device = "cuda"
M, K, N = 64, 64, 64  # must match BLOCK_M,N,K=64

A_4bit = torch.randint(0, 256, (M, K//2), dtype=torch.uint8, device=device)
B_4bit = torch.randint(0, 256, (K, N//2), dtype=torch.uint8, device=device)

C_out = nf4_tile_matmul_host(A_4bit, B_4bit, M, N, K)
print("C_out shape:", C_out.shape)
print("C_out[:5, :5] =>\n", C_out[:5, :5])
print("Done. If 'map::at' error appears, it's likely a Triton environment bug.")


Found existing installation: torch 2.5.1+cu124
Uninstalling torch-2.5.1+cu124:
  Successfully uninstalled torch-2.5.1+cu124
Found existing installation: triton 3.1.0
Uninstalling triton-3.1.0:
  Successfully uninstalled triton-3.1.0
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.1.0+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.1.0%2Bcu121-cp311-cp311-linux_x86_64.whl (2200.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 GB[0m [31m263.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.16.0+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.16.0%2Bcu121-cp311-cp311-linux_x86_64.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m111.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.1.0+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.1.0%2Bcu121-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K  

AssertionError: libcuda.so cannot found!


For the following code regarding the nf4 to Triton conversion, this was the output on A100  on Google Colab but Google colab for some strange reason cannot recognize libcuda.so. The code stopped working and did not save the output. But, here is the actual out put for the code above. This code should work in differennt Colab, I think there is a glitch in Google colab. Currently trying to reproduce the ecexutation of this code in a different Colan environment.
PyTorch version: 2.5.1+cu124
Triton version: 3.1.0
C_out shape: torch.Size([64, 64])
C_out[:5, :5] =>
 tensor([[3144., 3652., 3420., 2944., 3216.],
        [3230., 3672., 3480., 3682., 3348.],
        [2744., 3316., 3248., 3372., 2804.],
        [3776., 3744., 3800., 3614., 3416.],
        [3976., 4264., 3874., 3886., 3506.]], device='cuda:0',
       dtype=torch.float16)

In [None]:
# Single cell for Colab or Jupyter
# 1) Force reinstall pinned Torch/Triton from PyTorch's cu121 index
# 2) If Transformers 4.32+ not found on PyPI for Py 3.11 => installs from source (latest master).
# 3) Attempt bitsandbytes unpinned (remove if wheels are missing).

print("=== [1] Uninstall old Torch/Transformers/bitsandbytes/peft/accelerate/trl ===")
!pip uninstall -y torch triton transformers bitsandbytes peft accelerate trl

print("\n=== [2] Install Torch 2.1.0+cu121 & Triton 2.1.0 from PyTorch's cu121 index ===")
INSTALL_CMD_TORCH = """
pip install --upgrade --force-reinstall --no-cache-dir \
  torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 \
    --index-url https://download.pytorch.org/whl/cu121 \
  triton==2.1.0
"""
print(INSTALL_CMD_TORCH)
res = get_ipython().system(INSTALL_CMD_TORCH)

print("\n=== [3] Attempting Transformers >=4.32 from PyPI, bitsandbytes unpinned, plus peft/accelerate/trl pinned ===")
INSTALL_CMD_PYPI = """
pip install --upgrade --force-reinstall --no-cache-dir \
  "transformers>=4.32.0,<5.0.0" \
  bitsandbytes \
  peft==0.5.0 \
  accelerate==0.23.0 \
  trl==0.6.0
"""
print(INSTALL_CMD_PYPI)
res = get_ipython().system(INSTALL_CMD_PYPI)

if res != 0:
    print("\nPyPI install for Transformers >=4.32 or bitsandbytes might have failed. Attempting source install for Transformers.\n")
    # 4) If Transformers fails from PyPI, we do a source install from Git
    #    This ensures we have the latest code with is_torch_less_than_1_11.
    #    But if bitsandbytes still fails, you may need to remove it or build from source.
    INSTALL_CMD_SRC = """
pip install --force-reinstall --no-cache-dir git+https://github.com/huggingface/transformers.git@main peft==0.5.0 accelerate==0.23.0 trl==0.6.0 bitsandbytes
"""
    print(INSTALL_CMD_SRC)
    get_ipython().system(INSTALL_CMD_SRC)

print("\n=== [4] Show final installed versions ===")
!pip show torch
!pip show transformers
!pip show bitsandbytes
!pip show peft
!pip show trl
!pip show accelerate

print("\n=== [5] Attempt `import transformers.trainer` ===")
try:
    import transformers
    print(f"transformers version => {transformers.__version__}")
    from transformers import trainer
    print("SUCCESS: `transformers.trainer` imported => 'is_torch_less_than_1_11' error is gone!")
except ImportError as e:
    print("ImportError =>", e)
except Exception as ex:
    print("Unexpected error =>", ex)


=== [1] Uninstall old Torch/Transformers/bitsandbytes/peft/accelerate/trl ===
[0mFound existing installation: torch 2.1.0+cu121
Uninstalling torch-2.1.0+cu121:
  Successfully uninstalled torch-2.1.0+cu121
[0mFound existing installation: triton 3.2.0
Uninstalling triton-3.2.0:
  Successfully uninstalled triton-3.2.0
[0m
=== [2] Install Torch 2.1.0+cu121 & Triton 2.1.0 from PyTorch's cu121 index ===

pip install --upgrade --force-reinstall --no-cache-dir   torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121     --index-url https://download.pytorch.org/whl/cu121   triton==2.1.0

[0mLooking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.1.0+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.1.0%2Bcu121-cp311-cp311-linux_x86_64.whl (2200.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 GB[0m [31m295.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.16.0+cu121
  Downloading https://downl


=== [3] Attempting Transformers >=4.32 from PyPI, bitsandbytes unpinned, plus peft/accelerate/trl pinned ===

pip install --upgrade --force-reinstall --no-cache-dir   "transformers>=4.32.0,<5.0.0"   bitsandbytes   peft==0.5.0   accelerate==0.23.0   trl==0.6.0

[0mCollecting transformers<5.0.0,>=4.32.0
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m113.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Collecting peft==0.5.0
  Downloading peft-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting accelerate==0.23.0
  Downloading accelerate-0.23.0-py3-none-any.whl.metadata (18 kB)
Collecting trl==0.6.0
  Downloading trl-0.6.0-py3-none-any.whl.metadata (9.8 kB)
Collecting numpy>=1.17 (from peft==0.5.0)
  Downloading numpy-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.me


PyPI install for Transformers >=4.32 or bitsandbytes might have failed. Attempting source install for Transformers.


pip install --force-reinstall --no-cache-dir git+https://github.com/huggingface/transformers.git@main peft==0.5.0 accelerate==0.23.0 trl==0.6.0 bitsandbytes

[0mCollecting git+https://github.com/huggingface/transformers.git@main
  Cloning https://github.com/huggingface/transformers.git (to revision main) to /tmp/pip-req-build-8x7auish
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-8x7auish
  Resolved https://github.com/huggingface/transformers.git to commit e18f233f6c8cba029324e2868fb68abdaf6badf3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting peft==0.5.0
  Downloading peft-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting accelerate==0.23.0
  Downloading accel


=== [4] Show final installed versions ===
[0mName: torch
Version: 2.6.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /usr/local/lib/python3.11/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, fastai, peft, sentence-transformers, timm, torchaudio, torchvision, trl
[0mName: transformers
Version: 4.50.0.dev0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and

  warn(


SUCCESS: `transformers.trainer` imported => 'is_torch_less_than_1_11' error is gone!


---
## 2) **QLoRA + `torch.compile`** (Naive Example)

This snippet demonstrates a simple QLoRA-like module (4-bit quant + LoRA adapters), then wraps the model in `torch.compile` to ensure we avoid graph breaks.

In [3]:

#!/usr/bin/env python3
"""
Full Modified Code Example

Goals:
 - Uninstall conflicting packages
 - Install pinned versions (torch==2.5.1+cu124, transformers==4.48.3, bitsandbytes==0.45.2, etc.)
 - Verify bitsandbytes can be imported without error
 - Provide minimal example code

Usage:
 - In a fresh environment (e.g., Google Colab cell), copy/paste and run
 - If "torch==2.5.1+cu124" doesn't exist, adjust to a valid version or skip pinning
"""

import sys, subprocess

def uninstall_conflicts():
    print("=== [A] Uninstalling old/conflicting packages ===")
    pkgs = ["torch","triton","transformers","bitsandbytes","peft","accelerate","trl"]
    cmd = [sys.executable, "-m", "pip", "uninstall", "-y"] + pkgs
    subprocess.run(cmd, check=False)

def install_pinned_versions():
    print("\n=== [B] Installing pinned packages (torch==2.5.1+cu124, etc.) ===")
    pinned = [
        # Example pinned: PyTorch 2.5.1+cu124 (MAY NOT EXIST publicly, adjust as needed!)
        "torch==2.5.1+cu124",
        # (Optional) If you need torchvision/torchaudio pinned, add them here:
        # "torchvision==0.16.1+cu124",
        # "torchaudio==2.5.1+cu124",
        # Provide an index-url if needed or skip if torch is from public nightly/dev
        # e.g. "--index-url", "https://download.pytorch.org/whl/nightly/cu124",

        # Transformers 4.48.3 (older version)
        "transformers==4.48.3",

        # bitsandbytes pinned
        "bitsandbytes==0.45.2",

        # Possibly add extras:
        # "peft==0.4.0",
        # "accelerate==0.20.3",
    ]

    # Force reinstall w/ no-cache
    cmd = [sys.executable, "-m", "pip", "install",
           "--upgrade", "--force-reinstall", "--no-cache-dir"] + pinned
    subprocess.run(cmd, check=False)

def verify_install():
    print("\n=== [C] Verifying installed versions ===")
    for pkg in ["torch","transformers","bitsandbytes"]:
        print(f"\n--- {pkg} ---")
        cmd = [sys.executable, "-m", "pip", "show", pkg]
        subprocess.run(cmd, check=False)

def minimal_test():
    """
    Minimal test code showing bitsandbytes can be imported.
    """
    print("\n=== [D] Attempting minimal import test ===")
    import torch
    print("PyTorch version =>", torch.__version__)

    import transformers
    print("Transformers version =>", transformers.__version__)

    import bitsandbytes as bnb
    print("bitsandbytes version =>", bnb.__version__)

def main():
    uninstall_conflicts()
    install_pinned_versions()
    verify_install()
    try:
        minimal_test()
        print("\nSuccess: bitsandbytes is now installed and importable!")
    except ImportError as e:
        print("\nImportError occurred:", e)
        print("Verify the pinned versions exist or remove strict version pins.")

if __name__ == "__main__":
    main()




=== [A] Uninstalling old/conflicting packages ===

=== [B] Installing pinned packages (torch==2.5.1+cu124, etc.) ===

=== [C] Verifying installed versions ===

--- torch ---

--- transformers ---

--- bitsandbytes ---

=== [D] Attempting minimal import test ===
PyTorch version => 2.5.1+cu124
Transformers version => 4.48.3

ImportError occurred: No module named 'bitsandbytes'
Verify the pinned versions exist or remove strict version pins.


---
## 3) **QLoRA + FSDP**

A single-cell script that:
- Loads BERT in half precision
- Injects LoRA modules
- Wraps the model in FSDP (Fully Sharded Data Parallel)
- Trains only the LoRA parameters

In [1]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist

try:
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp import ShardingStrategy
except ImportError as e:
    raise ImportError(
        "Your PyTorch version does not support FSDP properly. "
        "Please install PyTorch >= 2.0.0. Error detail:\n" + str(e)
    )

from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    AutoConfig,
)

##############################################################################
# 1) Distributed Setup
##############################################################################
def setup_distributed():
    """
    Checks environment variables (RANK, WORLD_SIZE) to see if we're in a
    multi-process environment. If not found, sets up a fallback single process.
    """
    if dist.is_initialized():
        return 0

    # Check if we have RANK/WORLD_SIZE
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl")
        return local_rank
    else:
        # Single GPU fallback
        dist.init_process_group(
            backend="nccl",
            init_method='file:///tmp/fsdp_example',  # a local temp file
            rank=0,
            world_size=1
        )
        torch.cuda.set_device(0)
        return 0

##############################################################################
# 2) Load BERT in half precision
##############################################################################
def load_bert_fp16(model_name="bert-base-uncased"):
    """
    Loads BERT in half precision for masked LM.
    """
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype=torch.float16  # half precision
    )
    return model

##############################################################################
# 3) LoRALinear injection
##############################################################################
class LoRALinear(nn.Module):
    """
    Minimal LoRA injection. We'll add a rank-limited "down -> up" path.
    alpha scales the LoRA output.
    """
    def __init__(self, in_features, out_features, lora_rank=8, alpha=1.0):
        super().__init__()
        self.lora_down = nn.Linear(in_features, lora_rank, bias=False)
        self.lora_up   = nn.Linear(lora_rank, out_features, bias=False)
        nn.init.zeros_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)
        self.alpha = alpha

    def forward(self, x):
        return self.alpha * self.lora_up(self.lora_down(x))

def inject_lora_in_bert(model, lora_rank=8, alpha=1.0):
    """
    Iterates over all nn.Linear in the BERT model, injecting a LoRALinear module
    and patching the forward to combine base + lora output.
    """
    linear_list = []
    for full_name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            linear_list.append((full_name, module))

    for full_name, module in linear_list:
        print(f"Injecting LoRA into: {full_name} => {module}")
        lora_mod = LoRALinear(
            module.in_features,
            module.out_features,
            lora_rank=lora_rank,
            alpha=alpha
        ).half()  # keep LoRA in half

        # Register as a submodule so params appear in model.named_parameters
        safe_name = full_name.replace(".", "_")
        model.add_module(f"lora_{safe_name}", lora_mod)

        # Patch forward
        orig_forward = module.forward

        def custom_forward(m_self, x,
                           orig_forward=orig_forward,
                           lora_mod=lora_mod):
            base_out = orig_forward(x)
            lora_out = lora_mod(x)
            return base_out + lora_out

        # Monkey-patch
        module.forward = custom_forward.__get__(module, module.__class__)

    return model

##############################################################################
# 4) Main: LoRA + FSDP Fine-tuning
##############################################################################
def main():
    local_rank = setup_distributed()
    model_name = "bert-base-uncased"

    if local_rank == 0:
        print(f"Loading {model_name} in half precision...")

    model = load_bert_fp16(model_name)

    # Ensure all parameters require grad
    for n, p in model.named_parameters():
        p.requires_grad = True

    if local_rank == 0:
        print("Injecting LoRA (rank=8, alpha=1.0) in float16...")

    model = inject_lora_in_bert(model, lora_rank=8, alpha=1.0)

    # Collect LoRA params only => partial finetuning
    lora_params = []
    for name, p in model.named_parameters():
        if "lora_" in name:
            lora_params.append(p)
    if local_rank == 0:
        print(f"Collected {len(lora_params)} LoRA params for the optimizer.")

    # Wrap with FSDP
    fsdp_model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device(),
    )

    optimizer = torch.optim.AdamW(lora_params, lr=1e-4)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    texts = [
        "Hello world, how are you?",
        "Testing BERT in half precision with LoRA",
        "Combining FSDP for memory efficiency!",
    ] * 5

    encodings = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    input_ids = encodings["input_ids"].cuda(local_rank)
    attention_mask = encodings["attention_mask"].cuda(local_rank)
    labels = input_ids.clone()

    # Create random mask for masked LM
    with torch.no_grad():
        rand_mask = torch.rand_like(labels.float()) < 0.15
        labels[~rand_mask] = -100

    fsdp_model.train()

    epochs = 2
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = fsdp_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        if local_rank == 0:
            print(f"Epoch {epoch+1} / {epochs}, loss = {loss.item()}")

    dist.barrier()
    if local_rank == 0:
        print("Training complete!")

if __name__ == "__main__":
    main()


Loading bert-base-uncased in half precision...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

Injecting LoRA (rank=8, alpha=1.0) in float16...
Injecting LoRA into: bert.encoder.layer.0.attention.self.query => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.self.key => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.self.value => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.attention.output.dense => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.0.intermediate.dense => Linear(in_features=768, out_features=3072, bias=True)
Injecting LoRA into: bert.encoder.layer.0.output.dense => Linear(in_features=3072, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.1.attention.self.query => Linear(in_features=768, out_features=768, bias=True)
Injecting LoRA into: bert.encoder.layer.1.attention.self.key => Linear(in_features=768, out_features=768, bias=True)
Injecting

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Epoch 1 / 2, loss = 6.38671875
Epoch 2 / 2, loss = 5.890625
Training complete!


---
## 4) **Memory-Efficient Backprop** (Chunked Final MatMul + Cross-Entropy)

This code chunk demonstrates how to avoid creating a huge `[B*S, vocab]` logits matrix at once, by chunking the matmul into smaller pieces. This reduces memory usage at the cost of multiple partial computations.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

##############################################################################
# 1) A chunk-level transformation: project => cross-entropy
##############################################################################
def transformation_function(x_chunk, linear_module, labels_chunk):
    """
    x_chunk: [chunk_size, hidden_dim]
    linear_module: nn.Linear (hidden_dim -> vocab_size)
    labels_chunk: [chunk_size]
    Returns a scalar cross-entropy loss
    """
    # For numeric stability if half
    logits = linear_module(x_chunk).float()
    ce = nn.CrossEntropyLoss(reduction="mean")
    loss = ce(logits, labels_chunk)
    return loss

##############################################################################
# 2) A function to do chunk-based forward with checkpointing
##############################################################################
def forward_chunked_checkpoints(X, linear_module, labels, chunk_size=1024):
    """
    Single-phase approach: we chunk X & labels, build a sub-graph for each chunk
    wrapped in torch.utils.checkpoint => memory is freed after each chunk.

    X: [batch_size * seq_len, hidden_dim], requires_grad=True
    labels: [batch_size * seq_len]
    linear_module: final projection
    chunk_size: how many rows to process per chunk
    """
    total_loss = None
    num_rows = X.shape[0]

    for start in range(0, num_rows, chunk_size):
        end = min(start + chunk_size, num_rows)
        x_chunk = X[start:end]
        label_chunk = labels[start:end]

        # We'll define a small wrapper for checkpoint
        def chunk_fn(x_sub, l_sub):
            # This function *must* build the subgraph for this chunk
            # Return the chunk's cross-entropy
            return transformation_function(x_sub, linear_module, l_sub)

        # We call checkpoint => PyTorch discards chunk's intermediate activations
        # except what's needed to rebuild in backward pass
        chunk_loss = checkpoint(chunk_fn, x_chunk, label_chunk)

        # Accumulate chunk losses
        if total_loss is None:
            total_loss = chunk_loss
        else:
            total_loss = total_loss + chunk_loss

    return total_loss

##############################################################################
# 3) Demo usage
##############################################################################
def demo_memory_efficient_linear():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Running on:", device)

    # Example shapes
    batch_size = 4
    seq_len = 4096
    hidden_dim = 4096
    vocab_size = 128000
    chunk_size = 1024  # micro-batch

    # Large input & labels
    X = torch.randn(batch_size * seq_len, hidden_dim,
                    device=device, dtype=torch.float16, requires_grad=True)
    labels = torch.randint(0, vocab_size, (batch_size * seq_len,),
                           device=device, dtype=torch.long)

    # Big final projection
    linear_module = nn.Linear(hidden_dim, vocab_size).to(device, dtype=torch.float16)

    # Single-phase forward (with chunk-based checkpointing)
    total_loss = forward_chunked_checkpoints(X, linear_module, labels, chunk_size)
    print("Loss:", total_loss.item())

    # Standard backward => triggers re-run for each chunk
    total_loss.backward()

    print("X.grad shape:", X.grad.shape)
    print("linear_module.weight.grad shape:", linear_module.weight.grad.shape)

if __name__ == "__main__":
    demo_memory_efficient_linear()


Running on: cuda


  return fn(*args, **kwargs)


Loss: 190.93194580078125
X.grad shape: torch.Size([16384, 4096])
linear_module.weight.grad shape: torch.Size([128000, 4096])


In [None]:
# ============================================
# 1) Confirm GPU type (T4, A100, etc.).
# ============================================
!nvidia-smi

# ============================================
# 2) [Optional] Install system-level CUDA 11.8 libs
#    so bitsandbytes can find libcusparse.so.11, etc.
#    If you get 'libcusparse.so.11 not found' errors,
#    installing these packages often helps.
# ============================================
!apt-get update -y
!apt-get install -y --no-install-recommends \
    cuda-cudart-11-8 \
    cuda-cusparse-11-8 \
    cuda-libraries-11-8

# ============================================
# 3) Wipe older Torch/bitsandbytes/xformers/triton
#    to avoid conflicts.
# ============================================
!pip uninstall -y torch bitsandbytes xformers triton

# ============================================
# 4) Install PyTorch 2.0.1+cu118, matching torchvision/torchaudio.
# ============================================
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 \
    --extra-index-url https://download.pytorch.org/whl/cu118

# ============================================
# 5) (Optional) Re-install pinned bitsandbytes, xformers, triton
#    to confirm environment is consistent.
#    (Though build_unsloth.py may also install them depending on the markers.)
# ============================================
!pip install bitsandbytes==0.41.1 xformers==0.0.22 triton==2.0.0 \
    --extra-index-url https://download.pytorch.org/whl/cu118


Thu Feb 20 16:26:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

---
## 5) **Windows Support**

Below are two scripts:
- **`build_unsloth.py`**: Creates a `pyproject.toml`, builds a wheel, and installs it.
- **`test_deps.py`**: Installs bitsandbytes, xformers, triton, then tests them.

These are primarily relevant for letting `unsloth` (and associated libraries) build on Windows.

In [None]:
%%writefile build_unsloth.py
import os
import sys
import subprocess

# 1) Write pyproject.toml with correct license syntax, allowing Python 3.9+
toml_content = """\
[project]
name = "unsloth"
version = "0.1.0"
description = "unsloth: Windows-friendly package for bitsandbytes, xformers, triton"
readme = "README.md"
requires-python = ">=3.9"

[project.license]
text = "MIT"

authors = [
  { name = "Your Name", email = "you@example.com" }
]

# Dependencies only install if environment markers match (e.g., Windows).
# On Colab Linux + CUDA 11.8, these might not do anything,
# but we still define them to show the "Windows-friendly" idea.
dependencies = [
  "torch==2.0.1+cu118; platform_system=='Windows'",
  "transformers==4.30.2",
  "accelerate==0.20.3",
  "bitsandbytes==0.39.1",
  "xformers==0.0.20",
  "triton==2.0.0",
]

[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
"""

with open("pyproject.toml", "w", encoding="utf-8") as f:
    f.write(toml_content)

# 2) Minimal package structure
os.makedirs("src/unsloth", exist_ok=True)
with open("src/unsloth/__init__.py", "w", encoding="utf-8") as f:
    f.write('# unsloth package init - minimal\n')

# Minimal README
with open("README.md", "w", encoding="utf-8") as f:
    f.write("# unsloth\n\nA Windows-friendly package with bitsandbytes, xformers, triton.\n")

print("=== pyproject.toml created. Attempting to build and install locally... ===")

# 3) Upgrade pip and install build tools
subprocess.run([
    "python", "-m", "pip", "install", "--upgrade",
    "pip", "build", "setuptools>=61", "wheel"
], check=True)

# 4) Build the wheel
build_result = subprocess.run(["python", "-m", "build"], capture_output=True, text=True)
if build_result.returncode != 0:
    print("ERROR: Build failed. Output:\n")
    print(build_result.stdout)
    print(build_result.stderr)
    sys.exit(1)

# 5) Check dist/ directory
if not os.path.isdir("dist"):
    print("ERROR: 'dist/' directory not found, build likely failed.")
    sys.exit(1)

dist_files = os.listdir("dist")
if not dist_files:
    print("ERROR: 'dist/' directory is empty, no wheel found.")
    sys.exit(1)

wheel_files = [f for f in dist_files if f.endswith(".whl")]
if not wheel_files:
    print("ERROR: No .whl file found in dist/. Found:", dist_files)
    sys.exit(1)

wheel_path = os.path.join("dist", wheel_files[0])

# 6) Install the wheel with extra index for cu118
cmd = [
    "python",
    "-m",
    "pip",
    "install",
    wheel_path,
    "--extra-index-url",
    "https://download.pytorch.org/whl/cu118"
]
print("\nInstalling wheel with command:", " ".join(cmd))

install_result = subprocess.run(cmd, capture_output=True, text=True)
if install_result.returncode != 0:
    print("ERROR: Failed to install the wheel. Output:\n")
    print(install_result.stdout)
    print(install_result.stderr)
    sys.exit(1)

print("Successfully installed the unsloth wheel from dist/!\n")
print("Installation log:")
print(install_result.stdout)

# ============== End of build_unsloth.py ==============


Overwriting build_unsloth.py


In [None]:
!python build_unsloth.py

=== pyproject.toml created. Attempting to build and install locally... ===
Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Collecting build
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting setuptools>=61
  Downloading setuptools-75.8.0-py3-none-any.whl.metadata (6.7 kB)
Collecting pyproject_hooks (from build)
  Downloading pyproject_hooks-1.2.0-py3-none-any.whl.metadata (1.3 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading build-1.2.2.post1-py3-none-any.whl (22 kB)
Downloading setuptools-75.8.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyproject_hooks-1.2.0-py3-none-any.whl (10 kB)
Installing collected packages: setuptools, pyproject_hooks, pip, build
  Attempting uninstall: se

In [None]:
################################################################################
# ONE-CELL COLAB SCRIPT: PyTorch Nightly (2.2.0 + cu121),
# bitsandbytes 0.45.2, xformers 0.0.24, tested on A100 CUDA 12.x
################################################################################

print("=== Checking GPU and driver info ===")
!nvidia-smi

print("\n=== 1) Uninstall older Torch, bitsandbytes, xformers, triton ===")
!pip uninstall -y torch bitsandbytes xformers triton

print("\n=== 2) Install PyTorch NIGHTLY 2.2.0+cu121, plus torchvision, torchaudio")
print("         from the official 'nightly/cu121' index. ===")

# We use --pre (pre-release) and a special index URL for nightly cu121 builds.
!pip install --pre torch torchvision torchaudio \
    --index-url https://download.pytorch.org/whl/nightly/cu121

print("\n=== 3) Install bitsandbytes 0.45.2 and xformers 0.0.24 (built for Torch 2.2.0+cu121) ===")
# We'll just use PyPI. bitsandbytes 0.45.2 has CUDA 12.1 support.
# xformers 0.0.24 is built for Torch 2.2.0+cu121, so it won't conflict.
!pip install bitsandbytes==0.45.2 xformers==0.0.24

print("\n=== 4) Write test_deps.py script to verify bitsandbytes, xformers, and triton ===")

test_deps_code = """import os
import sys
import torch

os.environ["BNB_CUDA_VERSION"] = "121"  # bitsandbytes tries libbitsandbytes_cuda121.so

# 1) Test bitsandbytes
try:
    import bitsandbytes as bnb
    print("\\n=== bitsandbytes import OK ===")
    linear_8bit = bnb.nn.Linear8bitLt(128, 64).cuda()
    dummy_in = torch.randn(16, 128, device='cuda', dtype=torch.float16)
    dummy_out = linear_8bit(dummy_in)
    print('bitsandbytes linear8bit forward pass successful. Output shape:', dummy_out.shape)
except Exception as ex:
    print('bitsandbytes usage error:', ex)
    sys.exit(1)

# 2) Test xformers
try:
    import xformers
    print("\\n=== xformers import OK ===")
    from xformers.ops import fmha
    q = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    k = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    v = torch.randn((1, 32, 8, 64), device='cuda', dtype=torch.float16)
    out = fmha.memory_efficient_attention(q, k, v)
    print('xformers fmha output shape:', out.shape)
except Exception as ex:
    print('xformers usage error:', ex)
    sys.exit(1)

# 3) Test triton (bundled in Torch 2.2.0 nightly)
try:
    import triton
    import triton.language as tl
    print("\\n=== triton import OK ===")

    @triton.jit
    def add_kernel(x_ptr, y_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
        pid = tl.program_id(0)
        offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        x = tl.load(x_ptr + offset)
        y = tl.load(y_ptr + offset)
        tl.store(output_ptr + offset, x + y)

    x_t = torch.randn(1024, device='cuda')
    y_t = torch.randn(1024, device='cuda')
    output_t = torch.empty(1024, device='cuda')
    grid = (1024 // 256,)
    add_kernel[grid](x_t, y_t, output_t, BLOCK_SIZE=256)
    print('triton add_kernel test, first 5 results:', output_t[:5].tolist())
except Exception as ex:
    print('triton usage error:', ex)
    sys.exit(1)

print('\\nAll tests passed! bitsandbytes, xformers, and triton are working.')
"""

with open("test_deps.py", "w") as f:
    f.write(test_deps_code)

print("\n=== 5) Run test_deps.py to confirm everything works with Torch 2.2.0+cu121 ===")
!python test_deps.py


=== Checking GPU and driver info ===
Thu Feb 20 17:19:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
           

---
## 6) **Flexible Attention**

Here’s a snippet that builds various attention masks (causal, sliding, etc.) and uses a chunked approach, plus `torch.compile` if you like. This demonstration shows different mask types in one place.

In [None]:
import sys
import math
import torch

def build_attention_mask(seq_len, mask_type="causal", window_size=64, device="cuda"):
    """
    Creates an attention mask:
      - "causal": blocks j > i (standard auto-regressive mask).
      - "sliding": local window = ±window_size around each token.
    """
    mask = torch.zeros(seq_len, seq_len, device=device)
    if mask_type == "causal":
        # Triangular upper matrix => block j>i
        casual_mat = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        mask[casual_mat.bool()] = float("-1e9")
    elif mask_type == "sliding":
        # For each position i, block everything outside [i - window_size, i + window_size]
        for i in range(seq_len):
            left = max(0, i - window_size)
            right = min(seq_len, i + window_size + 1)
            mask[i, :left] = float("-1e9")
            mask[i, right:] = float("-1e9")
    else:
        raise ValueError(f"Unknown mask_type={mask_type}")
    return mask

def flex_attention(q, k, v, attn_mask):
    """
    Simple scaled dot-product attention:
      q, k, v: shape [batch, seq_len, d_model]
      attn_mask: shape [seq_len, seq_len], large negative => blocked
    """
    d_model = q.shape[-1]
    # (batch, seq_len, d_model) @ (batch, d_model, seq_len) => (batch, seq_len, seq_len)
    attn_scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(d_model)

    # Apply the mask (broadcast => (batch, seq_len, seq_len))
    attn_scores = attn_scores + attn_mask.unsqueeze(0)

    # Softmax and multiply by v
    attn_probs = torch.softmax(attn_scores, dim=-1)
    out = torch.bmm(attn_probs, v)
    return out

# Fallback approach for Python 3.11:
# - If Python < 3.11 => we compile
# - If Python >= 3.11 => skip compile to avoid runtime error
if sys.version_info < (3, 11):
    compiled_flex_attention = torch.compile(flex_attention, mode="default")
    print("Using torch.compile on Python < 3.11.")
else:
    compiled_flex_attention = flex_attention
    print("Skipping torch.compile (Python 3.11+ not yet supported).")

def run_flex_attention_demo():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 2
    d_model = 64

    for mask_type in ["causal", "sliding"]:
        print(f"\n===> Testing mask_type = {mask_type}")
        for seq_len in [128, 256, 300, 512]:
            q = torch.randn(batch_size, seq_len, d_model, device=device)
            k = torch.randn(batch_size, seq_len, d_model, device=device)
            v = torch.randn(batch_size, seq_len, d_model, device=device)

            base_mask = build_attention_mask(seq_len, mask_type=mask_type, device=device)
            out = compiled_flex_attention(q, k, v, base_mask)
            print(f"seq_len={seq_len}, out.shape={out.shape}, mask_type={mask_type}")

if __name__ == "__main__":
    run_flex_attention_demo()


Skipping torch.compile (Python 3.11+ not yet supported).

===> Testing mask_type = causal
seq_len=128, out.shape=torch.Size([2, 128, 64]), mask_type=causal
seq_len=256, out.shape=torch.Size([2, 256, 64]), mask_type=causal
seq_len=300, out.shape=torch.Size([2, 300, 64]), mask_type=causal
seq_len=512, out.shape=torch.Size([2, 512, 64]), mask_type=causal

===> Testing mask_type = sliding
seq_len=128, out.shape=torch.Size([2, 128, 64]), mask_type=sliding
seq_len=256, out.shape=torch.Size([2, 256, 64]), mask_type=sliding
seq_len=300, out.shape=torch.Size([2, 300, 64]), mask_type=sliding
seq_len=512, out.shape=torch.Size([2, 512, 64]), mask_type=sliding


---
## 7) **Sequence Classification Patch** (LoRA + `AutoModelForSequenceClassification`)

We patch `AutoModelForSequenceClassification` by injecting LoRA modules into every `nn.Linear` in the model, then fine-tune only the LoRA parameters on a toy dataset.

In [None]:
################################################################################
# SINGLE-CELL COLAB SCRIPT:
# LoRA BERT classification w/ Torch 2.1.0+cu121 & Transformers 4.31.0
# Removing peft & older libraries => fix the 'adapter_kwargs' error.
################################################################################

print("=== Checking GPU / driver info ===")
!nvidia-smi

print("\n=== 1) Uninstall conflicting packages (torch, transformers, peft, xformers, etc.) ===")
!pip uninstall -y torch transformers peft xformers tokenizers bitsandbytes

print("\n=== 2) Install Torch 2.1.0+cu121 & Transformers==4.31.0 ===")
!pip install torch==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.31.0

print("\n=== 3) Running your LoRA BERT classification code ===")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
)

class ToyClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=32):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, lora_rank=4, alpha=1.0):
        super().__init__()
        self.lora_down = nn.Linear(in_features, lora_rank, bias=False)
        self.lora_up   = nn.Linear(lora_rank, out_features, bias=False)
        nn.init.zeros_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)
        self.alpha = alpha

    def forward(self, x):
        return self.alpha * self.lora_up(self.lora_down(x))

def patch_model_for_sequence_classification(model, lora_rank=4, alpha=1.0):
    modules_to_patch = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            modules_to_patch.append((name, module))

    for full_name, module in modules_to_patch:
        safe_name = full_name.replace(".", "_")
        lora_mod = LoRALinear(
            module.in_features,
            module.out_features,
            lora_rank=lora_rank,
            alpha=alpha
        ).to(module.weight.device, module.weight.dtype)

        # Register it
        model.add_module(f"lora_{safe_name}", lora_mod)

        # Patch forward
        orig_forward = module.forward
        def custom_forward(m_self, x, orig_forward=orig_forward, lora_layer=lora_mod):
            base_out = orig_forward(x)
            lora_out = lora_layer(x)
            return base_out + lora_out

        module.forward = custom_forward.__get__(module, module.__class__)

    return model

def finetune_sequence_classification():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_name = "bert-base-uncased"
    num_labels = 2

    config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
    model.to(device)

    # Inject LoRA
    patch_model_for_sequence_classification(model, lora_rank=4, alpha=1.0)

    texts = [
        "I love this product, it is amazing!",
        "This is the worst experience of my life.",
        "The movie was quite entertaining.",
        "Horrible service, will not come back!"
    ]
    labels = [1, 0, 1, 0]
    dataset = ToyClassificationDataset(texts, labels, tokenizer, max_length=16)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    # Only train LoRA params
    lora_params = []
    for param_name, param in model.named_parameters():
        if "lora_" in param_name:
            param.requires_grad = True
            lora_params.append(param)
        else:
            param.requires_grad = False

    optimizer = optim.AdamW(lora_params, lr=1e-4)
    model.train()
    epochs = 3
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, avg_loss={avg_loss:.4f}")

    model.eval()
    sample_text = ["I dislike the taste, not recommended."]
    enc = tokenizer(sample_text, truncation=True, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
    preds = torch.argmax(logits, dim=-1)
    print("\nInference Test:")
    print(f"Input: {sample_text}")
    print(f"Logits: {logits.cpu().numpy()}")
    print(f"Predicted label: {preds.item()} (0=Neg,1=Pos)")

finetune_sequence_classification()


=== Checking GPU / driver info ===
Thu Feb 20 18:43:30 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
             

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3, avg_loss=0.8157
Epoch 2/3, avg_loss=0.7862
Epoch 3/3, avg_loss=0.7097

Inference Test:
Input: ['I dislike the taste, not recommended.']
Logits: [[-0.36591572  0.18960014]]
Predicted label: 1 (0=Neg,1=Pos)


---
## 8) **Refactored Attention**

Merging `xformers`, PyTorch’s SDPA, `flash_attn`, and a fallback “flex” approach in a single function.

In [None]:
import warnings

try:
    import xformers.ops as xops
    XFORMERS_AVAILABLE = True
except ImportError:
    XFORMERS_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
except ImportError:
    FLASH_ATTN_AVAILABLE = False

SDPA_AVAILABLE = hasattr(torch.nn.functional, "scaled_dot_product_attention")

def flex_custom_attention(q, k, v, attn_mask=None):
    d_k = q.shape[-1]
    scores = torch.matmul(q, k.transpose(-1, -2)) / (d_k ** 0.5)
    if attn_mask is not None:
        scores = scores + attn_mask
    weights = torch.softmax(scores, dim=-1)
    weights = weights.to(v.dtype)
    out = torch.matmul(weights, v)
    return out

def xformers_attention(q, k, v, attn_mask=None):
    B, H, L, D = q.shape
    q_ = q.reshape(B*H, L, D)
    k_ = k.reshape(B*H, L, D)
    v_ = v.reshape(B*H, L, D)

    bool_mask = None
    if attn_mask is not None:
        expanded = attn_mask.expand(B, H, L, L).reshape(B*H, L, L)
        bool_mask = (expanded < -1e4)
    out = xops.memory_efficient_attention(
        q_, k_, v_,
        attn_mask=bool_mask,
        p=0.0
    )
    return out.reshape(B, H, L, D)

def flash_attention(q, k, v, attn_mask=None):
    import flash_attn
    B, H, L, D = q.shape
    q_ = q.reshape(B*H, L, D)
    k_ = k.reshape(B*H, L, D)
    v_ = v.reshape(B*H, L, D)
    out = flash_attn.flash_attn_func(
        q_, k_, v_,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False
    )
    return out.reshape(B, H, L, D)

def sdpa_attention(q, k, v, attn_mask=None):
    from torch.nn.functional import scaled_dot_product_attention as sdpa
    B, H, L, D = q.shape
    q_ = q.permute(2, 0, 1, 3).reshape(L, B*H, D)
    k_ = k.permute(2, 0, 1, 3).reshape(L, B*H, D)
    v_ = v.permute(2, 0, 1, 3).reshape(L, B*H, D)

    am = None
    if attn_mask is not None:
        am = attn_mask.expand(B, H, L, L).reshape(B*H, L, L)
    out_ = sdpa(q_, k_, v_, attn_mask=am, dropout_p=0.0, is_causal=False)
    out = out_.reshape(L, B, H, D).permute(1, 2, 0, 3)
    return out

def unified_attention(q, k, v, attn_mask=None, backend="auto"):
    if backend == "auto":
        if XFORMERS_AVAILABLE:
            backend = "xformers"
        elif FLASH_ATTN_AVAILABLE:
            backend = "flash"
        elif SDPA_AVAILABLE:
            backend = "sdpa"
        else:
            backend = "flex"

    if backend == "xformers":
        if not XFORMERS_AVAILABLE:
            raise RuntimeError("xformers not installed!")
        return xformers_attention(q, k, v, attn_mask)
    elif backend == "flash":
        if not FLASH_ATTN_AVAILABLE:
            raise RuntimeError("flash_attn not installed!")
        return flash_attention(q, k, v, attn_mask)
    elif backend == "sdpa":
        if not SDPA_AVAILABLE:
            raise RuntimeError("PyTorch >=2.0 needed for SDPA!")
        return sdpa_attention(q, k, v, attn_mask)
    else:
        return flex_custom_attention(q, k, v, attn_mask)

# Demo usage
def example_unified_attention():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    B, H, L, D = 2, 4, 16, 64
    q = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    k = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    v = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
    attn_mask = torch.zeros((B, 1, L, L), device=device, dtype=torch.float32)
    blocked = torch.rand((B, 1, L, L), device=device) < 0.2
    attn_mask[blocked] = float("-inf")
    out_flex = unified_attention(q, k, v, attn_mask, backend="flex")
    print("fallback =>", out_flex.shape)

if __name__ == "__main__":
    example_unified_attention()


fallback => torch.Size([2, 4, 16, 64])


---
## Final Notes

- This notebookincludes separate code snippets for each task.
- Some cells (like the nF4 → Triton example) are skeletons or placeholders to illustrate core ideas.
