<a href="https://colab.research.google.com/github/byi8220/unsloth-puzzles/blob/main/Problem5/Unsloth_Problem_5_Other_Loss_Functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unsloth Problem 5 - Memory Efficient Backprop for other loss functions

#### Ran on a colab L4 GPU instance with 53 GB VRAM, 24 GB RAM

#### Considerations
Any degree of chunking will introduce more addition ops, which can lead to compounding numerical error. This is more pronounced along the `qlen` dimension.

The selection of `mel_num_chunks` heavily influences memory saved and numerical accuracy. The most stable configuration is `(B,1,1)`, where you only chunk by batch. If you have very large batch sizes, this is sufficient.

---
## Setup
---

In [None]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [None]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype, atol=None, rtol=None):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=atol, rtol=rtol)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

---
<a name="problem-5-impl"><a/>
## `MemoryEfficientLinear` Implementation for other functions


In [None]:
import torch.nn.functional as F
import math

def transformation_function_with_other_loss(batch, linear, labels,
                                            down_projection_function,
                                            reduction="mean"):
    x = linear(batch).float() # Up projection to large space
    loss = down_projection_function(x, labels, reduction=reduction)
    return loss

# Repeat the test above for various functions
def nll_loss_f(batch, labels, reduction="mean"):
    """The nonlinear part of `transformation_function`"""
    from torch.nn import NLLLoss
    # Down projection to small space
    down_projection_function = NLLLoss(reduction=reduction)
    # These inputs/targets are kinda nonsense, but we're just testing output
    # equivalence even with nonsense labels.
    batch = F.log_softmax(batch, -1)
    loss = down_projection_function(batch.reshape(-1, batch.shape[-1]), labels.reshape(-1))
    return loss

def bce_loss_f(batch, labels, reduction="mean"):
    """The nonlinear part of `transformation_function`"""
    from torch.nn import BCELoss
    # Down projection to small space
    down_projection_function = BCELoss(reduction=reduction)
    # These inputs/targets are kinda nonsense, but we're just testing output
    # equivalence even with nonsense labels.
    loss = down_projection_function(
        torch.sigmoid(torch.mean(torch.softmax(batch, -1), -1)),
        labels)
    return loss

def kl_div_loss_f(batch, labels, reduction="mean"):
    """The nonlinear part of `transformation_function`"""
    from torch.nn import KLDivLoss
    # Down projection to small space
    down_projection_function = KLDivLoss(reduction=reduction)
    # These inputs/targets are kinda nonsense, but we're just testing output
    # equivalence even with nonsense labels.
    loss = down_projection_function(
        torch.sigmoid(torch.mean(torch.softmax(batch, -1), -1)),
        labels)
    return loss

def cross_entropy_f(x, labels, reduction="mean"):
    """The nonlinear part of `transformation_function`"""
    from torch.nn import CrossEntropyLoss
    down_projection_function = CrossEntropyLoss(reduction=reduction)
    # Down projection to small space
    loss = down_projection_function(x.reshape(-1, x.shape[-1]), labels.reshape(-1))
    return loss

