In [None]:
import math
import time
import cv2
import mss
import numpy as np
from PIL import Image
from keras.models import Sequential, Model, load_model,model_from_json
from keras.layers import Conv2D, Dense, MaxPooling2D, Flatten, Input, BatchNormalization, Dropout, Add
from keras.optimizers import RMSprop, Adam
import tensorflow as tf
from keras import backend as K
import os
import datetime
from tensorflow.keras.callbacks import TensorBoard
import sys
import ctypes
from ctypes import wintypes, c_int, byref
import time, random
from collections import deque, Counter
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime

user32 = ctypes.WinDLL('user32', use_last_error=True)

INPUT_MOUSE    = 0
INPUT_KEYBOARD = 1
INPUT_HARDWARE = 2

KEYEVENTF_EXTENDEDKEY = 0x0001
KEYEVENTF_KEYUP       = 0x0002
KEYEVENTF_UNICODE     = 0x0004
KEYEVENTF_SCANCODE    = 0x0008

MAPVK_VK_TO_VSC = 0

# msdn.microsoft.com/en-us/library/dd375731
VK_LMENU    = 0x12 # Left Alt key 18
VK_LCONTROL = 0xA2 # Left Ctrl key 162

# Playstation valid buttons
DPAD_UP     = 0x57 # W key 87
DPAD_DOWN   = 0x53 # S key 83
DPAD_LEFT   = 0x41 # A key 65
DPAD_RIGHT  = 0x44 # D key 68
CROSS       = 0x45 # E key 69
SQUARE      = 0x52 # R key 82
CIRCLE      = 0x54 # T key 84
TRIANGLE    = 0x59 # Y key 89
R1          = 0x51 # Q key 81
TOUCHPAD    = 0x58 # X l1 key 88

# Diagonal directional buttons with arbitrary values
DIAG_DOWN_LEFT = (DPAD_LEFT, DPAD_DOWN)
DIAG_DOWN_RIGHT = (DPAD_RIGHT, DPAD_DOWN)
DIAG_UP_LEFT = (DPAD_LEFT, DPAD_UP)
DIAG_UP_RIGHT = (DPAD_RIGHT, DPAD_UP)

# Multi button attacks with arbitrary values
CROSS_SQUARE = (0,(CROSS, SQUARE))
CROSS_CIRCLE = (TRIANGLE,TRIANGLE,TRIANGLE,TRIANGLE)
SQUARE_TRIANGLE = (CROSS,TRIANGLE)
TRIANGLE_CIRCLE = (0,(TRIANGLE, CIRCLE))

ATTACK01 = (SQUARE,TRIANGLE,DPAD_RIGHT,DPAD_RIGHT,TRIANGLE,DPAD_RIGHT,CIRCLE)

delay = ['hold','tap']
# available inputs by type
direction = [0, DPAD_UP, DPAD_DOWN, DPAD_LEFT, DPAD_RIGHT, DIAG_DOWN_LEFT, DIAG_DOWN_RIGHT, DIAG_UP_LEFT, DIAG_UP_RIGHT]
attack = [ATTACK01,TRIANGLE, CIRCLE, CROSS, SQUARE,0, DPAD_UP, DPAD_DOWN, DPAD_LEFT, DPAD_RIGHT, DIAG_DOWN_LEFT, DIAG_DOWN_RIGHT, DIAG_UP_LEFT, DIAG_UP_RIGHT,
           CROSS_SQUARE, CROSS_CIRCLE, SQUARE_TRIANGLE, TRIANGLE_CIRCLE,R1,]
rage = [R1]

# all valid actions
valid_actions = attack
# valid_actions = [(x,y) for x in direction for y in attack]

wintypes.ULONG_PTR = wintypes.WPARAM

