In [1]:
%matplotlib inline

In [2]:
import sys
import numpy as np
import tensorflow as tf

from batcher import Batcher
from metrics import acc012

In [3]:
np.random.seed(1337)

In [4]:
X_train = np.load('../data/bin/train/deal.npy')
y_train = np.load('../data/bin/train/tricks_notrump.npy')

X_val = np.load('../data/bin/val/deal.npy')
y_val = np.load('../data/bin/val/tricks_notrump.npy')

n_examples = X_train.shape[0]

n_h = X_train.shape[1]
n_w = X_train.shape[2]
n_c = X_train.shape[3]

batch_size = 64
n_iterations = 500000
display_step = 1000

learning_rate = 0.001

n_hidden_units = 128

l2_reg = 0.05

In [5]:
X_train.shape, X_val.shape, y_train.shape, y_val.shape

((800000, 4, 13, 4), (100000, 4, 13, 4), (800000, 1), (100000, 1))

In [6]:
X = tf.placeholder(tf.float32, shape=[None, n_h, n_w, n_c])
Y = tf.placeholder(tf.float32, shape=[1, None])

conv1_w = tf.get_variable('c1w', shape=[1, 4, 4, 32], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
conv1_z = tf.nn.conv2d(X, filter=conv1_w, strides=[1,1,1,1], padding='SAME')
conv1_a = tf.nn.relu(conv1_z)

conv2_w = tf.get_variable('c2w', shape=[1, 4, 32, 64], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
conv2_z = tf.nn.conv2d(conv1_a, filter=conv2_w, strides=[1,1,1,1], padding='SAME')
conv2_a = tf.nn.relu(conv2_z)

conv3_w = tf.get_variable('c3w', shape=[1, 4, 64, 128], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
conv3_z = tf.nn.conv2d(conv2_a, filter=conv3_w, strides=[1,1,1,1], padding='SAME')
conv3_a = tf.nn.relu(conv3_z)

conv4_w = tf.get_variable('c4w', shape=[4, 4, 128, 512], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
conv4_z = tf.nn.conv2d(conv3_a, filter=conv4_w, strides=[1,1,1,1], padding='VALID')
conv4_a = tf.nn.relu(conv4_z)

fc_in = tf.contrib.layers.flatten(conv4_a)
fc_w = tf.get_variable('fcw', shape=[n_hidden_units, fc_in.shape.as_list()[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
fc_b = tf.Variable(np.zeros((n_hidden_units, 1)), dtype=tf.float32)
fc_z = tf.add(tf.matmul(fc_w, tf.transpose(fc_in)), fc_b)
fc_a = tf.nn.relu(fc_z)

w_out = tf.get_variable('w_out', shape=[1, n_hidden_units], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(seed=1337))
b_out = tf.Variable(np.zeros((1, 1)), dtype=tf.float32)
pred = tf.add(tf.matmul(w_out, fc_a), b_out)

In [9]:
cost = tf.reduce_mean(tf.squared_difference(pred, Y))

In [10]:
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9).minimize(cost)

In [11]:
init = tf.global_variables_initializer()

In [None]:
batch = Batcher(n_examples, batch_size)
cost_train_batch = Batcher(n_examples, 10000)
cost_val_batch = Batcher(100000, 10000)

In [None]:
costs = []
with tf.Session() as sess:
    sess.run(init)
    
    for iteration in range(n_iterations):
        x_batch, y_batch = batch.next_batch([X_train, y_train])
        
        if iteration % display_step == 0:
            sys.stdout.write('*')
            x_batch_c, y_batch_c = cost_train_batch.next_batch([X_train, y_train])
            x_batch_v, y_batch_v = cost_val_batch.next_batch([X_val, y_val])
            c = sess.run(cost, feed_dict={X: x_batch_c, Y: y_batch_c.T})
            costs.append(c)
            pred_train = sess.run(pred, feed_dict={X: x_batch_c, Y: y_batch_c.T})
            pred_val = sess.run(pred, feed_dict={X: x_batch_v, Y: y_batch_v.T})
            print('it={} cost={}'.format(iteration, c))
            print(acc012(y_batch_c, pred_train.T))
            print(acc012(y_batch_v, pred_val.T))
            
        sess.run(optimizer, feed_dict={X:x_batch, Y:y_batch.T})
            

*it=0 cost=44.44297790527344
(0.014200000000000001, 0.051299999999999998, 0.1076)
(0.014999999999999999, 0.054399999999999997, 0.1192)
*it=1000 cost=1.5723474025726318
(0.31230000000000002, 0.7752, 0.95420000000000005)
(0.31669999999999998, 0.77629999999999999, 0.95120000000000005)
*it=2000 cost=1.4288051128387451
(0.3427, 0.80520000000000003, 0.96050000000000002)
(0.32379999999999998, 0.79169999999999996, 0.95689999999999997)
*it=3000 cost=1.344244122505188
(0.35170000000000001, 0.82130000000000003, 0.96399999999999997)
(0.34710000000000002, 0.82389999999999997, 0.9667)
*it=4000 cost=1.2466201782226562
(0.3644, 0.82979999999999998, 0.96999999999999997)
(0.35949999999999999, 0.83150000000000002, 0.97099999999999997)
*it=5000 cost=1.2668870687484741
(0.36409999999999998, 0.83040000000000003, 0.96860000000000002)
(0.36499999999999999, 0.83799999999999997, 0.96889999999999998)
*it=6000 cost=1.2145185470581055
(0.36770000000000003, 0.83550000000000002, 0.97160000000000002)
(0.3740999999999

*it=55000 cost=0.7303171753883362
(0.47010000000000002, 0.92330000000000001, 0.9909)
(0.46639999999999998, 0.91659999999999997, 0.98850000000000005)
*it=56000 cost=0.7285425662994385
(0.47820000000000001, 0.9264, 0.99070000000000003)
(0.45700000000000002, 0.91659999999999997, 0.98939999999999995)
*it=57000 cost=0.6886695027351379
(0.48089999999999999, 0.92769999999999997, 0.99250000000000005)
(0.47289999999999999, 0.92310000000000003, 0.98960000000000004)
*it=58000 cost=0.7043746113777161
(0.48270000000000002, 0.92679999999999996, 0.99029999999999996)
(0.47289999999999999, 0.9194, 0.99080000000000001)
*it=59000 cost=0.6893830299377441
(0.49070000000000003, 0.92930000000000001, 0.99270000000000003)
(0.46729999999999999, 0.9173, 0.98909999999999998)
*it=60000 cost=0.7254316210746765
(0.48070000000000002, 0.92259999999999998, 0.9909)
(0.47670000000000001, 0.9214, 0.99060000000000004)
*it=61000 cost=0.6710036396980286
(0.48759999999999998, 0.93630000000000002, 0.99180000000000001)
(0.4713,

*it=109000 cost=0.5222280025482178
(0.54049999999999998, 0.95960000000000001, 0.99590000000000001)
(0.50880000000000003, 0.93969999999999998, 0.9919)
*it=110000 cost=0.549863338470459
(0.5323, 0.95499999999999996, 0.99580000000000002)
(0.49419999999999997, 0.93130000000000002, 0.99309999999999998)
*it=111000 cost=0.5439617037773132
(0.53590000000000004, 0.95379999999999998, 0.99629999999999996)
(0.49490000000000001, 0.93410000000000004, 0.99019999999999997)
*it=112000 cost=0.5150355696678162
(0.53620000000000001, 0.95989999999999998, 0.997)
(0.50190000000000001, 0.93269999999999997, 0.99119999999999997)
*it=113000 cost=0.4800330102443695
(0.56979999999999997, 0.96450000000000002, 0.99729999999999996)
(0.51100000000000001, 0.93240000000000001, 0.99160000000000004)
*it=114000 cost=0.5049015879631042
(0.55310000000000004, 0.95999999999999996, 0.997)
(0.50280000000000002, 0.93430000000000002, 0.99239999999999995)
*it=115000 cost=0.507064700126648
(0.55289999999999995, 0.95889999999999997, 

*it=163000 cost=0.3724261224269867
(0.61140000000000005, 0.98019999999999996, 0.99870000000000003)
(0.51480000000000004, 0.9375, 0.99239999999999995)
*it=164000 cost=0.36229604482650757
(0.624, 0.98060000000000003, 0.99890000000000001)
(0.50839999999999996, 0.93600000000000005, 0.99129999999999996)
*it=165000 cost=0.3663780391216278
(0.61380000000000001, 0.98360000000000003, 0.99939999999999996)
(0.50619999999999998, 0.9325, 0.9919)
*it=166000 cost=0.36087682843208313
(0.61070000000000002, 0.98350000000000004, 0.99890000000000001)
(0.49830000000000002, 0.93130000000000002, 0.99070000000000003)
*it=167000 cost=0.3651680648326874
(0.62029999999999996, 0.98270000000000002, 0.99950000000000006)
(0.51319999999999999, 0.93330000000000002, 0.98939999999999995)
*it=168000 cost=0.35368871688842773
(0.62409999999999999, 0.9829, 0.999)
(0.50960000000000005, 0.93500000000000005, 0.99209999999999998)
*it=169000 cost=0.3508398234844208
(0.62660000000000005, 0.98260000000000003, 0.99950000000000006)


*it=217000 cost=0.23936264216899872
(0.70499999999999996, 0.99539999999999995, 0.99990000000000001)
(0.49209999999999998, 0.93079999999999996, 0.99119999999999997)
*it=218000 cost=0.2426588386297226
(0.70109999999999995, 0.99650000000000005, 1.0)
(0.49780000000000002, 0.93179999999999996, 0.99060000000000004)
*it=219000 cost=0.24697749316692352
(0.69189999999999996, 0.99619999999999997, 0.99990000000000001)
(0.49480000000000002, 0.9274, 0.99070000000000003)
*it=220000 cost=0.24103181064128876
(0.70309999999999995, 0.99650000000000005, 1.0)
(0.49659999999999999, 0.93410000000000004, 0.99139999999999995)
*it=221000 cost=0.24491152167320251
(0.69520000000000004, 0.99660000000000004, 1.0)
(0.50129999999999997, 0.93069999999999997, 0.99150000000000005)
*it=222000 cost=0.24403434991836548
(0.70089999999999997, 0.996, 1.0)
(0.49430000000000002, 0.92720000000000002, 0.98929999999999996)
*it=223000 cost=0.234504833817482
(0.70809999999999995, 0.99670000000000003, 1.0)
(0.4899, 0.926899999999999

*it=275000 cost=0.15900680422782898
(0.79020000000000001, 0.99919999999999998, 1.0)
(0.47649999999999998, 0.92130000000000001, 0.98870000000000002)
*it=276000 cost=0.14475388824939728
(0.81299999999999994, 0.99990000000000001, 1.0)
(0.47610000000000002, 0.91969999999999996, 0.98850000000000005)
*it=277000 cost=0.14736869931221008
(0.81179999999999997, 0.99939999999999996, 1.0)
(0.47689999999999999, 0.92420000000000002, 0.99019999999999997)
*it=278000 cost=0.14362746477127075
(0.81559999999999999, 1.0, 1.0)
(0.48930000000000001, 0.92259999999999998, 0.9919)
*it=279000 cost=0.14724045991897583
(0.81730000000000003, 0.99980000000000002, 1.0)
(0.48149999999999998, 0.91859999999999997, 0.98929999999999996)
*it=280000 cost=0.14420756697654724
(0.81850000000000001, 1.0, 1.0)
(0.47360000000000002, 0.9254, 0.98980000000000001)
*it=281000 cost=0.1455322504043579
(0.81159999999999999, 0.99960000000000004, 1.0)
(0.47389999999999999, 0.92049999999999998, 0.98960000000000004)
*it=282000 cost=0.14493

*it=336000 cost=0.09614759683609009
(0.88949999999999996, 1.0, 1.0)
(0.4667, 0.9083, 0.98560000000000003)
*it=337000 cost=0.09365596622228622
(0.8992, 1.0, 1.0)
(0.45989999999999998, 0.90920000000000001, 0.98599999999999999)
*it=338000 cost=0.08801426738500595
(0.9073, 1.0, 1.0)
(0.46489999999999998, 0.91539999999999999, 0.98729999999999996)
*it=339000 cost=0.0850159302353859
(0.91259999999999997, 1.0, 1.0)
(0.46810000000000002, 0.91369999999999996, 0.98740000000000006)
*it=340000 cost=0.08716710656881332
(0.90820000000000001, 1.0, 1.0)
(0.46450000000000002, 0.91159999999999997, 0.98780000000000001)
*it=341000 cost=0.08658051490783691
(0.91149999999999998, 1.0, 1.0)
(0.46739999999999998, 0.91110000000000002, 0.98809999999999998)
*it=342000 cost=0.08692701905965805
(0.90810000000000002, 1.0, 1.0)
(0.4698, 0.91749999999999998, 0.98960000000000004)
*it=343000 cost=0.08886852860450745
(0.90329999999999999, 1.0, 1.0)
(0.4592, 0.91410000000000002, 0.98650000000000004)
*it=344000 cost=0.09019