In [None]:
import argparse
import pandas as pd
import torch
from data_preprocess import preprocess_data, data_loaders
from dataset import IMDbDataset, get_data_loaders, split_data
from model import TransformerNet, setup_model
from train import train_f
from test import test_model, test_result

def parse_args():
    parser = argparse.ArgumentParser(description='argparser')

    parser.add_argument("--vocab_size", type=int, default=121302, help="Vocabulary size")
    parser.add_argument("--embedding_size", type=int, default=64, help="Embedding size")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
    parser.add_argument("--hidden_size", type=int, default=768, help="Hidden size")
    parser.add_argument("--n_heads", type=int, default=8, help="Number of attention heads")
    parser.add_argument("--n_layers", type=int, default=6, help="Number of transformer layers")
    parser.add_argument("--n_labels", type=int, default=2, help="Number of labels")
    parser.add_argument("--dropout", type=float, default=0.15, help="Dropout rate")
    parser.add_argument("--seq_length", type=int, default=256, help="Maximum sequence length")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--grad_clip", type=float, default=5, help="Gradient clipping")
    parser.add_argument("--train_size", type=float, default=0.5, help="Training set size")
    parser.add_argument("--val_size", type=float, default=0.5, help="Validation set size")
    parser.add_argument("--epochs", type=int, default=8, help="Number of training epochs")
    parser.add_argument("--es_threshold", type=int, default=3, help="Early stopping threshold")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()

    data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/transformer/imdb_processed.csv')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    features, labels = preprocess_data(data, args)
    model = TransformerNet(args.vocab_size, args.embedding_size, args.hidden_size, args.n_heads, args.n_layers, args.seq_length, args.n_labels, args.dropout).to(device=device)
    criterion, optimizer = setup_model(model, args.learning_rate)

    train_loader, val_loader, test_loader = data_loaders(features, labels, args)

    # Training
    train_losses, train_accuracies, val_losses, val_accuracies = train_f(model, optimizer, train_loader, val_loader, criterion, args.epochs, args.es_threshold, args)

    # test the model
    test_accuracy, test_loss = test_model(model, test_loader, criterion)
    test_result(test_accuracy, test_loss)