# Test
Cumulative code from parts 1-3

In [1]:
import cv2
from PIL import Image
import torch
from torchvision.models import resnet50, ResNet50_Weights
from collections import deque
import numpy as np
import matplotlib as plt
import win32gui
import win32api
import win32con
import time
import gymnasium
from gymnasium import spaces
from stable_baselines3 import DQN

VIDEO_IDX = 1 # OBS virtual cam index

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class ResNetEmbeddings:
    def __init__(self) -> None:
        # get model weight 
        self._weights = ResNet50_Weights.DEFAULT
        self._preprocess = self._weights.transforms()

        # get model
        self._model = resnet50(weights=self._weights).to(device)
        self._model.eval()

        self._feature_extractor = torch.nn.Sequential(*(list(self._model.children())[:-1])).to(device)
    
    def get_img_embedding(self, img):
        img_transformed = self._preprocess(img).unsqueeze(0).to(device)

        with torch.no_grad():
            embedding = self._feature_extractor(img_transformed).squeeze().cpu()
        return embedding

In [4]:
class FrameStacker:
    def __init__(self, stack_size=4, embedding_dim=2048):
        self.stack_size = stack_size
        self.embedding_dim = embedding_dim
        self.stack = deque(maxlen=stack_size)  # Fixed-size buffer

    def reset_stack(self, initial_embedding):
        self.stack.clear()
        for _ in range(self.stack_size): # fill with copies of first frame
            self.stack.append(initial_embedding)
        return self._get_stacked_embeddings()

    def add_frame(self, embedding):
        self.stack.append(embedding)
        return self._get_stacked_embeddings() # also returns current stacked embedings

    def _get_stacked_embeddings(self):
        return np.concatenate(self.stack, axis=0)

In [5]:
class EventDetectorWithROI:
    def __init__(
            self, 
            completion_template_path = "./template_img/complete_template.png",
            retry_template_path = "./template_img/retry_template.png", 
            completion_roi = (120, 110, 400, 40), 
            retry_roi = (160, 375, 55, 50) 
            ):
        # templates
        self.completion_template = cv2.imread(completion_template_path, cv2.IMREAD_GRAYSCALE)
        self.retry_template = cv2.imread(retry_template_path, cv2.IMREAD_GRAYSCALE)

        # roi regions
        self.retry_roi = retry_roi
        self.completion_roi = completion_roi
        print(self.completion_template.shape)
        print(self.retry_template.shape)

    def _crop_to_roi(self, frame, roi):
        """
        Crop the frame to the given ROI.
        """
        x, y, w, h = roi
        return frame[y:y+h, x:x+w]

    def detect_event(self, frame):
        """
        Detect if the player has died or completed the level.
        
        returns tring "death", "completion", or None
        """
        if isinstance(frame, Image.Image):
            frame = np.array(frame)
        # Convert the frame to grayscale
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

        # Crop the ROIs
        retry_roi_frame = self._crop_to_roi(gray_frame, self.retry_roi)
        completion_roi_frame = self._crop_to_roi(gray_frame, self.completion_roi)

        # Detect completion (Completion screen)
        completion_res = cv2.matchTemplate(completion_roi_frame, self.completion_template, cv2.TM_CCOEFF_NORMED)
        completion_threshold = 0.9
        if (completion_res >= completion_threshold).any():
            return "completion"
        
        # Detect death (Retry button), also after completion
        retry_res = cv2.matchTemplate(retry_roi_frame, self.retry_template, cv2.TM_CCOEFF_NORMED)
        retry_threshold = 0.8
        if (retry_res >= retry_threshold).any():
            return "death"

        # Detect completion (Completion screen)
        completion_res = cv2.matchTemplate(completion_roi_frame, self.completion_template, cv2.TM_CCOEFF_NORMED)
        completion_threshold = 0.8
        if (completion_res >= completion_threshold).any():
            return "completion"

        return None

In [6]:
class InputController:
    """
    Handles inputs to Geometry Dash using PyWin32
    """
    def __init__(self, window_name="Geometry Dash"):
        self.window_name = window_name
        self.hwnd = self.get_gd_hwnd()

    def get_gd_hwnd(self):
        hwnd = win32gui.FindWindow(None, self.window_name)
        if hwnd == 0:
            raise RuntimeError("Geometry Dash window not found. Make sure the game is running.")
        # print(f"Found Geometry Dash window: {hwnd}")
        return hwnd

    def send_keydown(self, key = None):
        """
        Send a KeyDown (Spacebar pressed) message to Geometry Dash.
        """
        if self.hwnd:
            if key:
                win32api.PostMessage(self.hwnd, win32con.WM_KEYDOWN, ord(key), 0)
            else:
                win32api.PostMessage(self.hwnd, win32con.WM_KEYDOWN, win32con.VK_SPACE, 0)

    def send_keyup(self, key = None):
        """
        Send a KeyUp (Spacebar released) message to Geometry Dash.
        """
        if self.hwnd:
            if key:
                win32api.PostMessage(self.hwnd, win32con.WM_KEYUP, ord(key), 0xC0000000)
            else:
                win32api.PostMessage(self.hwnd, win32con.WM_KEYUP, win32con.VK_SPACE, 0xC0000000)

    def reset_game(self):
        """
        Simulate resetting the Geometry Dash game (press 'R').
        """
        time.sleep(0.5)
        if self.hwnd:
            # Send 'r' key press to reset
            self.send_keydown('r')
            time.sleep(0.05)  # Short delay
            self.send_keyup('r')
            # print("Reset command sent to Geometry Dash!")
            self.send_keydown()
            time.sleep(0.05)  # Short delay
            self.send_keyup()

