In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.utils import shuffle
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

df_train = pd.read_csv('./titanic/train.csv')
df_test  = pd.read_csv('./titanic/test.csv')
df_sub   = pd.read_csv('./titanic/gender_submission.csv')


In [2]:
df_train.drop(['Name','Ticket','Cabin'],axis=1,inplace=True)
sex = pd.get_dummies(df_train['Sex'],drop_first=True)
embark = pd.get_dummies(df_train['Embarked'],drop_first=True)
df_train = pd.concat([df_train,sex,embark],axis=1)
df_train.drop(['Sex','Embarked'],axis=1,inplace=True)

df_test.drop( ['Name','Ticket','Cabin'],axis=1,inplace=True)
sex = pd.get_dummies(df_test['Sex'],drop_first=True)
embark = pd.get_dummies(df_test['Embarked'],drop_first=True)
df_test = pd.concat([df_test,sex,embark],axis=1)
df_test.drop(['Sex','Embarked'],axis=1,inplace=True)

In [4]:
df_train.head()

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare,male,Q,S
0,1,0,3,22.0,1,0,7.25,1,0,1
1,2,1,1,38.0,1,0,71.2833,0,0,0
2,3,1,3,26.0,0,0,7.925,0,0,1
3,4,1,1,35.0,1,0,53.1,0,0,1
4,5,0,3,35.0,0,0,8.05,1,0,1


In [5]:
df_test.head()

Unnamed: 0,PassengerId,Pclass,Age,SibSp,Parch,Fare,male,Q,S
0,892,3,34.5,0,0,7.8292,1,1,0
1,893,3,47.0,1,0,7.0,0,0,1
2,894,2,62.0,0,0,9.6875,1,1,0
3,895,3,27.0,0,0,8.6625,1,0,1
4,896,3,22.0,1,1,12.2875,0,0,1


In [8]:
df_train[df_train.isnull().values == True]

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare,male,Q,S
5,6,0,3,,0,0,8.4583,1,1,0
17,18,1,2,,0,0,13.0000,1,0,1
19,20,1,3,,0,0,7.2250,0,0,0
26,27,0,3,,0,0,7.2250,1,0,0
28,29,1,3,,0,0,7.8792,0,1,0
...,...,...,...,...,...,...,...,...,...,...
859,860,0,3,,0,0,7.2292,1,0,0
863,864,0,3,,8,2,69.5500,0,0,1
868,869,0,3,,0,0,9.5000,1,0,1
878,879,0,3,,0,0,7.8958,1,0,1


In [11]:
def Class_Mean_Age(column):
    age = column[0]
    Class = column[1]
    
    if pd.isnull(age):
        if Class ==1:
            return 37
        elif Class == 2:
            return 29
        else:
            return 24
    else:
        return age
df_train['Age'] = df_train[['Age','Pclass']].apply(Class_Mean_Age, axis=1)

In [13]:
len(df_train[df_train.isnull().values == True])

0

In [14]:
Scaler1 = StandardScaler()
Scaler2 = StandardScaler()

train_columns = df_train.columns
test_columns  = df_test.columns

df_train = pd.DataFrame(Scaler1.fit_transform(df_train))
df_test  = pd.DataFrame(Scaler2.fit_transform(df_test))

df_train.columns = train_columns
df_test.columns  = test_columns

features = df_train.iloc[:,2:].columns.tolist()
target   = df_train.loc[:, 'Survived'].name

X_train = df_train.iloc[:,2:].values
y_train = df_train.loc[:, 'Survived'].values

In [15]:
train_columns

Index(['PassengerId', 'Survived', 'Pclass', 'Age', 'SibSp', 'Parch', 'Fare',
       'male', 'Q', 'S'],
      dtype='object')

