In [22]:
import random
import numpy as np
import torch
import pandas as pd
from pathlib import Path
from torch.utils.data import random_split

In [23]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [24]:
model_name = 'tw-roberta'

In [25]:
path_data_dir = Path('../data')
path_final = path_data_dir / Path('final') / Path(model_name) / Path('full.csv')
path_split_dir = path_data_dir / Path('split')



path_split_dir = path_data_dir / Path('split') / Path(model_name)

path_mbti = path_split_dir / Path('mbti.csv')
path_bigfive_c = path_split_dir / Path('bigfive_c.csv')
path_bigfive_s = path_split_dir / Path('bigfive_s.csv')

In [26]:
data = pd.read_csv(path_final)
data

Unnamed: 0,AUTHOR,mbtiEXT,mbtiSEN,mbtiTHI,mbtiJUD,cEXT,cNEU,cAGR,cCON,cOPN,...,758,759,760,761,762,763,764,765,766,767
0,-9221022384933360074,0.0,0.0,1.0,0.0,,,,,,...,-0.060962,-0.022286,-0.146039,-0.055642,-0.021573,0.056082,0.105208,-0.247814,-0.008487,0.002584
1,-9220321758358532571,,,,,1.0,1.0,1.0,1.0,1.0,...,-0.084201,0.005282,-0.147341,-0.029248,0.010502,0.081995,-0.056109,-0.160702,-0.012305,0.027042
2,-9220031623198266213,0.0,1.0,1.0,1.0,,,,,,...,-0.052625,-0.021188,-0.144721,-0.074823,-0.009746,0.036002,0.028212,-0.135139,-0.008191,-0.007997
3,-9219633155989415906,0.0,0.0,1.0,0.0,,,,,,...,-0.064515,-0.010986,-0.162659,-0.070636,-0.012977,0.066532,0.039103,-0.140113,-0.008603,0.000269
4,-9219237589017844173,0.0,0.0,0.0,0.0,,,,,,...,-0.074869,-0.016893,-0.147531,-0.062592,0.028137,0.053380,0.011517,-0.106821,-0.012819,0.006908
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21679,9220307502816513261,1.0,0.0,0.0,0.0,,,,,,...,-0.047444,-0.016894,-0.154476,-0.064496,0.002011,0.055266,0.054305,-0.141798,-0.012529,0.013597
21680,9220556403022889385,0.0,0.0,1.0,1.0,,,,,,...,-0.038752,-0.029328,-0.130875,-0.067572,-0.017599,0.048109,0.071662,-0.167837,0.005230,0.012262
21681,9221651641191792423,0.0,0.0,0.0,0.0,,,,,,...,-0.068792,-0.019671,-0.167552,-0.061184,0.008157,0.051426,0.075112,-0.188497,-0.017943,0.017325
21682,9222607780732095571,0.0,0.0,0.0,1.0,,,,,,...,-0.064733,-0.016261,-0.132525,-0.057998,-0.002772,0.075129,0.055707,-0.120860,-0.014478,0.005084


In [27]:
emb_cols = list(map(str, range(768)))
stat_cols = ['NUM_CHARS', 'NUM_UPPERCASED', 'NUM_EMOJI', 'NUM_POSTS']

mbti_columns = ['mbtiEXT',	'mbtiSEN',	'mbtiTHI',	'mbtiJUD']
bigfive_c_columns = ['cEXT',	'cNEU',	'cAGR',	'cCON',	'cOPN']
bigfive_s_columns = ['sEXT',	'sNEU',	'sAGR',	'sCON',	'sOPN']
bigfive_columns = bigfive_c_columns + bigfive_s_columns

target_cols = mbti_columns + bigfive_columns

In [28]:
author_list = [('AUTHOR', 'AUTHOR')]
stat_list = [('STATS', stat) for stat in stat_cols]
emb_list = [('CLS', emb) for emb in emb_cols]
target_list = [('TARGET', target) for target in target_cols]
tuples = author_list + stat_list + emb_list + target_list
multiindex = pd.MultiIndex.from_tuples(tuples, names=['GROUP', 'FEATURE'])

In [29]:
tuples = author_list + stat_list + emb_list + target_list

In [30]:
multiindex = pd.MultiIndex.from_tuples(tuples, names=['GROUP', 'FEATURE'])

