In [1]:
import torch
import cs336_basics
from cs336_basics.model import BasicsTransformerLM
from cs336_basics.data import get_batch
from cs336_basics.optimizer import AdamW
from cs336_basics.nn_utils import cross_entropy

In [None]:
from torch import Tensor
from jaxtyping import Float, Bool, Int
from cs336_basics.nn_utils import softmax
from cs336_basics.model import Linear

def scaled_dot_product_attention(
    Q: Float[Tensor, " ... queries d_k"],
    K: Float[Tensor, " ... keys    d_k"],
    V: Float[Tensor, " ... keys    d_v"],
    mask: Bool[Tensor, " ... queries keys"] | None = None,
) -> Float[Tensor, " ... queries d_v"]:
    """Scaled dot-product attention.

    This function implements Eq. 1 of the Transformer paper.

    Args:
        Q: Tensor of queries, may have any number of leading dimensions.
        K: Tensor of keys, sharing leading dimensions with Q.
        V: Tensor of values, sharding leading dimensions with Q and K.
        mask: An (optional) mask of shape (..., seq_len, seq_len).
            Attention scores for positions with a mask value of `False` should
            be masked out, i.e., not affect the softmaxed attention probabilities.

    Returns:
        torch.FloatTensor of shape (..., seq_len, value_dimension)
        with the output of running your scaled dot product attention
        implementation with the provided key, query, and value tensors.
    """

    d_k = K.shape[-1]
    attention_scores = einsum(Q, K, "... query d_k, ... key d_k -> ... query key") / math.sqrt(d_k)

    if mask is not None:
        attention_scores = torch.where(mask, attention_scores, float("-inf"))

    attention_weights = softmax(attention_scores, dim=-1)  # Softmax over the key dimension

    return einsum(attention_weights, V, "... query key, ... key d_v ->  ... query d_v")


class AttentionWithLinear(nn.Module):
    """Single-Head Self-Attention
    Args:
        batch_size: Batch size
        seq_len: Length of sequence
        d_model: head embedding size

    Returns:
        Tensor of shape `(batch_size, sequence_length, d_model)`.
    """

    def __init__(
        self,
        batch_size,
        seq_len,
        d_model,
        vocab_size,
        device,
    ):
        super().__init__()
        self.Q = torch.randn((batch_size, seq_len, d_model), device=device)
        self.K = torch.randn((batch_size, seq_len, d_model), device=device)
        self.V = torch.randn((batch_size, seq_len, d_model), device=device)
        self.lm_head = Linear(d_model, vocab_size)

    def forward(self) -> Float[Tensor, " ... seq d_v"]:
        # Construct causal mask
        causal_mask = torch.tril(torch.ones(seq_len, seq_len))

        # Shape: (..., num_heads, sequence_length, d_k)
        attn_output = scaled_dot_product_attention(K=K, Q=Q, V=V, mask=causal_mask)

        return attn_output

In [17]:
import torch

# Create a sample matrix
matrix = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# Get the lower triangular part with offset 0 (default)
lower_triangular_matrix = torch.tril(matrix)
print("Lower triangular matrix (offset=0):\n", lower_triangular_matrix)

Lower triangular matrix (offset=0):
 tensor([[1, 0, 0],
        [4, 5, 0],
        [7, 8, 9]])


In [None]:

BATCH_SIZE = 8
d_models = [16, 32, 64, 128]
seq_lens = [256, 1024, 4096, 8192, 16384]

for d_model, seq_len in zip(d_models, seq_lens):
    Q = torch.randn((BATCH_SIZE, seq_len, d_model)).to('cuda')
    break

In [None]:
VOCAB_SIZE = 10_000
ROPE_THETA = 10_000
BATCH_SIZE = 4
CONTEXT_LENGTH = 256
D_MODEL = 768
D_FF = 3072
NUM_LAYERS = 12
NUM_HEADS = 12

WARMUP_STEPS = 5


In [None]:

model = BasicsTransformerLM(
    vocab_size = VOCAB_SIZE,
    context_length = CONTEXT_LENGTH,
    d_model = D_MODEL,
    num_layers = NUM_LAYERS,
    num_heads = NUM_HEADS,
    d_ff = D_FF,
    rope_theta = ROPE_THETA,
)
model.to("cuda:0");

In [None]:
import numpy as np
# np_file = "../data/ts_valid.npy"
# dataset = np.load(np_file)
dataset = np.random.randint(0, VOCAB_SIZE, 1024)


In [None]:
import timeit
from functools import partial

def train_step(do_backward=False):
    x, y = get_batch(
        dataset, BATCH_SIZE, CONTEXT_LENGTH, "cuda" 
    )
    y_hat = model(x)
    if do_backward:
        optimizer.zero_grad()
        loss = cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
    torch.cuda.synchronize()

