In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd

In [2]:
#讀取資料
train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
test = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
y_train = train.pop('survived')
y_test = test.pop('survived')
train

Unnamed: 0,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone
0,male,22.0,1,0,7.2500,Third,unknown,Southampton,n
1,female,38.0,1,0,71.2833,First,C,Cherbourg,n
2,female,26.0,0,0,7.9250,Third,unknown,Southampton,y
3,female,35.0,1,0,53.1000,First,C,Southampton,n
4,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y
...,...,...,...,...,...,...,...,...,...
622,male,28.0,0,0,10.5000,Second,unknown,Southampton,y
623,male,25.0,0,0,7.0500,Third,unknown,Southampton,y
624,female,19.0,0,0,30.0000,First,B,Southampton,y
625,female,28.0,1,2,23.4500,Third,unknown,Southampton,n


In [3]:
categorical_columns = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',
                       'embark_town', 'alone']
numeric_columns = ['age', 'fare']

feature_columns = []
#類別欄位處理
for feature_name in categorical_columns:
  #依據train的feature_name將categorical轉one-hot
  vocabulary = train[feature_name].unique()
  feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))

#數值欄位處理
for feature_name in numeric_columns:
  feature_columns.append(tf.feature_column.numeric_column(feature_name,
                                                          dtype=tf.float32))

In [4]:
def make_input_fn(X, y, n_epochs=10, shuffle=True, batch_size=64):
  def input_function():
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
    if shuffle:
      dataset = dataset.shuffle(1000)
    dataset = dataset.batch(batch_size).repeat(n_epochs)
    return dataset
  return input_function

train_input_fn = make_input_fn(train, y_train)
test_input_fn = make_input_fn(test, y_test, n_epochs=1, shuffle=False)

In [5]:
params = {
  'n_trees': 50, #樹的數量
  'max_depth': 3, #樹的深度
  'n_batches_per_layer': 5, #每一層訓練批次數
}

est = tf.estimator.BoostedTreesClassifier(feature_columns, **params)
est.train(train_input_fn, max_steps=100)

results = est.evaluate(test_input_fn)
pd.Series(results).to_frame()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu6u25tqv', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initiali

Unnamed: 0,0
accuracy,0.772727
accuracy_baseline,0.625
auc,0.841996
auc_precision_recall,0.808529
average_loss,0.465244
label/mean,0.375
loss,0.449376
precision,0.73494
prediction/mean,0.39867
recall,0.616162


In [6]:
pred_dicts = list(est.predict(test_input_fn))
pred_dicts

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpu6u25tqv/model.ckpt-95
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


[{'all_class_ids': array([0, 1], dtype=int32),
  'all_classes': array([b'0', b'1'], dtype=object),
  'class_ids': array([0]),
  'classes': array([b'0'], dtype=object),
  'logistic': array([0.18777516], dtype=float32),
  'logits': array([-1.4645319], dtype=float32),
  'probabilities': array([0.81222486, 0.18777516], dtype=float32)},
 {'all_class_ids': array([0, 1], dtype=int32),
  'all_classes': array([b'0', b'1'], dtype=object),
  'class_ids': array([0]),
  'classes': array([b'0'], dtype=object),
  'logistic': array([0.34525838], dtype=float32),
  'logits': array([-0.63994765], dtype=float32),
  'probabilities': array([0.65474164, 0.34525838], dtype=float32)},
 {'all_class_ids': array([0, 1], dtype=int32),
  'all_classes': array([b'0', b'1'], dtype=object),
  'class_ids': array([1]),
  'classes': array([b'1'], dtype=object),
  'logistic': array([0.82922214], dtype=float32),
  'logits': array([1.5801246], dtype=float32),
  'probabilities': array([0.17077783, 0.82922214], dtype=float32)}