In [1]:
from typing import List, Optional, Tuple, Union, Unpack
import torch
import torch.nn.functional as F
import os

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import LlamaPreTrainedModel, GenerationMixin, LlamaModel
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, KwargsForCausalLM, LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers import Cache
from transformers.utils import replace_return_docstrings, add_start_docstrings_to_model_forward
from transformers.utils.deprecation import deprecate_kwarg
from transformers import AutoConfig

def transformation_function(batch, linear, labels):
    x = linear(batch).float()
    return F.cross_entropy(x.view(-1, x.shape[-1]), labels.view(-1))

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function, chunk_size):
        ctx.save_for_backward(X, labels)
        ctx.linear = linear
        ctx.forward_function = forward_function
        ctx.chunk_size = chunk_size

        total_output = 0.0
        n = X.shape[0]
        for i in range(0, n, chunk_size):
            end_idx = min(i + chunk_size, n)
            batch_chunk = X[i:end_idx]
            labels_chunk = labels[i:end_idx]
            output_chunk = forward_function(batch_chunk, linear, labels_chunk)
            total_output += output_chunk * (end_idx - i) / n
        return total_output

    @staticmethod
    def backward(ctx, grad_output):
        X, labels = ctx.saved_tensors
        linear = ctx.linear
        forward_function = ctx.forward_function
        chunk_size = ctx.chunk_size

        grad_X = torch.zeros_like(X)
        n = X.shape[0]
        for i in range(0, n, chunk_size):
            end_idx = min(i + chunk_size, n)
            batch_chunk = X[i:end_idx].detach().requires_grad_(True)
            labels_chunk = labels[i:end_idx]

            with torch.enable_grad():
                output_chunk = forward_function(batch_chunk, linear, labels_chunk)

            chunk_grad = torch.autograd.grad(
                output_chunk, batch_chunk,
                grad_output * (end_idx - i) / n
            )[0]
            grad_X[i:end_idx] = chunk_grad
        return grad_X, None, None, None, None

def memory_efficient_linear(X, linear, labels, forward_function=transformation_function, chunk_size=1):
    return MemoryEfficientLinear.apply(X, linear, labels, forward_function, chunk_size)

device = "cuda"

batch_size = 4
seq_len = 1024
hidden_dim = 1024
vocab_size = 128000

linear = torch.nn.Linear(hidden_dim, vocab_size).to(device)

X = torch.randn(batch_size, seq_len, hidden_dim, requires_grad=True, device=device)
labels = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

torch.cuda.reset_peak_memory_stats(device)
loss_std = transformation_function(X, linear, labels)
loss_std.backward()
std_mem = torch.cuda.max_memory_allocated(device)
print("standard loss:", loss_std.item())
print("standard max gpu memory: {:.2f} MB".format(std_mem / (1024 * 1024)))

linear.zero_grad()
if X.grad is not None:
    X.grad.zero_()

### CHUNK SIZE 1 means 1 batch once now not chunk amount, we have 4 batches, standard one will process all at once
### but ours will process one batch at once + 1 store of loss float so in theory
### it guaranteed %75ish vram reduction, but in practice its like %50+ as shown below
### if it does not do %50 you can do more batches to get benefit like 6 or 8
chunk_size = 1

torch.cuda.reset_peak_memory_stats(device)
loss_chunked = memory_efficient_linear(X, linear, labels, forward_function=transformation_function, chunk_size=chunk_size)
loss_chunked.backward()
chunked_mem = torch.cuda.max_memory_allocated(device)
print("Memory-Efficient loss:", loss_chunked.item())
print("Memory-Efficient max gpu memory: {:.2f} MB".format(chunked_mem / (1024 * 1024)))

reduction_pct = (std_mem - chunked_mem) / std_mem * 100
print("VRAM reduction: {:.1f}%".format(reduction_pct))

def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
    old_logits = old_logits.to(torch.float32)
    new_logits = new_logits.to(torch.float32)
    input_ids  = input_ids.unsqueeze(-1)

    old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
    new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)

    old_logsumexp = torch.logsumexp(old_logits, dim=-1)
    old = old_x - old_logsumexp

    new_logsumexp = torch.logsumexp(new_logits, dim=-1)
    new = new_x - new_logsumexp

    kl_i = torch.exp(old - new) - (old - new) - 1.0

    loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
    loss_i = -(loss_i - beta * kl_i)

    mask = mask.to(torch.float32)
    n_mask_per_reward = mask.sum(1)

    loss = (loss_i * mask).sum() / mask.sum()
    
    with torch.inference_mode():
        completion_length = n_mask_per_reward.mean()
        mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
        mean_kl = mean_kl_per_reward.mean()
    return loss, completion_length, mean_kl

def grpo_accumulated_loss(
    old_hidden_states,
    new_hidden_states,
    lm_head,
    input_ids,
    logits_to_keep,
    completion_mask,
    beta,
    advantages,
):
    os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
    
    completion_input_ids = input_ids[:, -logits_to_keep:]
    old_hidden_states = old_hidden_states[:, -logits_to_keep-1:]
    new_hidden_states = new_hidden_states[:, -logits_to_keep-1:]
    new_logits = torch.matmul(new_hidden_states, lm_head.weight.t())
    new_logits = new_logits[:, :-1, :]
    old_logits = torch.matmul(old_hidden_states, lm_head.weight.t())
    old_logits = old_logits[:, :-1, :]
    loss, completion_length, mean_kl = grpo_compute_loss(
        old_logits, new_logits, completion_input_ids, completion_mask, beta, advantages,
    )
    return loss, completion_length, mean_kl


# Memory efficient GRPO implementation using batch chunking similar to MemoryEfficientLinear
class MemoryEfficientGRPO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, old_hidden_states, new_hidden_states, lm_head, input_ids, logits_to_keep, completion_mask, beta, advantages, chunk_size):
        ctx.save_for_backward(old_hidden_states, new_hidden_states, input_ids, completion_mask, advantages)
        ctx.lm_head = lm_head
        ctx.logits_to_keep = logits_to_keep
        ctx.beta = beta
        ctx.chunk_size = chunk_size

        bsz = input_ids.shape[0]
        total_loss = 0.0
        total_completion_length = 0.0
        total_mean_kl = 0.0
        for i in range(0, bsz, chunk_size):
            end_idx = min(i + chunk_size, bsz)
            old_h_chunk = old_hidden_states[i:end_idx][:, -logits_to_keep-1:]
            new_h_chunk = new_hidden_states[i:end_idx][:, -logits_to_keep-1:]
            input_ids_chunk = input_ids[i:end_idx]
            completion_input_ids = input_ids_chunk[:, -logits_to_keep:]
            comp_mask_chunk = completion_mask[i:end_idx]
            adv_chunk = advantages[i:end_idx]

            new_logits_chunk = torch.matmul(new_h_chunk, lm_head.weight.t())[:, :-1, :]
            old_logits_chunk = torch.matmul(old_h_chunk, lm_head.weight.t())[:, :-1, :]
            loss_chunk, completion_length_chunk, mean_kl_chunk = grpo_compute_loss(
                old_logits_chunk, new_logits_chunk, completion_input_ids, comp_mask_chunk, beta, adv_chunk
            )
            weight = (end_idx - i) / bsz
            total_loss += loss_chunk * weight
            total_completion_length += completion_length_chunk * weight
            total_mean_kl += mean_kl_chunk * weight
        return total_loss, total_completion_length, total_mean_kl

    @staticmethod
    def backward(ctx, grad_loss, grad_completion_length, grad_mean_kl):
        old_hidden_states, new_hidden_states, input_ids, completion_mask, advantages = ctx.saved_tensors
        lm_head = ctx.lm_head
        logits_to_keep = ctx.logits_to_keep
        beta = ctx.beta
        chunk_size = ctx.chunk_size
        bsz = input_ids.shape[0]
        grad_old = torch.zeros_like(old_hidden_states)
        grad_new = torch.zeros_like(new_hidden_states)
        for i in range(0, bsz, chunk_size):
            end_idx = min(i + chunk_size, bsz)
            old_h_chunk = old_hidden_states[i:end_idx].detach().requires_grad_(True)
            new_h_chunk = new_hidden_states[i:end_idx].detach().requires_grad_(True)
            input_ids_chunk = input_ids[i:end_idx]
            comp_mask_chunk = completion_mask[i:end_idx]
            adv_chunk = advantages[i:end_idx]
            with torch.enable_grad():
                old_h_chunk_slice = old_h_chunk[:, -logits_to_keep-1:]
                new_h_chunk_slice = new_h_chunk[:, -logits_to_keep-1:]
                completion_input_ids = input_ids_chunk[:, -logits_to_keep:]
                new_logits_chunk = torch.matmul(new_h_chunk_slice, lm_head.weight.t())[:, :-1, :]
                old_logits_chunk = torch.matmul(old_h_chunk_slice, lm_head.weight.t())[:, :-1, :]
                loss_chunk, _, _ = grpo_compute_loss(
                    old_logits_chunk, new_logits_chunk, completion_input_ids, comp_mask_chunk, beta, adv_chunk
                )
            weight = (end_idx - i) / bsz
            grad_inputs = torch.autograd.grad(loss_chunk, (old_h_chunk, new_h_chunk), grad_loss * weight, retain_graph=False)
            grad_old[i:end_idx] = grad_inputs[0]
            grad_new[i:end_idx] = grad_inputs[1]
        # return two grads because two hidden states, since its grpo objective.
        return grad_old, grad_new, None, None, None, None, None, None, None

