## TrainTestSplit

Create a train-test-split for the datasets found in the pipeline

In [34]:
# imports
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
import os
import pandas as pd

In [9]:
# makedirs if not exist
os.makedirs("../../etl/data/intermediate/TrainTestSplit", exist_ok=False)

In [40]:
# parameters
FULL_DATA_PATH="../../etl/data/raw/01_extract.csv"
TRAIN_DATA_PATH="../../etl/data/intermediate/TrainTestSplit/01_train.csv"
TEST_DATA_PATH="../../etl/data/intermediate/TrainTestSplit/01_test.csv"
VAL_DATA_PATH="../../etl/data/intermediate/TrainTestSplit/01_val.csv"
RANDOM_STATE=42

In [60]:
df = pd.read_csv(FULL_DATA_PATH)

In [61]:
# group by sentences
df

Unnamed: 0,doc_id,verb_form,verb_form_start,verb_form_end,verb_lemma,arg1,arg1_start,arg1_end,arg1_pos,arg1_head,...,arg2,arg2_start,arg2_end,arg2_pos,arg2_head,arg2_head_start,arg2_head_end,rel_type,pred_serial,full_sentence_text
0,4bc8c13ddaa028e64a34ce08397157b846fb4de3ad26e3...,abgestraft,26,36,abstrafen,Alexis Tsipras,5,19,N,Alexis,...,Dass Alexis Tsipras jetzt abgestraft wurde,0,43,$.,.,143,144,neutral,"Predicate(type='neutral', args=(Head(sentence=...","Dass Alexis Tsipras jetzt abgestraft wurde , h..."
1,4bc8c13ddaa028e64a34ce08397157b846fb4de3ad26e3...,enttäuschen,140,151,enttäuschen,die neue Regierung,75,93,N,Regierung,...,die Hoffnungen auf einen spürbaren Aufschwung,94,139,N,Hoffnungen,98,108,neutral,"Predicate(type='neutral', args=(Head(sentence=...",Wenn die Kreditgeber Athen nicht zusätzlichen ...
2,4bc8c13ddaa028e64a34ce08397157b846fb4de3ad26e3...,beenden,119,126,beenden,Will er dem Land etwas Gutes tun,0,33,$.,.,...,die politische Polarisierung beenden,90,127,N,Polarisierung,105,118,neutral,"Predicate(type='neutral', args=(Head(sentence=...","Will er dem Land etwas Gutes tun , dann sollte..."
3,290f3971010f6d9385e896208f328948f5fb3f9bc0caeb...,akzeptieren,69,80,akzeptieren,Pajtim Kasami,0,14,N,Pajtim,...,die Kurzarbeit nun doch,81,105,N,Kurzarbeit,85,95,pro,"Predicate(type='pro', args=(Head(sentence=12, ...","Pajtim Kasami , Ermir Lenjani , Birama Ndoye u..."
4,290f3971010f6d9385e896208f328948f5fb3f9bc0caeb...,entliess,30,38,entlassen,der FC Sion,39,50,N,FC,...,Fussball Neun Spieler,8,29,N,Fussball,8,16,neutral,"Predicate(type='neutral', args=(Head(sentence=...",( dpa ) Fussball Neun Spieler entliess der FC ...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
832168,9701f0c776430a365f0619836f34658459be788b9a6646...,äussert,5,12,äussern,Frau A,13,19,N,Frau,...,massive Vorwürfe gegen ihre Vorgesetzten,20,60,N,Vorwürfe,28,36,neutral,"Predicate(type='neutral', args=(Head(sentence=...",Dann äussert Frau A massive Vorwürfe gegen ihr...
832169,9701f0c776430a365f0619836f34658459be788b9a6646...,stelle,184,190,stellen,das Kommando,171,183,N,Kommando,...,einen Antrag auf ihren Ausschluss,191,224,N,Antrag,197,203,neutral,"Predicate(type='neutral', args=(Head(sentence=...",Am 5. Dezember stellt der Chef der Abteilung M...
832170,d0fc434ce0021b0dff7b52157d352daaff8d1a65f4640a...,ausgewiesen,77,88,ausweisen,Der Konzern,0,11,N,Konzern,...,eine Liquidität,32,47,N,Liquidität,37,47,con,"Predicate(type='con', args=(Head(sentence=19, ...",Der Konzern hatte im April noch eine Liquiditä...
832171,d0fc434ce0021b0dff7b52157d352daaff8d1a65f4640a...,belebt,8,14,beleben,das Verkehrsaufkommen,20,41,N,Verkehrsaufkommen,...,das Verkehrsaufkommen,20,41,N,Verkehrsaufkommen,24,41,con,"Predicate(type='con', args=(Head(sentence=25, ...",Seitdem belebt sich das Verkehrsaufkommen auf ...


