In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

In [2]:
from ai import *

In [3]:
from torch.utils.data import Dataset, DataLoader
import pickle
import random
class GameDataset(Dataset):
    def __init__(self, record, filepath='/home/smg/game_record.pkl'):
        self.record = record
        try:
            with open(filepath, 'rb') as f:
                record_from_file = pickle.load(f)
                file_length = len(record_from_file['input'])
                use_num = min(file_length, 100000)
                sample_key = random.sample(range(file_length), use_num)
                record_cut = {'input': [], 'value': [], 'policy': []}
                for key in sample_key:
                    record_cut['input'].append(record_from_file['input'][key])
                    record_cut['policy'].append(record_from_file['policy'][key])
                    record_cut['value'].append(record_from_file['value'][key])
                record_from_file = record_cut
                self.record['input'] += record_from_file['input']
                self.record['policy'] += record_from_file['policy']
                self.record['value'] += record_from_file['value']
                with open('/home/smg/game_record.pkl', 'wb') as f2:
                    pickle.dump(self.record, f2)
        except:
            with open('/home/smg/game_record.pkl', 'wb') as f2:
                pickle.dump(self.record, f2)

    def __len__(self):
        return len(self.record['input'])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()

        return {'input': self.record['input'][idx],
                'value': self.record['value'][idx],
                'policy': self.record['policy'][idx]}

initial_state = [[1,0,0,0,0,0,2],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[2,0,0,0,0,0,1]]
def self_play(match_num=100):
    game_record = {'input': [],
                'policy': [],
                'value': []}
    print('Matching', end=' ')
    for i in range(match_num):
        AI1 = SidusPlayer(1, use_gpu=True)
        AI2 = SidusPlayer(2, use_gpu=True)
        board = initial_state
        turn = 0
        max_turn = 100
        while turn < max_turn:
            move_1 = AI1.move(board, max_visit=20, is_train=True)
            if move_1 != tuple():
                policy = get_move_key(move_1[0][0], move_1[0][1], move_1[1][0], move_1[1][1])
            else:
                policy = 24
            tensor_input_1 = list_to_tensor_input(board, 1)
            game_record['input'].append(tensor_input_1)
            game_record['policy'].append(torch.tensor(policy, dtype=torch.long))
            board = update_board(board, move_1)
            if is_end(board) or turn >= max_turn:
                if get_winner(board, 1, 2) == 1:
                    value_1st = 1
                else:
                    value_1st = -1
                v = value_1st
                while len(game_record['value']) < len(game_record['input']):
                    game_record['value'].append(torch.tensor(v, dtype=torch.float))
                    v = -v
                break
            turn += 1
            move_2 = AI2.move(board, max_visit=20, is_train=True)
            if move_2 != tuple():
                policy = get_move_key(move_2[0][0], move_2[0][1], move_2[1][0], move_2[1][1])
            else:
                policy = 24
            tensor_input_2 = list_to_tensor_input(board, 2)
            game_record['input'].append(tensor_input_2)
            game_record['policy'].append(torch.tensor(policy, dtype=torch.long))
            board = update_board(board, move_2)
            if is_end(board) or turn >= max_turn:
                if get_winner(board, 1, 2) == 1:
                    value_1st = 1
                else:
                    value_1st = -1
                v = value_1st
                while len(game_record['value']) < len(game_record['input']):
                    game_record['value'].append(torch.tensor(v, dtype=torch.float))
                    v = -v
                break
            turn += 1
        print(i, end=' ')
    print()
    # End of for
    return game_record

def train(record, epoch_num=10):
    print('Training')
    datafeeder = GameDataset(record)
    randomsampler = torch.utils.data.RandomSampler(datafeeder, replacement=True, num_samples=10000)

    net = SidusAtaxxNet()
    net.load_state_dict(torch.load('model.pt'))
    net.train()
    net.cuda()

    criterion_pol = nn.NLLLoss()
    criterion_val = nn.MSELoss()
    optim = torch.optim.Adam(net.parameters(), lr=0.00001)

    for epoch in range(epoch_num):
        dataloader = DataLoader(datafeeder, batch_size=32, sampler=randomsampler)
        total_loss = 0
        loss_pol_sum = 0
        loss_val_sum = 0
        loss_val_max = -1
        loss_val_min = 1
        for i, data in enumerate(dataloader):
            #print(torch.sum(data['input'][0][0]))
            output = net(data['input'].to('cuda'))
            policy, value = output
            #print(value[0].detach().cpu().item(), data['value'][0])
            loss_pol = criterion_pol(policy, data['policy'].to('cuda'))
            loss_val = criterion_val(value, data['value'].to('cuda'))
            alpha = 1e0
            loss = 2 * (loss_pol + alpha * loss_val) / (1 + alpha)

            optim.zero_grad()
            loss.backward()
            optim.step()

            loss_pol_sum += loss_pol.detach().cpu().item()
            loss_val_sum += loss_val.detach().cpu().item()
            total_loss += loss.detach().cpu().item()
            loss_val_float = loss_val.detach().cpu().item()
            if loss_val_float > loss_val_max:
                loss_val_max = loss_val_float
            if loss_val_float < loss_val_min:
                loss_val_min = loss_val_float
            print_interval = 100
            if i % print_interval == print_interval - 1:
                print('[%d, %5d] loss: %.8f\tpol: %.8f\tval: %.8f %.5f %.5f' %
                    (epoch + 1, i + 1, total_loss/print_interval, loss_pol_sum/print_interval, loss_val_sum/print_interval, loss_val_max, loss_val_min))
                total_loss = 0
                loss_pol_sum = 0
                loss_val_sum = 0
                loss_val_max = -1
                loss_val_min = 1
                torch.save({'epoch': epoch,
                            'step': i,
                            'model_state_dict': net.state_dict(),
                            'optimizer_state_dict': optim.state_dict(),
                            'loss': loss}, 'checkpoint_aginst.tar')
        torch.save(net.state_dict(), 'model.pt')

def compete(match_num, target_num):
    print('Competing')
    AI_champ = SidusPlayer(1, filepath='model_top.pt', use_gpu=True)
    AI_chall = SidusPlayer(2, filepath='model.pt', use_gpu=True)
    win_count = 0
    for i in range(match_num):
        turn = 0
        max_turn = 100
        board = initial_state
        while turn < max_turn:
            move_1 = AI_champ.move(board, max_visit=50)
            board = update_board(board, move_1)
            if is_end(board) or turn >= max_turn:
                if get_winner(board, 1, 2) == 2:
                    win_count += 1
                    print('W', end='')
                else:
                    print('_', end='')
                break
            turn += 1
            move_2 = AI_chall.move(board, max_visit=50)
            board = update_board(board, move_2)
            if is_end(board) or turn >= max_turn:
                if get_winner(board, 1, 2) == 2:
                    win_count += 1
                    print('W', end='')
                else:
                    print('_', end='')
                break
            turn += 1
        if match_num - (i+1) + win_count < target_num:
            break
    print()
    if win_count >= target_num:
        torch.save(AI_chall.net.state_dict(), 'model_top.pt')
        return True
    return False



In [4]:
count = 0
while True:
    count += 1
    for i in range(1):
        game_record = self_play(10)
        # Train
        train(game_record, 10)
    new_champ = compete(7, 5)
    if new_champ:
        print('selfplay {}: New champ'.format(count))
    else:
        print('selfplay {}: End'.format(count))

Matching01234

In [5]:
def print_dataset(record, idx):
    print('idx', idx)
    print(torch.sum(record['input'][idx][0] - record['input'][idx][1]))
    print(record['input'][idx][0] - record['input'][idx][1])
    pol = record['policy'][idx].item()
    print(pol // 7 // 7, (pol // 7) % 7, pol % 7)
    print(record['value'][idx])

In [6]:
empty_record = {'input': [],
                'policy': [],
                'value': []}
train(empty_record)

Training
