In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple
import pygame
import copy
import time
import matplotlib.pyplot as plt
import os

# For reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

pygame 2.6.1 (SDL 2.28.4, Python 3.12.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
SHAPES = {
    'I': [[(0,0), (0,1), (0,2), (0,3)],
          [(0,0), (1,0), (2,0), (3,0)]],
    'O': [[(0,0), (0,1), (1,0), (1,1)]],
    'T': [[(0,1), (1,0), (1,1), (1,2)],
          [(0,1), (1,1), (1,2), (2,1)],
          [(1,0), (1,1), (1,2), (2,1)],
          [(0,1), (1,0), (1,1), (2,1)]],
    'L': [[(0,2), (1,0), (1,1), (1,2)],
          [(0,1), (1,1), (2,1), (2,2)],
          [(1,0), (1,1), (1,2), (2,0)],
          [(0,0), (0,1), (1,1), (2,1)]],
}

# Define Transition at module level
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [17]:
class TetrisEnv:
    def __init__(self, width=10, height=20):
        self.width = width
        self.height = height
        self.board = None
        self.current_piece = None
        self.current_pos = None
        self.current_rotation = None
        self.score = 0
        
        # Enhanced color scheme
        self.colors = {
            'I': (0, 240, 240),     # Bright Cyan
            'O': (240, 240, 0),     # Bright Yellow
            'T': (160, 0, 240),     # Bright Purple
            'L': (240, 160, 0),     # Bright Orange
            'grid': (40, 40, 40),   # Darker Gray
            'background': (0, 0, 0), # Black
            'ghost': (128, 128, 128, 128)  # Semi-transparent ghost piece
        }
        
        # Reward weights
        self.reward_weights = {
            'lines_cleared': 20,      # Base multiplier for lines cleared
            'tetris_bonus': 50,       # Additional bonus for clearing 4 lines
            'height_penalty': 2.5,    # Penalty for stack height
            'hole_penalty': 3.0,      # Penalty for creating holes
            'surface_penalty': 0.2,   # Penalty for surface roughness
            'well_reward': 0.3,       # Reward for creating wells for long pieces
            'clear_field_bonus': 100, # Bonus for clearing the field
            'game_over_penalty': 50   # Penalty for game over
        }
        
        self.board = np.zeros((height, width), dtype=int)
        self.spawn_piece()
    
    def reset(self):
        """Reset the environment to initial state"""
        self.board = np.zeros((self.height, self.width), dtype=int)
        self.score = 0
        self.spawn_piece()
        return self._get_state()
    
    def spawn_piece(self):
        """Spawn a new piece at the top of the board"""
        self.current_piece = random.choice(list(SHAPES.keys()))
        self.current_rotation = 0
        self.current_pos = [0, self.width // 2 - 1]
    
    def _get_current_shape(self):
        """Get the current piece's shape based on rotation"""
        return SHAPES[self.current_piece][self.current_rotation]
    
    def _is_valid_move(self, pos, shape):
        """Check if the given position is valid for the current piece"""
        for block in shape:
            y = pos[0] + block[0]
            x = pos[1] + block[1]
            if (x < 0 or x >= self.width or 
                y >= self.height or 
                (y >= 0 and self.board[y][x])):
                return False
        return True
    
    def _get_column_height(self, col):
        """Get the height of blocks in a given column"""
        for row in range(self.height):
            if self.board[row][col] == 1:
                return self.height - row
        return 0
    
    def _get_column_holes(self, col):
        """Count holes in a given column"""
        holes = 0
        block_found = False
        for row in range(self.height):
            if self.board[row][col] == 1:
                block_found = True
            elif block_found and self.board[row][col] == 0:
                holes += 1
        return holes
    
    def _get_max_height(self):
        """Get the maximum height of the stack"""
        for y in range(self.height):
            if 1 in self.board[y]:
                return self.height - y
        return 0
    
    def _count_holes(self):
        """Count total holes in the board"""
        return sum(self._get_column_holes(col) for col in range(self.width))
    
    def _clear_lines(self):
        """Clear complete lines and return the number of lines cleared"""
        lines_cleared = 0
        y = self.height - 1
        while y >= 0:
            if np.all(self.board[y]):
                self.board = np.vstack((np.zeros((1, self.width)), 
                                      self.board[:y], 
                                      self.board[y+1:]))
                lines_cleared += 1
            else:
                y -= 1
        return lines_cleared
    
    def _get_surface_roughness(self):
        """Calculate surface roughness"""
        heights = [self._get_column_height(col) for col in range(self.width)]
        return sum(abs(heights[i] - heights[i-1]) for i in range(1, len(heights)))
    
    def _detect_wells(self):
        """Detect and count wells suitable for I-pieces"""
        wells = 0
        heights = [self._get_column_height(col) for col in range(self.width)]
        
        for col in range(self.width):
            if col == 0:
                if heights[col] + 3 < heights[col + 1]:
                    wells += 1
            elif col == self.width - 1:
                if heights[col] + 3 < heights[col - 1]:
                    wells += 1
            else:
                if heights[col] + 3 < heights[col - 1] and heights[col] + 3 < heights[col + 1]:
                    wells += 1
        return wells
    
    def _get_state(self):
        """Get the current state representation"""
        board_state = self.board.flatten()
        piece_state = np.zeros(len(SHAPES))
        piece_state[list(SHAPES.keys()).index(self.current_piece)] = 1
        heights = np.array([self._get_column_height(col) for col in range(self.width)])
        holes = np.array([self._get_column_holes(col) for col in range(self.width)])
        bumpiness = np.diff(heights)
        
        return np.concatenate([
            board_state,
            piece_state,
            heights / self.height,
            holes / self.height,
            bumpiness / self.height
        ])
    
    def step(self, action):
        """Execute one step in the environment"""
        reward = 0
        done = False
        initial_height = self._get_max_height()
        initial_holes = self._count_holes()
        
        # Handle actions
        if action == 0:  # Move left
            new_pos = [self.current_pos[0], self.current_pos[1] - 1]
            if self._is_valid_move(new_pos, self._get_current_shape()):
                self.current_pos = new_pos
        elif action == 1:  # Move right
            new_pos = [self.current_pos[0], self.current_pos[1] + 1]
            if self._is_valid_move(new_pos, self._get_current_shape()):
                self.current_pos = new_pos
        elif action == 2:  # Rotate
            new_rotation = (self.current_rotation + 1) % len(SHAPES[self.current_piece])
            new_shape = SHAPES[self.current_piece][new_rotation]
            if self._is_valid_move(self.current_pos, new_shape):
                self.current_rotation = new_rotation
        
        # Try to move down
        new_pos = [self.current_pos[0] + 1, self.current_pos[1]]
        if self._is_valid_move(new_pos, self._get_current_shape()):
            self.current_pos = new_pos
        else:
            # Place the piece
            for block in self._get_current_shape():
                y = self.current_pos[0] + block[0]
                x = self.current_pos[1] + block[1]
                if y >= 0:
                    self.board[y][x] = 1
            
            # Calculate rewards
            lines_cleared = self._clear_lines()
            new_height = self._get_max_height()
            new_holes = self._count_holes()
            surface_roughness = self._get_surface_roughness()
            
            reward = self._calculate_reward(
                lines_cleared,
                new_height - initial_height,
                new_holes - initial_holes,
                surface_roughness
            )
            
            # Spawn new piece
            self.spawn_piece()
            if not self._is_valid_move(self.current_pos, self._get_current_shape()):
                done = True
                reward -= self.reward_weights['game_over_penalty']
        
        return self._get_state(), reward, done

In [19]:
class TetrisVisualizer:
    def __init__(self, env, cell_size=30):
        """Initialize the Tetris visualizer"""
        pygame.init()
        self.env = env
        self.cell_size = cell_size
        self.width = env.width * cell_size
        self.height = env.height * cell_size
        
        # Add padding for score display
        self.padding = 100
        self.screen = pygame.display.set_mode((self.width + self.padding, self.height))
        pygame.display.set_caption('Tetris AI')
        
        # Initialize fonts
        self.font = pygame.font.Font(None, 36)
        
        # Add gradient effects
        self.gradient_colors = self._create_gradient()
    
    def _create_gradient(self):
        """Create a subtle gradient effect for blocks"""
        gradient = {}
        for piece, color in self.env.colors.items():
            if isinstance(color, tuple) and piece != 'grid':
                # Handle both RGB and RGBA colors
                if len(color) == 4:  # RGBA
                    r, g, b, a = color
                else:  # RGB
                    r, g, b = color
                gradient[piece] = {
                    'main': (r, g, b),
                    'light': (min(r + 30, 255), min(g + 30, 255), min(b + 30, 255)),
                    'dark': (max(r - 30, 0), max(g - 30, 0), max(b - 30, 0))
                }
        return gradient
    
    def _draw_block(self, x, y, color_key):
        """Draw a single block with gradient effect"""
        if color_key in self.gradient_colors:
            colors = self.gradient_colors[color_key]
            # Main block
            pygame.draw.rect(self.screen, colors['main'],
                           [x, y, self.cell_size, self.cell_size])
            # Light edge
            pygame.draw.line(self.screen, colors['light'],
                           (x, y), (x + self.cell_size, y), 2)
            pygame.draw.line(self.screen, colors['light'],
                           (x, y), (x, y + self.cell_size), 2)
            # Dark edge
            pygame.draw.line(self.screen, colors['dark'],
                           (x + self.cell_size, y), (x + self.cell_size, y + self.cell_size), 2)
            pygame.draw.line(self.screen, colors['dark'],
                           (x, y + self.cell_size), (x + self.cell_size, y + self.cell_size), 2)
        else:
            # Fallback for pieces without gradient colors
            pygame.draw.rect(self.screen, self.env.colors.get(color_key, (128, 128, 128)),
                           [x, y, self.cell_size, self.cell_size])
    
    def _get_ghost_position(self):
        """Calculate the position where the current piece would land"""
        if not self.env.current_piece:
            return 0
        
        ghost_y = self.env.current_pos[0]
        while self.env._is_valid_move([ghost_y + 1, self.env.current_pos[1]], 
                                    self.env._get_current_shape()):
            ghost_y += 1
        return ghost_y
    
    def draw_board(self):
        """Draw the complete game board with all pieces and UI elements"""
        self.screen.fill(self.env.colors['background'])
        
        # Draw the grid
        for y in range(self.env.height):
            fade = 1 - (y / self.env.height * 0.5)  # Creates a subtle fade effect
            grid_color = tuple(int(c * fade) for c in self.env.colors['grid'])
            for x in range(self.env.width):
                pygame.draw.rect(self.screen, grid_color,
                               [x * self.cell_size, y * self.cell_size,
                                self.cell_size, self.cell_size], 1)
        
        # Draw ghost piece if there is a current piece
        if self.env.current_piece:
            ghost_y = self._get_ghost_position()
            ghost_color = self.env.colors['ghost']
            if len(ghost_color) == 3:  # Convert RGB to RGBA
                ghost_color = (*ghost_color, 128)
            ghost_surface = pygame.Surface((self.cell_size, self.cell_size), pygame.SRCALPHA)
            pygame.draw.rect(ghost_surface, ghost_color, ghost_surface.get_rect())
            
            for block in self.env._get_current_shape():
                x = (self.env.current_pos[1] + block[1]) * self.cell_size
                y = (ghost_y + block[0]) * self.cell_size
                if ghost_y + block[0] >= 0:
                    self.screen.blit(ghost_surface, (x, y))
        
        # Draw placed blocks
        for y in range(self.env.height):
            for x in range(self.env.width):
                if self.env.board[y][x]:
                    self._draw_block(x * self.cell_size, y * self.cell_size, 
                                   self.env.current_piece)
        
        # Draw current piece
        if self.env.current_piece:
            for block in self.env._get_current_shape():
                x = (self.env.current_pos[1] + block[1]) * self.cell_size
                y = (self.env.current_pos[0] + block[0]) * self.cell_size
                if self.env.current_pos[0] + block[0] >= 0:
                    self._draw_block(x, y, self.env.current_piece)
        
        # Draw score and other stats
        self._draw_stats()
        
        pygame.display.flip()
    
    def _draw_stats(self):
        """Draw score and other statistics"""
        score_text = self.font.render(f'Score: {self.env.score}', True, (255, 255, 255))
        self.screen.blit(score_text, (self.width + 10, 20))
    
    def close(self):
        """Clean up and close the visualization"""
        pygame.quit()

In [8]:
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.relu(self.fc3(x))
        return self.fc4(x)

In [9]:
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [10]:
class TetrisAgent:
    def __init__(self, state_size, action_size):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.state_size = state_size
        self.action_size = action_size
        
        self.policy_net = DQN(state_size, action_size).to(self.device)
        self.target_net = DQN(state_size, action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.batch_size = 256
        self.gamma = 0.99
        self.eps_start = 1.0
        self.eps_end = 0.01
        self.eps_decay = 2000
        self.target_update = 5
        self.learning_rate = 0.0001
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
        self.memory = ReplayMemory(10000)
        self.steps_done = 0
        
        # Training statistics
        self.episodes_completed = 0
        self.best_score = float('-inf')
        self.recent_scores = deque(maxlen=100)

In [15]:
def save_state(self, filename):
        try:
            state = {
                'policy_net_state': self.policy_net.state_dict(),
                'target_net_state': self.target_net.state_dict(),
                'optimizer_state': self.optimizer.state_dict(),
                'steps_done': self.steps_done,
                'memory': list(self.memory.memory),
                'epsilon': self.eps_start,
                'training_stats': {
                    'episodes_completed': self.episodes_completed,
                    'best_score': self.best_score,
                    'recent_scores': list(self.recent_scores)
                }
            }
            
            directory = 'saved_models'
            os.makedirs(directory, exist_ok=True)
            
            filepath = os.path.join(directory, filename)
            torch.save(state, filepath)
            print(f"Model saved to {filepath}")
        except Exception as e:
            print(f"Error saving model: {e}")
    
def load_state(self, filename):
        try:
            filepath = os.path.join('saved_models', filename)
            if not os.path.exists(filepath):
                print(f"No saved model found at {filepath}")
                return False
            
            state = torch.load(filepath, map_location=self.device)
            
            self.policy_net.load_state_dict(state['policy_net_state'])
            self.target_net.load_state_dict(state['target_net_state'])
            self.optimizer.load_state_dict(state['optimizer_state'])
            self.steps_done = state['steps_done']
            
            # Restore training statistics
            training_stats = state.get('training_stats', {})
            self.episodes_completed = training_stats.get('episodes_completed', 0)
            self.best_score = training_stats.get('best_score', float('-inf'))
            self.recent_scores = deque(training_stats.get('recent_scores', []), maxlen=100)
            
            # Restore replay memory
            self.memory = ReplayMemory(self.memory.capacity)
            self.memory.memory = deque(state['memory'], maxlen=self.memory.capacity)
            
            print(f"Model loaded from {filepath}")
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
            
def update_epsilon(self):
    """Update exploration rate"""
    self.eps_start = max(self.eps_end, self.eps_start * 0.995)  # Decay epsilon
    
def get_state_dict(self):
    """Get a dictionary of current agent state"""
    return {
        'steps': self.steps_done,
        'episodes': self.episodes_completed,
        'best_score': self.best_score,
        'recent_avg': np.mean(self.recent_scores) if self.recent_scores else 0,
        'epsilon': self.eps_start
    }
    
def add_experience(self, state, action, next_state, reward, done):
    """Add an experience to memory and handle end-of-episode tasks"""
    # Store transition in memory
    self.memory.push(state, action, next_state, reward)
    
    if done:
        self.episodes_completed += 1
        self.update_epsilon()
        
def train_step(self, state, action, next_state, reward, done):
    """Perform a single training step"""
    # Add experience to memory
    self.add_experience(state, action, next_state, reward, done)
    
    # Optimize model
    self.optimize_model()
    
    # Update target network if needed
    if self.steps_done % (self.target_update * self.batch_size) == 0:
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
    return self.get_state_dict()

In [11]:
def select_action(self, state):
        sample = random.random()
        eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * \
                       np.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        
        if sample > eps_threshold:
            with torch.no_grad():
                state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.action_size)]], 
                              device=self.device, dtype=torch.long)
    
