In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
data = pd.read_csv('../datasets/connect4data.csv', index_col=False)
data.head()

In [None]:
ENTRIES = data.shape[0]

In [None]:
states = np.array(data[["board", "-"] + [f"-.{i}" for i in range(1, 83)]])
policies = np.array(data[["policy"] + [f"-.{i}" for i in range(83, 89)]])
states[0]
firstpolicy = policies[0]

In [None]:
states = states.reshape(ENTRIES, 6, 7, 2)  # reshape to 6x7x2 for cnn

In [None]:
data = list(zip(states, policies))

In [None]:
# shuffle data
np.random.shuffle(data)
# split into train and test
train_data = data[:int(ENTRIES*0.8)]
test_data = data[int(ENTRIES*0.8):]

In [None]:
from net import get_model
model = get_model()
model.summary()

In [None]:
# train model
xs = np.array([x for x, _ in train_data])
ys = np.array([y for _, y in train_data])
xs_test = np.array([x for x, _ in test_data])
ys_test = np.array([y for _, y in test_data])

In [None]:
# tensorboard
import datetime
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir, histogram_freq=1)


In [None]:
model.fit(xs, ys, epochs=500, validation_data=(xs_test, ys_test), batch_size=32, callbacks=[tensorboard_callback])

In [None]:
# test loss on both test and train data
loss_test = model.evaluate(xs_test, ys_test)
loss_train = model.evaluate(xs, ys)
print(f"Test loss: {loss_test}")
print(f"Train loss: {loss_train}")

In [None]:
# for the starting position, we expect this dist
plt.bar(range(7), firstpolicy)

In [None]:
mock_board = np.zeros((6, 7, 2)) # looks like the starting position
dist = model.predict(mock_board.reshape(1, 6, 7, 2))[0]
# plot dist as a bar chart
plt.bar(range(7), dist)

In [None]:
mock_board[5, 3, 0] = 1  # place a piece in column 4 (index 3)
dist = model.predict(mock_board.reshape(1, 6, 7, 2))[0]
# plot dist as a bar chart
plt.bar(range(7), dist)


In [None]:
mock_board[4, 3, 1] = 1  # place a piece in column 4 (index 3)
mock_board[5, 2, 0] = 1  # place a piece in column 3 (index 2)
dist = model.predict(mock_board.reshape(1, 6, 7, 2))[0]
# plot dist as a bar chart
plt.bar(range(7), dist)


In [None]:
## model.save('direct_conv_policy.h5')