In [22]:
import torch
import torch.nn as nn
import torch.cuda.nvtx as nvtx
from einops import rearrange, einsum

In [23]:
import sys
sys.path.append("../")

from cs336_basics.cs336_basics.model import BasicsTransformerLM
from cs336_basics.cs336_basics.optimizer import get_cosine_lr, AdamW
from cs336_basics.cs336_basics.data import get_batch
from cs336_basics.cs336_basics.nn_utils import cross_entropy

In [None]:
# s = torch.tensor(0,dtype=torch.float32)
# for i in range(1000):
#     s += torch.tensor(0.01,dtype=torch.float32)
# print(s)
# s = torch.tensor(0,dtype=torch.float16)
# for i in range(1000):
#     s += torch.tensor(0.01,dtype=torch.float16)
# print(s)
# s = torch.tensor(0,dtype=torch.float32)
# for i in range(1000):
#     s += torch.tensor(0.01,dtype=torch.float16)
# print(s)
# s = torch.tensor(0,dtype=torch.float32)
# for i in range(1000):
#     x = torch.tensor(0.01,dtype=torch.float16)
#     s += x.type(torch.float32)
# print(s)

tensor(10.0001)
tensor(9.9531, dtype=torch.float16)
tensor(10.0021)
tensor(10.0021)


In [78]:
class ToyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        print(f"fc1: {x.dtype}")
        x = self.ln(x)
        print(f"ln: {x.dtype}")
        x = self.fc2(x)
        return x

In [80]:
model: nn.Module = ToyModel(in_features=10, out_features=10)
dtype: torch.dtype = torch.float32
device: str = "cuda"
x: torch.Tensor = torch.rand((10, 10))
target: torch.Tensor = torch.randint(0, 2, (10,))

criterion = cross_entropy
optimizer = AdamW(model.parameters())
scaler = torch.amp.GradScaler()

model = model.to(device)
x = x.to(device)
target = target.to(device)

for name, param in model.named_parameters():
    print(f"{name} grad dtype: {param.dtype}")
print("\n")


with torch.autocast(dtype=torch.bfloat16, device_type=device):
    logits = model(x)
    print(f"logits type: {logits.dtype}")
    loss = criterion(logits, target)
    print(f"loss type: {loss.dtype}")
scaler.scale(loss).backward()
print(f"loss type after grad_scale: {loss.dtype}")
print("\n")

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name} grad dtype: {param.grad.dtype}")

fc1.weight grad dtype: torch.float32
ln.weight grad dtype: torch.float32
ln.bias grad dtype: torch.float32
fc2.weight grad dtype: torch.float32


fc1: torch.bfloat16
ln: torch.float32
logits type: torch.bfloat16
loss type: torch.float32
loss type after grad_scale: torch.float32


fc1.weight grad dtype: torch.float32
ln.weight grad dtype: torch.float32
ln.bias grad dtype: torch.float32
fc2.weight grad dtype: torch.float32


In [51]:
from jaxtyping import Float, Bool, Int
from torch import Tensor
import torch.nn as nn
import math
import torch.nn.functional as F
import torch
import timeit

import sys
sys.path.append("../")

from cs336_basics.cs336_basics.nn_utils import *
from cs336_basics.cs336_basics.optimizer import AdamW
from cs336_basics.cs336_basics.model import *

In [44]:
class CausalMultiHeadSelfAttention(nn.Module):
    """Multi-Head Self-Attention

    This function implements section 3.2.2 of the Transformer paper. In particular,
    given an input tensor of shape `(batch_size, sequence_length, d_model)`, we project
    it to create queries, keys, and values, and then perform causal multi-headed attention with
    those queries, keys, and values.

    Args:
        d_model: int
            The dimensionality of the model embeddings and sublayer outputs.
        num_heads: int
            Number of heads to use in multi-headed attention. `d_model` must be
            evenly divisible by `num_heads`.
        positional_encoder: RotaryEmbedding
            The RoPE module to use.

    Returns:
        Tensor of shape `(batch_size, sequence_length, d_model)`.
    """

    def __init__(
        self,
        d_model: int,
    ):
        super().__init__()
        self.d_model = d_model

        self.output_proj = Linear(self.d_model, self.d_model)

    def forward(self, Q, K, V) -> Float[Tensor, " ... seq d_v"]:

        # Take apart each head from the embedding dimension of Q, K, V to shape (..., num_heads, seq_len, d_k).
        # Q, K, V = (
        #     rearrange(X, "... seq (heads d) -> ... heads seq d", heads=self.num_heads)
        #     for X in (Q, K, V)
        # )  # fmt: skip

        # if token_positions is None:
        #     token_positions = einx.rearrange("seq -> b... seq", torch.arange(sequence_length, device=x.device), b=[1] * len(b))

        # # Duplicate token positions for each head
        # token_positions = rearrange(token_positions, "... seq -> ... 1 seq")

        # Construct causal mask

        # Shape: (..., num_heads, sequence_length, d_k)
        attn_output = scaled_dot_product_attention(K=K, Q=Q, V=V)

        # Concatenate the attention output from all heads.
        # (..., sequence_length, num_heads * d_v).
        # Apply the output projection
        output = self.output_proj(attn_output)
        return output

In [45]:
a = CausalMultiHeadSelfAttention(d_model=16)