In [None]:
torch.mul()
# Elementwise multiplication
# Broadcasts

torch.dot()
# Dot product of 2 one-dimentional only (!) tensors
# Does no broadcasting

torch.mm()
# Do matrix mul of two 2D matrices
# Also does not broadcast

torch.bmm()
# Do matrix mul of two 3D matrices: (b x n x m) @ (b x m x p) -> (b x n x p)
# Also does not broadcast

torch.matmul() == @
# Depending on input will do dot(), mm() or bmm() broadcasted to tensors of any size
# This does broadcasting so may lead to undesired results, be careful with it



Broadcasting:

Note - depends on operation first! If the operation itself is not broadcasting, then no broadcasting will happend 
despite broadcastable.

Two tensors are broadcastable if (and):

1. Each tensor has at least one dim
2. When iterating of the the dims, starting at the last one (or): 
    The dim sizes must be equal OR
    One of the dim sizes is 1
    One of the dims does not exist

In [2]:
import torch
a = torch.randn(1, 64, 1152, 1, 8)
b = torch.randn(10, 1, 1152, 8, 16)

c = a @ b
c.size()

torch.Size([10, 64, 1152, 1, 16])

What we see is that the first dim in a is copied out 10 times, and same for second dim in b
The 4th dim in c is not a result of broadcasting, there it's just doing (1x8) @ (8x16) -> (1,16)

In [7]:
a = torch.randn(1, 64, 1152, 1, 8)
# You start at the end, so 16 is matched with 8 for the rules (last dimensions first)
b = torch.randn(8, 16)

c = a@b
c.size()

torch.Size([1, 64, 1152, 1, 16])

In [7]:
# Here an example of what we would need to do without broadcasting:

a = torch.arange(4)
# tensor([0, 1, 2, 3])

b = torch.ones(4,4)
#tensor([[1., 1., 1., 1.],
#        [1., 1., 1., 1.],
#        [1., 1., 1., 1.],
#        [1., 1., 1., 1.]])

# Without broacasting:

a = a.expand(4,4)
#tensor([[0, 1, 2, 3],
#        [0, 1, 2, 3],
#        [0, 1, 2, 3],
#        [0, 1, 2, 3]])

a + b

tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])

In [9]:
#With broadcasting:

a + 1.

tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])

Very common case where broadcasting may not tell you you're making an error is in loss functions.

In [13]:
import torch.nn as nn

loss_function = nn.MSELoss()
input = torch.randn(2,3,5)
label = torch.rand(3,5)

loss = loss_function(input, label)
loss

# This mistake was happening so much that the warning below was added in the loss fuction code :)


  return F.mse_loss(input, target, reduction=self.reduction)


tensor(1.7476)

Multidim matrix multiplication
------------------------------
It's always only the last two dims that are matrix multiplied, with our without broadcasting as per all of the above.

In [31]:
# Let's create some 3D matrices

a = torch.ones(3,2,2)
a[0] = 3
a[1] = 4
a[2] = 5

#tensor([[[3., 3.],
#         [3., 3.]],
#        [[4., 4.],
#         [4., 4.]],
#        [[5., 5.],
#         [5., 5.]]])

b = torch.ones(3,2,2)
b[0] = 6
b[1] = 7
b[2] = 8

#tensor([[[6., 6.],
#         [6., 6.]],
#        [[7., 7.],
#         [7., 7.]],
#        [[8., 8.],
#         [8., 8.]]])

#a@b  # We know it's going to be (3,2,2), with matmul on the last two dims

a@b

tensor([[[36., 36.],
         [36., 36.]],

        [[56., 56.],
         [56., 56.]],

        [[80., 80.],
         [80., 80.]]])

In [30]:
test = torch.ones(2,2)
test[0][0] = 0
test[0][1] = 1
test[1][0] = 2
test[1][1] = 3
#test
# Printing it out does so in rxc format
#tensor([[0., 1.],
#        [2., 3.]])


In [35]:
# Let's do 4D!

a = torch.ones(4,3,2,2) #-> We're going to see 3 2x2 matrices each of them with 3s in them, the 3 with 4s in them, etc.
a[0] = 3
a[1] = 4
a[2] = 5
a[3] = 6

# Read this as "we have 4 groups of 3 groups of 2x2 matrices"

#tensor([[[[3., 3.],
#          [3., 3.]],
#         [[3., 3.],
#          [3., 3.]],
#         [[3., 3.],
#          [3., 3.]]],
#        [[[4., 4.],
#          [4., 4.]],
#         [[4., 4.],
#          [4., 4.]],
#         [[4., 4.],
#          [4., 4.]]],
#        [[[5., 5.],
#          [5., 5.]],
#         [[5., 5.],
#          [5., 5.]],
#         [[5., 5.],
#          [5., 5.]]],
#        [[[6., 6.],
#          [6., 6.]],
#         [[6., 6.],
#          [6., 6.]],
#         [[6., 6.],
#          [6., 6.]]]])

b = torch.ones(4,3,2,2)
b[0] = 7
b[1] = 8
b[2] = 9
b[3] = 10

a@b

tensor([[[[ 42.,  42.],
          [ 42.,  42.]],

         [[ 42.,  42.],
          [ 42.,  42.]],

         [[ 42.,  42.],
          [ 42.,  42.]]],


        [[[ 64.,  64.],
          [ 64.,  64.]],

         [[ 64.,  64.],
          [ 64.,  64.]],

         [[ 64.,  64.],
          [ 64.,  64.]]],


        [[[ 90.,  90.],
          [ 90.,  90.]],

         [[ 90.,  90.],
          [ 90.,  90.]],

         [[ 90.,  90.],
          [ 90.,  90.]]],


        [[[120., 120.],
          [120., 120.]],

         [[120., 120.],
          [120., 120.]],

         [[120., 120.],
          [120., 120.]]]])

Difference between view and reshape

In [44]:
a = torch.arange(24)
print(a.view(4,6))
print(a)
print(a.reshape(4,6))
print(a)


tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23])


They seems to be doing the same, however view only works on contiguous data, so it won't work after you have transposed

In [60]:
a = a.view(4,6) # Can also use reshape, same result
a_T = a.transpose(0,1)
print(a)
print(a_T)
print(a.data_ptr()) # Returns address of first element in tensor
print(a_T.data_ptr()) # It's still in the same starting location, but has been rearranged and no longer contiguous

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
tensor([[ 0,  6, 12, 18],
        [ 1,  7, 13, 19],
        [ 2,  8, 14, 20],
        [ 3,  9, 15, 21],
        [ 4, 10, 16, 22],
        [ 5, 11, 17, 23]])
4570799040
4570799040


In [69]:
#a_T.view(4,6) # Error, because:

print(a_T.is_contiguous())

#However: (at this point a_T is (6,4) because it's transposed.
a_T_reshaped = a_T.reshape(4,6) # Reshape will make a new copy of the non-contiguous tensor, 
                                # but only if it really has to reshape, so NOT is we use (6,4) here.
print(a_T_reshaped)
print(a_T_reshaped.data_ptr())

False
tensor([[ 0,  6, 12, 18,  1,  7],
        [13, 19,  2,  8, 14, 20],
        [ 3,  9, 15, 21,  4, 10],
        [16, 22,  5, 11, 17, 23]])
4319112960
