<a href="https://colab.research.google.com/github/epeay/tetris-ml/blob/main/tetris.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1


In [2]:
import gymnasium as gym
from gym import spaces
import numpy as np
import pdb



class TetrominoPiece:

    BLOCK = '▆'

    def __init__(self, shape, patterns):
        self.shape = shape
        self.patterns = patterns
        self.rot = 0

    def __str__(self) -> str:
        return f"TetrominoPiece(shape={Tetrominos.shape_name(self.shape)}, rot={self.rot*90}, pattern= {self.printable_pattern(oneline=True)})"

    def printable_pattern(self, oneline=False):
        ret = []
        pattern = self.get_pattern()
        for i, row in enumerate(pattern):
            row_str = " ".join([str(c) for c in row])
            ret.append(row_str)

            if not oneline:
                ret.append("\n")
            else:
                if i < len(pattern)-1:
                    ret.append(" / ",)
        ret = "".join(ret).replace('1', TetrominoPiece.BLOCK).replace('0', '_')
        return "".join(ret)




    def get_pattern(self):
        return self.patterns[self.rot]

    def rotate(self):
        """Rotates IN PLACE, and returns the new pattern"""
        self.rot = (self.rot + 1) % 4
        return self.get_pattern()

    def get_height(self):
        return len(self.patterns[self.rot])

    def get_width(self):
        return max([len(x) for x in self.patterns[self.rot]])

    def get_bottom_offsets(self):
        """
        For each column in the shape, returns the gap between the bottom of
        the shape (across all columns) and the bottom of the shape in that
        column.

        Returned values in the list would expect to contain at least one 0, and
        no values higher than the height of the shape.

        For example, an S piece:
        _ X X
        X X _

        Would have offsets [0, 0, 1] in this current rotation. This method is
        used in determining if a piece will fit at a certain position
        in the board.
        """
        # print(self)
        pattern = self.get_pattern()
        # pdb.set_trace()
        ret = [len(pattern)+1 for x in range(len(pattern[0]))]
        # Iterates rows from top, down
        for ri in range(len(pattern)):
            # Given a T shape:
            # X X X
            # _ X _
            # Start with row [X X X] (ri=0, offset=1)
            row = pattern[ri]
            # print(f"Testing row {row} at index {ri}")
            for ci, col in enumerate(row):
                if col == 1:
                    offset = len(pattern) - ri - 1
                    ret[ci] = offset

            # Will return [1, 0, 1] for a T shape

        if max(ret) >= len(pattern):
          print(f"Pattern:")
          print(pattern)
          print(f"Bottom Offsets: {ret}")
          print(f"Shape: {self.shape}")
          raise ValueError("Tetromino pattern has incomplete bottom offsets")

        return ret

    def get_top_offsets(self):
        """
        Returns the height of the shape at each column.

        For example, an S piece:
        _ X X
        X X _

        Would have offsets [1, 2, 2] in this current rotation. This provides
        guidance on how to update the headroom list.

        Ideally we should cache this.
        """
        pattern = self.get_pattern()
        ret = [0 for x in len(pattern[0])]
        for ri, row in enumerate(range(pattern, )):
            for col in pattern[row]:
                if pattern[row][col] == 1:
                    ret[col] = max(ret[col], row)
        return ret


class Tetrominos:
    O = 1
    I = 2
    S = 3
    Z = 4
    T = 5
    J = 6
    L = 7

    base_patterns = {
        # X X
        # X X
        O: np.array([[1, 1], [1, 1]]),

        # X X X X
        I: np.array([[1, 1, 1, 1]]),

        # _ X X
        # X X _
        S: np.array([[0, 1, 1], [1, 1, 0]]),
        Z: np.array([[1, 1, 0], [0, 1, 1]]),
        T: np.array([[1, 1, 1], [0, 1, 0]]),
        J: np.array([[1, 0, 0], [1, 1, 1]]),
        L: np.array([[0, 0, 1], [1, 1, 1]])
    }

    # Stores patterns for each tetromino, at each rotation
    cache = {}

    def num_tetrominos():
        return len(Tetrominos.base_patterns.keys())

    @staticmethod
    def shape_name(shape):
        if shape == Tetrominos.O:
            return "O"
        elif shape == Tetrominos.I:
            return "I"
        elif shape == Tetrominos.S:
            return "S"
        elif shape == Tetrominos.Z:
            return "Z"
        elif shape == Tetrominos.T:
            return "T"
        elif shape == Tetrominos.J:
            return "J"
        elif shape == Tetrominos.L:
            return "L"
        else:
            raise ValueError("Invalid shape")


    @staticmethod
    def make(shape):
        """
        shape:
        """
        if not Tetrominos.cache:
            for shape, pattern in Tetrominos.base_patterns.items():
                Tetrominos.cache[shape] = [
                    pattern,
                    np.rot90(pattern),
                    np.rot90(pattern, 2),
                    np.rot90(pattern, 3)
                ]


        if shape not in Tetrominos.base_patterns.keys():
            raise ValueError("Invalid shape")

        return TetrominoPiece(shape, Tetrominos.cache[shape])

class TetrisBoard:

    BLOCK = '▆'

    def __init__(self, height, width):
        self.height = height
        self.width = width
        self.reset()

    def reset(self):
        self.board = np.zeros((self.height, self.width), dtype=int)
        self.headroom = [self.height for _ in range(self.width)]
        self.piece = None

    def remove_tetris(self):
        to_delete = []
        for r, row in enumerate(self.board):
            if sum(row) == self.width:
                to_delete.append(r)

        if to_delete:
          self.board = np.delete(self.board, to_delete, axis=0)
          self.board.resize((self.height, self.width))
          pdb.set_trace()

    def place_piece(self, piece:TetrominoPiece, logical_coords):
        """
        Places a piece at the specified column. Dynamically calculates correct
        height for the piece.

        piece: a TetrominoPiece object
        logical_coords: The logical row and column for the bottom left
            of the piece's pattern
        """
        pattern = piece.get_pattern()
        bottom_offsets = np.array(piece.get_bottom_offsets())
        # TODO don't calculate all bottoms because we don't need them all

        lrow = logical_coords[0]
        lcol = logical_coords[1]

        p_height = piece.get_height()

        for r in range(p_height):
            pattern_row = pattern[len(pattern)-1-r]
            board_row = self.board[lrow-1+r]

            for i, c in enumerate(pattern_row):
                # Iff c is 1, push it to the board
                board_row[lcol-1+i] |= c


    def find_logical_BL_placement(self, piece:TetrominoPiece, col):
        """
        Returns the logical row and column of the bottom left corner of the
        pattern, such that when placed, the piece will sit flush against existing
        tower parts, and not exceed the max board height.

        Given:
        BOARD       PIECE
        5 _ _ _ _
        4 _ _ _ X
        3 _ _ X X   X X X X
        2 _ X X _
        1 X X X X

        Returns (5, 1)

        piece: a TetrominoPiece object
        col: zero-index column to place the 0th column of the piece.
        """
        pattern = piece.get_pattern()
        bottom_offsets = np.array(piece.get_bottom_offsets())
        # TODO don't calculate all bottoms because we don't need them all
        board_heights = np.array(self.get_tops()[col:col+piece.get_width()])

        # Given:
        # BOARD       PIECE
        # 5 _ _ _ _
        # 4 _ _ _ X
        # 3 _ _ X X   X X X X
        # 2 _ X X _
        # 1 X X X X
        # Tops -> [1,2,3,4]
        #
        # The sideways I has bottom offsets [0,0,0,0]
        # Start at min(board_tops)+1 and try to place the piece.
        #
        # If placing on row 2, the piece heights would be [2,2,2,2]
        # Board heights are [1,2,3,4], so this
        # doesn't clear the board for all columns. Try placing on row 3.
        # [3,3,3,3] > [1,2,3,4] ? False
        # Try row 4... False. Try row 5...
        # [5,5,5,5] > [1,2,3,4] ? True
        # So we place the piece on row 5 (index 4)
        #
        # 5 X X X X
        # 4 _ _ _ X
        # 3 _ _ X X
        # 2 _ X X _
        # 1 X X X X
        # (yes, this is a horrible move)

        p_height = piece.get_height()
        p_width = piece.get_width()
        can_place = False

        # TODO Pick better min test height
        # If there's a very narrow, tall tower, and you're placing a flat I
        # just to the left of it, you'll likely test placement for each level of
        # the tower until the piece clears it.
        for place_row in range(min(board_heights)+1, max(board_heights)+2):
            # In the example, place_row would be 2...3...4...5

            # Is [2,2,2,2] > [1,2,3,4] ?
            # Does this placement not interfere with existing board pieces?
            print(f"Trying placement at row {place_row}")
            print(f"{(bottom_offsets + place_row)} > {board_heights}")
            bottom_clears_board = all((bottom_offsets + place_row) > board_heights)

            if not bottom_clears_board:
                continue

            # Check the final height
            if place_row-1 + p_height > self.height:
                raise ValueError(f"Requested placement at col {col+1} would require rows {place_row}-{place_row-1 + p_height}. Piece {piece}")

            can_place = True
            break

        if not can_place:
            # pdb.set_trace()
            raise ValueError(f"Piece failed to be placed at lcolumn {col+1}")

        return (place_row, col+1)


    def render(self, mode='human'):
        output = False
        for i, row in enumerate(reversed(self.board)):
            if sum(row) == 0 and not output:
                continue
            else:
                output = True

            for cell in row:
                if cell == 1:
                    print(TetrisBoard.BLOCK, end=' ')
                else:
                    print('_', end=' ')
            print()

        if not output:
            print("<<EMPTY BOARD>>")



    def get_tops(self):
        """
        Gets the height of each column on the board.
        This is gonna be inefficient for now.

        A board with only an I at the left side would return [4, 0, 0, ...]
        """
        tops = [0 for _ in range(self.width)]
        for r, row in enumerate(self.board):
            if sum(row) == 0:
                break

            for col, val in enumerate(row):
                if val == 1:
                    tops[col] = r+1

        return tops




