Skip to content

Fix mx.prod vjp for complex types#3433

Merged
angeloskath merged 1 commit intoml-explore:mainfrom
CameronChurchwell:complex_prod_fix
Apr 22, 2026
Merged

Fix mx.prod vjp for complex types#3433
angeloskath merged 1 commit intoml-explore:mainfrom
CameronChurchwell:complex_prod_fix

Conversation

@CameronChurchwell
Copy link
Copy Markdown
Contributor

Fix bug in VJP for Product Reduction

I was running some experiments and noticed some strange gradient behavior. After some digging, it turned out the issue was with how mx.prod handles complex numbers. My understanding is that MLX uses the conjugate-transpose formulation as seen in Matmul::vjp here:

std::vector<array> Matmul::vjp(
    const std::vector<array>& primals,
    const std::vector<array>& cotangents,
    const std::vector<int>& argnums,
    const std::vector<array>&) {
  std::vector<array> vjps;
  auto& cotan = cotangents[0];
  std::vector<int> reorder(cotan.ndim());
  std::iota(reorder.begin(), reorder.end(), 0);
  std::iter_swap(reorder.end() - 1, reorder.end() - 2);
  auto& s = stream();

  auto complex_transpose = [&](const array& x) {
    return transpose(conjugate(x, s), reorder, s);
  };

  for (auto arg : argnums) {
    if (arg == 0) {
      // M X N * (K X N).T -> M X K
      vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s));
    } else {
      // (M X K).T * M X N -> K X N
      vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s));
    }
  }
  return vjps;
}

And also in Multiply::vjp here:

std::vector<array> Multiply::vjp(
    const std::vector<array>& primals,
    const std::vector<array>& cotangents,
    const std::vector<int>& argnums,
    const std::vector<array>&) {
  std::vector<array> vjps;
  for (auto arg : argnums) {
    vjps.push_back(multiply(
        conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
  }
  return vjps;
}

This PR fixes Reduce::vjp to use the same convention by conjugating exclusive_prod before multiplying it with the cotangent.

I added a test which checks that the updated vjp agrees both with a hand-computed version as well as with the vjp implementation of multiply, since they should always agree. I also tested these changes in a toy example and in a tiny experiment and the gradient divergence issues I was seeing before are gone now.

To my knowledge, no documentation needs to be updated.

One thing to note is that there may be other operations in MLX which also fail to follow this convention. I have checked fft, ifft, and multiply since those are the other ops I was using at the time, and those appear to be fine.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks that's a great catch!

@angeloskath angeloskath merged commit 68cf2fd into ml-explore:main Apr 22, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants