In [1]:
import torch
import torch.nn.functional as f

In [2]:
img = torch.arange(25).view(1,1,5,5).float()
img

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]]]])

In [8]:
H,W = 5,5
xs = torch.linspace(-1,1,W)
ys = torch.linspace(-1,1,H)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')

grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0)
grid.shape


torch.Size([1, 5, 5, 2])

In [13]:
out = f.grid_sample(img,grid,align_corners=True)

In [14]:
print("input:\n", img.squeeze())
print("output:\n", out.squeeze())

input:
 tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24.]])
output:
 tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24.]])


In [19]:
disp = torch.tensor([0.3, 0.0]).view(1, 1, 1, 2)
print(disp)
grid_trans = grid+disp
out_trans = f.grid_sample(img, grid_trans, align_corners=True)
print("input:\n", img.squeeze())
print("output:\n", out_trans.squeeze())

tensor([[[[0.3000, 0.0000]]]])
input:
 tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24.]])
output:
 tensor([[ 0.6000,  1.6000,  2.6000,  3.6000,  1.6000],
        [ 5.6000,  6.6000,  7.6000,  8.6000,  3.6000],
        [10.6000, 11.6000, 12.6000, 13.6000,  5.6000],
        [15.6000, 16.6000, 17.6000, 18.6000,  7.6000],
        [20.6000, 21.6000, 22.6000, 23.6000,  9.6000]])


In [25]:
import math

theta = math.radians(30)
A = torch.tensor([
    [ 2*math.cos(theta), -2*math.sin(theta)],
    [ 2*math.sin(theta),  2*math.cos(theta)]
]).float()   # shape (2,2)
print(A.shape)
b = torch.tensor([0.0, 0.0])   # translation vector
print(b.shape)
print(grid.shape)

# base grid is (H,W,2), want (H,W,2)
grid_affine = torch.einsum('ij,nhwj->hwi', A, grid)+b

out_aff = f.grid_sample(img, grid_affine.unsqueeze(0), align_corners=True)

print(out_aff.squeeze())


torch.Size([2, 2])
torch.Size([2])
torch.Size([1, 5, 5, 2])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.3397,  4.3397,  2.7705,  0.0000],
        [ 0.0000,  5.2679, 12.0000, 18.7321,  0.0000],
        [ 0.0000,  3.6603, 19.6603,  6.0910,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


In [None]:
# 1. Matrix multiplication (A @ B)
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# TODO: write your einsum
C1 = torch.einsum('ab,bc->ac', A, B)
print("Q1 shape:", C1.shape)   # expected: (2,4)

# 2. Batched matrix multiplication
A = torch.randn(5, 2, 3)   # (N,2,3)
B = torch.randn(5, 3, 4)   # (N,3,4)
# TODO
C2 = torch.einsum('nab,nbc->nac', A, B)
print("Q2 shape:", C2.shape)   # expected: (5,2,4)

# --------------------------
# Level 2: Dot product & weighted sum
# --------------------------

# 3. Pixelwise dot product over last dimension
x = torch.randn(10, 20, 3)
y = torch.randn(10, 20, 3)
# TODO
dot = torch.einsum('abc,abc -> ab', x, y)
print("Q3 shape:", dot.shape)   # expected: (10,20)


# 4. Row-wise weighted sum (sum over W axis)
A = torch.randn(6, 8)   # (H,W)
w = torch.randn(6)      # (H,)
# TODO
row_sum = torch.einsum('ij,i->i', A, w)
print("Q4 shape:", row_sum.shape)   # expected: (6,)

# --------------------------
# Level 3: 2D Affine Warp
# --------------------------

# 5. 2D affine warp (no batch)
A = torch.randn(2, 2)
grid = torch.randn(30, 40, 2)   # (H,W,2)
# TODO
grid2 = torch.einsum('ij,hwj->hwi', A, grid)
print("Q5 shape:", grid2.shape)   # expected: (30,40,2)


# 6. 2D affine warp with batch
A = torch.randn(5, 2, 2)         # (N,2,2)
grid = torch.randn(5, 30, 40, 2) # (N,H,W,2)
# TODO
grid2b = torch.einsum('nij,nhwj->nhwi', A, grid)
print("Q6 shape:", grid2b.shape)  # expected: (5,30,40,2)

# --------------------------
# Level 4: 3D Affine Warp
# --------------------------

# 7. 3D affine warp (no batch)
A = torch.randn(3, 3)
grid = torch.randn(10, 20, 30, 3)   # (D,H,W,3)
# TODO
grid3 = torch.einsum('ij,dhwj->dhwi', A, grid)
print("Q7 shape:", grid3.shape)     # expected: (10,20,30,3)


# 8. 3D affine warp with batch
A = torch.randn(4, 3, 3)             # (N,3,3)
grid = torch.randn(4, 10, 20, 30, 3) # (N,D,H,W,3)
# TODO
grid3b = torch.einsum('nij,ndhwj->ndhwi', A, grid)
print("Q8 shape:", grid3b.shape)     # expected: (4,10,20,30,3)


# --------------------------
# Level 5: Displacement add
# --------------------------

# 9. 2D displacement addition using einsum (no '+')
I = torch.eye(2)
grid = torch.randn(50, 60, 2)
disp = torch.randn(50, 60, 2)
# TODO: use einsum to compute grid + disp = I @ grid + disp
grid_disp = torch.einsum('ij,hwj->hwi', I, grid) + disp
print("Q9 shape:", grid_disp.shape)  # expected: (50,60,2)


# 10. 3D displacement addition with batch using einsum
I = torch.eye(3)
grid = torch.randn(2, 10, 20, 30, 3)
disp = torch.randn(2, 10, 20, 30, 3)
# TODO
grid_disp3 = torch.einsum('ij,ndhwj->ndhwi', I, grid) + disp
print("Q10 shape:", grid_disp3.shape)   # expected: (2,10,20,30,3)

Q1 shape: torch.Size([2, 4])
Q2 shape: torch.Size([5, 2, 4])
Q3 shape: torch.Size([10, 20])
Q4 shape: torch.Size([6])
Q5 shape: torch.Size([30, 40, 2])
Q6 shape: torch.Size([5, 30, 40, 2])
Q7 shape: torch.Size([10, 20, 30, 3])
Q8 shape: torch.Size([4, 10, 20, 30, 3])
Q9 shape: torch.Size([50, 60, 2])
Q10 shape: torch.Size([2, 10, 20, 30, 3])


Q3 shape: torch.Size([10, 20])
Q4 shape: torch.Size([6])


Q5 shape: torch.Size([30, 40, 2])
Q6 shape: torch.Size([5, 30, 40, 2])