def memory_efficient_grpo_loss(old_hidden_states, new_hidden_states, lm_head, input_ids, logits_to_keep, completion_mask, beta, advantages, chunk_size=1):
    return MemoryEfficientGRPO.apply(old_hidden_states, new_hidden_states, lm_head, input_ids, logits_to_keep, completion_mask, beta, advantages, chunk_size)

def transformation_function_autogressive(batch, linear, labels):
    logits = linear(batch).float()
    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]
    return F.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))

class BatchChunkedLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

            logits_to_keep (`int` or `torch.Tensor`, *optional*):
                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
                This is useful when using packed tensor format (single dimension for batch and sequence length).

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        ### my modification
        loss = memory_efficient_linear(hidden_states[:, slice_indices, :], self.lm_head, labels=labels, forward_function=transformation_function_autogressive, chunk_size=1)
        return loss

old_hidden_states = torch.randn(batch_size, seq_len, hidden_dim, device="cuda", requires_grad=True)
new_hidden_states = torch.randn(batch_size, seq_len, hidden_dim, device="cuda", requires_grad=True)
lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False, device="cuda")
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
logits_to_keep = seq_len-1
completion_mask = torch.ones((batch_size, logits_to_keep), device="cuda")
beta = 0.04
advantages = torch.randn(batch_size, device="cuda")

# VRAM benchmarking for GRPO

torch.cuda.reset_peak_memory_stats(device)
loss_std_grpo, comp_length_std_grpo, mean_kl_std_grpo = grpo_accumulated_loss(
    old_hidden_states,
    new_hidden_states,
    lm_head,
    input_ids,
    logits_to_keep,
    completion_mask,
    beta,
    advantages
)
loss_std_grpo.backward()
std_mem_grpo = torch.cuda.max_memory_allocated(device)
print("Standard GRPO Loss:", loss_std_grpo.item())
print("Standard GRPO max gpu memory: {:.2f} MB".format(std_mem_grpo / (1024 * 1024)))

