# 🤙 gpt-oss-20b on NVIDIA Brev

<div style="background: linear-gradient(90deg, #00ff87 0%, #60efff 100%); padding: 1px; border-radius: 8px; margin: 20px 0;">
    <div style="background: #0a0a0a; padding: 20px; border-radius: 7px;">
        <p style="color: #60efff; margin: 0;"><strong>⚡ Powered by Brev</strong> | Converted from <a href="https://github.com/unslothai/notebooks/blob/main/nb/gpt-oss-20b.ipynb" style="color: #00ff87;">Unsloth Notebook</a></p>
    </div>
</div>

## 📋 Configuration

<table style="width: auto; margin-left: 0; border-collapse: collapse; border: 2px solid #808080;">
    <thead>
        <tr style="border-bottom: 2px solid #808080;">
            <th style="text-align: left; padding: 8px 12px; border-right: 2px solid #808080; font-weight: bold;">Parameter</th>
            <th style="text-align: left; padding: 8px 12px; font-weight: bold;">Value</th>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Model</strong></td>
            <td style="text-align: left; padding: 8px 12px;">gpt-oss-20b</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Recommended GPU</strong></td>
            <td style="text-align: left; padding: 8px 12px;">A100-40GB</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Min VRAM</strong></td>
            <td style="text-align: left; padding: 8px 12px;">24 GB</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Batch Size</strong></td>
            <td style="text-align: left; padding: 8px 12px;">2</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Categories</strong></td>
            <td style="text-align: left; padding: 8px 12px;">reasoning, fine-tuning, large-model</td>
        </tr>
    </tbody>
</table>

## 🔧 Key Adaptations for Brev

- ✅ Replaced Colab-specific installation with conda-based Unsloth
- ✅ Converted magic commands to subprocess calls
- ✅ Removed Google Drive dependencies
- ✅ Updated paths from `/workspace/` to `/workspace/`
- ✅ Added `device_map="auto"` for multi-GPU support
- ✅ Optimized batch sizes for NVIDIA GPUs

## 📚 Resources

- [Unsloth Documentation](https://docs.unsloth.ai/)
- [Brev Documentation](https://docs.nvidia.com/brev)
- [Original Notebook](https://github.com/unslothai/notebooks/blob/main/nb/gpt-oss-20b.ipynb)



<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth your local device, follow [our guide](https://docs.unsloth.ai/get-started/install-and-update). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://github.com/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [None]:
# Environment Check for Brev
import sys
import os
import shutil

print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")

# Configure PyTorch cache directories to avoid permission errors
# MUST be set before any torch imports
# Prefer /ephemeral for Brev instances (larger scratch space)

# Test if /ephemeral exists and is actually writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_base = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
    print("Using /ephemeral for cache (Brev scratch space)")
else:
    cache_base = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")
    print("Using home directory for cache")

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_base
os.environ["TORCH_COMPILE_DIR"] = cache_base
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["XDG_CACHE_HOME"] = os.path.expanduser("~/.cache")
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Create cache directories with proper permissions (777 to ensure writability)
for cache_dir in [cache_base, triton_cache, tmpdir, os.environ["XDG_CACHE_HOME"]]:
    os.makedirs(cache_dir, mode=0o777, exist_ok=True)

# Clean up any old compiled caches that point to /tmp
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    print(f"⚠️  Removing old compiled cache: {old_cache}")
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ PyTorch cache: {cache_base}")

try:
    from unsloth import FastLanguageModel
    import transformers
    print("\n✅ Unsloth already available")
    print(f"   Unsloth: {FastLanguageModel.__module__}")
    print(f"   Transformers: {transformers.__version__}")
    
    # Check if we need to upgrade/downgrade transformers
    import pkg_resources
    try:
        current_transformers = pkg_resources.get_distribution("transformers").version
        if current_transformers != "4.56.2":
            print(f"   ⚠️  Transformers {current_transformers} != 4.56.2, may need adjustment")
    except:
        pass
    
    print("   ✅ All packages OK, skipping installation")
except ImportError:
    print("\n⚠️  Unsloth not found - installing required packages...")
    import subprocess
    
    # Find uv in common locations
    uv_paths = [
        "uv",  # In PATH
        os.path.expanduser("~/.venv/bin/uv"),
        os.path.expanduser("~/.cargo/bin/uv"),
        "/usr/local/bin/uv"
    ]
    
    uv_cmd = None
    for path in uv_paths:
        try:
            result = subprocess.run([path, "--version"], capture_output=True, timeout=2)
            if result.returncode == 0:
                uv_cmd = path
                print(f"   Found uv at: {path}")
                break
        except (FileNotFoundError, subprocess.TimeoutExpired):
            continue
    
    print(f"\nInstalling packages into: {sys.executable}")
    
    if uv_cmd:
        print("Using uv package manager...\n")
        try:
            subprocess.check_call([uv_cmd, "pip", "install", "unsloth"])
            subprocess.check_call([uv_cmd, "pip", "install", "transformers==4.56.2"])
            subprocess.check_call([uv_cmd, "pip", "install", "--no-deps", "trl==0.22.2"])
            print("\n✅ Installation complete")
        except subprocess.CalledProcessError as e:
            print(f"⚠️  uv install failed: {e}")
            uv_cmd = None  # Fall back to pip
    
    if not uv_cmd:
        print("Using pip package manager...\n")
        try:
            # Ensure pip is available
            subprocess.run([sys.executable, "-m", "ensurepip", "--upgrade"], 
                         capture_output=True, timeout=30)
            # Install packages
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "unsloth"])
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "transformers==4.56.2"])
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-deps", "trl==0.22.2"])
            print("\n✅ Installation complete")
        except subprocess.CalledProcessError as e:
            print(f"❌ Installation failed: {e}")
            print("   This may be due to permission issues.")
            print("   Packages may already be installed - attempting to continue...")
    
    # Verify installation
    try:
        from unsloth import FastLanguageModel
        print("✅ Unsloth is now available")
    except ImportError as e:
        print(f"❌ Unsloth still not available: {e}")
        print("⚠️  Please check setup script ran successfully or restart instance")