class PerformanceMonitor:
    """Class for monitoring and visualizing the performance of the Tekken Bot"""
    
    def __init__(self, save_dir='performance_data'):
        """Initialize the performance monitor"""
        self.save_dir = save_dir
        
        # Create directories if they don't exist
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            os.makedirs(os.path.join(save_dir, 'plots'))
            os.makedirs(os.path.join(save_dir, 'data'))
        
        # Initialize metrics storage
        self.episode_rewards = []
        self.episode_lengths = []
        self.avg_q_values = []
        self.losses = []
        self.epsilons = []
        self.max_combos = []
        self.difficulty_levels = []
        self.steps_history = []
        self.action_distribution = Counter()
        self.hit_rates = []
        self.damage_taken_rates = []
        self.training_times = []
        
        # Training progress tracking
        self.episodes_completed = 0
        self.total_steps = 0
        self.start_time = datetime.now()
        
        # Evaluation metrics
        self.eval_rewards = []
        self.eval_steps = []
        self.eval_q_values = []
        self.eval_max_combos = []
        self.eval_episodes = []
        
        # Performance metrics
        self.fps_history = []
        self.training_speed_history = []  # steps per second
        self.memory_usage = []
        
        # Initialize figure for real-time plotting
        plt.style.use('seaborn-darkgrid')
        plt.ion()  # Turn on interactive mode
        self.fig_size = (12, 10)
        
    def update(self, episode, reward, steps, avg_q, epsilon, loss, max_combo, difficulty_level, hits=0, damage_taken=0, actions=None):
        """Update metrics after each episode"""
        self.episodes_completed = episode
        self.total_steps = steps
        
        # Store episode metrics
        self.episode_rewards.append(reward)
        self.episode_lengths.append(steps - sum(self.episode_lengths) if self.episode_lengths else steps)
        self.avg_q_values.append(avg_q)
        self.epsilons.append(epsilon)
        self.max_combos.append(max_combo)
        self.difficulty_levels.append(difficulty_level)
        self.steps_history.append(steps)
        
        if loss is not None:
            self.losses.append(loss)
        
        # Store action distribution
        if actions is not None:
            self.action_distribution.update(actions)
            
        # Store hit and damage metrics
        self.hit_rates.append(hits)
        self.damage_taken_rates.append(damage_taken)
        
        # Store training time
        current_time = datetime.now()
        elapsed_time = (current_time - self.start_time).total_seconds()
        self.training_times.append(elapsed_time)
        
        # Calculate training speed (steps per second)
        steps_per_second = steps / elapsed_time if elapsed_time > 0 else 0
        self.training_speed_history.append(steps_per_second)
        
        # Save data periodically
        if episode % 5 == 0:
            self.save_data()
        
    def update_eval(self, episode, eval_results):
        """Update evaluation metrics"""
        self.eval_episodes.append(episode)
        self.eval_rewards.append(eval_results['mean_reward'])
        self.eval_steps.append(eval_results['mean_length'])
        self.eval_q_values.append(eval_results['mean_q'])
        self.eval_max_combos.append(eval_results['max_combo'])
        
        # Save evaluation data
        self.save_data()
        
    def update_fps(self, fps):
        """Update FPS metrics"""
        self.fps_history.append(fps)
        
    def plot_and_save(self):
        """Create and save comprehensive performance plots"""
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        
        # Create main dashboard
        self.create_performance_dashboard(timestamp)
        
        # Create detailed individual plots
        self.plot_rewards(timestamp)
        self.plot_q_values(timestamp)
        self.plot_loss(timestamp)
        self.plot_exploration(timestamp)
        self.plot_combos(timestamp)
        self.plot_action_distribution(timestamp)
        self.plot_training_speed(timestamp)
        self.plot_evaluation_metrics(timestamp)
        
        print(f"Performance plots saved to {self.save_dir}/plots/")
        
    def create_performance_dashboard(self, timestamp):
        """Create a comprehensive dashboard with key metrics"""
        plt.figure(figsize=(16, 12))
        
        # Episode rewards
        plt.subplot(3, 2, 1)
        plt.plot(self.episode_rewards, label='Episode Reward')
        if self.eval_episodes:
            eval_x = self.eval_episodes
            plt.scatter(eval_x, self.eval_rewards, color='red', label='Evaluation')
        plt.title('Episode Rewards')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.legend()
        
        # Q-values
        plt.subplot(3, 2, 2)
        plt.plot(self.avg_q_values, label='Avg Q-Value')
        plt.title('Average Q-Values')
        plt.xlabel('Episode')
        plt.ylabel('Q-Value')
        
        # Loss
        if self.losses:
            plt.subplot(3, 2, 3)
            plt.plot(self.losses)
            plt.title('Loss')
            plt.xlabel('Training Step (sampled)')
            plt.ylabel('Loss')
        
        # Epsilon
        plt.subplot(3, 2, 4)
        plt.plot(self.epsilons)
        plt.title('Exploration Rate (Epsilon)')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')
        
        # Max combos
        plt.subplot(3, 2, 5)
        plt.plot(self.max_combos, label='Max Combo')
        plt.title('Max Combo per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Max Combo')
        
        # Training speed
        plt.subplot(3, 2, 6)
        plt.plot(self.training_speed_history)
        plt.title('Training Speed')
        plt.xlabel('Episode')
        plt.ylabel('Steps per Second')
        
        plt.tight_layout()
        plt.savefig(f"{self.save_dir}/plots/dashboard_{timestamp}.png", dpi=150)
        plt.close()
        
    def plot_rewards(self, timestamp):
        """Create and save reward plots"""
        plt.figure(figsize=self.fig_size)
        
        # Raw rewards
        plt.subplot(2, 1, 1)
        plt.plot(self.episode_rewards)
        plt.title('Episode Rewards')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        
        # Moving average
        plt.subplot(2, 1, 2)
        window_size = min(10, len(self.episode_rewards))
        if window_size > 0:
            moving_avg = pd.Series(self.episode_rewards).rolling(window=window_size).mean()
            plt.plot(moving_avg)
            plt.title(f'Moving Average Reward (Window={window_size})')
            plt.xlabel('Episode')
            plt.ylabel('Average Reward')
        
        plt.tight_layout()
        plt.savefig(f"{self.save_dir}/plots/rewards_{timestamp}.png")
        plt.close()
        
    def plot_q_values(self, timestamp):
        """Create and save Q-value plots"""
        plt.figure(figsize=self.fig_size)
        
        plt.plot(self.avg_q_values)
        plt.title('Average Q-Values')
        plt.xlabel('Episode')
        plt.ylabel('Q-Value')
        
        plt.savefig(f"{self.save_dir}/plots/q_values_{timestamp}.png")
        plt.close()
        
    def plot_loss(self, timestamp):
        """Create and save loss plots"""
        if not self.losses:
            return
            
        plt.figure(figsize=self.fig_size)
        
        plt.plot(self.losses)
        plt.title('Training Loss')
        plt.xlabel('Training Step (sampled)')
        plt.ylabel('Loss')
        
        plt.savefig(f"{self.save_dir}/plots/loss_{timestamp}.png")
        plt.close()
        
    def plot_exploration(self, timestamp):
        """Create and save exploration rate plots"""
        plt.figure(figsize=self.fig_size)
        
        plt.plot(self.epsilons)
        plt.title('Exploration Rate (Epsilon)')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')
        
        plt.savefig(f"{self.save_dir}/plots/exploration_{timestamp}.png")
        plt.close()
        
    def plot_combos(self, timestamp):
        """Create and save combo metrics plots"""
        plt.figure(figsize=self.fig_size)
        
        plt.plot(self.max_combos)
        plt.title('Max Combo per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Max Combo')
        
        plt.savefig(f"{self.save_dir}/plots/combos_{timestamp}.png")
        plt.close()
        
    def plot_action_distribution(self, timestamp):
        """Create and save action distribution plots"""
        if not self.action_distribution:
            return
            
        plt.figure(figsize=self.fig_size)
        
        actions = list(self.action_distribution.keys())
        counts = list(self.action_distribution.values())
        
        # Sort by action index
        sorted_indices = np.argsort(actions)
        actions = [actions[i] for i in sorted_indices]
        counts = [counts[i] for i in sorted_indices]
        
        plt.bar(actions, counts)
        plt.title('Action Distribution')
        plt.xlabel('Action Index')
        plt.ylabel('Count')
        
        plt.savefig(f"{self.save_dir}/plots/action_distribution_{timestamp}.png")
        plt.close()
        
    def plot_training_speed(self, timestamp):
        """Create and save training speed plots"""
        plt.figure(figsize=self.fig_size)
        
        plt.subplot(2, 1, 1)
        plt.plot(self.training_speed_history)
        plt.title('Training Speed')
        plt.xlabel('Episode')
        plt.ylabel('Steps per Second')
        
        if self.fps_history:
            plt.subplot(2, 1, 2)
            plt.plot(self.fps_history)
            plt.title('Frames Per Second')
            plt.xlabel('Sample')
            plt.ylabel('FPS')
        
        plt.tight_layout()
        plt.savefig(f"{self.save_dir}/plots/training_speed_{timestamp}.png")
        plt.close()
        
    def plot_evaluation_metrics(self, timestamp):
        """Create and save evaluation metrics plots"""
        if not self.eval_episodes:
            return
            
        plt.figure(figsize=(16, 12))
        
        # Evaluation rewards
        plt.subplot(2, 2, 1)
        plt.plot(self.eval_episodes, self.eval_rewards, marker='o')
        plt.title('Evaluation Rewards')
        plt.xlabel('Training Episode')
        plt.ylabel('Mean Reward')
        
        # Evaluation Q-values
        plt.subplot(2, 2, 2)
        plt.plot(self.eval_episodes, self.eval_q_values, marker='o')
        plt.title('Evaluation Q-Values')
        plt.xlabel('Training Episode')
        plt.ylabel('Mean Q-Value')
        
        # Evaluation episode lengths
        plt.subplot(2, 2, 3)
        plt.plot(self.eval_episodes, self.eval_steps, marker='o')
        plt.title('Evaluation Episode Lengths')
        plt.xlabel('Training Episode')
        plt.ylabel('Mean Steps')
        
        # Evaluation max combos
        plt.subplot(2, 2, 4)
        plt.plot(self.eval_episodes, self.eval_max_combos, marker='o')
        plt.title('Evaluation Max Combos')
        plt.xlabel('Training Episode')
        plt.ylabel('Max Combo')
        
        plt.tight_layout()
        plt.savefig(f"{self.save_dir}/plots/evaluation_{timestamp}.png")
        plt.close()
        
    def save_data(self):
        """Save metrics to CSV files"""
        # Save training metrics
        training_data = {
            'episode': list(range(len(self.episode_rewards))),
            'reward': self.episode_rewards,
            'episode_length': self.episode_lengths,
            'q_value': self.avg_q_values,
            'epsilon': self.epsilons,
            'max_combo': self.max_combos,
            'difficulty': self.difficulty_levels,
            'total_steps': self.steps_history,
            'training_speed': self.training_speed_history,
            'training_time': self.training_times
        }
        
        # Create DataFrame and save
        df_training = pd.DataFrame(training_data)
        df_training.to_csv(f"{self.save_dir}/data/training_metrics.csv", index=False)
        
        # Save evaluation metrics if available
        if self.eval_episodes:
            eval_data = {
                'training_episode': self.eval_episodes,
                'eval_reward': self.eval_rewards,
                'eval_steps': self.eval_steps,
                'eval_q_value': self.eval_q_values,
                'eval_max_combo': self.eval_max_combos
            }
            
            df_eval = pd.DataFrame(eval_data)
            df_eval.to_csv(f"{self.save_dir}/data/evaluation_metrics.csv", index=False)
            
        # Save action distribution
        if self.action_distribution:
            action_data = {'action': list(self.action_distribution.keys()),
                          'count': list(self.action_distribution.values())}
            df_actions = pd.DataFrame(action_data)
            df_actions.to_csv(f"{self.save_dir}/data/action_distribution.csv", index=False)
            
        print(f"Metrics data saved to {self.save_dir}/data/")
        
    def get_summary_stats(self):
        """Get summary statistics of the training"""
        if not self.episode_rewards:
            return "No training data available."
            
        summary = {
            'episodes_completed': self.episodes_completed,
            'total_steps': self.total_steps,
            'avg_reward': np.mean(self.episode_rewards[-10:]) if len(self.episode_rewards) >= 10 else np.mean(self.episode_rewards),
            'max_reward': max(self.episode_rewards) if self.episode_rewards else 0,
            'avg_q_value': np.mean(self.avg_q_values[-10:]) if len(self.avg_q_values) >= 10 else np.mean(self.avg_q_values),
            'current_epsilon': self.epsilons[-1] if self.epsilons else 0,
            'max_combo_achieved': max(self.max_combos) if self.max_combos else 0,
            'training_time': self.training_times[-1] if self.training_times else 0,
            'fps': np.mean(self.fps_history[-100:]) if len(self.fps_history) >= 100 else np.mean(self.fps_history) if self.fps_history else 0,
            'steps_per_second': self.training_speed_history[-1] if self.training_speed_history else 0
        }
        
        # Format training time
        hours, remainder = divmod(summary['training_time'], 3600)
        minutes, seconds = divmod(remainder, 60)
        summary['training_time_formatted'] = f"{int(hours)}h {int(minutes)}m {int(seconds)}s"
        
        return summary
        
    def print_summary(self):
        """Print a formatted summary of the training progress"""
        stats = self.get_summary_stats()
        
        if isinstance(stats, str):
            print(stats)
            return
            
        print("\n" + "="*50)
        print("TRAINING SUMMARY")
        print("="*50)
        print(f"Episodes Completed: {stats['episodes_completed']}")
        print(f"Total Steps: {stats['total_steps']}")
        print(f"Recent Average Reward: {stats['avg_reward']:.2f}")
        print(f"Max Reward Achieved: {stats['max_reward']:.2f}")
        print(f"Current Average Q-Value: {stats['avg_q_value']:.4f}")
        print(f"Current Exploration Rate: {stats['current_epsilon']:.4f}")
        print(f"Max Combo Achieved: {stats['max_combo_achieved']}")
        print(f"Training Speed: {stats['steps_per_second']:.2f} steps/sec")
        print(f"Average FPS: {stats['fps']:.1f}")
        print(f"Total Training Time: {stats['training_time_formatted']}")
        print("="*50)

