Let's start by importing numpy!

In [2]:
import numpy as np

Importing numpy as np is a common convention but I won't judge you if you picked your favorite short string!

## Creating numpy arrays

Let's start with some basics. np.array() creates numpy arrays from python lists.

In [3]:
x = np.array([[1, 2], [2, 3], [3, 4]])
print(f'{x = }')

x = array([[1, 2],
       [2, 3],
       [3, 4]])


Here are a few more common ways of creating numpy arrays:

In [4]:
x = np.arange(2, 10, 2)
print(f'{x = }')

x = np.zeros((3, 2))
print(f'{x = }')

x = np.ones((3, 2))
print(f'{x = }')

x = np.tril(np.ones((3,3)))  # tril => lower triangle, useful for creating causal masks!
print(f'{x = }')

x = array([2, 4, 6, 8])
x = array([[0., 0.],
       [0., 0.],
       [0., 0.]])
x = array([[1., 1.],
       [1., 1.],
       [1., 1.]])
x = array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]])


The np.random package can be used to create arrays with random numbers:

In [5]:
# random arrays of shape = (3, 2)
x = np.random.rand(3, 2)  # random samples from a uniform distribution over [0, 1)
print(f'{x = }')

x = np.random.randn(3, 2)  # random samples from the "standard normal" distribution
print(f'{x = }')

x = np.random.randint(low=3, high=10, size=(3,2))  # random sampled in the int range [low, high)
print(f'{x = }')

x = array([[0.71247287, 0.42528162],
       [0.08128043, 0.38892688],
       [0.63015336, 0.47903512]])
x = array([[ 0.41907239,  0.23361579],
       [-0.41244341, -0.92222489],
       [-0.65285907,  0.1902672 ]])
x = array([[4, 7],
       [4, 3],
       [4, 5]])


## dtype and shape
Commonly useful properties of a numpy array are its shape and dtype.

In [6]:
print(f'{x.shape = }')
print(f'{x.dtype = }')

x.shape = (3, 2)
x.dtype = dtype('int64')


The shape and dtype of a numpy array can be changed to other compatible values, for instance:

In [7]:
x = np.array([[1, 2], [2, 3], [3, 4]])
print(f'{x.shape = }')
print(f'{x.dtype = }')

# reshape
x_23 = np.reshape(x, [2, 3])
x_6 = x.reshape((-1, 1))  # numpy infers the value of -1, provided it's inferable.
print(f'{x_23 = }, {x_23.shape = }')
print(f'{x_6 = }, {x_6.shape = }')

# change dtypes
x_bool = np.array([True, False, False, True])
x_int = x_bool.astype(np.int32)
print(f'{x_int = }, {x_int.dtype = }')


x.shape = (3, 2)
x.dtype = dtype('int64')
x_23 = array([[1, 2, 2],
       [3, 3, 4]]), x_23.shape = (2, 3)
x_6 = array([[1],
       [2],
       [2],
       [3],
       [3],
       [4]]), x_6.shape = (6, 1)
x_int = array([1, 0, 0, 1], dtype=int32), x_int.dtype = dtype('int32')


## Indexing

Let's look at a few examples:

In [8]:
x = np.array([[1, 2], [2, 3], [3, 4]])
print(f'{x = }')

# access an element, python-style
print(f'{x[0][1] = }')

# access an element, numpy-style
print(f'{x[0, 1] = }')

# access a row
print(f'{x[0] = }')

# access a column
print(f'{x[:, 0] = }')  # : means all the values from that axis

x = array([[1, 2],
       [2, 3],
       [3, 4]])
x[0][1] = np.int64(2)
x[0, 1] = np.int64(2)
x[0] = array([1, 2])
x[:, 0] = array([1, 2, 3])


Even though python-style access and numpy-style access look identical, they can be deceptively different. Let's look at an example:

In [9]:
print(f'{x[:, 0] = }')
print(f'{x[:][0] = }')

x[:, 0] = array([1, 2, 3])
x[:][0] = array([1, 2])


`x[:, 0]` returned the first column, but `x[:][0]` returned the first row. What happened here?

`x[:][0]` creates a chain of two accesses: it first evaluates `x[:]`, which returns all of x, then evaluates res[0], which takes element 0 of the original array, hence returning the first row of the array. This behavior is consistent in python (`[[1, 2], [2, 3]][:][0] == [1, 2]`)

