In [4]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
import torch

# Load the parallel dataset from the en.txt and vi.txt files
def load_translation_dataset(en_file, vi_file):
    # Open the files and read lines
    with open(en_file, 'r', encoding='utf-8') as f_en:
        en_sentences = f_en.readlines()

    with open(vi_file, 'r', encoding='utf-8') as f_vi:
        vi_sentences = f_vi.readlines()

    # Build a dataset of input-output pairs
    assert len(en_sentences) == len(vi_sentences), "Mismatched number of lines in the translation files."

    dataset = Dataset.from_dict({
        'en': en_sentences,
        'vi': vi_sentences
    })

    return dataset

# Tokenize input/output sentences
def tokenize_translation(examples, tokenizer, max_length=128):
    inputs = tokenizer(examples['en'], max_length=max_length, truncation=True, return_tensors='pt', padding='max_length')
    outputs = tokenizer(examples['vi'], max_length=max_length, truncation=True, return_tensors='pt', padding='max_length')

    # Shift the outputs by one token for the autoregressive modeling
    labels = outputs['input_ids']
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': labels
    }

# Train function
def train_translation_model(en_file, vi_file, model_name, output_dir, num_train_epochs=3, batch_size=8):
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)

    # Set pad_token to eos_token
    tokenizer.pad_token = tokenizer.eos_token

    # Load the dataset
    dataset = load_translation_dataset(en_file, vi_file)

    # Tokenize the dataset
    tokenized_dataset = dataset.map(lambda examples: tokenize_translation(examples, tokenizer), batched=True)

    # Load GPT-2 model
    model = GPT2LMHeadModel.from_pretrained(model_name)

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=batch_size,  # Increase batch size for faster training
        num_train_epochs=num_train_epochs,
        save_steps=1000,  # Save less frequently (adjust according to preference)
        logging_steps=500,  # Log less frequently
        fp16=True,  # Enable mixed precision for faster training on supported hardware
        logging_dir='./logs',
        evaluation_strategy="no",  # No evaluation during training to speed up process
        save_total_limit=1,  # Only keep the latest checkpoint to save disk space
    )

    # Trainer to handle the training loop
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
    )

    # Train the model
    trainer.train()
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)


# Paths to data files
en_file = "en.txt"
vi_file = "vi.txt"
model_name = 'gpt2'
output_dir = './translation_model'

# Train the model
train_translation_model(en_file, vi_file, model_name, output_dir)




Map:   0%|          | 0/5000 [00:00<?, ? examples/s]



Step,Training Loss
500,3.0542
1000,2.6892
1500,2.6682


In [26]:
def translate_text(sequence, max_length=50):
    model_path = './translation_model'
    model = GPT2LMHeadModel.from_pretrained(model_path)
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)

    # Tokenize the English input
    inputs = tokenizer.encode(sequence, return_tensors='pt', padding=True)
    attention_mask = torch.ones_like(inputs)  # Create attention mask manually

    # Generate the translation with the attention mask
    output = model.generate(
        inputs,
        attention_mask=attention_mask,  # Pass the attention mask
        max_length=max_length,
        num_return_sequences=1,
        pad_token_id=model.config.eos_token_id
    )

    # Decode the output while skipping special tokens
    translation = tokenizer.decode(output[0], skip_special_tokens=True)
    return translation

# Example translation
sequence = "Welcome to the show"
translation = translate_text(sequence)
print(translation)


Welcome to the show����������������������������������������������
