In [2]:
# import libraries
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression


### Text handling module

In [11]:
from process import texthandling as th
a = pd.DataFrame({'id':   [1,2,3],
                        'name': ['Mr.Tanaka', 'Ms.Suzuki', 'Mrs.Black']})
b = ['Mrs', 'Ms', 'Mr']

x = th.FlagWords(a['name'], b, independent=True)
x['Mrs']

0    0
1    0
2    1
Name: Mrs, dtype: object

In [21]:
def FlagWords(text_df, words, independent=False):
    flags_df = pd.DataFrame(columns=words)
    
    for text in text_df:
        _list = []
        for word in words:
            if word in text:
                _list.append(word)
                if independent == True:
                    break

        _tmp_df = pd.DataFrame([[int(i in _list) for i in words]], columns=words)
        flags_df = flags_df.append(_tmp_df, ignore_index=True)
    flags_df = flags_df.astype('int')
    return flags_df

0    0
1    0
2    1
Name: Mrs, dtype: int64

# Cross Validation

In [98]:
# Generate sample data
np.random.seed(123)
df = pd.DataFrame({'explain1': np.random.randn(1000),
                   'explain2': np.random.randn(1000),
                   'explain3': np.random.randn(1000),
                   'error': np.random.randn(1000)})
df = df.assign(score = df.explain1 + df.explain2 + df.explain3 + df.error*0.5)
df = df.assign(survival = df.score.apply(lambda x: 1 if x >= 0.2 else 0))


In [99]:
lr = LogisticRegression()
x_var = df[['explain1', 'explain2', 'explain3']]
y_var = df.survival

lr.fit(x_var, y_var)

print (lr.intercept_, lr.coef_)
print (lr.score(x_var, y_var))


[-0.60972616] [[ 3.07527999  3.30663655  3.11866776]]
0.912


In [100]:
def fitmodel(method, x, y):
    fit = method.fit(x, y)
    return fit
    
x = fitmodel(method=lr, x=x_var, y=y_var)

In [101]:
pred = x.predict(x_var)
actual = y_var

compare = pd.DataFrame({'pred': pred, 'actual': actual})
compare['correct'] = (compare.pred == compare.actual).apply(lambda x: int(x))
cross = pd.crosstab(compare.pred, compare.actual, margins=True)

accuracy = compare.correct.sum() / compare.correct.count()
print('accuracy: {0:.3f}'.format(accuracy))

accuracy: 0.912


### calc accuracy

In [227]:
def accuracy(pred, actual):
    compare = pd.DataFrame({'pred': pred, 'actual': actual})
    compare['correct'] = (compare.pred == compare.actual).apply(lambda x: int(x))
    
    crosstab = pd.crosstab(compare.pred, compare.actual, margins=True)
    accuracy = compare.correct.sum() / compare.correct.count()
    size = len(compare)
    print('size: {0:d}, accuracy: {1:.3f}'.format(size, accuracy))
    
    output = {'crosstable': crosstab, 'accuracy': accuracy, 'compare': compare}
    return output

_pred = x.predict(x_var)
_actual = y_var
accu = accuracy(_pred, _actual)

size: 1000, accuracy: 0.912


### data separation

In [215]:
def dfsplit(train, num, seed):
    np.random.seed(seed)
    splitdf = train.assign(splitflg = np.random.randint(0, num, len(train)))
    return splitdf
    
print(dfsplit(df, 5, 123).head(10))

      error  explain1  explain2  explain3     score  survival  splitflg
0 -0.450599 -1.085631 -0.748827 -1.774224 -3.833981         0         2
1  0.609590  0.997345  0.567595 -1.201377  0.668358         1         4
2  1.173744  0.282978  0.718151  1.096257  2.684258         1         2
3  0.871815 -1.506295 -0.999381  0.861037 -1.208731         0         1
4  1.904723 -0.578600  0.474898 -1.520367 -0.671707         0         3
5  0.133491  1.651437 -1.868500 -0.447440 -0.597758         0         2
6  1.281844 -2.426679 -0.202659  0.463487 -1.524929         0         3
7 -1.159187 -0.428913 -1.134248  0.392493 -1.750261         0         1
8  0.870502  1.265936 -0.807699 -1.627167 -0.733679         0         1
9 -0.209635 -0.866740 -1.276077  0.260010 -1.987626         0         0


In [222]:
df1 = dfsplit(df, 5, 123)
df1[df1.splitflg != 3]


Unnamed: 0,error,explain1,explain2,explain3,score,survival,splitflg
0,-0.450599,-1.085631,-0.748827,-1.774224,-3.833981,0,2
1,0.609590,0.997345,0.567595,-1.201377,0.668358,1,4
2,1.173744,0.282978,0.718151,1.096257,2.684258,1,2
3,0.871815,-1.506295,-0.999381,0.861037,-1.208731,0,1
5,0.133491,1.651437,-1.868500,-0.447440,-0.597758,0,2
7,-1.159187,-0.428913,-1.134248,0.392493,-1.750261,0,1
8,0.870502,1.265936,-0.807699,-1.627167,-0.733679,0,1
9,-0.209635,-0.866740,-1.276077,0.260010,-1.987626,0,0
10,-0.146671,-0.678886,0.553626,-0.607853,-0.806448,0,1
11,1.380591,-0.094709,0.553874,1.198668,2.348129,1,1