class MOUSEINPUT(ctypes.Structure):
    _fields_ = (("dx",          wintypes.LONG),
                ("dy",          wintypes.LONG),
                ("mouseData",   wintypes.DWORD),
                ("dwFlags",     wintypes.DWORD),
                ("time",        wintypes.DWORD),
                ("dwExtraInfo", wintypes.ULONG_PTR))

class KEYBDINPUT(ctypes.Structure):
    _fields_ = (("wVk",         wintypes.WORD),
                ("wScan",       wintypes.WORD),
                ("dwFlags",     wintypes.DWORD),
                ("time",        wintypes.DWORD),
                ("dwExtraInfo", wintypes.ULONG_PTR))

    def __init__(self, *args, **kwds):
        super(KEYBDINPUT, self).__init__(*args, **kwds)
        # some programs use the scan code even if KEYEVENTF_SCANCODE
        # isn't set in dwFflags, so attempt to map the correct code.
        if not self.dwFlags & KEYEVENTF_UNICODE:
            self.wScan = user32.MapVirtualKeyExW(self.wVk,
                                                 MAPVK_VK_TO_VSC, 0)
class HARDWAREINPUT(ctypes.Structure):
    _fields_ = (("uMsg",    wintypes.DWORD),
                ("wParamL", wintypes.WORD),
                ("wParamH", wintypes.WORD))

class INPUT(ctypes.Structure):
    class _INPUT(ctypes.Union):
        _fields_ = (("ki", KEYBDINPUT),
                    ("mi", MOUSEINPUT),
                    ("hi", HARDWAREINPUT))
    _anonymous_ = ("_input",)
    _fields_ = (("type",   wintypes.DWORD),
                ("_input", _INPUT))

def _check_count(result, func, args):
    if result == 0:
        raise ctypes.WinError(ctypes.get_last_error())
    return args

class InputHandler:
    LPINPUT = ctypes.POINTER(INPUT)
    user32.SendInput.errcheck = _check_count
    user32.SendInput.argtypes = (wintypes.UINT, # nInputs
                                 LPINPUT,       # pInputs
                                 ctypes.c_int)  # cbSize

    def __init__(self):
        self.PS4RemotePlayHWND = 0
        self.PS4RemotePlayPID = 0
        # Create a command buffer to optimize action execution
        self.command_buffer = []
        
    def queue_command(self, key, action_type, delay=0):
        """Queue a command to execute later"""
        self.command_buffer.append((key, action_type, delay))
        
    def execute_command_buffer(self):
        """Execute all queued commands in sequence"""
        for key, action_type, delay in self.command_buffer:
            if action_type == 'press':
                self.press_key(key)
            elif action_type == 'release':
                self.release_key(key)
            
            if delay > 0:
                time.sleep(delay)
                
        # Clear buffer after execution
        self.command_buffer = []

    def get_actions(self, amount):
        actions = []
        actions.append([])
        action = 0
        for i in range(0,amount):
            temp = random.randint(0,1)
            if temp == 0:
                # select something random from the direction arroy
                action = direction[random.randint(0,len(direction)-1)]
            else:
                # select something random from the attack array
                action = attack[random.randint(0,len(direction)-1)]
            # Only add the input if it is not 0. 0 Is the same as nothing.
            if action != 0:
                actions.append(action)
        # Get the delay time for pressing these keys
        delayVal = delay[random.randint(0,1)]
        if delayVal == 'hold':
            # can't use i as the index because I am only adding non 0 inputs
            actions[0].append(random.uniform(0.04, 0.06))
        else:
            # can't use i as the index because I am only adding non 0 inputs
            actions[0].append(random.uniform(0.02, 0.04))
        return actions

    def get_action(self, index):
        print('Valid Actions:', valid_actions[index])
        return valid_actions[index]

    def get_remote_play_pid(self):
        # register winapi functions
        EnumWindows = ctypes.windll.user32.EnumWindows
        EnumWindowsProc = ctypes.WINFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int))
        GetWindowText = ctypes.windll.user32.GetWindowTextW
        GetWindowTextLength = ctypes.windll.user32.GetWindowTextLengthW
        IsWindowVisible = ctypes.windll.user32.IsWindowVisible
        GetWindowThreadProcessId = ctypes.windll.user32.GetWindowThreadProcessId

        def foreach_window(self, hwnd, lParam):
            # window must be visible
            if IsWindowVisible(hwnd):
                length = GetWindowTextLength(hwnd)
                buff = ctypes.create_unicode_buffer(length + 1)
                GetWindowText(hwnd, buff, length + 1)
                try:
                    windowtitle = buff.value
                    if "REPL4Y" in windowtitle:
                        # get the processid from the hwnd
                        # declaring this as global means refer to the global version
                        processID = c_int()
                        threadID = GetWindowThreadProcessId(hwnd, byref(processID))
                        # found the process ID
                        self.PS4RemotePlayPID = processID
                        self.PS4RemotePlayHWND = hwnd
                        return True
                except:
                    print("Unexpected error:"+sys.exc_info()[0])
                    pass;
            return True
        EnumWindows(EnumWindowsProc(foreach_window(self)), 0)

    def press_key(self, hexKeyCode):
        print(f"Pressing key: {hexKeyCode}")
        x = INPUT(type=INPUT_KEYBOARD,
                  ki=KEYBDINPUT(wVk=hexKeyCode))
        user32.SendInput(1, ctypes.byref(x), ctypes.sizeof(x))

    def release_key(self, hexKeyCode):
        print(f"Releasing key: {hexKeyCode}")
        x = INPUT(type=INPUT_KEYBOARD,
                  ki=KEYBDINPUT(wVk=hexKeyCode,
                                dwFlags=KEYEVENTF_KEYUP))
        user32.SendInput(1, ctypes.byref(x), ctypes.sizeof(x))

    def focus_window(self, hwnd):
        ctypes.windll.user32.SetForegroundWindow(hwnd)
        self.activate_remap()

    def activate_remap(self):
        time.sleep(0.5)
        self.press_key(VK_LCONTROL)
        self.press_key(VK_LMENU)
        time.sleep(0.01)
        self.release_key(VK_LCONTROL)
        self.release_key(VK_LMENU)
        time.sleep(0.01)
        
    def hold_delay(self):
        time.sleep(random.uniform(0.3, 0.35))

    def process_keys(self, keys, action_type):
        """Recursively process key presses and releases."""
        if isinstance(keys, tuple):
            for key in keys:
                self.process_keys(key, action_type)
        else:
            if keys != 0:
                if action_type == 'press':
                    self.press_key(keys)
                elif action_type == 'release':
                    self.release_key(keys)

    def execute_action(self, actionIndex):
        action = valid_actions[actionIndex]
        print(f"Executing action: {action}")  # Debug print
        self.press_key(88)
        if(actionIndex==0): #(82, 89, 68, 68, 89, 68, 84)
            self.press_key(88)
            self.press_key(82)
            time.sleep(random.uniform(0.3, 0.35))
            self.release_key(82)
            self.press_key(89)
            time.sleep(random.uniform(0.3, 0.35))
            self.release_key(89)
            self.press_key(68)
            time.sleep(random.uniform(0.3, 0.35))
            self.release_key(68)
            self.press_key(68)
            time.sleep(random.uniform(0.3, 0.35))
            self.press_key(89)
            time.sleep(random.uniform(0.3, 0.35))
            self.release_key(89)
            self.press_key(68)
            time.sleep(random.uniform(0.03, 0.05))
            self.release_key(68)
            self.press_key(68)
            self.press_key(84)
            time.sleep(random.uniform(0.3, 0.35))
            self.release_key(84)
        # Check if action is a tuple or list
        elif isinstance(action, (tuple, list)):
            self.press_key(88)
            for key_sequence in action:
                # Press keys in the current sequence
                self.process_keys(key_sequence, 'press')
                # Hold delay after pressing keys
                self.hold_delay()
                # Release keys in the current sequence
                self.process_keys(key_sequence, 'release')
        else:
            # Handle single key press action
            if action != 0:
                self.process_keys(action, 'press')
                self.hold_delay()
                self.process_keys(action, 'release')
            else:
                print(f"Invalid action structure: {action}")
        self.release_key(88)

    def execute_actions(self, actions):
        print("in execute_Actions", actions)
        for action in actions:
            self.execute_action(action)
            # time.sleep(random.uniform(0.01, 0.03))

