In [1]:
import numpy as np
import torch
import random

In [2]:
import torch.nn as nn
import torch.nn.functional as F

NUM_MOVES = 81 # Max number of valid moves in a game of NestedTTT

class NestedTTTNet(nn.Module):
    def __init__(self, inner_board_units = 32, outer_board_units = 64, final_units = 256, n_res_layers = 2):
        super(NestedTTTNet, self).__init__()
        self.n_res_layers = n_res_layers
        self.outer_board_units = outer_board_units
        self.final_units = final_units
        
        self.conv1 = nn.Conv2d(3, inner_board_units, 3, padding = 1)
        self.bn_inner = nn.BatchNorm2d(inner_board_units)
        self.res1 = nn.Conv2d(inner_board_units, inner_board_units, 3, padding = 1)
        
        self.conv2_collapse = nn.Conv2d(inner_board_units, outer_board_units, 3)
        self.bn_outer = nn.BatchNorm2d(outer_board_units)
        self.res2 = nn.Conv2d(outer_board_units, outer_board_units, 3, padding = 1)
        
        self.conv3_collapse = nn.Conv2d(outer_board_units, final_units, 3)
        self.bn_linear = nn.BatchNorm1d(final_units)
        
        self.policy_head = nn.Sequential(
            nn.Linear(final_units, NUM_MOVES),
            nn.Softmax(dim = 1)
        )
        
        self.value_head = nn.Sequential(
            nn.Linear(final_units, final_units // 2),
            nn.ReLU(),
            nn.Linear(final_units // 2, 1),
            nn.Tanh()
        )
        
    def forward(self, board_states):
        '''
        board_states.size() == (-1, 3, 3, 3, 3, 3)
                            == (batch_size, [x-owned, o-owned, turn-state], outer_row, outer_col, inner_row, inner_col)
        '''
        # Move the outer board states to the left so their intra-square can be resolved before introducing inter-square
        board_states = board_states.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, 3, 3, 3)
        
        x = F.relu(self.bn_inner(self.conv1(board_states)))
        for _ in range(self.n_res_layers):
            x_conv = F.relu(self.bn_inner(self.res1(x)))
            x = F.relu(x + self.bn_inner(self.res1(x_conv)))
            
        x = F.relu(self.bn_outer(self.conv2_collapse(x)))
        x = x.view(-1, 3, 3, self.outer_board_units).permute(0, 3, 1, 2).contiguous() # Move them back to the right
        
        for _ in range(self.n_res_layers):
            x_conv = F.relu(self.bn_outer(self.res2(x)))
            x = F.relu(x + self.bn_outer(self.res2(x_conv)))
        
        x = F.relu(self.bn_linear(self.conv3_collapse(x).view(-1, self.final_units)))
        
        policy = self.policy_head(x)
        value = self.value_head(x)
        
        return policy, value

In [3]:
import torch

class NestedTTT:

    anti_diag_tensor = torch.ByteTensor([[1 if r + c == 2 else 0 for c in range(3)] for r in range(3)])

    def __init__(self):
        #Board State Tensor
        #Outermost dimension:
        #First value represents Player 0's (X's) holdings, second represents O's
        #Third is filled with the number of the current player to move (0 = X, 1 = O)
        #Rest of the dimensions are (Outer Row, Outer Column, Inner Row, Inner Column)
        self.state = torch.full((3, 3, 3, 3, 3), 0, dtype = torch.float32, requires_grad = False)

        #Valid Move Tensor
        #(Outer Row, Outer Column, Inner Row, Inner Column)
        self.valid_moves = torch.full((3, 3, 3, 3), 1, dtype = torch.uint8, requires_grad = False)

        #Summary Board Tensor
        #(Player, Row, Column)
        self.summary_boards = torch.full((2, 3, 3), 0, dtype = torch.uint8, requires_grad = False)

        #Current Player
        self.current_player = 0
        
    def check_win(self, board):
        """
        Short-circuits if current player has less than 3 squares on the current board
        If they do, checks all rows, columns, and diagonals for a win

        Returns 0 if the current player did not win the board in question, or 1 if they did
        """
        if board.sum().item() < 3:
            return 0
        else:
            return (board.sum(0) == 3).any().item() or \
                   (board.sum(1) == 3).any().item() or \
                   (board.diag().sum() == 3).item()   or \
                   (board.masked_select(self.anti_diag_tensor).sum() == 3).item()

    def make_move(self, move):
        """
        Takes a move:
          ({0 for X, 1 for O}, outer row, outer col, inner row, inner col)
        Returns 1 if the move won the game, 0 otherwise
        """
        if not self.valid_moves[move[1:]].item():
            raise ValueError("Invalid move placement")
        if move[0] != self.state[2,0,0,0,0].item():
            raise ValueError("Only the current player can make a move")
        self.state[move] = 1
        self.valid_moves[move[1:]] = 0
        self.state[2] = 1 - self.state[2]

        if self.check_win(self.state[move[:3]]):
            self.summary_boards[move[:3]] = 1
            self.valid_moves[move[1:3]] = 0
            if self.check_win(self.summary_boards[move[0]]):
                self.valid_moves.fill_(0)
                return 1

        return 0

    def undo_move(self, move):
        if not self.state[move].item():
            raise ValueError("Can only undo where a move has been played")
        if move[0] == self.current_player:
            raise ValueError("Cannot undo a move for the current player")
        if self.state[(2,0,0,0,0)].item() != self.current_player:
            raise ValueError("Desynchronization between current player and state tensor")
        self.state[move] = 0
        self.valid_moves[move[1:3]] = 1 - self.state[:2, move[1], move[2]].sum(0)

        self.summary_boards[move[:3]] = 0
        self.state[2] = 1 - self.state[2]

    def switch_player(self):
        self.current_player = (self.current_player + 1) % 2
        
    def copy(self):
        c = self.__new__(NestedTTT)
        c.state = self.state.clone().detach().detach()
        c.valid_moves = self.valid_moves.clone().detach()
        c.summary_boards = self.summary_boards.clone().detach()
        c.current_player = self.current_player
        return c
    
    def reset(self):
        self.__init__()
        
    def get_valid_moves(self):
        return self.valid_moves.nonzero().tolist()

In [4]:
from math import sqrt
import random

EXPLORATION_PARAM = sqrt(2)

class Node:
    """
    Nodes represent a state in the game.  They can have edges which
    indicate legal actions to take from the state
    """
    def __init__(self, incoming_edge):
        self.parent_edge = incoming_edge
        self.leaf = True
        self.edges = dict()

    def choose_action(self, temperature = 1):
        if not self.edges:
            raise ValueError("Cannot choose action from an unexplored node")

        visits_exp = sum([edge.N ** (1 / temperature) for edge in self.edges.values()])
        thresh = random.random()
        cumsum = 0
        for action, edge in self.edges.items():
            cumsum = cumsum + edge.N ** (1 / temperature)
            if thresh < cumsum / visits_exp:
                return action

    def selection_step(self):
        if self.leaf:
            raise ValueError("Cannot perform selection step at a leaf")
        max_result = None
        max_action = None
        
        total_trials = sum([edge.N for edge in self.edges.values()])

        for action, edge in self.edges.items():
            current_estimate = edge.value_estimate(total_trials)
            if max_result is None or max_result < current_estimate:
                max_result = current_estimate
                max_action = action

        return max_action
    
    def expansion(self, actions, ps):
        if actions:
            for action, p in zip(actions, ps):
                self.edges[action] = Edge(self, p)
            self.leaf = False
        
    def backpropagate(self, value):
        if self.parent_edge is None:
            raise ValueError("Can't start backpropagating at the root")
        else:
            #Current state for player A is a result of the action (edge) chosen by player B
            #If the current state has a good value for player A, then it has a bad value for player B
            #Value of current state (node) is inverted to update value of action taken (incoming edge)
            self.parent_edge.backpropagate(-value)

class Edge:
    """
    Edges represent transitions between game states, initiated by an action
    Edges keep track of the number of visits, the estimated value of taking
    their associated action, and the prior probability of taking that action
    """
    def __init__(self, origin, p):
        self.origin = origin
        self.destination = Node(self)
        self.N = 0 # Visits
        self.Q = 0 # Value estimate
        self.P = p # NN-generated prior

    def backpropagate(self, value):
        self.N += 1
        self.Q = self.Q * ((self.N - 1) / self.N) + value / self.N
        
        #Alternating edges indicate alternating player actions, so the value of states will be inverted at each layer
        if self.origin.parent_edge is not None:
            self.origin.parent_edge.backpropagate(-value)

    def value_estimate(self, parent_trials):
        return self.Q + EXPLORATION_PARAM * self.P * parent_trials ** .5 / (1 + self.N)
    
    def __str__(self):
        return "N={}, Q={}, P={}".format(self.N, self.Q, self.P)
    
    def __repr__(self):
        return str(self)

In [15]:
from time import time
import logging

class AlphaMCTSAgent:
    def __init__(self, control_net = None):
        self.root = Node(None)
        self.playout_total = 0
        self.control_net = control_net
        
    def update_control_net(self, control_net):
        self.control_net = control_net
        
    def search(self, game, turn_num, allotted_playouts = 800):
        if self.control_net is None:
            raise ValueError("Control net must be set before starting searches")
        playouts = 0

        start_time = time()
        while playouts < allotted_playouts:
            current_game = game.copy()
            current_game, node = self._selection(current_game)

            actions = current_game.get_valid_moves()
            ps, v = self.control_net(current_game.state.unsqueeze(0))
            ps.squeeze_()
            v = v.squeeze().item()

            indices = [self.ttt_position_to_index(action) for action in actions]
            
            actions = [(round(current_game.state[2,0,0,0,0].item()),) + tuple(action) for action in actions]

            node.expansion(actions, ps.detach().numpy()[indices])
            
            if id(self.root) != id(node):
                node.backpropagate(v)

            playouts += 1

        temp = 1 if turn_num < 30 else .1
        max_action = self.root.choose_action(temp)
        
        # debug info
        self.playout_total = playouts
        self.action_node = max_action
        
        actions = game.get_valid_moves()
        indices = [self.ttt_position_to_index(action) for action in actions]
        actions = [(round(game.state[2,0,0,0,0].item()),) + tuple(action) for action in actions]
        
        mcts_probs = {indices[i]: self.root.edges[actions[i]].N for i in range(len(indices))}
        total_trials = sum(mcts_probs.values())
        mcts_probs = [mcts_probs.get(x, 0) / total_trials for x in range(81)]
        
        return max_action, mcts_probs
    
    def take_action(self, action):
        try:
            self.root = self.root.edges[action].destination
        except KeyError:
            print("Action not found, throwing away tree")
            self.reset()
            
    def reset(self):
        self.root = Node(None)

    def _selection(self, state):
        """
        Progress through the tree of Nodes, starting at the root, until a
        leaf is found or there are unexplored actions at the level we are
        exploring.

        Uses the UCT algorithm to determine which nodes to progress to.
        """
        node = self.root
        while not node.leaf:
            move = node.selection_step()
            node = node.edges[move].destination
            state.make_move(move)
        return state, node
    
    @staticmethod
    def index_to_ttt_position(idx):
        return idx // 27 % 3, idx // 9 % 3, idx // 3 % 3, idx % 3

    @staticmethod
    def ttt_position_to_index(position):
        return position[0] * 27 + position[1] * 9 + position[2] * 3 + position[3]

In [16]:
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity = 100000):
        self.buffer = deque(maxlen = capacity)
    
    def push(self, data):
        self.buffer.append(data)
        
    def extend(self, data):
        self.buffer.extend(data)
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

In [17]:
# One class that has access to:
  # Game simulation
  # Control network
  # Current network
  # Replay buffer
  # MCTS agent
  # Training code  -- Another process?

In [25]:
import torch
import random

def zero_gen():
    while True:
        yield 0

def one_neg_one_gen():
    while True:
        yield 1
        yield -1

class SelfPlayTrainer:
    def __init__(self, agent, game, buffer_file = None, weights_file = None, n_batches = 0):
        self.agent = agent
        self.game = game
        self.replay_buffer = ReplayBuffer()
        
        if buffer_file is not None:
            self.replay_buffer.buffer = pickle.load(open(buffer_file, "rb"))
        
        self.current_network = NestedTTTNet()
        self.control_network = NestedTTTNet()
        
        if weights_file is not None:
            self.control_network.load_state_dict(torch.load(weights_file))
        
        self.current_network.load_state_dict(self.control_network.state_dict())
        self.control_network.eval()
        self.current_network.train()
        
        self.agent.update_control_net(self.control_network)
        
        self.n_batches = n_batches
        
        self.optim = torch.optim.Adam(self.current_network.parameters(), lr = .01, weight_decay = 10e-4)
    
    def generate_self_play_data(self, n_games = 100):
        for _ in range(n_games):
            turn_num = 0
            self.game.reset()
            self.agent.reset()
            result = 0
            player_num = 0

            states = []
            move_vectors = []

            while len(self.game.get_valid_moves()) > 0:
                move, move_probs = self.agent.search(self.game.copy(), turn_num, allotted_playouts = 400)

                states.append(self.game.state.tolist())
                move_vectors.append(move_probs)

                result = self.game.make_move(move)
                if not result:
                    self.game.switch_player()
                    self.agent.take_action(move)
                    turn_num += 1
                    player_num = (player_num + 1) % 2

            if not result:
                self.replay_buffer.extend(list(zip(states, move_vectors, zero_gen())))
            else:
                self.replay_buffer.extend(list(zip(states[::-1], move_vectors[::-1], one_neg_one_gen()))[::-1])

    def compare_control_to_train(self):
        self.current_network.eval()
        old_agent = AlphaMCTSAgent(control_net = self.control_network)
        new_agent = AlphaMCTSAgent(control_net = self.current_network)
        
        agents = [old_agent, new_agent]
        
        wins = 0
        ties = 0
        
        game = self.game.copy()
        
        for game_num in range(100):
            game.reset()
            agents[0].reset()
            agents[1].reset()
            result = 0
            player_num = game_num // 50 #Both take first turn 50 times
            turn_num = 100 #Turn down the temperature
            
            while len(game.get_valid_moves()) > 0:
                move, _ = agents[player_num].search(game.copy(), turn_num, allotted_playouts = 800)
                _, _ = agents[1 - player_num].search(game.copy(), turn_num, allotted_playouts = 800)

                result = self.game.make_move(move)
                if not result:
                    game.switch_player()
                    agents[0].take_action(move)
                    agents[1].take_action(move)
                    player_num = (player_num + 1) % 2

            if not result:
                ties += 1
            elif result and current_player == 1:
                wins += 1
        
        if wins + .5 * ties >= 55:
            print("Challenger network won {} games and tied {} games; it becomes new control network".format(wins, ties))
            torch.save(self.current_network.state_dict(), "control_weights_{}.pth".format(self.n_batches))
            self.control_network.load_state_dict(self.current_network.state_dict())
        else:
            print("Challenger network not sufficiently better; {} wins and {} ties".format(wins, ties))
        
        self.control_network.eval()
        self.current_network.train()            
    
    def train_on_batch(self, batch_size = 32):
        if len(self.replay_buffer) < batch_size:
            return
        
        self.current_network.train()
        
        sample = self.replay_buffer.sample(batch_size)
        states, probs, rewards = zip(*sample)
        states = torch.FloatTensor(states).requires_grad_(True)
        probs = torch.FloatTensor(probs).requires_grad_(True)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).requires_grad_(True)
        self.optim.zero_grad()
        
        ps, vs = trainer.current_network(states)
        
        loss = torch.nn.functional.mse_loss(vs, rewards) - (ps.log() * probs).sum()
        loss.backward()
        
        self.optim.step()
        
        self.n_batches += 1
        
        return loss.item()
    
    def run(self, total_runs = 10, self_play_games = 100, training_batches = 200, batch_size = 32):
        losses = []
        for run_num in range(1, total_runs+1):
            print("Run {} of {}".format(run_num, total_runs))
            for selfplay_num in range(1, self_play_games + 1):
                self.generate_self_play_data(1)
                print("\tFinished self-play game {} of {} (Buffer size {})".format(selfplay_num, self_play_games, len(self.replay_buffer)))
            print("Finished {} self-play games".format(self_play_games))
            for _ in range(training_batches):
                losses.append(self.train_on_batch(batch_size))
                if len(losses) == 5:
                    print("\tLoss for last 5 batches: {}".format(sum(losses)))
                    losses = []
        
        self.compare_control_to_train()

In [26]:
trainer = SelfPlayTrainer(AlphaMCTSAgent(), NestedTTT())

In [39]:
trainer.run(1, 0, 30000)

Run 1 of 1
Finished 0 self-play games
	Loss for last 5 batches: 619.5038223266602
	Loss for last 5 batches: 616.9755706787109
	Loss for last 5 batches: 632.480094909668
	Loss for last 5 batches: 627.7477569580078
	Loss for last 5 batches: 620.2631530761719
	Loss for last 5 batches: 620.7949600219727
	Loss for last 5 batches: 622.7703094482422
	Loss for last 5 batches: 614.2876358032227
	Loss for last 5 batches: 615.6514434814453
	Loss for last 5 batches: 622.0893859863281
	Loss for last 5 batches: 608.3180770874023
	Loss for last 5 batches: 635.6034622192383
	Loss for last 5 batches: 632.2602920532227
	Loss for last 5 batches: 625.8481369018555
	Loss for last 5 batches: 616.8213806152344
	Loss for last 5 batches: 622.3674163818359
	Loss for last 5 batches: 629.5902633666992
	Loss for last 5 batches: 622.5725173950195
	Loss for last 5 batches: 633.1866683959961
	Loss for last 5 batches: 623.2088623046875
	Loss for last 5 batches: 619.0293807983398
	Loss for last 5 batches: 623.910507202

	Loss for last 5 batches: 625.5573425292969
	Loss for last 5 batches: 631.7596054077148
	Loss for last 5 batches: 618.2097930908203
	Loss for last 5 batches: 617.1615829467773
	Loss for last 5 batches: 636.5772552490234
	Loss for last 5 batches: 639.4394226074219
	Loss for last 5 batches: 625.2365188598633
	Loss for last 5 batches: 624.9333419799805
	Loss for last 5 batches: 624.4702377319336
	Loss for last 5 batches: 626.2000045776367
	Loss for last 5 batches: 613.6240615844727
	Loss for last 5 batches: 613.2272644042969
	Loss for last 5 batches: 627.3467102050781
	Loss for last 5 batches: 618.4839553833008
	Loss for last 5 batches: 624.8309860229492
	Loss for last 5 batches: 633.0959396362305
	Loss for last 5 batches: 614.5800476074219
	Loss for last 5 batches: 610.8126831054688
	Loss for last 5 batches: 619.8279190063477
	Loss for last 5 batches: 624.5617599487305
	Loss for last 5 batches: 614.26171875
	Loss for last 5 batches: 623.0623550415039
	Loss for last 5 batches: 632.0245895

	Loss for last 5 batches: 624.1297607421875
	Loss for last 5 batches: 616.0171203613281
	Loss for last 5 batches: 621.5968933105469
	Loss for last 5 batches: 613.0711898803711
	Loss for last 5 batches: 633.8810195922852
	Loss for last 5 batches: 622.2598495483398
	Loss for last 5 batches: 620.2450561523438
	Loss for last 5 batches: 619.9661178588867
	Loss for last 5 batches: 614.572395324707
	Loss for last 5 batches: 630.0176620483398
	Loss for last 5 batches: 631.5188674926758
	Loss for last 5 batches: 620.9067153930664
	Loss for last 5 batches: 610.4141693115234
	Loss for last 5 batches: 620.3015060424805
	Loss for last 5 batches: 630.3010711669922
	Loss for last 5 batches: 608.8565444946289
	Loss for last 5 batches: 629.8126602172852
	Loss for last 5 batches: 617.3827514648438
	Loss for last 5 batches: 630.1281509399414
	Loss for last 5 batches: 630.411247253418
	Loss for last 5 batches: 628.6357879638672
	Loss for last 5 batches: 628.7556228637695
	Loss for last 5 batches: 632.1241

	Loss for last 5 batches: 623.7795257568359
	Loss for last 5 batches: 629.0427703857422
	Loss for last 5 batches: 615.5578460693359
	Loss for last 5 batches: 624.552978515625
	Loss for last 5 batches: 622.757438659668
	Loss for last 5 batches: 619.0261306762695
	Loss for last 5 batches: 622.6587219238281
	Loss for last 5 batches: 621.1754150390625
	Loss for last 5 batches: 622.6700973510742
	Loss for last 5 batches: 617.5886840820312
	Loss for last 5 batches: 625.8371810913086
	Loss for last 5 batches: 616.8326721191406
	Loss for last 5 batches: 622.8044281005859
	Loss for last 5 batches: 621.2694854736328
	Loss for last 5 batches: 617.2003479003906
	Loss for last 5 batches: 617.9424362182617
	Loss for last 5 batches: 616.7885131835938
	Loss for last 5 batches: 624.3287887573242
	Loss for last 5 batches: 634.9415740966797
	Loss for last 5 batches: 613.8180618286133
	Loss for last 5 batches: 613.6225814819336
	Loss for last 5 batches: 611.1263275146484
	Loss for last 5 batches: 625.8172

	Loss for last 5 batches: 627.423469543457
	Loss for last 5 batches: 629.3079452514648
	Loss for last 5 batches: 622.072151184082
	Loss for last 5 batches: 623.7713241577148
	Loss for last 5 batches: 624.9899215698242
	Loss for last 5 batches: 611.3786849975586
	Loss for last 5 batches: 611.5485305786133
	Loss for last 5 batches: 617.7333450317383
	Loss for last 5 batches: 616.8922882080078
	Loss for last 5 batches: 620.4868545532227
	Loss for last 5 batches: 618.9091033935547
	Loss for last 5 batches: 626.1132125854492
	Loss for last 5 batches: 615.2046432495117
	Loss for last 5 batches: 623.955078125
	Loss for last 5 batches: 614.5669403076172
	Loss for last 5 batches: 617.7656631469727
	Loss for last 5 batches: 621.6399459838867
	Loss for last 5 batches: 606.0004196166992
	Loss for last 5 batches: 636.6131439208984
	Loss for last 5 batches: 615.8760833740234
	Loss for last 5 batches: 633.8162612915039
	Loss for last 5 batches: 616.8188705444336
	Loss for last 5 batches: 620.78649902

	Loss for last 5 batches: 634.5682754516602
	Loss for last 5 batches: 625.2084426879883
	Loss for last 5 batches: 624.8352966308594
	Loss for last 5 batches: 633.5164566040039
	Loss for last 5 batches: 627.1563110351562
	Loss for last 5 batches: 625.7212829589844
	Loss for last 5 batches: 616.5391235351562
	Loss for last 5 batches: 623.6807327270508
	Loss for last 5 batches: 614.4197006225586
	Loss for last 5 batches: 612.2201690673828
	Loss for last 5 batches: 620.3474960327148
	Loss for last 5 batches: 626.0092010498047
	Loss for last 5 batches: 621.0910949707031
	Loss for last 5 batches: 627.101203918457
	Loss for last 5 batches: 624.9694671630859
	Loss for last 5 batches: 626.4357452392578
	Loss for last 5 batches: 633.9826583862305
	Loss for last 5 batches: 619.7064666748047
	Loss for last 5 batches: 619.5982666015625
	Loss for last 5 batches: 620.8780670166016
	Loss for last 5 batches: 620.6679916381836
	Loss for last 5 batches: 631.8873443603516
	Loss for last 5 batches: 618.228

	Loss for last 5 batches: 620.349494934082
	Loss for last 5 batches: 626.0971527099609
	Loss for last 5 batches: 612.6644134521484
	Loss for last 5 batches: 629.4139633178711
	Loss for last 5 batches: 621.9611892700195
	Loss for last 5 batches: 613.1655578613281
	Loss for last 5 batches: 634.1270599365234
	Loss for last 5 batches: 619.5074462890625
	Loss for last 5 batches: 613.7618713378906
	Loss for last 5 batches: 626.1775894165039
	Loss for last 5 batches: 618.6329116821289
	Loss for last 5 batches: 616.0005416870117
	Loss for last 5 batches: 617.5896530151367
	Loss for last 5 batches: 619.1653747558594
	Loss for last 5 batches: 608.8685760498047
	Loss for last 5 batches: 624.8856658935547
	Loss for last 5 batches: 611.1333236694336
	Loss for last 5 batches: 613.7142105102539
	Loss for last 5 batches: 629.0030059814453
	Loss for last 5 batches: 616.2134704589844
	Loss for last 5 batches: 624.1238021850586
	Loss for last 5 batches: 629.4047012329102
	Loss for last 5 batches: 617.879

	Loss for last 5 batches: 619.3973083496094
	Loss for last 5 batches: 616.1474151611328
	Loss for last 5 batches: 606.0713882446289
	Loss for last 5 batches: 618.159065246582
	Loss for last 5 batches: 626.4625473022461
	Loss for last 5 batches: 620.4588851928711
	Loss for last 5 batches: 616.0340728759766
	Loss for last 5 batches: 623.0324783325195
	Loss for last 5 batches: 625.2618179321289
	Loss for last 5 batches: 621.6729965209961
	Loss for last 5 batches: 608.9871292114258
	Loss for last 5 batches: 626.1924896240234
	Loss for last 5 batches: 614.9895477294922
	Loss for last 5 batches: 624.6957321166992
	Loss for last 5 batches: 600.1023712158203
	Loss for last 5 batches: 633.1692199707031
	Loss for last 5 batches: 613.7763748168945
	Loss for last 5 batches: 624.0037612915039
	Loss for last 5 batches: 631.6962661743164
	Loss for last 5 batches: 619.452522277832
	Loss for last 5 batches: 627.6825942993164
	Loss for last 5 batches: 617.7007675170898
	Loss for last 5 batches: 641.3156

	Loss for last 5 batches: 619.0522842407227
	Loss for last 5 batches: 618.1801376342773
	Loss for last 5 batches: 626.8066940307617
	Loss for last 5 batches: 624.3341827392578
	Loss for last 5 batches: 613.9219741821289
	Loss for last 5 batches: 629.49267578125
	Loss for last 5 batches: 617.5326843261719
	Loss for last 5 batches: 611.8031005859375
	Loss for last 5 batches: 620.140266418457
	Loss for last 5 batches: 623.2110824584961
	Loss for last 5 batches: 619.5630874633789
	Loss for last 5 batches: 616.0394744873047
	Loss for last 5 batches: 631.2383880615234
	Loss for last 5 batches: 621.9070663452148
	Loss for last 5 batches: 619.5241088867188
	Loss for last 5 batches: 632.74853515625
	Loss for last 5 batches: 620.2431564331055
	Loss for last 5 batches: 622.5754089355469
	Loss for last 5 batches: 631.213493347168
	Loss for last 5 batches: 614.7797241210938
	Loss for last 5 batches: 618.5107345581055
	Loss for last 5 batches: 622.1021499633789
	Loss for last 5 batches: 643.69706726

	Loss for last 5 batches: 626.7633743286133
	Loss for last 5 batches: 616.6999588012695
	Loss for last 5 batches: 620.5741348266602
	Loss for last 5 batches: 625.859977722168
	Loss for last 5 batches: 603.0045394897461
	Loss for last 5 batches: 622.7180252075195
	Loss for last 5 batches: 624.3259353637695
	Loss for last 5 batches: 637.4766387939453
	Loss for last 5 batches: 622.2389602661133
	Loss for last 5 batches: 618.4001617431641
	Loss for last 5 batches: 627.7022705078125
	Loss for last 5 batches: 621.2270355224609
	Loss for last 5 batches: 628.3573608398438
	Loss for last 5 batches: 631.397216796875
	Loss for last 5 batches: 620.5634155273438
	Loss for last 5 batches: 625.0666732788086
	Loss for last 5 batches: 623.4162902832031
	Loss for last 5 batches: 629.9369201660156
	Loss for last 5 batches: 620.5973587036133
	Loss for last 5 batches: 621.2229309082031
	Loss for last 5 batches: 624.5205764770508
	Loss for last 5 batches: 617.2468185424805
	Loss for last 5 batches: 619.0456

	Loss for last 5 batches: 601.4698715209961
	Loss for last 5 batches: 619.5819244384766
	Loss for last 5 batches: 625.4755325317383
	Loss for last 5 batches: 625.9417953491211
	Loss for last 5 batches: 624.0665435791016
	Loss for last 5 batches: 632.4534454345703
	Loss for last 5 batches: 625.1736602783203
	Loss for last 5 batches: 625.2648620605469
	Loss for last 5 batches: 627.9201126098633
	Loss for last 5 batches: 620.2826385498047
	Loss for last 5 batches: 626.8454132080078
	Loss for last 5 batches: 609.6402053833008
	Loss for last 5 batches: 629.5185317993164
	Loss for last 5 batches: 624.5886001586914
	Loss for last 5 batches: 633.5375366210938
	Loss for last 5 batches: 603.4433898925781
	Loss for last 5 batches: 624.7780914306641
	Loss for last 5 batches: 621.7499084472656
	Loss for last 5 batches: 626.055061340332
	Loss for last 5 batches: 623.6078414916992
	Loss for last 5 batches: 616.2292785644531
	Loss for last 5 batches: 633.9651641845703
	Loss for last 5 batches: 628.242

	Loss for last 5 batches: 611.9560241699219
	Loss for last 5 batches: 624.9466018676758
	Loss for last 5 batches: 611.6551361083984
	Loss for last 5 batches: 616.6437606811523
	Loss for last 5 batches: 624.5307846069336
	Loss for last 5 batches: 611.383186340332
	Loss for last 5 batches: 619.283088684082
	Loss for last 5 batches: 627.2478408813477
	Loss for last 5 batches: 599.5911102294922
	Loss for last 5 batches: 620.3711395263672
	Loss for last 5 batches: 639.0061645507812
	Loss for last 5 batches: 619.8637390136719
	Loss for last 5 batches: 629.5639801025391
	Loss for last 5 batches: 616.5596389770508
	Loss for last 5 batches: 624.5702209472656
	Loss for last 5 batches: 620.8977584838867
	Loss for last 5 batches: 611.2435684204102
	Loss for last 5 batches: 620.564208984375
	Loss for last 5 batches: 607.4180145263672
	Loss for last 5 batches: 620.0914840698242
	Loss for last 5 batches: 610.9449157714844
	Loss for last 5 batches: 609.3451995849609
	Loss for last 5 batches: 628.18248

	Loss for last 5 batches: 618.5159378051758
	Loss for last 5 batches: 628.1734313964844
	Loss for last 5 batches: 620.629280090332
	Loss for last 5 batches: 623.1659164428711
	Loss for last 5 batches: 617.6929702758789
	Loss for last 5 batches: 612.3947067260742
	Loss for last 5 batches: 614.8458557128906
	Loss for last 5 batches: 612.0023956298828
	Loss for last 5 batches: 614.6522598266602
	Loss for last 5 batches: 623.4795837402344
	Loss for last 5 batches: 627.4942779541016
	Loss for last 5 batches: 623.4455871582031
	Loss for last 5 batches: 623.0117797851562
	Loss for last 5 batches: 624.4417114257812
	Loss for last 5 batches: 620.7941436767578
	Loss for last 5 batches: 620.2859954833984
	Loss for last 5 batches: 613.7908401489258
	Loss for last 5 batches: 632.1869354248047
	Loss for last 5 batches: 608.6446914672852
	Loss for last 5 batches: 618.1867218017578
	Loss for last 5 batches: 602.2728805541992
	Loss for last 5 batches: 618.2952117919922
	Loss for last 5 batches: 610.960

	Loss for last 5 batches: 607.4837646484375
	Loss for last 5 batches: 626.4439468383789
	Loss for last 5 batches: 613.4609527587891
	Loss for last 5 batches: 609.7420120239258
	Loss for last 5 batches: 610.5840759277344
	Loss for last 5 batches: 604.8994750976562
	Loss for last 5 batches: 623.6975173950195
	Loss for last 5 batches: 623.8045272827148
	Loss for last 5 batches: 628.2409362792969
	Loss for last 5 batches: 621.7820053100586
	Loss for last 5 batches: 622.9592590332031
	Loss for last 5 batches: 624.7495651245117
	Loss for last 5 batches: 614.2679138183594
	Loss for last 5 batches: 613.7499313354492
	Loss for last 5 batches: 620.4616470336914
	Loss for last 5 batches: 629.7053298950195
	Loss for last 5 batches: 622.7777099609375
	Loss for last 5 batches: 607.2471160888672
	Loss for last 5 batches: 620.7845153808594
	Loss for last 5 batches: 611.611442565918
	Loss for last 5 batches: 620.1641693115234
	Loss for last 5 batches: 623.8096694946289
	Loss for last 5 batches: 625.388

	Loss for last 5 batches: 624.5940322875977
	Loss for last 5 batches: 614.1932144165039
	Loss for last 5 batches: 614.3256225585938
	Loss for last 5 batches: 632.2599563598633
	Loss for last 5 batches: 628.3733291625977
	Loss for last 5 batches: 618.5448913574219
	Loss for last 5 batches: 618.5103988647461
	Loss for last 5 batches: 626.9700622558594
	Loss for last 5 batches: 620.9442138671875
	Loss for last 5 batches: 619.6933822631836
	Loss for last 5 batches: 618.7152709960938
	Loss for last 5 batches: 616.6642150878906
	Loss for last 5 batches: 629.2242202758789
	Loss for last 5 batches: 628.4526672363281
	Loss for last 5 batches: 624.7635116577148
	Loss for last 5 batches: 612.6932678222656
	Loss for last 5 batches: 617.0099639892578
	Loss for last 5 batches: 634.0630416870117
	Loss for last 5 batches: 623.9814224243164
	Loss for last 5 batches: 632.2470932006836
	Loss for last 5 batches: 633.2116012573242
	Loss for last 5 batches: 613.5600280761719
	Loss for last 5 batches: 632.17

	Loss for last 5 batches: 622.8884048461914
	Loss for last 5 batches: 623.2931518554688
	Loss for last 5 batches: 605.6925506591797
	Loss for last 5 batches: 624.6440124511719
	Loss for last 5 batches: 622.1975936889648
	Loss for last 5 batches: 621.8024520874023
	Loss for last 5 batches: 611.9894485473633
	Loss for last 5 batches: 624.6884613037109
	Loss for last 5 batches: 616.0556945800781
	Loss for last 5 batches: 617.8135299682617
	Loss for last 5 batches: 617.9010391235352
	Loss for last 5 batches: 617.1594619750977
	Loss for last 5 batches: 617.1243515014648
	Loss for last 5 batches: 634.2785263061523
	Loss for last 5 batches: 622.6867904663086
	Loss for last 5 batches: 629.4927062988281
	Loss for last 5 batches: 626.9230651855469
	Loss for last 5 batches: 619.0762786865234
	Loss for last 5 batches: 637.447265625
	Loss for last 5 batches: 618.8758773803711
	Loss for last 5 batches: 623.5746383666992
	Loss for last 5 batches: 630.234375
	Loss for last 5 batches: 616.7877578735352

	Loss for last 5 batches: 622.6517715454102
	Loss for last 5 batches: 624.0445327758789
	Loss for last 5 batches: 618.6636428833008
	Loss for last 5 batches: 639.0219116210938
	Loss for last 5 batches: 625.462272644043
	Loss for last 5 batches: 629.5547256469727
	Loss for last 5 batches: 613.9684677124023
	Loss for last 5 batches: 636.4607696533203
	Loss for last 5 batches: 630.0886306762695
	Loss for last 5 batches: 616.799201965332
	Loss for last 5 batches: 616.3177871704102
	Loss for last 5 batches: 626.7265853881836
	Loss for last 5 batches: 609.3867797851562
	Loss for last 5 batches: 620.2155914306641
	Loss for last 5 batches: 629.2334747314453
	Loss for last 5 batches: 626.3018188476562
	Loss for last 5 batches: 625.4251327514648
	Loss for last 5 batches: 619.2043075561523
	Loss for last 5 batches: 624.2146453857422
	Loss for last 5 batches: 636.6684875488281
	Loss for last 5 batches: 620.2276840209961
	Loss for last 5 batches: 614.8795623779297
	Loss for last 5 batches: 617.0391

	Loss for last 5 batches: 611.7064819335938
	Loss for last 5 batches: 609.3132171630859
	Loss for last 5 batches: 617.5491561889648
	Loss for last 5 batches: 623.639778137207
	Loss for last 5 batches: 622.1807861328125
	Loss for last 5 batches: 627.7609786987305
	Loss for last 5 batches: 612.9428176879883
	Loss for last 5 batches: 629.5691986083984
	Loss for last 5 batches: 629.4510269165039
	Loss for last 5 batches: 622.4982986450195
	Loss for last 5 batches: 625.0761795043945
	Loss for last 5 batches: 611.1005783081055
	Loss for last 5 batches: 619.9580230712891
	Loss for last 5 batches: 624.4959182739258
	Loss for last 5 batches: 629.5687713623047
	Loss for last 5 batches: 620.167366027832
	Loss for last 5 batches: 630.4694137573242
	Loss for last 5 batches: 619.6136779785156
	Loss for last 5 batches: 630.1640014648438
	Loss for last 5 batches: 625.8111801147461
	Loss for last 5 batches: 626.9584121704102
	Loss for last 5 batches: 620.951545715332
	Loss for last 5 batches: 620.99562

	Loss for last 5 batches: 635.6028823852539
	Loss for last 5 batches: 614.7853088378906
	Loss for last 5 batches: 628.0316009521484
	Loss for last 5 batches: 617.2824172973633
	Loss for last 5 batches: 604.6794586181641
	Loss for last 5 batches: 618.5458526611328
	Loss for last 5 batches: 615.8666763305664
	Loss for last 5 batches: 628.3020629882812
	Loss for last 5 batches: 625.3618545532227
	Loss for last 5 batches: 616.2127838134766
	Loss for last 5 batches: 609.1052551269531
	Loss for last 5 batches: 608.1830139160156
	Loss for last 5 batches: 625.7568054199219
	Loss for last 5 batches: 606.7919921875
	Loss for last 5 batches: 616.9694671630859
	Loss for last 5 batches: 622.1424179077148
	Loss for last 5 batches: 619.1458892822266
	Loss for last 5 batches: 621.8493194580078
	Loss for last 5 batches: 629.1564331054688
	Loss for last 5 batches: 624.1336212158203
	Loss for last 5 batches: 625.5676803588867
	Loss for last 5 batches: 630.8153915405273
	Loss for last 5 batches: 629.51750

	Loss for last 5 batches: 613.5635833740234
	Loss for last 5 batches: 619.4406204223633
	Loss for last 5 batches: 625.2646179199219
	Loss for last 5 batches: 601.4146728515625
	Loss for last 5 batches: 618.5783157348633
	Loss for last 5 batches: 615.7856826782227
	Loss for last 5 batches: 619.4009017944336
	Loss for last 5 batches: 624.3568954467773
	Loss for last 5 batches: 627.0924224853516
	Loss for last 5 batches: 612.6870651245117
	Loss for last 5 batches: 621.7621383666992
	Loss for last 5 batches: 620.1903228759766
	Loss for last 5 batches: 609.0563049316406
	Loss for last 5 batches: 636.7998428344727
	Loss for last 5 batches: 608.9538116455078
	Loss for last 5 batches: 611.1245040893555
	Loss for last 5 batches: 614.9663314819336
	Loss for last 5 batches: 608.4625015258789
	Loss for last 5 batches: 619.968994140625
	Loss for last 5 batches: 616.6438674926758
	Loss for last 5 batches: 631.4749145507812
	Loss for last 5 batches: 603.6104049682617
	Loss for last 5 batches: 621.015

	Loss for last 5 batches: 631.0073013305664
	Loss for last 5 batches: 614.9323120117188
	Loss for last 5 batches: 608.3475341796875
	Loss for last 5 batches: 630.274787902832
	Loss for last 5 batches: 618.3491134643555
	Loss for last 5 batches: 636.63623046875
	Loss for last 5 batches: 629.1388168334961
	Loss for last 5 batches: 621.0441513061523
	Loss for last 5 batches: 614.6122055053711
	Loss for last 5 batches: 626.6187515258789
	Loss for last 5 batches: 637.4892959594727
	Loss for last 5 batches: 627.8242263793945
	Loss for last 5 batches: 613.4682464599609
	Loss for last 5 batches: 622.1634826660156
	Loss for last 5 batches: 624.210205078125
	Loss for last 5 batches: 623.8376007080078
	Loss for last 5 batches: 623.9069061279297
	Loss for last 5 batches: 619.5625228881836
	Loss for last 5 batches: 617.368522644043
	Loss for last 5 batches: 618.92919921875
	Loss for last 5 batches: 627.7644958496094
	Loss for last 5 batches: 607.8791732788086
	Loss for last 5 batches: 623.910812377

	Loss for last 5 batches: 633.8552093505859
	Loss for last 5 batches: 629.938117980957
	Loss for last 5 batches: 620.6337432861328
	Loss for last 5 batches: 619.7459945678711
	Loss for last 5 batches: 631.8469543457031
	Loss for last 5 batches: 620.5527725219727
	Loss for last 5 batches: 629.4309768676758
	Loss for last 5 batches: 622.6030960083008
	Loss for last 5 batches: 628.1426239013672
	Loss for last 5 batches: 628.3809204101562
	Loss for last 5 batches: 616.5539703369141
	Loss for last 5 batches: 614.9721755981445
	Loss for last 5 batches: 603.5353927612305
	Loss for last 5 batches: 610.8718338012695
	Loss for last 5 batches: 629.4878463745117
	Loss for last 5 batches: 624.4727630615234
	Loss for last 5 batches: 619.1972198486328
	Loss for last 5 batches: 629.9946212768555
	Loss for last 5 batches: 635.719108581543
	Loss for last 5 batches: 629.3201904296875
	Loss for last 5 batches: 625.0456619262695
	Loss for last 5 batches: 625.3467483520508
	Loss for last 5 batches: 619.3956

	Loss for last 5 batches: 623.7795562744141
	Loss for last 5 batches: 635.6050186157227
	Loss for last 5 batches: 625.723030090332
	Loss for last 5 batches: 634.4549865722656
	Loss for last 5 batches: 623.6157684326172
	Loss for last 5 batches: 623.3015060424805
	Loss for last 5 batches: 619.1212539672852
	Loss for last 5 batches: 623.9802398681641
	Loss for last 5 batches: 619.680793762207
	Loss for last 5 batches: 628.6558837890625
	Loss for last 5 batches: 616.9484481811523
	Loss for last 5 batches: 615.0107192993164
	Loss for last 5 batches: 617.787956237793
	Loss for last 5 batches: 625.9780120849609
	Loss for last 5 batches: 621.3031768798828
	Loss for last 5 batches: 618.0880813598633
	Loss for last 5 batches: 612.1628875732422
	Loss for last 5 batches: 624.2603912353516
	Loss for last 5 batches: 623.7531127929688
	Loss for last 5 batches: 619.2780990600586
	Loss for last 5 batches: 614.6976470947266
	Loss for last 5 batches: 619.4861831665039
	Loss for last 5 batches: 612.49548

	Loss for last 5 batches: 627.7197570800781
	Loss for last 5 batches: 618.3731155395508
	Loss for last 5 batches: 630.8525619506836
	Loss for last 5 batches: 622.4410629272461
	Loss for last 5 batches: 623.2313613891602
	Loss for last 5 batches: 620.2875137329102
	Loss for last 5 batches: 610.6579895019531
	Loss for last 5 batches: 615.8878402709961
	Loss for last 5 batches: 625.3491134643555
	Loss for last 5 batches: 616.3301544189453
	Loss for last 5 batches: 616.8374633789062
	Loss for last 5 batches: 624.3797760009766
	Loss for last 5 batches: 620.2088775634766
	Loss for last 5 batches: 620.0550003051758
	Loss for last 5 batches: 607.1961669921875
	Loss for last 5 batches: 615.7392196655273
	Loss for last 5 batches: 635.1636123657227
	Loss for last 5 batches: 622.4784622192383
	Loss for last 5 batches: 631.5393142700195
	Loss for last 5 batches: 620.6067428588867
	Loss for last 5 batches: 615.362419128418
	Loss for last 5 batches: 621.4527435302734
	Loss for last 5 batches: 618.631

	Loss for last 5 batches: 618.2972564697266
	Loss for last 5 batches: 611.7771301269531
	Loss for last 5 batches: 621.5517425537109
	Loss for last 5 batches: 619.3676681518555
	Loss for last 5 batches: 617.5492935180664
	Loss for last 5 batches: 613.3819122314453
	Loss for last 5 batches: 627.043083190918
	Loss for last 5 batches: 619.8581314086914
	Loss for last 5 batches: 635.7582702636719
	Loss for last 5 batches: 624.4783401489258
	Loss for last 5 batches: 623.459587097168
	Loss for last 5 batches: 632.4863662719727
	Loss for last 5 batches: 624.3144607543945
	Loss for last 5 batches: 621.1538391113281
	Loss for last 5 batches: 621.2056503295898
	Loss for last 5 batches: 623.0418548583984
	Loss for last 5 batches: 624.4502182006836
	Loss for last 5 batches: 625.7138137817383
	Loss for last 5 batches: 620.4660873413086
	Loss for last 5 batches: 607.1304550170898
	Loss for last 5 batches: 624.8083343505859
	Loss for last 5 batches: 627.288215637207
	Loss for last 5 batches: 609.36023

	Loss for last 5 batches: 614.0530090332031
	Loss for last 5 batches: 621.7968978881836
	Loss for last 5 batches: 611.5879135131836
	Loss for last 5 batches: 628.6121444702148
	Loss for last 5 batches: 621.8645248413086
	Loss for last 5 batches: 615.2507781982422
	Loss for last 5 batches: 623.6416702270508
	Loss for last 5 batches: 626.2879486083984
	Loss for last 5 batches: 625.945556640625
	Loss for last 5 batches: 610.6781692504883
	Loss for last 5 batches: 617.9212188720703
	Loss for last 5 batches: 615.0138549804688
	Loss for last 5 batches: 609.1622695922852
	Loss for last 5 batches: 620.2449951171875
	Loss for last 5 batches: 620.2008285522461
	Loss for last 5 batches: 626.1098327636719
	Loss for last 5 batches: 624.3063735961914
	Loss for last 5 batches: 616.8806838989258
	Loss for last 5 batches: 623.6520843505859
	Loss for last 5 batches: 618.5977401733398
	Loss for last 5 batches: 609.4058990478516
	Loss for last 5 batches: 628.2525939941406
	Loss for last 5 batches: 631.159

	Loss for last 5 batches: 616.8961334228516
	Loss for last 5 batches: 617.4299087524414
	Loss for last 5 batches: 614.2913055419922
	Loss for last 5 batches: 623.5818252563477
	Loss for last 5 batches: 634.1767730712891
	Loss for last 5 batches: 626.469612121582
	Loss for last 5 batches: 620.573860168457
	Loss for last 5 batches: 624.1410064697266
	Loss for last 5 batches: 623.0407180786133
	Loss for last 5 batches: 616.77294921875
	Loss for last 5 batches: 635.7318954467773
	Loss for last 5 batches: 617.2989501953125
	Loss for last 5 batches: 616.7073974609375
	Loss for last 5 batches: 635.9307174682617
	Loss for last 5 batches: 639.7347717285156
	Loss for last 5 batches: 632.4867095947266
	Loss for last 5 batches: 621.3542098999023
	Loss for last 5 batches: 623.07421875
	Loss for last 5 batches: 626.7631149291992
	Loss for last 5 batches: 626.8816528320312
	Loss for last 5 batches: 623.0488204956055
	Loss for last 5 batches: 631.3950958251953
	Loss for last 5 batches: 628.08564758300

	Loss for last 5 batches: 615.4293899536133
	Loss for last 5 batches: 636.7816772460938
	Loss for last 5 batches: 620.8553848266602
	Loss for last 5 batches: 629.9621047973633
	Loss for last 5 batches: 618.5258865356445
	Loss for last 5 batches: 628.1450042724609
	Loss for last 5 batches: 621.6420516967773
	Loss for last 5 batches: 628.0823211669922
	Loss for last 5 batches: 610.9947280883789
	Loss for last 5 batches: 620.801643371582
	Loss for last 5 batches: 612.0737609863281
	Loss for last 5 batches: 615.3076934814453
	Loss for last 5 batches: 616.5887832641602
	Loss for last 5 batches: 632.6037521362305
	Loss for last 5 batches: 618.0727310180664
	Loss for last 5 batches: 629.3633270263672
	Loss for last 5 batches: 617.0744018554688
	Loss for last 5 batches: 631.0284423828125
	Loss for last 5 batches: 614.0686264038086
	Loss for last 5 batches: 618.8810043334961
	Loss for last 5 batches: 633.4111099243164
	Loss for last 5 batches: 624.0230484008789
	Loss for last 5 batches: 620.426

	Loss for last 5 batches: 617.3535537719727
	Loss for last 5 batches: 619.0312881469727
	Loss for last 5 batches: 622.4192047119141
	Loss for last 5 batches: 610.3487777709961
	Loss for last 5 batches: 630.833122253418
	Loss for last 5 batches: 606.3819122314453
	Loss for last 5 batches: 624.4832000732422
	Loss for last 5 batches: 617.8890075683594
	Loss for last 5 batches: 627.7072296142578
	Loss for last 5 batches: 632.2181549072266
	Loss for last 5 batches: 622.8483963012695
	Loss for last 5 batches: 618.7981262207031
	Loss for last 5 batches: 623.0785980224609
	Loss for last 5 batches: 616.0032272338867
	Loss for last 5 batches: 634.5557250976562
	Loss for last 5 batches: 620.8296890258789
	Loss for last 5 batches: 621.9666748046875
	Loss for last 5 batches: 632.6778411865234
	Loss for last 5 batches: 600.8701477050781
	Loss for last 5 batches: 613.0245971679688
	Loss for last 5 batches: 614.9880828857422
	Loss for last 5 batches: 628.003662109375
	Loss for last 5 batches: 636.7044

	Loss for last 5 batches: 608.7450103759766
	Loss for last 5 batches: 637.368766784668
	Loss for last 5 batches: 625.0664291381836
	Loss for last 5 batches: 619.4346237182617
	Loss for last 5 batches: 619.2102966308594
	Loss for last 5 batches: 623.0515365600586
	Loss for last 5 batches: 605.7835998535156
	Loss for last 5 batches: 613.8307800292969
	Loss for last 5 batches: 632.7296142578125
	Loss for last 5 batches: 625.328254699707
	Loss for last 5 batches: 615.8970718383789
	Loss for last 5 batches: 622.2962265014648
	Loss for last 5 batches: 626.3149948120117
	Loss for last 5 batches: 626.131965637207
	Loss for last 5 batches: 622.0307083129883
	Loss for last 5 batches: 622.1618576049805
	Loss for last 5 batches: 606.5009536743164
	Loss for last 5 batches: 617.7463531494141
	Loss for last 5 batches: 618.5802154541016
	Loss for last 5 batches: 622.0924453735352
	Loss for last 5 batches: 629.7820053100586
	Loss for last 5 batches: 635.4334411621094
	Loss for last 5 batches: 634.10060

	Loss for last 5 batches: 615.483024597168
	Loss for last 5 batches: 611.516227722168
	Loss for last 5 batches: 614.5423431396484
	Loss for last 5 batches: 620.942741394043
	Loss for last 5 batches: 623.6966018676758
	Loss for last 5 batches: 620.1510772705078
	Loss for last 5 batches: 623.6224594116211
	Loss for last 5 batches: 622.6757431030273
	Loss for last 5 batches: 616.5242080688477
	Loss for last 5 batches: 621.0811996459961
	Loss for last 5 batches: 623.4852905273438
	Loss for last 5 batches: 629.2010040283203
	Loss for last 5 batches: 623.7059097290039
	Loss for last 5 batches: 624.5064239501953
	Loss for last 5 batches: 609.7808380126953
	Loss for last 5 batches: 611.2548675537109
	Loss for last 5 batches: 624.4407348632812
	Loss for last 5 batches: 623.5468978881836
	Loss for last 5 batches: 620.4817733764648
	Loss for last 5 batches: 628.9982833862305
	Loss for last 5 batches: 616.9763031005859
	Loss for last 5 batches: 626.2640228271484
	Loss for last 5 batches: 625.49306

	Loss for last 5 batches: 622.6270599365234
	Loss for last 5 batches: 622.977294921875
	Loss for last 5 batches: 624.8667755126953
	Loss for last 5 batches: 629.8073196411133
	Loss for last 5 batches: 622.1313552856445
	Loss for last 5 batches: 629.111701965332
	Loss for last 5 batches: 616.0930480957031
	Loss for last 5 batches: 615.214973449707
	Loss for last 5 batches: 620.763542175293
	Loss for last 5 batches: 630.2136001586914
	Loss for last 5 batches: 631.9077072143555
	Loss for last 5 batches: 619.8510208129883
	Loss for last 5 batches: 615.0225601196289
	Loss for last 5 batches: 616.9824447631836
	Loss for last 5 batches: 611.2778930664062
	Loss for last 5 batches: 604.296875
	Loss for last 5 batches: 623.7285614013672
	Loss for last 5 batches: 599.1784515380859
	Loss for last 5 batches: 614.7247085571289
	Loss for last 5 batches: 607.0470123291016
	Loss for last 5 batches: 627.0322494506836
	Loss for last 5 batches: 612.1145477294922
	Loss for last 5 batches: 625.1609573364258

	Loss for last 5 batches: 612.8546905517578
	Loss for last 5 batches: 603.4023513793945
	Loss for last 5 batches: 619.3019943237305
	Loss for last 5 batches: 619.0001525878906
	Loss for last 5 batches: 613.969612121582
	Loss for last 5 batches: 613.8365097045898
	Loss for last 5 batches: 627.5362777709961
	Loss for last 5 batches: 617.7236404418945
	Loss for last 5 batches: 623.6233901977539
	Loss for last 5 batches: 614.5812911987305
	Loss for last 5 batches: 599.6281127929688
	Loss for last 5 batches: 626.7739334106445
	Loss for last 5 batches: 609.1735458374023
	Loss for last 5 batches: 618.7493515014648
	Loss for last 5 batches: 621.2779159545898
	Loss for last 5 batches: 631.0158538818359
	Loss for last 5 batches: 609.2898712158203


