In [1]:
from nes_py.wrappers import JoypadSpace
import gym_tetris
from gym_tetris.actions import SIMPLE_MOVEMENT,MOVEMENT
import numpy as np
import random
import numpy as np
from matplotlib import pyplot as plt

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

device = torch.device("cpu")


In [3]:
actionCount = 0

In [4]:
piece_type_lookup = {
    'Tu': 'T', 'Tr': 'T', 'Td': 'T', 'Tl': 'T',
    'Jl': 'J', 'Ju': 'J', 'Jr': 'J', 'Jd': 'J',
    'Zh': 'Z', 'Zv': 'Z',
    'O': 'O',
    'Sh': 'S', 'Sv': 'S',
    'Lr': 'L', 'Ld': 'L', 'Ll': 'L', 'Lu': 'L',
    'Iv': 'I', 'Ih': 'I',
    'none': 'none'
}

In [5]:
tetris_start_positions = {
    'Tu': (-2, 4),  # Top of 'T' piece, up orientation, centered
    'Tr': (-2, 4),  # Top of 'T' piece, right orientation, centered
    'Td': (-2, 4),  # Top of 'T' piece, down orientation, centered
    'Tl': (-2, 4),  # Top of 'T' piece, left orientation, centered
    'Jl': (-2, 4),  # Top of 'J' piece, left orientation, centered
    'Ju': (-2, 4),  # Top of 'J' piece, up orientation, centered
    'Jr': (-2, 4),  # Top of 'J' piece, right orientation, centered
    'Jd': (-2, 4),  # Top of 'J' piece, down orientation, centered
    'Zh': (-2, 4),  # Top of 'Z' piece, horizontal orientation, centered
    'Zv': (-2, 4),  # Top of 'Z' piece, vertical orientation, centered
    'O':  (-2, 4),  # Top of 'O' piece, centered
    'Sh': (-2, 4),  # Top of 'S' piece, horizontal orientation, centered
    'Sv': (-2, 4),  # Top of 'S' piece, vertical orientation, centered
    'Lr': (-2, 4),  # Top of 'L' piece, right orientation, centered
    'Ld': (-2, 4),  # Top of 'L' piece, down orientation, centered
    'Ll': (-2, 4),  # Top of 'L' piece, left orientation, centered
    'Lu': (-2, 4),  # Top of 'L' piece, up orientation, centered
    # Top of 'I' piece, vertical orientation, slightly left centered to fit 4 blocks
    'Iv': (-4, 3),
    # Top of 'I' piece, horizontal orientation, slightly left centered to fit 4 blocks
    'Ih': (-1, 3)
}

In [6]:
def calculate_heights(grid):
    # This function finds the height of each column in the grid.
    # It calculates height from the bottom to the first non-zero cell encountered from the top.
    heights = np.zeros(grid.shape[1], dtype=int)
    for col in range(grid.shape[1]):
        column = grid[:, col]  # Extract the entire column
        first_filled = np.where(column > 0)[0]
        if first_filled.size > 0:
            heights[col] = grid.shape[0] - first_filled.min()
    return heights

In [7]:
def statePreprocess(state):
    #the shape of the play area is from 48 to 208 in the x direction and 96 to 176 in the y direction
    state = state[48:208,96:176]
    grayscale = np.dot(state[...,:3], [0.2989, 0.5870, 0.1140])
    binary_array = grayscale.reshape(20,8,10,8).max(axis=(1,3)) > 0
    return binary_array.astype(int)

In [8]:
def one_hot_piece(piece):
    # Extended mapping to include variations like 'Td', 'Ld', etc.
    mapping = {
    'Tu': 0,
    'Tr': 1,
    'Td': 2,
    'Tl': 3,
    'Jl': 4,
    'Ju': 5,
    'Jr': 6,
    'Jd': 7,
    'Zh': 8,
    'Zv': 9,
    'O': 10,
    'Sh': 11,
    'Sv': 12,
    'Lr': 13,
    'Ld': 14,
    'Ll': 15,
    'Lu': 16,
    'Iv': 17,
    'Ih': 18,
    'none': -1}
    vector = [0] * len(mapping)
    if piece in mapping:  # Check if the piece is recognized
        vector[mapping[piece]] = 1
    else:
        print('Piece not recognized:', piece)
    return vector

In [9]:
def count_covered_voids(board):
    """
    Count the number of empty cells directly covered by a full cell in a Tetris board.

    :param board: 2D numpy array representing the Tetris board, where 1 is a full cell and 0 is empty.
    :return: Integer count of covered voids.
    """
    covered_voids = 0
    rows, cols = board.shape

    # Iterate over each cell in the board except for the bottom row
    for r in range(rows - 1):
        for c in range(cols):
            if board[r, c] == 1 and board[r + 1, c] == 0:
                # If the current cell is full and the cell directly below it is empty, count it as a covered void
                covered_voids += 1

    return covered_voids

In [10]:
def calculate_reward(clearLines, heightDiff, MaxHeight, done, fixState):
    C = 10
    aplha = 0.5
    beta = 1
    detla = 0.1
    global actionCount
    emptyCells = count_covered_voids(fixState)
    difference =np.sum(np.abs(heightDiff - np.median(heightDiff)))
    reward = C * clearLines - aplha * difference - beta * MaxHeight + 1 + detla * emptyCells
    if done:
        reward -=50
    else :
        reward += 1

    return reward

In [11]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.conv1 = nn.Sequential(nn.Linear(input_dim, 128), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU())
        self.conv3 = nn.Sequential(nn.Linear(128, 128), nn.ReLU())
        self.fc = nn.Linear(128, output_dim)
        
        self._create_weights()
        
    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):

        x = self.conv1(x)

        x = self.conv2(x)

        x = self.conv3(x)

        x = self.fc(x)

        return x

