In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
from minigrid.wrappers import ImgObsWrapper
from minigrid.core.constants import COLOR_NAMES

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv
import numpy as np

# 1. Custom Logic Wrapper because we use 7x7
class GoToRedBallLogicWrapper(gym.Wrapper):
    """
    Applies penalties based on symbolic rules for the GoToRedBall task.
    - Penalizes picking up non-red balls.
    - Penalizes calling 'done' when not at the red ball location.
    """
    def __init__(self, env: gym.Env, violation_penalty: float = -0.5):
        super().__init__(env)
        self.penalty = violation_penalty
        self.red_color_index = COLOR_NAMES.index('red')
        self.red_ball_pos = None
        assert hasattr(self.env.unwrapped, 'actions'), "Environment must have an 'actions' attribute"

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        grid = self.env.unwrapped.grid
        self.red_ball_pos = None

        print(f"\n--- RESETTING ENV --- Grid Size: {grid.width}x{grid.height}")
        self.red_ball_pos = None
        ball_found_debug = False

        for i in range(grid.width):
            for j in range(grid.height):
                cell = grid.get(i, j)
                if cell:
                    raw_color_attr = getattr(cell, 'color', None)
                    cell_type = getattr(cell, 'type', 'N/A')
                    cell_color_idx = -1

                    if isinstance(raw_color_attr, str):
                        try: cell_color_idx = COLOR_NAMES.index(raw_color_attr)
                        except ValueError: pass
                    elif isinstance(raw_color_attr, int):
                        if 0 <= raw_color_attr < len(COLOR_NAMES):
                          cell_color_idx = raw_color_attr
                    if hasattr(cell, 'type') and cell.type == 'ball':
                        ball_found_debug = True
                        if cell_color_idx == self.red_color_index:
                            self.red_ball_pos = (i, j)
                            break
            if self.red_ball_pos:
                break

        if not self.red_ball_pos:
            print(f"Warning: Red ball not found during reset! (Any ball objects found: {ball_found_debug})")
        else:
            print(f"FOUND: Red ball successfully located at {self.red_ball_pos}")
        print("--- RESET COMPLETE ---")

        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        info['logic_violation'] = None

        unwrapped_env = self.env.unwrapped

        if action == unwrapped_env.actions.pickup:
            fx, fy = unwrapped_env.front_pos
            cell_in_front = unwrapped_env.grid.get(fx, fy)
            if cell_in_front and hasattr(cell_in_front, 'type') and cell_in_front.type == 'ball' and hasattr(cell_in_front, 'color') and cell_in_front.color != self.red_color_index:
                reward += self.penalty
                info['logic_violation'] = 'pickup_wrong_ball'

        if action == unwrapped_env.actions.done:
            agent_pos = tuple(unwrapped_env.agent_pos)
            if self.red_ball_pos is None:
                 print("Warning: Checking done action but red_ball_pos is None.")
            elif agent_pos != self.red_ball_pos:
                reward += self.penalty
                info['logic_violation'] = 'premature_done'
        return obs, reward, terminated, truncated, info

# 2. Custom CNN Feature Extractor
class MiniGridCNN(BaseFeaturesExtractor):
    """
    Custom CNN Feature Extractor for MiniGrid-like environments (7x7 input).
    """
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        input_height = observation_space.shape[1]
        input_width = observation_space.shape[2]
        print(f"Initializing MiniGridCNN with input shape: ({n_input_channels}, {input_height}, {input_width})")
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            dummy_input = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(dummy_input).shape[1]
            print(f"Flattened CNN output size: {n_flatten}")

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        cnn_output = self.cnn(observations.float())
        return self.linear(cnn_output)

def make_env(env_id='BabyAI-GoToRedBallGrey-v0', violation_penalty=-0.5, seed=0):
    """
    Utility function for creating and wrapping the environment.
    Includes seeding for reproducibility.
    """
    def _init():
        env = gym.make(env_id, render_mode='rgb_array')
        env = ImgObsWrapper(env)
        env = GoToRedBallLogicWrapper(env, violation_penalty=violation_penalty)
        return env
    return _init