class MemoryEfficientLinear(torch.autograd.Function):
    # IMO, the spec is a bit vague, and I interpreted the arguments to
    # as MemoryEfficientLinear.forward(X, W, labels, fn) = fn(XW, labels)
    @staticmethod
    # (bsz, qlen, hd) @ (hd, vocab) -> (bsz, qlen, vocab)
    def forward(ctx, X, W, labels, forward_function, mel_num_chunks=1, ignore_index=-100):
        # NOTE: I wasn't sure what `allows_dynamic_chunk_sizes` means here.
        # I interpreted it to mean "let the user specify the number of chunks,
        # and the chunks will be sized accordingly."
        ctx.mel_num_chunks = mel_num_chunks # How to chunk `XW` over batches

        # Perform `forward_function` in chunks, and reduce them into `output`
        output = 0.0

        # Require uniform chunk size, for cleaner computations involving
        # `ForCausalLMLoss` and `num_items_in_batch`.
        assert X.shape[0] % ctx.mel_num_chunks == 0
        assert ctx.mel_num_chunks <= X.shape[0]
        b_per_chunk = X.shape[0] // ctx.mel_num_chunks

        N = 0
        for b in range(ctx.mel_num_chunks):
            b0, b1 = b *  b_per_chunk, (b+1) * b_per_chunk
            # Reduce (bsz, qlen, vocab) to (b_per_chunk, q_per_chunk, vocab)
            with torch.no_grad():
                X_slice = X[b0:b1]
                l_slice = labels[b0:b1]
                XW_slice = (F.linear(X_slice, W.T)).float()
            output += torch.numel(l_slice) * forward_function(XW_slice, l_slice)
            N += torch.numel(l_slice)
        del XW_slice
        ctx.save_for_backward(X, W, labels)
        ctx.forward_function = forward_function
        ctx.N = N
        ctx.ignore_index = ignore_index
        return output / N

    # L(X,W,T,f) = f(XW, T)
    # dL/dX = dL/df * df/d(XW) * d(XW)/dX
    # dL/dW = dL/df * df/d(XW) * d(XW)/dW
    # We want to avoid materializing df/d(XW) to save on memory,
    # as XW is the large tensor we are trying to avoid materializing
    @staticmethod
    def backward(ctx, dY):

        # As written we need to retain at least all of X, W, labels
        # (This could possibly be optimized more)
        X, W, labels = ctx.saved_tensors

        # The absolute minimum memory usage this function can possibly incur is
        # that required for the returned gradients.
        dX = torch.zeros_like(X)
        dW = torch.zeros_like(W)
        assert X.shape[0] % ctx.mel_num_chunks == 0
        assert ctx.mel_num_chunks <= X.shape[0]
        b_per_chunk = X.shape[0] // ctx.mel_num_chunks

        for b in range(ctx.mel_num_chunks):
            b0, b1 = b * b_per_chunk, (b+1) * b_per_chunk
            X_slice = X[b0:b1].detach().requires_grad_()
            W_slice = W.detach().requires_grad_()
            l_slice = labels[b0:b1].detach()
            with torch.enable_grad():
                XW_slice = (F.linear(X_slice, W_slice.T)).float()
                out = ctx.forward_function(XW_slice, l_slice) * torch.numel(l_slice)
            # From my testing this appears to use more memory than hardcoded matmul (sometimes)
            dX_slice, dW_slice = torch.autograd.grad(out, (X_slice, W_slice), dY / ctx.N, retain_graph=False, create_graph=False)
            dX[b0:b1] = dX_slice.to(dX.dtype)
            dW += dW_slice.to(dW.dtype)

        return dX, dW, None, None, None, None



---
## `MemoryEfficientLinear` Tests

This just all of Problem 5, Test 1, but repeated for other functions.


### Test Batch Only

#### Functions tested

NLLLoss, KLDivLoss, BCELoss - Mean reduction

In [None]:
#### This could (and for real code, should) be modularized better.
# But for the purpose of this problem set, I'm trying to keep things organized
# by requirement.
import torch
from functools import partial
import gc

# Make sure we're actually using cuda
device = 'cuda'

# Default config parameters
# Smaller than the CE loss test, so I can actually get through them in colab.
bsz = 32
qlen = 512
hd = 512
vocab = 32 * 1024 # 32k
dtype = torch.bfloat16
num_chunks = 8

# The projection we are trying to optimize is of the form:
# (bsz, qlen, hd) @ (hd, vocab) -> (bsz, qlen, vocab)

# Create input

batch = torch.randn((bsz, qlen, hd), dtype=dtype, requires_grad=True)
batch.retain_grad()
ce_labels = torch.randint(0, vocab, (bsz, qlen), dtype=torch.long).to(device)
nll_labels = torch.randint(0, vocab, (bsz, qlen), dtype=torch.long).to(device)
bce_labels = torch.sigmoid(nll_labels).to(torch.float32).to(device)
kl_labels = torch.sigmoid(nll_labels).to(torch.float32).to(device)
initial_W = torch.randn(hd, vocab, dtype=dtype)

