In [1]:
# %cd /home/gsantoss/frodo

from board3 import Board3, sqr_distance, empty_cells, tod_cells
from controller3 import ActionController, MW_CELLS
import time
from heapq import heappush, heappop
from tqdm.auto import tqdm
import random
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from torch.distributions.categorical import Categorical
from collections import Counter, deque
from mcts import search
from nnl import gather_history
import concurrent.futures
import torch.multiprocessing as tmp

In [2]:
def to_emb(board: Board3):

    py, px = board.get_player_position()
    ey, ex = board.get_enemy_position()
    ty, tx = board.get_todd_position()

    mws = [0] * 16

    for (y, x) in board.mw:
        mws[y * 4 + x] = 1

    pe = torch.Tensor([[py, px, ey, ex, ty, tx]]) / 3
    mwe = torch.Tensor([mws])

    return torch.cat([pe, mwe], dim=1)


def emb_mem(mem, nc=2):
    fe = []
    for b, a in mem:
        e1 = to_emb(b)
        e2 = nn.functional.one_hot(torch.LongTensor([a]), num_classes=nc)
        fe.append(torch.cat([e1, e2], dim=1))

    return torch.cat(fe, dim=0).unsqueeze(0)

In [3]:
def pos_encode(max_len, d_model, dtype=torch.float32):
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, d_model, dtype=dtype)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


class FP(nn.Module):
    def __init__(self, n_dim, a_space, m_space, e_dim=384, ff_dim=1024, n_layers=4, n_heads=12, max_len=15):
        super(FP, self).__init__()

        self.fl = nn.Sequential(
            nn.Linear(n_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, e_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

        self.fh = nn.Sequential(
            nn.Linear(m_space, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, e_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

        self.dec_l = nn.TransformerDecoderLayer(d_model=e_dim, nhead=n_heads, dim_feedforward=ff_dim, batch_first=True)
        self.dec = nn.TransformerDecoder(self.dec_l, num_layers=n_layers)

        self.pe = pos_encode(max_len, e_dim)

        self.fc = nn.Sequential(
            nn.Linear(e_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, a_space),
            nn.LogSoftmax(dim=-1)
        )

        self.fv = nn.Sequential(
            nn.Linear(e_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
            nn.Tanh()
        )

    def forward(self, x, h):
        nx = self.fl(x)
        nh = self.fh(h) + self.pe[:h.shape[1], :].unsqueeze(0).to(h.device)
        hidden = self.dec(nx.unsqueeze(1), nh).mean(dim=1)
        return self.fc(hidden), self.fv(hidden)


In [4]:

class NNT:

    def __init__(self, board, player, m1, m2):
        self.board = board
        self.controller = ActionController(board)
        self.children = None
        self.player = player
        self.action = None
        self.prob = 0
        self.value = 0
        self.n = 0
        self.m1 = m1
        self.m2 = m2
        pass


    def get_winner(self):
        if self.controller.is_win():
            return 1
        if self.controller.is_lose():
            return -1
        if self.controller.is_block():
            return -1

        return None

    def search(self, fp):

        winner = self.get_winner()
        if winner is not None:
            self.n += 1
            self.value = -winner
            return self.value


        if self.children is None:
            self.expand(fp)

            return -self.value

        cs = sum([x.n for x in self.children])
        sv = max(self.children, key=lambda x: x.value + 0.5 * x.prob * (math.sqrt(cs ) / (x.n + 1)))
        res = sv.search(fp)
        self.value = (self.value * self.n + res) / (self.n + 1)
        self.n += 1

        return -res

    def expand(self, fp):
        self.children = []

        if self.player == 1:
            m = self.m1 + 0
        else:
            m = self.m2 + 0

        with torch.no_grad():
            o, v = fp(to_emb(self.board), m)

        vs = []
        for p in self.controller.get_available_moves():
            b = self.board.copy()
            ActionController(b).execute_action(p)
            b.step(500)

            if self.player == 1:
                m1 = torch.cat([self.m1[:, 1:, :], emb_mem([(b, p)], ActionController.get_action_space())], dim=1)
                m2 = self.m2
            else:
                m1 = self.m1
                m2 = torch.cat([self.m2[:, 1:, :], emb_mem([(b, p)], ActionController.get_action_space())], dim=1)

            nt = NNT(b, -self.player, m1, m2)
            nt.action = p
            nt.prob = o[:, p].exp().item()
            nt.value = 0

            if nt.controller.is_win():
                nt.value = 1
            elif nt.controller.is_lose():
                nt.value = -1
            elif nt.controller.is_block():
                nt.value = -1

            nt.n = 1
            vs.append(nt.prob)
            b.swap_enemy()
            self.children.append(nt)

        self.value = v.item()
        self.n = len(vs)

    def get_policy(self):
        p = [0] * ActionController.get_action_space()
        for x in self.children:
            p[x.action] = x.n

        return torch.Tensor(p) / sum(p)


In [5]:

def run_episode(fp, mem_length = 10, frames=100, mcts_iter=50):
    fp.eval()
    b = Board3(walk_time=200)

    y1 = emb_mem([(b, 0)] * mem_length, ActionController.get_action_space())
    b.swap_enemy()
    y2 = emb_mem([(b, 0)] * mem_length, ActionController.get_action_space())
    b.swap_enemy()

    nt = NNT(b, 1, y1, y2)

    h = []

    for _ in range(frames):
        for _ in range(mcts_iter):
            nt.search(fp)

        mt = max(nt.children, key=lambda x: x.n)
        nb = mt.board.copy()
        nb.swap_enemy()

        if nt.player == 1:
            h.append((to_emb(nb), nt.m1, nt.get_policy()))
        else:
            h.append((to_emb(nb), nt.m2, nt.get_policy()))

        act = ActionController(nb)

        if act.is_win():
            r = reversed([1 if i % 2 == 0 else -1 for i in range(len(h))])
            return [(x, y, z, v) for (x, y, z), v in zip(h, r)]
        elif act.is_lose():
            r = reversed([-1 if i % 2 == 0 else 1 for i in range(len(h))])
            return [(x, y, z, v) for (x, y, z), v in zip(h, r)]
        elif act.is_block():
            r = reversed([-1 if i % 2 == 0 else 0 for i in range(len(h))])
            return [(x, y, z, v) for (x, y, z), v in zip(h, r)]

        nt = mt

    return [(x, y, z, 0) for x, y, z in h]


def train_nn(data, fp, epochs = 5, lr=0.001, batch_size=64):
    b, m, p, r = list(zip(*data))
    dataset = TensorDataset(torch.cat(b, dim=0), torch.cat(m, dim=0), torch.stack(p), torch.Tensor(r))
    nf = nfp()
    # nf.load_state_dict(fp.state_dict(), strict=False)
    # nf.load_state_dict(torch.load('models/fp.pth', map_location=torch.device('cpu'), weights_only=True), strict=False)

    optimizer = optim.AdamW(nf.parameters(), lr=lr)
    crit1 = nn.CrossEntropyLoss()
    crit2 = nn.MSELoss()

    lh = []
    for _ in range(epochs):
        el = []
        for b, m, p, r in DataLoader(dataset, batch_size=batch_size, shuffle=True):
            optimizer.zero_grad()
            o, v = nf(b, m)
            l1 = crit1(o, p)
            l2 = crit2(v.squeeze(1), r)
            l = l1 + l2
            l.backward()
            optimizer.step()
            el.append(l.item())
        lh.append(sum(el) / len(el))


    # plt.plot(lh)
    # plt.show()
    return nf


def run_battle(fp, nf, mem_length = 10, frames=100):

    fp.eval()
    nf.eval()

    b = Board3(walk_time=200)

    y1 = emb_mem([(b, 0)] * mem_length, ActionController.get_action_space())
    b.swap_enemy()
    y2 = emb_mem([(b, 0)] * mem_length, ActionController.get_action_space())
    b.swap_enemy()

    act = ActionController(b)

    for _ in range(frames):


        with torch.no_grad():
            o, v = fp(to_emb(b), y1)
            a1 = o.exp().argmax().item()
            y1 = torch.cat([y1[:, 1:, :], emb_mem([(b, a1)], ActionController.get_action_space())], dim=1)
        act.execute_action(a1)
        b.step(500)

        if act.is_win():
            return 1
        elif act.is_lose():
            return -1
        elif act.is_block():
            return -1

        b.swap_enemy()

        with torch.no_grad():
            o, v = nf(to_emb(b), y2)
            a2 = o.exp().argmax().item()
            y2 = torch.cat([y2[:, 1:, :], emb_mem([(b, a2)], ActionController.get_action_space())], dim=1)

        act.execute_action(a2)
        b.step(500)

        if act.is_win():
            return -1
        elif act.is_lose():
            return 1
        elif act.is_block():
            return 1

        b.swap_enemy()

    return 0


def eval_nn(fp, nf, n = 40):
    w = 0
    d = 0
    l = 0
    for _ in range(n):
        r = run_battle(fp, nf)
        if r == 1:
            w += 1
        elif r == 0:
            d += 1
        else:
            l += 1

    return w / n, d / n, l / n


def nfp():
    b = Board3(walk_time=200)
    x = to_emb(b)
    y = emb_mem([(b, 0)] * 2, ActionController.get_action_space())
    return FP(x.shape[1], ActionController.get_action_space(), y.shape[-1])

In [6]:
# fp = nfp()
#
# with concurrent.futures.ProcessPoolExecutor() as executor:
#     futures = set()
#     for _ in range(1):
#         futures.add(executor.submit(run_episode, fp))
#
#     done, futures = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
#     for fut in done:
#         result = fut.result()
#         print(result)
# run_episode(fp)

# p = tqdm(total=1000)
#
# data = []
# with concurrent.futures.ProcessPoolExecutor() as executor:
#     futures = set()
#     for _ in range(12):
#         futures.add(executor.submit(run_episode, fp))
#
#     while len(data) < 1000:
#         # Wait until at least one future is completed.
#         done, futures = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
#         for fut in done:
#             result = fut.result()
#             if result:
#                 data.extend(result)
#                 p.update(len(result))
#             # Submit a new simulation to keep the pool busy.
#             futures.add(executor.submit(run_episode, fp))
#             # Break early if we have reached the desired amount.
#             if len(data) >= 1000:
#                 break

In [7]:
# fp.load_state_dict(torch.load('models/fp.pth', map_location=torch.device('cpu'), weights_only=True), strict=False)
fp = nfp()
fp.eval()


h = deque(maxlen=5000)

for _ in tqdm(range(20)):


    while len(h) < 500:
        h += run_episode(fp)

    nf = train_nn(h, fp, epochs = 3, lr=0.001, batch_size=64)

    w, d, l = eval_nn(nf, fp)
    print(w, d, l)
    if w > 0.55:
        print('Upgrade')
        fp = nf


  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 