In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase

model_name = '../pretrained_moels/Qwen2.5-0.5B'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.eos_token is None:
    tokenizer.eos_token = tokenizer.pad_token

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F

class SFTDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data['prompt'])

    def __getitem__(self, idx):
        prompt = self.data['prompt'][idx]
        response = self.data['response'][idx]
        context = prompt + response
        
        prompt_ids = self.tokenizer(prompt, truncation=True, max_length=self.max_len)
        prompt_ids['token_type_ids'] = [0 for i in range(len(prompt_ids['input_ids']))]
        
        response_ids = self.tokenizer(response, truncation=True, max_length=self.max_len)
        response_ids['token_type_ids'] = [1 for i in range(len(response_ids['input_ids']))]

        input_ids = {}
        for key in prompt_ids:
            input_ids[key] = prompt_ids[key] + response_ids[key]

        return input_ids

@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizerBase
    def __call__(self, examples):
        batch = self.tokenizer.pad(examples, padding=True, return_tensors='pt')
        batch['labels'] = batch['input_ids']
        return batch

data_collator = DataCollator(tokenizer=tokenizer)

train_data = {
    "prompt": [
        "Tell me about the capital of France.",
        "Translate 'good morning' into Spanish.",
        "What is the boiling point of water in Celsius?",
        "Write a short greeting to a new customer.",
        "Who is the author of 'Pride and Prejudice'?",
        "Explain what machine learning is in one sentence."
    ],
    "response": [
        "The capital of France is Paris.",
        "In Spanish, 'good morning' is 'buenos días'.",
        "The boiling point of water is 100 degrees Celsius at standard pressure.",
        "Welcome to our store! We are glad to have you here.",
        "The author is Jane Austen.",
        "Machine learning is a field of AI where computers learn from data to make predictions or decisions."
    ]
}

train_dataset = SFTDataset(train_data, tokenizer, max_len=512)
train_dataloder = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)


def cross_entropy_loss_manual(logits, targets, ignore_index=None):
    """
    logits: [N, C]  float
    targets: [N]  long
    """
    # 1. 对 logits 做 log_softmax
    log_probs = F.log_softmax(logits, dim=-1)  # [N, C]

    # 2. 取出标签对应的 log_prob
    if ignore_index is not None:
        mask = targets != ignore_index
        log_probs = log_probs[mask]
        targets = targets[mask]

    nll_loss = -log_probs[torch.arange(targets.size(0)), targets]

    # 3. 取平均
    return nll_loss.mean()

from torch.optim import Adam

optim = Adam(model.parameters(), weight_decay=0.99, lr=5e-5)

model.train()

for batch in train_dataloder:
    outputs = model(**batch)
    logits = outputs.logits
    input_ids = batch['input_ids']
    labels = batch['labels']
    token_type_ids = batch['token_type_ids']

    shift_labels = labels.clone()
    shift_labels[token_type_ids==0] = -100

    shift_labels = shift_labels[:, 1:].reshape(-1).contiguous()
    shift_logits = logits[:, :-1].reshape(-1, logits.shape[-1]).contiguous()

    loss = cross_entropy_loss_manual(shift_logits, shift_labels, ignore_index=-100)

    loss.backward()
    optim.zero_grad()
    optim.step()