functions_to_test = [cross_entropy_f, nll_loss_f, bce_loss_f, kl_div_loss_f]
labels_for_test = [ce_labels, nll_labels, bce_labels, kl_labels]

#### Run basic layer
for down_projection_function, l in zip(functions_to_test, labels_for_test):
    # Initialize a linear layer
    base_linear_ = torch.nn.Linear(hd, vocab, bias=False, dtype=dtype).to(device)
    base_linear_.weight = torch.nn.Parameter(
        initial_W.clone().detach().T.to(device))
    base_linear_.weight.grad = None

    # Prepare input data
    batch_in = batch.clone().detach().to(device).requires_grad_()
    batch_in.retain_grad()
    batch_in.grad = None

    # Clear memory and stats to profile forward
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # (bsz, qlen, hd) @ (hd, vocab) -> (bsz, qlen, vocab) logits.
    loss_expected = transformation_function_with_other_loss(batch_in, base_linear_, l.detach(), down_projection_function)
    base_linear_forward_mem = torch.cuda.max_memory_allocated()

    # Without checkpointing, `backward()` doubles our memory usage since we need
    # to persist intermediate state.

    # Clear memory and stats to profile backward()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Do the backward
    loss_expected.backward()

    # Measure peak memory usage
    # With the default config (and no checkpointing) we expect this to account for:
    # 1. The parameters of base_linear - 4096 * 128k * sizeof(bfloat16) = 1GB
    # 2. The upcasted, materialized logits - 2 * 4096 * 128k * sizeof(float) = 4GB
    # 3. The memory needed for backprop - At least 4 GB
    # 4. The gradients of `base_linear` - 4096 * 128k * sizeof(bfloat16) = 1GB
    # 5. `batch` and `labels`, which are negligibly small (Under 200 MB)
    # 6. The computed losses, which are just scalars
    base_linear_backward_mem = torch.cuda.max_memory_allocated()

    # Move to CPU to not interfere VRAM measurement
    base_linear_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
    base_linear_grad_values = base_linear_.weight.grad.clone().detach().transpose(0,1).to('cpu')

    # Free memory
    del base_linear_, batch_in

    torch.cuda.empty_cache()

    #### Run memory efficient layer

    # Initialize a linear parameter.
    mem_eff_linear_ = torch.nn.Parameter(initial_W.clone().detach().to(device))
    memEffLinear = MemoryEfficientLinear.apply
    mem_eff_linear_.grad = None

    # Prepare input data
    batch_in = batch.clone().detach().to(device).requires_grad_()
    batch_in.retain_grad()
    batch_in.grad = None

    # Clear memory and stats to profile forward
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    loss_actual = memEffLinear(batch_in,
                              mem_eff_linear_,
                              l.detach(),
                              down_projection_function,
                              num_chunks)
    mem_eff_linear_forward_mem = torch.cuda.max_memory_allocated()

    # Clear memory and stats to profile backward()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Do the backward
    loss_actual.backward()

    # Measure peak memory usage
    mem_eff_linear_backward_mem = torch.cuda.max_memory_allocated()

    # Move to CPU to not interfere VRAM measurement
    mem_eff_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
    mem_eff_linear_grad_values = mem_eff_linear_.grad.clone().detach().to('cpu')

    # Free memory
    del mem_eff_linear_, memEffLinear, batch_in
    torch.cuda.empty_cache()
    # Compare forward memory usage from the above test
    forward_vram_change = (mem_eff_linear_forward_mem - base_linear_forward_mem) / base_linear_forward_mem

    base_linear_forward_vram_gb = (base_linear_forward_mem) / (1024**3)
    base_mem_eff_forward_vram_gb = (mem_eff_linear_forward_mem) / (1024**3)
    print("#### Func {} ####".format(down_projection_function))
    print("Peak Memory usage during basic linear forward(): {:.2f} GB".format(base_linear_forward_vram_gb))
    print("Peak Memory usage during memory efficient linear forward(): {:.2f} GB".format(base_mem_eff_forward_vram_gb))
    print("Change: {:.2f} GB".format(base_mem_eff_forward_vram_gb - base_linear_forward_vram_gb))
    print("% Change: {:.2f}%".format(forward_vram_change * 100))
    # Compare backward memory usage from the above test
    backward_vram_change = (mem_eff_linear_backward_mem - base_linear_backward_mem) / base_linear_backward_mem

    base_linear_backward_vram_gb = (base_linear_backward_mem) / (1024**3)
    base_mem_eff_backward_vram_gb = (mem_eff_linear_backward_mem) / (1024**3)
    print("Peak Memory usage during basic linear backward(): {:.2f} GB".format(base_linear_backward_vram_gb))
    print("Peak Memory usage during memory efficient linear backward(): {:.2f} GB".format(base_mem_eff_backward_vram_gb))
    print("Change: {:.2f} GB".format(base_mem_eff_backward_vram_gb - base_linear_backward_vram_gb))
    print("% Change: {:.2f}%".format(backward_vram_change * 100))

    # Show losses from the above runs are equivalent
    # Bfloat16 is quite imprecise: https://nhigham.com/tag/bfloat16/
    # "bfloat16 numbers have the equivalent of about three decimal digits of precision"
    print("loss_expected", loss_expected)
    print("loss_actual", loss_actual)
    assert_same(loss_expected, loss_actual, _F(_C()), loss_actual.dtype)

    # Show gradients are equivalent
    # Assert X is same in batch and mem_eff case
    assert_same(mem_eff_batch_grad_values, base_linear_batch_grad_values,
                _F(_C()), mem_eff_batch_grad_values.dtype)
    # Assert W is same in batch and mem_eff case
    assert_same(mem_eff_linear_grad_values, base_linear_grad_values,
                _F(_C()), mem_eff_linear_grad_values.dtype)
    print()

#### Func <function cross_entropy_f at 0x78df40e5a160> ####
Peak Memory usage during basic linear forward(): 4.06 GB
Peak Memory usage during memory efficient linear forward(): 0.78 GB
Change: -3.27 GB
% Change: -80.72%
Peak Memory usage during basic linear backward(): 6.06 GB
Peak Memory usage during memory efficient linear backward(): 1.24 GB
Change: -4.82 GB
% Change: -79.57%
loss_expected tensor(92.5661, device='cuda:0', grad_fn=<NllLossBackward0>)
loss_actual tensor(92.5661, device='cuda:0', grad_fn=<MemoryEfficientLinearBackward>)

#### Func <function nll_loss_f at 0x78df40e0e200> ####
Peak Memory usage during basic linear forward(): 4.25 GB
Peak Memory usage during memory efficient linear forward(): 0.88 GB
Change: -3.38 GB
% Change: -79.40%
Peak Memory usage during basic linear backward(): 6.16 GB
Peak Memory usage during memory efficient linear backward(): 1.33 GB
Change: -4.83 GB
% Change: -78.39%
loss_expected tensor(92.7468, device='cuda:0', grad_fn=<NllLossBackward0>)
loss



#### Func <function kl_div_loss_f at 0x78df40e5a2a0> ####
Peak Memory usage during basic linear forward(): 4.25 GB
Peak Memory usage during memory efficient linear forward(): 0.88 GB
Change: -3.38 GB
% Change: -79.40%
Peak Memory usage during basic linear backward(): 8.16 GB
Peak Memory usage during memory efficient linear backward(): 1.58 GB
Change: -6.58 GB
% Change: -80.62%
loss_expected tensor(-0.5000, device='cuda:0', grad_fn=<MeanBackward0>)
loss_actual tensor(-0.5000, device='cuda:0', grad_fn=<MemoryEfficientLinearBackward>)



### ForCausalLMLoss

In [None]:
# Test this works on `ForCausalLMLoss
import torch
from typing import Callable, List, Optional, Tuple, Union
from transformers.models.llama.modeling_llama import LlamaForCausalLM, KwargsForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
from transformers.loss.loss_utils import ForCausalLMLoss
from functools import partial
import torch.nn as nn
import gc

