In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from game import Board, ACTION_SIZE, rotate, rotate_bundle
from model import NNet, MCTS
import numpy as np
import pickle
import random
import json
from tqdm import tqdm
import multiprocessing
import os
print("pid=%d"%os.getpid())

In [2]:
debug = None

def count_data(mcts: MCTS):
    count=0
    for key in mcts.Ns:
        if mcts.Ns[key] > 10:
            count+=1
    return count

def self_play(mcts: MCTS, batch_num=100, timeout=10):
    global debug
    for batch in range(batch_num):
        print(f"Batch {batch}")
        num = count_data(mcts)
        if num > 10000:
            print("Data is enough, break", num, ">", 10000)
            break
        else:
            print("Data is not enough, continue self play", num, "<", 10000)
        board = Board()
        for i in range(100):
            debug = board.copy()
            action = mcts.best_move(board, timeout)
            c = board.color
            v = mcts.query_v(board, action)
            board.place(*board.int2move(action))
            if hasattr(board, "winner"):
                break
            print(str(board))
            print(f"Color: {'O_X'[c+1]}, Action: {action}, Value: {v}")


In [3]:
ver = 1

In [4]:
import pickle
def save_mcts(mcts, filename):
    with open(filename, "wb") as f:
        pickle.dump(
            {
                "Ps": mcts.Ps,
                "Ns": mcts.Ns,
                "Qsa": mcts.Qsa,
                "Nsa": mcts.Nsa,
            }, f)
def load_mcts(mcts, filename):
    with open(filename, "rb") as f:
        data = pickle.load(f)
        mcts.Ps = data["Ps"]
        mcts.Ns = data["Ns"]
        mcts.Qsa = data["Qsa"]
        mcts.Nsa = data["Nsa"]
def get_data(mcts: MCTS):
    key_list = list(mcts.Ns.keys())
    weight_list = [mcts.Ns[key] for key in key_list]
    sum = np.sum(weight_list)
    weight_list = [w/sum for w in weight_list]
    key_choose = np.random.choice(range(len(key_list)), 20000, replace=True, p=weight_list)
    data = []
    
    for id in tqdm(key_choose):
        key = key_list[id]
        board = Board.from_state(key)
        input = board.bundled_input(board.legal_moves_input())
        ps = np.zeros(ACTION_SIZE)
        action = mcts.best_move(board, timeout=0)
        ps[action] = 1.0
        vs = mcts.Qsa[key, action]
        for i in range(4):
            data.append((
                input,
                ps,
                float(vs)
                )
            )
            input = rotate_bundle(input, board.n)
            last = ps[-1]
            ps = np.array(rotate(ps[:-1], board.n).tolist() + [last])
    return data

def train_network(data, epoch_num=1000, pi_only=False):
    map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(map_location)
    saved_state = torch.load("data/cnn.pt", map_location=map_location)
    nnet = NNet(0, 128, 256).to(device)
    nnet.load_state_dict(saved_state)

    data_input = torch.tensor(np.array([d[0] for d in data]), dtype=torch.float32).to(device)
    data_output1, data_output2 = \
        torch.tensor(np.array([d[1] for d in data]), dtype=torch.float32).to(device), \
        torch.tensor(np.array([d[2] for d in data]), dtype=torch.float32).to(device)
    data_output1 = data_output1.view(-1, ACTION_SIZE)
    data_output2 = data_output2.view(-1, 1)

    # train nnet with data
    optimizer = optim.Adam(nnet.parameters(), lr=0.0005, weight_decay=1e-4)
    for epoch in range(epoch_num):
        optimizer.zero_grad()
        output1, output2 = nnet(data_input)
        if not pi_only:
            # 计算交叉熵
            loss1 = torch.mean(-data_output1 * output1)
            loss2 = nn.MSELoss()(output2, data_output2)
            # 分别训练 loss1 和 loss2
            loss = loss1 + loss2
        else:
            loss = torch.mean(-data_output1 * output1)
        loss.backward()
        optimizer.step()
        if not pi_only:
            print(f'Epoch {epoch}, Loss1: {loss1.item()}, Loss2: {loss2.item()}')
        else:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

    # save nnet
    torch.save(nnet.state_dict(), "data/cnn.pt")


In [None]:
nnet = NNet(0, 128, 256)
saved_state = torch.load("data/cnn.pt", map_location='cpu')
nnet.load_state_dict(saved_state)
mcts = MCTS(nnet)
self_play(mcts, timeout=10)

In [None]:
nnet = NNet(0, 128, 256)
while True:
    print('----------------------------------')
    saved_state = torch.load("data/cnn.pt", map_location='cpu')
    nnet.load_state_dict(saved_state)
    mcts = MCTS(nnet)
    self_play(mcts, timeout=4)
    print('Count Data: ', count_data(mcts))
    data = get_data(mcts)
    with open(f"data{ver}.pkl", "wb") as f:
        pickle.dump(data, f)
    train_network(data[0:50000], epoch_num=20)
    train_network(data[50000:100000], epoch_num=20)
    train_network(data[100000:150000], epoch_num=20)
    train_network(data[150000:200000], epoch_num=20)
    ver += 1
    del mcts