### Unsloth

# Goal: Make faster kernels with Reinforcement Learning

Our goal is to make a faster matrix multiplication kernel by doing RL on GTP-OSS 20B with Unsloth.

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/18/Matrix_multiplication_qtl1.svg/500px-Matrix_multiplication_qtl1.svg.png" height=200 />

You will learn how to:
1. Counteract **reward hacking** like cheating, caching, laziness.
2. Timing and correctness of kernels and time limits.
3. Making good **reward functions**
4. How to seriously do RL to make optimized CUDA kernels

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 768 # Can increase for longer RL output
lora_rank = 4 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gpt-oss-20b",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    offload_embedding = True, # Reduces VRAM by 1GB,
    device_map="auto")

We now add some small amount of LoRA weights to GPT-OSS so we only need to train those, instead of training on the full model.

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

# Optimized matrix multiplication

Numpy has optimized matrix multiplication kernels for CPUs via BLAS optimized operations. For GPUs, one can use CUDA accelerated cuBLAS kernels which PyTorch calls under the hood.

To generate some random matrices to do matrix multiplication, we can do the below:

In [None]:
import numpy as np
def generate_random_matrices(seed = 3407, n = 256):
    random_state = np.random.RandomState(seed)
    n, k, m = random_state.randint(1, n+1, size = 3)
    A = np.random.uniform(-10, 10, size = (n, k))
    B = np.random.uniform(-10, 10, size = (k, m))
    return A, A.tolist(), B, B.tolist()

We shall generate a small matrix, and see the matrix multiplied output

In [None]:
A, A_list, B, B_list = generate_random_matrices(seed = 42, n = 5)
print(A)
print(B)
print(np.matmul(A, B))

We can call a LLM to generate a simple matrix multiply kernel in Python only, and we can calculate the differences between the actual result and the kernel's result

In [None]:
def calculate_difference(pred, real):
    if pred is None: return 5, 5
    assert real is not None
    import numpy as np
    try:
        difference = pred - real
    except:
        return 5, 5
    amax_error = float(np.amax(difference))
    mse_error  = float(np.mean(np.square(difference)))
    return amax_error, mse_error

