In [1]:
import sys
sys.path.append("../processing/")

In [2]:
import gzip
import pickle
import random
import ujson as json
import numpy as np
import pandas as pd

from collections import defaultdict
from tqdm import tqdm

from _config import Config
from utils import json_paths_iter

In [3]:
ds_names = ["news", "midterms"]
prefixes = [p for p in range(10, 110, 10)]
TOX_THRESHOLD = 0.531
ds_map = {}

for ds_name in ds_names:
    print("\n>>", ds_name)

    ds = {}
    conf = Config(ds_name)
    fpaths = json_paths_iter(conf.conversations_no_embs_jsons_dir)
    
    for i, fpath in enumerate(fpaths):
        if i % 10000 == 0: print(i, end=" ")

        conversation = json.load(gzip.open(fpath))
        
        root_tweet_id = conversation["reply_tree"]["tweet"]
        
        # sort tweets
        tweets = list(conversation["tweets"].values())
        tweets.sort(key=lambda x: x["time"])

        # id_str => tox score
        tox_scores = conversation["toxicity_scores"]

        # tox_scores array => 1 index
        N = len(tweets)
        tox_scores_arr = np.zeros(N + 1)

        # dummy value: number to avoid comparison with none
        tox_scores_arr[0] = -100
        for t_idx, tweet in enumerate(tweets):
            tox_scores_arr[t_idx + 1] = tox_scores.get(tweet["id"], -1)

        tox_scores_bin = tox_scores_arr > TOX_THRESHOLD
        tox_scores_cum_sum = np.cumsum(tox_scores_bin, dtype=float)
        tox_total = tox_scores_cum_sum[-1]

        assert np.sum(tox_scores_bin) == tox_scores_cum_sum[-1]
        
        # tweet data
        tweet_dict = {}
        tweet_dict["root_tweet_id"] = root_tweet_id
        tweet_dict["root_tweet_type"] = conversation["root_tweet_type"]
        tweet_dict["n"] = N
        
        for pre_n in prefixes:
            if 2 * pre_n <= N:
                # [1, prefix]
                pre_n_tox = tox_scores_cum_sum[pre_n]

                # (prefix, suffix]
                suf_i_tox = tox_scores_cum_sum[2 * pre_n] - pre_n_tox

                # (prefix, end]
                suf_n = N - pre_n
                suf_n_tox = tox_total - pre_n_tox
                suf_f_tox = suf_n_tox / suf_n

                tweet_dict[f"p{pre_n}_pre_n_tox"] = pre_n_tox
                tweet_dict[f"p{pre_n}_suf_i_tox"] = suf_i_tox
                tweet_dict[f"p{pre_n}_suf_n"] = suf_n
                tweet_dict[f"p{pre_n}_suf_n_tox"] = suf_n_tox
                tweet_dict[f"p{pre_n}_suf_f_tox"] = suf_f_tox
        
        # save tweet data
        ds[root_tweet_id] = tweet_dict
        
    # write ds to ds_map
    ds_map[ds_name] = ds


>> news
0 10000 20000 30000 40000 50000 60000 70000 80000 90000 100000 110000 120000 130000 140000 150000 160000 170000 180000 190000 200000 210000 220000 230000 240000 250000 260000 270000 280000 290000 300000 310000 320000 330000 340000 350000 360000 370000 380000 390000 400000 410000 420000 430000 440000 450000 460000 470000 480000 490000 500000 510000 
>> midterms
0 10000 20000 30000 40000 50000 60000 70000 80000 90000 100000 110000 120000 130000 140000 150000 160000 170000 180000 190000 200000 210000 220000 230000 240000 250000 260000 270000 280000 290000 300000 310000 320000 330000 340000 350000 360000 370000 380000 390000 400000 410000 420000 430000 440000 450000 460000 470000 480000 490000 500000 510000 520000 530000 540000 550000 560000 570000 580000 590000 600000 610000 620000 630000 640000 650000 660000 670000 

In [4]:
ds_map_fpath = f"{Config().root}/data/modeling/prefix/label_maps.pkl.gz"
with gzip.open(ds_map_fpath, "wb") as fout:
    pickle.dump(ds_map, fout, protocol=4)

### Create labels

In [3]:
ds_map_fpath = f"{Config().root}/data/modeling/prefix/label_maps.pkl.gz"
ds_map = pickle.load(gzip.open(ds_map_fpath))

In [22]:
ds_name = "midterms"
ds = ds_map[ds_name]
prefixes = [p for p in range(10, 110, 10)]
metrics = ["suf_i_tox", "suf_f_tox"]

In [23]:
# 1] for each prefix & metric collect all root_tweet_id, values pairs

acc = defaultdict(list) 

for conv in ds.values():    
    for p in prefixes:
        for metric in metrics:
            if f"p{p}_pre_n_tox" not in conv:
                continue

            pre_n_tox = int(conv[f"p{p}_pre_n_tox"])
            id_val_pair = (conv["root_tweet_id"], conv[f"p{p}_{metric}"])

            acc[f"p{p}__{metric}__all"].append(id_val_pair)
            acc[f"p{p}__{metric}__pre_n_tox_{pre_n_tox}"].append(id_val_pair)

