In [9]:
from sagemaker.huggingface.model import HuggingFaceModel
from datasets import Dataset
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import pandas as pd
import awswrangler as wr
from quality_calculator import compute_bias_metrics_for_model, calculate_overall_auc, get_final_metric
from tqdm import tqdm
tqdm.pandas()

SEED = 1234

import os
import random
import gc
import warnings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 500)
pd.options.mode.chained_assignment = None
warnings.simplefilter(action='ignore', category=FutureWarning)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Today I'm going to use {device.type}")

model_name = 'unitary/toxic-bert'

model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir = '../tmp/AutoModel')
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = '../tmp/AutoTokenizer')    

Today I'm going to use cuda


In [11]:
def seed_everything(seed=SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()
gc.collect()
torch.cuda.empty_cache()
# del model
# del Trainer
# del tokenizer

In [12]:
with open("../data/godel.txt") as f:
    lines = f.readlines()
godel_test_comments = pd.DataFrame(lines, columns = ['comment_text'])
godel_test_comments['comment_text'] = godel_test_comments['comment_text'].apply(lambda x: x.strip())

In [13]:
sigmoid = torch.nn.Sigmoid()

In [14]:
def predict_toxiity(text):
    inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
        model.to(device)
        outputs = model(**inputs).logits
        probas = sigmoid(outputs).cpu().detach().numpy()[0]
        return {
            'toxicity': probas[0],
            'severe_toxic': probas[1],
            'obscene': probas[2],
            'threat': probas[3],
            'insult': probas[4],
            'identity_hate': probas[5]
        }

In [15]:
%%time
godel_test_comments['results'] = godel_test_comments['comment_text'].apply(predict_toxiity)

CPU times: user 3.17 s, sys: 889 ms, total: 4.06 s
Wall time: 4.07 s


In [16]:
godel_test_comments = godel_test_comments.join(pd.json_normalize(godel_test_comments['results'])).drop(['results'], axis=1)
godel_test_comments.head(3)

Unnamed: 0,comment_text,toxicity,severe_toxic,obscene,threat,insult,identity_hate
0,"Hi team, August 15th is a state holiday in Poland, so do I understand correctly that we will receive our salaries on Friday, 12th?",0.000573,0.000125,0.000177,0.000129,0.000176,0.000144
1,"Or on Tuesday, 16th",0.000824,0.000107,0.000189,0.000113,0.000182,0.000133
2,"according to contract invoices have to be payed within 15 day, no exceptions for public holidays were described. paying on Tuesday, 16th seems to be a breach of contract. Viktoryia Charnianina please correct me if I'm wrong",0.000693,0.000112,0.000171,0.000119,0.000176,0.000137


In [17]:
BUCKET_NAME = 'sagemaker-godeltech'
TEST_PATH = f"s3://{BUCKET_NAME}/data/test/test.csv"
test = wr.s3.read_csv([TEST_PATH])

In [None]:
%%time
results = test['comment_text'].progress_apply(predict_toxiity)
# results = np.vectorize(predict_toxiity)(test['comment_text'])

  3%|▎         | 6288/194641 [03:51<2:00:27, 26.06it/s]

In [None]:
predictions = np.where(pd.json_normalize(results)['toxicity'] >= 0.5, 1, 0)

In [None]:
oof_name = 'predicted_target'
identity_columns = ['male', 'female', 'homosexual_gay_or_lesbian', 'christian', 'jewish', 'muslim', 'black', 'white', 'psychiatric_or_mental_illness']
test[oof_name] = predictions
#evaluation
bias_metrics_df = compute_bias_metrics_for_model(test, identity_columns, oof_name, 'toxicity')
display(bias_metrics_df)
FINAL_SCORE = get_final_metric(bias_metrics_df, calculate_overall_auc(test, oof_name))
print(f"FINAL SCORE FOR TOXIC-BERT IS {FINAL_SCORE}")  