In [None]:
# Kernel generated by GPT-5
def matmul(A, B):
    z, s = zip, sum
    Bt = list(z(*B))
    return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]

We see the error below is very small, so that's good!

In [None]:
prediction = matmul(A_list, B_list)
calculate_difference(prediction, np.matmul(A, B))

# Countering Reward Hacking

The ultimate goal of RL is to maximize some reward (say speed, revenue, some metric).

But RL can **cheat** When the RL algorithm learns a trick or exploits something to increase the reward, without actually doing the task at end, this is called "Reward Hacking".

Some good examples are in https://en.wikipedia.org/wiki/Reward_hacking

For matrix multiplication kernels, we might see the following issues:

* Laziness: RL learns to use Numpy, Torch, other libraries, which calls optimized CUDA kernels.
* Caching: RL learns to cache the result of the output
* Cheating: RL learns to find the actual output by inspecting Python global variables
* RL learns to edit the timing function to make it output 0 time as passed.

And possibly more. We shall try to address each!

# Countering Reward Hacking 1: Stop laziness
We can stop the RL algorithm from calling optimized code by inspecting if the generated code imports other non standard Python libraries. We used GPT-5 to help generate this check `check_only_stdlib_imports`:

In [None]:
#@title (Collapsible code)
import ast
import sys
import sysconfig
from pathlib import Path

def _stdlib_names():
    """
    Build a set of canonical stdlib top-level module/package names.
    Uses sys.stdlib_module_names when available (3.10+), with a
    filesystem fallback for older versions/edge cases.
    """
    names = {m.lower() for m in getattr(sys, "stdlib_module_names", set())}
    names |= {m.lower() for m in sys.builtin_module_names}
    names.add("__future__")  # special-case

    # Fallback/augmentation: scan the stdlib directory
    try:
        stdlib_dir = Path(sysconfig.get_path("stdlib"))
        if stdlib_dir.exists():
            for p in stdlib_dir.iterdir():
                if p.name == "site-packages":
                    continue
                if p.suffix == ".py":
                    names.add(p.stem.lower())
                elif p.is_dir() and (p / "__init__.py").exists():
                    names.add(p.name.lower())
    except Exception:
        # conservative fallback; the names set above will still work well
        pass

    return names

_STDLIB_SET = _stdlib_names()

def check_only_stdlib_imports(code: str):
    """
    Return (ok: bool, details: dict)

    ok == True  -> all absolute imports are from the stdlib.
    ok == False -> details['non_stdlib'] lists offending top-level modules.

    details includes:
      - stdlib: sorted list of stdlib imports found
      - non_stdlib: sorted list of non-stdlib imports found
      - relative_imports: count of relative imports (always allowed here)
    """
    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        return False, {
            "error": f"SyntaxError: {e}",
            "stdlib": [],
            "non_stdlib": [],
            "relative_imports": 0,
        }

    abs_imports = set()
    relative_count = 0

    class Visitor(ast.NodeVisitor):
        def visit_Import(self, node: ast.Import):
            for alias in node.names:
                abs_imports.add(alias.name.split(".")[0])
        def visit_ImportFrom(self, node: ast.ImportFrom):
            nonlocal relative_count
            if (node.level or 0) > 0:
                # relative import
                relative_count += 1
            else:
                if node.module:
                    abs_imports.add(node.module.split(".")[0])

    Visitor().visit(tree)

    stdlib_found = sorted(m for m in abs_imports if m.lower() in _STDLIB_SET)
    non_stdlib = sorted(m for m in abs_imports if m.lower() not in _STDLIB_SET)

    return len(non_stdlib) == 0, {
        "stdlib": stdlib_found,
        "non_stdlib": non_stdlib,
        "relative_imports": relative_count,
    }

For example, let's call `check_only_stdlib_imports` on a random piece of matrix multiplication code generated by GPT-5:

In [None]:
sample = """
def matmul(A, B):
    import numpy as np
    from torch import matmul
    z, s = zip, sum
    Bt = list(z(*B))
    return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]
"""
ok, info = check_only_stdlib_imports(sample)
print("Only stdlib imports?", ok)
print(info)

