In [None]:
import pandas as pd
import torch
import torchtext
from torchtext.vocab import GloVe
from torchtext.data import get_tokenizer
import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall, F1Score
from pytorch_lightning.loggers import TensorBoardLogger
from torch import nn

from rnn_dataset import RnnDataset
from rnn_trainer import RnnTrainer

pl.seed_everything(42)

In [None]:
def run(model_name, GRU):
    acc_ckpt = pl.callbacks.ModelCheckpoint(
        monitor="avg_acc",
        mode="max",
        verbose=True,
        dirpath="../checkpoints/",
        filename=f"{model_name}",
    )

    logger = TensorBoardLogger(
        f"../logs/{model_name}", name=f"{model_name}"
    )

    model = RnnTrainer(text_col="text", GRU=GRU)

    trainer = pl.Trainer(
        accelerator='gpu',
        precision=16,
        max_epochs=10,
        auto_select_gpus=True,
        # strategy=plugin,
        callbacks=[acc_ckpt],
        fast_dev_run=False,
        detect_anomaly=False,
        logger=logger,
    )

    trainer.fit(model)
    return trainer.test(model, ckpt_path=f'../checkpoints/{model_name}.ckpt', verbose=True)

In [None]:
torch.cuda.empty_cache()
res = run('lstm', False)
with open('../results/baseline.txt', 'a') as f:
    f.writelines(f'=== LSTM ===\n')
    for k, v in res[0].items():
        f.writelines(f'{k}: {v}\n')
    f.writelines('\n\n\n')
    
torch.cuda.empty_cache()
res = run('gru', True)
with open('../results/baseline.txt', 'a') as f:
    f.writelines(f'=== GRU ===\n')
    for k, v in res[0].items():
        f.writelines(f'{k}: {v}\n')
    f.writelines('\n\n\n')