def optimize_model(self):
    if len(self.memory) < self.batch_size:
        return
    
    transitions = self.memory.sample(self.batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), 
                                device=self.device, dtype=torch.bool)
    non_final_next_states = torch.FloatTensor([s for s in batch.next_state
                                             if s is not None]).to(self.device)
    
    state_batch = torch.FloatTensor(batch.state).to(self.device)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.FloatTensor(batch.reward).to(self.device)
    
    state_action_values = self.policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(self.batch_size, device=self.device)
    next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * self.gamma) + reward_batch
    
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    self.optimizer.zero_grad()
    loss.backward()
    for param in self.policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    self.optimizer.step()

In [12]:
def train_with_visualization(num_episodes=1000, load_file=None, save_file='tetris_model.pth'):
    env = TetrisEnv()
    visualizer = TetrisVisualizer(env)
    state_size = 200 + len(SHAPES) + env.width * 2 + (env.width - 1)  # Updated state size
    action_size = 4
    
    agent = TetrisAgent(state_size, action_size)
    
    # Load previous state if specified
    if load_file:
        agent.load_state(load_file)
    
    scores = []
    best_score = float('-inf')
    
    try:
        for episode in range(num_episodes):
            state = env.reset()
            episode_score = 0
            done = False
            
            # Play one episode
            while not done:
                # Select and perform an action
                action = agent.select_action(state)
                next_state, reward, done = env.step(action.item())
                episode_score += reward
                
                # Store the transition in memory
                agent.memory.push(state, action, next_state, reward)
                
                # Move to the next state
                state = next_state
                
                # Perform optimization step
                agent.optimize_model()
                
                # Visualize every 10th episode
                if episode % 10 == 0:
                    visualizer.draw_board()
                    pygame.event.pump()  # Handle pygame events
                    time.sleep(0.05)  # Add small delay for visualization
            
            # Update training statistics
            scores.append(episode_score)
            agent.recent_scores.append(episode_score)
            if episode_score > best_score:
                best_score = episode_score
                agent.best_score = best_score
            
            # Update the target network every few episodes
            if episode % agent.target_update == 0:
                agent.target_net.load_state_dict(agent.policy_net.state_dict())
            
            # Save the model periodically
            if episode % 100 == 0:
                agent.save_state(save_file)
            
            # Print progress every 10 episodes
            if episode % 10 == 0:
                avg_score = np.mean(agent.recent_scores)
                print(f'Episode {episode}/{num_episodes}, Score: {episode_score:.2f}, Avg Score: {avg_score:.2f}, Best: {best_score:.2f}')
                
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
        agent.save_state(save_file)  # Save on interrupt
    finally:
        visualizer.close()
    
    return scores

