In [2]:
"""
ViZDoom + Stable-Baselines3 PPO
- Train headless (no on-screen rendering)
- After training, record an evaluation rollout
- Save MP4 and GIF, each <= 100 MB by lowering FPS (frame dropping)

Requirements (install if missing):
  pip install vizdoom stable-baselines3 gymnasium imageio imageio-ffmpeg
"""

import os, io, sys, time, math, importlib, subprocess
import numpy as np

# ---------- Optional: ensure key video deps ----------
def ensure(spec, import_name=None):
    try:
        importlib.import_module(import_name or spec.split("[")[0])
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", spec])

# Only enforce the lightweight video bits automatically
ensure("imageio")
ensure("imageio-ffmpeg", "imageio_ffmpeg")

import imageio.v2 as imageio
import vizdoom as vzd
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO

# Force headless SDL when no window is desired (harmless if window_visible=True)
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")


# =========================
# Doom Environment (Gymnasium)
# =========================
def initialize_doom(window_visible: bool = False):
    game = vzd.DoomGame()
    game.set_doom_scenario_path(vzd.scenarios_path + "/basic.wad")
    game.set_doom_map("map01")

    # Screen settings (keep consistent between train & eval)
    game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)
    game.set_screen_format(vzd.ScreenFormat.RGB24)

    # Actions (Discrete over buttons)
    game.add_available_button(vzd.Button.MOVE_LEFT)
    game.add_available_button(vzd.Button.MOVE_RIGHT)
    game.add_available_button(vzd.Button.ATTACK)

    # Optional variables
    game.add_available_game_variable(vzd.GameVariable.AMMO2)

    # Difficulty
    game.set_doom_skill(3)

    # Headless during training/eval rollout
    game.set_window_visible(window_visible)
    game.set_sound_enabled(False)
    game.set_mode(vzd.Mode.PLAYER)

    # Episode handling
    game.set_episode_timeout(3000)  # max tics per episode
    game.set_episode_start_time(10)

    game.init()
    return game


class DoomEnv(gym.Env):
    """Minimal ViZDoom -> Gymnasium wrapper for SB3 CnnPolicy (channels-first)."""
    metadata = {"render_modes": []}

    def __init__(self, window_visible: bool = False):
        super().__init__()
        self.game = initialize_doom(window_visible=window_visible)
        self.h = self.game.get_screen_height()
        self.w = self.game.get_screen_width()
        # SB3 CnnPolicy expects channels-first (C,H,W) uint8
        self.observation_space = spaces.Box(low=0, high=255, shape=(3, self.h, self.w), dtype=np.uint8)
        self.action_space = spaces.Discrete(len(self.game.get_available_buttons()))

    def _obs_from_game(self):
        st = self.game.get_state()
        if st is None or st.screen_buffer is None:
            return np.zeros((3, self.h, self.w), dtype=np.uint8)
        # ViZDoom gives HxWx3 RGB; convert to CxHxW
        return np.transpose(st.screen_buffer, (2, 0, 1))

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)
        self.game.new_episode()
        obs = self._obs_from_game()
        return obs, {}

    def step(self, action):
        # Convert discrete action index -> one-hot list for ViZDoom
        act = [0] * self.action_space.n
        act[int(action)] = 1
        reward = float(self.game.make_action(act))
        done = self.game.is_episode_finished()

        if done:
            obs = np.zeros_like(self._obs_from_game())
        else:
            obs = self._obs_from_game()

        terminated = done
        truncated = False
        return obs, reward, terminated, truncated, {}

    def close(self):
        if self.game is not None:
            self.game.close()
            self.game = None


# =========================
# Training (headless)
# =========================
TOTAL_TIMESTEPS = 10_000   # increase for better policy
ALG_VERBOSE = 1

def train_model():
    print("Creating headless Doom training env...")
    env = DoomEnv(window_visible=False)

    print(f"Training PPO for {TOTAL_TIMESTEPS} timesteps (no on-screen rendering)...")
    model = PPO("CnnPolicy", env, verbose=ALG_VERBOSE)
    model.learn(total_timesteps=TOTAL_TIMESTEPS)

    env.close()
    print("Training complete.")
    return model


# =========================
# Recording & Size Capping
# =========================
EVAL_SECONDS = 20           # duration of the recorded clip
ORIG_FPS    = 35            # visual FPS for recording
MAX_MB      = 100
FINAL_MP4   = "doom_policy.mp4"
FINAL_GIF   = "doom_policy.gif"
TMP_MP4     = "_tmp_doom_full.mp4"  # temporary file for first pass

def file_mb(path):
    try:
        return os.path.getsize(path) / (1024 * 1024)
    except OSError:
        return 0.0

def downsample_stride(orig_fps, target_fps):
    return max(1, math.ceil(orig_fps / max(1, target_fps)))

def transcode_mp4_under_cap(src_path, dst_path, orig_fps, max_mb=100, min_fps=1, step_fps=5):
    """
    Re-encode MP4 by reducing FPS via frame dropping (stride) until size <= max_mb.
    Streaming approach to avoid loading entire video into RAM.
    """
    tried = []
    fps_list = list(range(orig_fps, min_fps - 1, -step_fps))
    if fps_list[-1] != min_fps:
        fps_list.append(min_fps)

    for fps in fps_list:
        stride = downsample_stride(orig_fps, fps)
        reader = imageio.get_reader(src_path)
        writer = imageio.get_writer(dst_path, format="mp4", fps=fps)
        kept = 0
        for i, frm in enumerate(reader):
            if i % stride == 0:
                writer.append_data(frm)
                kept += 1
        writer.close()
        reader.close()
        sz = file_mb(dst_path)
        tried.append((fps, kept, sz))
        if sz <= max_mb:
            return True, fps, kept, sz, tried

    # If none hit the target, keep the last attempt
    return False, tried[-1][0], tried[-1][1], tried[-1][2], tried

def transcode_gif_under_cap(src_path, dst_path, orig_fps, max_mb=100, min_fps=1, step_fps=5):
    """
    Read frames from MP4 and write a GIF at decreasing FPS (via stride) until <= max_mb.
    Uses per-frame duration metadata; streams to avoid high RAM.
    """
    tried = []
    fps_list = list(range(orig_fps, min_fps - 1, -step_fps))
    if fps_list[-1] != min_fps:
        fps_list.append(min_fps)

    for fps in fps_list:
        stride = downsample_stride(orig_fps, fps)
        reader = imageio.get_reader(src_path)
        writer = imageio.get_writer(dst_path, format="gif")  # streaming GIF writer
        kept = 0
        duration = 1.0 / float(max(1, fps))
        for i, frm in enumerate(reader):
            if i % stride == 0:
                writer.append_data(frm, {"duration": duration})
                kept += 1
        writer.close()
        reader.close()
        sz = file_mb(dst_path)
        tried.append((fps, kept, sz))
        if sz <= max_mb:
            return True, fps, kept, sz, tried

    return False, tried[-1][0], tried[-1][1], tried[-1][2], tried

def record_rollout_to_tmp_mp4(model, seconds=EVAL_SECONDS, fps=ORIG_FPS, tmp_path=TMP_MP4):
    print(f"Evaluating policy for ~{seconds}s -> writing {tmp_path} at {fps} FPS...")
    env = DoomEnv(window_visible=False)
    obs, _ = env.reset()

    writer = imageio.get_writer(tmp_path, format="mp4", fps=fps)
    start = time.time()
    frames = 0

    while time.time() - start < seconds:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, _ = env.step(action)

        # Convert (3,H,W) -> (H,W,3) for video
        frame = np.transpose(obs, (1, 2, 0))
        writer.append_data(frame)
        frames += 1

        if terminated or truncated:
            obs, _ = env.reset()

    writer.close()
    env.close()
    print(f"Wrote temp video: {tmp_path} ({file_mb(tmp_path):.2f} MB) | frames: {frames}")


def main():
    model = train_model()

    # 1) Record first-pass MP4 at ORIG_FPS
    record_rollout_to_tmp_mp4(model, seconds=EVAL_SECONDS, fps=ORIG_FPS, tmp_path=TMP_MP4)

    # 2) Re-encode MP4 to <= 100 MB by lowering FPS if needed
    ok, used_fps, kept, sz, attempts = transcode_mp4_under_cap(
        TMP_MP4, FINAL_MP4, ORIG_FPS, MAX_MB, min_fps=1, step_fps=5
    )
    print("MP4 attempts (fps, frames_kept, size_MB):")
    for a in attempts:
        print(f"  - {a[0]:>2} fps | {a[1]:>5} frames | {a[2]:6.2f} MB")
    if ok:
        print(f"Final MP4: {FINAL_MP4} | {used_fps} fps | {kept} frames | {sz:.2f} MB (<= {MAX_MB} MB)")
    else:
        print(f"Warning: MP4 still > {MAX_MB} MB; final size {sz:.2f} MB at {used_fps} fps.")

    # 3) Create GIF <= 100 MB similarly
    ok_g, gif_fps, gif_kept, gif_sz, gif_attempts = transcode_gif_under_cap(
        TMP_MP4, FINAL_GIF, ORIG_FPS, MAX_MB, min_fps=1, step_fps=5
    )
    print("GIF attempts (fps, frames_kept, size_MB):")
    for a in gif_attempts:
        print(f"  - {a[0]:>2} fps | {a[1]:>5} frames | {a[2]:6.2f} MB")
    if ok_g:
        print(f"Final GIF: {FINAL_GIF} | {gif_fps} fps | {gif_kept} frames | {gif_sz:.2f} MB (<= {MAX_MB} MB)")
    else:
        print(f"Warning: GIF still > {MAX_MB} MB; final size {gif_sz:.2f} MB at {gif_fps} fps.")

    # 4) Cleanup temp
    try:
        os.remove(TMP_MP4)
    except Exception:
        pass

    print("Done. Files saved:", FINAL_MP4, FINAL_GIF)


if __name__ == "__main__":
    main()


Creating headless Doom training env...
Training PPO for 10000 timesteps (no on-screen rendering)...
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 102      |
|    ep_rew_mean     | 77       |
| time/              |          |
|    fps             | 31       |
|    iterations      | 1        |
|    time_elapsed    | 65       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 160         |
|    ep_rew_mean          | 61.3        |
| time/                   |             |
|    fps                  | 9           |
|    iterations           | 2           |
|    time_elapsed         | 437         |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.018895304 |
|   