In [9]:
from nes_py.wrappers import JoypadSpace
import gym_tetris
from gym_tetris.actions import SIMPLE_MOVEMENT,MOVEMENT
import numpy as np
import random
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

In [10]:
def statePreprocess(state):
    #the shape of the play area is from 48 to 208 in the x direction and 96 to 176 in the y direction
    state = state[48:208,96:176]
    grayscale = np.dot(state[...,:3], [0.2989, 0.5870, 0.1140])
    binary_array = grayscale.reshape(20,8,10,8).max(axis=(1,3)) > 0
    return binary_array.astype(int)

In [11]:
def one_hot_piece(piece):
    # Extended mapping to include variations like 'Td', 'Ld', etc.
    mapping = {
    'Tu': 0,
    'Tr': 1,
    'Td': 2,
    'Tl': 3,
    'Jl': 4,
    'Ju': 5,
    'Jr': 6,
    'Jd': 7,
    'Zh': 8,
    'Zv': 9,
    'O': 10,
    'Sh': 11,
    'Sv': 12,
    'Lr': 13,
    'Ld': 14,
    'Ll': 15,
    'Lu': 16,
    'Iv': 17,
    'Ih': 18}
    vector = [0] * len(mapping)
    if piece in mapping:  # Check if the piece is recognized
        vector[mapping[piece]] = 1
    else:
        print('Piece not recognized:', piece)
    return vector

env = gym_tetris.make('TetrisA-v3')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
state = env.reset()
##state =torch.tensor(np.array(state, copy = True), dtype=torch.float32)

In [12]:
def process_state(grid, current_piece, next_piece):
    grid = statePreprocess(grid)
    # Flatten the grid
    flat_grid = grid.reshape(-1).astype(float)  # Convert grid to a flat, float array

    # One-hot encode the current and next pieces
    current_piece_vector = one_hot_piece(current_piece)
    next_piece_vector = one_hot_piece(next_piece)

    #calculate the total used blocks in the grid
    total_blocks = np.sum(grid)

    # Combine the flattened grid and the piece vectors into one state vector
    return torch.tensor(np.concatenate([flat_grid, current_piece_vector, next_piece_vector]), dtype=torch.float32)

In [13]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.conv1 = nn.Sequential(nn.Linear(input_dim, 128), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU())
        self.conv3 = nn.Sequential(nn.Linear(128, 128), nn.ReLU())
        self.fc = nn.Linear(128, output_dim)
        
        self._create_weights()
        
    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.fc(x)
        return x

In [14]:
#load the model
input_dim = 200 + 19 +19  # 200 for the grid, 14 for the one-hot encoded pieces
output_dim = len(SIMPLE_MOVEMENT)  # Number of possible actions
model = DQN(input_dim, output_dim)
print(model)
model.load_state_dict(torch.load('model.pth'))
env = gym_tetris.make('TetrisA-v2')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
model.eval()

DQN(
  (conv1): Sequential(
    (0): Linear(in_features=238, out_features=128, bias=True)
    (1): ReLU()
  )
  (conv2): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (conv3): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (fc): Linear(in_features=128, out_features=6, bias=True)
)


DQN(
  (conv1): Sequential(
    (0): Linear(in_features=238, out_features=128, bias=True)
    (1): ReLU()
  )
  (conv2): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (conv3): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (fc): Linear(in_features=128, out_features=6, bias=True)
)

In [15]:
episodes = 10000

In [16]:
for i in range(episodes):
    state = env.reset()
    state, reward, done, info = env.step(0)
    current_piece = info['current_piece']
    next_piece = info['next_piece']
    state = process_state(state, current_piece, next_piece)
    while not done:
        with torch.no_grad():
            action = model(state)
            print(action)
            action = torch.argmax(action)
        state, reward, done, info = env.step(action.item())
        current_piece = info['current_piece']
        next_piece = info['next_piece']
        state = process_state(state, current_piece, next_piece)
        env.render()
        
env.close()


    
    

tensor([-0.0493,  0.1431,  0.0159,  0.0271, -0.1268,  0.1935])
tensor([-0.0493,  0.1431,  0.0159,  0.0271, -0.1268,  0.1935])
tensor([-0.0493,  0.1431,  0.0159,  0.0271, -0.1268,  0.1935])
tensor([-0.0493,  0.1431,  0.0159,  0.0271, -0.1268,  0.1935])
tensor([-0.0493,  0.1431,  0.0159,  0.0271, -0.1268,  0.1935])
tensor([-0.0250,  0.0560, -0.0608, -0.0090, -0.0094,  0.0276])
tensor([-0.0250,  0.0560, -0.0608, -0.0090, -0.0094,  0.0276])
tensor([ 0.1098,  0.1491,  0.0466,  0.0124, -0.1336,  0.0827])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0.0510])
tensor([-0.0105,  0.1179,  0.0666,  0.0137, -0.0590,  0