<a href="https://colab.research.google.com/github/byi8220/unsloth-puzzles/blob/main/Problem3/Unsloth_Problem_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unsloth Problem 3 - Make torch.compile work without graph breaks for QLoRA

Run on an Nvidia L4 colab instance (since no bfloat16 on T4).

**NOTE:** Funny enough, there's some discussion on this exact problem within https://discuss.pytorch.org/t/how-to-solve-the-graph-break-happen-in-torch-compile/216858/1

In [1]:
# Code to install Unsloth, Triton, Torch etc
# We're installing torch nightly to workaround https://discuss.pytorch.org/t/how-to-solve-the-graph-break-happen-in-torch-compile/216858/1
# This leads to some fun version breaks...
%%capture
!pip install --no-cache-dir --force-reinstall accelerate huggingface_hub datasets trl hf_transfer triton

!pip install --no-deps --no-cache-dir --force-reinstall bitsandbytes
!pip install --no-cache-dir --force-reinstall transformers==4.49.0

!pip install --pre --no-cache-dir --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
!pip install --force-reinstall -U numpy
!pip install --force-reinstall -U scipy

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

---
---
---
<a name="COMPILE"></a>
## C) Make `torch.compile` work without graph breaks for QLoRA [Difficulty: Easy to Medium] [Max points: 9]

1. Goal: Write a single Python script like task B), except the goal is to `torch.compile` all modules if possible.

2. There must NOT be graph breaks, and excessive re-compilations should not be seen.

3. You should have say max 30 compilations. Over 60 is definitely wrong.

4. The loss must match with the non compiled module.

5. Utilize patching as much as possible.

6. Think about which areas might need disabling for compilation. Think about regional compilation. How do we compile sections efficiently?

7. Log memory / VRAM usage, and monitor speedups as well.

8. Must work for QLoRA.

We provided a script below, and showcased how to detect if graph breaks are seen. We also torch compiled the MLP for Llama:

In [3]:
import torch
torch_compile_options = torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

# Enable `fullgraph` to stop processing on graph break.
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def compiled_llama_mlp(self, x):
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    return down_proj

import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaMLP.forward = compiled_llama_mlp

# Compile flex attn
import transformers.integrations.flex_attention
transformers.integrations.flex_attention.flex_attention_forward = torch.compile(
    transformers.integrations.flex_attention.flex_attention_forward,
    fullgraph = True, dynamic = True, options = torch_compile_options
)

# Compile layernorm
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = torch.compile(
    transformers.models.llama.modeling_llama.LlamaRMSNorm.forward,
    fullgraph = True, dynamic = True, options = torch_compile_options
)


In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
    "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
)

model2 = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation = "flex_attention",
    quantization_config = bnb_config,
    torch_dtype = dtype, # Need to manually move dtypes
)
# Compile loss function
model2.loss_function = torch.compile(model2.loss_function, fullgraph = True, dynamic = True, options = torch_compile_options)