In [13]:
def evaluate_agent(agent, num_episodes=10, visualize=True):
    env = TetrisEnv()
    visualizer = TetrisVisualizer(env) if visualize else None
    scores = []
    
    try:
        for episode in range(num_episodes):
            state = env.reset()
            episode_score = 0
            done = False
            
            while not done:
                # Select action without exploration
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
                    action = agent.policy_net(state_tensor).max(1)[1].view(1, 1)
                
                # Perform action
                next_state, reward, done = env.step(action.item())
                episode_score += reward
                state = next_state
                
                # Visualize if requested
                if visualize:
                    visualizer.draw_board()
                    pygame.event.pump()
                    time.sleep(0.1)
            
            scores.append(episode_score)
            print(f'Evaluation Episode {episode + 1}/{num_episodes}, Score: {episode_score:.2f}')
    
    except KeyboardInterrupt:
        print("\nEvaluation interrupted by user")
    finally:
        if visualize:
            visualizer.close()
    
    return scores

In [14]:
def run_tetris_training(training_episodes=1000, eval_episodes=5, save_interval=100, 
                       load_file=None, save_file='tetris_model.pth'):
    # Train the agent
    print("Starting training...")
    training_scores = train_with_visualization(
        num_episodes=training_episodes,
        load_file=load_file,
        save_file=save_file
    )
    
    # Plot training progress
    plt.figure(figsize=(12, 6))
    plt.plot(training_scores)
    plt.title('Training Progress')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.grid(True)
    plt.show()
    
    # Plot moving average
    window_size = 50
    moving_avg = np.convolve(training_scores, np.ones(window_size)/window_size, mode='valid')
    plt.figure(figsize=(12, 6))
    plt.plot(moving_avg)
    plt.title(f'Moving Average Score (Window Size: {window_size})')
    plt.xlabel('Episode')
    plt.ylabel('Average Score')
    plt.grid(True)
    plt.show()
    
    # Evaluate the trained agent
    print("\nEvaluating trained agent...")
    env = TetrisEnv()
    state_size = 200 + len(SHAPES) + env.width * 2 + (env.width - 1)
    action_size = 4
    agent = TetrisAgent(state_size, action_size)
    
    # Load the final saved model for evaluation
    agent.load_state(save_file)
    
    eval_scores = evaluate_agent(agent, num_episodes=eval_episodes)
    
    print(f"\nEvaluation Results:")
    print(f"Average Score: {np.mean(eval_scores):.2f}")
    print(f"Best Score: {np.max(eval_scores):.2f}")
    print(f"Worst Score: {np.min(eval_scores):.2f}")
    
    return training_scores, eval_scores

In [20]:
# To start fresh training:
# training_scores, eval_scores = run_tetris_training(training_episodes=1000)

# To continue training from a saved model:
training_scores, eval_scores = run_tetris_training(
    training_episodes=1000,
    load_file='tetris_model.pth'
)

Starting training...


AttributeError: 'TetrisAgent' object has no attribute 'load_state'

: 