In [12]:
def inputData(heightDiff, current_piece, next_piece):
    heightDiff = heightDiff.reshape(-1).astype(float)

    # One-hot encode the current and next pieces
    current_piece_vector = one_hot_piece(current_piece)
    next_piece_vector = one_hot_piece(next_piece)

    # Combine the flattened grid and the piece vectors into one state vector
    return torch.tensor(np.concatenate([heightDiff, current_piece_vector, next_piece_vector]), dtype=torch.float32)

In [13]:
# Initialize the DQN
input_dim = 10 + 20 +20  # 10 for the height difference of the tetris grid, 19 for the one-hot encoded pieces 2 for none piece
output_dim = len(SIMPLE_MOVEMENT)  # Number of possible actions
model = DQN(input_dim, output_dim)
model.to(device)

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize the replay memory
replay_memory = deque(maxlen = 20000)
batch_size = 128

epsilon = 1.0  # Starting value of epsilon
epsilon_min = 0.01  # Minimum value of epsilon
epsilon_decay = 0.995  # Decay multiplier for epsilon

In [14]:
#number of episodes
episodes = 10000

env = gym_tetris.make('TetrisA-v3')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
model.train()

episode_rewards = []

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

for episode in range(episodes):
    state = env.reset()
    oldInfo = {"current_piece": 'none', "next_piece": 'none'}
    oldState = np.zeros((20, 10))
    fixState = np.zeros((20, 10))
    oldheightDiff = np.zeros((1, 10))
    newheightDiff = np.zeros((1, 10))
    total_reward = 0
    done = False
    oldData = inputData(oldheightDiff, oldInfo['current_piece'], oldInfo['next_piece'])
    global actionCount
    actionCount = 0
    randomAction = 0
    modelAction = 0
    while not done:
        actionCount += 1
        # Exploration vs Exploitation
        if random.random() < epsilon:
            action = random.randint(0, 5)
            randomAction += 1
        else:
            q_values = model(oldData)
            action = torch.argmax(q_values).item()
            modelAction += 1
        newState, reward, done, newInfo = env.step(action)
        #env.render()
        newState = statePreprocess(newState)
        
        #calculate the height difference
        # print(oldInfo['current_piece'], newInfo['current_piece'])
        # print(oldInfo['next_piece'], newInfo['next_piece'])
        if (oldInfo['current_piece'] != None and newInfo['current_piece'] != None) and (oldInfo['next_piece'] != None and newInfo['next_piece'] != None):

            if (piece_type_lookup[oldInfo['current_piece']] != piece_type_lookup[newInfo['current_piece']] or
                oldInfo['current_piece'] != newInfo['current_piece']):
                
                if oldInfo['current_piece'] != 'none':
                    prev_piece = oldInfo['current_piece']
                    
                    start_row, start_col = tetris_start_positions[prev_piece]
                    # Adjust for actual board size and orientation specifics
                    piece_array = np.zeros_like(oldState)
                    piece_height, piece_width = piece_array.shape
                    
                    for r in range(piece_height):
                        for c in range(piece_width):
                            if start_row + r < 0 or start_row + r >= 20 or start_col + c < 0 or start_col + c >= 10:
                                continue
                            if oldState[start_row + r, start_col + c] == 1:
                                oldState[start_row + r, start_col + c] = 0
                                
                    fixState = (oldState + newState) > 0
                    newheightDiff = calculate_heights(fixState)
                
                newRewrad = calculate_reward(newInfo['number_of_lines'], newheightDiff, np.max(newheightDiff), done, fixState)
                reward = newRewrad - total_reward
            
        newData = inputData(newheightDiff, newInfo['current_piece'], newInfo['next_piece'])
        
        # Add the inputdata, action, reward, next state, and done to the replay memory
        replay_memory.append((torch.tensor(newData, dtype=torch.float), torch.tensor(action, dtype=torch.long), torch.tensor(
            reward, dtype=torch.float), torch.tensor(oldData, dtype=torch.float), torch.tensor(done, dtype=torch.float)))

        
        # Training from replay buffer
        if len(replay_memory) >= batch_size:
            batch = random.sample(replay_memory, batch_size)
            data_batch, action_batch, reward_batch, next_state_batch, done_batch = map(torch.stack, zip(*batch))
            
            q_values = model(data_batch)
            with torch.no_grad():
                next_q_values = model(next_state_batch)
            target_q_values = reward_batch + 0.99 * torch.max(next_q_values, dim=1).values * (1 - done_batch)
            loss = loss_fn(q_values[range(batch_size), action_batch], target_q_values)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        oldheightDiff = newheightDiff
        oldInfo = newInfo
        oldState = newState
        oldData = newData
        total_reward = newRewrad
    episode_rewards.append(total_reward)
    print(f'Episode {episode + 1}, total reward: {total_reward}','actionCount:', actionCount, 'randomAction:', randomAction, 'modelAction:', modelAction)
    epsilon = max(epsilon_min, epsilon_decay * epsilon)
    if episode % 50 == 0:
        torch.save(model.state_dict(), 'model.pth')
        plt.plot(episode_rewards)
        plt.show()
    

  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  replay_memory.append((torch.tensor(newData, dtype=torch.float), torch.tensor(action, dtype=torch.long), torch.tensor(
  reward, dtype=torch.float), torch.tensor(oldData, dtype=torch.float), torch.tensor(done, dtype=torch.float)))


learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
learning
l

KeyboardInterrupt: 