In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
import torch
import transformers
import numpy as np

In [3]:
# !pip3.10 uninstall transformers -y
# !pip3.10 install -e .

In [4]:
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct')

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    'HuggingFaceTB/SmolLM2-135M-Instruct', attn_implementation = 'flex_attention'
).cuda()

In [6]:
model_sdpa = AutoModelForCausalLM.from_pretrained(
    'HuggingFaceTB/SmolLM2-135M-Instruct', attn_implementation = 'sdpa'
).cuda()

In [7]:
texts = [
    'how to solve world hunger',
    '1+1',
]
input_ids, position_ids, lengths = [], [], [0]
for t in texts:
    d = [
        {'role': 'user', 'content': t}
    ]
    d = tokenizer.apply_chat_template(d)
    input_ids.extend(d)
    position_ids.extend(list(range(len(d))))
    lengths.append(len(d))

In [8]:
lengths = torch.tensor(np.cumsum(lengths)).cuda()
input_ids = torch.tensor(input_ids).cuda()
position_ids = torch.tensor(position_ids, dtype = torch.int32).cuda()

In [9]:
model(input_ids = input_ids[None],
      labels = input_ids[None],
      position_ids = position_ids[None],
      attention_mask = [lengths]).loss

tensor(5.5380, device='cuda:0', grad_fn=<NllLossBackward0>)

In [13]:
def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

masks = []
for f in [lengths]:
    masks_ = []
    masking = torch.diff(f)
    for m in masking:
        masks_.append(torch.tril(torch.ones(m, m)))
    
    masks.append(block_diagonal_concat_inverted(*masks_, dtype = model.dtype))
    
masks = torch.stack(masks, 0).to('cuda')

In [14]:
model_sdpa(input_ids = input_ids[None],
      labels = input_ids[None],
      position_ids = position_ids[None],
      attention_mask = masks).loss

tensor(5.5380, device='cuda:0', grad_fn=<NllLossBackward0>)