In [16]:
df_train

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare,male,Q,S
0,-1.730108,-0.789272,0.827377,-0.533834,0.432793,-0.473674,-0.502445,0.737695,-0.307562,0.619306
1,-1.726220,1.266990,-1.566107,0.674891,0.432793,-0.473674,0.786845,-1.355574,-0.307562,-1.614710
2,-1.722332,1.266990,0.827377,-0.231653,-0.474545,-0.473674,-0.488854,-1.355574,-0.307562,0.619306
3,-1.718444,1.266990,-1.566107,0.448255,0.432793,-0.473674,0.420730,-1.355574,-0.307562,0.619306
4,-1.714556,-0.789272,0.827377,0.448255,-0.474545,-0.473674,-0.486337,0.737695,-0.307562,0.619306
...,...,...,...,...,...,...,...,...,...,...
886,1.714556,-0.789272,-0.369365,-0.156107,-0.474545,-0.473674,-0.386671,0.737695,-0.307562,0.619306
887,1.718444,1.266990,-1.566107,-0.760469,-0.474545,-0.473674,-0.044381,-1.355574,-0.307562,0.619306
888,1.722332,-0.789272,0.827377,-0.382743,0.432793,2.008933,-0.176263,-1.355574,-0.307562,0.619306
889,1.726220,1.266990,-1.566107,-0.231653,-0.474545,-0.473674,-0.044381,0.737695,-0.307562,-1.614710


In [18]:
target

'Survived'

