In [28]:
import os
import torch
import numpy as np
from matplotlib import pyplot as plt

from model import ZeroTTT
from database import DataBase

In [29]:
# Hyperparams:
epochs = 1
lr = 3e-4
weight_decay = 1e-4

batch_size=100
'''
Remember that last index of eigth batch games is 218
                            ninth batch games is 248
                            10th batch games is 288
'''

'\nRemember that last index of eigth batch games is 218\n                            ninth batch games is 248\n                            10th batch games is 288\n'

In [30]:
database = DataBase()
db_path = "/storage/replay_buffer"
model = ZeroTTT(brain_path=None, opt_path=None, lr=lr, weight_decay=weight_decay, board_len=10)

In [31]:
print(model.get_parameter_count())

183432


In [32]:
state_paths = [os.path.join(db_path, "states", name) for name in sorted(os.listdir("/storage/replay_buffer/states"))]
policy_paths = [os.path.join(db_path, "policy_labels", name) for name in sorted(os.listdir("/storage/replay_buffer/policy_labels"))]
value_paths = [os.path.join(db_path, "value_labels", name) for name in sorted(os.listdir("/storage/replay_buffer/value_labels"))]

names = list(zip(state_paths, policy_paths, value_paths))
filtered_names = []

for i in range(len(names)):
    index = int(names[i][0].split("_")[-1][:-4])
    if index > -1:
        filtered_names.append(names[i])
names = filtered_names
#test_set = names[-2:]
#names = names[:-2]
test_set = []

In [33]:
len(names)

40

In [34]:
def test_loss(model):
    model.brain.eval()
    total_p_loss = 0.0
    total_v_loss = 0.0
    for s, p, v in test_set:
        batch_sts, batch_pls, batch_vls = database.prepare_batches(batch_size=batch_size, from_memory_paths=(s, p, v))
        for b_nr in range(len(batch_sts)):
            batch_st, batch_pl, batch_vl = batch_sts[b_nr], batch_pls[b_nr], batch_vls[b_nr]
            
            batch_pl = torch.from_numpy(batch_pl).to(model.device)
            batch_vl = torch.from_numpy(batch_vl).float().to(model.device)
            prob, val = model.predict(batch_st, interpret_policy=False)
            val = val.flatten()

            p_loss = model.policy_loss(prob, batch_pl)
            v_loss = model.value_loss(val, batch_vl)
        
            total_p_loss += p_loss.item()
            total_v_loss += v_loss.item()
    return total_p_loss/(len(batch_sts)*len(test_set)), total_v_loss/(len(batch_sts)*len(test_set))

In [35]:
train_policy_losses = []
train_value_losses = []
test_policy_losses = []
test_value_losses = []
for e in range(epochs):
    model.brain.train()
    cumulative_policy_epoch_loss = 0.0
    cumulative_value_epoch_loss = 0.0
    for s_name, p_name, v_name in names:
        batch_sts, batch_pls, batch_vls = database.prepare_batches(batch_size=batch_size, from_memory_paths=(s_name, p_name, v_name))
        for b_nr in range(len(batch_sts)):
            model.optimizer.zero_grad()
            batch_st, batch_pl, batch_vl = batch_sts[b_nr], batch_pls[b_nr], batch_vls[b_nr]
            
            batch_pl = torch.from_numpy(batch_pl).to(model.device)
            batch_vl = torch.from_numpy(batch_vl).float().to(model.device)
            prob, val = model.predict(batch_st, interpret_output=False)
            val = val.flatten()

            p_loss = model.policy_loss(prob, batch_pl)
            v_loss = model.value_loss(val, batch_vl)
                        
            cumulative_policy_epoch_loss += p_loss.item()
            cumulative_value_epoch_loss += v_loss.item()

            loss = p_loss + v_loss
            loss.backward()
   
            model.optimizer.step()
    
    # Loss on test set:
    cumulative_policy_epoch_loss /= len(names)*len(batch_sts) # div by batch count
    cumulative_value_epoch_loss /= len(names)*len(batch_sts)
    # test_epoch_policy_loss, test_epoch_value_loss = test_loss(model)
    print(f"Epoch #{e} train policy loss: {cumulative_policy_epoch_loss} | train value loss: {cumulative_value_epoch_loss}")
    # print(f"Test policy loss: {test_epoch_policy_loss} | Test value loss: {test_epoch_value_loss}")
    train_policy_losses.append(cumulative_policy_epoch_loss)
    train_value_losses.append(cumulative_value_epoch_loss)
    # test_policy_losses.append(test_epoch_policy_loss)
    # test_value_losses.append(test_epoch_value_loss)
    # Checkpoint:
    model.save_brain("trained_model_0", "trained_opt_state_0")

Epoch #0 train policy loss: 4.169333491732179 | train value loss: 0.542679855262395
Saving brain...


In [None]:
plt.plot(train_value_losses)
plt.plot(test_value_losses)
plt.show()


In [None]:
plt.plot(train_policy_losses)
plt.plot(test_policy_losses)
plt.show()