# Exercise 1: Tensor basics 
In this exercise you will learn the basics of tensor creation, manipulation, indexing, broadcasting, vectorization, einsum, and attention masking fundamentals. These basics are important for understanding any complex implementation later on so make sure you understand them well.

**To complete this exercise fill in all TODOs in the functions below.** 

Make sure to check the output of your function and whether or not it fulfills the requirements outlined in the function definition. Do NOT change the function signature or name since we will be running checks on your functions during grading.

### Shape legend used in this notebook
- `B`: batch size
- `T`: sequence length / time
- `D`: feature dimension
- `H`: number of attention heads
- `Dh`: per-head feature dimension

### Debugging tip: what to print
When you get a shape error, print:
- `x.shape`, `x.dtype`, `x.device`
- `x.is_contiguous()` (important for `view`)
For masks also print:
- `mask.shape`, `mask.dtype`, `mask.sum()` and a small slice like `mask[0, :10]`

### Reproducibility tip: seeding in PyTorch
Many operations in deep learning involve randomness (e.g., initializing model weights, shuffling data, dropout, random augmentations).
**Seeding** sets the starting state of PyTorch’s random number generator so that these random choices become **repeatable**.

- If you set the same seed and run the same code again, you should get the same *random* tensors / initial weights.
- If you don’t set a seed, results can vary between runs.

Common usage: `torch.manual_seed(seed)`

Note: even with fixed seeds, some GPU operations can still be non-deterministic due to performance optimizations. For this assignment, seeding is mainly to make debugging easier and to ensure everyone can reproduce the same intermediate results. If you are given a seed, make sure to use it when creating tensors or performing other operations.

## Tensor creation
This warmup exercise teaches you how to create tensors with different shapes and values. A few details about tensor creation that are good to know:
- `torch.tensor([...])` infers dtype from Python values (ints → integer tensor, floats → float tensor).
- `torch.arange(start, end)` is **end-exclusive**.
- `torch.linspace(start, end, steps)` is **end-inclusive**.

In [2]:
from collections.abc import Sequence
import torch

