-
Notifications
You must be signed in to change notification settings - Fork 1
[PERFORMANCE OPTIMIZATION] add dot pytorch kernel #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,21 @@ | ||||||||||
| import torch | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def dot( | ||||||||||
| x: torch.Tensor, | ||||||||||
| y: torch.Tensor, | ||||||||||
| ) -> torch.Tensor: | ||||||||||
| """ | ||||||||||
| Computes the dot product of two tensors using PyTorch operations. | ||||||||||
|
||||||||||
| Computes the dot product of two tensors using PyTorch operations. | |
| Computes the dot product of two tensors by multiplying corresponding elements | |
| and summing the results using PyTorch operations. |
Copilot
AI
Dec 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch implementation does not flatten the input tensors before computing the dot product, unlike the Python reference implementation. The Python reference (lines 19-20) uses x.reshape(-1) and y.reshape(-1) to ensure inputs are 1D vectors.
Without flattening, this implementation will:
- Produce incorrect results for multi-dimensional tensors
- Not be numerically equivalent to the Python reference
- Potentially fail with broadcasting errors for certain input shapes
Recommendation: Add tensor flattening before the multiplication:
x = x.reshape(-1)
y = y.reshape(-1)
z = torch.sum(torch.mul(x, y))| x = x.reshape(-1) | |
| y = y.reshape(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR marks the PyTorch
dotimplementation as complete (✅) but does not include a corresponding test file. All other complete operations (copy, swap, scal, axpby) have test files (test_copy.py, test_swap.py, test_scal.py, test_axpby.py) that validate implementations across backends.The PR description acknowledges this: "A follow‑up
tests/test_dot.pywill treat the Python reference implementation as ground truth..." However, marking the implementation as complete without tests is inconsistent with the existing convention.Recommendation: Either:
tests/test_dot.pyin this PR before marking as ✅, or