In [None]:
from stable_baselines3 import PPO
# 4. Training for PPO
if __name__ == "__main__":
    ENV_ID = 'BabyAI-GoToRedBallGrey-v0'
    NUM_ENVS = 4
    LOG_DIR = "./logs/gotorball_logic_custom_cnn/"
    MODEL_SAVE_PATH = "ppo_goto_redball_logic_custom_cnn"
    TOTAL_TIMESTEPS = 1_000_000
    FEATURES_DIM = 64
    VIOLATION_PENALTY = -0.5

    env_fns = [make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY, seed=i) for i in range(NUM_ENVS)]
    vec_env = DummyVecEnv(env_fns)

    print(f"Observation Space (after VecTransposeImage): {vec_env.observation_space.shape}")
    print(f"Action Space: {vec_env.action_space}")

    policy_kwargs = dict(
        features_extractor_class=MiniGridCNN,
        features_extractor_kwargs=dict(features_dim=FEATURES_DIM),
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    model = PPO(
        policy='CnnPolicy',
        env=vec_env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        tensorboard_log=LOG_DIR,
        device=device,
        n_steps=512,
        batch_size=64 * NUM_ENVS,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.02,
        vf_coef=0.5,
        learning_rate=1e-3,
    )

    print(f"\nStarting training on {ENV_ID} for {TOTAL_TIMESTEPS} timesteps...")
    print(f"TensorBoard logs will be saved to: {LOG_DIR}")
    print(f"Model will be saved to: {MODEL_SAVE_PATH}.zip\n")


    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        log_interval=10,
        )

    model.save(MODEL_SAVE_PATH)
    print(f"\nTraining finished. Model saved to {MODEL_SAVE_PATH}.zip")

    vec_env.close()
    print("Environment closed.")

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
# 4. Training for DQN
if __name__ == "__main__":
    ENV_ID = 'BabyAI-GoToRedBallGrey-v0'
    NUM_ENVS = 1
    LOG_DIR = "./logs/gotorball_logic_custom_cnn_dqn/"
    MODEL_SAVE_PATH = "dqn_goto_redball_logic_custom_cnn"
    TOTAL_TIMESTEPS = 1_000_000
    FEATURES_DIM = 64
    VIOLATION_PENALTY = -0.5

    vec_env = make_vec_env(
        lambda: make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY, seed=0)(),
        n_envs=NUM_ENVS,
        vec_env_cls=DummyVecEnv
    )
    vec_env = VecTransposeImage(vec_env)


    print(f"Observation Space: {vec_env.observation_space.shape}")
    print(f"Action Space: {vec_env.action_space}")

    policy_kwargs = dict(
        features_extractor_class=MiniGridCNN,
        features_extractor_kwargs=dict(features_dim=FEATURES_DIM),
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    model = DQN(
        policy='CnnPolicy',
        env=vec_env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        tensorboard_log=LOG_DIR,
        device=device,
        buffer_size=100_000,
        learning_rate=1e-4,
        learning_starts=5000,
        batch_size=32,
        tau=1.0,
        gamma=0.99,
        train_freq=(4, "step"),
        gradient_steps=1,
        target_update_interval=1000,
        exploration_fraction=0.1,
        exploration_initial_eps=1.0,
        exploration_final_eps=0.05,
        optimize_memory_usage=False,
    )

    print(f"\nStarting DQN training on {ENV_ID} for {TOTAL_TIMESTEPS} timesteps...")
    print(f"TensorBoard logs will be saved to: {LOG_DIR}")
    print(f"Model will be saved to: {MODEL_SAVE_PATH}.zip\n")


    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        log_interval=4
    )

    model.save(MODEL_SAVE_PATH)
    print(f"\nTraining finished. Model saved to {MODEL_SAVE_PATH}.zip")

    vec_env.close()
    print("Environment closed.")