**Balance the dataset**

Make sure that each class is represented equally.

In [62]:
df = df.groupby('rel_type')
df = df.apply(lambda x: x.sample(df.size().min()).reset_index(drop=True))

df.rel_type.value_counts()

con        89673
neutral    89673
pro        89673
Name: rel_type, dtype: int64

**How big is the problem with multi-PAS per sentence?**

Which should not be split accross the dataset splits.

In [63]:
df_occ = df.groupby(["doc_id", "full_sentence_text"]).size().reset_index(name='counts')
df_occ_freq = df_occ.groupby(["counts"]).size()
df_occ_freq

counts
1    258388
2      5026
3       185
4         6
dtype: int64

**Option 1: Train-test-splitting**

Only problem: We may have sentences within the same documents with multiple PAS that are split accross the sets.

In [64]:
train, test_val = train_test_split(df, test_size=0.3, stratify=df["rel_type"], random_state=RANDOM_STATE)
test, val = train_test_split(test_val, test_size=0.5, stratify=test_val["rel_type"],random_state=RANDOM_STATE)

print("VALUE COUNTS: train")
print(train.rel_type.value_counts())
train.to_csv(TRAIN_DATA_PATH, index=False)

print("VALUE COUNTS: test")
print(test.rel_type.value_counts())
test.to_csv(TEST_DATA_PATH, index=False)

print("VALUE COUNTS: val")
print(val.rel_type.value_counts())
val.to_csv(VAL_DATA_PATH, index=False)


VALUE COUNTS: train
neutral    62771
pro        62771
con        62771
Name: rel_type, dtype: int64
VALUE COUNTS: test
neutral    13451
pro        13451
con        13451
Name: rel_type, dtype: int64
VALUE COUNTS: val
pro        13451
con        13451
neutral    13451
Name: rel_type, dtype: int64


**Option 2: Train-test-splitting**

With respecting group distribution

In [65]:
# preserve groups between sentences (in this case doc_id is safe enough)
splitter = GroupShuffleSplit(test_size=.30, n_splits=1, random_state=RANDOM_STATE)
split = splitter.split(df, groups=df['doc_id'])
train_inds, test_val_inds = next(split)

train = df.iloc[train_inds]
test_val = df.iloc[test_val_inds]

splitter = GroupShuffleSplit(test_size=.5, n_splits=1, random_state=RANDOM_STATE)
split = splitter.split(test_val, groups=test_val['doc_id'])
test_inds, val_inds = next(split)

test = test_val.iloc[test_inds]
val = test_val.iloc[val_inds]

In [66]:
print("VALUE COUNTS: train")
print(train.rel_type.value_counts())
train.to_csv(TRAIN_DATA_PATH, index=False)

print("VALUE COUNTS: test")
print(test.rel_type.value_counts())
test.to_csv(TEST_DATA_PATH, index=False)

print("VALUE COUNTS: val")
print(val.rel_type.value_counts())
val.to_csv(VAL_DATA_PATH, index=False)

VALUE COUNTS: train
con        62798
pro        62784
neutral    62725
Name: rel_type, dtype: int64
VALUE COUNTS: test
neutral    13493
pro        13415
con        13349
Name: rel_type, dtype: int64
VALUE COUNTS: val
con        13526
pro        13474
neutral    13455
Name: rel_type, dtype: int64


In [67]:
# a small test to see whether option no 2 achieves our goals.

X = np.ones(shape=(10, 2))
y = np.ones(shape=(10, 1))
groups = np.array([1, 1, 2, 2, 2, 3, 3, 3, 8, 8])
print(groups.shape)

gss = GroupShuffleSplit(n_splits=1, train_size=.8, random_state=42)
gss.get_n_splits()

for train_idx, test_idx in gss.split(X, y, groups):
    print("TRAIN:", [groups[i] for i in train_idx], "TEST:", [groups[i] for i in test_idx])

(10,)
TRAIN: [1, 1, 3, 3, 3, 8, 8] TEST: [2, 2, 2]