In [None]:
class GeometryDashEnv(gymnasium.Env):
    def __init__(self, frame_stacker, resnet_extractor, input_controller, event_detector, video_idx=VIDEO_IDX):
        super(GeometryDashEnv, self).__init__()
        print("Environment initialized!")

        # Initialize components
        self.frame_stacker: FrameStacker = frame_stacker
        self.resnet_extractor: ResNetEmbeddings = resnet_extractor
        self.input_controller: InputController = input_controller
        self.event_detector: EventDetectorWithROI = event_detector

        # Open the camera (OBS virtual camera)
        self.cap = cv2.VideoCapture(video_idx)
        if not self.cap.isOpened():
            raise RuntimeError(f"Failed to open video capture device at index {video_idx}.")

        # Define action and observation spaces
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(2048 * frame_stacker.stack_size,),
            dtype=np.float32
        )
        self.action_space = spaces.Discrete(2)  # 0: KeyUp, 1: KeyDown

        # Environment state
        self.current_state = None
        self.done = True
        self.previous_action = None

    def reset(self, *, seed=None, options=None):
        """
        Reset the environment and return the initial observation.
        """
        # Reset the game
        self.input_controller.reset_game()

        # Capture the first frame and process
        frame = self._capture_frame()
        initial_embedding = self.resnet_extractor.get_img_embedding(frame)

        # Initialize the frame stack
        self.current_state = self.frame_stacker.reset_stack(initial_embedding)
        self.done = False
        self.previous_action = None

        # gymnaisium stuff
        if seed is not None:
            self.np_random, seed = gymnasium.utils.seeding.np_random(seed)

        return self.current_state, {}

    def step(self, action):
        """
        Perform an action and return observation, reward, done, and info.
        """
        # Send the action to the game
        if action != self.previous_action:
            if action == 0:
                self.input_controller.send_keyup()
            else:
                self.input_controller.send_keydown()
            self.previous_action = action

        # Capture the next frame and process
        frame = self._capture_frame()
        next_embedding = self.resnet_extractor.get_img_embedding(frame)
        self.current_state = self.frame_stacker.add_frame(next_embedding)

        # Detect events (death or completion)
        event = self.event_detector.detect_event(frame)
        self.done = event in ["death", "completion"]

        # Calculate the reward
        reward = self._calculate_reward(event)

        return self.current_state, reward, self.done, False, {}

    def _capture_frame(self):
        """
        Capture a frame from the camera.
        """
        while True:
            ret, frame = self.cap.read()
            if ret:
                break
            else:
                print('fk')
        
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_frame = Image.fromarray(frame)
        return pil_frame

    def _calculate_reward(self, event):
        """
        Calculate the reward based on the current event.
        """
        if event == "death":
            return -100  # Penalty for death
        elif event == "completion":
            return 500  # Reward for completing the level
        else:
            return 2e-2  # Survival reward (training on 0.2)

    def close(self):
        """
        Release the camera and perform cleanup.
        """
        if self.cap.isOpened():
            self.cap.release()
        cv2.destroyAllWindows()

In [8]:
frame_stacker = FrameStacker(stack_size=4, embedding_dim=2048)
resnet_extractor = ResNetEmbeddings()
input_controller = InputController(window_name="Geometry Dash")
event_detector = EventDetectorWithROI(
    "./template_img/complete_template.png",
    "./template_img/retry_template.png",
    completion_roi=(120, 110, 400, 40), 
    retry_roi=(160, 375, 55, 50) 
)

(40, 400)
(50, 55)


In [9]:
# make env
env = GeometryDashEnv(
    frame_stacker, 
    resnet_extractor, 
    input_controller, 
    event_detector
)

Environment initialized!


In [None]:
# test env
obs, info = env.reset()
for _ in range(100):
    action = env.action_space.sample()
    obs, reward, done, truncated, info = env.step(action)
    print(f"Action: {action}, Reward: {reward}, Done: {done}, Truncated: {truncated}")
    if done or truncated:
        obs, info = env.reset()
env.close()

In [10]:
model = DQN(
    policy='MlpPolicy',
    env = env,
    buffer_size=10_000,
    batch_size=16,
    verbose=1,
    tensorboard_log='./dqn_gd_tensorboard'
)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [11]:
model.learn(total_timesteps=1_000_000)
model.save('dqn_geometry_dash')

Logging to ./dqn_gd_tensorboard\DQN_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.19e+03 |
|    ep_rew_mean      | 137      |
|    exploration_rate | 0.955    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 29       |
|    time_elapsed     | 161      |
|    total_timesteps  | 4744     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.22e-05 |
|    n_updates        | 1160     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.25e+03 |
|    ep_rew_mean      | 350      |
|    exploration_rate | 0.829    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 29       |
|    time_elapsed     | 605      |
|    total_timesteps  | 17998    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss        

In [12]:
model_load = DQN.load('dqn_geometry_dash', env=env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [13]:
obs, info = env.reset()

# Test the agent for a fixed number of steps
for _ in range(1000):  # Adjust the number of steps
    action, _states = model.predict(obs, deterministic=True)  # Get action from the model
    obs, reward, done, truncated, info = env.step(action)
    print(f"Action: {action}, Reward: {reward}, Done: {done}")
    
    # Reset the environment if done
    if done or truncated:
        obs, info = env.reset()

Action: 0, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 1, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done: False
Action: 0, Reward: 0.2, Done