In [None]:
import torch

# There are 4 main ways to multiply matrices in PyTorch
- `torch.matmul()`
- `torch.mm()`
- `torch.bmm()`
- `torch.mul()`  This does NOT perform traditional matrix multiplication, but is often mentioned along with the others

In [None]:
mat1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.int32)
mat2 = torch.ones([2,2], dtype=torch.int32)
print(f"Matrix 1:\n{mat1}")
print(f"Matrix 2:\n{mat2}")

In [None]:
print("torch.matmul(mat1, mat2):")
print("Most common way to multiply matrices in PyTorch. Will automatically handle broadcasting and can handle higher dimensions.")
print(torch.matmul(mat1, mat2))
print('\n')

print("torch.mm(mat1, mat2):")
print("torch.mm is a specialized function for 2D matrix multiplication. It does not handle broadcasting and expects both inputs to be 2D tensors.")
print(torch.mm(mat1, mat2))
print('\n')

print("torch.bmm(mat1.unsqueeze(0), mat2.unsqueeze(0)):")
print("torch.bmm is used for batch matrix multiplication. It expects 3D tensors where the first dimension is the batch size. Here, we unsqueeze to add a batch dimension.")
print(torch.bmm(mat1.unsqueeze(0), mat2.unsqueeze(0)))
print('\n')

print("torch.mul(mat1, mat2):")
print("torch.mul performs element-wise multiplication. It does not perform matrix multiplication but multiplies corresponding elements of the two tensors.")
print("Note: This is equivalent to using the * operator for element-wise multiplication.")
print(torch.mul(mat1, mat2))
print('\n')

For help with broadcasting, see here:
https://docs.pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics