In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
from cut_cross_entropy.transformers import cce_patch
from cut_cross_entropy import linear_cross_entropy
import torch
import transformers
import numpy as np
import random
from torchviz import make_dot

In [3]:
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct')

In [4]:
!nvidia-smi

Fri Jan 17 13:39:11 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090 Ti     Off |   00000000:01:00.0 Off |                  Off |
| 30%   33C    P8             22W /  400W |    1004MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090 Ti     Off |   00

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    'HuggingFaceTB/SmolLM2-135M-Instruct',
    torch_dtype = torch.bfloat16
).cuda()

In [6]:
rank = 64
peft_config = LoraConfig(
        lora_alpha=rank * 2,
        lora_dropout=0.0,
        r=rank,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["embed_tokens", "lm_head"],
    )

In [7]:
model = get_peft_model(model, peft_config)



In [8]:
input_ids = torch.tensor([1,2,3])[None].cuda()

In [9]:
o = model(input_ids = input_ids, output_hidden_states = True)

In [10]:
classifier = model.lm_head.weight

In [11]:
classifier.shape

torch.Size([49152, 576])

In [12]:
c_a = model.lm_head.lora_A.default.weight
c_b = model.lm_head.lora_B.default.weight
alpha = model.lm_head.scaling['default']

In [13]:
manual_shift_loss = linear_cross_entropy(o.hidden_states[-1], classifier,
                                         c_a = c_a,
                                         c_b = c_b,
                                         alpha = alpha,
                                         targets = input_ids, shift = True)

In [14]:
manual_shift_loss

tensor(19.7411, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)

In [15]:
!nvidia-smi

Fri Jan 17 13:39:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090 Ti     Off |   00000000:01:00.0 Off |                  Off |
| 30%   33C    P8             22W /  400W |    1004MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090 Ti     Off |   00