In [1]:
# from google.colab import drive
# drive.mount('/content/drive')
# path = "/content/drive/MyDrive/code/"
path = ""

In [2]:
# !pip install d2l==1.0.0-alpha1.post0

In [3]:
# some code borrowed from
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
# and http://d2l.ai/

In [4]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
from torch import nn
from d2l import torch as d2l
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [5]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tetris environment

In [7]:
# piece is a tensor of coordinates with min x and y values being 0
test = torch.tensor([[0,0],[0,1],[0,2],[1,1]])  # a T-piece

In [8]:
def set_corner(piece):
    return piece - piece.min(axis=0).values

In [9]:
def coords_to_2d(piece):
    piece_2d = torch.zeros(*(piece.max(axis=0).values+1).tolist(), dtype=torch.uint8)
    piece_2d[piece[:,0],piece[:,1]] = 1
    return piece_2d

In [10]:
def get_rotations(piece):
    return [
        coords_to_2d(
            set_corner(piece @ torch.linalg.matrix_power(
                torch.tensor([[0, 1], [-1, 0]]), i))) for i in range(4)
    ]

In [11]:
def kill_duplicates(array):
    # quick and dirty way of killing duplicate rotations
    num_elems = len(array)
    duplicated = []
    for i in range(len(array)):
        for j in range(i):
            if torch.equal(array[i], array[j]):
                duplicated.append(i)
                break
    return [array[i] for i in range(num_elems) if i not in duplicated]

In [12]:
def visualise(piece_2d):
    print('\n'.join(' '.join('%d' % x for x in y)
                    for y in torch.flipud(piece_2d)).replace('0', '.').replace('1', '#'))

In [13]:
shapes = [
        torch.tensor(_) for _ in [
            [[0, 0], [0, 1], [0, 2], [0, 3]],  # I
            [[1, 0], [0, 0], [0, 1], [0, 2]],  # J
            [[0, 0], [0, 1], [0, 2], [1, 2]],  # L
            [[0, 0], [0, 1], [1, 0], [1, 1]],  # O
            [[0, 0], [0, 1], [1, 1], [1, 2]],  # S
            [[1, 0], [1, 1], [0, 1], [0, 2]],  # Z
            [[0, 0], [0, 1], [0, 2], [1, 1]],  # T
        ]
    ]

pieces_2d = [kill_duplicates(get_rotations(shapes[i])) for i in range(len(shapes))]

In [14]:
class Playfield(d2l.HyperParameters):
    global pieces_2d
    def __init__(self, width=10, height=20, game_over_cost=200, clear_reward=10, pieces=pieces_2d):
        self.save_hyperparameters()
        self.board = torch.zeros(height, width, dtype=torch.uint8)
        self.time_alive = 0
        self.lines_cleared = 0
        self.dead = False
        self.memory = deque([], maxlen=10000)

In [15]:
@add_to_class(Playfield)
def show_board(self):
    visualise(self.board)

In [16]:
@add_to_class(Playfield)
def reset_board(self):
    self.board = torch.zeros(self.height, self.width, dtype=torch.uint8)
    self.dead = False
    self.time_alive, self.lines_cleared = 0, 0

In [17]:
def col_heights(board):
    flipped = torch.flipud(board)
    c_heights = (board.shape[0]-flipped.argmax(axis=0))*flipped.max(axis=0).values
    c_holes = board.argmin(axis=0)+1
    c_holes = c_holes*(c_holes < c_heights)
    return c_heights, c_holes

@add_to_class(Playfield)
def heights(self):
    return col_heights(self.board)[0]

@add_to_class(Playfield)
def holes(self):
    return col_heights(self.board)[1]

In [18]:
def clear_lines(board):
    remaining = (board.sum(axis=1) != board.shape[1])
    num_remaining = remaining.sum().item()
    return F.pad(board[remaining, :],
                 pad=(0, 0, 0, board.shape[0] -
                      num_remaining)), board.shape[0]-num_remaining