In [31]:
data_multiindexed = pd.DataFrame(columns=multiindex)
data_multiindexed

GROUP,AUTHOR,STATS,STATS,STATS,STATS,CLS,CLS,CLS,CLS,CLS,...,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET
FEATURE,AUTHOR,NUM_CHARS,NUM_UPPERCASED,NUM_EMOJI,NUM_POSTS,0,1,2,3,4,...,cEXT,cNEU,cAGR,cCON,cOPN,sEXT,sNEU,sAGR,sCON,sOPN


In [32]:
data_multiindexed['AUTHOR'] = data['AUTHOR']
data_multiindexed['STATS'] = data[stat_cols]
data_multiindexed['CLS'] = data[emb_cols]
data_multiindexed['TARGET'] = data[target_cols]

In [33]:
data_multiindexed = data_multiindexed.set_index(('AUTHOR', 'AUTHOR'), drop=True)

In [34]:
data_multiindexed.index.names = ['AUTHOR']

In [35]:
data_multiindexed

GROUP,STATS,STATS,STATS,STATS,CLS,CLS,CLS,CLS,CLS,CLS,...,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET,TARGET
FEATURE,NUM_CHARS,NUM_UPPERCASED,NUM_EMOJI,NUM_POSTS,0,1,2,3,4,5,...,cEXT,cNEU,cAGR,cCON,cOPN,sEXT,sNEU,sAGR,sCON,sOPN
AUTHOR,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
-9221022384933360074,82.842105,7.924812,0.067669,133,0.011379,0.121943,-0.035148,-0.159384,0.204076,-0.159495,...,,,,,,,,,,
-9220321758358532571,4075.000000,139.000000,0.000000,1,-0.037997,0.160573,-0.025189,-0.198521,0.314517,0.054473,...,1.0,1.0,1.0,1.0,1.0,,,,,
-9220031623198266213,59.416667,2.958333,0.341667,120,0.007463,0.160259,-0.032039,-0.139000,0.192175,-0.062065,...,,,,,,,,,,
-9219633155989415906,178.041667,5.354167,0.000000,48,0.004201,0.146285,-0.031886,-0.149592,0.195900,-0.065319,...,,,,,,,,,,
-9219237589017844173,77.430380,3.303797,0.221519,158,0.015168,0.169178,-0.036061,-0.185521,0.251327,-0.082999,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9220307502816513261,125.840000,4.840000,0.000000,50,0.016355,0.150665,-0.038156,-0.169107,0.177899,-0.105115,...,,,,,,,,,,
9220556403022889385,134.159091,6.681818,0.000000,44,0.020765,0.136793,-0.028575,-0.145969,0.188023,-0.111762,...,,,,,,,,,,
9221651641191792423,173.071429,4.976190,0.000000,42,0.011746,0.165671,-0.030071,-0.145691,0.181874,-0.086711,...,,,,,,,,,,
9222607780732095571,186.740000,5.980000,0.000000,50,0.030282,0.130152,-0.034604,-0.149935,0.169514,-0.074337,...,,,,,,,,,,


In [36]:
data_multiindexed_sorted = data_multiindexed.sort_index(axis=1)
all_columns_except_target = data_multiindexed_sorted.loc[:, data_multiindexed_sorted.columns.get_level_values(0) != 'TARGET']
mbti_selection = data_multiindexed_sorted.loc[:, (slice('TARGET'), mbti_columns)]
bigfive_c_selection = data_multiindexed_sorted.loc[:, (slice('TARGET'), bigfive_c_columns)]
bigfive_s_selection = data_multiindexed_sorted.loc[:, (slice('TARGET'), bigfive_s_columns)]
mbti_df = pd.concat([all_columns_except_target, mbti_selection], axis=1)
bf_c_df = pd.concat([all_columns_except_target, bigfive_c_selection], axis=1)
bf_s_df = pd.concat([all_columns_except_target, bigfive_s_selection], axis=1)

In [37]:
mbti_df

