In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple
import pygame
import copy

pygame 2.6.1 (SDL 2.28.4, Python 3.12.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
SHAPES = {
    'I': [[(0,0), (0,1), (0,2), (0,3)],
          [(0,0), (1,0), (2,0), (3,0)]],
    'O': [[(0,0), (0,1), (1,0), (1,1)]],
    'T': [[(0,1), (1,0), (1,1), (1,2)],
          [(0,1), (1,1), (1,2), (2,1)],
          [(1,0), (1,1), (1,2), (2,1)],
          [(0,1), (1,0), (1,1), (2,1)]],
    'L': [[(0,2), (1,0), (1,1), (1,2)],
          [(0,1), (1,1), (2,1), (2,2)],
          [(1,0), (1,1), (1,2), (2,0)],
          [(0,0), (0,1), (1,1), (2,1)]],
}

In [18]:
class TetrisEnv:
    def __init__(self, width=10, height=20):
        self.width = width
        self.height = height
        self.board = None
        self.current_piece = None
        self.current_pos = None
        self.score = 0
        
        # Add color mapping for pieces
        self.colors = {
            'I': (0, 255, 255),    # Cyan
            'O': (255, 255, 0),    # Yellow
            'T': (128, 0, 128),    # Purple
            'L': (255, 165, 0),    # Orange
            'grid': (128, 128, 128) # Gray
        }
        
        self.reset()
        
    def reset(self):
        self.board = np.zeros((self.height, self.width), dtype=int)
        self.spawn_piece()
        return self._get_state()
    
    def spawn_piece(self):
        self.current_piece = random.choice(list(SHAPES.keys()))
        self.current_rotation = 0
        self.current_pos = [0, self.width // 2 - 1]
    
    def _get_current_shape(self):
        return SHAPES[self.current_piece][self.current_rotation]
    
    def _is_valid_move(self, pos, shape):
        for block in shape:
            y = pos[0] + block[0]
            x = pos[1] + block[1]
            if (x < 0 or x >= self.width or 
                y >= self.height or 
                (y >= 0 and self.board[y][x])):
                return False
        return True
    
    def _clear_lines(self):
        lines_cleared = 0
        y = self.height - 1
        while y >= 0:
            if np.all(self.board[y]):
                self.board = np.vstack((np.zeros((1, self.width)), 
                                      self.board[:y], 
                                      self.board[y+1:]))
                lines_cleared += 1
            else:
                y -= 1
        return lines_cleared
    
    def _get_state(self):
        # Returns the current state as a flat array
        return np.append(self.board.flatten(), 
                        [self.current_pos[0], self.current_pos[1], 
                         self.current_rotation])
    
    def step(self, action):
        # Actions: 0: left, 1: right, 2: rotate, 3: down
        reward = 0
        done = False
        
        if action == 0:  # Move left
            new_pos = [self.current_pos[0], self.current_pos[1] - 1]
            if self._is_valid_move(new_pos, self._get_current_shape()):
                self.current_pos = new_pos
                
        elif action == 1:  # Move right
            new_pos = [self.current_pos[0], self.current_pos[1] + 1]
            if self._is_valid_move(new_pos, self._get_current_shape()):
                self.current_pos = new_pos
                
        elif action == 2:  # Rotate
            new_rotation = (self.current_rotation + 1) % len(SHAPES[self.current_piece])
            new_shape = SHAPES[self.current_piece][new_rotation]
            if self._is_valid_move(self.current_pos, new_shape):
                self.current_rotation = new_rotation
                
        # Always try to move down
        new_pos = [self.current_pos[0] + 1, self.current_pos[1]]
        if self._is_valid_move(new_pos, self._get_current_shape()):
            self.current_pos = new_pos
        else:
            # Place the piece
            for block in self._get_current_shape():
                y = self.current_pos[0] + block[0]
                x = self.current_pos[1] + block[1]
                if y >= 0:
                    self.board[y][x] = 1
            
            # Check for cleared lines
            lines_cleared = self._clear_lines()
            reward = lines_cleared ** 2
            
            # Spawn new piece
            self.spawn_piece()
            if not self._is_valid_move(self.current_pos, self._get_current_shape()):
                done = True
                reward = -10
        
        return self._get_state(), reward, done

In [20]:
class TetrisVisualizer:
    def __init__(self, env, cell_size=30):
        pygame.init()
        self.env = env
        self.cell_size = cell_size
        self.width = env.width * cell_size
        self.height = env.height * cell_size
        self.screen = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption('Tetris AI')

        # Define default color for placed blocks
        self.placed_block_color = (0, 255, 255)  # Cyan
        
    def draw_board(self):
        self.screen.fill((0, 0, 0))
        
        # Draw the grid
        for y in range(self.env.height):
            for x in range(self.env.width):
                pygame.draw.rect(self.screen, self.env.colors['grid'],
                               [x * self.cell_size, y * self.cell_size,
                                self.cell_size, self.cell_size], 1)
        
        # Draw the placed pieces
        for y in range(self.env.height):
            for x in range(self.env.width):
                if self.env.board[y][x]:
                    pygame.draw.rect(self.screen, self.env.colors['I'],
                                   [x * self.cell_size, y * self.cell_size,
                                    self.cell_size, self.cell_size])
        
        # Draw the current piece
        if self.env.current_piece:
            for block in self.env._get_current_shape():
                x = (self.env.current_pos[1] + block[1]) * self.cell_size
                y = (self.env.current_pos[0] + block[0]) * self.cell_size
                if self.env.current_pos[0] + block[0] >= 0:
                    pygame.draw.rect(self.screen, 
                                   self.env.colors[self.env.current_piece],
                                   [x, y, self.cell_size, self.cell_size])
        
        pygame.display.flip()
    
    def close(self):
        pygame.quit()

In [12]:
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [13]:
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque([], maxlen=capacity)
        self.transition = namedtuple('Transition',
                                   ('state', 'action', 'next_state', 'reward'))
    
    def push(self, *args):
        self.memory.append(self.transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [14]:
class TetrisAgent:
    def __init__(self, state_size, action_size):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(state_size, action_size).to(self.device)
        self.target_net = DQN(state_size, action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters())
        self.memory = ReplayMemory(10000)
        
        self.batch_size = 128
        self.gamma = 0.99
        self.eps_start = 0.9
        self.eps_end = 0.05
        self.eps_decay = 200
        self.target_update = 10
        self.steps_done = 0
        
    def select_action(self, state):
        sample = random.random()
        eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * \
                       np.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        
        if sample > eps_threshold:
            with torch.no_grad():
                state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(4)]], 
                              device=self.device, dtype=torch.long)
    
    def optimize_model(self):
        if len(self.memory) < self.batch_size:
            return
        
        transitions = self.memory.sample(self.batch_size)
        batch = self.memory.transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), 
                                    device=self.device, dtype=torch.bool)
        non_final_next_states = torch.FloatTensor([s for s in batch.next_state
                                                 if s is not None]).to(self.device)
        
        state_batch = torch.FloatTensor(batch.state).to(self.device)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.FloatTensor(batch.reward).to(self.device)
        
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
        
        next_state_values = torch.zeros(self.batch_size, device=self.device)
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch
        
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