In [24]:
# 2] for each parition metrics
# --- compute the aggregate stats
# --- split the examples
# --- downsample to get class balance

def balance_samples(x, y, RNG):
    # downsample s.t. |x| == |y|
    if len(x) > len(y):
        x = RNG.sample(x, len(y))
    elif len(y) > len(x):
        y = RNG.sample(y, len(x))

    assert len(x) == len(y)
    
    return x, y

paritions = {}
min_bucket_size = 200
agg_stats = {"n": {}, "q33": {}, "q50": {}, "q66": {}}

for key, root_id_val_pairs in acc.items():
    
    vals = [val for _, val in root_id_val_pairs]
    
    q33, q50, q66 = np.quantile(vals, [1/3, 0.5, 2/3])
    
    # save aggregates
    agg_stats["n"][key] = len(vals)
    agg_stats["q33"][key] = q33
    agg_stats["q50"][key] = q50
    agg_stats["q66"][key] = q66 
    
    if len(vals) < min_bucket_size:
        continue
    
    # split tweet into paritions
    q50_below, q50_above = [], []
    q33_below, q66_above = [], []
    
    for root_id, val in root_id_val_pairs:
        if val >= q50:
            q50_above.append(root_id)
        else:
            q50_below.append(root_id)
            
        if val >= q66:
            q66_above.append(root_id)
        elif val <= q33:
            q33_below.append(root_id)
    
    
    # downsample to achieve balance
    RNG = random.Random(42)
    q50_below, q50_above = balance_samples(q50_below, q50_above, RNG)
    q33_below, q66_above = balance_samples(q33_below, q66_above, RNG)    
    
    # print(key, len(q50_above), len(q50_below), "|", len(q66_above), len(q33_below))
    
    paritions[key] = {
        "<q50": q50_below, ">=q50": q50_above,
        "<=q33": q33_below, ">=q66": q66_above
    }

In [25]:
# 3] accumulate labels for each tweet for different prefix / metrics / aggregation

# tweet_id1 => {p10__suf_f_tox__tox_bucket__q33_v_q66: True}
labels = defaultdict(dict)

for prefix in prefixes:
    for metric in metrics:
        # all
        key_p = f"p{prefix}__{metric}__all"
        key_l = f"p{prefix}__{metric}__all"
        
        # >=q50
        for root_id in paritions[key_p][">=q50"]:
            labels[root_id][f"{key_l}__>=q50"] = True

        for root_id in paritions[key_p]["<q50"]:
            labels[root_id][f"{key_l}__>=q50"] = False
        
        # q33_v_q66
        for root_id in paritions[key_p][">=q66"]:
            labels[root_id][f"{key_l}__q33_v_q66"] = True

        for root_id in paritions[key_p]["<=q33"]:
            labels[root_id][f"{key_l}__q33_v_q66"] = False
        
        # toxicity buckets
        for pre_tox_n in range(p + 1):
            key_p = f"p{prefix}__{metric}__pre_n_tox_{pre_tox_n}"
            key_l = f"p{prefix}__{metric}__tox_bucket"
            
            if key_p not in paritions:
                continue

            # >=q50
            for root_id in paritions[key_p][">=q50"]:
                labels[root_id][f"{key_l}__>=q50"] = True

            for root_id in paritions[key_p]["<q50"]:
                labels[root_id][f"{key_l}__>=q50"] = False

            # q33_v_q66
            for root_id in paritions[key_p][">=q66"]:
                labels[root_id][f"{key_l}__q33_v_q66"] = True

            for root_id in paritions[key_p]["<=q33"]:
                labels[root_id][f"{key_l}__q33_v_q66"] = False

In [26]:
# + other conversation stats
out = {}

for root_tweet_id, conv_labels in labels.items():
    out[root_tweet_id] = {
        **ds[root_tweet_id],
        **conv_labels
    }

In [27]:
# output
out_fpath = f"{Config().modeling_dir}/prefix/datasets/{ds_name}_labels.pkl.gz"

with gzip.open(out_fpath, "wb") as fout:
    pickle.dump(out, fout, protocol=4)
    
print("DONE!")

DONE!


In [None]:
# list(out.items())[0]

### Sanity Checks

In [29]:
bucket_sizes = defaultdict(int)