GROUP,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,STATS,STATS,STATS,STATS,TARGET,TARGET,TARGET,TARGET
FEATURE,0,1,10,100,101,102,103,104,105,106,...,98,99,NUM_CHARS,NUM_EMOJI,NUM_POSTS,NUM_UPPERCASED,mbtiEXT,mbtiJUD,mbtiSEN,mbtiTHI
AUTHOR,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
-9221022384933360074,0.011379,0.121943,0.021768,0.053510,-0.008066,-0.076857,0.010531,-0.013488,-0.006139,0.046846,...,0.091728,-0.032075,82.842105,0.067669,133,7.924812,0.0,0.0,0.0,1.0
-9220321758358532571,-0.037997,0.160573,0.040159,0.073303,0.019503,-0.152491,0.078980,-0.034442,-0.002517,0.143091,...,0.114990,0.024144,4075.000000,0.000000,1,139.000000,,,,
-9220031623198266213,0.007463,0.160259,0.010096,-0.028306,-0.001064,-0.111850,0.091180,-0.041328,-0.009941,0.046663,...,0.081681,-0.024641,59.416667,0.341667,120,2.958333,0.0,1.0,1.0,1.0
-9219633155989415906,0.004201,0.146285,-0.002562,0.016444,0.017426,-0.079033,0.035775,-0.004088,-0.014403,0.057938,...,0.086518,-0.040992,178.041667,0.000000,48,5.354167,0.0,0.0,0.0,1.0
-9219237589017844173,0.015168,0.169178,-0.024717,-0.012770,0.027100,-0.114756,0.064650,-0.014268,-0.015988,0.054172,...,0.076033,-0.045083,77.430380,0.221519,158,3.303797,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9220307502816513261,0.016355,0.150665,0.000408,0.023503,-0.000010,-0.115132,0.059027,-0.026924,-0.003065,0.056899,...,0.075520,-0.036230,125.840000,0.000000,50,4.840000,1.0,0.0,0.0,0.0
9220556403022889385,0.020765,0.136793,0.013010,0.007891,0.010322,-0.090868,0.045060,-0.014999,-0.032346,0.043911,...,0.070653,-0.028178,134.159091,0.000000,44,6.681818,0.0,1.0,0.0,1.0
9221651641191792423,0.011746,0.165671,0.011132,0.027535,0.020733,-0.053291,0.037028,-0.015085,-0.013423,0.065237,...,0.090778,-0.045856,173.071429,0.000000,42,4.976190,0.0,0.0,0.0,0.0
9222607780732095571,0.030282,0.130152,0.002258,-0.011900,0.011808,-0.066177,0.034080,-0.017137,0.002918,0.061081,...,0.074286,-0.019485,186.740000,0.000000,50,5.980000,0.0,1.0,0.0,0.0


In [38]:
mbti_df = mbti_df.dropna(axis=0)
bf_c_df = bf_c_df.dropna(axis=0)
bf_s_df = bf_s_df.dropna(axis=0)
mbti_df.to_csv(path_mbti, index=True, header=True)
bf_c_df.to_csv(path_bigfive_c, index=True, header=True)
bf_s_df.to_csv(path_bigfive_s, index=True, header=True)

In [39]:
mbti_df

