In [30]:
from board2 import Board2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import random
from controller import ActionController
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [48]:
class FrodoPolicy(nn.Module):
    def __init__(self, action_space):
        super(FrodoPolicy, self).__init__()

        self.nl = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(144, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, action_space),
            nn.LogSoftmax(dim=-1))

    def forward(self, x):
        return self.nl(x)


class GameRunner:

    def run(self, player, enemy, runs=650, explore=True, explore_rate=0.1):
        value = 0
        history = []
        board = Board2()
        controller = ActionController(board)
        for i in range(runs):

            if controller.is_win():
                value = 1
                break
            elif controller.is_lose():
                value = -1
                break
            elif controller.is_block():
                value = -0.5
                break

            with torch.no_grad():
                state = board.grid.clone().unsqueeze(0)

                if explore and random.random() < explore_rate:
                    player_action = random.randint(0, controller.get_action_space() - 1)
                else:
                    player_action = torch.argmax(player(state).exp()).item()
                history.append((state, player_action))
                controller.execute_action(player_action)
                enemy_action = enemy(state).exp()
                board.swap_enemy()
                controller.execute_action(torch.argmax(enemy_action).item())
                board.swap_enemy()
                board.step()

        return history, value

    def test(self, player, enemy, runs=650, battles=10):
        wins = 0
        losses = 0
        draws = 0

        for _ in range(battles):
            history, value = self.run(player, enemy, runs, explore=False)
            if value == 1:
                wins += 1
            elif value == -1:
                losses += 1
            else:
                draws += 1

        return wins, losses, draws

class FrodoTrainer:

    def train(self, runs=100):
        board = Board2()
        controller = ActionController(board)

        player = FrodoPolicy(controller.get_action_space())
        enemy = FrodoPolicy(controller.get_action_space())

        game_runner = GameRunner()

        for i in tqdm(range(runs)):
            self.train_iteration(player, enemy, game_runner, controller)

            if i % 10 == 0:
                wins, losses, draws = game_runner.test(player, enemy)
                print(f'Wins: {wins}, Losses: {losses}, Draws: {draws}')

    def train_iteration(self, player, enemy, game_runner, controller):
        train_boards = []
        train_actions = []

        for _ in range(10):
            history, value = game_runner.run(player, enemy)

            boards, actions = zip(*history)

            if value < 1:
                actions = random.choices(list(range(0, controller.get_action_space())), k=len(actions))

            train_boards.extend(boards)
            train_actions.extend(actions)


        dataset = TensorDataset(torch.cat(train_boards, dim=0), torch.LongTensor(train_actions))

        crit = nn.NLLLoss()
        optimizer = optim.Adam(player.parameters(), lr=0.001)
        lh = []
        for e in range(100):
            el = 0
            for b, a in DataLoader(dataset, batch_size=32):
                player.zero_grad()
                output = player(b)
                loss = crit(output, a)
                loss.backward()
                el += loss.item()
                optimizer.step()
            lh.append(el / len(dataset))

trainer = FrodoTrainer()

In [49]:
trainer.train(500)

  0%|          | 0/500 [00:00<?, ?it/s]

Wins: 0, Losses: 1, Draws: 9
Wins: 0, Losses: 0, Draws: 10


KeyboardInterrupt: 