for root_id, conv_labels in labels.items():
    conv_dict = ds[root_id]
    n = conv_dict["n"]
    
    for p in prefixes:
        for m in metrics:
            key_all_q50 = f"p{p}__{m}__all__>=q50"
            key_all_q33 = f"p{p}__{m}__all__q33_v_q66"
            key_tox_q50 = f"p{p}__{m}__tox_bucket__>=q50"
            key_tox_q33 = f"p{p}__{m}__tox_bucket__q33_v_q66"
            
            # (*) size > prefix size
            found = False
            if (key_all_q50 in conv_labels or 
                key_all_q33 in conv_labels or 
                key_tox_q50 in conv_labels or 
                key_tox_q33 in conv_labels
               ):
                found = True
                # assert n >= 2 * p
                assert n >= p + 10
            
            # (*) if q33_v_q66 = True => >=q50 = True
            if key_all_q33 in conv_labels and conv_labels[key_all_q33]:
                assert conv_labels.get(key_all_q50, True) == True

            if key_tox_q33 in conv_labels and conv_labels[key_tox_q33]:
                assert conv_labels.get(key_tox_q50, True) == True
            
            # (*) check the equalities
            if found:
                m_val = conv_dict[f"p{p}_{m}"]
                pre_n_tox = int(conv_dict[f"p{p}_pre_n_tox"])
                
                if key_all_q50 in conv_labels:                    
                    if conv_labels[key_all_q50]:
                        bucket_sizes[(f"p{p}__{m}__all", ">=q50")] += 1
                        assert m_val >= agg_stats["q50"][f"p{p}__{m}__all"]
                    else:
                        bucket_sizes[(f"p{p}__{m}__all", "<q50")] += 1
                        assert m_val < agg_stats["q50"][f"p{p}__{m}__all"]
                        
                if key_all_q33 in conv_labels:
                    if conv_labels[key_all_q33]:
                        bucket_sizes[(f"p{p}__{m}__all", ">=q66")] += 1
                        assert m_val >= agg_stats["q66"][f"p{p}__{m}__all"]
                    else:
                        bucket_sizes[(f"p{p}__{m}__all", "<=q33")] += 1
                        assert m_val <= agg_stats["q33"][f"p{p}__{m}__all"]
                        
                if key_tox_q50 in conv_labels:
                    if conv_labels[key_tox_q50]:
                        bucket_sizes[(f"p{p}__{m}__pre_n_tox_{pre_n_tox}", ">=q50")] += 1
                        assert m_val >= agg_stats["q50"][f"p{p}__{m}__pre_n_tox_{pre_n_tox}"]                         
                    else:
                        bucket_sizes[(f"p{p}__{m}__pre_n_tox_{pre_n_tox}", "<q50")] += 1
                        assert m_val < agg_stats["q50"][f"p{p}__{m}__pre_n_tox_{pre_n_tox}"]  
                
                if key_tox_q33 in conv_labels:
                    if conv_labels[key_tox_q33]:
                        bucket_sizes[(f"p{p}__{m}__pre_n_tox_{pre_n_tox}", ">=q66")] += 1
                        assert m_val >= agg_stats["q66"][f"p{p}__{m}__pre_n_tox_{pre_n_tox}"]                         
                    else:
                        bucket_sizes[(f"p{p}__{m}__pre_n_tox_{pre_n_tox}", "<=q33")] += 1
                        assert m_val <= agg_stats["q33"][f"p{p}__{m}__pre_n_tox_{pre_n_tox}"]

# (*) are all tweet paritions accounted for
for key_qq, count in bucket_sizes.items():
    key, qq = key_qq
    assert len(paritions[key][qq]) == count

In [30]:
# (*) check the aggregate statistics
df_list = []

for conv in ds.values():
    for p in prefixes:
        if f"p{p}_pre_n_tox" not in conv:
            continue
            
        df_list.append({
            "p": p,
            "pre_n_tox": int(conv[f"p{p}_pre_n_tox"]),
            "metric": "suf_i_tox",
            "val": conv[f"p{p}_suf_i_tox"]
        })
        
        df_list.append({
            "p": p,
            "pre_n_tox": int(conv[f"p{p}_pre_n_tox"]),
            "metric": "suf_f_tox",
            "val": conv[f"p{p}_suf_f_tox"]
        })
        
df = pd.DataFrame(df_list)

In [31]:
def f_q33(x):
    return np.quantile(x, 1/3)

def f_q50(x):
    return np.quantile(x, 0.5)

def f_q66(x):
    return np.quantile(x, 2/3)

df_med_all = df\
    .groupby(["p", "metric"])["val"]\
    .agg(["count", f_q33, f_q50, f_q66])\
    .reset_index()

df_med_pre = df\
    .groupby(["p", "metric", "pre_n_tox"])["val"]\
    .agg(["count", f_q33, f_q50, f_q66])\
    .reset_index()

In [32]:
for _, row in df_med_all.iterrows():
    p = row["p"]
    m = row["metric"]
    key = f"p{p}__{m}__all"
    assert row["count"] == agg_stats["n"][key]    
    assert row["f_q33"] == agg_stats["q33"][key]
    assert row["f_q50"] == agg_stats["q50"][key]
    assert row["f_q66"] == agg_stats["q66"][key]

In [33]:
for _, row in df_med_pre.iterrows():
    p = row["p"]
    m = row["metric"]
    pre_n_tox = int(row["pre_n_tox"])
    key = f"p{p}__{m}__pre_n_tox_{pre_n_tox}"
    assert row["count"] == agg_stats["n"][key]    
    assert row["f_q33"] == agg_stats["q33"][key]
    assert row["f_q50"] == agg_stats["q50"][key]
    assert row["f_q66"] == agg_stats["q66"][key]