In [1]:
from peft import LoraConfig, get_peft_model
import torch
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "dbmdz/german-gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


In [3]:
# Read the text file
with open('/home/yaning/Documents/LLM/skip_gram/winnetou_s.txt', 'r') as file:
    text_lines = file.readlines()

# # Optional: Clean up (strip extra spaces, newline characters)
# text_lines = [line.strip() for line in text_lines if line.strip() != ""]


In [12]:
len(text_lines[0])

227148

In [5]:
max_length = 150

In [6]:
# If the input text exceeds the max_length, split it into smaller chunks
def split_text_into_chunks(text, chunk_size):
    # Tokenize the entire text and split it into smaller chunks
    tokenized_text = tokenizer(text, padding=False, truncation=False, return_tensors="pt")
    total_tokens = tokenized_text['input_ids'][0]
    chunks = [total_tokens[i:i + chunk_size] for i in range(0, len(total_tokens), chunk_size)]
    return chunks

# Split into chunks (if necessary)
chunks = split_text_into_chunks(text_lines[0], max_length)


In [7]:
chunks = chunks[:10]

In [None]:
chunks

In [14]:
# Custom Dataset for single line split into chunks
class TextDataset(Dataset):
    def __init__(self, chunks):
        self.chunks = chunks

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.chunks[idx]),
            "attention_mask": torch.ones_like(self.chunks[idx]),  # Attention mask for the whole sequence
        }

# Create a dataset instance
dataset = TextDataset(chunks)

# Create a DataLoader (here, batch_size will be 1 since you have one long text)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [18]:
from sklearn.model_selection import train_test_split

train_size = int(0.8 * len(dataset))
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])


In [10]:
chunks[0].shape

torch.Size([150])

In [11]:
# Configure LoRA settings
lora_config = LoraConfig(
    r=8,  # Rank of the low-rank matrices (can experiment with different values)
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.1,  # Dropout rate for LoRA layers
    task_type="CAUSAL_LM"  # Causal Language Model task
)

# Apply LoRA to the model
lora_model = get_peft_model(model, lora_config)



In [None]:

# Data collator for language modeling (adds random padding)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training arguments
training_args = TrainingArguments(
    output_dir="./gpt2-german-lora",
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    save_steps=1000,
    learning_rate=5e-4,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
    report_to="none",  # Disable reporting to external services
    save_total_limit=2
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)

# Start training
trainer.train()


In [19]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir="./german-gpt2-lora",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    num_train_epochs=3,
    save_steps=10_000,
    save_total_limit=2,
)

# Assuming you have already set up the LoRA model (lora_model)
trainer = Trainer(
    model=lora_model,  # Your LoRA model
    args=training_args,
    train_dataset=dataset,  # Pass the dataset here
    eval_dataset=eval_dataset  # <- Add this line
)

# Start training
trainer.train()


  "input_ids": torch.tensor(self.chunks[idx]),


ValueError: The model did not return a loss from the inputs, only the following keys: logits,past_key_values. For reference, the inputs it received are input_ids,attention_mask.