class TetrisEnv(gym.Env):
    def __init__(self):
        super(TetrisEnv, self).__init__()
        self.board_height = 20
        self.board_width = 10
        self.board = TetrisBoard(self.board_height, self.board_width)
        self.current_piece = None
        self.pieces = Tetrominos()
        self.reward_history = deque(maxlen=10)

        # Action space: tuple (column, rotation)
        # TODO Limit action width properly
        self.action_space = spaces.MultiDiscrete([self.board_width, 4])

        # Observation space: the board state
        self.observation_space = spaces.Box(
            low=0,
            high=1,
            shape=(self.board_height * self.board_width + Tetrominos.num_tetrominos(),),
            dtype=int
            )

        self.reset()

    def reset(self):
        self.board.reset()
        self.current_piece = self._get_random_piece()
        return self._get_board_state()

    def step(self, action):
        col, rotation = action

        # Rotate the piece to the desired rotation
        for _ in range(rotation):
            self.current_piece.rotate() # Rotates IN PLACE

        lcoords = None

        try:
            # Find where the piece would sit on the board
            lcoords = self.board.find_logical_BL_placement(self.current_piece, col)
        except ValueError as e:
            print(e)
            done = True
            # TODO Account for a fatal placement
            reward = self._calculate_reward()
            return self._get_board_state(), reward, done, {}


        if self._is_valid_action(self.current_piece, lcoords):
            self.board.place_piece(self.current_piece, lcoords)
            reward = self._calculate_reward()
            self.reward_history.append(reward)
            done = self._is_done()
            self.board.remove_tetris()
            self.current_piece = self._get_random_piece()
        else:
            print("Invalid Action")
            self.board.render()
            print(f"Action Column: {col+1} (1-{self.board_width})")
            print(f"Piece: {self.current_piece}")
            reward = self._calculate_reward() * 0.5
            done = True

        next_state = self._get_board_state()
        return next_state, reward, done, {}

    def render(self):
        self.board.render()

    def _get_random_piece(self):
        return self.pieces.make(np.random.randint(1, 8))

    def _is_valid_action(self, piece, coords):
        piece = self.current_piece

        col = coords[1] - 1

        if col < 0 or col > self.board_width:
            return False

        if col + piece.get_width() > self.board_width:
            return False
        return True

    def _calculate_reward(self):

        # Evaluate line pack
        # Packed lines produces a higher score
        # Big narrow tower would produce a low score
        active_lines = 0
        board_tiles = 0
        lines_cleared = 0
        for row in self.board.board:
            row_sum = sum(row)
            board_tiles += row_sum
            if row_sum == 0:
                continue

            active_lines += 1
            if row_sum == self.board.width:
                lines_cleared += 1

        line_score = (board_tiles+(5*lines_cleared)) / float(self.board_width * active_lines)
        reward = line_score  # That's all for now
        return reward

    def _is_done(self):
        return False

    def _get_board_state(self):
        # Get the current board state
        board_state = self.board.board.flatten()

        # Create a one-hot encoding for the current piece
        piece_one_hot = np.zeros(Tetrominos.num_tetrominos())
        piece_one_hot[self.current_piece.shape - 1] = 1

        # Concatenate the board state and the one-hot encoding
        return np.concatenate((board_state, piece_one_hot))

    # def _get_column_heights(self):
    #     """Return the heights of each column."""
    #     return [self.board_height - np.argmax(self.board.board[:, col][::-1] != 0) for col in range(self.board_width)]

    def _choose_action(self):
        """Choose the action based on a greedy strategy (lowest column height)."""
        column_heights = self._get_column_heights()
        min_height = min(column_heights)
        best_columns = [i for i, height in enumerate(column_heights) if height == min_height]
        best_column = np.random.choice(best_columns)  # Choose randomly among best columns
        best_rotation = 0  # You could implement a strategy to choose the best rotation
        return (best_column, best_rotation)


