Experimental WebGPU backend for PyTorch
Not even 0.0.1 release yet! I make the repository public, so you give me support and I can get some dopamine out of it (building alone, in private, after day job, without a positive feedback - it is quite difficult, at least to me!)
Goals:
- Run PyTorch on WebGPU
device="webgpu" - Compile PyTorch code for WebGPU -
@torch.compile(m, backend=webgpu) - High performance without platform specific (CUDA, MPS, ROCm) kernels. Five ingredients are enough to get there - PyTorch, Python, C++, WGSL shaders and WebGPU runtime. Currently,
torch-webpguuses Google Dawn
Add tensors on WebGPU and move data between CPU and WebGPU!
a = torch.tensor([-1.5, 2.7, 1.0, 2.0], device="webgpu")
b = torch.tensor([-1.0, 0.9, 1.1, -2.1], device="webgpu")
result = a + b
expected = torch.tensor([-2.5, 3.6, 2.1, -0.1], device="cpu")
assert torch.allclose(result.to("cpu"), expected)This is a TL;DR showcase of where we currently are with torch-webgpu. It will get regularly updated when new features land
Only for developers and curious very early adopters
-
Clone this repo
-
Install
google/dawn
Guide: https://github.com/google/dawn/blob/main/docs/quickstart-cmake.md. Set DAWN_PREFIX= to dawn/install/Release based on there is your dawn repo, like DAWN_PREFIX=/home/user/dawn/install/Release
- In this repo, run
./build.sh
In Python:
import torch_webgpu
And now you can use device="webgpu" and to="webgpu" to run pytorch on a real webgpu!
This list helps me pick up what to work on next, aside of adding new ops
- only float32 supported
wgpu::Queue.Submit()handled synchronously- not enough unit tests (a standarized testing out-of-tree backends is still in progress as of Dec 2025, I hope to involve torch-webgpu into this effort)
- some ops might fallback to CPU
- CPU <-> WebGPU
- CUDA <-> WebGPU
- MPS <-> WebGPU
- Intel Gaudi <-> WebGPU
- XLA <-> WebGPU
How serious are you about this project? Is it a research or PoC in mind or are you going to make it production quality?
Once we hit version 1.0.0, torch-webgpu will be a production-ready PyTorch backend. WebGPU is an exciting, emerging technology. As of Nov 2025 all major browsers support WebGPU. I think that it's highly important to build a bridge between PyTorch and WebGPU.
We'll see, ideally I'd see it as a part of PyTorch core, but we need to get a very high quality first to allow ourselves to ask PyTorch maintainers about it
I have a very little time and need to be picky about contributions, so please make sure you contribute code that is:
- well thought
- covered with unit tests
- you understand everything what you wrote
- as concise as possible - I can't handle too big PRs, sorry!
Use LLM at your discretion, but provide exhaustive explanation of what you built and why. Write it by yourself to show that you really understand
I can understand if that sounds too picky, but since I build this project after hours, I need to cut any additional noise. Sorry and thanks for understanding!
That's ok. The main goal here is to build a bridge (for community) and learn ML compilers in depth (for me). The project moves regularly, at its own pace. Things improve, cover more use cases, get more tests, get rethinked and rewrote. A journey, insights and learning over a raw development velocity. That's a tradeoff I choose
You can fund the project to give me more spare time to work on it. My email: github@maczan.pl
- empty.memory_format
- empty_strided
- as_strided
- copy_
- _copy_from
- to.device
- empty_like
- zeros_like
- ones_like
- arange
- full
- rand
- randn
- clone
- to.dtype
- to
- quantize_per_tensor
- dequantize
f32 only for now!
- add.Tensor
- gelu
- silu
- relu
- masked_select
- add.Scalar
- add
- sub.Tensor
- sub
- mul.Tensor
- mul
- div.Tensor
- div
- neg
- pow.Tensor_Scalar
- pow
- sqrt
- rsqrt
- abs
- exp
- log
- tanh
- sigmoid
- clamp_min
- clamp
- round
- floor
- ceil
- minimum
- maximum
- where.self
- where
- masked_fill
- bitwise_and.Tensor
- eq.Tensor
- ne.Scalar
- ne.Tensor
- lt.Tensor
- le.Tensor
- gt.Tensor
- ge.Tensor
- sum.dim_IntList
- sum
- mean.dim
- mean
- amax
- amin
- argmax
- argmin
- var_mean
- topk
- view
- resize
- reshape
- flatten
- permute
- transpose.int
- transpose
- contiguous
- unsqueeze
- squeeze
- cat
- stack
- slice.Tensor
- slice
- select
- narrow
- expand
- broadcast_to
- index_select
- addmm
- mm
- bmm
- matmul
- scaled_dot_product_attention
- _log_softmax
- softmax.int
- softmax
- layer_norm
- native_layer_norm
- rms_norm
- batch_norm
- group_norm
- embedding
- conv2d
- conv2d_backward
- adaptive_avg_pool2d
- max_pool2d
- interpolate
I mainly use Ascend's NPU backend for PyTorch https://github.com/ascend/pytorch, Elie's WebGPU guide https://eliemichel.github.io/LearnWebGPU/index.html, WGSL spec https://www.w3.org/TR/WGSL/ and PyTorch PrivateUse1 custom backend docs as a reference https://docs.pytorch.org/tutorials/advanced/privateuseone.html https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html https://docs.pytorch.org/tutorials/advanced/dispatcher
Note: This project is unrelated to webgpu-torch, which is a neat PyTorch reimplementation in TypeScript targeting WebGPU