`x[:, 0]` indexes both axes at once: all rows, column 0, and returns the first column.

Numpy arrays are mutable:

In [10]:
x = np.array([[1, 2], [2, 3], [3, 4]])
print(f'{x = }')

print("# edit an element")
x[0][1] = -1
print(f'{x = }')

print("# edit a row")
x[0] = [-1, -2]
print(f'{x = }')

print("# edit a column")
x[:, 0] = [-1, -2, -3]
print(f'{x = }')


x = array([[1, 2],
       [2, 3],
       [3, 4]])
# edit an element
x = array([[ 1, -1],
       [ 2,  3],
       [ 3,  4]])
# edit a row
x = array([[-1, -2],
       [ 2,  3],
       [ 3,  4]])
# edit a column
x = array([[-1, -2],
       [-2,  3],
       [-3,  4]])


Let's look at a few more examples of using `:` and `::` to access numpy arrays

In [11]:
x = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
print(f'{x = }')

# access first 2 rows
print(f'{x[:2] = }')

# start at 1-th row, end at 4-th row
print(f'{x[1:4] = }')

# start at 1-th row, end at 4-th row, access every 2nd row
print(f'{x[1:4:2] = }')

# start at 0-th row, end at 4-th row, access every 2nd row
print(f'{x[:4:2] = }')

# start at 1-th row, end at the end, access every 2nd row
print(f'{x[1::2] = }')

# start at beginning, end at the end, access every 2nd row
print(f'{x[::2] = }')

# start at beginning, end at the end, access every -1-th row (in reverse)
print(f'{x[::-1] = }')

# start at beginning, end at the end, access every -1-th row and 0-th col
print(f'{x[::-1, 0] = }')


x = array([[1, 2],
       [2, 3],
       [3, 4],
       [4, 5],
       [5, 6],
       [6, 7]])
x[:2] = array([[1, 2],
       [2, 3]])
x[1:4] = array([[2, 3],
       [3, 4],
       [4, 5]])
x[1:4:2] = array([[2, 3],
       [4, 5]])
x[:4:2] = array([[1, 2],
       [3, 4]])
x[1::2] = array([[2, 3],
       [4, 5],
       [6, 7]])
x[::2] = array([[1, 2],
       [3, 4],
       [5, 6]])
x[::-1] = array([[6, 7],
       [5, 6],
       [4, 5],
       [3, 4],
       [2, 3],
       [1, 2]])
x[::-1, 0] = array([6, 5, 4, 3, 2, 1])


Unlike python, numpy arrays can also be indexed using integer lists:

In [12]:
x = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
print(f'{x = }')

print(f'{x[[1, 3, 1]] = }')
print(f'{x[np.array([1, 3, 5]), np.array([0, 1, 0])] = }')  # x[1, 0], x[3, 1], x[5, 0]

x = array([[1, 2],
       [2, 3],
       [3, 4],
       [4, 5],
       [5, 6],
       [6, 7]])
x[[1, 3, 1]] = array([[2, 3],
       [4, 5],
       [2, 3]])
x[np.array([1, 3, 5]), np.array([0, 1, 0])] = array([2, 5, 6])


This can be used to reorder arrays; the example below reorders array x in the decreasing order of array y. This is useful for instance when sampling from a language model, where the vocab ids have to be ordered by the probabilities generated by the model)



In [13]:
x = np.random.permutation(5)
y = np.random.randn(5)
sorted_indices = np.argsort(x)
reverse_sorted_indices = sorted_indices[::-1]

print(f'{x = }')
print(f'{y = }')
print(f'{reverse_sorted_indices = }')
print(f'{y[reverse_sorted_indices] = }')

x = array([1, 3, 0, 4, 2])
y = array([-0.94947826, -0.42810292,  0.49767165,  0.93196266,  0.20026202])
reverse_sorted_indices = array([3, 1, 4, 0, 2])
y[reverse_sorted_indices] = array([ 0.93196266, -0.42810292,  0.20026202, -0.94947826,  0.49767165])


We can also filter numpy arrays using bool lists or arrays for indexing:

In [14]:
x = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
mask = np.array([True, False, False, True, True, False])
print(f'{x = }')
print(f'{mask = }')

print(f'{x[mask] = }')

x = array([[1, 2],
       [2, 3],
       [3, 4],
       [4, 5],
       [5, 6],
       [6, 7]])
mask = array([ True, False, False,  True,  True, False])
x[mask] = array([[1, 2],
       [4, 5],
       [5, 6]])


This can be used to filter numpy arrays based on the values of (the same or other) numpy arrays, e.g. selecting model outputs belonging to a particular class of examples.

In [15]:
x = np.arange(-5, 5)
y = np.random.randn(x.shape[0])
print(f'{x = }')
print(f'{y = }')
print(f'{y[x > 2] = }')

x = array([-5, -4, -3, -2, -1,  0,  1,  2,  3,  4])
y = array([ 0.81478051,  0.56263881, -0.26024234,  1.98922655, -0.67673687,
        0.31322671, -0.20885959, -0.33416936,  0.75858371,  0.12762243])
y[x > 2] = array([0.75858371, 0.12762243])


A quick reminder that numpy arrays are mutable, so all access patterns above can be used to set the values of the array at those indices (in contrast, jax arrays are immutable).

## Broadcasting

Let's look at a few simple operations.

In [16]:
x = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array([[2, 3], [4, 5], [6, 7]])

print(f'{x = }')
print(f'{y = }')
print(f'{x + y = }')
print(f'{x - y = }')
print(f'{x * y = }')
print(f'{x / y = }')

x = array([[1, 2],
       [3, 4],
       [5, 6]])
y = array([[2, 3],
       [4, 5],
       [6, 7]])
x + y = array([[ 3,  5],
       [ 7,  9],
       [11, 13]])
x - y = array([[-1, -1],
       [-1, -1],
       [-1, -1]])
x * y = array([[ 2,  6],
       [12, 20],
       [30, 42]])
x / y = array([[0.5       , 0.66666667],
       [0.75      , 0.8       ],
       [0.83333333, 0.85714286]])


All the above operations are element-wise (the operation is applied on each x[i, j] and y[i, j] pair) and easy to understand. What if shapes aren't the same?

Numpy tries to broadcast the shape of the smaller array across the shape of the larger array to make them compatible. Here's the simplest example:

In [17]:
x = np.array([[1,2,3], [2,3,4], [4,5,6]])
print(f'{x = }')
print(f'{1 + x = }')

x = array([[1, 2, 3],
       [2, 3, 4],
       [4, 5, 6]])
1 + x = array([[2, 3, 4],
       [3, 4, 5],
       [5, 6, 7]])


Here, the scalar 1 gets broadcasted to the shape of x, i.e. (3, 3) so it can be added to x. Let's look at a less simple example:

In [18]:
x = np.array([1, 2, 3])
y = np.array([[1, 2, 3], [2, 3, 4], [4, 5, 6]])
print(f'{x = }, {x.shape = }')
print(f'{y = }, {y.shape = }')
print(f'{x + y = }')

x = array([1, 2, 3]), x.shape = (3,)
y = array([[1, 2, 3],
       [2, 3, 4],
       [4, 5, 6]]), y.shape = (3, 3)
x + y = array([[2, 4, 6],
       [3, 5, 7],
       [5, 7, 9]])


X was replicated across axis = 0, so the above operation was equivalent to `[[1, 2, 3], [2, 3, 4], [4, 5, 6]] + [[1, 2, 3], [1, 2, 3], [1, 2, 3]]`. What if we wanted to replcate X across axis = 0, and perform the equivalent of `[[1, 2, 3], [2, 3, 4], [4, 5, 6]] + [[1, 1, 1], [2, 2, 2], [3, 3, 3]]`?

We can do this without making needless copies of the smaller array, by adding new dimensions to the array. Here are 3 equivalent ways to do this:

In [19]:
x = np.array([1, 2, 3])
print(f'{x = }, {x.shape = }')

print()

print(f'{x[None, :] = }, {x[None, :].shape = }')
print(f'{x[np.newaxis, :] = }, {x[np.newaxis, :].shape = }')
print(f'{np.expand_dims(x, axis=0) = }, {np.expand_dims(x, axis=0).shape = }')