# Improved SumTree implementation with more efficient operations
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.write = 0
        self.n_entries = 0
        
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        
        if parent != 0:
            self._propagate(parent, change)
            
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
            
        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])
        
    def total(self):
        return self.tree[0]
    
    def add(self, p, data):
        idx = self.write + self.capacity - 1
        
        self.data[self.write] = data
        self.update(idx, p)
        
        self.write += 1
        if self.write >= self.capacity:
            self.write = 0
            
        if self.n_entries < self.capacity:
            self.n_entries += 1
            
    def update(self, idx, p):
        change = p - self.tree[idx]
        
        self.tree[idx] = p
        self._propagate(idx, change)
        
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        
        return (idx, self.tree[idx], self.data[dataIdx])
    
    def get_batch(self, batch_size):
        """More efficient batch sampling"""
        indices = []
        priorities = []
        data = []
        segment = self.total() / batch_size
        
        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            
            s = random.uniform(a, b)
            idx, p, item = self.get(s)
            
            indices.append(idx)
            priorities.append(p)
            data.append(item)
            
        return indices, priorities, data

# IMPROVED HYPERPARAMETERS
IMAGE_STACK = 4 
IMAGE_WIDTH = 84
IMAGE_HEIGHT = 84
HUBER_LOSS = 1.0  # Reduced for more stable learning
LEARNING_RATE = 0.0001  # Lower learning rate for better stability
MEMORY_CAPACITY = 500000
BATCH_SIZE = 64
GAMMA = 0.99
MAX_EPSILON = 1.0
MIN_EPSILON = 0.05
EXPLORATION_STOP = 500000  # Shorter exploration for faster convergence
LAMBDA = -math.log(0.01) / EXPLORATION_STOP
UPDATE_TARGET_FREQUENCY = 2000  # More frequent updates
FRAME_SKIP = 2  # Process every nth frame for efficiency
REPLAY_PERIOD = 4  # How often to perform replay

# Tracking variables
best_eval_reward = float('-inf')
running_reward = 0

def huber_loss(y_true, y_pred):
    """Huber loss function for robust training"""
    error = y_true - y_pred
    cond = tf.abs(error) < HUBER_LOSS
    
    squared_loss = 0.5 * tf.square(error)
    linear_loss = HUBER_LOSS * (tf.abs(error) - 0.5 * HUBER_LOSS)
    
    loss = tf.where(cond, squared_loss, linear_loss)
    return tf.reduce_mean(loss)

def preprocess_frame(frame):
    """More efficient frame preprocessing with proper normalization"""
    # Convert to grayscale if needed
    if len(frame.shape) > 2 and frame.shape[2] > 1:
        gray = np.dot(frame[...,:3], [0.299, 0.587, 0.114])
    else:
        gray = frame
    
    # Resize efficiently
    resized = cv2.resize(gray, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA)
    
    # Normalize to [0,1] range
    normalized = resized / 255.0
    
    return normalized.astype(np.float32)

def create_state(frame):
    """Create a state by stacking the same frame multiple times"""
    # Create an array of stacked identical frames, more efficient than stack
    frames = [frame] * IMAGE_STACK
    return np.array(frames).transpose(1, 2, 0)  # HWC format

def update_state(state, new_frame):
    """Update state with a new frame - efficient implementation"""
    # Update state by shifting frames and adding new frame
    # This avoids creating a new array each time
    state = np.roll(state, -1, axis=2)  # Roll along channel dimension
    state[:, :, -1] = new_frame  # Add new frame
    return state

def augment_state(state):
    """Data augmentation for better generalization"""
    # Make a copy to avoid modifying the original
    augmented = np.copy(state)
    
    # Random brightness adjustment
    if random.random() < 0.3:
        brightness_factor = random.uniform(0.8, 1.2)
        augmented = np.clip(augmented * brightness_factor, 0, 1.0)
    
    # Random contrast adjustment
    if random.random() < 0.3:
        contrast_factor = random.uniform(0.8, 1.2)
        mean = np.mean(augmented)
        augmented = np.clip((augmented - mean) * contrast_factor + mean, 0, 1.0)
        
    # Adding random noise - improves robustness
    if random.random() < 0.2:
        noise = np.random.normal(0, 0.01, augmented.shape)
        augmented = np.clip(augmented + noise, 0, 1.0)
    
    return augmented

class ResidualBlock(tf.keras.layers.Layer):
    """Residual block for better gradient flow"""
    def __init__(self, filters, kernel_size=3, strides=1, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        
        self.conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters, kernel_size, padding='same')
        self.bn2 = BatchNormalization()
        
        # Skip connection
        self.skip = Conv2D(filters, 1, strides=strides, padding='same') if strides > 1 else None
        
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        
        # Apply skip connection
        if self.skip is not None:
            skip = self.skip(inputs)
        else:
            skip = inputs
            
        return tf.nn.relu(x + skip)
    
    def get_config(self):
        config = super(ResidualBlock, self).get_config()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides
        })
        return config

