In [1]:
import torch.utils.data as Data
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq

In [2]:
train_dataset = load_dataset(path='gbharti/finance-alpaca', split='train')
train_dataset[:10]

{'instruction': ['For a car, what scams can be plotted with 0% financing vs rebate?',
  'Why does it matter if a Central Bank has a negative rather than 0% interest rate?',
  'Where should I be investing my money?',
  'Specifically when do options expire?',
  'Negative Balance from Automatic Options Exercise. What to do?',
  'Approximation of equity value for company in default',
  'Is it true that 90% of investors lose their money?',
  'Can a company charge you for services never requested or received?',
  'Working out if I should be registered as self-employed in the UK',
  'About eToro investments'],
 'input': ['', '', '', '', '', '', '', '', '', ''],
 'output': ["The car deal makes money 3 ways. If you pay in one lump payment. If the payment is greater than what they paid for the car, plus their expenses, they make a profit. They loan you the money. You make payments over months or years, if the total amount you pay is greater than what they paid for the car, plus their expenses, p

In [3]:
tokenizer = AutoTokenizer.from_pretrained('E:\huggingface_models\Qwen2.5-0.5B-Instruct')

In [4]:
def process_func(example):
    messages = [
        {"role": "system", "content": "You are a financial advisor."},
        {"role": "user", "content": example['instruction'].strip()}
    ]
    instruction = tokenizer.apply_chat_template(conversation=messages,
                                                add_generation_prompt=True,
                                                tokenize=True,
                                                return_dict=True)
    input_ids, attention_mask, labels = [], [], []
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.eos_token_id]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


train_dataset = train_dataset.map(process_func)
train_dataset

Dataset({
    features: ['instruction', 'input', 'output', 'text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 68912
})

In [5]:
train_dataset = train_dataset.remove_columns(['instruction', 'input', 'output', 'text'])
train_dataset = train_dataset.add_column('yy', list(range(0, len(train_dataset))))
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels', 'yy'],
    num_rows: 68912
})

In [6]:
dac = Data.DataLoader(train_dataset, 
                      # Data collator that will dynamically pad the inputs received, as well as the labels.
                      collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                                        padding=True  # 默认padding=True
                                                        ), 
                      batch_size=2)

for i in dac:
    print(i.keys())
    print(i['input_ids'])
    print(i['attention_mask'])
    print(i['labels'])
    break

[{'input_ids': [151644, 8948, 198, 2610, 525, 264, 5896, 36613, 13, 151645, 198, 151644, 872, 198, 2461, 264, 1803, 11, 1128, 63055, 646, 387, 67583, 448, 220, 15, 4, 28542, 6165, 89946, 30, 151645, 198, 151644, 77091, 198, 785, 1803, 3484, 3643, 3220, 220, 18, 5510, 13, 1416, 498, 2291, 304, 825, 48529, 8160, 13, 1416, 279, 8160, 374, 7046, 1091, 1128, 807, 7171, 369, 279, 1803, 11, 5519, 862, 18024, 11, 807, 1281, 264, 11372, 13, 2379, 11679, 498, 279, 3220, 13, 1446, 1281, 14173, 916, 3951, 476, 1635, 11, 421, 279, 2790, 3311, 498, 2291, 374, 7046, 1091, 1128, 807, 7171, 369, 279, 1803, 11, 5519, 862, 18024, 11, 5519, 862, 17017, 18024, 807, 1281, 3220, 13, 4940, 3308, 279, 3220, 4990, 1635, 311, 2525, 304, 11, 476, 807, 4559, 697, 11679, 311, 2441, 2562, 311, 633, 279, 3220, 10596, 714, 304, 264, 9155, 3311, 13, 1446, 6559, 304, 264, 1803, 323, 807, 4559, 432, 518, 264, 11372, 13, 4940, 3308, 429, 501, 7745, 1410, 387, 264, 48529, 2629, 476, 264, 11679, 389, 279, 1483, 1803, 1112, 