# Need to manually set compute_dtype.
setattr(model2.config.quantization_config, "bnb_4bit_compute_dtype", dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

# Get LoRA and setup model
lora_config = LoraConfig(
    r = 64,
    lora_alpha = 128,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

model2 = get_peft_model(model2, lora_config)

with torch.no_grad():
    for name, param in model2.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)

# Currently GC will cause torch.compile to be disabled, so disable it
# model.gradient_checkpointing_enable()
model2.enable_input_require_grads()

# Get dataset
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


We provide full logging for `torch.compile` like below:

In [5]:
# Must show all graph breaks are not seen with torch.compile
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

import logging
torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo = logging.WARN,
    inductor = logging.WARN,
    graph_breaks = True,
    recompiles = True,
    recompiles_verbose = True,
    compiled_autograd_verbose = True,
    # aot_joint_graph = True, # Enable for more logs
    # aot_graphs = True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

# Solution: Patching model to allow compilation

Why are there graph breaks? From our stack trace of the unpatched graph above, we see a few lines of interest:

```
Reason: Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}
```

```
Reason: Unsupported: Graph break due to unsupported builtin None._SimpleCData.__new__. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
```

This suggests that in order to fix these graph breaks, we should look into getting custom ops working for 4bit quantized params.

That is, we should stop `torch.compile` from tracing in the dequant code and consider it an isolated black box.

This may be solvable by wrapping [the code which actually calls the CUDA dequant](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/e772a9e8723cfc2036fecc830c328ad3b9705250/bitsandbytes/functional.py#L1028-L1046) with a custom operator.

Which may be hackable with some aggressive monkeypatching somewhere within [MatMul4Bit](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/e772a9e8723cfc2036fecc830c328ad3b9705250/bitsandbytes/autograd/_functions.py#L441).


The first error expands to:
```
Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}

from user code:
   File "<ipython-input-4-1b0083b2d5de>", line 12, in compiled_llama_mlp
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  File "/usr/local/lib/python3.11/dist-packages/peft/tuners/lora/bnb.py", line 496, in forward
    result = self.base_layer(x, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py", line 484, in forward
    return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
```

This suggests that the problem is caused by `Params4bit.t()` when we are trying to use our quantized weights in a matmul.

One way of solving this is to:

1. Refactor `Params4bit` so that we procure the transposed weights without a graph break
2. Refactor+patch `Linear4Bit` to use this new function
3. Wrap `dequantize_4bit` in a custom op

Maybe since `Params4bit` is a subclass of `Parameter`, we can convince dynamo to stop wrapping it in `UserDefinedObjectVariable`, and instead give it a hint that it can be treated as a `TensorVariable`?

In [6]:
# Define a custom matmul_4bit function
import torch
from typing import Callable, Optional, Tuple, List
import bitsandbytes.functional as F
from bitsandbytes.functional import QuantState
from bitsandbytes.nn.modules import Params4bit, Linear4bit, fix_4bit_weight_quant_state_from_module
import bitsandbytes

def _l4b_forward(self, x):
    fix_4bit_weight_quant_state_from_module(self)

    # weights are cast automatically as Int8Params, but the bias has to be cast manually
    if self.bias is not None and self.bias.dtype != x.dtype:
        self.bias.data = self.bias.data.to(x.dtype)

    if not self.compute_type_is_set:
        self.set_compute_type(x)
        self.compute_type_is_set = True

    inp_dtype = x.dtype
    if self.compute_dtype is not None:
        x = x.to(self.compute_dtype)
    bias = None if self.bias is None else self.bias.to(self.compute_dtype)
    weight_tensor = self.weight.data

    return bitsandbytes.matmul_4bit(x, weight_tensor.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)

Linear4bit.forward = _l4b_forward


In [7]:
# Define a custom dequantize op which doesn't cause graph breaks
import torch
from typing import Callable, Optional, Tuple, List
import bitsandbytes.functional as F
from bitsandbytes.functional import QuantState

if not '_dequantize_4bit' in globals():
    _dequantize_4bit = F.dequantize_4bit

@torch.library.custom_op("my_qlora::dequantize_4bit", mutates_args=(["out"]))
def dequantize_4bit_op(
    A: torch.Tensor,
    shape: List[int],
    absmax: Optional[torch.Tensor] = None,
    code: Optional[torch.Tensor] = None,
    blocksize: int = 4096,
    dtype: torch.dtype = torch.float16,
    offset: Optional[torch.Tensor] = None,
    absmax2: Optional[torch.Tensor] = None,
    code2: Optional[torch.Tensor] = None,
    dtype2: Optional[torch.dtype] = None,
    blocksize2: Optional[int] = None,
    quant_type: str = "fp4",
    out: torch.Tensor = None, # `out` is not optional (torch.compile doesn't like returning its inputs)
) -> None:
    # Rebuild quant state
    state2 = None
    if code2 is not None:
        state2 = QuantState(
            absmax=absmax2,
            blocksize=blocksize2,
            code=code2,
        )
    state = QuantState(
        shape=shape,
        absmax=absmax,
        code=code,
        blocksize=blocksize,
        dtype=dtype,
        offset=offset,
        state2=state2
    )
    _dequantize_4bit(A, state, absmax, out, blocksize, quant_type)

# Need to transform `quant_state` into a form accepted by custom ops
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def dequantize_4bit_wrapper(
    A: torch.Tensor,
    quant_state: Optional[QuantState] = None,
    absmax: Optional[torch.Tensor] = None,
    out: Optional[torch.Tensor] = None,
    blocksize: int = 64,
    quant_type: str ="fp4",
):
    # Unpack quant state
    if absmax is None:
        absmax = quant_state.absmax
    code = quant_state.code
    blocksize = blocksize if quant_state is None else quant_state.blocksize
    dtype = quant_state.dtype
    offset = quant_state.offset

    state2 = quant_state.state2
    absmax2, code2, dtype2, blocksize2 = None, None, None, None
    if quant_state.nested:
        absmax2, code2, dtype2 = state2.absmax, state2.code, state2.dtype
        blocksize2 = blocksize if state2.blocksize is None else state2.blocksize

    is_transposed = A.shape[0] == 1
    if out is None:
        out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
    dequantize_4bit_op(
        A,
        quant_state.shape,
        absmax, code, blocksize,
        dtype,
        offset,
        absmax2, code2, dtype2, blocksize2,
        quant_type,
        out,
    )
    if is_transposed:
        return out.t()
    return out

F.dequantize_4bit = dequantize_4bit_wrapper # Put our new op in.

assert(_dequantize_4bit != F.dequantize_4bit)

In [8]:
#@title Retrain the model with our patched kernels.
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
peak_memory_before = torch.cuda.max_memory_allocated()
trainer = SFTTrainer(
    model = model2,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 2,
        warmup_steps = 1,
        max_steps = 10, # Run many steps just so compilation is actually worth it.
        logging_steps = 1,
        output_dir = "outputs_new",
        seed = 3407,
        max_seq_length = max_seq_length,
        fp16 = model2.get_input_embeddings().weight.dtype == torch.float16,
        bf16 = model2.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to = "none", # For W&B
        dataset_num_proc = 4,
    ),
)
patched_stats = trainer.train()

torch.cuda.synchronize()
peak_memory_after = torch.cuda.max_memory_allocated()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
W0410 03:47:20.411000 14548 torch/_inductor/debug.py:454] [1/0] model__1_inference_0 debug trace: /content/torch_compile_debug/run_2025_04_10_03_47_20_295335-pid_14548/torchinductor/model__1_inference_0.0
V0410 03:47:20.654000 14548 torch/_dynamo/guards.py:2974] [1/1] [__recompiles_verbose] Recompiling function dequantize_4bit_wrapper in <ipython-input-7-3aa9c7a62eec>:46
V0410 03:47:20.654000 14548 torch/_dynamo/guards.py:2974] [1/1] [__recompiles_verbose]     triggered by the following guard failure(s):
V0410 03:47:20.654000 14548 torch/_dynamo/guards.py:2974] [1/1] [__recompiles_verbose]     guard 0 failures:
V0410 03:47:20.654000 14548 torch/_dynamo/guards.py:2974] [1/1] [__recompiles_verbose]     - 1/0: L['quant_state'].state2.a