In [28]:
torch.save(trainer.current_network.state_dict(), "latest_weights.pth")

In [29]:
import pickle
len(trainer.replay_buffer)

100000

In [40]:
trainer.current_network.eval()
old_agent = AlphaMCTSAgent(control_net = trainer.control_network)
new_agent = AlphaMCTSAgent(control_net = trainer.current_network)

agents = [old_agent, new_agent]

wins = 0
ties = 0

game = NestedTTT()

for game_num in range(100):
    game.reset()
    agents[0].reset()
    agents[1].reset()
    result = 0
    player_num = game_num // 50 #Both take first turn 50 times
    turn_num = 100 #Turn down the temperature

    while len(game.get_valid_moves()) > 0:
        move, _ = agents[player_num].search(game.copy(), turn_num, allotted_playouts = 800)
        _, _ = agents[1 - player_num].search(game.copy(), turn_num, allotted_playouts = 800)

        result = game.make_move(move)
        if not result:
            game.switch_player()
            agents[0].take_action(move)
            agents[1].take_action(move)
            player_num = (player_num + 1) % 2

    if not result:
        print("Game {}: Tie".format(game_num))
        ties += 1
    elif result and player_num == 1:
        print("Game {}: Win".format(game_num))
        wins += 1

if wins + .5 * ties >= 55:
    print("Challenger network won {} games and tied {} games; it becomes new control network".format(wins, ties))
    torch.save(trainer.current_network.state_dict(), "control_weights_{}.pth".format(trainer.n_batches))
    trainer.control_network.load_state_dict(trainer.current_network.state_dict())
