# How to write Pytorch code that isn't (too) bad

This is a guide to avoiding some common pitfalls, not to writing the fastest code possible (you probably wouldn't have time for that anyway).

### Motivation

- Compute is an expensive shared resource.
- Wasting compute means literally just creating entropy (__CLIMATE CHANGE IS A THING__).
- Other people also have projects they want to do.

### Some online guides that go more in-depth

- [Official Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [Official PyTorch Guide on CUDA](https://docs.pytorch.org/docs/stable/notes/cuda.html)

### A Note on training on more than one GPU

Long story short, you probably don't need to.

#### But I want my model to be faster

PyTorch supports training on more than one GPU with [Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html). Implementing it can be fairly straightforward, depending on usecase.

_However_, training performance usually scales sublinear with GPU count, that is to say, the performance gain is likely not going to be that big, especially considering the increased usage of limited shared resources.

Training on multiple GPUs introduces significant, unavoidable overhead. If you don't know what you are doing, spending effort on doing it is likely not worth it for student projects.

#### But my model doesn't fit in one GPU

There are models (especially modern LLMs) that require more than the 50-80 GBs of V-Ram a typical Datacenter GPU can offer. If you need to work with a model like that, multi GPU might be hard to avoid. Luckily for you, most libraries centered around these models (like [transformers](https://huggingface.co/docs/transformers/index)) will handle the hard work for you. Consult their documentation.

If you are working with large LLMs via Transformers, consider using a [Quantized Model](https://huggingface.co/docs/transformers/quantization/overview) and consult their [documentation](https://huggingface.co/docs/transformers/v5.0.0rc2/en/llm_tutorial_optimization#1-lower-precision) for other ways of reducing memory requirements, such as lowering precision or paging parts of the model. Multi-GPU should be treated as a last resort.

### We now switch to a bigger Model to better see the speedups

(The code below is based on the aforementioned GPT2 from scratch tutorial.)

In [17]:
from src.training import *
device = "cuda:2"

In [18]:
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
enc = tiktoken.get_encoding('gpt2')
train_loader = DataLoader(B, T, 1, 1)


start = time.time()
avg_batch_time = 0
for i in range(max_steps):
    start_batch = time.time()

    optimizer.zero_grad()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    logits, loss = model(x, y)

    loss.backward()
    optimizer.step()

    torch.cuda.synchronize()
    end_batch = time.time()
    batch_time = int((end_batch - start_batch) * 1000)
    print(f"step {i + 1}, {batch_time}ms elapsed")
    avg_batch_time += batch_time
end = time.time()
print(f"{int((end - start) * 1000)}ms elapsed, {avg_batch_time / max_steps}ms avg batch time")

loaded 338020 tokens
1 epoch = 20 batches
step 1, 803ms elapsed
step 2, 554ms elapsed
step 3, 588ms elapsed
step 4, 572ms elapsed
step 5, 574ms elapsed
step 6, 580ms elapsed
step 7, 574ms elapsed
step 8, 581ms elapsed
step 9, 580ms elapsed
step 10, 575ms elapsed
step 11, 587ms elapsed
step 12, 577ms elapsed
step 13, 582ms elapsed
step 14, 581ms elapsed
step 15, 581ms elapsed
step 16, 577ms elapsed
step 17, 577ms elapsed
step 18, 583ms elapsed
step 19, 588ms elapsed
step 20, 583ms elapsed
step 21, 586ms elapsed
step 22, 586ms elapsed
step 23, 591ms elapsed
step 24, 578ms elapsed
step 25, 603ms elapsed
step 26, 584ms elapsed
step 27, 602ms elapsed
step 28, 588ms elapsed
step 29, 604ms elapsed
step 30, 588ms elapsed
step 31, 595ms elapsed
step 32, 594ms elapsed
step 33, 599ms elapsed
step 34, 590ms elapsed
step 35, 604ms elapsed
step 36, 592ms elapsed
step 37, 597ms elapsed
step 38, 599ms elapsed
step 39, 598ms elapsed
step 40, 601ms elapsed
step 41, 601ms elapsed
step 42, 603ms elapsed
s

### Reducing precision for performance gains

#### Matrix multiplication

`torch.set_float32_matmul_precision` reduces precision during matrix multiplication on CUDA devices to accelerate computation with minimal numerical effects ([Details](https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html)). This allows the GPU to uses dedicated "Tensor Cores" which are optimized 16bit matrix multiplication.

Your GPU might not support this (Nvidia GPUs for datacenters like KIGS or the HPC in Erlangen do). Most libraries that implements models (e.g. transformers etc.) will either do this by default or let you enable it through them.

#### Autocast

Using `with torch.autocast` automatically reduces the length of Mantissa in 32-bit floating point numbers from 23 to 10 (TensorFloat-32) or 7 (bfloat16) in supported operations while keeping exponent length the same to preserve range of values. This improves performance and memory usage at usually negligible costs to precision. It's best practice to only wrap the forward pass in the autocast, as the loss of precision can lead to issues in other places in some cases.

In [19]:
#reduce precision for matrix multiplication
torch.set_float32_matmul_precision('high') #TF32
torch.set_float32_matmul_precision('medium') #BF16

model = GPT(GPTConfig(vocab_size=50304))
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
enc = tiktoken.get_encoding('gpt2')
train_loader = DataLoader(B, T, 1, 1)


start = time.time()
avg_batch_time = 0
for i in range(max_steps):
    start_batch = time.time()

    optimizer.zero_grad()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16): #bf16 for forward pass only
        logits, loss = model(x, y)

    loss.backward()
    optimizer.step()

    torch.cuda.synchronize()
    end_batch = time.time()
    batch_time = int((end_batch - start_batch) * 1000)
    print(f"step {i + 1}, {batch_time}ms elapsed")
    avg_batch_time += batch_time
end = time.time()
print(f"{int((end - start) * 1000)}ms elapsed, {avg_batch_time / max_steps}ms avg batch time")

loaded 338020 tokens
1 epoch = 20 batches
step 1, 194ms elapsed
step 2, 152ms elapsed
step 3, 148ms elapsed
step 4, 151ms elapsed
step 5, 148ms elapsed
step 6, 152ms elapsed
step 7, 150ms elapsed
step 8, 151ms elapsed
step 9, 150ms elapsed
step 10, 150ms elapsed
step 11, 150ms elapsed
step 12, 149ms elapsed
step 13, 143ms elapsed
step 14, 153ms elapsed
step 15, 149ms elapsed
step 16, 151ms elapsed
step 17, 150ms elapsed
step 18, 150ms elapsed
step 19, 150ms elapsed
step 20, 150ms elapsed
step 21, 154ms elapsed
step 22, 149ms elapsed
step 23, 151ms elapsed
step 24, 150ms elapsed
step 25, 151ms elapsed
step 26, 150ms elapsed
step 27, 151ms elapsed
step 28, 151ms elapsed
step 29, 148ms elapsed
step 30, 152ms elapsed
step 31, 148ms elapsed
step 32, 152ms elapsed
step 33, 150ms elapsed
step 34, 150ms elapsed
step 35, 151ms elapsed
step 36, 152ms elapsed
step 37, 151ms elapsed
step 38, 149ms elapsed
step 39, 151ms elapsed
step 40, 149ms elapsed
step 41, 151ms elapsed
step 42, 152ms elapsed
s

## Looking under the hood

### Resources
- [Official PyTorch Docs for compiler](https://docs.pytorch.org/docs/stable/torch.compiler.html)
- [Official Tutorial for working with compile](https://docs.pytorch.org/docs/stable/compile/programming_model.html)
- [A guide on integrating control flow into high performance PyTorch](https://blog.ezyang.com/2025/09/so-you-want-to-control-flow-in-pt2/)
- [torch.compile, the missing manual](https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?tab=t.0#heading=h.ivdr7fmrbeab) - this is probably the best resource for working with compile out there ([There is also an overview video for it](https://www.youtube.com/live/rew5CSUaIXg))







### Autograd

Autograd is the engine that pytorch uses to perform automatic differentiation. When a operation is performed on a tensor, the tensor stores that information, building a graph like structure that contains a "history" of what happened to that tensor. According to the chain rule, one can now differentiate all those operations in reverse order to arrive at the derivative of the tensor. Autograd does this when `.backward()` is called.

### Dynamo

Dynamo is an at runtime just in time compiler that turns your Python code into a computation graph. This computation graph is then used by a Compiler like Triton to produce bytecode which can run much faster than native Python.

Given the more static nature of the produced bytecode, Dynamo uses so-called _Guards_ to perform periodic checks on the processed tensors. This prevents errors due to shape mismatches etc. by forcing a recompilations if inputs change. These recompilations slow performance and should generally be avoided.

### Inductor

Inductor serves as interface between low level backends like Triton and Dynamo. It also optimizes the graphs produced by Dynamo to produce better Kernels

### The CPU-GPU bottleneck

PyTorch can, without any special code, run processes asynchronously on the CPU and GPU.

The central element to doing so efficiently is understanding how PyTorch handles tensors on the GPU:

Operations on the GPU aren't actually completed when their line of Python code is executed, merely scheduled. This means that any subsequent line of Python code can continue to run, __so long as it doesn't depend on the result of the GPU operation__.

This means that there are a set of operations that force the GPU and CPU to synchronize, which can lead to (sometimes massive) idle time. These operations include:

- Moving tensors, models, etc. between CPU and GPU (with `.to(device)`, `.cpu()`, `.value()`, explicit casting i.e. `.numpy()` or `int(...)`, etc.)
- Any Python control structure (if, while, for, ...) that uses an operation on a GPU tensor to check for a truth value or iterates on one.
- Explicitly calling `torch.cuda.synchronize()` (this has some actual use cases though, like accurately timing GPU execution time in the examples I have used)

### Multithreading is very simple and you should probably use it

Many model pipelines contain various preprocessing, monitoring, evaluation and logging steps that need to be performed by the cpu. As explained in the previous section, certain operations force synchronization between the cpu and gpu. The multithreading in regular python provides a very simple way for GPU and CPU processes to run simultaneously without inhibiting each other. A general guideline here is to start threads as early as possible and rejoin them as late as possible

#### Very basic Strategy:
- Time all the GPU and CPU operations separately by using `torch.cuda.synchronize()` to force python to wait for GPU operations to finish.
- Figure out what CPU operations take significant time and try to either
    - A: Run them in parallel to GPU operations by avoiding any CPU-GPU synchronization
    - B: explicitely multithreading them

In [20]:
#these are dummy functions used for illustration purposes
import time

def preprocessing(iteration): #example of process that needs to run before the training step
    time.sleep(.1)

def evaluate(result): #example of process that needs to run after the training step and requires the result
    time.sleep(.4)


model = GPT(GPTConfig(vocab_size=50304))
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
enc = tiktoken.get_encoding('gpt2')
train_loader = DataLoader(B, T, 1, 1)


start = time.time()
avg_batch_time = 0
for i in range(max_steps):
    start_batch = time.time()

    preprocessing(i)
    optimizer.zero_grad()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16): #bf16 for forward pass only
        logits, loss = model(x, y)

    loss.backward()
    optimizer.step()

    evaluate(loss.cpu().detach())

    torch.cuda.synchronize()
    end_batch = time.time()
    batch_time = int((end_batch - start_batch) * 1000)
    print(f"step {i + 1}, {batch_time}ms elapsed")
    avg_batch_time += batch_time
end = time.time()
print(f"{int((end - start) * 1000)}ms elapsed, {avg_batch_time / max_steps}ms avg batch time")

loaded 338020 tokens
1 epoch = 20 batches
step 1, 657ms elapsed
step 2, 650ms elapsed
step 3, 647ms elapsed
step 4, 652ms elapsed
step 5, 647ms elapsed
step 6, 651ms elapsed
step 7, 644ms elapsed
step 8, 659ms elapsed
step 9, 646ms elapsed
step 10, 649ms elapsed
step 11, 650ms elapsed
step 12, 648ms elapsed
step 13, 646ms elapsed
step 14, 647ms elapsed
step 15, 654ms elapsed
step 16, 648ms elapsed
step 17, 645ms elapsed
step 18, 647ms elapsed
step 19, 646ms elapsed
step 20, 649ms elapsed
step 21, 652ms elapsed
step 22, 646ms elapsed
step 23, 651ms elapsed
step 24, 647ms elapsed
step 25, 650ms elapsed
step 26, 648ms elapsed
step 27, 648ms elapsed
step 28, 650ms elapsed
step 29, 649ms elapsed
step 30, 647ms elapsed
step 31, 654ms elapsed
step 32, 648ms elapsed
step 33, 648ms elapsed
step 34, 649ms elapsed
step 35, 648ms elapsed
step 36, 645ms elapsed
step 37, 652ms elapsed
step 38, 645ms elapsed
step 39, 652ms elapsed
step 40, 644ms elapsed
step 41, 649ms elapsed
step 42, 649ms elapsed
s

In [21]:
import threading
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
enc = tiktoken.get_encoding('gpt2')
train_loader = DataLoader(B, T, 1, 1)


start = time.time()
avg_batch_time = 0
preprocessing_thread = threading.Thread(target=preprocessing, args=(0,))
preprocessing_thread.start() # start immediately for first sample

eval_threads = [] #collect evaluation threads

for i in range(max_steps):
    start_batch = time.time()

    optimizer.zero_grad()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    preprocessing_thread.join() # needs to finish before step is performed
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16): #bf16 for forward pass only
        logits, loss = model(x, y)

    #restart preprocessing immediately for next iteration
    preprocessing_thread = threading.Thread(target=preprocessing, args=(i+1,))
    preprocessing_thread.start()

    loss.backward()
    optimizer.step()

    #start evaluation as soon as possible
    evaluation_thread = threading.Thread(target=evaluate, args=(loss.cpu().detach(),))
    evaluation_thread.start()

    #collect evaluation thread so they can be stopped later
    eval_threads.append(evaluation_thread)

    torch.cuda.synchronize()
    end_batch = time.time()
    batch_time = int((end_batch - start_batch) * 1000)
    print(f"step {i + 1}, {batch_time}ms elapsed")
    avg_batch_time += batch_time

for t in eval_threads:
    t.join()

end = time.time()
print(f"{int((end - start) * 1000)}ms elapsed, {avg_batch_time / max_steps}ms avg batch time")

loaded 338020 tokens
1 epoch = 20 batches
step 1, 255ms elapsed
step 2, 148ms elapsed
step 3, 151ms elapsed
step 4, 153ms elapsed
step 5, 153ms elapsed
step 6, 155ms elapsed
step 7, 150ms elapsed
step 8, 150ms elapsed
step 9, 153ms elapsed
step 10, 149ms elapsed
step 11, 151ms elapsed
step 12, 149ms elapsed
step 13, 151ms elapsed
step 14, 155ms elapsed
step 15, 147ms elapsed
step 16, 151ms elapsed
step 17, 153ms elapsed
step 18, 158ms elapsed
step 19, 151ms elapsed
step 20, 154ms elapsed
step 21, 155ms elapsed
step 22, 153ms elapsed
step 23, 151ms elapsed
step 24, 155ms elapsed
step 25, 151ms elapsed
step 26, 154ms elapsed
step 27, 149ms elapsed
step 28, 149ms elapsed
step 29, 153ms elapsed
step 30, 151ms elapsed
step 31, 151ms elapsed
step 32, 152ms elapsed
step 33, 156ms elapsed
step 34, 152ms elapsed
step 35, 155ms elapsed
step 36, 152ms elapsed
step 37, 152ms elapsed
step 38, 150ms elapsed
step 39, 153ms elapsed
step 40, 157ms elapsed
step 41, 151ms elapsed
step 42, 153ms elapsed
s



### torch.compile is really neat

torch.compile is what calls Dynamo to optimize your code. It can be used in various ways, usually either on a `nn.Module` (inference) or a function (train loop). How, when and where it is used has a big impact on results, as do certain arguments. I do not understand it well enough to give general advice.


#### Setup
You will need the triton backend to use compile on Nvidia GPUs. It can usually be installed like any python package. A [Windows version](https://github.com/woct0rdho/triton-windows) also exists now. There are other backends for other usecases/hardware.


In [22]:
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)

model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
enc = tiktoken.get_encoding('gpt2')
train_loader = DataLoader(B, T, 1, 1)


def train(it):
    optimizer.zero_grad()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        logits, loss = model(x, y)

    loss.backward()
    optimizer.step()


start = time.time()
avg_batch_time = 0
for i in range(max_steps):
    start_batch = time.time()

    train(i)

    torch.cuda.synchronize()
    end_batch = time.time()
    batch_time = int((end_batch - start_batch) * 1000)
    print(f"step {i + 1}, {batch_time}ms elapsed")
    avg_batch_time += batch_time
end = time.time()
print(f"{int((end - start) * 1000)}ms elapsed, {avg_batch_time / max_steps}ms avg batch time")

loaded 338020 tokens
1 epoch = 20 batches
step 1, 2150ms elapsed
step 2, 124ms elapsed
step 3, 130ms elapsed
step 4, 131ms elapsed
step 5, 131ms elapsed
step 6, 130ms elapsed
step 7, 129ms elapsed
step 8, 130ms elapsed
step 9, 130ms elapsed
step 10, 127ms elapsed
step 11, 130ms elapsed
step 12, 129ms elapsed
step 13, 131ms elapsed
step 14, 130ms elapsed
step 15, 129ms elapsed
step 16, 126ms elapsed
step 17, 129ms elapsed
step 18, 133ms elapsed
step 19, 135ms elapsed
step 20, 132ms elapsed
step 21, 132ms elapsed
step 22, 131ms elapsed
step 23, 130ms elapsed
step 24, 128ms elapsed
step 25, 134ms elapsed
step 26, 134ms elapsed
step 27, 132ms elapsed
step 28, 137ms elapsed
step 29, 135ms elapsed
step 30, 131ms elapsed
step 31, 131ms elapsed
step 32, 132ms elapsed
step 33, 134ms elapsed
step 34, 126ms elapsed
step 35, 134ms elapsed
step 36, 132ms elapsed
step 37, 135ms elapsed
step 38, 130ms elapsed
step 39, 135ms elapsed
step 40, 135ms elapsed
step 41, 133ms elapsed
step 42, 132ms elapsed


#### Some notes and guidelines
- The `reduce-overhead` argument can help with compiling small models/batches (try it out in each usecase)
- Graph breaks and recompilations kill performance. Make your performance critical code as static as possible (no dynamic shape changes, no control flow) and avoid putting python operations into it.
- Avoid inplace operations when possible (use `x = x+1` instead of `x.add(1)` etc.), AOTAutograd sometimes has issues with them and they can cause weird issues.

#### Debugging
Run something like
```
TORCH_LOGS=+all python gpt_demo.py
```
to produce logs for Dynamo, Autograd and Inductor. This will probably be overwhelming and not very helpful. Check out [this page](https://docs.pytorch.org/docs/stable/compile/programming_model.observability.html#torch-logs) for more info on more finegrained arguments. Passing `recompiles` instead of `all` can often be a good first step.

TORCH_TRACE is also a thing, run it if it works (I have had issues with it before, not sure why).




### Avoiding Computation of Unnecessary Gradients

Autograd computes gradients by default on any tensor that has. They should be disabled otherwise.

In [23]:
x = torch.tensor([1,2,3])
x.requires_grad

False

In [24]:
import torch.nn as nn
import torch
layer = nn.Linear(100, 100).to("cuda:2")
layer.weight.requires_grad

True

In [25]:
x = torch.rand(100, device="cuda:2")
y_hat = layer(x)
print(y_hat.grad_fn)

<ViewBackward0 object at 0x77f231013e20>


`detach()` deletes the accumulated gradient.

In [26]:
print(y_hat.detach().grad_fn)

None


`with torch.no_grad()` decorator disables gradient computation in it's context. This is useful for forward passes during inference.

In [27]:
with torch.no_grad():
    x = torch.rand(100, device="cuda:2")
    y_hat = layer(x)
    print(y_hat.grad_fn)

None
