<a href="https://colab.research.google.com/github/jayasuryajsk/unsloth-puzzles/blob/main/Puzzle_E_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Solution for Puzzle E: **Memory Efficient Backprop**

Used: Cusrosr - R1

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import gc

# Custom bfloat16 cross-entropy loss
def bfloat16_cross_entropy(logits, targets, reduction="mean"):
    """
    Compute cross-entropy loss in bfloat16 without upcasting.
    Args:
        logits: (batch_size, vocab_size) in bfloat16
        targets: (batch_size,) integer labels
        reduction: 'mean' or 'sum'
    Returns:
        loss: Scalar loss in bfloat16
    """
    assert logits.dtype == torch.bfloat16, "Logits must be bfloat16"

    # Subtract max for numerical stability
    logits_max, _ = torch.max(logits, dim=-1, keepdim=True)
    logits_stable = logits - logits_max

    # Softmax: exp(logits) / sum(exp(logits))
    exp_logits = torch.exp(logits_stable)
    sum_exp_logits = torch.sum(exp_logits, dim=-1, keepdim=True)
    log_softmax = logits_stable - torch.log(sum_exp_logits)

    # Gather log probabilities for targets
    target_log_probs = log_softmax.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)

    # Negative log likelihood
    loss = -target_log_probs

    # Reduction
    if reduction == "mean":
        return loss.mean()
    elif reduction == "sum":
        return loss.sum()
    else:
        return loss

# Updated transformation function
def transformation_function(batch, linear, labels):
    x = linear(batch)  # Keep in bfloat16
    loss = bfloat16_cross_entropy(
        x.view(-1, x.shape[-1]),
        labels.view(-1),
        reduction="mean"
    )
    return loss

# Memory-efficient linear layer with autograd
class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        # Save tensors needed for backward pass
        ctx.save_for_backward(X)
        ctx.linear = linear
        ctx.labels = labels
        ctx.chunk_size = 1024  # Default chunk size

        # Compute output in chunks to save memory
        batch_size = X.size(0)
        outputs = []

        for i in range(0, batch_size, ctx.chunk_size):
            chunk = X[i:i+ctx.chunk_size]
            chunk_labels = labels[i:i+ctx.chunk_size] if labels is not None else None
            chunk_loss = forward_function(chunk, linear, chunk_labels)
            outputs.append(chunk_loss)

        # Average the losses from all chunks
        return sum(outputs) / len(outputs) if len(outputs) > 1 else outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        X, = ctx.saved_tensors
        linear = ctx.linear
        labels = ctx.labels
        chunk_size = ctx.chunk_size

        # Initialize gradient accumulators
        grad_X = torch.zeros_like(X, dtype=X.dtype)
        grad_weight = torch.zeros_like(linear.weight, dtype=linear.weight.dtype)

        # Compute gradients in chunks
        for i in range(0, X.size(0), chunk_size):
            chunk = X[i:i+chunk_size]
            chunk_labels = labels[i:i+chunk_size] if labels is not None else None

            # Forward pass for this chunk with grad enabled
            chunk.requires_grad_(True)
            chunk_output = linear(chunk)  # Keep bfloat16
            loss = bfloat16_cross_entropy(
                chunk_output.view(-1, chunk_output.shape[-1]),
                chunk_labels.view(-1) if chunk_labels is not None else None,
                reduction="mean"
            )

            # Backward pass for this chunk
            chunk_grad = torch.autograd.grad(
                loss, [chunk, linear.weight],
                grad_output,
                retain_graph=False
            )

            # Accumulate gradients
            grad_X[i:i+chunk_size] = chunk_grad[0]
            grad_weight += chunk_grad[1]

        return grad_X, grad_weight, None, None

def memory_efficient_forward(X, linear, labels):
    return MemoryEfficientLinear.apply(X, linear, labels, transformation_function)

# Memory-efficient matrix multiplication
class MemoryEfficientMatmul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, chunk_size):
        W = W.to(dtype=X.dtype)
        ctx.save_for_backward(X, W)
        ctx.chunk_size = chunk_size

        output = []
        for i in range(0, X.size(0), chunk_size):
            output.append(X[i:i+chunk_size] @ W.T)
        return torch.cat(output, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        X, W = ctx.saved_tensors
        chunk_size = ctx.chunk_size

        grad_output = grad_output.to(dtype=X.dtype)
        grad_X = torch.zeros_like(X)
        grad_W = torch.zeros_like(W)

        for i in range(0, X.size(0), chunk_size):
            X_chunk = X[i:i+chunk_size]
            grad_output_chunk = grad_output[i:i+chunk_size]
            grad_X[i:i+chunk_size] = grad_output_chunk @ W
            grad_W += grad_output_chunk.T @ X_chunk

        return grad_X, grad_W, None

def memory_efficient_matmul(X, W, chunk_size=1024):
    return MemoryEfficientMatmul.apply(X, W, chunk_size)

# Memory-efficient linear module
class EfficientLinear(nn.Module):
    def __init__(self, in_features, out_features, chunk_size=1024):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=torch.bfloat16))
        self.chunk_size = chunk_size

    def forward(self, X, labels=None):
        if labels is not None:
            return memory_efficient_forward(X, self, labels)
        return memory_efficient_matmul(X, self.weight, self.chunk_size)

    def adjust_chunk_size(self, new_size):
        self.chunk_size = new_size

