In [None]:
from cs336_systems.ddp_overlap_bucketed import get_buckets
import numpy as np
from cs336_basics.data import get_batch
from cs336_basics.model import BasicsTransformerLM
from cs336_basics.nn_utils import cross_entropy

In [None]:
module = BasicsTransformerLM(
    vocab_size=10_000,
    context_length=512,
    d_model=1600,
    d_ff=6400,
    num_layers=48,
    num_heads=25,
    rope_theta=10_000,
)
module.to("cuda")

In [None]:
dataset = np.random.randint(0, 10_000, 1024)
x, y = get_batch(dataset, 4, 512, "cuda")
y_hat = module(x)
loss = cross_entropy(y_hat, y)
loss.backward()

In [None]:
for param in reversed(list(module.parameters())):
    if param.requires_grad: 
        print(param.grad)

In [None]:
buckets = get_buckets(module, 10.0)
for bucket in buckets:
    for param in bucket:
        if param.grad is None:
            print(param)

In [None]:
print(buckets[-1][0].grad[0])

In [None]:
buckets[-1][0].size()

In [None]:
buckets[-1][1].size()

In [None]:
x = torch.tensor(1.0, requires_grad=True)
y = x + 2

In [None]:
y.backward()

In [None]:
x.grad

In [None]:
z = x.detach()

In [None]:
z.grad

In [None]:
import torch
import torch.nn as nn

class _FC2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 50, bias=True)
        self.fc.bias.requires_grad = False

    def forward(self, x):
        x = self.fc(x)
        return x

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10, bias=False)
        self.fc2 = _FC2()
        self.fc3 = nn.Linear(50, 5, bias=False)
        self.relu = nn.ReLU()
        self.no_grad_fixed_param = nn.Parameter(torch.tensor([2.0, 2.0]), requires_grad=False)

In [None]:
def get_buckets(model, bucket_size_mb:float) -> list[list[torch.tensor]]:
    # get bucket size limit in num of elements
    bucket_size_limit = (1024**2 * bucket_size_mb)//4
    buckets = []
    current_size = 0
    current_bucket = []
    for param in reversed(list(model.parameters())):
        param.grad = param.data
        # skip params do not need grad calculation
        if not param.requires_grad:
            continue
        # if current param is large
        if param.nelement() > bucket_size_limit:
            if current_bucket:  
                buckets.append(current_bucket)
                current_bucket = []
                current_size = 0
            buckets.append([param.grad])
        # if current param leads to overflow
        elif param.nelement() + current_size > bucket_size_limit:
            buckets.append(current_bucket)
            current_bucket = [param.grad]
            current_size = param.nelement()
        # if current param is ok
        else:
            current_bucket += [param.grad]
            current_size += param.nelement()
    # final pieces if exists
    if current_bucket:  
        buckets.append(current_bucket)

    return buckets
        

In [None]:
model = ToyModel()
for param in model.parameters():
    print(param.requires_grad)

In [None]:
buckets_flat = [torch._utils._flatten_dense_tensors([bucket]) for bucket in model.parameters()]

In [None]:
type(buckets_flat[0])

In [None]:
buckets = get_buckets(ToyModel(), 0.004)
buckets_flat = [torch._utils._flatten_dense_tensors(bucket) for bucket in buckets]

In [None]:
buckets[0][0].shape

In [None]:
buckets_flat[0]

In [None]:
buckets_unflat = [
    torch._utils._unflatten_dense_tensors(bucket_flat, bucket) for (bucket_flat, bucket) in zip(buckets_flat, buckets)
]

In [None]:
buckets_unflat[0][2]

In [None]:
buckets_flat = [torch._utils._flatten_dense_tensors(bucket) for bucket in buckets]

In [None]:
buckets[0]

In [None]:
torch._utils._unflatten_dense_tensors(buckets_flat[0], buckets[0])

In [None]:
import torch
# v = torch.tensor([0., 0., 0.], requires_grad=True)
x = torch.tensor([0., 0., 0.], requires_grad=True)
v = x * 2
lr = 0.01
# simulate a simple SGD update
h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
v.backward(torch.tensor([1., 2., 3.]))
v

h.remove()  # removes the hook

In [None]:
import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
    print(grad)
    return grad + 2

c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))
c.retain_grad()

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad:grad + 100)

e = c * d

e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

In [None]:
# What you can inspect from grad_fn:
print("Backward node type:", c.grad_fn)
print("Operation name:", c.grad_fn.name())
print("Next functions:", c.grad_fn.next_functions)
print("\nHooks are stored on the tensor, not grad_fn:")
print("Tensor hooks:", c._backward_hooks)

In [None]:
??c.backward

In [None]:
d._backward_hooks

In [None]:
c._backward_hooks

In [None]:
c.backward()

In [None]:
c.backward()

In [None]:
a

In [None]:
import torch
a= torch.ones(5, requires_grad=True)

b = a**2
b.retain_grad()
# b.register_hook(lambda x: print("hello world"))

c = b.mean()
c.retain_grad()
c.backward()

print(c.grad, b.grad, a.grad)


In [None]:
c.backward(torch.tensor(1.0))

In [None]:
# Example 1: Scalar output - gradient argument defaults to 1.0
import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.sum()  # y is scalar: y = x1 + x2 + x3

print(f"y = {y}")
# y.backward()  # Equivalent to
y.backward(torch.tensor(2.0))
print(f"∂y/∂x = {x.grad}")  # Should be [1, 1, 1]

In [None]:
c

In [None]:
import torch
import time
import os
from typing import List, Callable
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.fsdp
from execute_util import text, image, link, system_text
from torch_util import get_device
from lecture_util import article_link
from lecture_08_utils import spawn, int_divide, summarize_tensor, get_init_params, render_duration

In [None]:
from lecture_08 import collective_operations_main, data_parallelism_main, generate_sample_data, tensor_parallelism_main

In [None]:
spawn(collective_operations_main, world_size=2)

In [None]:
data = generate_sample_data()

In [None]:
spawn(data_parallelism_main, world_size=2, data=data, num_layers=4, num_steps=1)

In [None]:
spawn(tensor_parallelism_main, world_size=2, data=data, num_layers=4)

In [None]:
from 