In [None]:
import argparse
import logging
import torch
import nltk
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from preprocess import load_multi30k_data, convert_examples_to_features
from metric import compute_metrics
from model import create_model
from train import create_training_args, create_trainer
from translate import translate_1
nltk.download('punkt')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Translate German to English using T5 model")
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_dir", type=str, default="/content/drive/MyDrive/Colab Notebooks/multi30k",
                        help="Directory containing data files")
    parser.add_argument("--model_ckpt", type=str, default="t5-small", help="Path to the model checkpoint")
    parser.add_argument("--output_dir", type=str, default="chkpt", help="Directory to save model checkpoints")
    parser.add_argument("--max_token_length", type=int, default=128, help="Maximum token length")

    args = parser.parse_args()

    return args

if __name__ == "__main__":
    args = parse_args()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load and preprocess data
    dataset = load_multi30k_data(args.data_dir)
    tokenizer = T5Tokenizer.from_pretrained(args.model_ckpt)

    tokenized_datasets = dataset.map(
        lambda examples: convert_examples_to_features(examples, tokenizer, args.max_token_length),
        batched=True,
        remove_columns=dataset["train"].column_names
    )

    # Create and train the model
    model = create_model(args.model_ckpt, device)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    training_args = create_training_args(args.output_dir)
    trainer = create_trainer(model, training_args, tokenized_datasets, data_collator, tokenizer, compute_metrics=compute_metrics)

    trainer.train()

    # Translate and evaluate on test data
    test_dataloader = DataLoader(
        tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
    )

    eval_preds = trainer.predict(tokenized_datasets["valid"])
    metric_result = compute_metrics(eval_preds)
    tokenizer = T5Tokenizer.from_pretrained(args.model_ckpt)
    logger.info(f'BLEU Score: {metric_result["bleu"]:.4f}')

    average_bleu_score = translate_1(model, test_dataloader, tokenizer, args.max_token_length)
    logger.info(f'Average BLEU-4 Score: {average_bleu_score:.4f}')