lm_head.zero_grad()
if old_hidden_states.grad is not None:
    old_hidden_states.grad.zero_()
if new_hidden_states.grad is not None:
    new_hidden_states.grad.zero_()

### Using MemoryEfficientGRPO with chunking on batch dimension
chunk_size = 1

torch.cuda.reset_peak_memory_stats(device)
loss_chunked_grpo, comp_length_chunked_grpo, mean_kl_chunked_grpo = memory_efficient_grpo_loss(
    old_hidden_states,
    new_hidden_states,
    lm_head,
    input_ids,
    logits_to_keep,
    completion_mask,
    beta,
    advantages,
    chunk_size=chunk_size
)
loss_chunked_grpo.backward()
chunked_mem_grpo = torch.cuda.max_memory_allocated(device)
print("Memory-Efficient GRPO Loss:", loss_chunked_grpo.item())
print("Memory-Efficient GRPO max gpu memory: {:.2f} MB".format(chunked_mem_grpo / (1024 * 1024)))
reduction_pct_grpo = (std_mem_grpo - chunked_mem_grpo) / std_mem_grpo * 100
print("GRPO VRAM reduction: {:.1f}%".format(reduction_pct_grpo))

lm_head.zero_grad()
if old_hidden_states.grad is not None:
    old_hidden_states.grad.zero_()
if new_hidden_states.grad is not None:
    new_hidden_states.grad.zero_()

# import config from real Llama model but init random
config_real_llama3_2_1B = AutoConfig.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
chunked_llama = BatchChunkedLlamaForCausalLM(config_real_llama3_2_1B).to(device)
classic_llama = LlamaForCausalLM(config_real_llama3_2_1B).to(device)

chunked_llama.load_state_dict(classic_llama.state_dict())

torch.manual_seed(0)
input_ids = torch.randint(0, config_real_llama3_2_1B.vocab_size, (batch_size, seq_len)).to(device)
labels = input_ids.clone()

torch.cuda.reset_peak_memory_stats(device)
outputs_classic = classic_llama(input_ids=input_ids, labels=labels)
loss_classic = outputs_classic.loss
loss_classic.backward()
std_mem = torch.cuda.max_memory_allocated(device)
print("Classic loss:", loss_classic.item())
print("Classic max gpu memory: {:.2f} MB".format(std_mem / (1024 * 1024)))

classic_llama.zero_grad()

torch.cuda.reset_peak_memory_stats(device)
loss_chunked = chunked_llama(input_ids=input_ids, labels=labels)
loss_chunked.backward()
chunked_mem = torch.cuda.max_memory_allocated(device)
print("Chunked loss:", loss_chunked.item())
print("Chunked max gpu memory: {:.2f} MB".format(chunked_mem / (1024 * 1024)))

assert torch.allclose(loss_classic, loss_chunked), "Losses do not match!"

reduction_pct = (std_mem - chunked_mem) / std_mem * 100
print("VRAM reduction: {:.1f}%".format(reduction_pct))

# Full model vram reduction cannot be as large as %50 ish because we only chunk for the last lm_head and loss
# so its expected, its still good if its %10+, it will get better on models like gemma where lm head has larger vocabulary sizes

standard loss: 11.942072868347168
standard max gpu memory: 6524.65 MB
Memory-Efficient loss: 11.942071914672852
Memory-Efficient max gpu memory: 2068.77 MB
VRAM reduction: 68.3%
Standard GRPO Loss: 0.2337828278541565
Standard GRPO max gpu memory: 13073.07 MB
Memory-Efficient GRPO Loss: 0.2337827980518341
Memory-Efficient GRPO max gpu memory: 4151.37 MB
GRPO VRAM reduction: 68.2%
Classic loss: 12.17023754119873
Classic max gpu memory: 32230.35 MB
Chunked loss: 12.170235633850098
Chunked max gpu memory: 27760.38 MB
VRAM reduction: 13.9%