In [6]:
def make_tensor(data, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """ Create a tensor from Python data (list/tuple/nested lists). """
    return torch.tensor(data=data, dtype=dtype, device=device)

x = make_tensor([[1, 2], [3, 4]], dtype=torch.float32)
x

tensor([[1., 2.],
        [3., 4.]])

In [7]:
def make_zeros(shape: Sequence[int], dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a tensor filled with zeros."""
    return torch.zeros(size=shape, dtype=dtype, device=device)

z = make_zeros((2, 3), dtype=torch.float64)
z

tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)

In [9]:
def make_ones_like(x: torch.Tensor) -> torch.Tensor:
    """Create a tensor of ones with the same shape, dtype, and device as x. """
    return torch.ones_like(input=x)

base = torch.randn(2, 3, dtype=torch.float32)
ones = make_ones_like(base)
ones

tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [10]:
def make_arange(start: int, end: int, step: int = 1, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a 1D tensor containing values [start, start+step, ..., < end]."""
    return torch.arange(start=start, end=end, step=step, dtype=dtype, device=device)

ar = make_arange(0, 5, 2, dtype=torch.int64)
ar

tensor([0, 2, 4])

In [11]:
def make_linspace(start: float, end: float, steps: int, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a 1D tensor with evenly spaced values from start to end (inclusive)."""
    return torch.linspace(start=start, end=end, steps=steps, dtype=dtype, device=device)

ls = make_linspace(0.0, 1.0, steps=5, dtype=torch.float32)
ls

tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])

In [12]:
def make_randn(shape: Sequence[int], seed: int | None = None, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a tensor filled with values from a standard normal distribution."""
    if seed is not None:
        torch.manual_seed(seed)
    return torch.randn(size=shape, dtype=dtype, device=device)

a = make_randn((2, 3), seed=123, dtype=torch.float32)
a

tensor([[-0.1115,  0.1204, -0.3696],
        [-0.2404, -1.1969,  0.2093]])

In [14]:
def cast_dtype_and_move(x: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Convert tensor dtype and move to device."""
    return x.to(device=device, dtype=dtype)

casted = cast_dtype_and_move(torch.tensor([1, 2, 3]), torch.device("cpu"), torch.float32)
casted

tensor([1., 2., 3.])

## Shape manipulation
Now that we covered the basic tensor creation schemes, we want to focus on shape manipulation. Understanding the difference between these mechanisms is key for building larger systems and many people still get it wrong. 
The core ideas to understand are:
- **Contiguous tensors** store data in a single, row-major memory layout.
- Many ops (especially slicing like `x[:, ::2]`, `transpose`, `permute`) often create **non-contiguous** tensors (no copy but different strides).
- `view(...)` is **zero-copy** but typically requires **contiguous** memory → may throw an error.
- `reshape(...)` tries to return a view, but if the tensor is non-contiguous it will **allocate/copy**.
- `contiguous()` forces a contiguous copy when the tensor isn’t contiguous.

If you *need* a view after reordering dims: call `x = x.contiguous()` first (this makes a contiguous copy).

In [16]:
def reshape_tensor(x: torch.Tensor, new_shape: Sequence[int]) -> torch.Tensor:
    """Reshape tensor to new_shape (may return a view or a copy)."""
    return x.reshape(new_shape)

x = torch.arange(6)
y = reshape_tensor(x, (2, 3))
x, y

(tensor([0, 1, 2, 3, 4, 5]),
 tensor([[0, 1, 2],
         [3, 4, 5]]))

In [17]:
def view_tensor(x: torch.Tensor, new_shape: Sequence[int]) -> torch.Tensor:
    """View tensor as new_shape (requires contiguous memory and doesn't allocate new memory for the tensor data)."""
    return x.view(new_shape)

y_view = view_tensor(x, (2, 3))
y_view

tensor([[0, 1, 2],
        [3, 4, 5]])

In [19]:
def flatten_from_dim(x: torch.Tensor, start_dim: int = 0) -> torch.Tensor:
    """Flatten a tensor starting from start_dim into a single dimension."""
    return x.flatten(start_dim=start_dim)

x2 = torch.randn(2, 3, 4)
flat = flatten_from_dim(x2, start_dim=1)
x2, flat

(tensor([[[-1.2203,  1.3139,  1.0533,  0.1388],
          [-0.2044, -2.2685, -0.9133, -0.4204],
          [ 0.2436, -0.0567,  0.3784,  1.6863]],
 
         [[ 0.2553, -0.5496,  1.0042,  0.8272],
          [ 1.5434,  0.1406,  1.0617, -0.9929],
          [-1.6025, -1.0764,  0.9031, -0.7218]]]),
 tensor([[-1.2203,  1.3139,  1.0533,  0.1388, -0.2044, -2.2685, -0.9133, -0.4204,
           0.2436, -0.0567,  0.3784,  1.6863],
         [ 0.2553, -0.5496,  1.0042,  0.8272,  1.5434,  0.1406,  1.0617, -0.9929,
          -1.6025, -1.0764,  0.9031, -0.7218]]))

In [20]:
def add_singleton_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
    """Insert a size-1 dimension at position dim."""
    return x.unsqueeze(dim=dim)

x3 = torch.randn(5, 7)
x3s = add_singleton_dim(x3, dim=1)
x3, x3s

(tensor([[ 1.0720,  1.5026, -0.8190,  0.2686, -2.2150, -1.3193, -2.0915],
         [ 0.9629, -0.9948,  1.2176, -0.2282,  1.3382,  1.9929,  1.3708],
         [-0.5009, -0.2793,  1.2311, -1.0973, -0.9669,  0.7763, -0.2582],
         [-2.0407, -0.8016, -0.8183, -1.1820, -0.2877, -0.6043,  1.3334],
         [-1.4053, -0.5922, -0.2548,  1.1517, -0.0179,  0.4264, -0.7657]]),
 tensor([[[ 1.0720,  1.5026, -0.8190,  0.2686, -2.2150, -1.3193, -2.0915]],
 
         [[ 0.9629, -0.9948,  1.2176, -0.2282,  1.3382,  1.9929,  1.3708]],
 
         [[-0.5009, -0.2793,  1.2311, -1.0973, -0.9669,  0.7763, -0.2582]],
 
         [[-2.0407, -0.8016, -0.8183, -1.1820, -0.2877, -0.6043,  1.3334]],
 
         [[-1.4053, -0.5922, -0.2548,  1.1517, -0.0179,  0.4264, -0.7657]]]))

In [26]:
def remove_singleton_dims(x: torch.Tensor, dim: int | None = None) -> torch.Tensor:
    """Remove size-1 dimensions."""
    return x.squeeze(dim=dim) if dim is not None else x.squeeze()

x4 = torch.randn(2, 1, 3)
x4s = remove_singleton_dims(x4)
x4, x4s

(tensor([[[-0.2005, -0.1195,  1.1332]],
 
         [[ 0.6291, -0.8709, -0.7470]]]),
 tensor([[-0.2005, -0.1195,  1.1332],
         [ 0.6291, -0.8709, -0.7470]]))

In [27]:
def transpose_last_two(x: torch.Tensor) -> torch.Tensor:
    """Swap the last two dimensions of x."""
    return x.transpose(-1, -2)

x6 = torch.randn(2, 3, 4)
x6t = transpose_last_two(x6)
x6, x6t

(tensor([[[-0.6062,  0.4771,  0.7203, -0.0215],
          [ 1.0731, -0.1408, -0.5394, -1.2782],
          [-0.2589,  1.3113, -0.0360,  0.2118]],
 
         [[-0.0086,  1.8576,  2.1321, -0.5056],
          [ 1.6921, -1.0944, -1.0197, -0.5399],
          [ 1.2117, -0.8632,  1.3337,  0.0771]]]),
 tensor([[[-0.6062,  1.0731, -0.2589],
          [ 0.4771, -0.1408,  1.3113],
          [ 0.7203, -0.5394, -0.0360],
          [-0.0215, -1.2782,  0.2118]],
 
         [[-0.0086,  1.6921,  1.2117],
          [ 1.8576, -1.0944, -0.8632],
          [ 2.1321, -1.0197,  1.3337],
          [-0.5056, -0.5399,  0.0771]]]))

In [28]:
def permute_bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
    """Convert (B, H, W, C) tensor into (B, C, H, W)."""
    return x.permute(0, 3, 1, 2)

x7 = torch.randn(8, 32, 32, 3)
x7p = permute_bhwc_to_bchw(x7)
x7.shape, x7p.shape

(torch.Size([8, 32, 32, 3]), torch.Size([8, 3, 32, 32]))

In [30]:
def make_contiguous(x: torch.Tensor) -> torch.Tensor:
    """Check if tensor is contiguous and if not make contiguous."""
    return x.contiguous()

x8 = torch.randn(4, 6)[:, ::2]
x8c = make_contiguous(x8)
print(x8.is_contiguous(), x8c.is_contiguous())
x8, x8c

False True


(tensor([[ 0.0186, -1.3608, -0.2622],
         [ 0.4351, -0.4247, -0.0134],
         [-0.8891,  0.0802,  0.6056],
         [ 1.0850, -0.2268, -1.3068]]),
 tensor([[ 0.0186, -1.3608, -0.2622],
         [ 0.4351, -0.4247, -0.0134],
         [-0.8891,  0.0802,  0.6056],
         [ 1.0850, -0.2268, -1.3068]]))

## Indexing
Now that we know how to create tensors and manipulate them we need to understand how we can extract certain components from them using indexing. 
- Basic slicing (`x[a:b]`) returns a view when possible.
- “Fancy” indexing (lists/tensors of indices) usually allocates a new tensor.
- In-place vs out-of-place matters: if a function says “return a copy, leave the input unchanged”, you need `clone()`.

In [31]:
def slice_rows(x: torch.Tensor, start: int, end: int) -> torch.Tensor:
    """Slice rows in a 2D tensor: x[start:end, :]."""
    return x[start:end, :]

x = torch.arange(12).reshape(4, 3)
rows = slice_rows(x, 1, 3)
x, rows

(tensor([[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]]),
 tensor([[3, 4, 5],
         [6, 7, 8]]))

In [33]:
def select_columns(x: torch.Tensor, cols: Sequence[int]) -> torch.Tensor:
    """Select specific columns from a 2D tensor."""
    return x[:, cols]

cols = select_columns(x, [0, 2])
x, cols

(tensor([[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]]),
 tensor([[ 0,  2],
         [ 3,  5],
         [ 6,  8],
         [ 9, 11]]))

In [35]:
def get_diagonal(x: torch.Tensor) -> torch.Tensor:
    """Get the diagonal of a 2D tensor."""
    return x[range(x.shape[0]), range(x.shape[0])]

d = get_diagonal(torch.tensor([[1, 2], [3, 4]]))
d

tensor([1, 4])

In [36]:
def set_subtensor(x: torch.Tensor, row_idx: int, col_idx: int, value: float) -> torch.Tensor:
    """Return a copy of x where x[row_idx, col_idx] is set to value."""
    x_copy = x.clone()
    x_copy[row_idx, col_idx] = value
    return x_copy

base = torch.zeros(2, 2)
out = set_subtensor(base, 0, 1, 5.0)
base, out

(tensor([[0., 0.],
         [0., 0.]]),
 tensor([[0., 5.],
         [0., 0.]]))

In [38]:
def gather_rows(x: torch.Tensor, row_indices: torch.Tensor) -> torch.Tensor:
    """Gather (concat) rows from x using row_indices."""
    return x[row_indices]

x2 = torch.tensor([[10, 11], [20, 21], [30, 31]])
idx = torch.tensor([2, 0])
gathered = gather_rows(x2, idx)
x2, idx, gathered

(tensor([[10, 11],
         [20, 21],
         [30, 31]]),
 tensor([2, 0]),
 tensor([[30, 31],
         [10, 11]]))

## Broadcasting and reducing
Now we're covering a pytorch mechanism that lets you apply elementwise ops without using python loops. It's important to understand how it works to trace your shapes in complicated systems. The broadcasting rules to know are:
- Dimensions align from the **right**.
- A dimension can broadcast if it’s equal or one of them is **1**.

### Reduction ops and `keepdim`

When you reduce over a dimension (e.g. `sum`, `mean`, `max`), PyTorch can either:

- **remove** the reduced dimension (`keepdim=False`, default), or
- **keep** it as size 1 (`keepdim=True`)

Keeping the dimension is often helpful because it makes broadcasting back “just work”.

#### Shape diagram examples

Assume `x` has shape `(B, T, D)`:

**Sum over time**
- `x.sum(dim=1)` → shape `(B, D)`
- `x.sum(dim=1, keepdim=True)` → shape `(B, 1, D)`

**Mean over features**
- `x.mean(dim=2)` → shape `(B, T)`
- `x.mean(dim=2, keepdim=True)` → shape `(B, T, 1)`

#### Why `keepdim=True` helps with broadcasting

Example: center `x` by subtracting the mean over `T`

- If `m = x.mean(dim=1)` has shape `(B, D)`, then `x - m` **fails** (shapes `(B,T,D)` and `(B,D)` don't align).
- If `m = x.mean(dim=1, keepdim=True)` has shape `(B,1,D)`, then `x - m` **works** via broadcasting.

In [39]:
def sum_over_dim(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    """Sum tensor values along dimension dim."""
    return x.sum(dim=dim, keepdim=keepdim)

x = torch.ones(2, 3)
y = sum_over_dim(x, dim=1)
x, y

(tensor([[1., 1., 1.],
         [1., 1., 1.]]),
 tensor([3., 3.]))

In [40]:
def mean_over_dim(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    """Mean along dimension dim."""
    return x.mean(dim=dim, keepdim=keepdim)

x2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y2 = mean_over_dim(x2, dim=0)
x2, y2

(tensor([[1., 2.],
         [3., 4.]]),
 tensor([2., 3.]))

In [43]:
def max_over_dim(x: torch.Tensor, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Max values and argmax indices along dimension dim."""
    result = torch.max(x, dim=dim)
    return result.values, result.indices

x3 = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
values, idx = max_over_dim(x3, dim=1)
x3, values, idx

(tensor([[1., 5.],
         [3., 2.]]),
 tensor([5., 3.]),
 tensor([1, 0]))

In [44]:
def argmax_over_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
    """Argmax indices along dimension dim."""
    return torch.argmax(x, dim=dim)

idx2 = argmax_over_dim(x3, dim=1)
x3, idx2

(tensor([[1., 5.],
         [3., 2.]]),
 tensor([1, 0]))

In [45]:
def broadcast_add_vector(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Add a vector v to each row of a 2D tensor x using broadcasting."""
    return x + v

x4 = torch.zeros(3, 2)
v = torch.tensor([10.0, 20.0])
y4 = broadcast_add_vector(x4, v)
x4, v, y4

(tensor([[0., 0.],
         [0., 0.],
         [0., 0.]]),
 tensor([10., 20.]),
 tensor([[10., 20.],
         [10., 20.],
         [10., 20.]]))

## Vectorization
We want to avoid slow (due to per-iteration overhead) python loops as much as possible and pytorch gives us many tools to avoid it. We cover these basics:
- `cat` vs `stack` (concatenate existing dims vs create a new dim)
- `repeat` vs `expand`
- `scatter_add` / `index_add` for accumulation
- `where` for conditional selection

### `expand` vs `repeat`

- `repeat(...)` **copies** data → larger tensor with independent storage.
- `expand(...)` **does not copy** data → it creates a *view* with clever strides.

This has two important implications:

1) `expand` only works when expanding a **size-1 dimension** (broadcasting a singleton).
2) The expanded tensor may have **many positions pointing to the same memory**.  
   Modifying the expanded tensor can therefore produce surprising results (multiple rows change).

Rule of thumb:
- Use `expand` for read-only broadcasting.
- Use `repeat` if you truly need independent copies.


NOTE: We implore you to write your own quick checks from now on for calling the functions and checking their output. As before you are still required to fill in the TODOs in each function.

In [46]:
def concat_tensors(tensors: Sequence[torch.Tensor], dim: int = 0) -> torch.Tensor:
    """Concatenate tensors along dim. NOTE: This will always allocate new memory"""
    return torch.cat(tensors=tensors, dim=dim)

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t_cat = concat_tensors([t1, t2], dim=0)
t1, t2, t_cat

(tensor([[1, 2],
         [3, 4]]),
 tensor([[5, 6],
         [7, 8]]),
 tensor([[1, 2],
         [3, 4],
         [5, 6],
         [7, 8]]))

In [48]:
def stack_tensors(tensors: Sequence[torch.Tensor], dim: int = 0) -> torch.Tensor:
    """Stack tensors along a new dimension dim."""
    return torch.stack(tensors=tensors, dim=dim)

t1 = torch.tensor([1, 2])
t2 = torch.tensor([3, 4])
t_stack = stack_tensors([t1, t2], dim=0)
print(t1.shape, t2.shape, t_stack.shape)
t1, t2, t_stack

torch.Size([2]) torch.Size([2]) torch.Size([2, 2])


(tensor([1, 2]),
 tensor([3, 4]),
 tensor([[1, 2],
         [3, 4]]))

In [49]:
def repeat_tensor(x: torch.Tensor, repeats: Sequence[int]) -> torch.Tensor:
    """Repeat tensor along each dimension."""
    return x.repeat(repeats)

x5 = torch.tensor([[1, 2], [3, 4]])
x5_repeated = repeat_tensor(x5, repeats=[2, 3])
x5, x5_repeated

(tensor([[1, 2],
         [3, 4]]),
 tensor([[1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4],
         [1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4]]))

In [51]:
def expand_tensor(x: torch.Tensor, *sizes: int) -> torch.Tensor:
    """Expand tensor to a larger size without copying data.(Sizes can be -1 to keep original dimension.)"""
    return x.expand(*sizes)

x6 = torch.tensor([[1], [2], [3]])
x6_expanded = expand_tensor(x6, 3, 4)
print(x6.shape, x6_expanded.shape)
x6, x6_expanded

torch.Size([3, 1]) torch.Size([3, 4])


(tensor([[1],
         [2],
         [3]]),
 tensor([[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]]))

In [53]:
def cumsum_over_dim(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """Cumulative sum along dim."""
    return x.cumsum(dim=dim)

x7 = torch.tensor([[1, 2], [3, 4]])
x7_cumsum = cumsum_over_dim(x7, dim=1)
x7, x7_cumsum

(tensor([[1, 2],
         [3, 4]]),
 tensor([[1, 3],
         [3, 7]]))

In [55]:
def where_select(mask: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Elementwise select: return a where mask is True else b. mask must be broadcastable to a and b."""
    return torch.where(condition=mask, input=a, other=b)

mask = torch.tensor([[True, False], [False, True]])
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
selected = where_select(mask, a, b)
mask, a, b, selected

(tensor([[ True, False],
         [False,  True]]),
 tensor([[1, 2],
         [3, 4]]),
 tensor([[10, 20],
         [30, 40]]),
 tensor([[ 1, 20],
         [30,  4]]))

In [56]:
def one_hot(indices: torch.Tensor, num_classes: int, dtype: torch.dtype | None = None) -> torch.Tensor:
    """
    Create one-hot encodings.
    Output is a tensor of the same shape as indices with an added dimension of size num_classes at the end, 
    where the value along that dimension is 1 if it matches the index and 0 otherwise.

    Shapes:
    - indices: (...,) integer tensor
    Return:
    - out: (..., num_classes)

    Requirements:
    - Must work for arbitrary leading shape.
    - No Python loops.
    """
    out = torch.zeros(*indices.shape, num_classes, dtype=dtype or indices.dtype, device=indices.device)
    out.scatter_(-1, indices.unsqueeze(-1), 1)
    return out

indices = torch.tensor([0, 2, 1])
num_classes = 4
one_hot_encoded = one_hot(indices, num_classes)
indices, one_hot_encoded

(tensor([0, 2, 1]),
 tensor([[1, 0, 0, 0],
         [0, 0, 1, 0],
         [0, 1, 0, 0]]))

In [60]:
def scatter_add_1d(
    values: torch.Tensor, indices: torch.Tensor, size: int
) -> torch.Tensor:
    """
    Sum `values` into an output vector at positions `indices`.

    Shapes:
    - values: (N,)
    - indices: (N,) integer indices in [0, size)
    Return:
    - out: (size,) with same dtype and device as values

    Requirement:
    - no Python loops
    """
    mask = indices.unsqueeze(1) == torch.arange(size, device=values.device)
    return (values.unsqueeze(1) * mask).sum(dim=0)

values = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0])
indices = torch.tensor([0, 2, 0, 1, 2])
out = scatter_add_1d(values, indices, size=3)
values, indices, out

(tensor([3., 1., 4., 1., 5.]), tensor([0, 2, 0, 1, 2]), tensor([7., 1., 6.]))

In [61]:
def batched_token_histogram(tokens: torch.Tensor, vocab_size: int) -> torch.Tensor:
    """
    Count token occurrences per batch item.

    Shapes:
    - tokens: (B, T) int64
    Return:
    - counts: (B, vocab_size) where counts[b, v] = number of times token v appears in tokens[b] 

    Requirements:
    - No Python loops over B or T.
    """
    mask = tokens.unsqueeze(-1) == torch.arange(vocab_size, device=tokens.device)
    return mask.sum(dim=1)

tokens = torch.tensor([[5, 2, 5, 2, 0],
                       [1, 1, 3, 0, 1]])
hist = batched_token_histogram(tokens, vocab_size=6)
tokens, hist

(tensor([[5, 2, 5, 2, 0],
         [1, 1, 3, 0, 1]]),
 tensor([[1, 0, 2, 0, 0, 2],
         [1, 3, 0, 1, 0, 0]]))

In [73]:
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Mean over `dim` considering only mask==True entries.

    Convention:
    - mask: bool tensor broadcastable to x
    - mask==True means "keep this entry"

    Return: same shape as x.mean(dim=dim)

    Requirements:
    - Avoid division by zero: if all mask are False along `dim`, define mean as 0.
    """
    mask = mask.to(x.dtype)
    numerator = (x * mask).sum(dim=dim)
    denominator = mask.sum(dim=dim)
    return torch.where(denominator == 0, torch.zeros_like(numerator), numerator / denominator)

x = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                  [1.0, 2.0, 3.0, 4.0]])
mask = torch.tensor([[True,  False, True,  False],
                     [False, False, False, False]])
masked_means = masked_mean(x, mask, dim=1)
x, mask, masked_means

(tensor([[1., 2., 3., 4.],
         [1., 2., 3., 4.]]),
 tensor([[ True, False,  True, False],
         [False, False, False, False]]),
 tensor([2., 0.]))

## Einsum warmup

Now that you’re comfortable with shapes and broadcasting, we’ll introduce `torch.einsum`, a concise way to express tensor operations by explicitly naming axes and summing over repeated indices.


### The idea
You describe each input tensor by labeling its dimensions with letters, e.g.
- `x: (B, T, D)` → `"btd"`
- `W: (D, H)`    → `"dh"`

Then you tell einsum what output labels you want:
- `"btd,dh->bth"`

### Rules of einsum
1) **Same letter = same axis** (must match in size, except broadcastable size-1).
2) **Repeated letters are summed over** (a “contraction”).
3) **Letters that appear in the output are kept** (in that order).
4) You can **reorder axes** just by changing the output label order.

### Tiny cheat sheet
- Sum over an axis: `"btd->bt"` (sums over `d`)
- Transpose: `"ij->ji"`
- Dot product: `"d,d->"` or batched `"btd,btd->bt"`
- Matrix multiply: `"ik,kj->ij"`
- Batched matmul: `"bij,bjk->bik"`
- Outer product: `"i,j->ij"`

### How to derive an einsum (recommended workflow)
1) Write down shapes with named axes (e.g. `q: b h t d`, `k: b h s d`).
2) Decide which axes you want to **sum over** (give them the same letter in both inputs).
3) Decide which axes you want to **keep** in the output (write them after `->`).

In this section, you’ll use einsum to implement building blocks that show up in attention:
- linear projections (`x @ W`)
- dot products
- attention score matrices (`QKᵀ`)
- applying attention weights (`softmax(scores) @ V`)

NOTE: For these exercises you are required to use `torch.einsum` not `matmul` (we check). You are also not required to understand the attention mechanism at this point and the exercises are sovable without. It is good however, to remember the implementations in this exercise for future implementations.

In [72]:
def einsum_linear_btd_dh_to_bth(x: torch.Tensor, W: torch.Tensor) -> torch.Tensor:
    """
    Linear projection using einsum.

    Shapes:
    - x: (B, T, D)
    - W: (D, H)
    Return:
    - y: (B, T, H)
    """
    return torch.einsum("btd,dh->bth", x, W)

x = torch.randn(2, 3, 4)
W = torch.randn(4, 5)
y = einsum_linear_btd_dh_to_bth(x, W)
x.shape, W.shape, y.shape

(torch.Size([2, 3, 4]), torch.Size([4, 5]), torch.Size([2, 3, 5]))

In [78]:
def einsum_pairwise_dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Pairwise dot product between x and y.

    Shapes:
    - x: (B, T, D)
    - y: (B, T, D)
    Return:
    - dots: (B, T) where dots[b,t] = dot(x[b,t], y[b,t])
    """
    return torch.einsum("btd,btd->bt", x, y)

x = torch.tensor([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
                 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]])
y = torch.tensor([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
                 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]])
dots = einsum_pairwise_dot(x, y)
print(x.shape, y.shape, dots.shape)
x, y, dots

torch.Size([2, 3, 4]) torch.Size([2, 3, 4]) torch.Size([2, 3])


(tensor([[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],
 
         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]]]),
 tensor([[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],
 
         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]]]),
 tensor([[ 30., 174., 446.],
         [ 30., 174., 446.]]))

In [77]:
def einsum_qk_scores(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    """
    Compute attention scores QK^T using einsum.

    Shapes:
    - q: (B, H, T, Dh)
    - k: (B, H, T, Dh)
    Return:
    - scores: (B, H, T, T) where scores[b,h,i,j] = dot(q[b,h,i], k[b,h,j])
    """
    return torch.einsum("bhid,bhjd->bhij", q, k)

q = torch.randn(2, 4, 3, 5)
k = torch.randn(2, 4, 3, 5)
scores = einsum_qk_scores(q, k)
q.shape, k.shape, scores.shape

(torch.Size([2, 4, 3, 5]), torch.Size([2, 4, 3, 5]), torch.Size([2, 4, 3, 3]))

In [79]:
def einsum_apply_attention(weights: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """
    Apply attention weights to values using einsum.

    Shapes:
    - weights: (B, H, T, T)
    - v:       (B, H, T, Dh)
    Return:
    - out:     (B, H, T, Dh) where out[b,h,i] = sum_j weights[b,h,i,j] * v[b,h,j]
    """
    return torch.einsum("bhij,bhjd->bhid", weights, v)

weights = torch.randn(2, 4, 3, 3)
v = torch.randn(2, 4, 3, 5)
out = einsum_apply_attention(weights, v)
weights.shape, v.shape, out.shape

(torch.Size([2, 4, 3, 3]), torch.Size([2, 4, 3, 5]), torch.Size([2, 4, 3, 5]))

## Attention Fundamentals
This exercise introduces some building blocks of the attention mechanism which we will encounter extensively throughout the course. It's not yet required for you to fully understand the mechanism to implement the exercises. However, it's good to remember these building blocks for the future. 

To complete the exercises you should familiarize yourself with these topics:
- Stable softmax read: https://jaykmody.com/blog/stable-softmax/
- Masking: typically this means setting masked logits to -inf *before* softmax.
- For attention: causal masks are upper-triangular (no attending to the future).

In [80]:
def stable_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Numerically stable softmax along `dim`.

    Requirements:
    - Must not overflow for large values in x.
    - Output sums to 1 along `dim`.
    """
    x = x - x.max(dim=dim, keepdim=True).values
    exp_x = torch.exp(x)
    return exp_x / exp_x.sum(dim=dim, keepdim=True)

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
softmax_x = stable_softmax(x, dim=1)
x, softmax_x