In [22]:
# Training function with visualization
def train_with_visualization(num_episodes=1000, visualize_every=100):
    env = TetrisEnv()
    visualizer = TetrisVisualizer(env)
    state_size = env.width * env.height + 3
    action_size = 4
    
    agent = TetrisAgent(state_size, action_size)
    scores = []
    
    try:
        for episode in range(num_episodes):
            state = env.reset()
            total_reward = 0
            
            # Visualize every nth episode
            should_visualize = episode % visualize_every == 0
            
            while True:
                # Handle pygame events
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        visualizer.close()
                        return agent, scores
                
                action = agent.select_action(state)
                next_state, reward, done = env.step(action.item())
                total_reward += reward
                
                if should_visualize:
                    visualizer.draw_board()
                    time.sleep(0.05)
                
                if done:
                    next_state = None
                
                agent.memory.push(state, action, next_state, reward)
                state = next_state
                
                agent.optimize_model()
                
                if done:
                    break
            
            if episode % agent.target_update == 0:
                agent.target_net.load_state_dict(agent.policy_net.state_dict())
            
            scores.append(total_reward)
            if episode % 10 == 0:
                print(f"Episode {episode}, Score: {total_reward}, Average Score: {np.mean(scores[-100:])}")
    
    finally:
        visualizer.close()
    
    return agent, scores

In [16]:
# Save and load functions for the trained model
def save_model(agent, filename):
    torch.save({
        'policy_net_state_dict': agent.policy_net.state_dict(),
        'target_net_state_dict': agent.target_net.state_dict(),
        'optimizer_state_dict': agent.optimizer.state_dict(),
    }, filename)

In [17]:
def load_model(filename, state_size, action_size):
    agent = TetrisAgent(state_size, action_size)
    checkpoint = torch.load(filename)
    agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
    agent.target_net.load_state_dict(checkpoint['target_net_state_dict'])
    agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return agent

In [21]:
def visualize_ai_play(agent, env, visualizer, num_games=5, delay=0.1):
    for game in range(num_games):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            # Handle pygame events
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    visualizer.close()
                    return
            
            # Get AI action
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
                action = agent.policy_net(state_tensor).max(1)[1].view(1, 1)
            
            # Take step in environment
            state, reward, done = env.step(action.item())
            total_reward += reward
            
            # Update visualization
            visualizer.draw_board()
            time.sleep(delay)  # Add delay to make visualization visible
            
        print(f"Game {game + 1} Score: {total_reward}")
    
    visualizer.close()

In [24]:
# Train for 1000 episodes
agent, scores = train_with_visualization(num_episodes=1000, visualize_every=100)

# Save the trained model
save_model(agent, 'tetris_model.pth')

# Load a trained model later
state_size = 203  # 10x20 board + 3 for position and rotation
action_size = 4
loaded_agent = load_model('tetris_model.pth', state_size, action_size)

# Load a trained model
state_size = 203  # 10x20 board + 3 for position and rotation
action_size = 4
agent = load_model('tetris_model.pth', state_size, action_size)

# Create environment and visualizer
env = TetrisEnv()
visualizer = TetrisVisualizer(env)

# Watch the AI play 5 games with 0.1 second delay between moves
visualize_ai_play(agent, env, visualizer, num_games=5, delay=0.1)

AttributeError: 'TetrisEnv' object has no attribute 'colors'