From 8395ca80a792641a34705a9c0d7f65149aa39b48 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 1 Dec 2025 11:22:02 +0800 Subject: [PATCH 1/2] Introduces PyTorch dot helper Provides a reusable tensor dot routine to standardize future vector operations --- kernel_course/pytorch_ops/dot.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 kernel_course/pytorch_ops/dot.py diff --git a/kernel_course/pytorch_ops/dot.py b/kernel_course/pytorch_ops/dot.py new file mode 100644 index 0000000..cdb1115 --- /dev/null +++ b/kernel_course/pytorch_ops/dot.py @@ -0,0 +1,21 @@ +import torch + + +def dot( + x: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """ + Computes the dot product of two tensors using PyTorch operations. + + Args: + x (torch.Tensor): First tensor. + y (torch.Tensor): Second tensor. + + Returns: + torch.Tensor: The dot product of `x` and `y`. + """ + + z = torch.sum(torch.mul(x, y)) + + return z From 5d0bd042bb27c746a61a917dec274b11ba611fbd Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 1 Dec 2025 11:22:12 +0800 Subject: [PATCH 2/2] Marks dot as available in PyTorch Updates the BLAS kernel matrix so the PyTorch column reflects the completed dot implementation, keeping the support table accurate for users. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5cf5b02..53cfe8f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [swap](./docs/swap.md) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [✅](./kernel_course/python_ops/swap.py) | [✅](./kernel_course/pytorch_ops/swap.py) | [✅](./kernel_course/triton_ops/swap.py) | ❌ | [✅](./tests/test_swap.py) | | [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [✅](./kernel_course/python_ops/scal.py) | [✅](./kernel_course/pytorch_ops/scal.py) | [✅](./kernel_course/triton_ops/scal.py) | ❌ | [✅](./tests/test_scal.py) | | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | -| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | ❌ | ❌ | ❌ | ❌ | +| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | ❌ | ❌ | ❌ | | gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | ❌ | ❌ | ❌ | ❌ | ❌ | | geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |