## Sieć dla agenta do gry w Connect4 

In [38]:
import matplotlib.pyplot as plt
%matplotlib inline

import torch 
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

from imp import reload

import DataLoader
reload(DataLoader)

from DataLoader import InMemDataLoader
from DataLoader import C4DataSet

In [20]:
class Model(nn.Module):
    def __init__(self, dp=0.5):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(2, 10, 2, padding=1)
        self.bn1 = nn.BatchNorm2d(10)

        self.fc1 = nn.Linear(10*3*4, 20)
        self.bn2 = nn.BatchNorm1d(20)

        self.fc2 = nn.Linear(20, 3)

    def forward(self, x):

        x = F.max_pool2d(self.conv1(x), 2)
        x = F.relu(self.bn1(x))

        x = x.view(x.shape[0], -1)

        x = self.fc1(x)
        x = F.relu(self.bn2(x))

        x = self.fc2(x)

        return x

    def loss(self, Out, Targets):
      return F.cross_entropy(Out, Targets)

model = Model()
summary(model, (2, 6, 8))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 10, 7, 9]              90
       BatchNorm2d-2             [-1, 10, 3, 4]              20
            Linear-3                   [-1, 20]           2,420
       BatchNorm1d-4                   [-1, 20]              40
            Linear-5                    [-1, 3]              63
Total params: 2,633
Trainable params: 2,633
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.01
Estimated Total Size (MB): 0.02
----------------------------------------------------------------


In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
amount_of_games = 1000 
moves_observed  = 10 
all_samples = amount_of_games * moves_observed

batch_size = 64
train_size, val_size, test_size = int(all_samples/3), int(all_samples/3), int(all_samples/3)
amount_of_train_batches = train_size / batch_size

dataset = C4DataSet(amount_of_games, moves_observed).create_data_set()

train_set = dataset[:train_size]
val_set = dataset[train_size:train_size+val_size] 
test_set = dataset[train_size+val_size:]

data_loaders = {
    "train": InMemDataLoader(train_set, batch_size=batch_size, shuffle=True),
    "valid": InMemDataLoader(val_set, batch_size=batch_size, shuffle=False),
    "test": InMemDataLoader(test_set, batch_size=batch_size, shuffle=False),
}

100%|██████████| 3333/3333 [00:00<00:00, 62831.99it/s]
100%|██████████| 3333/3333 [00:00<00:00, 18828.14it/s]
100%|██████████| 3334/3334 [00:00<00:00, 58806.40it/s]


In [40]:
print(train_set[0])

(tensor([[[0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 1., 0.],
         [0., 1., 1., 0., 0., 1., 1., 0.],
         [0., 1., 0., 1., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 1., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 1., 1., 0.],
         [0., 1., 0., 0., 0., 1., 0., 0.],
         [1., 0., 0., 1., 0., 0., 0., 0.],
         [1., 0., 1., 0., 0., 0., 1., 0.]]]), 1)


In [23]:
def compute_error_rate(model, data_loader, device="cpu"):
  model.eval()

  num_errs, num_examples = 0, 0
  with torch.no_grad():
    for batch in data_loader:

      x, y = batch[0].to(device), batch[1].to(device)
      out = model(x)

      _, pred = out.max(dim=1)
      num_errs += (pred != y.data).sum().item()
      num_examples += x.size(0)

  return num_errs / num_examples

In [25]:
def train(num_of_epochs, train_loader, opt, print_every=10, device="cpu"):
  model.train()

  for data_loader in data_loaders.values():
    if isinstance(data_loader, InMemDataLoader):
        data_loader.to(device)

  min_batch_err = 100

  iter = 0
  for e in range(num_of_epochs):
    model.train()
    print(f"Epoch {e+1}")

    for batch in train_loader:
      x = batch[0].to(device)
      y = batch[1].to(device)
      opt.zero_grad()
      iter += 1

      out = model(x)
      loss = nn.CrossEntropyLoss()(out, y)
      loss.backward()
      opt.step()

      _, pred = out.max(dim=1)
      batch_err = (pred != y).sum().item() / out.size(0)
      min_batch_err = min(min_batch_err, batch_err)

      if iter % print_every == 0:
        print(f"iter = {iter}, batch_err = {batch_err * 100.0}")

    val_err = compute_error_rate(model, data_loader=data_loaders["valid"], device=device)
    print(f"val err = {100*val_err:.2f}")
    # wandb.log({'val_err': val_err, 'epoch': e+1})

In [26]:
def initialize_weights(model):
    with torch.no_grad():
        for name, p in model.named_parameters():
            if 'weight' in name:
                if 'conv' in name:
                    f_in = p.shape[1]*p.shape[2]*p.shape[3]
                    p.normal_(0, torch.sqrt(torch.tensor(2./f_in)))
                elif 'bn' in name:
                    p = torch.ones_like(p)
                elif 'fc' in name:
                    f_in = p.shape[1]
                    p.normal_(0, torch.sqrt(torch.tensor(2./f_in)))
                else:
                    raise Exception('weird weight')

            elif 'bias' in name:
                p.zero_()
            else:
                raise Exception('weird parameter')

In [58]:
lr = 0.0001
weight_decay = 0.001
momentum = 0.9
epochs = 40
_device = "cpu"
_print_every = 2

opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay = weight_decay, momentum = momentum)

In [60]:
model = Model().to(device)
initialize_weights(model)

train(num_of_epochs=epochs, train_loader=data_loaders["train"], opt=opt, print_every=_print_every, device=_device)
test_error_rate = compute_error_rate(model, data_loaders["test"], device=_device)

Epoch 1
iter = 2, batch_err = 70.3125
iter = 4, batch_err = 71.875
iter = 6, batch_err = 59.375
iter = 8, batch_err = 60.9375
iter = 10, batch_err = 56.25
iter = 12, batch_err = 67.1875
iter = 14, batch_err = 64.0625
iter = 16, batch_err = 65.625
iter = 18, batch_err = 65.625
iter = 20, batch_err = 60.9375
iter = 22, batch_err = 59.375
iter = 24, batch_err = 59.375
iter = 26, batch_err = 62.5
iter = 28, batch_err = 65.625
iter = 30, batch_err = 68.75
iter = 32, batch_err = 76.5625
iter = 34, batch_err = 53.125
iter = 36, batch_err = 68.75
iter = 38, batch_err = 68.75
iter = 40, batch_err = 60.9375
iter = 42, batch_err = 60.9375
iter = 44, batch_err = 51.5625
iter = 46, batch_err = 65.625
iter = 48, batch_err = 57.8125
iter = 50, batch_err = 60.9375
iter = 52, batch_err = 71.875
val err = 66.85
Epoch 2
iter = 54, batch_err = 67.1875
iter = 56, batch_err = 60.9375
iter = 58, batch_err = 64.0625
iter = 60, batch_err = 76.5625
iter = 62, batch_err = 65.625
iter = 64, batch_err = 64.0625
it