In [19]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(8, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x
model = Net()
print(model)

Net(
  (fc1): Linear(in_features=8, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [31]:
batch_size = 64
n_epochs = 500
batch_no = len(X_train) // batch_size

print(batch_no)

train_loss = 0
train_loss_min = np.Inf
for epoch in range(n_epochs):
    for i in range(batch_no):
        print(train_loss)

        start = i*batch_size
        end = start+batch_size
        x_var = torch.FloatTensor(X_train[start:end])
        y_var = torch.LongTensor(y_train[start:end])
        
        optimizer.zero_grad()
        output = model(x_var)
        loss = criterion(output,y_var)
        loss.backward()
        optimizer.step()
        
        values, labels = torch.max(output, 1)
        num_right = np.sum(labels.data.numpy() == y_train[start:end])
        train_loss += loss.item()*batch_size
        # print(len(X_train[start:end]))
        # print(x_var)

    
    train_loss = train_loss / len(X_train)
    if train_loss <= train_loss_min:
        print("Validation loss decreased ({:6f} ===> {:6f}). Saving the model...".format(train_loss_min,train_loss))
        torch.save(model.state_dict(), "model.pt")
        train_loss_min = train_loss
    
    if epoch % 200 == 0:
        print('')
        print("Epoch: {} \tTrain Loss: {} \tTrain Accuracy: {}".format(epoch+1, train_loss,num_right / len(y_train[start:end]) ))
    # break
print('Training Ended! ')

13
0
26.661746978759766
53.78306198120117
73.67473030090332
101.19891548156738
129.60194778442383
146.74592781066895
171.14139366149902
197.38611793518066
219.63089752197266
239.78802108764648
264.8507671356201
282.5609321594238
Validation loss decreased (   inf ===> 0.341393). Saving the model...

Epoch: 1 	Train Loss: 0.34139253849667467 	Train Accuracy: 0.0
0.34139253849667467
25.693347952315033
51.31673433537168
71.43636514897031
99.30953409428281
130.89189340824765
145.74236299748105
170.8803920959918
197.3219490265338
220.75556566471738
238.8215542053668
264.05210687870664
283.15221597904844
0.3423566465264851
26.15768715189758
50.44913193033996
69.73783394144836
98.5383615140802
130.02761932658018
147.60624214457334
171.6765870695001
196.4163102750665
220.81370445536436
239.30493637369932
265.67997833536924
283.6478719358331
0.34310314225881655
25.73630115556448
49.247150154587914
66.97516605424612
92.79946682023733
122.80591938066213
140.78313800859183
166.43423435258597
194.57

175.69195203437022
201.1938689197462
225.32925252570323
244.03004293097666
271.0227454151075
290.5010798419874
0.34740978491040514
24.52422481786939
49.042207300779545
68.99032169592115
95.58697658788893
126.77211338293287
143.94985920202467
166.76740986120436
194.0630317331526
218.19706875097486
235.35935360204908
259.80944973242015
279.8239780069563
0.3375846010401029
26.392741308803775
50.38413153707526
70.15810881673835
96.65195761739753
127.5422144991114
144.52265654623054
168.5555907350733
196.8923007112696
220.88978491842292
239.22722922384284
265.8067980867579
284.0690108400538
0.34160464070661595
27.15489237569197
52.268900325521074
71.17088358662947
99.52666895650252
130.60980646870956
147.459034373983
171.70149271748886
196.94644778035507
221.1304668691002
237.91755717061386
264.66497652791367
281.8581814076744
0.3400681604067844
26.96307116638823
52.689853011481006
73.51409274231597
100.95640116821929
134.22403651367827
150.75532466065093
175.66512805115386
201.632918654791

199.7140980603333
223.54144295474248
241.96432694216924
269.20525368472295
287.36462029238896
0.3465836115961386
27.28188062392524
52.029708153832466
72.01875615798286
99.74792028151802
131.18764997206978
146.67153669081978
170.77772451125435
196.35883069716743
219.91629148207954
237.4735806056147
262.6885807582026
282.8315479823237
0.3430361199267791
27.215001528251975
53.1216597008594
73.31373733757815
101.47784561394533
131.15065902947268
147.54893821953615
170.84756798027834
198.89497703789553
221.87037986992678
241.15461487053713
267.546918337334
285.28632110833007
0.3449257781529542
25.112450497391237
52.8930996825719
70.78200806880237
99.01671304011585
127.85536469722034
143.69432343745473
167.32091225886586
193.90928353572133
217.72708215022328
235.69358529353383
260.78015221858266
279.6420277526403
0.33959892745264453
27.252238545372567
51.63225677962549
69.96055153365381
97.84884193892725
130.07554939742334
147.4734089707876
171.62047127242334
197.9352581834341
221.9255822038

97.5433118221626
129.06657193873974
145.18304036830517
168.07754300807568
195.70401357387158
219.2208497402534
237.3537194607124
264.1247269985542
281.96191190456005
0.3418822301755834
23.72075414790019
48.2234225286375
65.2102446569334
94.52326917780742
125.29113531245098
143.4019351018797
168.74032926691876
194.7146029485594
218.21913862360822
235.49513006342755
261.68933248652326
279.98207426203595
0.3379314943878617
23.970743994387863
49.58341489282536
69.64492116418766
97.5255440279572
127.62060437646794
143.92099080530096
167.2849949404328
193.4198711916535
218.24086842981268
234.9735764071076
259.810129980716
278.24984822717596
0.33767871327369087
24.931743425798103
50.1477086967942
69.17392329640357
94.48407153553931
123.65009097523658
141.1484411186692
164.2033021873948
191.3645055717942
215.42170123524636
232.94186381764382
259.31126193470925
275.4257677025315
0.33294268841903985
26.576417648746187
53.12933703656357
72.11586161847275
99.88706752057236
131.27545520062608
149.1

129.58121913231733
144.28042548455122
167.96161597527387
194.77597564019086
217.85774367607954
237.06652587212446
262.4008516625297
281.7883905724418
0.3395511800944618
28.22018548795579
52.66174432584641
72.86960145780442
100.76945421048997
132.02344056913256
149.74560853788256
173.85115358182787
202.006156175822
225.49777719327807
243.6251270467937
267.52631685087084
284.11955568143725
0.3457316958604236
22.634515722960032
47.08932681915144
65.90516277191023
92.73290058013777
123.29593082305769
139.25212665435652
163.71607013580183
189.69644542571882
213.59029384490827
231.8467063510118
255.04197498199323
274.4343642794786
0.3305169029093683
26.239544002518745
51.779531566605655
70.24179658125409
98.61862572859296
129.9209958000041
144.69134816359053
167.83696850965987
194.57943257521163
219.91071614454756
237.4126482887126
262.08995541761885
279.59794148634444
0.33912165380833803
24.500742747558338
48.0954674027097
68.66188700415013
98.05857165075658
127.21169169165013
144.150032832

211.9091132651753
231.0581439506001
256.09116397013133
276.02992663492626
0.3351527730502338
25.612164644388123
50.529713777688904
70.08103480735198
97.69642558493987
127.26860728660003
144.9601623097934
168.29062762656585
194.81968417563812
215.40164294639007
232.3718310872592
255.79473605552093
272.808480409769
0.32932666084357404
23.682910950394355
48.350486786576
66.30785945198127
93.71976855537483
124.21598819038459
141.32170871040412
166.02699664375373
191.46862033149787
215.0246639563758
232.0605240180213
257.2350845648963
276.65382388374394
0.33528114091156114
28.193689115154726
55.06447387467621
75.76323676835297
104.55233550797699
133.886693723309
148.90296531449553
172.2930772377133
200.61837363969084
226.00664878617522
243.29105735551116
268.7242047859311
286.1812131477475
0.3467962345906914
24.63742466781823
51.10277577255456
68.5853101811239
97.51026364181726
128.45029278610437
145.25734921310632
169.63233204696863
197.21859761093347
219.0014192662069
237.4285718998739
26

261.1757274735588
280.08632241374437
0.3380513620551207
24.970497651117622
48.630707306879344
66.49487738134223
95.541878266352
125.43821386815864
142.2728009653266
165.73161177160102
192.7667775583442
214.539925141352
233.15194181920845
259.5853810739692
277.8535409403022
0.33723816678448376
26.96821770474835
51.71184143826886
71.79550347134992
100.17711815640851
129.60459503934308
146.7722509841673
169.8041990737669
194.587152336096
218.5004938583128
236.08916458890363
260.53656754300516
276.4957416991941
0.33406588669762377
25.337775679788443
49.82998797532067
68.92334124680505
97.80150172348962
129.1065077507113
145.64121386643396
170.3923144065951
196.1365084373568
218.45127818222986
235.81367632981286
259.6473650657748
279.1175226890658
0.33809172020344874
25.12282835350423
49.71347891197591
67.60335432396322
97.5760620723763
131.2242668758431
148.143275130848
171.14147077904136
197.101469863514
219.52534185753257
237.5475319515267
259.3899849544564
277.89878355369956
0.337228880

137.8168853605522
161.4233020628227
188.41121519453657
212.34806478865278
230.36615408308637
254.2921222532524
272.4334873045219
0.3308172433425295
26.10358907866968
50.4299192636062
70.85918714690698
98.68856527495873
128.61484243560326
145.6599760263015
169.01794149566186
196.96331884551537
220.40783407378686
237.82080938506616
261.8420725076247
278.92991163421163
0.33768110459555156
29.02284078782309
51.5009387034725
70.7254374522518
98.06824903672445
129.7706378001522
145.10223131364097
167.6478531855892
193.1980984706234
218.41917543595542
236.92530565446128
263.638888647015
280.23074655717124
0.3400394505008783
25.636557776306542
48.30997582135049
66.58534165081826
94.94502754864541
125.67355843243448
142.7241317814335
165.09146805462686
191.5644638126835
214.49372406658975
231.81181641278116
256.8717586582645
275.2254116123661
0.3315777518605794
26.608612511382063
52.94204566262718
70.0527119376516
96.77904937051292
126.40530631326195
141.75769278787132
164.74676177285667
191.03

168.7030928923773
193.14358742973081
217.12264474174253
233.82201035758726
258.6802714659857
276.3788054778265
0.334498794346645
24.033379943638636
48.163499267369104
66.09204808023532
95.73330441263278
123.04735127237399
139.42275753763278
164.17021694925387
190.74072399881442
214.88009968545992
232.46288433816989
256.7859930112656
274.9535974576767
0.33286804666720476
27.04625076761447
50.745061345007045
68.86298508157931
96.11045784464083
128.87315315714085
145.72006744852268
171.53757805338108
198.4213213395139
220.93062538614475
238.08297867288792
260.86613792886936
280.2344850968381
0.33985920085639626
26.195146752736278
49.78950710429878
67.76727504863472
94.37418193950386
124.705791665456
141.0359289182636
164.67162914409369
190.39877529277533
210.79175968303412
228.72234554424017
253.33554477824896
271.69637317790716
0.3291772016118509
27.73313456855521
50.74535304389701
68.99492007575736
98.53930598579154
128.56392222724662
143.87430983863578
167.86695510230766
193.7415068753

255.99148675337813
274.37138291731856
0.33288151349701356
29.227155271431585
54.24511105146088
72.50461155500092
100.09006649579682
132.36615139569915
148.03952175702727
172.60693699445403
198.51198536481536
220.74752956952727
237.15728527631438
261.9449592265097
278.5138087901083
0.33557817945672647
27.662060952283873
53.55266306410028
70.81949160108758
99.29615328321648
128.91438982496453
144.75822851667596
168.28017637739373
192.5065157176037
217.51142714033318
232.95450041303826
255.62211821088982
273.4978620768566
0.3316476526753825
27.41245629220175
52.2150247278707
70.30095841317343
98.43918396859335
129.4255616846578
144.1111610117452
166.9649799051778
194.97792317300008
218.41016079812215
235.6165263834493
261.02529599099324
277.9296844187276
0.3402886661151194
23.387392547951055
46.99868061862976
65.88555195652039
95.9030728610614
123.05451443515808
138.58195736728698
162.22023823581725
186.8479962619403
210.7455163272479
225.91896488986998
249.46010449253112
266.226906326637

245.8070033364888
263.8122103982564
0.3207576661057072
25.079891958708245
49.534943380827386
66.6056983857102
93.14592055367895
123.11422232675024
139.25210455941624
162.3791721253342
187.36788443612522
211.04037169503636
227.22889021920628
250.14014510201878
266.83568266915745
0.3218779888115555
26.845424207683624
50.71509507560354
68.58662942313772
95.91369965934378
125.91461709403616
140.60144761466603
163.45470956229786
188.18844941520314
210.375093015667
226.74459222220997
249.95758775138478
266.47675660514454
0.3226625349749042
26.412044706605762
49.672910871766895
67.18054026629815
95.70602435138116
124.20260828997979
140.20903605487237
163.90157336261163
187.5895940687884
209.66901797320733
225.4705755141009
250.43749255206475
268.1344282057513
0.326807867868205
23.82103432355668
46.755349051339884
63.36358154950883
89.18092430768755
119.69494904218462
134.69489277539995
156.78277767835405
180.96845425306108
202.03120220838335
218.70832623182085
243.7246226185396
261.3104647511

182.35842997696471
205.5386858668573
222.42856318619323
246.82882792618346
265.00172526505065
0.3211634890361809
24.58395416896782
47.37531311672173
65.26037819545708
93.4497740999981
123.61045296352349
140.0874713198467
162.81571610134085
188.3305190340801
214.90149529140433
232.320421530418
255.25673325221976
270.19925434749564
0.32829291017949735
26.529583040794734
51.33592802675665
69.24518210085333
99.62104612978399
130.43767553957403
147.2170639673084
168.5085621515125
195.73655325563848
220.5045395532459
238.18357855471075
261.25337607057986
277.8105660119861
0.3355186647213769
23.86653406083954
47.2631672090329
63.93710977495087
90.9129245942868
120.614874667529
136.15216237962616
158.512992686633
183.7760599321164
207.29369527757538
224.53999502122772
249.39037305772675
268.2383306688107
0.3262371998918352
24.74023485003832
48.88063478289965
67.10339402972582
96.41355943499926
126.36459207354906
141.44064569293383
163.3555493336809
188.1628880482805
211.06694269000414
227.9388

162.59930932924772
188.81420648500944
212.34540498659635
229.00296342775846
251.0923684875729
269.21000039980436
0.32602448830484165
24.19248528847574
48.99896378884195
66.38982339272378
93.53119988808511
122.47310395607828
138.4172901667106
161.16031880745766
185.6956867731559
208.3900398767936
224.58713384041664
245.05086751351234
263.41686292061684
0.31869054772246297
26.068501720207813
48.87047697045195
67.05824399926055
94.97331548668731
124.10245634056915
140.23105932213653
164.04946065880645
188.2981960866629
209.59669614769805
228.83842397667755
251.8445942495047
269.44137121178494
0.3265769171364445
24.995901792014376
49.51844665285422
66.49066039797141
92.80722495790793
121.16080543276145
137.91244575258565
160.9721877264626
187.64960357424093
210.67688056703878
226.7899138617165
251.69568320986104
270.2125365423806
0.3290132386836548
25.478047738805724
49.54588640762897
66.99553621841999
93.87985361649127
123.3527959687862
139.56660021378133
163.16258753372762
186.1600297792

220.52946856579638
241.17006686291552
258.1715774830327
Validation loss decreased (0.311679 ===> 0.311432). Saving the model...
0.3114319725666501
24.194728939119386
47.70131692129712
65.42770585257153
91.45832070547681
121.48222169119458
137.3056393547688
159.33724602896314
185.400354473177
209.25471505362134
226.66550454336743
252.07828721243482
269.29486474234204
0.32747691043385596
25.879890850863543
48.14397662051687
66.17354815371999
93.34042208560476
124.28590815433034
139.85228484042653
161.72501891025075
185.99199431308278
209.73558562167653
227.232892445407
251.58059637912282
268.36063330539235
0.3270679108128779
25.369578897141004
48.349309503220105
65.38147121586658
91.10943370976307
119.48866039433338
136.49836689152576
158.91832119145252
183.36560398259022
206.61333423771717
222.36443763890125
245.63652092137195
262.5430074326024
0.31911325354276254
26.60191392798124
50.587612627932415
68.1438050260037
97.61414003271757
127.04661607641874
142.51508664984402
163.8239760388

202.01419175894424
218.51033319266006
240.85772813589736
259.5131446245448
0.31499757471076745
28.265274903201977
51.414816758060375
68.30658807458869
94.92403306665412
124.13020982446662
139.0902394265174
161.3620957345008
187.37144174280158
211.41425599756232
228.9878224343299
254.94529237451545
270.9738015145301
0.3272298614819844
24.694123629792532
47.09825647281011
63.311946276887255
89.71852243350347
120.27412736819585
136.53900659488042
157.53759134219487
182.6131013672193
207.25983179019292
224.23326242373784
248.7830175201734
266.82823312686287
0.3230638850596707
24.913485561695413
48.55076125749619
63.97725871690781
93.99014094957383
120.50488856920273
136.55578807481797
158.37563327440293
183.3309593547252
206.46694759019883
223.32201579698594
245.9982891429332
264.3397770274547
0.32116642922674404
26.657288942044126
51.6454176041535
68.67936268777166
98.2163061234406
127.13412800759588
144.7670110794953
168.27114621133123
194.55134716958318
216.5878004166535
233.82963505715

In [32]:
X_test = df_test.iloc[:,1:].values
X_test_var = Variable(torch.FloatTensor(X_test), requires_grad=False) 
with torch.no_grad():
    test_result = model(X_test_var)
values, labels = torch.max(test_result, 1)
survived = labels.data.numpy()

In [33]:
submission = pd.DataFrame({'PassengerId': df_sub['PassengerId'], 'Survived': survived})
submission.to_csv('submission.csv', index=False)