# Countering Reward Hacking 2: Stop cheating
We can stop the RL algorithm from using global or cached variables by restricting it's `locals` and `globals`.

We are also going to use `exec` to create the function, so we have to save the output to an empty dict.

We also disallow global variable access.

In [None]:
output_function = {}
exec(sample, {}, output_function)
output_function["matmul"]

We also disallow global variable access via `types.FunctionType(f.__code__, {})`


In [None]:
import types
output_function["matmul"] = types.FunctionType(output_function["matmul"].__code__, {})

def import_numpy():
    np.matmul
    print("Success")

import_numpy()
import_numpy = types.FunctionType(import_numpy.__code__, {})
try:
    import_numpy()
except Exception as e:
    print(str(e))

In [None]:
def create_locked_down_function(function):
    output_function = {}
    exec(function, {}, output_function)
    new_matmul = output_function["matmul"]
    new_matmul = types.FunctionType(new_matmul.__code__, {})
    return new_matmul

# Countering Reward Hacking 3: Stop caching
We can stop the RL algorithm from using cached data by wiping the cache with a large fake matrix. We also have to benchmark carefully with multiple loops and turns.

We also add a **timer** to not make the algorithm go in an endless loop.

In [None]:
import os, gc, time, statistics
import signal
from contextlib import contextmanager
class TimeoutError(Exception): pass

@contextmanager
def time_limit(seconds):
    def _handler(signum, frame):
        raise TimeoutError(f"Timed out after {seconds}s")
    old = signal.signal(signal.SIGALRM, _handler)
    signal.setitimer(signal.ITIMER_REAL, seconds)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0.0)
        signal.signal(signal.SIGALRM, old)

class Benchmarker:
    def __init__(self, trials = 3, loops = 1, timeout = 30):
        self.buffer = np.zeros(2 * 1024 * 1024 * 1024, dtype = np.uint8)
        self.trials = trials
        self.loops = loops
        assert timeout > 0 # Cannot be 0 since it won't work!
        self.timeout = timeout
    def thrash(self):
        # Edit the buffer to wipe cache lines
        self.buffer ^= 1
        return int(self.buffer[::4096].sum())

    def benchmark(self, function, arguments):
        assert len(arguments) == self.loops
        samples = []
        exceptions = []
        timed_out = 0
        for _ in range(self.trials):
            gc.collect(); gc.disable(); self.thrash()
            t_start = time.perf_counter_ns()
            for i in range(self.loops):
                try:
                    with time_limit(self.timeout):
                        function(*arguments[i])
                except TimeoutError as e:
                    timed_out += 1
                except Exception as e:
                    exceptions.append(str(e))
            t_end = time.perf_counter_ns()
            gc.enable()
            samples.append((t_end - t_start) // max(1, self.loops))
        return {
            "median_ns": int(statistics.median(samples)),
            "mean_ns": int(statistics.fmean(samples)),
            "stdev_ns": int(statistics.pstdev(samples) if len(samples) > 1 else 0),
            "exceptions" : exceptions,
            "timeouts" : timed_out,
        }

For example we use our matmul kernel we had, and benchmark it with a 10 second delay:

In [None]:
A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)
Benchmarker(trials = 1, timeout = 10).benchmark(output_function["matmul"], [(A_list, B_list)])

# Data & RL task setup

We now have to create a prompt to the model for which it will do some task. For our matrix multiply example, we use the below:

In [None]:
prompt = """
Create a new fast matrix multiplication function using only native Python code.
You are given a list of list of numbers.
Output your new function in backticks using the format below:
```python
def matmul(A, B):
    return ...
```
""".strip()
print(prompt)

First, let's prompt GPT-OSS without RL and see how it goes:

In [None]:
# Fix torch compilation cache permissions
import os
import shutil

# Test if /ephemeral is writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_dir = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
else:
    cache_dir = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")

