Creating a balanced test set for flat training (cross validation)

In [1]:
import os
import yaml
import pickle

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

In [2]:
cd ../src/

/Users/cock/kDrive/PhD/Projects/Labs/beerslaw-lab/src


# Data


In [3]:
with open('../data/post_test/rankings_scored.pkl', 'rb') as fp:
    rankings = pickle.load(fp)
    rankings['prior_4cat_knowledge'] = rankings['prior_4cat'].apply(lambda x: x[0] if x != 'none' else 0)
    rankings['total_score'] = rankings['q1_score'] + rankings['q2_score'] + rankings['q3_score'] + rankings['q4_score'] + rankings['q5_score'] + rankings['q6_score']
    
# stratification columns (ideal separation)
stratification = [
    'language', 'field', 'year', 'gender', 'prior_4cat_knowledge', 'total_score'
]
rankings['stratification_column_v0'] = rankings['language'] + '_' + rankings['field'] + '_' + rankings['year'] + '_' + rankings['gender'].astype(str) + '_' + rankings['prior_4cat_knowledge'].astype(str) + '_' + rankings['total_score'].astype(str)
rankings = rankings.dropna(subset=['stratification_column_v0'])

# Check the labels that are alone 
## Prune language
clean = rankings[['username', 'stratification_column_v0']].groupby('stratification_column_v0').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v0'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v0'] in unique_attributes:
        new_strat = row['field'] + '_' + row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + str(row['total_score'])
        return new_strat
    else:
        return row['stratification_column_v0']
rankings['stratification_column_v1'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)
## Prune language
clean = rankings[['username', 'stratification_column_v1']].groupby('stratification_column_v1').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v1'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v1'] in unique_attributes:
        new_strat = row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + str(row['total_score'])
        return new_strat
    else:
        return row['stratification_column_v1']
rankings['stratification_column_v2'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

## Round Score
clean = rankings[['username', 'stratification_column_v2']].groupby('stratification_column_v2').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v2'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v2'] in unique_attributes:
        score = str(int(row['total_score']/6))
        new_strat = row['year'] + '_' + str(row['gender']) + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v2']
rankings['stratification_column_v3'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Minority genders together
clean = rankings[['username', 'stratification_column_v3']].groupby('stratification_column_v3').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v3'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v3'] in unique_attributes:
        score = str(int(row['total_score']/6))
        gender = str(int(row['gender'] > 1))
        new_strat = row['year'] + '_' + gender + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v3']
rankings['stratification_column_v4'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Strip the year out
clean = rankings[['username', 'stratification_column_v4']].groupby('stratification_column_v4').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v4'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v4'] in unique_attributes:
        score = str(int(row['total_score']/6))
        gender = str(int(row['gender'] > 1))
        new_strat = gender + '_' + str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v4']
rankings['stratification_column_v5'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Strip Gender
clean = rankings[['username', 'stratification_column_v5']].groupby('stratification_column_v5').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v5'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v5'] in unique_attributes:
        score = str(int(row['total_score']/6))
        new_strat = str(row['prior_4cat_knowledge']) + '_' + score
        return new_strat
    else:
        return row['stratification_column_v5']
rankings['stratification_column_v6'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)

# Group remains together
clean = rankings[['username', 'stratification_column_v6']].groupby('stratification_column_v6').nunique().reset_index()
clean = clean[clean['username'] == 1]
unique_attributes = list(clean['stratification_column_v6'])
def reduce_attribute_no_lang(row, unique_attributes) -> str:
    if row['stratification_column_v6'] in unique_attributes:
        return 'no_group'
    else:
        return row['stratification_column_v6']
rankings['stratification_column_v7'] = rankings.apply(lambda x: reduce_attribute_no_lang(x, unique_attributes), axis=1)


rankings['stratification_column'] = rankings['stratification_column_v7'].copy()




In [4]:
train, test = train_test_split(rankings, test_size=0.35, random_state=0, stratify=rankings[['stratification_column']])

In [5]:
for strat in stratification:
    
    test_strat = test[[strat, 'username']].groupby(strat).nunique().reset_index()
    test_strat['username'] = test_strat['username'] / sum(test_strat['username'])
    test_strat.columns = [strat, 'test']
    
    train_strat = train[[strat, 'username']].groupby(strat).nunique().reset_index()
    train_strat['username'] = train_strat['username'] / sum(train_strat['username'])
    train_strat.columns = [strat, 'train']
    
    strat_df = test_strat.merge(train_strat, on=strat, how='inner')
    
    print(strat)
    display(strat_df)
    print()

language


Unnamed: 0,language,test,train
0,Deutsch,0.78626,0.777778
1,Français,0.21374,0.222222



field


Unnamed: 0,field,test,train
0,Biology,0.091603,0.106996
1,Chemistry,0.625954,0.625514
2,"Chemistry, Textiles",0.167939,0.152263
3,Fast track,0.053435,0.045267
4,Pharma Chemistry,0.061069,0.069959



year


Unnamed: 0,year,test,train
0,1st,0.450382,0.44856
1,2nd,0.374046,0.37037
2,3rd,0.175573,0.18107



gender


Unnamed: 0,gender,test,train
0,1,0.503817,0.460905
1,2,0.450382,0.497942
2,3,0.022901,0.020576
3,4,0.022901,0.020576



prior_4cat_knowledge


