In [None]:
import torch
import torch.nn as nn
import argparse
from data_preprocess import preprocess_data, myTokenizerDE
from model import TranslateTransformer
from train import train_f
from evaluate import evaluate_model
import logging
save_frequency = 1
checkpoint_dir = '/path/to/checkpoints'

def parse_args():
    parser = argparse.ArgumentParser(description='argparser')

    parser = argparse.ArgumentParser()
    parser.add_argument('--learning_rate', type=float, default=5e-05, help='Learning rate')
    parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
    parser.add_argument('--device', type=str, default='cuda', help='Device (cuda or cpu)')
    parser.add_argument('--embedding_size', type=int, default=128, help='Embedding size')
    parser.add_argument('--num_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--num_encoder_layers', type=int, default=3, help='Number of encoder layers')
    parser.add_argument('--num_decoder_layers', type=int, default=3, help='Number of decoder layers')
    parser.add_argument('--max_len', type=int, default=227, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')

    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    logger.info("Training started...")
    logger.info(f"Hyperparameters: learning_rate={args.learning_rate}, batch_size={args.batch_size}")

    return args

if __name__ == "__main__":
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_iterator, valid_iterator, test_iterator, SRC, TARGET = preprocess_data()
    trg_vocab_size =len(TARGET.vocab)

    model = TranslateTransformer(
        embedding_size=args.embedding_size,
        src_vocab_size=len(SRC.vocab),
        trg_vocab_size=len(TARGET.vocab),
        src_pad_idx=SRC.vocab.stoi["<pad>"],
        num_heads=args.num_heads,
        num_encoder_layers=args.num_encoder_layers,
        num_decoder_layers=args.num_decoder_layers,
        max_len=args.max_len
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=SRC.vocab.stoi["<pad>"])

    train_losses, valid_losses = train_f(model, train_iterator,  valid_iterator, optimizer, criterion, device, checkpoint_dir, args.epochs, save_frequency, trg_vocab_size)

    bleu_score = evaluate_model(model, test_iterator, SRC, TARGET, myTokenizerDE)
    print(f'BLEU-4 Score: {bleu_score:.2f}')
