# Experiment 4: Document level embedding

In [None]:
from data_classes.TextLightningDataModule import TextLightningDataModule
from models.ClassifierSystem import LightningClassifier
from data_classes.pretrained_embeddings import get_pretrained_embeddings
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks import ModelCheckpoint


In [None]:
# Data and model settings
dataset = "IMDB"
num_class = 2
embedding = "Glove"
max_vectors = 20000
dim = 300
trunc = 234+2*173


# Training settings
max_epochs = 20
patience = 6
monitor = "Val Loss"
lr = 1e-3
batch_size = 32
num_workers = 0
advanced_metrics = False

num_layers = 1

# Log file:
log_file = "exp4"
log_file_csv = "exp4_csv"


In [None]:
vocab, vectors = get_pretrained_embeddings(
    embedding=embedding, max_vectors=max_vectors, dim=dim)


In [None]:
imdb_data = TextLightningDataModule(
    vocab, dataset=dataset, batch_size=batch_size, num_workers=num_workers, trunc=trunc)


In [None]:
imdb_data_bert = TextLightningDataModule(
    vocab, dataset="IMDBSentence", batch_size=batch_size, num_workers=num_workers, trunc=trunc, format="bert", shuffle=True)


In [None]:
imdb_data_wme = TextLightningDataModule(
    vocab, dataset="IMDBSentence", batch_size=batch_size, num_workers=num_workers, trunc=trunc, format="wme", shuffle=True)


In [None]:
for _ in range(5):
    for (model_type, dim) in [("pretrained-average", 300), ("from-scratch-average", 300), ("sentence-bert", 384), ("wme", 300)]:
        for output_layer_type in ["linear", "MLP"]:
            name = log_file + "-" + model_type + "-" + output_layer_type
            logger_tensor = TensorBoardLogger(log_file, name=name)
            logger_csv = CSVLogger(log_file_csv, name=name)
            checkpoint_callback = ModelCheckpoint(monitor=monitor)
            trainer = Trainer(max_epochs=max_epochs, gpus=1, auto_select_gpus=True, callbacks=[
                              EarlyStopping(monitor=monitor, patience=patience)], logger=[logger_tensor, logger_csv])
            classifier = LightningClassifier(embedding_level="sentence", num_class=num_class, vocab=vocab, vectors=vectors, embedding_size=dim,
                                             learning_rate=lr, model_type=model_type, output_layer_type=output_layer_type, advanced_metrics=advanced_metrics)
            if model_type == "sentence-bert":
                trainer.fit(classifier, imdb_data_bert)
            elif model_type == "wme":
                trainer.fit(classifier, imdb_data_wme)
            else:
                trainer.fit(classifier, imdb_data)
            trainer.test(ckpt_path="best")
