In [1]:
import torch

from main import play_sample_game_with_model
from neural_network_utils import load_npy_data
from torch_model import Model, train_model


In [2]:
# Load the data
states, policies, values = load_npy_data()

# Then convert the data into PyTorch tensors
states = torch.from_numpy(states).float().cuda()
policies = torch.from_numpy(policies).float().cuda()
values = torch.from_numpy(values).float().cuda()

print("states.shape:", states.shape)
print("policies.shape:", policies.shape)
print("values.shape:", values.shape)

states.shape: torch.Size([32698, 3, 6, 7])
policies.shape: torch.Size([32698, 7])
values.shape: torch.Size([32698])


In [3]:
# Create a function to group the data into batches
def create_batches(states, policies, values, batch_size=64):
    batches_count = states.shape[0] // batch_size
    batched_data = []

    for batch_index in range(batches_count):
        start_index = 64 * batch_index
        end_index = 64 * (batch_index + 1)

        batch_states = states[start_index: end_index]
        batch_policies = policies[start_index: end_index]
        batch_values = values[start_index: end_index]

        batch_data = (batch_states, (batch_policies, batch_values))
        batched_data.append(batch_data)

    return batched_data


# Create the batches
train_data = create_batches(states, policies, values)

In [4]:
# Instantiate the model
model = Model().cuda()

# Train the model
train_model(model, train_data)

[0, 0] loss: 0.006490335464477539
[0, 200] loss: 1.0426629739627242
[0, 400] loss: 0.9975074104219676
[1, 0] loss: 0.0055049598217010495
[1, 200] loss: 1.0275124653056265
[1, 400] loss: 0.9885934852808714
[2, 0] loss: 0.005458171963691711
[2, 200] loss: 1.0256898556277156
[2, 400] loss: 0.9849565913900733
[3, 0] loss: 0.0054616773128509525
[3, 200] loss: 1.0225708158314228
[3, 400] loss: 0.9821840541064739
[4, 0] loss: 0.005450524091720581
[4, 200] loss: 1.0190478206425906
[4, 400] loss: 0.978538531512022
[5, 0] loss: 0.005431268215179443
[5, 200] loss: 1.0176702028140425
[5, 400] loss: 0.9782892985641957
[6, 0] loss: 0.005433815717697144
[6, 200] loss: 1.0085655111819505
[6, 400] loss: 0.9775990724936128
[7, 0] loss: 0.005417351126670838
[7, 200] loss: 0.9990505203604698
[7, 400] loss: 0.973168509863317
[8, 0] loss: 0.005447244644165039
[8, 200] loss: 0.9868611437082291
[8, 400] loss: 0.9620297027379274
[9, 0] loss: 0.005465421676635742
[9, 200] loss: 0.9700974923372269
[9, 400] loss:

In [5]:
# Train the model
train_model(model, train_data, epochs=20, learning_rate=0.0001, print_stats_every=10)

[0, 0] loss: 0.10973556041717529
[0, 10] loss: 0.8953492730855942
[0, 20] loss: 0.8817059189081192
[0, 30] loss: 1.0232446014881134
[0, 40] loss: 1.0223386585712433
[0, 50] loss: 0.8876881301403046
[0, 60] loss: 0.9386918127536774
[0, 70] loss: 0.9879222273826599
[0, 80] loss: 0.9937357127666473
[0, 90] loss: 1.0299693405628205
[0, 100] loss: 1.0386572420597076
[0, 110] loss: 0.9204074829816818
[0, 120] loss: 0.8544064372777939
[0, 130] loss: 0.8889709003269672
[0, 140] loss: 1.0358852565288543
[0, 150] loss: 0.8666288882493973
[0, 160] loss: 1.0046903192996979
[0, 170] loss: 0.9890736222267151
[0, 180] loss: 0.9591405779123306
[0, 190] loss: 0.9530097842216492
[0, 200] loss: 0.9139062702655792
[0, 210] loss: 0.8773505181074143
[0, 220] loss: 0.9386531949043274
[0, 230] loss: 1.0333682119846344
[0, 240] loss: 0.8629172146320343
[0, 250] loss: 1.0407519459724426
[0, 260] loss: 0.7007297813892365
[0, 270] loss: 0.989693284034729
[0, 280] loss: 1.0547072649002076
[0, 290] loss: 0.90106284

KeyboardInterrupt: 

In [14]:
play_sample_game_with_model(model)

NN Policy: tensor([ 0.2509,  0.2600,  0.2348,  0.1534, -0.0413,  0.1108,  0.1588])
NN Value: tensor([-0.9441])
Turn: White
0  1  2  3  4  5  6
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  X  .  .  .  .  .  

NN Policy: tensor([ 0.0101,  0.2731,  0.2716,  0.2458,  0.1001,  0.2288, -0.1874])
NN Value: tensor([0.7874])
Turn: Black
0  1  2  3  4  5  6
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  O  .  .  .  .  .  
.  X  .  .  .  .  .  

NN Policy: tensor([ 0.0138,  0.2469,  0.2222,  0.3938,  0.1455,  0.1855, -0.0469])
NN Value: tensor([-0.9816])
Turn: White
0  1  2  3  4  5  6
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  .  .  .  .  .  .  
.  O  .  .  .  .  .  
.  X  .  X  .  .  .  

NN Policy: tensor([-0.1136,  0.2102,  0.1052,  0.2971,  0.2767,  0.3064, -0.0760])
NN Value: tensor([-0.9961])
Turn: Black
0  1  2  3  4  5  6
.  .  .  .  .  .  .  
.  .  . 