In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# Code to set up the assignment
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

!pip3 install pybind11

In [None]:
!make

In [None]:
!python3 -m pytest -l -v -k "pad_forward"

In [2]:
import sys
sys.path.append('./python')

In [3]:
import needle as ndl
import needle.ops as ops
import numpy as np
import torch

### Compare grads

In [95]:
ndl_tensor = ndl.Tensor(a, requires_grad=True)
b = ndl_tensor.maximum(0)
out = b.sum()
out.backward()
ndl_tensor.grad

out_grad: [1. 1. 1.], out_grad_shape: (3,) out_shape: (1, 3),  in_shape: (4, 3)


needle.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])

In [88]:
torch_tensor = torch.tensor(a, requires_grad=True)
torch_b = torch.softmax(torch_tensor, 0)
out = torch_b.sum()
out.backward()
torch_tensor.grad

tensor([[0.0000e+00, 2.0976e-17, 0.0000e+00],
        [0.0000e+00, 3.7275e-17, 0.0000e+00],
        [0.0000e+00, 1.8664e-17, 0.0000e+00],
        [0.0000e+00, 3.4107e-17, 0.0000e+00]], dtype=torch.float64)

In [93]:
torch_b

tensor([[0.1794, 0.1889, 0.2503],
        [0.2095, 0.3357, 0.2113],
        [0.2759, 0.1681, 0.2383],
        [0.3352, 0.3072, 0.3001]], dtype=torch.float64,
       grad_fn=<SoftmaxBackward0>)

In [117]:
T, d = 100, 128
attn = torch.nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
N = 10
M = torch.triu(-float("inf")*torch.ones(T,T),1)
X = torch.randn(N, T, d)
Y_, A_ = attn(X,X,X, attn_mask=M)

In [118]:
W_KQV = attn.in_proj_weight.detach().numpy().T

In [119]:
X.shape

torch.Size([10, 100, 128])

In [120]:
W_KQV.shape

(128, 384)

In [122]:
W_out = attn.out_proj.weight.detach().numpy().T
W_out.shape

(128, 128)

### Compare matmul performance

In [112]:
a = np.random.random((128, 128))
b = np.random.random((128, 512))

numpy_array_a = ndl.Tensor(a, device=ndl.cpu_numpy())
numpy_array_b = ndl.Tensor(b, device=ndl.cpu_numpy())

cpu_array_a = ndl.Tensor(a, device=ndl.cpu())
cpu_array_b = ndl.Tensor(b, device=ndl.cpu())

In [113]:
%%timeit
numpy_array_a @ numpy_array_b

922 µs ± 236 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [116]:
%%timeit
cpu_array_a @ cpu_array_b

4.73 ms ± 34.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [115]:
%%timeit
a @ b

332 µs ± 62.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Batch matmul

In [117]:
device = ndl.cpu()

In [170]:
b1_shape = (10, 4, 100, 32)
b2_shape = (10, 4, 32, 100)

In [180]:
a = np.random.random(b1_shape)
b = np.random.random(b2_shape)
res_ = (a @ b)
res_.shape

(10, 4, 100, 100)

In [179]:
(a.reshape(-1, 32) @ b.swapaxes(-1, -2).reshape(-1, 32).swapaxes(-1, -2))

ValueError: cannot reshape array of size 16000000 into shape (10,4,100,100)

In [178]:
10 * 4 * 100 * 100

400000

In [159]:
res = a.reshape(-1, 16) @ b
res = res.reshape((10, 128, 64))
res.shape

(10, 128, 64)

In [184]:
torch_bmm = torch.bmm(torch.tensor(a), torch.tensor(b).reshape(1, 16, 64).broadcast_to((10, 16, 64)))
torch_bmm.shape

RuntimeError: shape '[1, 16, 64]' is invalid for input of size 128000

In [161]:
np.linalg.norm(torch_bmm.numpy() - res)

0.0

### Batch matmul numpy backend

In [255]:
ndl_a = ndl.Tensor(a, device=ndl.cpu_numpy())
ndl_b = ndl.Tensor(b, device=ndl.cpu_numpy())

ndl_a.cached_data = ndl_a.cached_data.numpy()
ndl_b.cached_data = ndl_b.cached_data.numpy()

ndl_a.shape, ndl_a.shape

((5, 4, 2, 3), (5, 4, 2, 3))

In [256]:
res = ndl_a @ ndl_b