else:
    print("Challenger network not sufficiently better; {} wins and {} ties".format(wins, ties))

trainer.control_network.eval()
trainer.current_network.train()            

Game 0: Tie
Game 22: Tie
Game 25: Tie
Game 28: Tie
Game 31: Tie
Game 54: Tie
Game 61: Tie
Game 62: Tie
Game 63: Tie
Game 67: Tie
Game 87: Tie
Game 93: Tie
Game 98: Tie
Challenger network not sufficiently better; 0 wins and 13 ties


NestedTTTNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_inner): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_collapse): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn_outer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_collapse): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1))
  (bn_linear): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (policy_head): Sequential(
    (0): Linear(in_features=256, out_features=81, bias=True)
    (1): Softmax()
  )
  (value_head): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
    (3): Tanh()
  )
)

In [41]:
trainer.current_network.load_state_dict(trainer.control_network.state_dict())

In [42]:
from multiprocessing import Pool, cpu_count

cpu_count()

8

In [43]:
pool = Pool(cpu_count()-1)
pool

<multiprocessing.pool.Pool at 0x1d49bd35908>

In [45]:
def run_train_parallel(trainer):
    print("Starting self-play game")
    trainer.generate_self_play_data(1)
    print("Self play finished, returning buffer contents")
    return list(trainer.replay_buffer.buffer)

In [None]:
trainers = []
for _ in range(cpu_count()-1):
    trainer = SelfPlayTrainer(AlphaMCTSAgent(), NestedTTT(), weights_file = "control_weights_3000.pth")
    trainers.append(trainer)
results = pool.map(run_train_parallel, trainers)

In [None]:
results

In [None]:
print("HI")