In [None]:
import argparse
import logging
import torch
import torch.nn as nn
from preprocess import load_and_preprocess_data, load_data, de_train_dataset, en_train_dataset, de_test_dataset, en_test_dataset
from train import run_training_loop
from model import Encoder, Decoder, Transformer
from preprocess import de_word2index, en_word2index, test
import torch.optim as optim
from translate import show_bleu, CustomTestData

# -Main- #

def parse_args():

    parser = argparse.ArgumentParser(description="Transformer Model for Machine Translation")
    parser.add_argument("--train_path_en", type=str, default="path_to_train_en.txt", help="Path to English training data")
    parser.add_argument("--train_path_de", type=str, default="path_to_train_de.txt", help="Path to German training data")
    parser.add_argument("--max_len", type=int, default=50, help="Maximum translation length")
    parser.add_argument("--log_file", type=str, default="transformer.log", help="Log file path")
    parser.add_argument("--learning_rate", type=float, default=0.0005, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=64, help="Dataloader batch size")
    parser.add_argument("--n_epochs", type=int, default=2, help="Number of training epochs")
    parser.add_argument("--clip", type=float, default=1.0, help="Gradient clipping threshold")

    args = parser.parse_args()
    return args

if __name__ == "__main__":

    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    logging.basicConfig(filename=args.log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    data = load_and_preprocess_data(args.train_path_en, args.train_path_de)

    INPUT_DIM = len(de_word2index)
    OUTPUT_DIM = len(en_word2index)

    HIDDEN_DIM = 256
    ENC_LAYERS = 3
    DEC_LAYERS = 3
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 512
    DEC_PF_DIM = 512
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1

    SRC_PAD_IDX = de_word2index['<PAD>']
    TRG_PAD_IDX = en_word2index['<PAD>']

    enc = Encoder(INPUT_DIM, HIDDEN_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT)
    dec = Decoder(3, 3, OUTPUT_DIM, HIDDEN_DIM, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT)

    model = Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    train_loader, test_loader = load_data(args.batch_size, de_train_dataset, en_train_dataset, de_test_dataset, en_test_dataset)

    run_training_loop(model, train_loader, optimizer, criterion, args)

    test_data_de = CustomTestData(test['de'])

    show_bleu(test_data_de, de_word2index, en_word2index, model)