(tensor([[1., 2., 3.],
         [4., 5., 6.]]),
 tensor([[0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652]]))

In [81]:
def masked_fill_tensor(x: torch.Tensor, mask: torch.Tensor, value: float) -> torch.Tensor:
    """
    Return a copy of x where positions with mask == True are replaced by `value`.
    
    Requirements:
    - mask must be broadcastable to x.
    - do NOT modify x in-place.
    """
    return x.clone().masked_fill(mask, value)

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
mask = torch.tensor([[True, False], [False, True]])
value = -1.0
masked_filled = masked_fill_tensor(x, mask, value)
x, mask, masked_filled

(tensor([[1., 2.],
         [3., 4.]]),
 tensor([[ True, False],
         [False,  True]]),
 tensor([[-1.,  2.],
         [ 3., -1.]]))

In [82]:
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Softmax over x with a boolean mask.

    Convention:
    - mask == True means "invalid and must receive probability 0".
    - Do masking before softmax (i.e., set invalid logits to a large negative).”

    Requirements:
    - Must be numerically stable.
    - Output must be exactly 0 where mask==True.
    - If all entries are masked along `dim`, return all zeros along `dim`.
    - You may reuse functions you implemented above.
    """
    x = masked_fill_tensor(x, mask, float('-inf'))
    out = stable_softmax(x, dim=dim)
    return out.nan_to_num(0.0)

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[False, True, False], [True, False, True]])
masked_softmax_x = masked_softmax(x, mask, dim=1)
x, mask, masked_softmax_x

(tensor([[1., 2., 3.],
         [4., 5., 6.]]),
 tensor([[False,  True, False],
         [ True, False,  True]]),
 tensor([[0.1192, 0.0000, 0.8808],
         [0.0000, 1.0000, 0.0000]]))

In [86]:
def make_causal_mask(T: int, device: torch.device | str | None = None) -> torch.Tensor:
    """
    Create a causal (future-masking) boolean mask of shape (T, T).

    Convention:
    - mask[i, j] == True  => position (i attends to j) is NOT allowed (j is in the future)
    - mask[i, j] == False => allowed

    So this is an upper-triangular mask above the diagonal.

    Return:
    - mask: boolean tensor on the specified device

    Example (T=4):
        [[F, T, T, T],
         [F, F, T, T],
         [F, F, F, T],
         [F, F, F, F]]
    """
    return torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)

T = 4
causal_mask = make_causal_mask(T)
print(causal_mask)

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [91]:
def apply_causal_mask(attn_logits: torch.Tensor, value: float = -1e9) -> torch.Tensor:
    """
    Apply a causal mask to attention logits.

    Expected shapes:
    - attn_logits: (..., T, T)

    Returns:
    - masked logits (same shape) where masked positions have been set to `value`.

    Notes:
    - Create a causal mask for the final two dims.
    - Broadcast it across leading dims.
    - You may reuse functions declared above.
    """
    T = attn_logits.shape[-1]
    mask = make_causal_mask(T, device=attn_logits.device)
    return masked_fill_tensor(attn_logits, mask, value)

attn_logits = torch.randn(2, 4, 3, 3)
masked_attn_logits = apply_causal_mask(attn_logits, value=float('-inf'))
attn_logits.shape, masked_attn_logits.shape
attn_logits, masked_attn_logits

(tensor([[[[ 0.3327,  0.0627, -2.0045],
           [-0.8967, -1.1738, -1.0775],
           [ 0.0554,  0.5573, -0.1508]],
 
          [[ 0.9447,  2.0411,  0.6932],
           [-0.8074, -1.3391, -0.7592],
           [-1.0000,  2.2931, -1.4080]],
 
          [[ 0.5483, -1.2771, -1.0473],
           [ 0.0253,  0.8101,  1.6036],
           [ 0.9964, -0.4209,  0.0042]],
 
          [[ 0.7538,  0.3938,  1.7975],
           [-0.9047, -0.8821,  0.0347],
           [-0.5886, -0.1708, -0.6585]]],
 
 
         [[[-0.4928,  2.1953,  0.7555],
           [ 0.9967, -0.2279,  2.0150],
           [ 1.1887,  1.7528, -0.5652]],
 
          [[ 0.7452,  0.9624, -1.0270],
           [ 0.2820, -0.6895, -0.8664],
           [-0.0208, -0.4424, -1.3617]],
 
          [[ 2.6034, -0.7616,  0.6102],
           [ 0.3048,  0.4401,  1.3703],
           [-0.0659, -0.3316,  0.9226]],
 
          [[ 1.0229,  0.4207,  1.9750],
           [ 0.2810,  0.3507, -1.3199],
           [-0.9578, -0.1753,  1.2907]]]]),
 tensor([[[[