class Model:
    def __init__(self, input_shape, actionCnt, model=None, target_model=None, use_dueling=True):
        self.input_shape = input_shape
        self.actionCnt = actionCnt
        self.steps = 0
        self.use_dueling = use_dueling
        
        # Setup TensorBoard with better logging
        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
        self.log_dir = f"logs/fit/{current_time}"
        self.tensorboard_callback = TensorBoard(
            log_dir=self.log_dir,
            histogram_freq=1,
            update_freq='epoch',
            profile_batch=0
        )
        
        # Learning rate scheduler
        self.initial_learning_rate = LEARNING_RATE
        
        if model is not None:
            self.model = model
            self.target_model = target_model
        else:
            if use_dueling:
                self.model = self._createDuelingModel()
                self.target_model = self._createDuelingModel()
                # Ensure target model has same weights
                self.update_target_model()
            else:
                self.model = self._createModel()
                self.target_model = self._createModel()
                self.update_target_model()
                
    def _createModel(self):
        """Basic DQN model with residual connections"""
        model = Sequential()
        
        # Input shape in HWC format - height, width, channels(frames)
        input_shape = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_STACK)
        
        # First convolutional layer
        model.add(Conv2D(32, (8, 8), strides=(4, 4), activation='relu', 
                         input_shape=input_shape, padding='same'))
        
        # Add residual blocks for better gradient flow
        model.add(ResidualBlock(64, 4, strides=2))
        model.add(ResidualBlock(64, 3))
        
        # Flatten and fully connected layers
        model.add(Flatten())
        model.add(Dense(512, activation='relu'))
        model.add(Dropout(0.2))
        model.add(Dense(units=self.actionCnt, activation='linear'))
        
        # Use Adam optimizer with gradient clipping
        opt = Adam(learning_rate=LEARNING_RATE, clipnorm=1.0)
        model.compile(loss=huber_loss, optimizer=opt)
        model.summary()
        return model
        
    def _createDuelingModel(self):
        """Dueling DQN with residual connections"""
        # Input in HWC format
        input_shape = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_STACK)
        input_layer = Input(shape=input_shape)
        
        # Convolutional layers with residual connections
        conv1 = Conv2D(32, (8, 8), strides=(4, 4), activation='relu', padding='same')(input_layer)
        res1 = ResidualBlock(64, 4, strides=2)(conv1)
        res2 = ResidualBlock(64, 3)(res1)
        
        flat = Flatten()(res2)
        
        # Value stream (estimates state value)
        value_fc = Dense(512, activation='relu')(flat)
        value_dropout = Dropout(0.2)(value_fc)
        value = Dense(1)(value_dropout)
        
        # Advantage stream (estimates action advantages)
        adv_fc = Dense(512, activation='relu')(flat)
        adv_dropout = Dropout(0.2)(adv_fc)
        advantage = Dense(self.actionCnt)(adv_dropout)
        
        # Combine value and advantage (dueling architecture)
        # Subtract mean to ensure identifiability
        outputs = value + (advantage - tf.reduce_mean(advantage, axis=1, keepdims=True))
        
        model = tf.keras.Model(inputs=input_layer, outputs=outputs)
        opt = Adam(learning_rate=LEARNING_RATE, clipnorm=1.0)
        model.compile(loss=huber_loss, optimizer=opt)
        model.summary()
        
        return model
        
    def train(self, x, y, sample_weight=None, epochs=1, verbose=0):
        """Training with learning rate decay and callbacks"""
        self.steps += 1
        
        # Apply learning rate decay schedule
        if self.steps > 50000:
            lr = self.initial_learning_rate * 0.5
        elif self.steps > 100000:
            lr = self.initial_learning_rate * 0.25
        elif self.steps > 200000:
            lr = self.initial_learning_rate * 0.1
        else:
            lr = self.initial_learning_rate
            
        K.set_value(self.model.optimizer.lr, lr)
        
        # Use larger batch size as training progresses for efficiency
        batch_size = min(128, 32 + self.steps // 50000)
        
        if sample_weight is not None:
            return self.model.fit(
                x, y, 
                batch_size=batch_size, 
                sample_weight=sample_weight,
                epochs=epochs, 
                verbose=verbose, 
                callbacks=[self.tensorboard_callback]
            )
        else:
            return self.model.fit(
                x, y, 
                batch_size=batch_size,
                epochs=epochs, 
                verbose=verbose, 
                callbacks=[self.tensorboard_callback]
            )
            
    def train_on_batch(self, x, y, sample_weight=None):
        """Efficient batch training with gradient clipping"""
        return self.model.train_on_batch(x, y, sample_weight=sample_weight)
        
    def predict(self, s, target=False):
        """Prediction with error handling"""
        try:
            if target:
                return self.target_model.predict(s, verbose=0)
            else:
                return self.model.predict(s, verbose=0)
        except Exception as e:
            print(f"Error in prediction: {e}")
            print(f"Input shape: {s.shape}")
            # Return zeros if prediction fails
            if len(s.shape) == 4:  # Batch of states
                return np.zeros((s.shape[0], self.actionCnt))
            else:  # Single state
                return np.zeros(self.actionCnt)
                
    def predict_one(self, s, target=False):
        """Predict Q-values for a single state"""
        # Ensure correct shape: HWC format with batch dimension
        if len(s.shape) == 3:  # If no batch dimension
            s = np.expand_dims(s, axis=0)  # Add batch dimension
        return self.predict(s, target).flatten()
        
    def update_target_model(self):
        """Update target model with current weights"""
        self.target_model.set_weights(self.model.get_weights())
        
    def save_models(self, path, prefix=""):
        """Save both models with error handling"""
        try:
            self.model.save(f"{path}/{prefix}model.h5")
            self.target_model.save(f"{path}/{prefix}target_model.h5")
            print(f"Models saved to {path}")
            return True
        except Exception as e:
            print(f"Error saving models: {e}")
            return False

# Enhanced Memory implementation with fixed-size efficient sampling
class Memory:
    def __init__(self, capacity, epsilon=0.01, alpha=0.6, beta=0.4, beta_increment=0.001):
        self.tree = SumTree(capacity)
        self.epsilon = epsilon  # small amount to avoid zero priority
        self.alpha = alpha      # how much prioritization is used
        self.beta = beta        # importance-sampling, increases to 1 over time
        self.beta_increment = beta_increment
        self.capacity = capacity
        self.max_priority = 1.0
        
    def _getPriority(self, error):
        """Calculate priority based on TD error"""
        return (np.abs(error) + self.epsilon) ** self.alpha
        
    def add(self, error, sample):
        """Add experience to memory with prioritization"""
        # Use max priority for new samples to ensure they get sampled
        p = self.max_priority if error is None else self._getPriority(error)
        self.tree.add(p, sample)
        
        # Update max priority
        if p > self.max_priority:
            self.max_priority = p
            
    def sample(self, n):
        """Sample batch with importance sampling weights"""
        # Increase beta over time for annealing
        self.beta = min(1.0, self.beta + self.beta_increment)
        
        # More efficient batch sampling
        indices, priorities, samples = self.tree.get_batch(n)
        
        # Calculate importance sampling weights
        sampling_probabilities = np.array(priorities) / self.tree.total()
        weights = (self.capacity * sampling_probabilities) ** (-self.beta)
        weights /= weights.max()  # Normalize for stability
        
        return list(zip(indices, samples)), weights
        
    def update(self, idx, error):
        """Update priorities based on new TD errors"""
        p = self._getPriority(error)
        self.tree.update(idx, p)
        
        # Update max priority
        if p > self.max_priority:
            self.max_priority = p
            
    def size(self):
        """Get current memory size"""
        return self.tree.n_entries

class LearningAgent:
    steps = 0
    latest_Q = 0
    epsilon = MAX_EPSILON
    difficulty_level = 5  # Starting difficulty level (1-10)
    combo_counter = 0
    defensive_stance = False

    def __init__(self, learning=False, epsilon=1.0, alpha=0.5):
        self.input_handler = InputHandler()
        self.learning = learning
        # HWC format for input shape
        self.inputShape = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_STACK)
        self.numActions = len(valid_actions)
        self.model = Model(self.inputShape, self.numActions, use_dueling=True)
        self.memory = Memory(MEMORY_CAPACITY)
        
        # Efficient data tracking with fixed-size arrays
        self.rewards_history = deque(maxlen=10000)
        self.q_values_history = deque(maxlen=10000)
        self.loss_history = deque(maxlen=10000)
        
        # Action tracking for diversity
        self.prev_actions = deque(maxlen=10)
        self.action_repeat_penalty = -0.1
        
        # Frame skipping counter
        self.frame_skip_counter = 0
        self.last_action = 0
        
        # Combo tracking
        self.current_combo = 0
        self.max_combo = 0
        
        # Add for performance monitoring
        self.action_counts = Counter()
        self.hits_landed = 0
        self.damage_taken = 0
        
    def observe(self, sample):
        """Process and store experience in replay memory"""
        s, a, r, s_ = sample
        
        # Apply data augmentation for better generalization
        if random.random() < 0.2:
            s = augment_state(s)
            if s_ is not None:
                s_ = augment_state(s_)
        
        # Get targets and errors for prioritized replay
        x, y, errors = self.get_targets([(0, (s, a, r, s_))])
        self.memory.add(errors[0], (s, a, r, s_))

        # Update target network periodically
        if self.steps % UPDATE_TARGET_FREQUENCY == 0:
            self.model.update_target_model()
            print(f"Target network updated at step {self.steps}")

        # Update epsilon with decay schedule
        self.steps += 1
        self.epsilon = MIN_EPSILON + (MAX_EPSILON - MIN_EPSILON) * math.exp(-LAMBDA * self.steps)
        
        # Track reward for metrics
        self.rewards_history.append(r)
        
        # Log Q values periodically
        if self.steps % 100 == 0:
            self.q_values_history.append(np.mean(self.model.predict_one(s)))
            
            # Print progress
            avg_reward = np.mean(list(self.rewards_history)[-100:]) if len(self.rewards_history) >= 100 else np.mean(list(self.rewards_history))
            print(f"Step: {self.steps}, Epsilon: {self.epsilon:.4f}, Avg Reward: {avg_reward:.4f}, Latest Q: {self.latest_Q:.4f}")

    def get_targets(self, batch):
        """Calculate target Q values using Double DQN"""
        no_state = np.zeros(self.inputShape)

        states = np.array([o[1][0] for o in batch])
        states_ = np.array([(no_state if o[1][3] is None else o[1][3]) for o in batch])

        # Predict Q values from current and target networks
        p = self.model.predict(states)
        p_ = self.model.predict(states_, target=False)  # model predictions for next states
        pTarget_ = self.model.predict(states_, target=True)  # target model predictions for next states

        x = np.zeros((len(batch), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_STACK))
        y = np.zeros((len(batch), self.numActions))
        errors = np.zeros(len(batch))

        for i in range(len(batch)):
            o = batch[i][1]
            s = o[0]; a = o[1]; r = o[2]; s_ = o[3]
            
            t = p[i].copy()
            oldVal = t[a]
            
            if s_ is None:
                t[a] = r
            else:
                # Double Deep Q-Learning update
                a_ = np.argmax(p_[i])  # Action selection from online network
                t[a] = r + GAMMA * pTarget_[i][a_]  # Value from target network

            x[i] = s
            y[i] = t
            # TD error for prioritized replay
            errors[i] = abs(oldVal - t[a])
            self.latest_Q = t[a]

        return (x, y, errors)

    def choose_action(self, state):
        """Select action using epsilon-greedy policy with diversity bonus"""
        self.state = state
        
        # Apply frame skipping - repeat last action for efficiency
        if self.frame_skip_counter < FRAME_SKIP and self.frame_skip_counter > 0:
            self.frame_skip_counter += 1
            return self.last_action
        else:
            self.frame_skip_counter = 1
        
        # Apply action repeating penalty to encourage diversity
        action_penalties = np.zeros(self.numActions)
        for prev_action in self.prev_actions:
            action_penalties[prev_action] += self.action_repeat_penalty
        
        action = None

        if not self.learning:
            action = random.randint(0, len(valid_actions)-1)
        else:
            # Exploration vs exploitation
            if random.uniform(0, 1) < self.epsilon:
                action = random.randint(0, len(valid_actions)-1)
                print("Exploring with random action:", action)
            else:
                # Get Q values and apply penalties for repeated actions
                q_values = self.model.predict_one(state)
                adjusted_q_values = q_values + action_penalties
                action = np.argmax(adjusted_q_values)
                print(f"Exploiting with action {action}, Q-value: {q_values[action]:.4f}")
        
        # Update previous actions list for action diversity
        self.prev_actions.append(action)
        self.last_action = action
        
        # Update action counter for monitoring
        self.action_counts[action] += 1
            
        return action

    def execute_action(self, action):
        """Execute selected action in game environment"""
        self.input_handler.execute_action(action)

    def replay(self):
        """Experience replay with prioritized sampling"""
        # Wait until we have enough samples
        if self.memory.size() < BATCH_SIZE:
            return None
            
        # Sample batch with importance sampling weights
        batch, is_weights = self.memory.sample(BATCH_SIZE)
        
        # Get targets and errors for network update
        x, y, errors = self.get_targets(batch)
        
        # Update priorities in memory
        for i, (idx, _) in enumerate(batch):
            self.memory.update(idx, errors[i])
        
        # Train the model with importance sampling weights
        loss = self.model.train_on_batch(x, y, sample_weight=is_weights)
        
        # Track loss for metrics
        self.loss_history.append(loss)
        
        return loss

    def play(self, state):
        """Used during evaluation/deployment without exploration"""
        self.state = state
        # Always choose the best action during play
        action = np.argmax(self.model.predict_one(state))
        
        # Track action for monitoring
        self.action_counts[action] += 1
        
        return action
        
    def adapt_difficulty(self, mean_reward):
        """Adjust opponent AI difficulty based on agent performance"""
        if mean_reward > 5.0:
            # Increase difficulty when agent performs well
            self.difficulty_level = min(10, self.difficulty_level + 1)
            print(f"Increasing difficulty to level {self.difficulty_level}")
        elif mean_reward < -5.0:
            # Decrease difficulty when agent struggles
            self.difficulty_level = max(1, self.difficulty_level - 1)
            print(f"Decreasing difficulty to level {self.difficulty_level}")
        
        # Return current difficulty level for game setup
        return self.difficulty_level
        
    def update_combo(self, hit_successful):
        """Track combo counter for reward shaping"""
        if hit_successful:
            self.current_combo += 1
            self.max_combo = max(self.max_combo, self.current_combo)
        else:
            self.current_combo = 0
            
    def get_combo_bonus(self):
        """Calculate bonus reward based on current combo"""
        return min(0.5, self.current_combo * 0.1)
        
    def reset_episode_stats(self):
        """Reset per-episode tracking statistics"""
        self.hits_landed = 0
        self.damage_taken = 0
        self.current_combo = 0
        # Don't reset max_combo as that's a running statistic

class Vision:
    screen = {'top': 110, 'left': 0, 'width': 1920, 'height': 970}
    leftHPCapture = {'top': 110, 'left': 240, 'width': 620, 'height': 20}
    rightHPCapture = {'top': 110, 'left': 1043, 'width': 620, 'height': 20}
    
    positive = 1    # AI hit the opponent
    negative = -1   # AI took a hit
    
    def __init__(self, side):
        self.side = side
        with mss.mss() as sct:
            self.prevLeftHP = self.numpy_img_to_gray(np.array(sct.grab(self.leftHPCapture)))
            self.prevRightHP = self.numpy_img_to_gray(np.array(sct.grab(self.rightHPCapture)))
        
        # For combo tracking
        self.combo_counter = 0
        self.last_hit_time = 0
        self.combo_timeout = 1.0  # seconds between hits to count as combo
        
        # For defensive move detection
        self.defensive_stance = False
        self.last_damage_taken = 0
        
        # For state normalization
        self.frames_seen = 0
        self.running_mean = 0
        self.running_std = 0
        
        # Performance tracking
        self.frame_times = deque(maxlen=100)
        self.last_frame_time = time.time()
        
        # Add for performance monitoring
        self.hits_landed = 0
        self.damage_taken = 0
        self.total_frames = 0
        self.performance_monitor = None

    def numpy_img_to_gray(self, img):
        """Convert RGB image to grayscale efficiently"""
        return np.dot(img[...,:3], [0.299, 0.587, 0.114])

    def get_current_screen(self):
        """Capture and preprocess current screen"""
        start_time = time.time()
        
        with mss.mss() as sct:
            sct_img = sct.grab(self.screen)
            img = Image.frombytes('RGB', sct_img.size, sct_img.rgb)
            img = img.resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.LANCZOS)
            currScreen = np.array(img)
            
        # Preprocess the screen
        processed = preprocess_frame(currScreen)
        
        # Update running statistics for normalization (optional)
        self.frames_seen += 1
        delta = processed.mean() - self.running_mean
        self.running_mean += delta / self.frames_seen
        delta2 = processed.mean() - self.running_mean
        self.running_std += delta * delta2
        
        # Track frame processing time
        frame_time = time.time() - start_time
        self.frame_times.append(frame_time)
        
        # Print average FPS periodically
        if self.frames_seen % 100 == 0:
            avg_frame_time = np.mean(self.frame_times)
            fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
            print(f"Average FPS: {fps:.1f}")
            
            # Update performance monitor with FPS if available
            if self.performance_monitor is not None:
                self.performance_monitor.update_fps(fps)
            
        return processed

    def get_reward(self):
        """Calculate reward based on game state with improved reward shaping"""
        with mss.mss() as sct:
            reward = 0
            hit_landed = False
            took_damage = False

            currLeftHP = np.array(sct.grab(self.leftHPCapture))
            currRightHP = np.array(sct.grab(self.rightHPCapture))
            
            # Convert to gray
            currLeftHP = self.numpy_img_to_gray(currLeftHP)
            currRightHP = self.numpy_img_to_gray(currRightHP)
            
            # Get the difference in previous vs current
            diffLeftHP = self.prevLeftHP - currLeftHP
            diffRightHP = self.prevRightHP - currRightHP
            
            # Round negative values up to 0
            diffLeftHP = diffLeftHP.clip(min=0)
            diffRightHP = diffRightHP.clip(min=0)
            
            # Calculate damage amounts for more granular rewards
            left_damage_amount = (diffLeftHP > 125).sum() / 100.0
            right_damage_amount = (diffRightHP > 125).sum() / 100.0
            
            # Basic reward calculation based on side
            if left_damage_amount > 0.1:  # Left character took damage
                if self.side == 'left':
                    reward -= 1 + left_damage_amount  # Penalize based on damage amount
                    self.defensive_stance = True
                    self.last_damage_taken = time.time()
                    self.combo_counter = 0
                    took_damage = True
                    self.damage_taken += left_damage_amount  # Track damage taken
                else:
                    reward += 1 + left_damage_amount  # Reward based on damage amount
                    self.update_combo()
                    hit_landed = True
                    self.hits_landed += 1  # Track successful hits
                    
            if right_damage_amount > 0.1:  # Right character took damage
                if self.side == 'right':
                    reward -= 1 + right_damage_amount
                    self.defensive_stance = True
                    self.last_damage_taken = time.time()
                    self.combo_counter = 0
                    took_damage = True
                    self.damage_taken += right_damage_amount  # Track damage taken
                else:
                    reward += 1 + right_damage_amount
                    self.update_combo()
                    hit_landed = True
                    self.hits_landed += 1  # Track successful hits
            
            # Add combo bonus
            if self.combo_counter > 1:
                combo_bonus = min(self.combo_counter * 0.2, 1.5)
                reward += combo_bonus
                
            # Add time penalty (small) to encourage action
            time_penalty = -0.005
            reward += time_penalty
            
            # Add defensive bonus if successful defense
            if self.defensive_stance and time.time() - self.last_damage_taken > 2.0:
                if not took_damage:
                    reward += 0.2  # Bonus for successful defense
                self.defensive_stance = False
                
            # Add position reward based on screen analysis (optional)
            # This would require additional screen analysis
                
            # Set previous frame data to current frame data
            self.prevLeftHP = currLeftHP
            self.prevRightHP = currRightHP
            
            # Increment frame counter
            self.total_frames += 1
            
            # Track FPS every 100 frames
            if self.total_frames % 100 == 0:
                avg_frame_time = np.mean(self.frame_times)
                fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
                # Update the performance monitor with FPS if it exists
                if hasattr(self, 'performance_monitor') and self.performance_monitor is not None:
                    self.performance_monitor.update_fps(fps)

        return reward, hit_landed, took_damage
    
    def update_combo(self):
        """Update combo counter based on timing between hits"""
        current_time = time.time()
        
        # Check if this hit is part of a combo (within timeout)
        if current_time - self.last_hit_time < self.combo_timeout:
            self.combo_counter += 1
            print(f"Combo x{self.combo_counter}!")
        else:
            # Start new combo
            self.combo_counter = 1
            
        self.last_hit_time = current_time
        return self.combo_counter

TRIAL = 1
TOTAL_TESTS = 1
TOTAL_EPISODES = 100
EVAL_INTERVAL = 5  # Evaluate every 5 episodes