In [257]:
res.cached_data = ndl.NDArray(res.cached_data)
# ndl_a.cached_data = ndl.NDArray(ndl_a.cached_data)
# ndl_b.cached_data = ndl.NDArray(ndl_b.cached_data)

In [258]:
out_ = res.sum()
out_.backward()

AssertionError: 

In [186]:
b1_shape = (5, 4, 2, 3)
b2_shape = (5, 4, 3, 2)

a = np.random.random(b1_shape)
b = np.random.random(b2_shape)

In [210]:
torch_a = torch.tensor(a, requires_grad=True)
torch_b = torch.tensor(b, requires_grad=True)

# torch_bmm = torch.bmm(torch_a, torch_b)
torch_bmm = torch_a @ torch_b
out = torch_bmm.sum()
out.backward()

In [218]:
torch_b.grad

tensor([[[[0.8718, 0.8718],
          [0.9529, 0.9529],
          [1.1576, 1.1576]],

         [[1.6606, 1.6606],
          [0.4437, 0.4437],
          [1.0558, 1.0558]],

         [[1.7624, 1.7624],
          [0.6240, 0.6240],
          [0.4162, 0.4162]],

         [[1.1117, 1.1117],
          [0.9589, 0.9589],
          [0.7407, 0.7407]]],


        [[[1.2667, 1.2667],
          [1.0882, 1.0882],
          [0.8433, 0.8433]],

         [[0.7540, 0.7540],
          [0.3090, 0.3090],
          [1.6774, 1.6774]],

         [[1.3287, 1.3287],
          [0.9941, 0.9941],
          [1.1644, 1.1644]],

         [[1.4826, 1.4826],
          [1.7378, 1.7378],
          [1.2141, 1.2141]]],


        [[[0.8733, 0.8733],
          [1.4612, 1.4612],
          [1.6045, 1.6045]],

         [[1.3020, 1.3020],
          [1.6444, 1.6444],
          [1.4164, 1.4164]],

         [[0.8837, 0.8837],
          [0.5235, 0.5235],
          [1.4137, 1.4137]],

         [[0.8736, 0.8736],
          [1.0256, 1.0

In [259]:
b1_shape = (2, 3, 4, 5)
b2_shape = (2, 3, 5, 6)

a = np.random.random(b1_shape)
b = np.random.random(b2_shape)

In [261]:
dot = a @ b
dot.shape

(2, 3, 4, 6)

In [281]:
placeholder = np.zeros((2 * 3 * 4 * 6,))

In [269]:
2 * 3 * 4 * 6

144

In [270]:
24 * 36

864

In [273]:
144 * 2 * 3

864

In [310]:
dot_ = (a.reshape(-1, 5) @ b.swapaxes(-1, -2).reshape(-1, 5).swapaxes(-1, -2))
dot_.shape

(24, 36)

In [330]:
dot_ = a.sum(axis=0).sum(axis=0) @ b.sum(axis=0).sum(axis=0)
dot_ = np.broadcast_to(dot_.reshape(1, 1, 4, 6), (2, 3, 4, 6))
np.linalg.norm(dot - dot_)

567.6207842665684

In [276]:
24 // 2, 36 // 3

(12, 12)

In [280]:
dot_.sum(axis=0).shape

(36,)

In [329]:
dot_ = (a.reshape(-1, 5) @ b.swapaxes(-1, -2).reshape(-1, 5).swapaxes(-1, -2))
dot_.shape

q = 0
sum_ = 0
cnt = 0
for i in range(0, 144, 6):
    placeholder[cnt] = dot_.reshape(-1)[i:i+6].sum() / 6
        
np.linalg.norm(dot - placeholder.reshape(2, 3, 4, 6))

4.647398387241218

### Batch matmul via split and stack

In [76]:
b1_shape = (2, 3, 4, 5)
b2_shape = (2, 3, 5, 6)

a = np.random.random(b1_shape).astype(np.float32)
b = np.random.random(b2_shape).astype(np.float32)

ndl_a = ndl.Tensor(a, device=ndl.cpu(), requires_grad=True)
ndl_b = ndl.Tensor(b, device=ndl.cpu(), requires_grad=True)

torch_a = torch.tensor(a, requires_grad=True, dtype=torch.float32)
torch_b = torch.tensor(b, requires_grad=True, dtype=torch.float32)

In [77]:
dot = a @ b
dot.shape

(2, 3, 4, 6)

In [78]:
splitted_a = ndl.ops.split(ndl_a.reshape((6, 4, 5)), axis=0)
splitted_b = ndl.ops.split(ndl_b.reshape((6, 5, 6)), axis=0)

tensors_to_stack = []

for i in range(6):
    tensors_to_stack.append(splitted_a[i] @ splitted_b[i])
    
ndl_bmm = ndl.ops.stack(tensors_to_stack, axis=0).reshape((2, 3, 4, 6))
ndl_bmm.shape

(2, 3, 4, 6)

In [79]:
loss_ndl = ndl_bmm.sum()
loss_ndl.backward()

In [80]:
np.linalg.norm(dot - ndl_bmm.detach().numpy())

7.822755e-07

In [81]:
torch_bmm = torch_a @ torch_b
loss = torch_bmm.sum()
loss.backward()

In [82]:
np.linalg.norm(dot - torch_bmm.detach().numpy())

7.822755e-07

In [83]:
np.linalg.norm(torch_a.grad.numpy() - ndl_a.grad.numpy()), np.linalg.norm(torch_b.grad.numpy() - ndl_b.grad.numpy())

(0.0, 0.0)

In [98]:
a.shape[:-2]

(2,)

In [95]:
from functools import reduce

reduce(lambda a, b: a * b, (4,))

4

In [87]:
b1_shape = (2, 3, 5)
b2_shape = (5, 4)

a = np.random.random(b1_shape).astype(np.float32)
b = np.random.random(b2_shape).astype(np.float32)

ndl_a = ndl.Tensor(a, device=ndl.cpu(), requires_grad=True)
ndl_b = ndl.Tensor(b, device=ndl.cpu(), requires_grad=True)

torch_a = torch.tensor(a, requires_grad=True, dtype=torch.float32)
torch_b = torch.tensor(b, requires_grad=True, dtype=torch.float32)

In [88]:
dot = a @ b
dot.shape

(2, 3, 4)

In [89]:
splitted_a = ndl.ops.split(ndl_a.reshape((2, 3, 5)), axis=0)

tensors_to_stack = []

for i in range(2):
    tensors_to_stack.append(splitted_a[i] @ ndl_b)
    
ndl_bmm = ndl.ops.stack(tensors_to_stack, axis=0).reshape((2, 3, 4))
print("shape", ndl_bmm.shape)

loss_ndl = ndl_bmm.sum()
loss_ndl.backward()

np.linalg.norm(dot - ndl_bmm.detach().numpy())

shape (2, 3, 4)


3.7697288e-07

In [90]:
torch_bmm = torch_a @ torch_b
loss = torch_bmm.sum()
loss.backward()

np.linalg.norm(dot - torch_bmm.detach().numpy())

0.0

In [91]:
np.linalg.norm(torch_a.grad.numpy() - ndl_a.grad.numpy()), np.linalg.norm(torch_b.grad.numpy() - ndl_b.grad.numpy())

(0.0, 4.7683716e-07)

In [127]:
b1_shape = (2, 3, 5)
b2_shape = (5, 4)

a = np.random.random(b1_shape).astype(np.float32)
b = np.random.random(b2_shape).astype(np.float32)

ndl_a = ndl.Tensor(a, device=ndl.cpu(), requires_grad=True)
ndl_b = ndl.Tensor(b, device=ndl.cpu(), requires_grad=True)

torch_a = torch.tensor(a, requires_grad=True, dtype=torch.float32)
torch_b = torch.tensor(b, requires_grad=True, dtype=torch.float32)

In [129]:
ndl_bmm = ndl_a @ ndl_b
loss_ndl = ndl_bmm.sum()
loss_ndl.backward()

torch_bmm = torch_a @ torch_b
loss = torch_bmm.sum()
loss.backward()

np.linalg.norm(torch_bmm.detach().numpy() - ndl_bmm.detach().numpy())

3.674277e-07

In [130]:
np.linalg.norm(torch_a.grad.numpy() - ndl_a.grad.numpy()), np.linalg.norm(torch_b.grad.numpy() - ndl_b.grad.numpy())

(2.9200194e-07, 6.7434956e-07)

In [132]:
b1_shape = (2, 3, 4, 5)
b2_shape = (2, 3, 5, 6)

a = np.random.random(b1_shape).astype(np.float32)
b = np.random.random(b2_shape).astype(np.float32)

ndl_a = ndl.Tensor(a, device=ndl.cpu(), requires_grad=True)
ndl_b = ndl.Tensor(b, device=ndl.cpu(), requires_grad=True)

torch_a = torch.tensor(a, requires_grad=True, dtype=torch.float32)
torch_b = torch.tensor(b, requires_grad=True, dtype=torch.float32)

ndl_bmm = ndl_a @ ndl_b
loss_ndl = ndl_bmm.sum()
loss_ndl.backward()

torch_bmm = torch_a @ torch_b
loss = torch_bmm.sum()
loss.backward()

np.linalg.norm(torch_bmm.detach().numpy() - ndl_bmm.detach().numpy())

0.0

In [133]:
np.linalg.norm(torch_a.grad.numpy() - ndl_a.grad.numpy()), np.linalg.norm(torch_b.grad.numpy() - ndl_b.grad.numpy())

(0.0, 0.0)

## Compare mh attention activations

In [162]:
device = ndl.cpu()

In [163]:
T, d = 100, 128
heads = 4
N = 10
M = torch.triu(-float("inf") * torch.ones(T, T), 1).numpy()
X = np.random.random((N, T, d))

In [164]:
torch_x = torch.tensor(X, dtype=torch.float32)
torch_mask = torch.tensor(M, dtype=torch.float32)

attn = torch.nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(torch_x, torch_x, torch_x, attn_mask=torch_mask)

In [165]:
k, q, v = np.split(attn.in_proj_weight.detach().numpy().T, 3, axis=-1)
w_kqv = ndl.nn.Parameter(attn.in_proj_weight.detach().numpy().T, device=device)
w_k = ndl.nn.Parameter(k, device=device)
w_q = ndl.nn.Parameter(q, device=device)
w_v = ndl.nn.Parameter(v, device=device)
w_out = ndl.nn.Parameter(attn.out_proj.weight.detach().numpy().T, device=device)

In [168]:
ndl_multihead_attention = ndl.nn.MultiHeadedAttention(heads, d, device=device)
ndl_multihead_attention.w_k = w_k
ndl_multihead_attention.w_q = w_q
ndl_multihead_attention.w_v = w_v
ndl_multihead_attention.w_out = w_out

ndl_x = ndl.Tensor(X, device=device)
ndl_mask = ndl.Tensor(M, device=device)

Y, A = ndl_multihead_attention(ndl_x, ndl_mask)

key: (10, 4, 100, 32), query.T: (10, 4, 32, 100)


AssertionError: 

In [30]:
a = ndl.Tensor(np.arange(9), device=ndl.cpu()).reshape((3, 3))
a

needle.Tensor([[0. 1. 2.]
 [3. 4. 5.]
 [6. 7. 8.]])

In [43]:
splitted = ndl.ops.split(w_kqv, axis=-1)

A device: cpu()
A : <class 'needle.backend_ndarray.ndarray.NDArray'>


In [47]:
splitted[0].shape, len(splitted)

((128,), 384)

In [63]:
stacked_first = ndl.ops.stack([tensor for i, tensor in enumerate(splitted) if i >= 128 and i < 128 * 2], axis=-1)
stacked_first.shape

(128, 128)

In [59]:
splitted

needle.TensorTuple(needle.Tensor([ 0.06078031 -0.0874882   0.10658623  0.01015931  0.0166598   0.07308573
  0.06505946 -0.04554389 -0.0248808   0.07836054 -0.0536439   0.06354582
  0.02857671  0.07657666 -0.060147    0.04223699  0.06297494  0.05656952
 -0.02007353  0.08819615  0.0846461  -0.05070292 -0.07401869 -0.01569892
  0.07307563 -0.10099024  0.05383591  0.09265185 -0.09417942 -0.10062625
  0.05386952  0.03341784 -0.0738204  -0.09635603  0.00274712 -0.06404778
 -0.07953996 -0.01619335  0.00156156  0.06637786 -0.10782114 -0.01826818
  0.07448841  0.10718491  0.08484554 -0.02647539 -0.10449789  0.01096009
  0.0521246  -0.05338632  0.06731283 -0.06714717 -0.00628517  0.02763208
  0.00543028 -0.07608956 -0.00989367 -0.09907368  0.02831867  0.04427879
  0.03527419 -0.08539858  0.01581849 -0.01648275  0.0266635   0.00604277
  0.01077894  0.06465601 -0.09762365  0.02458257  0.04755712  0.08795565
 -0.06416573 -0.07385759  0.0577533  -0.01916929  0.04783488 -0.03907819
 -0.02901503  0.06