# Create directories with full write permissions
for d in [cache_dir, triton_cache, tmpdir]:
    os.makedirs(d, mode=0o777, exist_ok=True)

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
os.environ["TORCH_COMPILE_DIR"] = cache_dir
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Clean up any old compiled caches
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ Torch cache: {cache_dir}")
print(f"✅ Temp dir: {tmpdir}")

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
    reasoning_effort = "low",
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 1.0,
    max_new_tokens = 512,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

# Reward functions

We now design the `extract_function` function which simply extracts the function wrapped in 3 backticks.

And 4 reward functions:

1. `function_works` which rewards the model if the strategy is a valid Python function.
2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.
3. `correctness_check` which checks if the kernel was correct or wrong - it shouldn't generate gibberish!
4. `speed_check` checks the performance relative to Numpy matmul directly.

In [None]:
def extract_function(text):
    if text.count("```") >= 2:
        first = text.find("```") + 3
        second = text.find("```", first)
        fx = text[first : second].strip()
        fx = fx.removeprefix("python\n")
        fx = fx[fx.find("def"):]
        if fx.startswith("def matmul(A, B):"): return fx
    return None
print(extract_function(prompt))

Below is our `function_works` reward function which uses Python's `exec` but guarded by not allowing leakage of local and global variables. We can also use `check_only_stdlib_imports` first to check if there are errors before even executing the function:

In [None]:
ok, info = check_only_stdlib_imports("def a")
ok, info

In [None]:
def function_works(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        print(function)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            score = -2.0
        else:
            try:
                new_matmul = create_locked_down_function(function)
                score = 1.0
            except:
                score = -0.5
        scores.append(score)
    return scores

`no_cheating` checks if the function cheated since it might have imported Numpy or Torch optimized code.

In [None]:
def no_cheating(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        else:
            ok = False
        scores.append(1.0 if ok else -20.0) # Penalize heavily!
    return scores

Next `correctness_check` checks if the kernel was correct. We want to penalize if the absolute error is larger than 1, and if the mean squared error is somewhat bigger then machine epsilon.

We have to execute the code now!

In [None]:
np.finfo(np.float64).eps

In [None]:
def correctness_check(completions, **kwargs):
    scores = []
    # Generate some random matrices of size less than 128
    A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 128)
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            scores.append(0)
            continue
        try:
            new_matmul = create_locked_down_function(function)
        except:
            scores.append(0)
            continue
        try:
            pred = new_matmul(A_list.copy(), B_list.copy())
        except:
            # Failed!
            scores.append(-2.0)
            continue
        true = np.matmul(A, B)
        amax_error, mse_error = calculate_difference(pred, true)

        # Check correctness and score!
        machine_epsilon = 100*np.finfo(np.float64).eps
        if   amax_error >= 3:   score = -3.0
        elif amax_error >= 2:   score = -2.5
        elif amax_error >= 1:   score = -2.0
        elif amax_error >= 0.5: score = -1.0
        elif amax_error >= 100*machine_epsilon: score = 0.0
        elif amax_error >= machine_epsilon: score = 1.0
        else: score = 3.0

        if   mse_error >= 3:   score += -3.0
        elif mse_error >= 2:   score += -2.5
        elif mse_error >= 1:   score += -2.0
        elif mse_error >= 0.5: score += -1.0
        elif mse_error >= 100*machine_epsilon: score += 0.0
        elif mse_error >= machine_epsilon: score += 1.0
        else: score += 3.0
        scores.append(score)
    return scores

Finally our benchmarking function for `speed_check`! We shall limit the timer to 10 seconds and do 3 trials.

In [None]:
A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)
benchmarker = Benchmarker(trials = 3, timeout = 10)
numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])
numpy_results

In [None]:
new_matmul = create_locked_down_function(extract_function(prompt))
new_results = benchmarker.benchmark(new_matmul, [(A_list, B_list)])
new_results

We can take the difference and do a negative sign for slower ones. If the ratio is less than 1 (ie faster, we shall invert it!)

In [None]:
negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
reward = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
reward

In [None]:
new_results["median_ns"] = 3
numpy_results["median_ns"] = 1000
negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
reward = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
reward

In [None]:
import subprocess
import sys

# Enhanced GPU check for NVIDIA Brev
print("=" * 60)
print("GPU Information")
print("=" * 60)

