Creating a balanced test set for flat training (cross validation)

In [1]:
import os
import yaml
import pickle

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

In [2]:
cd ../src/

/Users/cock/kDrive/PhD/Projects/Labs/beerslaw-lab/src


# Data


In [3]:
with open('../data/post_test/rankings_scored.pkl', 'rb') as fp:
    rankings = pickle.load(fp)
    rankings['prior_4cat_knowledge'] = rankings['prior_4cat'].apply(lambda x: x[0] if x != 'none' else 0)
    rankings['total_score'] = rankings['q1_score'] + rankings['q2_score'] + rankings['q3_score'] + rankings['q4_score'] + rankings['q5_score'] + rankings['q6_score']
    
# stratification columns (ideal separation)
stratification = [
    'language', 'field', 'year', 'gender', 'prior_4cat_knowledge', 'total_score'
]
rankings['stratification_column_v0'] = rankings['language'] + '_' + rankings['field'] + '_' + rankings['year'] + '_' + rankings['gender'].astype(str) + '_' + rankings['prior_4cat_knowledge'].astype(str) + '_' + rankings['total_score'].astype(str)
rankings = rankings.dropna(subset=['stratification_column_v0'])

# Check the labels that are alone 
## Prune language
clean = rankings[['username', 'stratification_column_v0']].groupby('stratification_column_v0').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v0'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v0'] in unique_attributes:
        new_strat = row['field'] + '_' + row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + str(row['total_score'])
        return new_strat
    else:
        return row['stratification_column_v0']
rankings['stratification_column_v1'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)
## Prune language
clean = rankings[['username', 'stratification_column_v1']].groupby('stratification_column_v1').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v1'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v1'] in unique_attributes:
        new_strat = row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + str(row['total_score'])
        return new_strat
    else:
        return row['stratification_column_v1']
rankings['stratification_column_v2'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

## Round Score
clean = rankings[['username', 'stratification_column_v2']].groupby('stratification_column_v2').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v2'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v2'] in unique_attributes:
        score = str(int(row['total_score']/6))
        new_strat = row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v2']
rankings['stratification_column_v3'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Minority genders together
clean = rankings[['username', 'stratification_column_v3']].groupby('stratification_column_v3').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v3'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v3'] in unique_attributes:
        score = str(int(row['total_score']/6))
        gender = str(int(row['gender'] > 1))
        new_strat = row['year'] + '_' + gender + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v3']
rankings['stratification_column_v4'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Strip the year out
clean = rankings[['username', 'stratification_column_v4']].groupby('stratification_column_v4').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v4'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v4'] in unique_attributes:
        score = str(int(row['total_score']/6))
        gender = str(int(row['gender'] > 1))
        new_strat = gender + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v4']
rankings['stratification_column_v5'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Strip Gender
clean = rankings[['username', 'stratification_column_v5']].groupby('stratification_column_v5').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v5'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v5'] in unique_attributes:
        score = str(int(row['total_score']/6))
        new_strat = str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v5']
rankings['stratification_column_v6'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Group remains together
clean = rankings[['username', 'stratification_column_v6']].groupby('stratification_column_v6').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v6'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v6'] in unique_attributes:
        return 'no_group'
    else:
        return row['stratification_column_v6']
rankings['stratification_column_v7'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)


rankings['stratification_column'] = rankings['stratification_column_v7'].copy()




In [4]:
train, test = train_test_split(rankings, test_size=0.35, random_state=0, stratify=rankings[['stratification_column']])

In [5]:
for strat in stratification:
    
    test_strat = test[[strat, 'username']].groupby(strat).nunique().reset_index()
    test_strat['username'] = test_strat['username'] / sum(test_strat['username'])
    test_strat.columns = [strat, 'test']
    
    train_strat = train[[strat, 'username']].groupby(strat).nunique().reset_index()
    train_strat['username'] = train_strat['username'] / sum(train_strat['username'])
    train_strat.columns = [strat, 'train']
    
    strat_df = test_strat.merge(train_strat, on=strat, how='inner')
    
    print(strat)
    display(strat_df)
    print()

language


Unnamed: 0,language,test,train
0,Deutsch,0.78626,0.777778
1,Français,0.21374,0.222222



field


Unnamed: 0,field,test,train
0,Biology,0.091603,0.106996
1,Chemistry,0.625954,0.625514
2,"Chemistry, Textiles",0.167939,0.152263
3,Fast track,0.053435,0.045267
4,Pharma Chemistry,0.061069,0.069959



year


Unnamed: 0,year,test,train
0,1st,0.450382,0.44856
1,2nd,0.374046,0.37037
2,3rd,0.175573,0.18107



gender


Unnamed: 0,gender,test,train
0,1,0.503817,0.460905
1,2,0.450382,0.497942
2,3,0.022901,0.020576
3,4,0.022901,0.020576



prior_4cat_knowledge


Unnamed: 0,prior_4cat_knowledge,test,train
0,0,0.122137,0.135802
1,0,0.328244,0.320988
2,1,0.198473,0.197531
3,2,0.076336,0.057613
4,3,0.274809,0.288066



total_score


Unnamed: 0,total_score,test,train
0,0,0.099237,0.115226
1,1,0.160305,0.144033
2,2,0.145038,0.152263
3,3,0.175573,0.18107
4,4,0.206107,0.226337
5,5,0.160305,0.127572
6,6,0.053435,0.053498





In [8]:
with open('../data/experiment_keys/flatstrat_testusernames.pkl', 'wb') as fp:
    pickle.dump(list(test['username']), fp)
with open('../data/experiment_keys/flatstrat_trainusernames.pkl', 'wb') as fp:
    pickle.dump(list(train['username']), fp)

In [18]:
for col in rankings.columns:
    print(col, '         ', rankings[rankings['username'] == 'wu7kdm6q'][col].iloc[0])

username           wu7kdm6q
start_time           {'time': 1620814035}
exploration_time           {'time': 1620814117}
ranking_task_time           {'time': 1620814181}
ranking           0312
ranking_confidence           wrong field
ranking_time           1620814270
q1           0.37
q1_conf           50
q1_time           1620814323
q2           1.59
q2_conf           0
q2_time           1620814355
q3           0.95
q3_conf           50
q3_time           1620814399
q4           0.2
q4_conf           35
q4_time           1620814447
q5_colour0           0
q5_colour1           0
q5_colour2           100
q5_colour3           0
q5_time           1620814512
q6_colour0           0
q6_colour1           100
q6_colour2           0
q6_colour3           0
q6_time           1620814711
q7_colour0           missing
q7_colour1           missing
q7_colour2           missing
q7_colour3           missing
q7_time           missing
q8_colour0           missing
q8_colour1           missing
q8_colour2         