In [4]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.ppo import PPO
import torch as th
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from gymnasium.wrappers import ResizeObservation
from vizdoom import gymnasium_wrapper

# Import model
from models.basic_vit import BasicViT

In [14]:
# Custom Features Extractor using BasicViT
class ViTFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=512):
        super().__init__(observation_space, features_dim)
        
        # Print raw observation space for debugging
        print(f"Observation space: {observation_space}")
        
        screen_space = observation_space['screen']
        shape = screen_space.shape
            
        print(f"Shape from observation space: {shape}")
        
        # VecTransposeImage wrapper changes format to channels-first (C, H, W)
        if len(shape) == 3 and shape[0] == 3:
            # Channels first format (C, H, W)
            c, h, w = shape
        else:
            # Assume channels last format (H, W, C)
            h, w, c = shape
            
        print(f"Extracted dimensions: h={h}, w={w}, c={c}")
        
        # Calculate patches
        patch_size = 8
        n_h = h // patch_size
        n_w = w // patch_size
        n_patches = n_h * n_w
        
        print(f"Will create {n_patches} patches ({n_h}x{n_w}) with patch_size={patch_size}")
        
        # Create the ViT with fixed in_channels=3 for RGB
        self.vit = BasicViT(
            img_size=(h, w),
            patch_size=patch_size,
            in_channels=3,  # IMPORTANT: Force to 3 for RGB images
            num_classes=features_dim,
            embed_dim=features_dim,
            depth=4,
            num_heads=4,
            mlp_ratio=2.0,
            pad_if_needed=True
        )

    def forward(self, observations):
        # Handle dict observations
        if isinstance(observations, dict):
            obs = observations['screen']
        else:
            obs = observations
        
        # Normalize
        normalized_obs = obs.float() / 255.0
        
        # Ensure BCHW format for PyTorch
        if len(normalized_obs.shape) == 4:
            # If batch dimension exists
            if normalized_obs.shape[1] == 3:
                # Already in BCHW format
                pass
            elif normalized_obs.shape[3] == 3:
                # Convert BHWC to BCHW
                normalized_obs = normalized_obs.permute(0, 3, 1, 2)
        
        return self.vit(normalized_obs)

# Custom Policy
class CustomViTPolicy(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            **kwargs,
            features_extractor_class=ViTFeatureExtractor,
            features_extractor_kwargs=dict(features_dim=512),
        )

# Environment Setup
env = make_vec_env("VizdoomBasic-v0", n_envs=4)
obs_space = env.observation_space['screen']
act_space = env.action_space.n
img_height, img_width, channels = obs_space.shape

# Train PPO Agent
model = PPO(CustomViTPolicy, env, verbose=1)
model.learn(total_timesteps=100_000)
model.save("ppo_vit_vizdoom")

Using cpu device
Wrapping the env in a VecTransposeImage.
Observation space: Dict('gamevariables': Box(-3.4028235e+38, 3.4028235e+38, (1,), float32), 'screen': Box(0, 255, (3, 240, 320), uint8))
Shape from observation space: (3, 240, 320)
Extracted dimensions: h=240, w=320, c=3
Will create 1200 patches (30x40) with patch_size=8
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 190      |
|    ep_rew_mean     | -177     |
| time/              |          |
|    fps             | 13       |
|    iterations      | 1        |
|    time_elapsed    | 613      |
|    total_timesteps | 8192     |
---------------------------------


KeyboardInterrupt: 