print()

print(f'{x[:, None] = }, {x[:, None].shape = }')
print(f'{x[:, np.newaxis] = }, {x[:, np.newaxis].shape = }')
print(f'{np.expand_dims(x, axis=1) = }, {np.expand_dims(x, axis=1).shape = }')


x = array([1, 2, 3]), x.shape = (3,)

x[None, :] = array([[1, 2, 3]]), x[None, :].shape = (1, 3)
x[np.newaxis, :] = array([[1, 2, 3]]), x[np.newaxis, :].shape = (1, 3)
np.expand_dims(x, axis=0) = array([[1, 2, 3]]), np.expand_dims(x, axis=0).shape = (1, 3)

x[:, None] = array([[1],
       [2],
       [3]]), x[:, None].shape = (3, 1)
x[:, np.newaxis] = array([[1],
       [2],
       [3]]), x[:, np.newaxis].shape = (3, 1)
np.expand_dims(x, axis=1) = array([[1],
       [2],
       [3]]), np.expand_dims(x, axis=1).shape = (3, 1)


Let's use this to control how x is broadcasted across y:

In [20]:
x = np.array([1, 2, 3])
y = np.array([[1, 2, 3], [2, 3, 4], [4, 5, 6]])
print(f'{x = }, {x.shape = }')
print(f'{y = }, {y.shape = }')
print(f'{x[None, :] + y = }')
print(f'{x[:, None] + y = }')

x = array([1, 2, 3]), x.shape = (3,)
y = array([[1, 2, 3],
       [2, 3, 4],
       [4, 5, 6]]), y.shape = (3, 3)
x[None, :] + y = array([[2, 4, 6],
       [3, 5, 7],
       [5, 7, 9]])
x[:, None] + y = array([[2, 3, 4],
       [4, 5, 6],
       [7, 8, 9]])


Broadcasting is pretty commonly used in numpy code, e.g. broadcasting the bias across a batch of features in an linear layer, but can be tricky. Always test your broadcasting code on small examples to make sure it's working correctly!

## Let's use numpy!
Let's use numpy to implement some components commonly used in neural networks! This will introduce us to a few common numpy functions and help us get familiar with the indexing and broadcasting concepts we just read about.

Please feel free to look up the expressions on the internet and try implementing these yourself and use the tests to verify your approach before looking at the implementations provided here!

### ReLU

ReLU (rectified linear unit) is the most basic activation function used in neural networks and is defined as relu(x) = x if x > 0 else 0. Let's implement it in numpy!

In [21]:
def relu(x):
  return np.where(x > 0, x, 0)

Let's test it:

In [22]:
num_tests = 5
for _ in range(num_tests):
  x = np.random.permutation(np.arange(-5, 5))
  print(f'{x = }')
  print(f'{relu(x) = }')

x = array([-3, -4,  0, -5,  3, -2,  4,  1, -1,  2])
relu(x) = array([0, 0, 0, 0, 3, 0, 4, 1, 0, 2])
x = array([-5,  4, -3,  2,  0,  3,  1, -4, -1, -2])
relu(x) = array([0, 4, 0, 2, 0, 3, 1, 0, 0, 0])
x = array([ 3, -4, -5, -3,  1,  2,  4, -2, -1,  0])
relu(x) = array([3, 0, 0, 0, 1, 2, 4, 0, 0, 0])
x = array([ 4,  2, -2,  1, -1, -4,  0,  3, -3, -5])
relu(x) = array([4, 2, 0, 1, 0, 0, 0, 3, 0, 0])
x = array([ 0, -5,  1,  2, -1, -4,  4, -3,  3, -2])
relu(x) = array([0, 0, 1, 2, 0, 0, 4, 0, 3, 0])


We can also implement relu(x) as `np.maximum(x, 0)`, which returns the element-wise maximum of the two arrays x and 0 (broadcasted). This is different from `np.max(a)` returns the max of array a (optionally, along an axis). Try modifying the implementation and run the tests!

BTW, `relu(x) = x[x > 0]` is incorrect. Can you explain why? Modify the function and run the tests to see what happens.

### Sigmoid
Sigmoid is used to convert the model output (i.e. logit) to a probability, commonly used for binary classfication problems. Let's implement it!

In [23]:
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

