In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

%matplotlib inline

In [None]:
seed = 19

In [None]:
df = pd.read_csv('../dataset/train.csv')
df.head()

In [None]:
cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

In [None]:
print('total', df.shape[0])
print('-----------------')
for col in cols:
    cnt = df[col].sum()
    print(col, cnt, '({:.2f}%)'.format(cnt*100.0/df.shape[0]))

In [None]:
fig, ax = plt.subplots(1,1)
cax = ax.matshow(df[cols].corr())
fig.colorbar(cax)

ax.set_xticklabels([''] + cols, rotation=90)
ax.set_yticklabels([''] + cols)

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

In [None]:
for i, col1 in enumerate(cols):
    for j in range(i+1, len(cols)):
        col2 = cols[j]
        cnt1 = df[col1].sum()
        cnt2 = df[col2].sum()
        cnt_common = (df[col1] * df[col2]).sum()
        e = cnt_common >= min(cnt1, cnt2) * 0.66
        print(col1, df[col1].sum(), col2, df[col2].sum(),
              'common', cnt_common, '!!!' if e else '')

In [None]:
cols2 = cols.copy()
cols2.remove('toxic')

a = df[cols2].sum(1)
cnt = ((a > 0) & (df['toxic'] == 0)).sum()
print('xxx but not toxic =', cnt)

cnt = ((a == 0) & (df['toxic'] > 0)).sum()
print('toxic but not xxx =', cnt)

### split

In [None]:
# get indcices of all xxx combinations
df['xxx'] = 0
for col in cols:
    df['xxx'] *= 2
    df['xxx'] += df[col]

# df['xxx'].value_counts()

In [None]:
def get_split_indices(num, split_cnt):
    indices = np.arange(num)
    np.random.shuffle(indices)
    indices_lst = []
    if num > split_cnt:
        per_len = num // split_cnt
    else:
        per_len = 1
    s = 0
    for _ in range(split_cnt-1):
        indices_lst.append(indices[s:s+per_len])
        s += per_len
    indices_lst.append(indices[s:])
    return indices_lst

def do_split(split_cnt=5):
    df_indices = [[] for _ in range(split_cnt)]
    for x in range(64):
        dfx = df[df['xxx'] == x]
        indices_lst = get_split_indices(dfx.shape[0], split_cnt)
        for i, indices in enumerate(indices_lst):
            df_indices[i] += dfx.index[indices].tolist()
    return df_indices

In [None]:
np.random.seed(seed)

df_indices = do_split()

In [None]:
# validate1
a = []
for i in range(5):
    a += df_indices[i]
set(a) == set(range(df.shape[0]))

In [None]:
# validate2
for i in range(5):
    df1 = df.iloc[df_indices[i]]
    print('total', df1.shape[0])
    print('-----------------')
    for col in cols:
        cnt = df1[col].sum()
        print(col, cnt, '({:.2f}%)'.format(cnt*100.0/df1.shape[0]))
    print('\n')

In [None]:
np.savez('../dataset/train_split.npz', indices=df_indices)