Unnamed: 0,prior_4cat_knowledge,test,train
0,0,0.122137,0.135802
1,0,0.328244,0.320988
2,1,0.198473,0.197531
3,2,0.076336,0.057613
4,3,0.274809,0.288066



total_score


Unnamed: 0,total_score,test,train
0,0,0.099237,0.115226
1,1,0.160305,0.144033
2,2,0.145038,0.152263
3,3,0.175573,0.18107
4,4,0.206107,0.226337
5,5,0.160305,0.127572
6,6,0.053435,0.053498





In [6]:
test

Unnamed: 0,username,start_time,exploration_time,ranking_task_time,ranking,ranking_confidence,ranking_time,q1,q1_conf,q1_time,...,total_score,stratification_column_v0,stratification_column_v1,stratification_column_v2,stratification_column_v3,stratification_column_v4,stratification_column_v5,stratification_column_v6,stratification_column_v7,stratification_column
220,wu7kdm6q,{'time': 1620814035},{'time': 1620814117},{'time': 1620814181},0312,wrong field,1620814270,0.37,50,1620814323,...,3,Deutsch_Chemistry_1st_2_2_3,Chemistry_1st_2_2_3,1st_2_2_3,1st_2_2_0,1st_2_2_0,1st_2_2_0,1st_2_2_0,1st_2_2_0,1st_2_2_0
43,krd7m9vb,{'time': 1624537954},{'time': 1624538284},{'time': 1624538338},3120,75,1624538441,148,100,1624538502,...,2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2
285,x844md8u,{'time': 1624351606},{'time': 1624351868},{'time': 1624352134},3201,0,1624352252,0.37,45,1624352340,...,3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3,Deutsch_Pharma Chemistry_1st_1_0_3
179,tzbaaz7e,{'time': 1624455151},{'time': 1624455199},{'time': 1624455285},0312,50,1624455774,0.37,100,1624455990,...,3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3,Deutsch_Chemistry_2nd_1_1_3
164,9aagpn4d,{'time': 1631256182},{'time': 1631256337},{'time': 1631256986},2301,55,1631257289,"A = 100*1*(0,74/200) = 0,37",85,1631257555,...,4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4,Français_Chemistry_2nd_2_3_4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8,r2e3h3tm,{'time': 1622441148},{'time': 1622441495},{'time': 1622442244},2031,75,1622442286,3.5,25,1622442359,...,0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0,Deutsch_Chemistry_2nd_2_0_0
246,5es5yqs8,{'time': 1624345591},{'time': 1624346153},{'time': 1624346487},0213,75,1624346526,1.48,70,1624346603,...,4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4,Deutsch_Pharma Chemistry_1st_1_0_4
79,kq2e6dgu,{'time': 1623311170},{'time': 1623311324},{'time': 1623311339},3012,0,1623311443,keine ahnung,0,1623311530,...,2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2,Deutsch_Chemistry_1st_1_0_2
133,2ejxq2u8,{'time': 1624871156},{'time': 1624871315},{'time': 1624871415},2301,55,1624871498,Eine Höhere da das Gefäss kleiner ist.,80,1624871543,...,1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1,Deutsch_Chemistry_1st_1_0_1


In [7]:
with open('../data/experiment_keys/flatstrat_testusernames.pkl', 'wb') as fp:
    pickle.dump(list(test['username']), fp)
with open('../data/experiment_keys/flatstrat_trainusernames.pkl', 'wb') as fp:
    pickle.dump(list(train['username']), fp)

['wu7kdm6q',
 'krd7m9vb',
 'x844md8u',
 'tzbaaz7e',
 '9aagpn4d',
 'jnjd79vh',
 '2ae6q3hw',
 'chca7sdb',
 '6cs3annc',
 'neuwx57b',
 '2cetp8yc',
 'eau7bsmq',
 '4vuc8rr3',
 'xe7c36dk',
 'g43q3d94',
 'uqven68r',
 'v5w2e3zw',
 'zkrr45y5',
 'r74r26kt',
 '8ga2zn5h',
 'zy256ycq',
 '84nmc3df',
 '9qvk2wew',
 'cjarhqn9',
 'phupma28',
 '9sgu2tbg',
 '47ce49e4',
 'ryfqnvfh',
 'ydws5xx9',
 'gyerx2d9',
 'x5sm9pfu',
 'z8hvrhwb',
 'y9tk3ysm',
 'uqzxsym7',
 'araav4jr',
 'oikzz9af',
 '3s6pz8qy',
 'rwax4gk7',
 '85pdk9mq',
 'jx3yyy26',
 'zgyc948n',
 'wyj76ntd',
 '6ruh7enb',
 'fj5tdybn',
 '43e33t3h',
 'nw65tu6j',
 'upkt7qb4',
 'wvxkvhne',
 '7ygreyfc',
 '4rhnvke9',
 's78drqcg',
 'pbxyuw7u',
 'aurjfgnn',
 'smqjhu44',
 'x3ykresy',
 'mhek2323',
 '33asfz2u',
 'wnurkn96',
 'j6nndaxp',
 'wxz98urt',
 '88kjzd8b',
 '63xqh9t5',
 'chm4sr6j',
 '2hr6mkdc',
 'wpszzhxa',
 'mcjaj2aj',
 'dubyutqd',
 'v9sra3j2',
 'sgdgynxy',
 'uagxrke6',
 'unkrat9w',
 'uenn9vgu',
 'upp6pqmx',
 'knun7j9s',
 '3gqs3sgc',
 'szvqb37f',
 'sz8qvgyv',