In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
from cut_cross_entropy.transformers import cce_patch
import torch
import transformers
import numpy as np
import random

In [3]:
# !pip3.10 uninstall transformers -y
# !pip3.10 install -e .

In [4]:
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-360M-Instruct')

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    'HuggingFaceTB/SmolLM2-360M-Instruct', attn_implementation = 'flex_attention',
    torch_dtype = torch.bfloat16
).cuda()
model = cce_patch(model)

In [6]:
trainable_parameters = [param for param in model.parameters() if param.requires_grad]
trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-4)

In [7]:
maxlen = 8192
batch_size = 2
input_ids = torch.ones(batch_size, maxlen, dtype = torch.int32)
position_ids = torch.arange(maxlen)[None].repeat((batch_size, 1))

In [8]:
def generate_random_lengths(total_length, num_documents):
    lengths = [1] * num_documents
    remaining_length = total_length - num_documents
    for _ in range(remaining_length):
        index = random.randint(0, num_documents - 1)
        lengths[index] += 1

    return lengths

def length_to_offsets(lengths, device):
    offsets = [0]
    offsets.extend(lengths)
    offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
    offsets = torch.cumsum(offsets, dim=-1)
    return offsets

In [9]:
lengths = []
for _ in range(batch_size):
    lengths.append(length_to_offsets(generate_random_lengths(maxlen, 10), 'cuda'))
lengths

[tensor([   0,  867, 1698, 2520, 3312, 4155, 4976, 5772, 6596, 7405, 8192],
        device='cuda:0'),
 tensor([   0,  810, 1630, 2419, 3265, 4085, 4911, 5773, 6563, 7389, 8192],
        device='cuda:0')]

In [10]:
input_ids = input_ids.cuda()
position_ids = position_ids.cuda()

In [11]:
o = model(input_ids = input_ids,
      labels = input_ids,
      position_ids = position_ids,
      attention_mask = lengths).loss

In [12]:
trainer.zero_grad()
o.backward()

In [13]:
trainer.step()