It's harder to inspect correctness, but still, let's run a few tests!

In [24]:
num_tests = 5
for _ in range(num_tests):
  x = np.random.randn(5)
  print(f'{x = }')
  print(f'{sigmoid(x) = }')

x = array([-0.56682805, -2.24890747,  0.27267574, -1.07624555,  1.39636642])
sigmoid(x) = array([0.36196906, 0.09544375, 0.56774968, 0.25421717, 0.80160666])
x = array([ 0.41083961, -0.28242042,  0.70086132, -1.52032512, -0.19486307])
sigmoid(x) = array([0.60128918, 0.42986048, 0.66837871, 0.17941365, 0.4514378 ])
x = array([ 1.07106844, -0.66546428, -1.04139873, -0.34232793, -0.91302592])
sigmoid(x) = array([0.74480005, 0.33951321, 0.2608802 , 0.41524411, 0.28638104])
x = array([ 1.03158171, -0.02043043, -0.12344788, -1.64196494,  0.99479116])
sigmoid(x) = array([0.73722243, 0.49489257, 0.46917716, 0.16219787, 0.73003323])
x = array([-0.42999825,  0.46108848,  0.21802558,  0.79601311,  0.00290826])
sigmoid(x) = array([0.39412675, 0.61327236, 0.5542915 , 0.689121  , 0.50072706])


One visual check for correctness is that the sigmoid is always positive (can you explain why?). We can also check for correctness using a few known values, e.g. x = 0, x = inf, and x = -inf. Can you explain the following outputs?

In [25]:
  print(f'{sigmoid(0) = }')  # 0.5
  print(f'{sigmoid(-np.inf) = }')  # 0.0
  print(f'{sigmoid(np.inf) = }')  # 1.0

sigmoid(0) = np.float64(0.5)
sigmoid(-np.inf) = np.float64(0.0)
sigmoid(np.inf) = np.float64(1.0)


### Softmax

In [26]:
def softmax(x, axis=None):
  exp = np.exp(x)
  sumexp = np.sum(exp, axis=axis, keepdims=True)
  return exp / sumexp

Let's test this before diving into the details:

In [27]:
x = np.array([[1, 1], [3, 0]])
print(f'{x = }')
print(f'{softmax(x) = }')

x = np.array([[1, 2], [3, 0]])
print(f'{x = }')
print(f'{softmax(x, axis=0) = }')

x = np.array([[1, 2], [3, 0]])
print(f'{x = }')
print(f'{softmax(x, axis=1) = }')

x = array([[1, 1],
       [3, 0]])
softmax(x) = array([[0.1024912, 0.1024912],
       [0.7573132, 0.0377044]])
x = array([[1, 2],
       [3, 0]])
softmax(x, axis=0) = array([[0.11920292, 0.88079708],
       [0.88079708, 0.11920292]])
x = array([[1, 2],
       [3, 0]])
softmax(x, axis=1) = array([[0.26894142, 0.73105858],
       [0.95257413, 0.04742587]])


There are a few things to talk about here, let's break them down step by step.

Let's first talk about the axis.

Many numpy operations allow you to apply them over one or more (or all) axes of the array. axis=None applies the operation over the entire array. But, commonly you'd apply softmax over a batch of activations (e.g. of shape (B, H), where B is the batch size and H is the hidden dimension) and here you need to apply softmax over each row of features separately rather than the entire array.

Compare the examples above:
+ #1 applies softmax to the entire array, and so all the values of the softmax add up to 1.
+ #2 applies softmax to each column of the array, and so all the values of the softmax in each col add up to 1.
+ #3 applies softmax to each row separately, and so all the values of the softmax in each row add up to 1. This is what we want in the (B, H) case.

So, axis=k denotes which dimension of the array to apply the operation over. If the shape of the array is (x0, x1, ..., xn), then axis=k applies the operation over the xk dim of the array. In the 2-D case discussed above, axis=0 applies it to the "batch" dim (cols), and axis=1 applies it the "feature" dim (rows).

Let's now talk about `keepdims`.