# Naive linear for comparison
class NaiveLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))

    def forward(self, X):
        return X @ self.weight.T

# Memory measurement decorator
def measure_memory(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.reset_peak_memory_stats()
        initial_mem = torch.cuda.max_memory_allocated()

        result = func(*args, **kwargs)

        peak_mem = torch.cuda.max_memory_allocated() - initial_mem
        wrapper.peak_memory = peak_mem
        return result
    return wrapper

# Test functions
def test_vram_reduction():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        print("CUDA not available, skipping VRAM test")
        return

    bsz, hd, vocab = 4, 4096, 128000
    inputs = torch.randn(bsz, hd, dtype=torch.bfloat16, device=device)

    @measure_memory
    def run_naive():
        model = NaiveLinear(hd, vocab).to(device).to(dtype=torch.bfloat16)
        out = model(inputs)
        return out

    @measure_memory
    def run_efficient():
        model = EfficientLinear(hd, vocab, chunk_size=2).to(device).to(dtype=torch.bfloat16)
        out = model(inputs)
        return out

    _ = run_naive()
    naive_mem = run_naive.peak_memory

    _ = run_efficient()
    efficient_mem = run_efficient.peak_memory

    reduction = (naive_mem - efficient_mem) / naive_mem
    print(f"Memory reduction: {reduction*100:.2f}%")
    return reduction >= 0.5

def test_float32_upcast():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EfficientLinear(4096, 128000).to(device).to(dtype=torch.bfloat16)
    inputs = torch.randn(4, 4096, dtype=torch.bfloat16, device=device)
    labels = torch.randint(0, 128000, (4,), device=device)

    with torch.no_grad():
        output = model(inputs, labels)  # Test loss computation
    maintains_bfloat16 = output.dtype == torch.bfloat16
    print(f"Maintains bfloat16: {maintains_bfloat16}")
    return maintains_bfloat16

def test_other_functions():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EfficientLinear(4096, 128000).to(device)
    inputs = torch.randn(4, 4096, device=device)

    activations = {
        'relu': F.relu,
        'gelu': F.gelu,
        'tanh': torch.tanh,
        'sigmoid': torch.sigmoid
    }

    results = {}
    for name, activation in activations.items():
        try:
            output = activation(model(inputs))
            results[name] = True
        except:
            results[name] = False

    print("Function compatibility:", results)
    return all(results.values())

def test_dynamic_chunk_sizes():
    model = EfficientLinear(4096, 128000)
    inputs = torch.randn(4, 4096)

    chunk_sizes = [1, 2, 4]
    results = []

    for chunk_size in chunk_sizes:
        try:
            model.chunk_size = chunk_size
            output = model(inputs)
            results.append(True)
        except:
            results.append(False)

    success = all(results)
    print(f"Dynamic chunk sizes work: {success}")
    return success


def run_all_tests():
    score = 0

    if test_vram_reduction():
        score += 2
        print("✓ VRAM reduction test passed (+2)")

    if not test_float32_upcast():
        score = 0
        print("✗ Float32 upcast test failed (score reset to 0)")
    else:
        print("✓ Float32 upcast test passed")


    if test_other_functions():
        score += 1
        print("✓ Other functions test passed (+1)")

    if test_dynamic_chunk_sizes():
        score += 1
        print("✓ Dynamic chunk sizes test passed (+1)")

    print(f"\nFinal score: {score}/4")
    return score

if __name__ == "__main__":
    run_all_tests()

Memory reduction: 66.60%
✓ VRAM reduction test passed (+2)
Maintains bfloat16: True
✓ Float32 upcast test passed
Function compatibility: {'relu': True, 'gelu': True, 'tanh': True, 'sigmoid': True}
✓ Other functions test passed (+1)
Dynamic chunk sizes work: True
✓ Dynamic chunk sizes test passed (+1)

Final score: 4/4