Step,Training Loss
1,1.5386
2,2.3994
3,2.4433
4,3.437
5,2.0862
6,2.8635
7,2.0892
8,1.4658
9,2.0902
10,2.3545


In [9]:
mem_diff = peak_memory_after - peak_memory_before
mem_diff_gb = (mem_diff) / (1024**3)

print("Peak VRAM usage during training: {:.2f} GB".format(mem_diff_gb))
print("Train Stats", patched_stats)

Peak VRAM usage during training: 0.87 GB
Train Stats TrainOutput(global_step=10, training_loss=2.2767750144004824, metrics={'train_runtime': 31.8307, 'train_samples_per_second': 0.628, 'train_steps_per_second': 0.314, 'total_flos': 10954170839040.0, 'train_loss': 2.2767750144004824})


Compare with: https://colab.research.google.com/drive/1mzqLo8c9lJ0eewV858qp5zEwx9t2M1gp#scrollTo=my-M_OQnDfPk

```
# Reference results, with graph breaks
Peak VRAM usage during training: 0.88 GB
Train Stats TrainOutput(global_step=10, training_loss=2.3851507186889647, metrics={'train_runtime': 62.6108, 'train_samples_per_second': 0.319, 'train_steps_per_second': 0.16, 'total_flos': 10592155496448.0, 'train_loss': 2.3851507186889647})
```

VRAM usage appears to be the same.