When we apply "reduce" operations like sum, max, etc. the number of output dimensions is the number of input dims - the number of axes we applied the operation over. Setting keepdims=True preserves the number of dimensions of the input. For instance, for x of shape (B, H), the output has shape:
  + for axis=None (sum over all axes), scalar for keepdims=False, (1, 1) for keepdims=True.
  + for axis=0, (H,) for keepdims=False, (1, H) for keepdims=True.
  + for axis=1, (B,) for keepdims=False, (B, 1) for keepdims=True.

Let's see this in action:

In [28]:
x = np.array([[1, 1], [3, 0]])

print(f'{x = }, {x.shape = }')
for axis in (None, 0, 1):
  print(f'{axis = }')
  sum_nokeepdim = np.sum(x, axis=axis, keepdims=False)
  sum_keepdim = np.sum(x, axis=axis, keepdims=True)
  print(f'{np.sum(x, axis=axis, keepdims=False) = }, {sum_nokeepdim.shape = }')
  print(f'{np.sum(x, axis=axis, keepdims=True) = }, {sum_keepdim.shape = }')



x = array([[1, 1],
       [3, 0]]), x.shape = (2, 2)
axis = None
np.sum(x, axis=axis, keepdims=False) = np.int64(5), sum_nokeepdim.shape = ()
np.sum(x, axis=axis, keepdims=True) = array([[5]]), sum_keepdim.shape = (1, 1)
axis = 0
np.sum(x, axis=axis, keepdims=False) = array([4, 1]), sum_nokeepdim.shape = (2,)
np.sum(x, axis=axis, keepdims=True) = array([[4, 1]]), sum_keepdim.shape = (1, 2)
axis = 1
np.sum(x, axis=axis, keepdims=False) = array([2, 3]), sum_nokeepdim.shape = (2,)
np.sum(x, axis=axis, keepdims=True) = array([[2],
       [3]]), sum_keepdim.shape = (2, 1)


Why is it important to set keepdims=True? To get the softmax, we divide the exponents of shape (B, H) by the sum of exponents. For axis=None, sumexp is a scalar and trivially broadcasted over shape (B, H). But for axis=0 and axis=1, the shape of sumexp (H,) or (B,) respectively, will be broadcasted over over shape (B, H) and this would fail for axis=1, or worse, if B == H, the sum will be broadcasted over the wrong dim and fail silently. Hence, we guide the broadcasting by setting keepdims=True. BTW, this is equivalent to `sum = np.sum(x, axis=1, keepdims=False); sum = sum[:, None]`


Let's check it out with our example (B = H = 2). Can you explain why the answer is wrong with keepdims=False?

In [29]:
x = np.array([[1, 1], [3, 0]])

print(f'{x = }, {x.shape = }')
for axis in (1,):
  print(f'{axis = }')
  sum_nokeepdim = np.sum(x, axis=axis, keepdims=False)
  sum_keepdim = np.sum(x, axis=axis, keepdims=True)
  print(f'{np.sum(x, axis=axis, keepdims=False) = }, {sum_nokeepdim.shape = }')
  print(f'{np.sum(x, axis=axis, keepdims=True) = }, {sum_keepdim.shape = }')
  print(f'Wrong: {x / sum_nokeepdim = }')
  print(f'Correct: {x / sum_keepdim = }')
  print()

x = array([[1, 1],
       [3, 0]]), x.shape = (2, 2)
axis = 1
np.sum(x, axis=axis, keepdims=False) = array([2, 3]), sum_nokeepdim.shape = (2,)
np.sum(x, axis=axis, keepdims=True) = array([[2],
       [3]]), sum_keepdim.shape = (2, 1)
Wrong: x / sum_nokeepdim = array([[0.5       , 0.33333333],
       [1.5       , 0.        ]])
Correct: x / sum_keepdim = array([[0.5, 0.5],
       [1. , 0. ]])



A general rule of thumb is that numpy can broadcast over the batch dimension, but not others. So (H,) -> (B, H) is correctly done without guidance, but (B,) -> (B, H) has to be guided using res[:, None] (or np.newaxis or np.expand_dims).

Great, hopefully your understanding of axes and broadcasting was enriched by this example! Let's apply it to improve the numerical stability of our softmax impl.

In our current impl, the sumexp term can become really large. To counter this, we subtract x by the the max of x (along the provided axis) before computing the exponents and summing them. This is equivalent to multiplying both the numerator and denominator of the softmax by exp(-mx) and so has no effect on the output of the softmax. But because the input to exp is now in the range (-inf, 0), the exp will be in range (0, 1) which makes the sumexp manageable.

