# Policy Value Network

## Policy Value Network Architecture

In [1]:
import torch
import numpy as np
import torch.nn as nn

class PolicyValueNetwork(nn.Module):



    def __init__(self, kernel, padding=0):
        super(PolicyValueNetwork, self).__init__()
        self.VERBOSE = False

        # Define the convolutional and batch normalisation layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=kernel , padding=padding)
        self.bnorm1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=kernel , padding=padding)
        self.bnorm2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=kernel , padding=padding)
        self.bnorm3 = nn.BatchNorm2d(128)

        # Define the fully connected layers
        self.fc1 = nn.Linear(128 * 9 * 9, 512)
        self.fc_policy = nn.Linear(512, 9*9+1)
        self.fc_value = nn.Linear(512, 1)

    def forward(self, x):
        VERBOSE = self.VERBOSE
        if VERBOSE: print('VERBOSE',VERBOSE)
        if VERBOSE: print(f'>> forward >>>>>>>>>>>>>>>>>>')
        if VERBOSE: print(f'>>>>>>>>>> {x.shape, x.dtype}')
        if VERBOSE: print(f'>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')

        # Pass the input through the convolutional layers
        if VERBOSE: print(f'  >> conv1', end=' ')
        x = self.conv1(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> batch', end=' ')
        x = self.bnorm1(x)
        if VERBOSE: print(f'  >>> {x.shape}')
            
        if VERBOSE: print(f'  >> relu', end=' ')
        x = nn.ReLU()(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        
        
        if VERBOSE: print(f'  >> conv2', end=' ')
        x = self.conv2(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> batch', end=' ')
        x = self.bnorm2(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> relu', end=' ')
        x = nn.ReLU()(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        
        if VERBOSE: print(f'  >> conv3', end=' ')
        x = self.conv3(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> batch', end=' ')
        x = self.bnorm3(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> relu', end=' ')
        x = nn.ReLU()(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        

        # Flatten the output and pass it through the fully connected layers
        if VERBOSE: print(f'  >> flatten', end=' ')
        x = x.view(x.size(0), -1)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> fc1', end=' ')
        x = self.fc1(x)
        if VERBOSE: print(f'  >>> {x.shape}')
        
        if VERBOSE: print(f'  >> relu', end=' ')
        x = nn.ReLU()(x)
        if VERBOSE: print(f'  >>> {x.shape}')

        # Compute the policy and value outputs
        if VERBOSE: print(f'  >> compute output', end=' ')
        policy = nn.Softmax(1)(self.fc_policy(x))        
        value = self.fc_value(x)
        # print(value.shape)
        
        if VERBOSE: print(f'  >>> policy: {policy.shape}', end=' ')
        if VERBOSE: print(f'  >>> value: {value.shape}')

        if VERBOSE: print(f'>> forward complete >>>>>>>>>>>>>>>>')
        return policy, value

    def set_verbose(self, verbose):
        self.VERBOSE = verbose
        return
    

model = PolicyValueNetwork(kernel=3, padding=1)
model.set_verbose(True)

data = torch.load('board_data.pt')
data = data[:512].reshape(-1,1,9,9)
model.forward(data)


VERBOSE True
>> forward >>>>>>>>>>>>>>>>>>
>>>>>>>>>> (torch.Size([512, 1, 9, 9]), torch.float32)
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  >> conv1   >>> torch.Size([512, 32, 9, 9])
  >> batch   >>> torch.Size([512, 32, 9, 9])
  >> relu   >>> torch.Size([512, 32, 9, 9])
  >> conv2   >>> torch.Size([512, 64, 9, 9])
  >> batch   >>> torch.Size([512, 64, 9, 9])
  >> relu   >>> torch.Size([512, 64, 9, 9])
  >> conv3   >>> torch.Size([512, 128, 9, 9])
  >> batch   >>> torch.Size([512, 128, 9, 9])
  >> relu   >>> torch.Size([512, 128, 9, 9])
  >> flatten   >>> torch.Size([512, 10368])
  >> fc1   >>> torch.Size([512, 512])
  >> relu   >>> torch.Size([512, 512])
  >> compute output   >>> policy: torch.Size([512, 82])   >>> value: torch.Size([512, 1])
>> forward complete >>>>>>>>>>>>>>>>


(tensor([[0.0120, 0.0121, 0.0147,  ..., 0.0091, 0.0117, 0.0118],
         [0.0118, 0.0118, 0.0145,  ..., 0.0093, 0.0119, 0.0117],
         [0.0116, 0.0115, 0.0157,  ..., 0.0105, 0.0134, 0.0108],
         ...,
         [0.0096, 0.0114, 0.0143,  ..., 0.0084, 0.0133, 0.0135],
         [0.0105, 0.0111, 0.0158,  ..., 0.0092, 0.0143, 0.0138],
         [0.0107, 0.0118, 0.0157,  ..., 0.0092, 0.0150, 0.0138]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[ 8.2279e-02],
         [ 8.0492e-02],
         [ 1.2844e-01],
         [ 9.5193e-02],
         [ 7.6025e-02],
         [ 5.2270e-02],
         [-8.8554e-02],
         [-3.9973e-02],
         [-3.0436e-02],
         [-4.7412e-02],
         [-8.2306e-04],
         [ 1.4189e-02],
         [ 1.5434e-01],
         [ 1.1350e-01],
         [ 5.6272e-02],
         [ 3.3449e-02],
         [-2.2195e-03],
         [ 2.7218e-02],
         [-4.4499e-02],
         [-7.2663e-03],
         [-1.7305e-03],
         [-5.1940e-02],
         [-7.3804e-02],
      

## Network Training 

In [6]:
#TODO: check is pass move is accounted for

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import time

# get data
inputs = torch.load('board_data.pt')
# print(inputs.dtype)
inputs = inputs.reshape(-1,1,9,9)

plabels = torch.load('plabels.pt')
vlabels = torch.load('vlabels.pt')
# print(vlabels.shape)
labels = torch.cat((plabels, vlabels), dim=1)

print('Preprocess ::::::::::::::::::::')
print(':::  Splitting :::')
print(f'::: :::  before splitting:', tuple(inputs.shape), tuple(labels.shape))
X_train, X_test, y_train, y_test = train_test_split(inputs, labels, test_size=0.2)
print(f'::: :::  after splitting:', tuple(X_train.shape), tuple(X_test.shape), tuple(y_train.shape), tuple(y_test.shape) )

# Create a TensorDataset from the data
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# Create a DataLoader with a batch size of 32
print(':::  Batch processing :::')
batch_size = 1024*2
iters = int(np.ceil(X_train.shape[0] / batch_size))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f'::: :::  train loader batch size: {batch_size}')
print(f'::: :::  train loader iters: {iters}')



Preprocess ::::::::::::::::::::
:::  Splitting :::
::: :::  before splitting: (608160, 1, 9, 9) (608160, 2)
::: :::  after splitting: (486528, 1, 9, 9) (121632, 1, 9, 9) (486528, 2) (121632, 2)
:::  Batch processing :::
::: :::  train loader batch size: 2048
::: :::  train loader iters: 238


In [5]:
import torch.optim as optim
import torch.nn.functional as F

model.set_verbose(False)


# Criterions 
def pcriterion(outputs, labels):
    return F.cross_entropy(outputs, labels.long())

def vcriterion(outputs, labels):
    return F.mse_loss(outputs, labels)

def pvcriterion(outputs, labels, alpha=0.5):
    poutputs, voutputs = outputs
    plabels = labels[:,0]#.squeeze()
    vlabels = labels[:,1].view((-1,1))
    # print(vlabels)
    
    vcrit = vcriterion(voutputs, vlabels)
    pcrit = pcriterion(poutputs, plabels)
    
    return vcrit + alpha * vcrit

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Training loop
print('\nTraining ::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::')
EPOCHS = 2
for epoch in range(EPOCHS):
    start_time = time.time()
    for i, (inputs, labels) in enumerate(train_loader):
        iter_time = time.time()
        optimizer.zero_grad()
        outputs = model(inputs)
        # print(outputs[0].dtype)
        # print(outputs[1].dtype)
        # print(labels[0].dtype)
        # print(labels[1].dtype)
        loss = pvcriterion(outputs, labels)
        # assert False

        loss.backward()
        optimizer.step()
        # if i%50 == 0:
        print(f':::  Epoch {epoch+1} ::: batch_iter({i+1}/{iters}) ::: loss:{loss.item():.4f} ::: time: {time.time()-start_time:.2f}s ::: iter_time: {time.time()-iter_time:.4f}s', end='\r')
    print('\n', end='')



Training ::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
:::  Epoch 1 ::: batch_iter(951/951) ::: loss:0.2152 ::: time: 551.47s ::: iter_time: 0.2041s
:::  Epoch 2 ::: batch_iter(902/951) ::: loss:0.2160 ::: time: 460.60s ::: iter_time: 0.4842s