def run_test(warmup_steps, train_steps, do_backward):
    for _ in range(warmup_steps):
        train_step()

    train_step_ = partial(train_step, do_backward=do_backward)
    elapsed = timeit.timeit(train_step_, number=train_steps)
    print(f"Time for {train_steps} training step: {elapsed:.6f} seconds")

In [None]:
run_test(5, 10, True)

In [None]:
import torch
from torch import nn

class ToyModel(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.ln(x)
        x = self.fc2(x)
        return x

INPUT_SIZE = 3
OUTPUT_SIZE = 4
BATCH_SIZE = 4
model = ToyModel(INPUT_SIZE, OUTPUT_SIZE).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

loss_fn = nn.MSELoss()

x = torch.randn(BATCH_SIZE, INPUT_SIZE).cuda()
y = torch.randn(BATCH_SIZE, OUTPUT_SIZE).cuda()


DTYPE = torch.bfloat16
with torch.autocast(device_type="cuda", dtype=DTYPE):
    print(f"Input: {x.dtype}")
    x_after_fc1 = model.fc1(x)
    print(f"After fc1: {x_after_fc1.dtype}")
    x_after_ln = model.ln(x_after_fc1)
    print(f"After LN: {x_after_ln.dtype}")
    y_pred = model.relu(model.fc2(x_after_ln))
    print(f"Logits: {y_pred.dtype}")
    loss = loss_fn(y, y_pred)
    print(f"Loss: {loss.dtype}")

    print("="*60)
    optimizer.zero_grad()
    for name, param in model.named_parameters():
        print(name, param.dtype)
    loss.backward()
    optimizer.step()
    print("="*60)
    for name, param in model.named_parameters():
        print(name, param.dtype, param.grad.dtype)


In [None]:
import torch
x = torch.tensor([1e-3], dtype=torch.float16)
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True))
print(rms)
x = torch.tensor([1e-3], dtype=torch.bfloat16)
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True))
print(rms)

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype, param.grad.dtype)

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype, param.grad.dtype)

In [None]:
import sqlite3
import pandas as pd

conn = sqlite3.connect('report-mixed.sqlite')
df = pd.read_sql_query("""
    SELECT text, SUM(end - start) as total_ns
    FROM NVTX_EVENTS 
    WHERE text = 'Backward Pass'
    GROUP BY text
""", conn)
print(df)

In [None]:
# Get schema for a specific table (e.g., NVTX_EVENTS)
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(NVTX_EVENTS);")
columns = cursor.fetchall()
print("NVTX_EVENTS table schema:")
for col in columns:
    print(f"  {col[1]} ({col[2]})")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Read the CSV data
df = pd.read_csv('benchmark_results.csv')

# Define size order for proper x-axis ordering
size_order = ['small', 'medium', 'large', 'xl', '2.7b']
df['size'] = pd.Categorical(df['size'], categories=size_order, ordered=True)

# Create the plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Forward Pass Plot
for precision in ['full', 'mixed']:
    data = df[df['precision'] == precision]
    ax1.plot(data['size'], data['forward_time_seconds'], 
             marker='o', linewidth=2, label=f'{precision} precision')

ax1.set_title('Forward Pass Time vs Model Size')
ax1.set_xlabel('Model Size')
ax1.set_ylabel('Time (seconds)')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Backward Pass Plot
for precision in ['full', 'mixed']:
    data = df[df['precision'] == precision]
    ax2.plot(data['size'], data['backward_time_seconds'], 
             marker='o', linewidth=2, label=f'{precision} precision')

ax2.set_title('Backward Pass Time vs Model Size')
ax2.set_xlabel('Model Size')
ax2.set_ylabel('Time (seconds)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('benchmark_timing_plot.png', dpi=300, bbox_inches='tight')
plt.show()

# Print some analysis
print("Forward Pass Speedup with Mixed Precision:")
for size in size_order:
    full_time = df[(df['size'] == size) & (df['precision'] == 'full')]['forward_time_seconds'].iloc[0]
    mixed_time = df[(df['size'] == size) & (df['precision'] == 'mixed')]['forward_time_seconds'].iloc[0]
    speedup = full_time / mixed_time
    print(f"{size}: {speedup:.2f}x {'speedup' if speedup > 1 else 'slowdown'}")

print("\nBackward Pass Speedup with Mixed Precision:")
for size in size_order:
    full_time = df[(df['size'] == size) & (df['precision'] == 'full')]['backward_time_seconds'].iloc[0]
    mixed_time = df[(df['size'] == size) & (df['precision'] == 'mixed')]['backward_time_seconds'].iloc[0]
    speedup = full_time / mixed_time
    print(f"{size}: {speedup:.2f}x {'speedup' if speedup > 1 else 'slowdown'}")