In [2]:
import time
from functools import partial
from statistics import mean, stdev

import fasttext
import torch
from torch.utils.data import DataLoader

from config import SG_CORPUS, SG_FULL, CHECKPOINTS_DIR, PROBLEM_TEST
from data import HatefulTweets, TextDataset
from experiment import (
    run_repeated_mlp,
    test_inference_time,
    calculate_memory_usage,
    check_errors,
)
from nn import BinaryMLP
from text_processing import get_fasttext_embeddings

In [2]:
run_repeated_mlp(SG_CORPUS, name="mlp_corpus")

Global seed set to 1
Global seed set to 2
Global seed set to 3
Global seed set to 4
Global seed set to 5
Global seed set to 6
Global seed set to 7
Global seed set to 8
Global seed set to 9
Global seed set to 10


{'test/loss': '0.3520 ± 0.0164',
 'test/f1': '0.5026 ± 0.0283',
 'test/acc': '0.8948 ± 0.0058',
 'test/precision': '0.6879 ± 0.0458',
 'test/recall': '0.3970 ± 0.0292',
 'train/loss': '0.1511 ± 0.0316',
 'train/f1': '0.8504 ± 0.0278',
 'train/acc': '0.9741 ± 0.0050',
 'train/precision': '0.8369 ± 0.0359',
 'train/recall': '0.8646 ± 0.0213',
 'train_time': '23.5230 ± 2.0934'}

In [3]:
run_repeated_mlp(SG_FULL, name="mlp_full")

Global seed set to 1
Global seed set to 2
Global seed set to 3
Global seed set to 4
Global seed set to 5
Global seed set to 6
Global seed set to 7
Global seed set to 8
Global seed set to 9
Global seed set to 10


{'test/loss': '0.3528 ± 0.0139',
 'test/f1': '0.4802 ± 0.0343',
 'test/acc': '0.8893 ± 0.0058',
 'test/precision': '0.6491 ± 0.0430',
 'test/recall': '0.3828 ± 0.0400',
 'train/loss': '0.1464 ± 0.0315',
 'train/f1': '0.8536 ± 0.0268',
 'train/acc': '0.9747 ± 0.0050',
 'train/precision': '0.8413 ± 0.0394',
 'train/recall': '0.8667 ± 0.0163',
 'train_time': '24.0255 ± 2.2360'}

In [3]:
embeddings_model = fasttext.load_model(str(SG_CORPUS))
get_embeddings = partial(get_fasttext_embeddings, embeddings_model)

dataset = TextDataset(PROBLEM_TEST, get_embeddings)
loader = DataLoader(
    dataset,
    batch_size=128,
    pin_memory=True,
    drop_last=True,
)

checkpoint_file = CHECKPOINTS_DIR / "mlp_corpus_1.ckpt"
model = BinaryMLP.load_from_checkpoint(
    checkpoint_file,
    emb_dim=300,
    hidden_dims=[512, 256, 128, 64],
    learning_rate=1e-4,
).cuda()



In [6]:
test_inference_time(model, loader)

'0.0003 ± 0.0001'

In [5]:
calculate_memory_usage(model)

'1.261 MB'

In [7]:
loader = DataLoader(
    dataset,
    batch_size=128,
    pin_memory=True,
)

check_errors(model, PROBLEM_TEST, loader)

Predicted correctly: 896
Predicted incorrectly: 104

Non-hate tweets predicted as hate: 26
Most misclassified examples:
	Prob: 1.000 	Text: 'celny snajperski strzał w lewacką chołotę'
	Prob: 0.997 	Text: 'bo ty tam pracujesz oszołomie'
	Prob: 0.974 	Text: 'wracaj szybko bo widzisz jak bez ciebie nam idzie'
	Prob: 0.912 	Text: 'wieczna zdrada nie zdrada trzeba rozmawiać pierdolenie od rzeczy'
	Prob: 0.902 	Text: 'droga pkamilko leczyć się leczyć póki czas'
	Prob: 0.862 	Text: 'też jesteś kwiatem tylko że chwastem'
	Prob: 0.822 	Text: 'rt panie kropiwnicki w latach 80 wojsko polskie skladalo przysięgę na wierność w szeregach armi radzieckiej'
	Prob: 0.806 	Text: 'tomaszowi szkoda że pmm nie schował zarobionej kasy w lisiej norze'
	Prob: 0.805 	Text: 'rt tomaszowi szkoda że pmm nie schował zarobionej kasy w lisiej norze'
	Prob: 0.786 	Text: 'ten to już zupełnie odwiesił mózg na kołek chory mózg'

Hate tweets predicted as non-hate: 78
Most misclassified examples:
	Prob: 0.039 	Text: 'zrzek