In [1]:
import sys

sys.path.append("..")

In [2]:
from pathlib import Path
from functools import partial

import numpy as np
import pandas as pd
import torch
import joblib
from torch.utils.data import DataLoader

from toxic.inference_bert import get_token_ids
from toxic.dataset import AUX_COLUMNS, ToxicDataset, collate_examples, SortSampler
from toxic.common import ToxicBot
from toxic.metric import ToxicMetric

In [3]:
MODEL_PATH = Path("../data/cache/")
DEVICE = "cuda:0"

In [4]:
tokenizer = joblib.load(str(MODEL_PATH / "bert-base-uncased_tokenizer.jbl"))
model = torch.load(str(MODEL_PATH / "bert-base-uncased_-1_yuval_220_f0.pth")).to(DEVICE)

In [5]:
collate_fn = partial(
    collate_examples,
    truncate_len=220,
    pad=0,
    closing_id=tokenizer.vocab["[SEP]"],
    mode="both"
)

![](https://pbs.twimg.com/media/DICFy_jWsAE6s6V?format=jpg&name=small)
[source](https://twitter.com/jessamyn/status/900867154412699649)

In [6]:
test_text = [
    "I am a man",
    "I am a woman",
    "I am a lesbian",
    "I am gay man",
    "I am dyke",
    "I am a white man",
    "I am a gay woman",
    "I am a white woman",
    "I am a gay white man",
    "I am a black man",
    "I am a gay white man",
    "I am a gay black man",
    "I am a black woman",
    "I am a gay black woman"
]
df = pd.DataFrame(dict(comment_text=test_text))

In [7]:
tokens = get_token_ids(
    df, tokenizer, is_bert=True)
test_ds = ToxicDataset(df, tokens, labeled=False)
test_loader = DataLoader(
    test_ds,
    collate_fn=collate_fn,
    batch_size=32,
    num_workers=0,
    pin_memory=True
)

In [8]:
with torch.no_grad():
    results = []
    for batch, _ in test_loader:
        results.append(model(batch.cuda()))
    results = torch.sigmoid(torch.cat(results)) * 100
    results.size()

In [9]:
predictions = pd.DataFrame(results.cpu().numpy(), columns=AUX_COLUMNS)
predictions["text"] = df["comment_text"].values
predictions.shape

(14, 16)

In [10]:
pd.set_option('display.float_format', lambda x: '%.2f' % x)

In [11]:
predictions.columns

Index(['target', 'severe_toxicity', 'obscene', 'identity_attack', 'insult',
       'threat', 'male', 'female', 'homosexual_gay_or_lesbian', 'christian',
       'jewish', 'muslim', 'black', 'white', 'psychiatric_or_mental_illness',
       'text'],
      dtype='object')

In [12]:
predictions[["text", "target", "identity_attack", "female", "homosexual_gay_or_lesbian", "black", "white"]].rename(
    columns={"target": "toxic", "homosexual_gay_or_lesbian":"homosexual"})

Unnamed: 0,text,toxic,identity_attack,female,homosexual,black,white
0,I am a man,0.61,0.27,3.37,0.12,0.01,0.03
1,I am a woman,0.69,0.82,98.79,0.08,0.16,0.03
2,I am a lesbian,6.35,5.11,0.1,97.3,0.07,0.02
3,I am gay man,3.24,9.14,0.34,93.88,0.39,0.05
4,I am dyke,1.29,0.11,0.02,0.0,0.0,0.0
5,I am a white man,1.55,3.8,1.36,0.03,0.73,98.58
6,I am a gay woman,3.74,8.48,78.52,91.64,0.1,0.23
7,I am a white woman,1.22,4.79,93.79,0.11,0.7,98.87
8,I am a gay white man,17.33,24.38,0.91,96.12,1.44,78.75
9,I am a black man,7.17,16.17,2.03,1.53,98.17,1.03


## Other random examples

In [23]:
test_text = [
    "Pretty sad really. Thousands come out to protest Trump who's in another country......and all those NDP and Green members - who were the driving force behind Proportional Representation - couldn't find the time to get out to support their own cause. That said, you can bet that Trudeau has alredy lost a good portion of that NDP/Green vote. The guy's a liar. Period.",
    "Idiots.",
    "Then put them in jail until they decide to cooperate.  This is ridiculous....",
    "\"He’s a proven liar. Why would you believe anything he has to say?\"\n\nExactly.\nThat\'s why I do not listen to a word Donnie says.",
    "Fascinating young woman, impressive work. It reminds me of one of the reasons I love Alaska so much —\xa0the people Alaska attracts are often just as extraordinary as the landscape itself. Great article, thank you.",
    "Well, at least the Russians are white, for heaven's sakes.  I'd rather have to live next to a nice white drunk Russian than a black Muslim Kenyan whose children's names are Satanic anagrams.",
    "Was I posted yesterday, it is interesting to note that under Session's watch only three black people have been appointed in Alabama for the federal courts. This despite the fact that black people make up over 39% of the population of that state. What underlines this reality must be Session's unconscious, if not conscious, attitude towards blacks in general."
]
df = pd.DataFrame(dict(comment_text=test_text))

In [24]:
tokens = get_token_ids(
    df, tokenizer, is_bert=True)
print([len(x) for x in tokens])
test_ds = ToxicDataset(df, tokens, labeled=False)
test_loader = DataLoader(
    test_ds,
    collate_fn=collate_fn,
    batch_size=32,
    num_workers=0,
    pin_memory=True
)
with torch.no_grad():
    results = []
    for batch, _ in test_loader:
        results.append(model(batch.cuda()))
    results = torch.sigmoid(torch.cat(results)) * 100
    results.size()
predictions = pd.DataFrame(results.cpu().numpy(), columns=AUX_COLUMNS)
predictions["text"] = df["comment_text"].values
predictions[["text", "target", "identity_attack", "female", "homosexual_gay_or_lesbian", "black", "white"]].rename(
    columns={"target": "toxic", "homosexual_gay_or_lesbian":"homosexual"})

[88, 3, 19, 36, 42, 48, 72]


Unnamed: 0,text,toxic,identity_attack,female,homosexual,black,white
0,Pretty sad really. Thousands come out to prote...,49.56,0.36,0.08,0.0,0.0,0.01
1,Idiots.,93.84,1.45,0.09,0.0,0.0,0.03
2,Then put them in jail until they decide to coo...,65.76,0.54,0.02,0.0,0.0,0.0
3,"""He’s a proven liar. Why would you believe any...",43.15,0.32,0.02,0.0,0.0,0.01
4,"Fascinating young woman, impressive work. It r...",0.33,0.3,66.28,0.08,0.08,0.06
5,"Well, at least the Russians are white, for hea...",86.47,84.04,1.25,0.28,85.57,95.26
6,"Was I posted yesterday, it is interesting to n...",13.69,12.58,0.14,0.03,98.41,0.2


## Validate
Make sure the mode is set up correctly.

In [80]:
df_valid, tokens_valid = joblib.load(str(MODEL_PATH / "valid_bert-base-uncased_-1_yuval_f0.jbl"))
idx = np.random.choice(np.arange(df_valid.shape[0]), 32 * 1000)
df_valid, tokens_valid = df_valid.iloc[idx].reset_index(drop=True), tokens_valid[idx]
valid_ds = ToxicDataset(df_valid, tokens_valid, labeled=True)
val_sampler = SortSampler(valid_ds, key=lambda x: len(valid_ds.tokens[x]))
df_valid = df_valid.iloc[list(iter(val_sampler))]
print(df_valid.target.describe())

count    32000.0000
mean         0.0914
std          0.2074
min          0.0000
25%          0.0000
50%          0.0000
75%          0.0918
max          1.0000
Name: target, dtype: float64


In [81]:
valid_loader = DataLoader(
    valid_ds,
    collate_fn=collate_fn,
    batch_size=64,
    num_workers=0,
    pin_memory=True,
    sampler=val_sampler
)

In [82]:
bot = ToxicBot(
    checkpoint_dir=Path("/tmp/"),
    log_dir=Path("/tmp/"),
    model=model, train_loader=None,
    val_loader=None, optimizer=None,
    echo=False,
    criterion=None,
    avg_window=100,
    callbacks=[],
    pbar=False,
    use_tensorboard=False,
    device=DEVICE
)
valid_pred, valid_y = bot.predict(valid_loader, return_y=True)

In [84]:
pd.set_option('precision', 4)
metric = ToxicMetric(df_valid)
metric(valid_y, valid_pred)

   bnsp_auc  bpsn_auc                       subgroup  subgroup_auc  \
7    0.9621    0.9016                          white        0.8591   
5    0.9593    0.9054                         muslim        0.8613   
2    0.9742    0.8827      homosexual_gay_or_lesbian        0.8890   
6    0.9754    0.8874                          black        0.9051   
0    0.9579    0.9550                           male        0.9348   
1    0.9634    0.9517                         female        0.9389   
4    0.9739    0.9355                         jewish        0.9460   
8    0.9839    0.9218  psychiatric_or_mental_illness        0.9470   
3    0.9552    0.9674                      christian        0.9511   

   subgroup_size  
7            452  
5            345  
2            217  
6            276  
0            754  
1            946  
4            142  
8             88  
3            699  
Overall AUC: 0.970701
Mean bnsp_auc: 0.966989
Mean bpsn_auc: 0.920417
Mean subgroup auc: 0.910494
Final score

(-0.9421502044674531, '94.22')