In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import embgam
from datasets import load_dataset
import pandas as pd
from toxigen import label_annotations
import numpy as np
import sklearn.metrics
import pickle as pkl

### Load data

In [2]:
TG = load_dataset("skg/toxigen-data", name="annotated", use_auth_token=True)
df_train = label_annotations(pd.DataFrame(TG["train"]))
df_TG_test = pd.DataFrame(TG["test"])
df_test = label_annotations(df_TG_test)

Reusing dataset toxigen-data (/home/chansingh/.cache/huggingface/datasets/skg___toxigen-data/annotated/1.1.0/3dd39bc1508e10d3eebcca2f60948e1529149c78a24594fd929aaa1f1bda74d0)


  0%|          | 0/2 [00:00<?, ?it/s]

### Fit model

In [None]:
m = embgam.EmbGAMClassifier(
    checkpoint='tomh/toxigen_roberta',
    ngrams=2,
    all_ngrams=False,
)
m.fit(df_train['text'], df_train['label'])
# m.cache_linear_coefs(df_test['text'])
pkl.dump(m, open(f'toxigen_embgam_ngrams=2_roberta.pkl', 'wb'))

### Evaluate performance

In [6]:
m = pkl.load(open('toxigen_embgam_ngrams=2_roberta.pkl', 'rb'))

def get_metrics(m, df):
    preds = m.predict(df['text'])
    preds_proba = m.predict_proba(df['text'])
    acc = sklearn.metrics.accuracy_score(df['label'], preds)
    roc_auc = sklearn.metrics.roc_auc_score(df['label'], preds_proba[:, 1])
    return acc, roc_auc

acc, roc_auc = get_metrics(m, df_test)
print(f'Test accuracy {acc:0.2f}')
print(f'Test ROC AUC {roc_auc:0.2f}')

Test accuracy 0.68
Test ROC AUC 0.77




In [17]:
rocs = []
target_groups = pd.DataFrame(TG["test"]).target_group.unique()
for target_group in target_groups:
    df = df_test[df_TG_test['target_group'] == target_group]
    try:
        acc, roc_auc = get_metrics(m, df)
        rocs.append(roc_auc)
    except:
        rocs.append(np.nan)
pd.DataFrame.from_dict({'target_group': target_groups, 'roc': rocs})



Unnamed: 0,target_group,roc
0,black/african-american folks,0.860465
1,black folks / african-americans,0.888889
2,mexican folks,0.702128
3,women,0.849112
4,native american/indigenous folks,0.897436
5,native american folks,
6,folks with physical disabilities,0.678796
7,latino/hispanic folks,0.611529
8,chinese folks,0.777778
9,middle eastern folks,0.836094


## Interpret model

In [None]:
# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))