# Run nvidia-smi
subprocess.run(['nvidia-smi'], check=False)

# PyTorch CUDA info
import torch
print(f"\nPyTorch CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"    Memory: {props.total_memory / 1024**3:.2f} GB")
print("=" * 60)


import gc
def speed_check(completions, **kwargs):
    scores = []
    # Generate some random matrices of size less than 256
    A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 256)
    numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            scores.append(0)
            continue
        try:
            new_matmul = create_locked_down_function(function)
        except:
            scores.append(0)
            continue
        new_results = benchmarker.benchmark(new_matmul, [(A_list.copy(), B_list.copy())])

        # Get score and clip to -10, 10
        negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
        positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
        score = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
        if score >= 10:  score = 10
        if score <= -10: score = -10
        scores.append(score)
    # Free memory to counteract OOMs
    gc.collect()
    torch.cuda.empty_cache()
    return scores

We create the dataset which includes a replica of our prompt. Remember to add reasoning effort of low!

In [None]:
from datasets import Dataset
dataset = Dataset.from_list([{"prompt" : [{"role": "user", "content": prompt.strip()}], "answer" : 0, "reasoning_effort": "low"}]*1000)
maximum_length = len(tokenizer(prompt.strip())["input_ids"])
print(maximum_length)
dataset[0]

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations! We also support GSDP, GAPO, Dr GRPO and more! Go to our docs https://docs.unsloth.ai/ for more info!

In [None]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    temperature = 1.0,
    learning_rate = 5e-5,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 100,
    report_to = "none", # Can use Weights & Biases
    output_dir="/workspace/outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
# For optional training + evaluation
# new_dataset = dataset.train_test_split(test_size = 0.01)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        function_works,
        no_cheating,
        correctness_check,
        speed_check,
    ],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)

And let's train the model!

**NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!

In [None]:
trainer.train()

<a name="Inference"></a>
# Inference
Now let's try the model we just trained!

In [None]:
# Fix torch compilation cache permissions
import os
import shutil

# Test if /ephemeral is writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_dir = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
else:
    cache_dir = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")

# Create directories with full write permissions
for d in [cache_dir, triton_cache, tmpdir]:
    os.makedirs(d, mode=0o777, exist_ok=True)

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
os.environ["TORCH_COMPILE_DIR"] = cache_dir
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Clean up any old compiled caches
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ Torch cache: {cache_dir}")
print(f"✅ Temp dir: {tmpdir}")

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
    reasoning_effort = "low",
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 1.0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

<a name="Save"></a>
### Saving to float16 or MXFP4 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `mxfp4` for MXFP4 (OpenAI's GPT-OSS native precision). We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge and push to hub in mxfp4 4bit format
if False:
    model.save_pretrained_merged("finetuned_model", tokenizer, save_method = "mxfp4")
if False: model.push_to_hub_merged("repo_id/repo_name", tokenizer, token = "hf...", save_method = "mxfp4")

# Merge and push to hub in 16bit
if False:
    model.save_pretrained_merged("finetuned_model", tokenizer, save_method = "merged_16bit")
if False: # Pushing to HF Hub
    model.push_to_hub_merged("hf/gpt-oss-finetune", tokenizer, save_method = "merged_16bit", token = "")

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

**Additional Resources:**

- 📚 [Unsloth Documentation](https://docs.unsloth.ai) - Complete guides and examples
- 💬 [Unsloth Discord](https://discord.gg/unsloth) - Community support
- 📖 [More Notebooks](https://github.com/unslothai/notebooks) - Full collection on GitHub
- 🚀 [Brev Documentation](https://docs.nvidia.com/brev) - Deploy and scale on NVIDIA GPUs

**Additional Resources:**

- 📚 [Unsloth Documentation](https://docs.unsloth.ai) - Complete guides and examples
- 💬 [Unsloth Discord](https://discord.gg/unsloth) - Community support
- 📖 [More Notebooks](https://github.com/unslothai/notebooks) - Full collection on GitHub
- 🚀 [Brev Documentation](https://docs.nvidia.com/brev) - Deploy and scale on NVIDIA GPUs