In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# i just found out
# you dont need custom autograd function (e.g., a MemoryEfficientLinear class)
# why we needed a custom autograd function was to let it generalize to other loss functions
# in here i apply the same by accepting a loss_fn
# so it will generalize to any loss function that accept input in same style
def chunk_function(batch, linear, labels, loss_fn, n_chunks=1):
    batch_chunks = torch.chunk(batch, n_chunks)
    labels_chunks = torch.chunk(labels, n_chunks)
    total_loss = 0

    for batch_chunk, labels_chunk in zip(batch_chunks, labels_chunks):
        x = linear(batch_chunk).float()
        loss = loss_fn(x.view(-1, x.shape[-1]), labels_chunk.view(-1)) / n_chunks
        total_loss += loss
    return total_loss

def test_small_case():
    torch.manual_seed(0)
    batch_size, seq_len, hidden_dim, vocab_size = 2, 4, 8, 16

    X = torch.randn(batch_size, seq_len, hidden_dim, requires_grad=True)
    labels = torch.randint(0, vocab_size, (batch_size, seq_len))
    linear = nn.Linear(hidden_dim, vocab_size)

    loss_fns = [F.cross_entropy, F.multi_margin_loss]

    for loss_fn in loss_fns:
        print(f"Testing with {loss_fn.__name__}")

        out1 = loss_fn(linear(X).float().view(-1, vocab_size), labels.view(-1))
        grad1 = torch.autograd.grad(out1, X)[0]

        out2 = chunk_function(X, linear, labels, loss_fn)
        grad2 = torch.autograd.grad(out2, X)[0]

        print("Forward pass matches:", torch.allclose(out1, out2))
        print("Backward pass matches:", torch.allclose(grad1, grad2))

def test_large_case():
    batch_size, seq_len, hidden_dim, vocab_size = 4, 4096, 4096, 128000

    X = torch.randn(batch_size, seq_len, hidden_dim, requires_grad=True, device="cuda")
    labels = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    linear = nn.Linear(hidden_dim, vocab_size).cuda()

    try:
        out1 = F.cross_entropy(linear(X).float().view(-1, vocab_size), labels.view(-1))
        print("Original approach succeeded (unexpected)")
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("Original approach OOM as expected")

    out2 = chunk_function(X, linear, labels, F.cross_entropy, 4)
    print("Memory efficient approach succeeded")

test_small_case()
test_large_case()

# this will work with grpo and llama because nothing prevents it from that
# also we validated the losses match already
# and it generalizes to other loss functions already because it accepts a loss_fn
# btw F.mse_loss and some other cases wont work since they accept different shapes
# and even through you do custom autograd mse will still require modifications which prevents generalization

Testing with cross_entropy
Forward pass matches: True
Backward pass matches: True
Testing with multi_margin_loss
Forward pass matches: True
Backward pass matches: True
Original approach OOM as expected
Memory efficient approach succeeded