# Make sure we're actually using cuda
device = 'cuda'

bsz = 4
qlen = 2048
hd = 2048
vocab = 128256
dtype = torch.bfloat16
num_chunks = 4
# Create input

batch = torch.randn((bsz, qlen, hd), dtype=dtype, requires_grad=True)
batch.retain_grad()
labels = torch.randint(0, vocab, (bsz, qlen), dtype=torch.long).to(device)
labels[:,:qlen//4] = -100
# labels = torch.randint(0, vocab, (bsz, qlen), dtype=torch.long).to(device)
initial_W = torch.randn(hd, vocab, dtype=dtype)

# Initialize a linear layer
base_linear_ = torch.nn.Linear(hd, vocab, bias=False, dtype=dtype).to(device)
base_linear_.weight = torch.nn.Parameter(
    initial_W.clone().detach().T.to(device))
base_linear_.weight.grad = None

# Prepare input data
batch_in = batch.clone().detach().to(device).requires_grad_()
batch_in.retain_grad()
batch_in.grad = None

# Clear memory and stats to profile forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# (bsz, qlen, hd) @ (hd, vocab) -> (bsz, qlen, vocab) logits.
loss_expected = transformation_function_with_other_loss(batch_in, base_linear_, labels.detach(),
                                                        partial(ForCausalLMLoss,
                                                                vocab_size=vocab))
base_linear_forward_mem = torch.cuda.max_memory_allocated()

# Without checkpointing, `backward()` doubles our memory usage since we need
# to persist intermediate state.

# Clear memory and stats to profile backward()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Do the backward
loss_expected.backward()

base_linear_backward_mem = torch.cuda.max_memory_allocated()

# Move to CPU to not interfere VRAM measurement
base_linear_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
base_linear_grad_values = base_linear_.weight.grad.clone().detach().transpose(0,1).to('cpu')

# Free memory
del base_linear_, batch_in

torch.cuda.empty_cache()

#### Run memory efficient layer

# Initialize a linear parameter.
mem_eff_linear_ = torch.nn.Parameter(initial_W.clone().detach().to(device))
memEffLinear = MemoryEfficientLinear.apply
mem_eff_linear_.grad = None

# Prepare input data
batch_in = batch.clone().detach().to(device).requires_grad_()
batch_in.retain_grad()
batch_in.grad = None

# Clear memory and stats to profile forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

loss_actual = memEffLinear(batch_in,
                              mem_eff_linear_,
                              labels.detach(),
                              partial(ForCausalLMLoss,vocab_size=vocab),
                              num_chunks)
mem_eff_linear_forward_mem = torch.cuda.max_memory_allocated()

# Clear memory and stats to profile backward()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Do the backward
loss_actual.backward()

# Measure peak memory usage
mem_eff_linear_backward_mem = torch.cuda.max_memory_allocated()

# Move to CPU to not interfere VRAM measurement
mem_eff_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
mem_eff_linear_grad_values = mem_eff_linear_.grad.clone().detach().to('cpu')

# Free memory
del mem_eff_linear_, memEffLinear, batch_in
torch.cuda.empty_cache()
# Compare forward memory usage from the above test
forward_vram_change = (mem_eff_linear_forward_mem - base_linear_forward_mem) / base_linear_forward_mem

base_linear_forward_vram_gb = (base_linear_forward_mem) / (1024**3)
base_mem_eff_forward_vram_gb = (mem_eff_linear_forward_mem) / (1024**3)
print("#### Func {} ####".format(down_projection_function))
print("Peak Memory usage during basic linear forward(): {:.2f} GB".format(base_linear_forward_vram_gb))
print("Peak Memory usage during memory efficient linear forward(): {:.2f} GB".format(base_mem_eff_forward_vram_gb))
print("Change: {:.2f} GB".format(base_mem_eff_forward_vram_gb - base_linear_forward_vram_gb))
print("% Change: {:.2f}%".format(forward_vram_change * 100))
# Compare backward memory usage from the above test
backward_vram_change = (mem_eff_linear_backward_mem - base_linear_backward_mem) / base_linear_backward_mem

base_linear_backward_vram_gb = (base_linear_backward_mem) / (1024**3)
base_mem_eff_backward_vram_gb = (mem_eff_linear_backward_mem) / (1024**3)
print("Peak Memory usage during basic linear backward(): {:.2f} GB".format(base_linear_backward_vram_gb))
print("Peak Memory usage during memory efficient linear backward(): {:.2f} GB".format(base_mem_eff_backward_vram_gb))
print("Change: {:.2f} GB".format(base_mem_eff_backward_vram_gb - base_linear_backward_vram_gb))
print("% Change: {:.2f}%".format(backward_vram_change * 100))

# Show losses from the above runs are equivalent
# Bfloat16 is quite imprecise: https://nhigham.com/tag/bfloat16/
# "bfloat16 numbers have the equivalent of about three decimal digits of precision"
print("loss_expected", loss_expected)
print("loss_actual", loss_actual)
assert_same(loss_expected, loss_actual, _F(_C()), loss_actual.dtype)

# Show gradients are equivalent
# Assert X is same in batch and mem_eff case
assert_same(mem_eff_batch_grad_values, base_linear_batch_grad_values,
            _F(_C()), mem_eff_batch_grad_values.dtype)
# Assert W is same in batch and mem_eff case
assert_same(mem_eff_linear_grad_values, base_linear_grad_values,
            _F(_C()), mem_eff_linear_grad_values.dtype)
print()

#### Func <function kl_div_loss_f at 0x78df40e5a2a0> ####
Peak Memory usage during basic linear forward(): 8.55 GB
Peak Memory usage during memory efficient linear forward(): 4.12 GB
Change: -4.43 GB
% Change: -51.84%
Peak Memory usage during basic linear backward(): 12.37 GB
Peak Memory usage during memory efficient linear backward(): 6.51 GB
Change: -5.86 GB
% Change: -47.37%
loss_expected tensor(198.6714, device='cuda:0', grad_fn=<NllLossBackward0>)
loss_actual tensor(198.6715, device='cuda:0', grad_fn=<MemoryEfficientLinearBackward>)



### SelectiveLogSoftmax

This is for GRPO support, however the signature is rather different from the baseline MemoryEfficientLinear module.

**NOTE:** Unfortuantely, this operation isn't very stable in bfloat16. Getting this more precise will require some further thought. However, we might be able to get away with using this for GRPO training anyways.

In [None]:
# Test this works on selective_log_softmax: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1659
import torch
from trl.trainer.utils import selective_log_softmax
from typing import Callable, List, Optional, Tuple, Union
from transformers.models.llama.modeling_llama import LlamaForCausalLM, KwargsForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
from transformers.loss.loss_utils import ForCausalLMLoss
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import gc


# Make sure we're actually using cuda
device = 'cuda'
def transformation_function_for_sls(batch, linear, labels,
                                           down_projection_function):
    x = linear(batch).float() # Up projection to large space
    loss = down_projection_function(x, labels)
    return loss

# Linear -> selective_log_softmax fusion
# This is specialized for the signature of `selective_log_softmax()`
class MemoryEfficientLinearSLS(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, index, forward_function, mel_num_chunks=1):
        ctx.mel_num_chunks = mel_num_chunks

        assert X.shape[0] % ctx.mel_num_chunks == 0
        assert ctx.mel_num_chunks <= X.shape[0]
        b_per_chunk = X.shape[0] // ctx.mel_num_chunks
        # selective_log_softmax
        bsz, qlen = X.shape[0], X.shape[1]
        output = torch.zeros(bsz, qlen).to(device)
        for b in range(ctx.mel_num_chunks):
            b0, b1 = b *  b_per_chunk, (b+1) * b_per_chunk
            with torch.no_grad():
                X_slice = X[b0:b1]
                l_slice = index[b0:b1]
                XW_slice = (F.linear(X_slice, W.T)).float()
            output[b0:b1] = forward_function(XW_slice, l_slice)
        del XW_slice
        ctx.save_for_backward(X, W, index)
        ctx.forward_function = forward_function
        return output

    @staticmethod
    def backward(ctx, dY):

        X, W, index = ctx.saved_tensors
        dX = torch.zeros_like(X)
        dW = torch.zeros_like(W)
        assert X.shape[0] % ctx.mel_num_chunks == 0
        assert ctx.mel_num_chunks <= X.shape[0]
        b_per_chunk = X.shape[0] // ctx.mel_num_chunks
        for b in range(ctx.mel_num_chunks):
            b0, b1 = b * b_per_chunk, (b+1) * b_per_chunk
            X_slice = X[b0:b1].detach().requires_grad_()
            W_slice = W.detach().requires_grad_()
            l_slice = index[b0:b1].detach()
            with torch.enable_grad():
                XW_slice = (F.linear(X_slice, W_slice.T)).float()
                out = ctx.forward_function(XW_slice, l_slice)
            dX_slice, dW_slice = torch.autograd.grad(out, (X_slice, W_slice), dY[b0:b1], retain_graph=False, create_graph=False)
            dX[b0:b1] = dX_slice.to(dX.dtype)
            dW += dW_slice.to(dW.dtype)
        return dX, dW, None, None, None, None

bsz = 2
qlen = 2048
hd = 2048
vocab = 128256
dtype = torch.bfloat16 # torch.float32 is stable and passes, but bfloat16 is too imprecise
num_chunks = 2

# Create input
batch = torch.randn((bsz, qlen, hd), dtype=dtype, requires_grad=True)
batch.retain_grad()
input_ids = torch.randint(0, qlen, (bsz, qlen), dtype=torch.long).to(device)
initial_W = torch.randn(hd, vocab, dtype=dtype)

# Initialize a linear layer
base_linear_ = torch.nn.Linear(hd, vocab, bias=False, dtype=dtype).to(device)
base_linear_.weight = torch.nn.Parameter(
    initial_W.clone().detach().T.to(device))
base_linear_.weight.grad = None

# Prepare input data
batch_in = batch.clone().detach().to(device).requires_grad_()
batch_in.retain_grad()
batch_in.grad = None

# Clear memory and stats to profile forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# (bsz, qlen, hd) @ (hd, vocab) -> (bsz, qlen, vocab) logits.
loss_expected = transformation_function_for_sls(batch_in, base_linear_, input_ids.detach(),
                                                          selective_log_softmax)
base_linear_forward_mem = torch.cuda.max_memory_allocated()

# Without checkpointing, `backward()` doubles our memory usage since we need
# to persist intermediate state.

# Clear memory and stats to profile backward()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Do the backward
loss_expected.backward(torch.ones_like(input_ids))

base_linear_backward_mem = torch.cuda.max_memory_allocated()

# Move to CPU to not interfere VRAM measurement
base_linear_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
base_linear_grad_values = base_linear_.weight.grad.clone().detach().transpose(0,1).to('cpu')

# Free memory
del base_linear_, batch_in

torch.cuda.empty_cache()

#### Run memory efficient layer

# Initialize a linear parameter.
mem_eff_linear_ = torch.nn.Parameter(initial_W.clone().detach().to(device))
memEffLinear = MemoryEfficientLinearSLS.apply
mem_eff_linear_.grad = None

# Prepare input data
batch_in = batch.clone().detach().to(device).requires_grad_()
batch_in.retain_grad()
batch_in.grad = None

# Clear memory and stats to profile forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

loss_actual = memEffLinear(batch_in,
                              mem_eff_linear_,
                              input_ids.detach(),
                              selective_log_softmax,
                              num_chunks)
mem_eff_linear_forward_mem = torch.cuda.max_memory_allocated()

# Clear memory and stats to profile backward()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Do the backward
loss_actual.backward(torch.ones_like(input_ids))

# Measure peak memory usage
mem_eff_linear_backward_mem = torch.cuda.max_memory_allocated()

# Move to CPU to not interfere VRAM measurement
mem_eff_batch_grad_values = batch_in.grad.clone().detach().to('cpu')
mem_eff_linear_grad_values = mem_eff_linear_.grad.clone().detach().to('cpu')

# Free memory
del mem_eff_linear_, memEffLinear, batch_in
torch.cuda.empty_cache()

# Compare forward memory usage from the above test
forward_vram_change = (mem_eff_linear_forward_mem - base_linear_forward_mem) / base_linear_forward_mem

base_linear_forward_vram_gb = (base_linear_forward_mem) / (1024**3)
base_mem_eff_forward_vram_gb = (mem_eff_linear_forward_mem) / (1024**3)

print("Peak Memory usage during basic linear forward(): {:.2f} GB".format(base_linear_forward_vram_gb))
print("Peak Memory usage during memory efficient linear forward(): {:.2f} GB".format(base_mem_eff_forward_vram_gb))
print("Change: {:.2f} GB".format(base_mem_eff_forward_vram_gb - base_linear_forward_vram_gb))
print("% Change: {:.2f}%".format(forward_vram_change * 100))
# Compare backward memory usage from the above test
backward_vram_change = (mem_eff_linear_backward_mem - base_linear_backward_mem) / base_linear_backward_mem

base_linear_backward_vram_gb = (base_linear_backward_mem) / (1024**3)
base_mem_eff_backward_vram_gb = (mem_eff_linear_backward_mem) / (1024**3)
print("Peak Memory usage during basic linear backward(): {:.2f} GB".format(base_linear_backward_vram_gb))
print("Peak Memory usage during memory efficient linear backward(): {:.2f} GB".format(base_mem_eff_backward_vram_gb))
print("Change: {:.2f} GB".format(base_mem_eff_backward_vram_gb - base_linear_backward_vram_gb))
print("% Change: {:.2f}%".format(backward_vram_change * 100))

# Show losses from the above runs are equivalent
# Bfloat16 is quite imprecise: https://nhigham.com/tag/bfloat16/
# "bfloat16 numbers have the equivalent of about three decimal digits of precision"
print("loss_expected", loss_expected)
print("loss_actual", loss_actual)
assert_same(loss_expected, loss_actual, _F(_C()), loss_actual.dtype)

# Show gradients are equivalent
# Assert X is same in batch and mem_eff case
assert_same(mem_eff_batch_grad_values, base_linear_batch_grad_values,
            _F(_C()), mem_eff_batch_grad_values.dtype)

# dW is not numerically stable
dw_abs_diff = mem_eff_linear_grad_values - base_linear_grad_values
print("dW relative difference:", dw_abs_diff)
print("min dW difference:", torch.min(dw_abs_diff))
print("mean dW difference:", torch.mean(dw_abs_diff))
print("max dW difference:", torch.max(dw_abs_diff))

Peak Memory usage during basic linear forward(): 5.54 GB
Peak Memory usage during memory efficient linear forward(): 5.02 GB
Change: -0.52 GB
% Change: -9.39%
Peak Memory usage during basic linear backward(): 7.44 GB
Peak Memory usage during memory efficient linear backward(): 6.45 GB
Change: -0.99 GB
% Change: -13.27%
loss_expected tensor([[-247.5000, -247.1250, -180.2837,  ..., -183.6250, -230.3139,
         -198.3490],
        [-244.7500, -178.7121, -133.5025,  ..., -214.6331, -224.1269,
         -202.0156]], device='cuda:0', grad_fn=<SubBackward0>)
loss_actual tensor([[-247.5000, -247.1250, -180.2837,  ..., -183.6250, -230.3139,
         -198.3490],
        [-244.7500, -178.7121, -133.5025,  ..., -214.6331, -224.1269,
         -202.0156]], device='cuda:0',
       grad_fn=<MemoryEfficientLinearSLSBackward>)
dW relative difference: tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.4414e-04,  0.0000e+00,  ..