def main():
  # for x in range(1,8):
  #     s = Tetrominos.make(x)
  #     print(s)
  #     print(s.get_bottom_offsets())
  #     s.rotate()
  #     print(s)
  #     print(s.get_bottom_offsets())
  #     s.rotate()
  #     print(s)
  #     print(s.get_bottom_offsets())
  #     s.rotate()
  #     print(s)
  #     print(s.get_bottom_offsets())
  #     print("---")



  # p = Tetrominos.make(Tetrominos.O)
  # print(p.get_pattern())
  # print("---")
  # print(p.get_bottom_offsets())


  # import sys
  # sys.exit()


  # Example usage
  env = TetrisEnv()
  state = env.reset()

  done = False
  loop_limit = 10
  loop = 0
  while not done and loop < loop_limit:
      action = env.action_space.sample()  # Random action for demonstration
      next_state, reward, done, _ = env.step(action)
      env.board.render()
      print(f"Reward: {reward}, Done: {done}")
      loop += 1




In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x








import random
from collections import deque

class DQNAgent:
    def __init__(self, state_dim, action_dim, learning_rate=0.001, discount_factor=0.99, exploration_rate=1.0, exploration_decay=0.995, min_exploration_rate=0.01, replay_buffer_size=10000, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim[0] * action_dim[1]  # Total number of actions
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.exploration_decay = exploration_decay
        self.min_exploration_rate = min_exploration_rate
        self.replay_buffer = deque(maxlen=replay_buffer_size)
        self.batch_size = batch_size

        self.model = DQN(state_dim, self.action_dim)
        self.target_model = DQN(state_dim, self.action_dim)
        self.update_target_model()

        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.loss_fn = nn.MSELoss()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))

    def choose_action(self, state):
        if random.uniform(0, 1) < self.exploration_rate:
            return (random.randint(0, self.action_dim // 4 - 1), random.randint(0, 4 - 1))
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.model(state)
            action_index = torch.argmax(q_values).item()
            return (action_index // 4, action_index % 4)

    def replay(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor([a[0] * 4 + a[1] for a in actions])
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = self.target_model(next_states).max(1)[0]
        target_q_values = rewards + self.discount_factor * next_q_values * (1 - dones)

        loss = self.loss_fn(q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def decay_exploration_rate(self):
        self.exploration_rate = max(self.min_exploration_rate, self.exploration_rate * self.exploration_decay)


  and should_run_async(code)


In [4]:
# Ensure TetrisEnv is defined as before

# Initialize Tetris environment
env = TetrisEnv()
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nvec

agent = DQNAgent(state_dim, action_dim)

num_episodes = 100
target_update_interval = 10

for episode in range(num_episodes):
    state = env.reset().flatten()  # Flatten the state to fit the input of the network
    total_reward = 0
    done = False

    while not done:
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        # env.board.render()
        next_state = next_state.flatten()

        agent.remember(state, action, reward, next_state, done)
        agent.replay()
        state = next_state
        total_reward += reward

    agent.decay_exploration_rate()

    if episode % target_update_interval == 0:
        agent.update_target_model()

    print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Exploration Rate: {agent.exploration_rate}")

print("Training completed.")


  and should_run_async(code)


Trying placement at row 1
[2 1 1] > [0 0 0]
Trying placement at row 1
[1 1] > [0 2]
Trying placement at row 2
[2 2] > [0 2]
Trying placement at row 3
[3 3] > [0 2]
Trying placement at row 1
[1 1 1] > [0 0 0]
Trying placement at row 1
[2 1] > [0]
Invalid Action
_ ▆ ▆ _ _ _ _ _ _ _ 
_ ▆ ▆ _ _ _ _ _ _ _ 
_ _ ▆ ▆ _ _ ▆ _ _ _ 
_ _ _ ▆ ▆ ▆ ▆ ▆ _ _ 
Action Column: 10 (1-10)
Piece: TetrominoPiece(shape=T, rot=270, pattern= _ ▆ / ▆ ▆ / _ ▆)
Episode: 1, Total Reward: 0.85, Exploration Rate: 0.995
Trying placement at row 1
[1 1 2] > [0 0]
operands could not be broadcast together with shapes (3,) (2,) 
Episode: 2, Total Reward: nan, Exploration Rate: 0.990025
Trying placement at row 1
[1 1] > [0 0]
Trying placement at row 1
[1 1 1 1] > [0 0 0 0]
Trying placement at row 2
[2 2 2 2] > [2 2 1 1]
Trying placement at row 3
[3 3 3 3] > [2 2 1 1]
Trying placement at row 1
[1 3] > [0]
Invalid Action
_ _ ▆ ▆ ▆ ▆ _ _ _ _ 
_ _ ▆ ▆ _ _ _ _ _ _ 
_ _ ▆ ▆ ▆ ▆ ▆ ▆ _ _ 
Action Column: 10 (1-10)
Piece: TetrominoPie

  line_score = (board_tiles+(5*lines_cleared)) / float(self.board_width * active_lines)


[5 5 5] > [ 8  2 11]
Trying placement at row 6
[6 6 6] > [ 8  2 11]
Trying placement at row 7
[7 7 7] > [ 8  2 11]
Trying placement at row 8
[8 8 8] > [ 8  2 11]
Trying placement at row 9
[9 9 9] > [ 8  2 11]
Trying placement at row 10
[10 10 10] > [ 8  2 11]
Trying placement at row 11
[11 11 11] > [ 8  2 11]
Trying placement at row 12
[12 12 12] > [ 8  2 11]
Trying placement at row 11
[13 11] > [12 10]
Trying placement at row 7
[7 7 8] > [13  6]
operands could not be broadcast together with shapes (3,) (2,) 
Episode: 4, Total Reward: 6.05820512820513, Exploration Rate: 0.9801495006250001
Trying placement at row 1
[2 1 2] > [0 0 0]
Trying placement at row 1
[3 1] > [2 0]
Trying placement at row 1
[1 1] > [0 2]
Trying placement at row 2
[2 2] > [0 2]
Trying placement at row 3
[3 3] > [0 2]
Trying placement at row 1
[2 1 1] > [0]
Invalid Action
_ _ _ ▆ _ _ _ _ _ _ 
_ _ _ ▆ _ _ _ _ _ _ 
_ _ ▆ ▆ _ ▆ ▆ _ _ _ 
_ _ _ ▆ ▆ ▆ ▆ _ _ _ 
_ _ _ _ ▆ _ ▆ _ _ _ 
Action Column: 10 (1-10)
Piece: Tetromin

  states = torch.FloatTensor(states)


Episode: 12, Total Reward: 0.9, Exploration Rate: 0.9416228069143757
Trying placement at row 1
[1 1] > [0 0]
Trying placement at row 1
[2 2 1] > [0 0 0]
Trying placement at row 1
[1 1] > [2 0]
Trying placement at row 2
[2 2] > [2 0]
Trying placement at row 3
[3 3] > [2 0]
Trying placement at row 3
[4 3 3] > [2 4 4]
Trying placement at row 4
[5 4 4] > [2 4 4]
Trying placement at row 5
[6 5 5] > [2 4 4]
Trying placement at row 6
[6 6 6] > [6 6 5]
Trying placement at row 7
[7 7 7] > [6 6 5]
Trying placement at row 1
[1 1 2] > [7 7 0]
Trying placement at row 2
[2 2 3] > [7 7 0]
Trying placement at row 3
[3 3 4] > [7 7 0]
Trying placement at row 4
[4 4 5] > [7 7 0]
Trying placement at row 5
[5 5 6] > [7 7 0]
Trying placement at row 6
[6 6 7] > [7 7 0]
Trying placement at row 7
[7 7 8] > [7 7 0]
Trying placement at row 8
[8 8 9] > [7 7 0]
Trying placement at row 3
[3 3 3] > [2 2 2]
Trying placement at row 1
[1 2] > [9 0]
Trying placement at row 2
[2 3] > [9 0]
Trying placement at row 3
[3 4]