In [19]:
def state_change(new_board, new_heights, new_holes, reward, cleared):
    return {
        'board': new_board,
        'heights': new_heights,
        'holes': new_holes,
        'reward': reward,
        'cleared': cleared
    }

In [20]:
@add_to_class(Playfield)
def next_states(self, poly_num=None):
    # Return a list of tuples (future board, future heights, reward, lines cleared).
    # I set reward to be -(change in board heights),
    # with the exception of a game over, in which reward is -game_over_cost.
    board_heights = self.heights()
    next_states = []
    if not poly_num:
        poly_num = random.randint(0, len(self.pieces) - 1)
    polyomino = self.pieces[poly_num]
    for rot in polyomino:
        rot_height, rot_width = rot.shape
        lower_boundary = rot.argmax(axis=0)
        for i in range(self.width - rot_width + 1):
            sitting_height = (board_heights[i:i + rot_width] -
                              lower_boundary).max()
            if sitting_height + rot_height < self.height:
                new_board = self.board.clone()
                new_board[torch.nonzero(rot)[:, 0] + sitting_height,
                          torch.nonzero(rot)[:, 1] + i] = 1
                new_board[:], cleared = clear_lines(new_board)
                new_heights, new_holes = col_heights(new_board)
#                 reward = torch.tensor(cleared * self.clear_reward)
                reward = 0.1*(board_heights.sum() - new_heights.sum(
                )) + cleared * self.clear_reward
                next_states.append(
                    state_change(new_board, new_heights, new_holes, reward, cleared))
    if not (len(next_states)):
        next_states = [
            state_change(
                torch.zeros(self.height, self.width, dtype=torch.uint8),
                torch.zeros(self.width), torch.zeros(self.width), torch.tensor(-self.game_over_cost), 0)
        ]
        self.dead = True
    return next_states

In [21]:
@add_to_class(Playfield)
def update_board_state(self, state):
    # set the board to a state (chosen from next_states)
    self.board = state['board']
    self.lines_cleared += state['cleared']
    self.time_alive += 1
    if self.dead:
        self.memory.append({'lines_cleared': self.lines_cleared, 'time_alive': self.time_alive})
        self.reset_board()
        # still need to write some code that tracks game statistics over multiple runs for the main loop

In [22]:
# To test everything works correctly so far, play by greedily choosing the action with highest immediate reward


def greedy_step(pfield, printing=True):
    next_states = pfield.next_states()
    max_reward = max([x['reward'] for x in next_states])
    pfield.update_board_state(
        random.choice([x for x in next_states if x['reward'] == max_reward]))
    if printing:
        pfield.show_board()
        print(
            f"{p.time_alive} piece drop{'' if pfield.time_alive==1 else 's'}, {pfield.lines_cleared} lines cleared"
        )


p = Playfield()
greedy_step(p,False)
while p.time_alive:
    greedy_step(p,False)

In [23]:
def state_to_float(state):
    # in retrospect, this should have been a class, not a dictionary
    return {
        'board': state[board].float(),
        'heights': state[heights].float(),
        'holes': state[holes].float(),
        'reward': reward,
        'cleared': cleared
    }

# DQN setup

