In [1]:
import time
from functools import partial
from statistics import mean, stdev

import fasttext
import torch
from torch.utils.data import DataLoader

from config import SG_CORPUS, SG_FULL, CHECKPOINTS_DIR, PROBLEM_TEST
from data import HatefulTweets, TextDataset
from experiment import run_repeated_mlp, test_inference_time, calculate_memory_usage
from nn import BinaryMLP
from text_processing import get_fasttext_embeddings

In [None]:
run_repeated_mlp(SG_CORPUS, name="mlp_corpus")

In [None]:
run_repeated_mlp(SG_FULL, name="mlp_full")

In [None]:
embeddings_model = fasttext.load_model(str(SG_CORPUS))
get_embeddings = partial(get_fasttext_embeddings, embeddings_model)

dataset = TextDataset(PROBLEM_TEST, get_embeddings)
loader = DataLoader(
    dataset,
    batch_size=128,
    pin_memory=True,
    shuffle=False,
    num_workers=0,
    drop_last=True,
)

checkpoint_file = CHECKPOINTS_DIR / "mlp_corpus_1.ckpt"
model = BinaryMLP.load_from_checkpoint(
    checkpoint_file,
    emb_dim=300,
    hidden_dims=[512, 256, 128, 64],
    learning_rate=1e-4,
).cuda()

In [None]:
test_inference_time(model, loader)

In [None]:
calculate_memory_usage(model)