Let's implement this!

In [30]:
def stable_softmax(x, axis=None):
  mx = np.max(x, axis=axis, keepdims=True)
  exps = np.exp(x - mx)
  sumexps = np.sum(exps, axis=axis, keepdims=True)
  return exps / sumexps

In [31]:
num_tests = 5
B = 5
H = 10
for i in range(num_tests):
  x = np.random.randn(B, H)
  s1 = softmax(x, axis=-1)
  s2 = stable_softmax(x, axis=-1)
  np.testing.assert_allclose(s1, s2)
print("The two impls matched!")

The two impls matched!


Note that as with indexing, `axis=-k` is equivalent to `axis=num_axes-k`. `axis=-1` is often used when you want to apply an operation to the feature dimension but there might be multiple "batch" dimensions, e.g. with language models, you often have a batch of token seqeunces of shape (B, T, H). Let's see this in action in LayerNorm!

### LayerNorm

Layernorm is a common component of transformer architectures (recently RMSNorm has been more common because it has fewer learnable params, hence simpler).

For layernorm, we compute the mean and variance along the feature axis (i.e. for each example in the batch) and normalize the features based on these. We also have 2 learnable params to scale and shift the normalized features; for now we'll assume that these are constants, and we'll later see how to apply backprop and update (i.e. learn) these params. Let's go!

In [32]:
def layernorm(x):
  B, T, H = x.shape
  mean_BT = np.mean(x, axis=-1, keepdims=True)
  var_BT = np.var(x, axis=-1, keepdims=True)
  eps = 1e-8
  norm_x_BTH = (x - mean_BT) / np.sqrt(var_BT + eps)
  # learnable params
  scale_H = np.ones((H,))
  shift_H = np.zeros((H,))
  return norm_x_BTH * scale_H + shift_H

Let's test this! Again, it's hard to visually inspect for correctness, but let's create an input array of large numbers so that we can see the effect of normalization (i.e. output values are close to 0). In our impl, the scale (= 1) and shift (= 0) have no effect because we're multiplying by 1 and shifting by 0.

In [33]:
B = 2
T = 3
H = 4

x = np.random.randn(B, T, H)
# scale and shift by large numbers so that we can see the effect of layernorm
x = x * 100 + 200
print(f'{x = }')
print(f'{layernorm(x) = }')

x = array([[[357.79169561, 208.15738989, 111.13551003, 153.71720507],
        [105.66080791, 162.39187008, 172.84249259,  87.72740474],
        [241.95510612, 120.82886376, 203.76837997, 202.8443505 ]],

       [[ 17.87137652, 263.96670787, 328.09823822, 201.05003962],
        [221.92099606, 220.7273338 , 368.30185166,   2.22508811],
        [213.29101514, 167.96342076, 221.96545766, 156.90834108]]])
layernorm(x) = array([[[ 1.60992053,  0.00490126, -1.03578246, -0.57903933],
        [-0.73164101,  0.83495754,  1.12354601, -1.22686254],
        [ 1.12217042, -1.61791092,  0.25832181,  0.23741869]],

       [[-1.59644341,  0.52865092,  1.08244262, -0.01465014],
        [ 0.14254234,  0.13340798,  1.26270498, -1.5386553 ],
        [ 0.82949729, -0.78704629,  1.13885883, -1.18130984]]])


Remember our rule of thumb regarding broadcasting: numpy knows how to broadcast over the batch dim. Hence we use keepdims=True to guide the broadcasting of mean and var because these are broadcasted over the feature dim (axis=2 or axis=-1), but we don't need to guide the broadcasting of the scale and shift params which are broadcasted over the batch dims. Hence `norm_x_BTH * scale_H[None, None, :] + shift_H[None, None, :]` is redundant.

## Matrix Multiplication

If you're rusty on matrix multiplication, do a quick google search and familiarize yourself with the concept.

Let's look a few common matrix multiplications we'll encounter in practice.

