In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
from cut_cross_entropy.transformers import cce_patch
from cut_cross_entropy import linear_cross_entropy
import torch
import transformers
import numpy as np
import random
from torchviz import make_dot

In [3]:
tokenizer = AutoTokenizer.from_pretrained('unsloth/Llama-3.2-3B-Instruct')

In [4]:
class Llama(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
    def forward(self, **kwargs):
        labels = kwargs.pop('labels')
        kwargs.pop('output_hidden_states', None)
        super_out = super().forward(**kwargs, output_hidden_states = True)
        if labels is not None:
            embeddings = super_out.hidden_states[-1]
            auto_shift_loss = linear_cross_entropy(embeddings, self.lm_head, labels, shift=True, impl = 'torch_compile')
            return {'loss': auto_shift_loss}
        return super_out

In [5]:
model = Llama.from_pretrained(
    'unsloth/Llama-3.2-3B-Instruct',
    torch_dtype = torch.bfloat16
).cuda()

In [6]:
model_auto = AutoModelForCausalLM.from_pretrained(
    'unsloth/Llama-3.2-3B-Instruct',
    torch_dtype = torch.bfloat16
).cuda()

In [7]:
rank = 256
peft_config = LoraConfig(
        lora_alpha=rank * 2,
        lora_dropout=0.0,
        r=rank,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["embed_tokens", "lm_head"],
    )

In [8]:
model = get_peft_model(model, peft_config)
model_auto = get_peft_model(model_auto, peft_config)



In [9]:
input_ids = tokenizer.apply_chat_template([
    {'role': 'user', 'content': 'Hi!'}
], return_tensors = 'pt').cuda()
input_ids

tensor([[128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   2437,   4448,    220,   2366,     20,    271, 128009, 128006,
            882, 128007,    271,  13347,      0, 128009]], device='cuda:0')

In [10]:
o = model(input_ids = input_ids, labels = input_ids)
o

{'loss': tensor(6.8273, device='cuda:0', grad_fn=<CompiledFunctionBackward>)}

In [15]:
o = model_auto(input_ids = input_ids, labels = input_ids)
o.loss

tensor(6.8448, device='cuda:0', grad_fn=<NllLossBackward0>)

In [16]:
!pip3.10 uninstall torch torchvision torchaudio -y

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Found existing installation: torch 2.5.1+cu121
Uninstalling torch-2.5.1+cu121:
  Successfully uninstalled torch-2.5.1+cu121
Found existing installation: torchvision 0.20.1+cu121
Uninstalling torchvision-0.20.1+cu121:
  Successfully uninstalled torchvision-0.20.1+cu121
Found existing installation: torchaudio 2.5.1+cu121
Uninstalling torchaudio-2.5.1+cu121:
  Successfully uninstalled torchaudio-2.5.1+cu121


In [None]:
!pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable
[0mLooking in indexes: https://download.pytorch.org/whl/nightly/cu124, https://pypi.ngc.nvidia.com
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20250102%2Bcu124-cp310-cp310-manylinux_2_28_x86_64.whl (766.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.6/766.6 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/nightly/cu124/torchvision-0.22.0.dev20250102%2Bcu124-cp310-cp310-linux_x86_64.whl (7.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m:01[0m
[?25hCollecting torchaudio
  Downloading https://download.pytorch.org/whl/nightly/cu124/torchaudio-2.6.0.dev20250102%2Bcu124-cp310-cp310-linux_x86_64.whl (