In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
from dotenv import load_dotenv
import os
import re


_ = load_dotenv()

_ = os.getenv("HF_TOKEN")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Use EOS as PAD
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


# Load gpt-2 in 8-bit to save memory
model = AutoModelForCausalLM.from_pretrained(
    "gpt2",
    pad_token_id=tokenizer.eos_token_id,
	torch_dtype=torch.float16,
    device_map="auto"
)

In [2]:
datasets = ["microsoft/wiki_qa", "rajpurkar/squad", "SoftAge-AI/sft-conversational_dataset", "D1rtyB1rd/Beloved-Everyday-Conversations"]

# Load dataset
ds = load_dataset(datasets[0])

# Take a small subset (to save memory)
ds_train = ds["train"].shuffle(seed=42).select(range(800))
ds_val = ds["validation"].shuffle(seed=42).select(range(200))

In [3]:
def clean_text(text):
    # Remove unwanted spaces before punctuation
    text = re.sub(r'\s+([.,;:!?%)])', r'\1', text)
    # Replace multiple punctuations with a single 
    text = re.sub(r'([.,;:!?%])\1+', r'\1', text)
    # Replace multiple spaces between words with a single space
    text = re.sub(r'\s+', ' ', text)
    # Ensure no leading or trailing spaces
    text.strip()
    return text

In [4]:
# Look at the structure
print(ds_train[0])

{'question_id': 'Q473', 'question': 'how much does united states spend on health care', 'document_title': 'Health care in the United States', 'answer': 'The U.S. Census Bureau reported that 49.9 million residents, 16.3% of the population, were uninsured in 2010 (up from 49.0 million residents, 16.1% of the population, in 2009).', 'label': 0}


In [5]:
def format_qa(examples):
    texts = []
    for q, a in zip(examples["question"], examples["answer"]):
        # Format: "User: [Question] Assistant: [Answer]"
        if "?" in q:
            text = f"Human: {clean_text(q.capitalize())} Bot: {clean_text(a)}{tokenizer.eos_token}"
        else:
            text = f"Human: {clean_text(q.capitalize())}? Bot: {clean_text(a)}{tokenizer.eos_token}"
        texts.append(text)
    return {"text": texts}

train_dataset = ds_train.map(format_qa, batched=True)
val_dataset = ds_val.map(format_qa, batched=True)

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [6]:
print(train_dataset[0])

{'question_id': 'Q473', 'question': 'how much does united states spend on health care', 'document_title': 'Health care in the United States', 'answer': 'The U.S. Census Bureau reported that 49.9 million residents, 16.3% of the population, were uninsured in 2010 (up from 49.0 million residents, 16.1% of the population, in 2009).', 'label': 0, 'text': 'Human: How much does united states spend on health care? Bot: The U.S. Census Bureau reported that 49.9 million residents, 16.3% of the population, were uninsured in 2010 (up from 49.0 million residents, 16.1% of the population, in 2009).<|endoftext|>'}


In [7]:
# Tokenize the datasets
def tokenize_function(examples):
    tokens = tokenizer(
        examples["text"],
        truncation=True,
        max_length=128
    ) 
    return tokens

tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["question_id", "question", "document_title", "answer", "label", "text"])
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=["question_id", "question", "document_title", "answer", "label", "text"])

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [8]:
print(tokenized_train[0])

{'input_ids': [20490, 25, 1374, 881, 857, 16503, 2585, 4341, 319, 1535, 1337, 30, 18579, 25, 383, 471, 13, 50, 13, 20962, 9840, 2098, 326, 5125, 13, 24, 1510, 5085, 11, 1467, 13, 18, 4, 286, 262, 3265, 11, 547, 32736, 287, 3050, 357, 929, 422, 5125, 13, 15, 1510, 5085, 11, 1467, 13, 16, 4, 286, 262, 3265, 11, 287, 3717, 737, 50256], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