One is projecting a batch of activations from one dimension (e.g. model dimension) to another (e.g. hidden dimensions). This involves multiplying an activations array of shape B, F and a (learnable) weight array of shape F, H (B -> batch dim, F -> feature dim, H -> hidden dim) to get an activations array of shape B, H. Let's implement this in 4 ways!

In [37]:
B, F, H = 2, 3, 4
x = np.random.randn(B, F)
y = np.random.randn(F, H)
print(f'{x = }')
print(f'{y = }')

print(f'{np.dot(x, y) = }')
print(f'{np.matmul(x, y) = }')
print(f'{x@y = }')
print(f'{np.einsum("bf,fh->bh", x, y) = }')

x = array([[-0.63207154,  1.63526704,  1.05375833],
       [ 0.82515517, -1.80175095,  0.61887486]])
y = array([[-0.56440053, -0.18720503,  0.22760663, -0.62886488],
       [ 0.93584682,  0.84739081, -0.08651255,  0.19707451],
       [-0.60932031, -0.05660355,  0.32302553,  0.27083101]])
np.dot(x, y) = array([[ 1.24502461,  1.44439077,  0.05505604,  1.00514748],
       [-2.52897394, -1.71629091,  0.54359724, -0.70637979]])
np.matmul(x, y) = array([[ 1.24502461,  1.44439077,  0.05505604,  1.00514748],
       [-2.52897394, -1.71629091,  0.54359724, -0.70637979]])
x@y = array([[ 1.24502461,  1.44439077,  0.05505604,  1.00514748],
       [-2.52897394, -1.71629091,  0.54359724, -0.70637979]])
np.einsum("bf,fh->bh", x, y) = array([[ 1.24502461,  1.44439077,  0.05505604,  1.00514748],
       [-2.52897394, -1.71629091,  0.54359724, -0.70637979]])


Another scenario is projecting a batch of activations from the feature dimension to a scalar for a binary classifier. This involves multiplying an activations array of shape B, F and a (learnable) weight array of shape F to get an activations array of shape (B,). This is generally passed through sigmoid to get a batch of probabilities. Let's implement this in 4 ways!

In [41]:
B, F = 2, 3
x = np.random.randn(B, F)
y = np.random.randn(F)
print(f'{x = }')
print(f'{y = }')

print(f'{np.dot(x, y) = }')
print(f'{np.matmul(x, y) = }')
print(f'{x@y = }')
print(f'{np.einsum("bf,f->b", x, y) = }')

x = array([[ 1.29142654,  0.76313607, -0.16906432],
       [ 0.37681326, -1.39252537,  1.08533761]])
y = array([-0.27512246,  0.57937015,  1.1045357 ])
np.dot(x, y) = array([-0.09989978,  0.28833671])
np.matmul(x, y) = array([-0.09989978,  0.28833671])
x@y = array([-0.09989978,  0.28833671])
np.einsum("bf,f->b", x, y) = array([-0.09989978,  0.28833671])


Let's look an example with more dimensions. In language models, we multiply a batch of token activations of shape (B, T, F) with a (learned) multi-headed query projection matrix of shape (N, F, H) to get a batch of multi-headed queries of shape (B, T, N, H) where (B is batch dim, T is the sequence length, F is the feature dim, N is the number of query heads and H is the attention dim). Let's implement this!

In [None]:
B, F = 2, 3
x = np.random.randn(B, F)
y = np.random.randn(F)
print(f'{x = }')
print(f'{y = }')

print(f'{np.dot(x, y) = }')
print(f'{np.matmul(x, y) = }')
print(f'{x@y = }')
print(f'{np.einsum("bf,f->b", x, y) = }')

Now let's look at 2-D arrays. In the 2-D case

np.einsum uses Einstein summation convention to denote the operation. commonly describes the input shapes and the output shapes, denoting the batch, contracting and non-contracting dimensions. Batch dimensions appear in both inputs and the output, non-contracting dimensions appear in one of the inputs and the output, contracting dimensions appear in both inputs but not in the output.

In [None]:
x = np.random.permutation(4)
y = np.random.permutation(5)
print(f'{x = }')
print(f'{y = }')
print(f'{np.dot(x[:, None], y[None, :]) = }')
print(f'{np.linalg.matmul(x[:, None], y[None, :]) = }')

In [None]:
print(f'{np.einsum("n,n->", x, y) = }')