def save_models(agent, curr_test, is_best=False):
    """Save model with enhanced naming and best model tracking"""
    directory = 'Models/trial 2'
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    # Save regular checkpoint
    model_name = f"/TekkenBotDDQN_{curr_test}.h5"
    target_name = f"/TekkenBotDDQN_Target_{curr_test}.h5"
    
    if is_best:
        # Save as best model
        model_name = "/TekkenBotDDQN_best.h5"
        target_name = "/TekkenBotDDQN_Target_best.h5"

        # Save model architecture as JSON
    model_json = agent.model.model.to_json()
    with open(f"{directory}{model_name.replace('.h5', '.json')}", "w") as json_file:
        json_file.write(model_json)
    
    # Save target model architecture as JSON
    target_json = agent.model.target_model.to_json()
    with open(f"{directory}{target_name.replace('.h5', '.json')}", "w") as json_file:
        json_file.write(target_json)
        
    # Save both models
    agent.model.model.save(directory + model_name, overwrite=True)
    agent.model.target_model.save(directory + target_name, overwrite=True)
    
    print(f"Models saved: {model_name} and {target_name}")
    
    # Save training history
    if len(agent.rewards_history) > 0:
        np.save(f"{directory}/rewards_history.npy", np.array(list(agent.rewards_history)))
    if len(agent.q_values_history) > 0:
        np.save(f"{directory}/q_values_history.npy", np.array(list(agent.q_values_history)))
    if len(agent.loss_history) > 0:
        np.save(f"{directory}/loss_history.npy", np.array(list(agent.loss_history)))
        
    # Save replay memory samples for later resume
    if agent.memory.size() > 1000:
        print("Saving memory samples...")
        mem_indices, mem_priorities, mem_samples = [], [], []
        for _ in range(1000):  # Save a subset of memory
            idx, p, sample = agent.memory.tree.get(random.uniform(0, agent.memory.tree.total()))
            mem_indices.append(idx)
            mem_priorities.append(p)
            mem_samples.append(sample)
        
        np.save(f"{directory}/memory_samples.npy", np.array(mem_samples, dtype=object))
        print("Memory samples saved.")

def evaluate(agent, num_episodes=10, performance_monitor=None):
    """More comprehensive evaluation function with performance monitoring"""
    global best_eval_reward
    
    agent.input_handler.activate_remap()
    vision = Vision('left')
    
    # Set performance monitor for vision if provided
    if performance_monitor is not None:
        vision.performance_monitor = performance_monitor
    
    rewards = []
    episode_lengths = []
    q_values = []
    
    print(f"Starting evaluation for {num_episodes} episodes...")
    
    for episode in range(num_episodes):
        # Reset episode stats
        agent.reset_episode_stats()
        vision.hits_landed = 0
        vision.damage_taken = 0
        
        # Get initial observation and create initial state
        init_frame = vision.get_current_screen()
        
        # Create state with stacked frames (more efficient with array)
        state = np.array([init_frame] * IMAGE_STACK).transpose(1, 2, 0)
        
        done = False
        total_reward = 0
        steps = 0
        start_time = time.time()
        
        while not done:
            # Use deterministic policy (no exploration)
            actionIndex = agent.play(state)
            agent.execute_action(actionIndex)
            
            # Capture new observation and reward
            new_obs = vision.get_current_screen()
            reward, hit, damage = vision.get_reward()
            
            # Update agent's internal combo counter
            if hit:
                agent.update_combo(True)
                agent.hits_landed += 1
            elif damage:
                agent.update_combo(False)
                agent.damage_taken += 1
            
            # Track Q-values during evaluation
            q_vals = agent.model.predict_one(state)
            q_values.append(np.mean(q_vals))
            
            # Efficient state update
            state = update_state(state, new_obs)
            
            total_reward += reward
            steps += 1
            
            # End episode after 60 seconds
            if time.time() - start_time > 59:
                done = True
                
        rewards.append(total_reward)
        episode_lengths.append(steps)
        print(f"Eval Episode {episode+1}: Reward = {total_reward:.2f}, Steps = {steps}, Max Combo: {agent.max_combo}")
    
    # Calculate evaluation metrics
    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)
    mean_length = np.mean(episode_lengths)
    std_length = np.std(episode_lengths)
    mean_q = np.mean(q_values) if q_values else 0
    
    print(f'Evaluation Results:')
    print(f'Average Reward: {mean_reward:.2f} ± {std_reward:.2f}')
    print(f'Average Episode Length: {mean_length:.2f} ± {std_length:.2f}')
    print(f'Average Q Value: {mean_q:.4f}')
    
    # Save best model if this is the best performance
    if mean_reward > best_eval_reward:
        best_eval_reward = mean_reward
        save_models(agent, -1, is_best=True)
        print(f"New best model saved with reward: {mean_reward:.2f}")
    
    # Update performance monitor if provided
    if performance_monitor is not None:
        eval_results = {
            'rewards': rewards,
            'mean_reward': mean_reward,
            'episode_lengths': episode_lengths,
            'mean_length': mean_length,
            'mean_q': mean_q,
            'max_combo': agent.max_combo
        }
        performance_monitor.update_eval(agent.steps, eval_results)
    
    # Return metrics for tracking
    return {
        'rewards': rewards,
        'mean_reward': mean_reward,
        'episode_lengths': episode_lengths,
        'mean_length': mean_length,
        'mean_q': mean_q,
        'max_combo': agent.max_combo
    }

def import_model(agent):
    """Import saved model with better error handling"""
    try:
        # Define custom objects including your ResidualBlock and loss function
        custom_objects = {
            'ResidualBlock': ResidualBlock,
            'huber_loss': huber_loss
        }
        
        # Try to load best model first
        best_model_path = 'Models/trial 2/TekkenBotDDQN_best.h5'
        best_target_path = 'Models/trial 2/TekkenBotDDQN_Target_best.h5'
        
        if os.path.exists(best_model_path) and os.path.exists(best_target_path):
            agent.model = Model(agent.inputShape, agent.numActions,
                load_model(best_model_path, custom_objects=custom_objects),
                load_model(best_target_path, custom_objects=custom_objects),
                use_dueling=True)
            print('Best model loaded successfully.')
            
            # Try to load saved memory samples
            memory_path = 'Models/trial 2/memory_samples.npy'
            if os.path.exists(memory_path):
                try:
                    memory_samples = np.load(memory_path, allow_pickle=True)
                    print(f"Loading {len(memory_samples)} memory samples...")
                    for sample in memory_samples:
                        agent.memory.add(None, sample)  # Add with max priority
                    print(f"Memory initialized with {agent.memory.size()} samples.")
                except Exception as e:
                    print(f"Error loading memory samples: {e}")
                    
            return True
            
        # Fall back to regular model
        model_path = 'Models/trial 2/TekkenBotDDQN_1.h5'
        target_path = 'Models/trial 2/TekkenBotDDQN_Target_1.h5'
        
        if os.path.exists(model_path) and os.path.exists(target_path):
            agent.model = Model(agent.inputShape, agent.numActions,
                load_model(model_path, custom_objects=custom_objects),
                load_model(target_path, custom_objects=custom_objects),
                use_dueling=True)
            print('Model loaded successfully.')
            return True
            
        print('No existing model found. Starting with fresh model.')
        return False
        
    except Exception as e:
        print('Error loading model:')
        print(e)
        print('Starting with fresh model.')
        return False
    
def play(agent):
    """Play mode for the trained agent with performance monitoring"""
    # Create performance monitor for play mode
    performance_monitor = PerformanceMonitor(save_dir='performance_data/play')
    
    agent.input_handler.activate_remap()
    vision = Vision('left')
    vision.performance_monitor = performance_monitor
    
    # Initialize with first observation
    initial_screen = vision.get_current_screen()
    # More efficient state creation with array
    state = np.array([initial_screen] * IMAGE_STACK).transpose(1, 2, 0)
    
    print("Starting play mode with trained agent. Press Ctrl+C to stop.")
    
    # Reset stats for play session
    agent.reset_episode_stats()
    episode_start_time = time.time()
    episode_rewards = 0
    episode_actions = []
    episode_num = 0
    
    try:
        while True:
            # Choose action using greedy policy
            actionIndex = agent.play(state)
            agent.execute_action(actionIndex)
            
            # Track actions
            episode_actions.append(actionIndex)
            
            # Get new observation
            screenCap = vision.get_current_screen()
            reward, hit, damage = vision.get_reward()
            
            # More efficient state update
            state = update_state(state, screenCap)
            
            # Update combo counter
            if hit:
                agent.update_combo(True)
                agent.hits_landed += 1
                print(f"Hit landed! Combo: {agent.current_combo}")
            elif damage:
                agent.update_combo(False)
                agent.damage_taken += 1
                print("Took damage!")
                
            # Accumulate reward
            episode_rewards += reward
            
            # End of episode (60 seconds) - update performance monitor
            if time.time() - episode_start_time > 59:
                # Get average Q-value
                avg_q = np.mean(agent.model.predict_one(state))
                
                # Update performance monitor
                performance_monitor.update(
                    episode=episode_num,
                    reward=episode_rewards,
                    steps=agent.steps,
                    avg_q=avg_q,
                    epsilon=0.0,  # No exploration in play mode
                    loss=None,
                    max_combo=agent.max_combo,
                    difficulty_level=agent.difficulty_level,
                    hits=agent.hits_landed,
                    damage_taken=agent.damage_taken,
                    actions=episode_actions
                )
                
                
                print(f"Play Episode {episode_num+1} completed. Reward: {episode_rewards:.2f}, Max Combo: {agent.max_combo}")
                
                # Reset for next episode
                episode_num += 1
                episode_start_time = time.time()
                episode_rewards = 0
                episode_actions = []
                agent.reset_episode_stats()
            
            # Optional: add short sleep to control action rate
            time.sleep(0.05)

    except KeyboardInterrupt:
        print('Play mode stopped.')
        # Save final performance data
        performance_monitor.print_summary()
        performance_monitor.plot_and_save()

