## 使用 TensorFlow 搭建神经网络预测泰坦尼克号乘客生存率

比赛地址：[Titanic: Machine Learning from Disaster](https://www.kaggle.com/c/titanic)

本文将从以下几方面介绍解决方案：
1. 添加库
2. 定义全局变量
3. 加载数据文件
4. 清洗数据
5. 特征工程
6. 搭建神经网络（前向传播）
7. 训练（反向传播）
8. 测试模型效果 && 生成提交结果

## 1. 添加库

In [3]:
import tensorflow as tf
import model
import pandas as pd
import numpy as np

## 2. 定义全局变量

In [4]:
train_data_file = r'./data/train.csv'
test_data_file = r'./data/test.csv'
test_label_file = r'./data/gender_submission.csv'
model_save_path = r'./ckpt/model'
output_file = r'./result.csv'

learning_rate = 0.001
BATCH_SIZE = 20

input_size = 11
hidden_size = 20
output_size = 2

epoch = 2000
KEEP_PROB = 0.5

## 3. 加载数据文件

In [5]:
data = pd.read_csv(train_data_file, 
    sep=',', 
    dtype={
        'Name' : 'str',
        'Survived' : 'int64',
        'Pclass' : 'float32',   
        'Sex' : 'str',
        'Age' : 'float32',
        'SibSp' : 'float32',
        'Parch' : 'float32',
        'Fare' : 'float32',
        'Embarked' : 'str',
    }
)

In [6]:
data.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3.0,"Braund, Mr. Owen Harris",male,22.0,1.0,0.0,A/5 21171,7.25,,S
1,2,1,1.0,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1.0,0.0,PC 17599,71.283302,C85,C
2,3,1,3.0,"Heikkinen, Miss. Laina",female,26.0,0.0,0.0,STON/O2. 3101282,7.925,,S
3,4,1,1.0,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1.0,0.0,113803,53.099998,C123,S
4,5,0,3.0,"Allen, Mr. William Henry",male,35.0,0.0,0.0,373450,8.05,,S


In [7]:
data.describe()

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare
count,891.0,891.0,891.0,714.0,891.0,891.0,891.0
mean,446.0,0.383838,2.308642,29.699118,0.523008,0.381594,32.204208
std,257.353842,0.486592,0.836071,14.526497,1.102744,0.806057,49.693428
min,1.0,0.0,1.0,0.42,0.0,0.0,0.0
25%,223.5,0.0,2.0,20.125,0.0,0.0,7.9104
50%,446.0,0.0,3.0,28.0,0.0,0.0,14.4542
75%,668.5,1.0,3.0,38.0,1.0,0.0,31.0
max,891.0,1.0,3.0,80.0,8.0,6.0,512.329224


## 4. 清洗数据

In [8]:
# 提取标签，构造训练集标签 y_
y_ = data.loc[:,'Survived']
y_0 = y_.map(lambda x: 0 if x==1 else 1)
y_1 = y_
y_ = pd.concat([y_0, y_1], axis=1)
# 重命名列名
y_.columns = ['Dead','Survived']
# 转换数据类型
y_ = y_.astype('float32')
# 获取 Numpy 格式的矩阵，便于 TensorFlow 处理
y_ = y_.values

In [9]:
# 提取需要考虑的字段作为特征，构造训练集数据 x
x = data.loc[:,['Name','Pclass','Sex','Age','SibSp','Parch','Fare','Embarked']]

In [10]:
# 清洗训练集数据前
x.head(10)

Unnamed: 0,Name,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,"Braund, Mr. Owen Harris",3.0,male,22.0,1.0,0.0,7.25,S
1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",1.0,female,38.0,1.0,0.0,71.283302,C
2,"Heikkinen, Miss. Laina",3.0,female,26.0,0.0,0.0,7.925,S
3,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",1.0,female,35.0,1.0,0.0,53.099998,S
4,"Allen, Mr. William Henry",3.0,male,35.0,0.0,0.0,8.05,S
5,"Moran, Mr. James",3.0,male,,0.0,0.0,8.4583,Q
6,"McCarthy, Mr. Timothy J",1.0,male,54.0,0.0,0.0,51.862499,S
7,"Palsson, Master. Gosta Leonard",3.0,male,2.0,3.0,1.0,21.075001,S
8,"Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)",3.0,female,27.0,0.0,2.0,11.1333,S
9,"Nasser, Mrs. Nicholas (Adele Achem)",2.0,female,14.0,1.0,0.0,30.070801,C


In [11]:
# 清洗数据
# 将 Sex 列中的性别字符串替换为数字
x['Sex'] = x['Sex'].replace(['female', 'male'],[0,1]).astype('int32')
# 将 Embarked 出发地的字符替换为数字，同时将列中 NaN 值替换为 0
x['Embarked'] = x['Embarked'].fillna('S')
mapping = {'C':0,'Q':1,'S':2}
x['Embarked'] = x['Embarked'].map(mapping)
# 将 Fare 收入中为 NaN 的填充为中位数
x['Fare'] = x['Fare'].fillna(x['Fare'].median())

# 使用 随机森林 预测 Age 字段的缺失值
from sklearn.ensemble import RandomForestRegressor
age = data[['Age','Survived','Fare','Parch','SibSp','Pclass']]
age_notnull = age.loc[(data.Age.notnull())]
age_isnull = age.loc[(data.Age.isnull())]
X = age_notnull.values[:,1:]
Y = age_notnull.values[:,0]
rfr = RandomForestRegressor(n_estimators=1000,n_jobs=-1)
rfr.fit(X,Y)
predictAges = rfr.predict(age_isnull.values[:,1:])
x.loc[(x.Age.isnull()),'Age'] = predictAges

  from numpy.core.umath_tests import inner1d


In [12]:
x.head(10)

Unnamed: 0,Name,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,"Braund, Mr. Owen Harris",3.0,1,22.0,1.0,0.0,7.25,2
1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",1.0,0,38.0,1.0,0.0,71.283302,0
2,"Heikkinen, Miss. Laina",3.0,0,26.0,0.0,0.0,7.925,2
3,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",1.0,0,35.0,1.0,0.0,53.099998,2
4,"Allen, Mr. William Henry",3.0,1,35.0,0.0,0.0,8.05,2
5,"Moran, Mr. James",3.0,1,23.713577,0.0,0.0,8.4583,1
6,"McCarthy, Mr. Timothy J",1.0,1,54.0,0.0,0.0,51.862499,2
7,"Palsson, Master. Gosta Leonard",3.0,1,2.0,3.0,1.0,21.075001,2
8,"Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)",3.0,0,27.0,0.0,2.0,11.1333,2
9,"Nasser, Mrs. Nicholas (Adele Achem)",2.0,0,14.0,1.0,0.0,30.070801,0


## 5. 特征工程

In [13]:
# 构造新的特征

# 添加 Child 特征，<=12 的为小孩子，设为1，否则为0
x['Child'] = x.Age.apply(lambda x: 1 if x<=16 else 0).astype('int32')

# 添加 FamilySize 特征，表示家族大小
x['FamilySize'] = x['SibSp'] + x['Parch'] + 1
x['FamilySize'] = x['FamilySize'].astype('int32')

# 添加 IsAlone 特征，表示是否独身一人。如果 FamilySize==1，则为1，否则为0
x['IsAlone'] = x.FamilySize.apply(lambda x: 1 if x==1 else 0)

# 添加 Age_bin 特征，划分年龄区间
x['Age_bin'] = pd.cut(x['Age'], bins=[0,16,32,48,1200], 
                    labels=['Children','Teenage','Adult','Elder'])
mapping = {'Children':0,'Teenage':1,'Adult':2,'Elder':3}
x['Age_bin'] = x['Age_bin'].map(mapping)


# 添加 Fare_bin 特征，划分收入区间
x['Fare_bin'] = pd.cut(x['Fare'], bins=[-1,7.91,14.45,31,12000], 
                    labels=['Low_fare','median_fare','Average_fare','high_fare'])
mapping = {'Low_fare':0,'median_fare':1,'Average_fare':2,'high_fare':3}
x['Fare_bin'] = x['Fare_bin'].map(mapping)

# 处理 Name 特征
import re
# Define function to extract titles from passenger names
def get_title(name):
    title_search = re.search(' ([A-Za-z]+)\.', name)
    # If the title exists, extract and return it.
    if title_search:
        return title_search.group(1)
    return ""
# Create a new feature Title, containing the titles of passenger names
x['Title'] = x['Name'].apply(get_title)
# Group all non-common titles into one single grouping "Rare"
x['Title'] = x['Title'].replace(['Lady', 'Countess','Capt', 'Col','Don', 
                                            'Dr', 'Major', 'Rev', 'Sir', 'Jonkheer', 'Dona'], 'Rare')

x['Title'] = x['Title'].replace('Mlle', 'Miss')
x['Title'] = x['Title'].replace('Ms', 'Miss')
x['Title'] = x['Title'].replace('Mme', 'Mrs')

mapping = {"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5}
x['Title'] = x['Title'].map(mapping)
x['Title'] = x['Title'].fillna(0)

# 丢弃无用字段
x.drop(["Name", "Age", "Fare"], axis=1, inplace=True)

In [14]:
x.head(10)

Unnamed: 0,Pclass,Sex,SibSp,Parch,Embarked,Child,FamilySize,IsAlone,Age_bin,Fare_bin,Title
0,3.0,1,1.0,0.0,2,0,2,0,1,0,1
1,1.0,0,1.0,0.0,0,0,2,0,2,3,3
2,3.0,0,0.0,0.0,2,0,1,1,1,1,2
3,1.0,0,1.0,0.0,2,0,2,0,2,3,3
4,3.0,1,0.0,0.0,2,0,1,1,2,1,1
5,3.0,1,0.0,0.0,1,0,1,1,1,1,1
6,1.0,1,0.0,0.0,2,0,1,1,3,3,1
7,3.0,1,3.0,1.0,2,1,5,0,0,2,4
8,3.0,0,0.0,2.0,2,0,3,0,1,1,3
9,2.0,0,1.0,0.0,0,1,2,0,0,2,3


In [15]:
# 获取 Numpy 格式的矩阵，便于 TensorFlow 处理
x = x.values

## 6. 搭建神经网络（前向传播）

In [16]:
def input_placeholder(input_size, output_size):
    # 输入占位符
    x = tf.placeholder(dtype=tf.float32, shape=[None, input_size])
    y_ = tf.placeholder(dtype=tf.float32, shape=[None, output_size])
    keep_prob = tf.placeholder(tf.float32)

    return x, y_, keep_prob

def forward(x, w1, w2, b1, b2, keep_prob=1.0):
    # 模型结构
    # 定义一个多层感知机（MLP），最后加一个softmax归一化进行二分类
    # 输入定义9个神经元，隐藏层定义100个神经元，输出层定义两个神经元（二分类），然后做一个softmax
    
    a = tf.matmul(x, w1) + b1
    a = tf.nn.dropout(a, keep_prob=keep_prob)
    a = tf.nn.relu(a)
    y = tf.matmul(a, w2) + b2
    return y

def loss(y, y_):
    # 交叉熵 损失
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
    return loss

def accuary(y, y_):
    # 预测准确率
    correct_pred = tf.equal(tf.argmax(y_, 1),tf.argmax(y,1))
    acc = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    return acc

## 7.训练（反向传播）

In [17]:
def Train():
    global x, y_
    # 训练过程
    X, Y_ = x, y_
    
    # 模型
    w1 = tf.Variable(tf.random_normal([input_size, hidden_size], stddev=1.0, seed=2.0))
    w2 = tf.Variable(tf.random_normal([hidden_size, output_size], stddev=1.0, seed=2.0))
    b1 = tf.Variable(tf.zeros([hidden_size]), name='bias1')
    b2 = tf.Variable(tf.zeros([output_size]), name='bias2')

    x, y_, keep_prob = model.input_placeholder(input_size, output_size)
    y = model.forward(x, w1, w2, b1, b2, keep_prob=keep_prob)
    loss = model.loss(y, y_)
    y = tf.nn.softmax(y)
    accuary = model.accuary(y, y_)

    # 定义训练（反向传播）过程
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    
    # tf saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 变量初始化
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # 训练
        print('Train start...')
        for i in range(epoch):
            for j in range(len(Y_) // BATCH_SIZE + 1):
                start = j * BATCH_SIZE
                end = start + BATCH_SIZE
                # 反向传播
                _, loss_result, y_result, acc_result = sess.run([train_op, loss, y, accuary], feed_dict={x:X[start:end], y_:Y_[start:end], keep_prob: KEEP_PROB})
            # 输出每个 epoch 之后的 loss 和 准确率
            print(i, sess.run([loss, accuary], feed_dict={x:X, y_:Y_, keep_prob: 1.0}))
        print('Train end.')

        # 保存模型到本地
        print('Saving model...')
        saver.save(sess, model_save_path)
        print('Save finally.')
        
        
Train()

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

Train start...
0 [9.238239, 0.3815937]
1 [6.5285263, 0.35353535]
2 [4.73439, 0.3580247]
3 [3.5500216, 0.338945]
4 [2.8156452, 0.3681257]
5 [2.2658987, 0.42199776]
6 [1.9280611, 0.5274972]
7 [1.6983049, 0.5409652]
8 [1.5067697, 0.54545456]
9 [1.3395972, 0.61616164]
10 [1.2204427, 0.6251403]
11 [1.0978968, 0.6195286]
12 [1.00659, 0.6666667]
13 [0.9300843, 0.6487093]
14 [0.86463743, 0.684624]
15 [0.80138737, 0.6924804]
16 [0.755404, 0.68350166]
17 [0.72290105, 0.68013465]
18 [0.6942688, 0.6936027]
19 [0.669489, 0.7003367]
20 [0.65378803, 0.69921434]
21 [0.64093703, 0.6958474]
22 [0.6233864, 0.7037037]
23 [0.6059771, 0.70819306]
24 [0.5951998, 0.7104377]
25 [0.5851009, 0.7115601]
26 [0.57629156, 0.70819306]
27 [0.5704922, 0.7115601]
28 [0.5655452, 0.70819306]
29 [0.5610418, 0.70819306]
30 [0.5599404, 

285 [0.38963622, 0.83950615]
286 [0.38828897, 0.8406285]
287 [0.38816574, 0.83950615]
288 [0.3892389, 0.83838385]
289 [0.38967267, 0.83613914]
290 [0.38969666, 0.83838385]
291 [0.38899106, 0.84175086]
292 [0.38772774, 0.8406285]
293 [0.3874741, 0.83838385]
294 [0.38701987, 0.83501685]
295 [0.38774607, 0.83501685]
296 [0.38831767, 0.8372615]
297 [0.3878356, 0.83838385]
298 [0.38841483, 0.8372615]
299 [0.3877315, 0.83501685]
300 [0.38774994, 0.83838385]
301 [0.3876871, 0.83838385]
302 [0.388971, 0.83613914]
303 [0.3890971, 0.83613914]
304 [0.38855213, 0.83613914]
305 [0.38658503, 0.83613914]
306 [0.3860616, 0.83613914]
307 [0.3860877, 0.8372615]
308 [0.38586268, 0.83613914]
309 [0.38550085, 0.83613914]
310 [0.38581172, 0.83501685]
311 [0.38581523, 0.83613914]
312 [0.38613942, 0.83613914]
313 [0.38654703, 0.83613914]
314 [0.38586816, 0.83613914]
315 [0.38806638, 0.8372615]
316 [0.38739604, 0.83838385]
317 [0.38716236, 0.8338945]
318 [0.3861148, 0.84175086]
319 [0.3853408, 0.83501685]
320 

577 [0.3723312, 0.8439955]
578 [0.37304747, 0.84511787]
579 [0.37354994, 0.84175086]
580 [0.37269744, 0.84287316]
581 [0.3721705, 0.8406285]
582 [0.37218374, 0.83950615]
583 [0.3716794, 0.84511787]
584 [0.37213543, 0.84511787]
585 [0.37130114, 0.84511787]
586 [0.3706353, 0.84175086]
587 [0.3702187, 0.8439955]
588 [0.37153384, 0.83950615]
589 [0.3720404, 0.8305275]
590 [0.3714025, 0.8271605]
591 [0.37054923, 0.8406285]
592 [0.37063172, 0.84175086]
593 [0.36982095, 0.8406285]
594 [0.3714469, 0.82828283]
595 [0.37131956, 0.8305275]
596 [0.37069565, 0.8294052]
597 [0.37039882, 0.84287316]
598 [0.36984983, 0.84175086]
599 [0.3697864, 0.84287316]
600 [0.37024608, 0.84287316]
601 [0.3699862, 0.84175086]
602 [0.36910176, 0.8406285]
603 [0.36929592, 0.83950615]
604 [0.36991864, 0.84175086]
605 [0.36969674, 0.8473625]
606 [0.3685479, 0.84960717]
607 [0.3696313, 0.8484849]
608 [0.36961177, 0.8484849]
609 [0.3695601, 0.8484849]
610 [0.36867714, 0.84960717]
611 [0.36986703, 0.84624016]
612 [0.36906

871 [0.35968128, 0.8529742]
872 [0.3597215, 0.84960717]
873 [0.36050192, 0.8518519]
874 [0.3598458, 0.8518519]
875 [0.36008197, 0.8518519]
876 [0.3593998, 0.8518519]
877 [0.359339, 0.8518519]
878 [0.35976467, 0.8529742]
879 [0.36036813, 0.8529742]
880 [0.360028, 0.8529742]
881 [0.35946274, 0.8529742]
882 [0.35852352, 0.8563412]
883 [0.35836127, 0.8563412]
884 [0.35913312, 0.8529742]
885 [0.35911152, 0.8484849]
886 [0.35903814, 0.8518519]
887 [0.3584578, 0.8563412]
888 [0.35885102, 0.8529742]
889 [0.35927066, 0.8563412]
890 [0.35915977, 0.8529742]
891 [0.35902733, 0.8529742]
892 [0.3587951, 0.8518519]
893 [0.35805035, 0.8529742]
894 [0.35822612, 0.8529742]
895 [0.359742, 0.8529742]
896 [0.35873383, 0.8563412]
897 [0.3586582, 0.8529742]
898 [0.35940698, 0.8518519]
899 [0.3595271, 0.8507295]
900 [0.35910192, 0.8473625]
901 [0.35832235, 0.8518519]
902 [0.35864577, 0.8529742]
903 [0.3582333, 0.8518519]
904 [0.35857767, 0.8529742]
905 [0.35880342, 0.8563412]
906 [0.35920724, 0.85409653]
907 

1162 [0.35462764, 0.8563412]
1163 [0.3551331, 0.8563412]
1164 [0.3556502, 0.8552188]
1165 [0.35637376, 0.8552188]
1166 [0.3568028, 0.8552188]
1167 [0.35603854, 0.8552188]
1168 [0.3555335, 0.8563412]
1169 [0.35552257, 0.8563412]
1170 [0.35492605, 0.8563412]
1171 [0.35522088, 0.8563412]
1172 [0.35473734, 0.85409653]
1173 [0.3546946, 0.8563412]
1174 [0.35558394, 0.85746354]
1175 [0.35511032, 0.8563412]
1176 [0.3565728, 0.8563412]
1177 [0.35635343, 0.8563412]
1178 [0.35726613, 0.8597082]
1179 [0.35607734, 0.85858583]
1180 [0.35603186, 0.8563412]
1181 [0.35577253, 0.8563412]
1182 [0.35640326, 0.8563412]
1183 [0.35584256, 0.8563412]
1184 [0.3559396, 0.8563412]
1185 [0.35532743, 0.8563412]
1186 [0.355062, 0.8563412]
1187 [0.35553446, 0.8563412]
1188 [0.3548726, 0.8563412]
1189 [0.35445622, 0.8563412]
1190 [0.35452554, 0.8563412]
1191 [0.35449564, 0.8563412]
1192 [0.3551112, 0.8563412]
1193 [0.35516846, 0.8597082]
1194 [0.3552344, 0.8563412]
1195 [0.35535562, 0.8529742]
1196 [0.35593405, 0.856

1445 [0.3537435, 0.8563412]
1446 [0.3537566, 0.8563412]
1447 [0.35442016, 0.8563412]
1448 [0.35556906, 0.8552188]
1449 [0.3548464, 0.85746354]
1450 [0.3549898, 0.85746354]
1451 [0.35417613, 0.8563412]
1452 [0.35440245, 0.85858583]
1453 [0.35434532, 0.85746354]
1454 [0.35445723, 0.8563412]
1455 [0.355668, 0.85746354]
1456 [0.35439152, 0.85746354]
1457 [0.35388806, 0.85746354]
1458 [0.3547435, 0.85746354]
1459 [0.35439897, 0.85746354]
1460 [0.35456485, 0.85746354]
1461 [0.3542563, 0.8563412]
1462 [0.35505638, 0.85746354]
1463 [0.35506693, 0.8563412]
1464 [0.35509548, 0.85746354]
1465 [0.3544317, 0.8563412]
1466 [0.3542327, 0.85746354]
1467 [0.35453138, 0.85746354]
1468 [0.35404736, 0.8563412]
1469 [0.3536503, 0.8563412]
1470 [0.35382056, 0.85746354]
1471 [0.35533518, 0.85746354]
1472 [0.35390037, 0.85746354]
1473 [0.3548573, 0.8529742]
1474 [0.35542834, 0.8529742]
1475 [0.35530657, 0.85746354]
1476 [0.35438403, 0.85746354]
1477 [0.35471925, 0.85746354]
1478 [0.35432446, 0.8563412]
1479 [

1730 [0.35409862, 0.8552188]
1731 [0.35379767, 0.8563412]
1732 [0.3532043, 0.85858583]
1733 [0.35381615, 0.85746354]
1734 [0.35408944, 0.8597082]
1735 [0.35346505, 0.85746354]
1736 [0.35353166, 0.8563412]
1737 [0.3534132, 0.8552188]
1738 [0.35366657, 0.8529742]
1739 [0.35366416, 0.85746354]
1740 [0.35443142, 0.85746354]
1741 [0.35428277, 0.8563412]
1742 [0.35389215, 0.85746354]
1743 [0.35248715, 0.85746354]
1744 [0.35305148, 0.85746354]
1745 [0.35301036, 0.85746354]
1746 [0.35424462, 0.85746354]
1747 [0.35462895, 0.8529742]
1748 [0.35382625, 0.8529742]
1749 [0.35415554, 0.8529742]
1750 [0.35445732, 0.8563412]
1751 [0.35333818, 0.8563412]
1752 [0.35491595, 0.8529742]
1753 [0.35401678, 0.85746354]
1754 [0.3536609, 0.8529742]
1755 [0.35304958, 0.85746354]
1756 [0.35324892, 0.85746354]
1757 [0.35321665, 0.85746354]
1758 [0.35280868, 0.85746354]
1759 [0.35321894, 0.85858583]
1760 [0.35427138, 0.85746354]
1761 [0.35524774, 0.85409653]
1762 [0.35527363, 0.85858583]
1763 [0.35482803, 0.8552188

## 8. 测试模型效果 && 生成提交结果

In [18]:
# 导入测试数据加载函数，处理方式和步骤3-5相似，详情请见代码
from train import get_test_data

In [19]:
# 定义测试过程 && 生成提交结果
def Test():
    # 测试
    # 提取测试数据
    X, Y_, PassengerId = get_test_data()

    # 模型
    w1 = tf.Variable(tf.random_normal([input_size, hidden_size], stddev=1.0, seed=2.0))
    w2 = tf.Variable(tf.random_normal([hidden_size, output_size], stddev=1.0, seed=2.0))
    b1 = tf.Variable(tf.zeros([hidden_size]), name='bias1')
    b2 = tf.Variable(tf.zeros([output_size]), name='bias2')

    x, y_, keep_prob = model.input_placeholder(input_size, output_size)
    y = model.forward(x, w1, w2, b1, b2, keep_prob=keep_prob)
    loss = model.loss(y, y_)
    y = tf.nn.softmax(y)
    accuary = model.accuary(y, y_)
    y = tf.argmax(y, 1)

    #保存模型对象saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 变量初始化
        saver.restore(sess, model_save_path)
        loss, y, acc_result = sess.run([loss, y, accuary], feed_dict={x:X, y_:Y_, keep_prob: 1.0})
        print('loss:',loss)
        print('accuary:',acc_result)
    
    # 生成输出文件
    # 准备输出数据
    Survived = y.reshape((-1,1))
    result = np.hstack((PassengerId, Survived))
    result = pd.DataFrame(result, columns=['PassengerId', 'Survived'])
    result.to_csv(output_file, sep=',', encoding='utf-8', index=False)

# 清除默认图的堆栈，并设置全局图为默认图 
tf.reset_default_graph() 
Test()

INFO:tensorflow:Restoring parameters from ./ckpt/model
loss: 0.82204694
accuary: 0.8827751
