In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

In [3]:
df = pd.read_csv('./dataframe.csv')
df = df.drop(['Unnamed: 0'], axis=1)
df = df[df.language != 'Dutch']

In [4]:
def concat_dummies(df, dummies):
    for dummy in list(dummies):
        df[dummy] = dummies[dummy]

In [5]:
def merge_wrong_corrected_columns(df, wrong_column, corrected_column, prefix):
    wrong = pd.get_dummies(df[wrong_column], prefix=prefix)
    corrected = pd.get_dummies(df[corrected_column], prefix=prefix)
    for wc in list(set(wrong) - set(corrected)):
        corrected[wc] = 0
    for cc in list(set(corrected) - set(wrong)):
        wrong[cc] = 0
    wrong = wrong.where(wrong == 1, -3)
    wrong = wrong.where(wrong == -3, -1)
    corrected = corrected.where(corrected != 1, 4)
    merged = wrong.add(corrected)
    merged = merged.where(merged != 3, 0)
    return merged

In [6]:
merged_pos = merge_wrong_corrected_columns(df, 'error_pos', 'correct_pos', 'u')
merged_pos_2 = merge_wrong_corrected_columns(df, 'error_pos_2', 'correct_pos_2', 'b')
merged_pos_2 = merged_pos_2.where(merged_pos_2 != 1, 0.5)
merged_pos_2 = merged_pos_2.where(merged_pos_2 != -1, -1.5)
merged_pos_2 = merged_pos_2.where(merged_pos_2 != 0, -0.5)
merged_pos_3 = merge_wrong_corrected_columns(df, 'error_pos_3', 'correct_pos_3', 't')
merged_pos_3 = merged_pos_3.where(merged_pos_3 != 1, 0.25)
merged_pos_3 = merged_pos_3.where(merged_pos_3 != -1, -1.75)
merged_pos_3 = merged_pos_3.where(merged_pos_3 != 0, -0.75)
languages = pd.get_dummies(df['language'], prefix='lang')
concat_dummies(df, merged_pos)
concat_dummies(df, merged_pos_2)
concat_dummies(df, merged_pos_3)
concat_dummies(df, languages)
df.describe()

Unnamed: 0,score,error_position,u_*,u_CC,u_CD,u_DT,u_EX,u_FW,u_IN,u_JJ,...,lang_Italian,lang_Japanese,lang_Korean,lang_Polish,lang_Portuguese,lang_Russian,lang_Spanish,lang_Swedish,lang_Thai,lang_Turkish
count,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,...,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0,45268.0
mean,26.375762,32.776398,-2.960436,-2.935031,-2.985774,-2.470885,-2.97908,-2.999293,-2.43216,-2.728926,...,0.053901,0.055072,0.05114,0.062517,0.052885,0.073606,0.160577,0.009808,0.06097,0.064549
std,5.52249,33.920929,0.352294,0.469493,0.218181,1.257216,0.273868,0.049736,1.226062,0.910498,...,0.225825,0.228123,0.220285,0.242094,0.223806,0.261132,0.367144,0.098551,0.239278,0.245731
min,0.0,0.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,23.0,10.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,26.0,22.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,30.0,43.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,40.0,341.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [7]:
y = df['error_type']
data = df[list(languages) + list(merged_pos) + list(merged_pos_2) 
        + list(merged_pos_3) + ['score']]

In [13]:
x_train, x_test, y_train, y_test = train_test_split(data, y, test_size=0.15)

In [14]:
dtc = DecisionTreeClassifier(random_state=0, max_depth=20)

In [15]:
dtc.fit(x_train, y_train)
dtc.score(x_test, y_test)

0.5592696215579444

In [16]:
clf = RandomForestClassifier(n_estimators=1000, max_depth=30, random_state=0)
clf.fit(x_train, y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=30, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=1000, n_jobs=None,
            oob_score=False, random_state=0, verbose=0, warm_start=False)

In [17]:
clf.score(x_test, y_test)

0.6186128699749669

In [20]:
print(clf.feature_importances_)

[7.68060751e-03 8.42604501e-03 1.23801276e-02 7.69336140e-03
 8.96318691e-03 7.66132367e-03 7.61188267e-03 6.90292391e-03
 7.84632890e-03 7.33742369e-03 8.84039160e-03 1.44762474e-02
 2.44437586e-03 8.24308959e-03 8.34956953e-03 3.88101732e-03
 8.01814089e-03 1.60917670e-03 5.22435733e-02 1.85462500e-03
 1.03612333e-04 5.33627173e-02 1.62763897e-02 1.38855142e-03
 9.75068463e-04 8.00578293e-03 2.75502165e-02 4.70907211e-03
 1.46893523e-04 1.77136834e-02 8.22945049e-04 3.06828486e-05
 2.12909183e-02 7.44910688e-03 1.58584174e-02 7.43535027e-04
 3.30774699e-04 1.80181205e-03 1.67932959e-02 9.66169761e-06
 2.02740331e-02 2.02229199e-02 1.44428220e-02 7.11226851e-03
 1.54109540e-02 1.32097400e-02 3.11467316e-03 2.63715121e-03
 5.72762139e-05 3.06909199e-03 4.93626674e-02 5.13349686e-03
 6.09915377e-03 1.39770539e-03 1.55892249e-02 9.12828429e-04
 6.39473852e-05 1.49088673e-02 1.08607615e-02 9.95195163e-04
 5.06419235e-04 3.69044910e-03 2.09166302e-02 5.31328849e-03
 3.33907614e-05 9.386360

In [21]:
data

Unnamed: 0,lang_Catalan,lang_Chinese,lang_French,lang_German,lang_Greek,lang_Italian,lang_Japanese,lang_Korean,lang_Polish,lang_Portuguese,...,_ WRB EX,_ WRB JJ,_ WRB MD,_ WRB NN,_ WRB PRP,_ WRB PRP$,_ WRB RB,_ WRB TO,_ WRB VBP,score
0,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
1,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
2,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
3,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
4,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
5,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
6,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
7,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
8,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
9,0,0,0,0,0,0,0,0,0,0,...,-100,-100,-100,-100,-100,-100,-100,-100,-100,25.0