def run(agent):
    """Enhanced training function with improved monitoring and visualization"""
    # Create performance monitor
    performance_monitor = PerformanceMonitor(save_dir='performance_data')
    
    agent.input_handler.activate_remap()
    vision = Vision('left')
    
    # Set performance monitor for vision
    vision.performance_monitor = performance_monitor
    
    # Get first observation and create initial state
    first_obs = vision.get_current_screen()
    # More efficient state creation with array instead of stack
    state = np.array([first_obs] * IMAGE_STACK).transpose(1, 2, 0)
    
    start_time = time.time()
    episode_start_time = time.time()
    global_start_time = time.time()
    
    # Training statistics
    reward_total = 0
    episode = 0
    true_episode = 0
    curr_test = 1
    record_size = TOTAL_EPISODES * TOTAL_TESTS + 1
    # Tracking more metrics
    reward_per_episode = np.zeros((record_size, 6))  # Added column for max combo
    i = 0
    
    # Evaluation metrics
    eval_results = []
    
    # Reset episode stats
    agent.reset_episode_stats()
    vision.hits_landed = 0
    vision.damage_taken = 0
    
    # For action distribution tracking
    episode_actions = []
    
    try:
        print("Starting training. Press Ctrl+C to stop.")
        while True:
            # Choose action based on current policy
            actionIndex = agent.choose_action(state)
            agent.execute_action(actionIndex)
            
            # Track actions for distribution analysis
            episode_actions.append(actionIndex)
            
            # Get new observation and reward
            screenCap = vision.get_current_screen()
            reward, hit, damage = vision.get_reward()
            
            # Update agent's combo counter
            if hit:
                agent.update_combo(True)
                agent.hits_landed += 1
            elif damage:
                agent.update_combo(False)
                agent.damage_taken += 1
            
            # Efficient state update
            next_state = update_state(state.copy(), screenCap)
            
            # Store experience in replay memory
            agent.observe((state, actionIndex, reward, next_state))
            
            # Perform experience replay periodically
            if agent.steps % REPLAY_PERIOD == 0:
                loss = agent.replay()
                if loss is not None:
                    agent.loss_history.append(loss)
            
            state = next_state
            reward_total += reward
            
            # End of episode (60 seconds)
            if time.time() - episode_start_time > 59:
                # Record episode stats
                avg_q = np.mean(list(agent.q_values_history)[-100:]) if agent.q_values_history else 0
                avg_loss = np.mean(list(agent.loss_history)[-100:]) if agent.loss_history and len(agent.loss_history) >= 100 else None
                
                reward_per_episode[i] = [true_episode, reward_total, agent.steps, avg_q, agent.difficulty_level, agent.max_combo]
                
                # Update performance monitor
                performance_monitor.update(
                    episode=true_episode,
                    reward=reward_total,
                    steps=agent.steps,
                    avg_q=avg_q,
                    epsilon=agent.epsilon,
                    loss=avg_loss,
                    max_combo=agent.max_combo,
                    difficulty_level=agent.difficulty_level,
                    hits=agent.hits_landed,
                    damage_taken=agent.damage_taken,
                    actions=episode_actions
                )
                
                # Generate plots every 5 episodes
                
                # Calculate training speed
                elapsed = time.time() - global_start_time
                steps_per_second = agent.steps / elapsed if elapsed > 0 else 0
                
                print(f"Episode {true_episode+1} completed. Reward: {reward_total:.2f}, Steps: {agent.steps}, "
                      f"Epsilon: {agent.epsilon:.4f}, Max Combo: {agent.max_combo}, "
                      f"Steps/sec: {steps_per_second:.1f}")
                
                episode += 1
                true_episode += 1
                i += 1
                episode_start_time = time.time()
                
                # Reset episode stats
                agent.reset_episode_stats()
                vision.hits_landed = 0
                vision.damage_taken = 0
                episode_actions = []
                
                # Evaluate periodically
                if episode % EVAL_INTERVAL == 0:
                    print(f"Performing evaluation after episode {true_episode}")
                    eval_result = evaluate(agent, num_episodes=5, performance_monitor=performance_monitor)
                    eval_results.append(eval_result)
                    
                    # Adjust difficulty based on performance
                    agent.adapt_difficulty(eval_result['mean_reward'])
                
                # Reset for next episode
                reward_total = 0
                
            # End of test segment
            if episode >= TOTAL_EPISODES:
                if curr_test >= TOTAL_TESTS:
                    save_models(agent, curr_test)
                    print("Training completed. Final model saved.")
                    # Final performance plots
                    performance_monitor.plot_and_save()
                    break
                else:
                    save_models(agent, curr_test)
                    curr_test += 1
                    episode = 0
                    print(f"Starting test segment {curr_test}")
                    
                    # Evaluate between test segments
                    evaluate(agent, num_episodes=10, performance_monitor=performance_monitor)

    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    
    finally:
        # Save final models and training data
        print("Saving models and training history...")
        save_models(agent, curr_test)
        
        performance_monitor.plot_and_save()
        performance_monitor.print_summary()
        # Save episode rewards and evaluation results
        np.savetxt('Models/trial 2/episodesAndRewards.txt', reward_per_episode, fmt='%.4f')
        print("Episodes and rewards saved to episodesAndRewards.txt")
        
        # Save evaluation results
        if eval_results:
            eval_dir = 'Models/trial 2/evaluations'
            if not os.path.exists(eval_dir):
                os.makedirs(eval_dir)
            np.save(f"{eval_dir}/eval_results.npy", eval_results)
            
        # Final performance plots
        performance_monitor.plot_and_save()
            
        # Final statistics
        elapsed = time.time() - global_start_time
        hours = elapsed // 3600
        minutes = (elapsed % 3600) // 60
        seconds = elapsed % 60
        
        print(f"Training complete. Total time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
        print(f"Total steps: {agent.steps}, Steps per second: {agent.steps / elapsed:.1f}")
        print(f"Max combo achieved: {agent.max_combo}")
        
        # Print performance summary
        performance_monitor.print_summary()

if __name__ == '__main__':
    try:
        # Create the learning agent with dueling network and improved hyperparameters
        agent = LearningAgent(learning=True, epsilon=MAX_EPSILON, alpha=0.6)
        
        # Try to load existing model if available
        imported = import_model(agent)
        
        # Ask user whether to train or play
        mode = input("Enter mode (train/play): ").lower()
        
        if mode == 'train':
            print("Starting training mode...")
            run(agent)
        elif mode == 'play':
            if not imported:
                print("Warning: No model found for play mode. Training a simple model first...")
                # Do a short training session
                run(agent)
            print("Starting play mode...")
            play(agent)
        else:
            print("Invalid mode. Please enter 'train' or 'play'.")
    
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        # Print stack trace for debugging
        import traceback
        traceback.print_exc()
    
    finally:
        print('Session ended. Thanks for playing!')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 84, 84, 4)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 21, 21, 32)   8224        ['input_1[0][0]']                
                                                                                                  
 residual_block (ResidualBlock)  (None, 11, 11, 64)  101056      ['conv2d[0][0]']                 
                                                                                                  
 residual_block_1 (ResidualBloc  (None, 11, 11, 64)  74368       ['residual_block[0][0]']         
 k)                                                                                           

  plt.style.use('seaborn-darkgrid')


Pressing key: 162
Pressing key: 18
Releasing key: 162
Releasing key: 18
Starting training. Press Ctrl+C to stop.
Exploring with random action: 5
Executing action: 0
Pressing key: 88
Invalid action structure: 0
Releasing key: 88


  img = img.resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.LANCZOS)


Target network updated at step 0
Executing action: 0
Pressing key: 88
Invalid action structure: 0
Releasing key: 88
Exploring with random action: 12
Executing action: (65, 87)
Pressing key: 88
Pressing key: 88
Pressing key: 65
Releasing key: 65
Pressing key: 87
Releasing key: 87
Releasing key: 88
Executing action: (65, 87)
Pressing key: 88
Pressing key: 88
Pressing key: 65
Releasing key: 65
Pressing key: 87
Releasing key: 87
Releasing key: 88
Exploring with random action: 14
Executing action: (0, (69, 82))
Pressing key: 88
Pressing key: 88
Pressing key: 69
Pressing key: 82
Releasing key: 69
Releasing key: 82
Releasing key: 88
Executing action: (0, (69, 82))
Pressing key: 88
Pressing key: 88
Pressing key: 69
Pressing key: 82
Releasing key: 69
Releasing key: 82
Releasing key: 88
Exploring with random action: 2
Executing action: 84
Pressing key: 88
Pressing key: 84
Releasing key: 84
Releasing key: 88
Executing action: 84
Pressing key: 88
Pressing key: 84
Releasing key: 84
Releasing key: 8

: 

In [None]:
xxerxxerxtxtxaxaxxeyxxeyxxxaxaxaxaxxdsxxdsxsxsxaxaxxeyxxeyxxawxxawxxxxawxxawxyxyxxyyyyxxyy