GROUP,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,STATS,STATS,STATS,STATS,TARGET,TARGET,TARGET,TARGET
FEATURE,0,1,10,100,101,102,103,104,105,106,...,98,99,NUM_CHARS,NUM_EMOJI,NUM_POSTS,NUM_UPPERCASED,mbtiEXT,mbtiJUD,mbtiSEN,mbtiTHI
AUTHOR,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
-9221022384933360074,0.011379,0.121943,0.021768,0.053510,-0.008066,-0.076857,0.010531,-0.013488,-0.006139,0.046846,...,0.091728,-0.032075,82.842105,0.067669,133,7.924812,0.0,0.0,0.0,1.0
-9220031623198266213,0.007463,0.160259,0.010096,-0.028306,-0.001064,-0.111850,0.091180,-0.041328,-0.009941,0.046663,...,0.081681,-0.024641,59.416667,0.341667,120,2.958333,0.0,1.0,1.0,1.0
-9219633155989415906,0.004201,0.146285,-0.002562,0.016444,0.017426,-0.079033,0.035775,-0.004088,-0.014403,0.057938,...,0.086518,-0.040992,178.041667,0.000000,48,5.354167,0.0,0.0,0.0,1.0
-9219237589017844173,0.015168,0.169178,-0.024717,-0.012770,0.027100,-0.114756,0.064650,-0.014268,-0.015988,0.054172,...,0.076033,-0.045083,77.430380,0.221519,158,3.303797,0.0,0.0,0.0,0.0
-9214568075844254832,0.012004,0.170368,-0.006627,-0.007512,0.014462,-0.092926,0.072816,-0.030916,-0.024403,0.053216,...,0.083053,-0.040561,100.617284,0.000000,81,6.716049,0.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9220307502816513261,0.016355,0.150665,0.000408,0.023503,-0.000010,-0.115132,0.059027,-0.026924,-0.003065,0.056899,...,0.075520,-0.036230,125.840000,0.000000,50,4.840000,1.0,0.0,0.0,0.0
9220556403022889385,0.020765,0.136793,0.013010,0.007891,0.010322,-0.090868,0.045060,-0.014999,-0.032346,0.043911,...,0.070653,-0.028178,134.159091,0.000000,44,6.681818,0.0,1.0,0.0,1.0
9221651641191792423,0.011746,0.165671,0.011132,0.027535,0.020733,-0.053291,0.037028,-0.015085,-0.013423,0.065237,...,0.090778,-0.045856,173.071429,0.000000,42,4.976190,0.0,0.0,0.0,0.0
9222607780732095571,0.030282,0.130152,0.002258,-0.011900,0.011808,-0.066177,0.034080,-0.017137,0.002918,0.061081,...,0.074286,-0.019485,186.740000,0.000000,50,5.980000,0.0,1.0,0.0,0.0


In [40]:
bf_c_df

GROUP,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,STATS,STATS,STATS,STATS,TARGET,TARGET,TARGET,TARGET,TARGET
FEATURE,0,1,10,100,101,102,103,104,105,106,...,99,NUM_CHARS,NUM_EMOJI,NUM_POSTS,NUM_UPPERCASED,cAGR,cCON,cEXT,cNEU,cOPN
AUTHOR,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
-9220321758358532571,-0.037997,0.160573,0.040159,0.073303,0.019503,-0.152491,0.078980,-0.034442,-0.002517,0.143091,...,0.024144,4075.000000,0.0,1,139.000000,1.0,1.0,1.0,1.0,1.0
-9217395741639621641,0.031393,0.112005,-0.034214,0.124763,-0.013985,-0.102971,0.008509,-0.073899,0.076637,0.036987,...,-0.002817,1855.000000,0.0,1,25.000000,1.0,0.0,0.0,0.0,1.0
-9201712306031068767,0.014218,0.132446,-0.055054,0.009627,0.036285,-0.113526,-0.053585,-0.020553,0.061702,0.079515,...,0.018606,1988.000000,0.0,1,54.000000,0.0,0.0,0.0,1.0,1.0
-9198712149500792736,0.045070,0.123840,0.097832,0.000422,-0.019057,0.023173,0.051116,-0.019236,0.048788,-0.001122,...,0.059733,3099.000000,0.0,1,86.000000,1.0,0.0,1.0,0.0,1.0
-9195745603526562446,-0.000577,0.144314,0.015660,0.057855,-0.024550,0.029852,-0.038563,-0.026313,-0.008585,0.070145,...,-0.028561,3677.000000,0.0,1,66.000000,0.0,0.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9170427891322287974,0.049749,0.101124,-0.033485,0.051001,-0.022178,-0.101269,0.014635,-0.102853,0.035185,0.036142,...,-0.095056,3178.000000,0.0,1,99.000000,1.0,1.0,1.0,0.0,1.0
9170530906082270117,0.026764,0.147973,0.003919,0.041752,0.015701,-0.071425,-0.018923,-0.011490,-0.043861,0.024762,...,-0.023291,1739.000000,0.0,1,44.000000,0.0,1.0,1.0,1.0,1.0
9172401295438939558,-0.006680,0.087605,0.035899,0.091709,-0.028121,-0.032639,-0.007399,0.004232,0.024861,0.065059,...,-0.002603,2813.000000,0.0,1,78.000000,1.0,1.0,1.0,0.0,1.0
9180056457229629599,0.089631,0.112762,-0.016099,-0.027667,-0.011114,-0.129113,0.037635,0.041825,0.040983,0.065133,...,-0.028646,3450.000000,0.0,1,77.000000,1.0,1.0,0.0,0.0,0.0


In [41]:
len(bf_s_df['CLS'].columns)

768