In [24]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [25]:
# this is a bit different to the usual transition because the action is
# completely characterised by the next state, so we just drop action
# reward is also gone because "next_states" are tuples anyway
Transition = namedtuple('Transition', ('state', 'next_states'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [26]:
# write DQN with lazy, also do some conv on heights and stuff
# also need to pass in height data (maybe hole data? this is encapsuled somewhat in the loss function though)
# maybe i should multiply cost of holes (e.g. by adding multiples of (cost-4))

In [27]:
class heights_CNN(nn.Module):
    def __init__(self, c1=256, c2=128, c3=512-20, c4=20, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = nn.LazyConv1d(c1, kernel_size=3, padding=0, stride=1)
        self.conv2 = nn.LazyConv1d(c2, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.LazyBatchNorm1d()
        self.bn2 = nn.LazyBatchNorm1d()
        self.out_layer = nn.LazyLinear(c3)
        
    def forward(self, x):
        x = x.to(device)
        y = F.relu(self.bn1(self.conv1(x)))
        y = F.relu(self.bn2(self.conv2(y)))
        y = torch.flatten(y,1,2)
        y = F.relu(self.out_layer(y))
        y = torch.concat((y,torch.squeeze(x.flatten(1,2),dim=1)),dim=1)
        return y

In [28]:
class column_conv(nn.Module):
    pass

In [29]:
class board_CNN(nn.Module):
    def __init__(self, c1=64, c2=32, c3=32, c4=32, c5=512, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = nn.LazyConv2d(c1, kernel_size=3, padding=0, stride=1)
        self.conv2 = nn.LazyConv2d(c2, kernel_size=3, padding=1, stride=2)
        self.conv3 = nn.LazyConv2d(c3, kernel_size=(3,1), padding=0, stride=2)
        self.conv4 = nn.LazyConv2d(c4, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()
        self.bn3 = nn.LazyBatchNorm2d()
        self.bn4 = nn.LazyBatchNorm2d()
        self.out_layer = nn.LazyLinear(c5)
    
    def forward(self, x):
        x = x.to(device)
        y = F.pad(x,(1,1,1,0),value=1)
        y = F.relu(self.bn1(self.conv1(y)))
        y = F.relu(self.bn2(self.conv2(y)))
        y = F.relu(self.bn3(self.conv3(y)))
        y = F.relu(self.bn4(self.conv4(y)))
        y = torch.flatten(y,1,3)
        y = F.relu(self.out_layer(y))
#         y = torch.concat((y,torch.squeeze(x,dim=1)),dim=1)
        return y

In [30]:
class DQN(nn.Module):
    def __init__(self, lr=1e-4, width=128, loss_mem_length=100, **kwargs):
        super().__init__(**kwargs)
        self.height_net = heights_CNN(64,64,128-20,20)
        self.board_net = board_CNN(32,32,16,16,128)
        self.dense1 = nn.LazyLinear(width)
        self.dense2 = nn.LazyLinear(width)
        self.out = nn.LazyLinear(1)
        self.optimiser = optim.RMSprop(self.parameters(),weight_decay=5e-5,lr=lr)
        self.loss_memory = deque([], maxlen=loss_mem_length)
    
    def forward(self, height_tensor, board_tensor):
        height_tensor = height_tensor.to(device)
        board_tensor = board_tensor.to(device)
        h = self.height_net(height_tensor)
        b = self.board_net(board_tensor)
        y = torch.cat((h,b),1)
        y = F.relu(self.dense1(y))
        y = F.relu(self.dense2(y))
        y = self.out(y)
        return y

In [31]:
@add_to_class(DQN)
def loss(self, y_hat, y):
    fn = nn.SmoothL1Loss()
    return fn(y_hat, y)

In [32]:
@add_to_class(DQN)
def train_step(self, batch_loss):
    self.optimiser.zero_grad()
    batch_loss.backward()
    for param in self.parameters():
        param.grad.data.clamp_(-1, 1)
    self.optimiser.step()

In [33]:
# batch size, num channels
test_cnn = heights_CNN()
test_heights = torch.concat(
    (p.heights().reshape(1, 1, -1) * torch.tensor([[[1]], [[1]]]).float(),
     p.holes().reshape(1, 1, -1) * torch.tensor([[[1]], [[1]]]).float()), 1)
test_cnn(test_heights).shape



torch.Size([2, 512])

In [34]:
test_cnn_2 = board_CNN()
test_boards = p.board.reshape(1,1,p.board.shape[0],p.board.shape[1])*torch.tensor([[[[1]]],[[[1]]]]).float()
test_cnn_2(test_boards).shape

torch.Size([2, 512])

In [35]:
test_boards.shape,test_heights.shape

(torch.Size([2, 1, 20, 10]), torch.Size([2, 2, 10]))

In [36]:
test_dqn = DQN()
test_heights = torch.concat(
    (p.heights().reshape(1, 1, -1) * torch.tensor([[[1]], [[1]]]).float(),
     p.holes().reshape(1, 1, -1) * torch.tensor([[[1]], [[1]]]).float()), 1)
test_boards = p.board.reshape(1,1,p.board.shape[0],p.board.shape[1])*torch.tensor([[[[1]]],[[[1]]]]).float()
test_dqn(test_heights, test_boards).shape

torch.Size([2, 1])

In [37]:
count_parameters(test_dqn)

156333

# set up training

In [38]:
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 200

BOARD_WIDTH = 10
BOARD_HEIGHT = 20
GAME_OVER_COST = 200
CLEAR_REWARD = 10

MOVES_PER_TRAIN = 50

Q_net = DQN()
p = Playfield(width=BOARD_WIDTH, height=BOARD_HEIGHT, game_over_cost=GAME_OVER_COST, clear_reward=CLEAR_REWARD)
init_heights = torch.concat(
    (p.heights().reshape(1, 1, -1) * torch.tensor([[[1]]]).float(),
     p.holes().reshape(1, 1, -1) * torch.tensor([[[1]]]).float()), 1)
init_boards = p.board.reshape(1,1,p.board.shape[0],p.board.shape[1])*torch.tensor([[[[1]]]]).float()
# print(init_heights.shape,init_boards.shape)
Q_net(init_heights,init_boards)
Q_net.apply(d2l.init_cnn)

memory = ReplayMemory(10000)

In [39]:
def next_states_to_tensors(states):
    out = {
        'board_tensor':
        torch.concat(
            tuple(t['board'].reshape(1, 1, BOARD_HEIGHT, BOARD_WIDTH)
                  for t in states)).float(),
        'heights_tensor':
        torch.concat(
            tuple(
                torch.concat((t['heights'].reshape(1, 1, BOARD_WIDTH),
                              t['holes'].reshape(1, 1, BOARD_WIDTH)), 1)
                for t in states)).float(),
        'reward_tensor':
        torch.concat(tuple(t['reward'].reshape(1,1) for t in states))
    }
#     print(out['heights_tensor'].shape)
    return out

In [40]:
def q_value_single(state):
    return Q_net(torch.concat((state['heights'],state['holes']),1), state['board'])

In [41]:
# Use epsilon-greedy. I might implement a policy that chooses moves with probability dependent on Q later
# (hm maybe i could also have a chance of running the step with highest reward, seems good for convergence)
def do_action(eps1, eps2, gamma=GAMMA):
    # eps1 is the probability of doing the move with the highest immediate value
    # eps2 is the probability of doing a completely random move
    heights, holes = p.heights(), p.holes()
    now_state = {
            'board': p.board.reshape(1, 1, BOARD_HEIGHT, BOARD_WIDTH).float(),
            'heights': heights.reshape(1, 1, BOARD_WIDTH).float(),
            'holes': holes.reshape(1, 1, BOARD_WIDTH).float(),
        }
    next_polyomino = random.randint(0, len(p.pieces) - 1)
    next_states = p.next_states(next_polyomino)
    memory.push(now_state, next_states)
    next_states_tensors = next_states_to_tensors(next_states)
    heights_tensor = next_states_tensors['heights_tensor']
    board_tensor = next_states_tensors['board_tensor']
    reward_tensor = next_states_tensors['reward_tensor']
    if random.random() > eps1+eps2:
        future_q_values = Q_net(heights_tensor, board_tensor)
        move_values = gamma*future_q_values+reward_tensor
        action = next_states[future_q_values.argmax()]
    elif random.random() > eps1:
        action = random.choice(next_states)
    else:
        max_reward = max([x['reward'] for x in next_states])
        action = random.choice([x for x in next_states if x['reward'] == max_reward])
    p.update_board_state(action)
#     p.show_board()
#     return action

In [42]:
do_action(0,0)

In [43]:
# for i in range(128):
#     do_action()

In [44]:
def train_Q_net(gamma, min_memory):
    if len(memory) < min_memory:
        return

    transitions = memory.sample(BATCH_SIZE)

    board_batches = [
        next_states_to_tensors(transition[1])['board_tensor']
        for transition in transitions
    ]
    height_batches = [
        next_states_to_tensors(transition[1])['heights_tensor']
        for transition in transitions
    ]

    now_q_values = torch.concat(
        tuple(q_value_single(transitions[i][0]) for i in range(BATCH_SIZE)))

    future_q_values = [
        Q_net(height_batches[i], board_batches[i]) for i in range(BATCH_SIZE)
    ]  # list (of length BATCH_SIZE) of tensors of q_values

    actions = [
        transitions[i][1][future_q_values[i].argmax()]
        for i in range(BATCH_SIZE)
    ]  # the future states chosen by the Q-net in each state in the batch

    rewards = torch.concat(tuple(t['reward'].reshape(1, -1) for t in actions))

    v_values = torch.concat(
        tuple(
            Q_net(
                torch.concat((
                    actions[i]['heights'].reshape(1, 1, BOARD_WIDTH).float(),
                    actions[i]['holes'].reshape(1, 1, BOARD_WIDTH).float()
                ), 1), actions[i]['board'].reshape(1, 1, BOARD_HEIGHT,
                                                   BOARD_WIDTH).float())
            for i in range(BATCH_SIZE)))

    q_loss = Q_net.loss(now_q_values, rewards + gamma * v_values)

    Q_net.train_step(q_loss)

    Q_net.loss_memory.append(q_loss.detach().item())

In [45]:
# train_Q_net()

# training loop

In [46]:
from IPython.display import clear_output

BATCH_SIZE = 100
GAMMA = 0.95
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 100000

BOARD_WIDTH = 10
BOARD_HEIGHT = 20
GAME_OVER_COST = 100
CLEAR_REWARD = 10

MOVES_PER_TRAIN = 200

LOAD_PARAMS = False
# change this to True to load a previously saved model

Q_net = DQN(loss_mem_length=20)
p = Playfield(width=BOARD_WIDTH,
              height=BOARD_HEIGHT,
              game_over_cost=GAME_OVER_COST,
              clear_reward=CLEAR_REWARD)
init_heights = torch.concat((p.heights().reshape(
    1, 1, -1).float(), p.holes().reshape(1, 1, -1).float()), 1)
init_boards = p.board.reshape(1, 1, p.board.shape[0],
                              p.board.shape[1]) * torch.tensor([[[[1]]]
                                                                ]).float()

if LOAD_PARAMS:
    Q_net.load_state_dict(torch.load(f'{path}Q_net.params'))
    with open(f"{path}num_steps.txt", "r") as f:
        num_steps = int(f.readlines()[0])
else:
    Q_net.apply(d2l.init_cnn)
    Q_net(init_heights, init_boards)
    num_steps = 0

memory = ReplayMemory(5000)

# num_steps = 0
while True:
    eps_thresh = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * num_steps / EPS_DECAY)
    for i in range(MOVES_PER_TRAIN):
#         do_action(0.04, 0.04, GAMMA)
        do_action(eps_thresh * 0.5, eps_thresh * 0.5)
        #         do_action(eps_thresh * eps_thresh, eps_thresh * (1-eps_thresh))  # i should make it greedy first then switch over
        num_steps += 1
    train_Q_net(GAMMA, 1000)
    if num_steps % 100 == 0:
        clear_output(wait=True)
        print(p.memory[-1])
        print(
            f"num_steps: {num_steps}, eps_thresh: {eps_thresh}, loss: {sum(Q_net.loss_memory)/max(len(Q_net.loss_memory),1)}"
        )
        p.show_board()
    if num_steps % 2000 == 0:
        torch.save(Q_net.state_dict(), f'{path}Q_net.params')
        with open(f"{path}num_steps.txt", "w") as f:
            f.write(str(num_steps))
    if num_steps % 100000 == 0:
        torch.save(Q_net.state_dict(), f'{path}Q_net_{num_steps}.params')

{'lines_cleared': 0, 'time_alive': 24}
num_steps: 2600, eps_thresh: 0.9286571387821183, loss: 3.124517046742969
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . . . . . . . .
. . . # . . # # # .
. . # # . . # . . .
. . # # # # # # . .
. # # # # # # # . .


KeyboardInterrupt: 

# meta

In [None]:
import sys

# These are the usual ipython objects, including this one you are creating
ipython_vars = ['In', 'Out', 'exit', 'quit', 'get_ipython', 'ipython_vars']

# Get a sorted list of the objects and their sizes
mem_sizes = sorted([
    (x, sys.getsizeof(globals().get(x))) for x in dir()
    if not x.startswith('_') and x not in sys.modules and x not in ipython_vars
],
                   key=lambda x: x[1],
                   reverse=True)
[print(thing[0],f"{thing[1]/1000} MB") for thing in mem_sizes if thing[1] > 1000]

# old stuff that i probably don't need

In [None]:
# old stuff, worry about it later

# concatenate the BATCH_SIZE sets of future boards into a single tensor
# same for heights
# later we will run Q on these as a group of concatenated batches...
# wait this breaks batchnorm doesn't it rip

# test_boards = torch.concat(
#     tuple(t['board'].reshape(1, 1, BOARD_HEIGHT, BOARD_WIDTH)
#           for transition in test_transitions for t in transition[1])).float()
# test_heights = torch.concat(
#     tuple(t['heights'].reshape(1, 1, BOARD_WIDTH)
#           for transition in test_transitions for t in transition[1])).float()
# test_lengths = torch.tensor([0]+[len(t[1]) for t in test_transitions]).cumsum(0)
# test_lengths
# test_q_values = Q_net(test_heights, test_boards)
# test_q_values.shape
# these two should be the same if everything is working correctly
# test_boards[test_lengths[10]:test_lengths[11]], tuple(t['board'] for t in test_transitions[10][1])
# [test_q_values[test_lengths[i]:test_lengths[i+1]].argmax() for i in range(BATCH_SIZE)]

In [None]:
# d2l.init_cnn?

In [None]:
# d2l.Classifier?

In [None]:
# # testing batch of transitions
# test_transitions = [
#     Transition(
#         {
#             'board': p.board.reshape(1, 1, BOARD_HEIGHT, BOARD_WIDTH).float(),
#             'heights': p.heights().reshape(1, 1, BOARD_WIDTH).float()
#         }, p.next_states()) for i in range(BATCH_SIZE)
# ]

In [None]:
# test_board_batches = [
#     next_states_to_tensors(transition[1])['board_tensor']
#     for transition in test_transitions
# ]
# test_height_batches = [
#     next_states_to_tensors(transition[1])['heights_tensor']
#     for transition in test_transitions
# ]

In [None]:
# test_now_q_values = torch.concat(
#     tuple(q_value_single(test_transitions[i][0]) for i in range(BATCH_SIZE))
# )

# test_future_q_values = [
#     Q_net(test_height_batches[i], test_board_batches[i])
#     for i in range(BATCH_SIZE)
# ]  # list (of length BATCH_SIZE) of tensors of q_values

# test_actions = [
#     test_transitions[i][1][test_future_q_values[i].argmax()]
#     for i in range(BATCH_SIZE)
# ]  # the future states chosen by the Q-net in each state in the batch

# test_rewards = torch.concat(
#     tuple(
#         t['reward'].reshape(1,-1) for t in test_actions
#     )
# )

# test_v_values = torch.concat(
#     tuple(
#         Q_net(
#             test_actions[i]['heights'].reshape(
#                 1, 1, BOARD_WIDTH).float(), test_actions[i]['board'].reshape(
#                     1, 1, BOARD_HEIGHT, BOARD_WIDTH).float())
#         for i in range(BATCH_SIZE)))

# # test_actions

In [None]:
# # use huber
# loss_fn = nn.SmoothL1Loss()
# test_loss = loss_fn(test_now_q_values, (test_rewards + GAMMA*test_v_values))
# test_loss

In [None]:
# Q_net.train_step(test_loss)