In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%ls /content/drive/MyDrive/IU/reinforcement_learning/v8/checkpoints

/content/drive/MyDrive/IU/reinforcement_learning/v6/checkpoints


In [None]:
!apt-get update
!apt-get install -y xvfb python3-opengl ffmpeg
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install ale-py
!pip install torch torchvision torchaudio moviepy matplotlib

import os
os.environ['SDL_VIDEODRIVER'] = 'dummy'

import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing
from gymnasium.wrappers import FrameStackObservation
import ale_py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import random
import math
from pathlib import Path
from datetime import datetime
import json
import pickle
import gc
import psutil

# Directories
checkpoint_dir = Path('/content/drive/MyDrive/IU/reinforcement_learning/v8/checkpoints')
logs_dir = Path('/content/drive/MyDrive/IU/reinforcement_learning/v8/logs')
checkpoint_dir.mkdir(exist_ok=True)
logs_dir.mkdir(exist_ok=True)

# --- Environment ---
env = gym.make('ALE/Pong-v5', render_mode='rgb_array', frameskip=1)
env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=False, terminal_on_life_loss=True)
env = FrameStackObservation(env, stack_size=4)

num_state_feats = env.observation_space.shape
num_actions = env.action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}, state shape: {num_state_feats}, actions: {num_actions}")

# --- Dueling DQN ---
class DuelingDQN(nn.Module):
    """Convolutional neural network for the Atari games."""
    def __init__(self, num_actions):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        std = math.sqrt(2.0 / (4 * 84 * 84))
        nn.init.normal_(self.conv1.weight, 0.0, std)
        self.conv1.bias.data.fill_(0.0)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        std = math.sqrt(2.0 / (32 * 4 * 8 * 8))
        nn.init.normal_(self.conv2.weight, 0.0, std)
        self.conv2.bias.data.fill_(0.0)

        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        std = math.sqrt(2.0 / (64 * 32 * 4 * 4))
        nn.init.normal_(self.conv3.weight, 0.0, std)
        self.conv3.bias.data.fill_(0.0)

        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        std = math.sqrt(2.0 / (64 * 64 * 3 * 3))
        nn.init.normal_(self.fc1.weight, 0.0, std)
        self.fc1.bias.data.fill_(0.0)

        self.V = nn.Linear(512, 1)
        self.A = nn.Linear(512, num_actions)

    def forward(self, x):
        """Forward pass of the neural network with some inputs."""
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc1(x.view(x.size(0), -1))) # Flatten input.
        V = self.V(x)
        A = self.A(x)
        return V + (A - A.mean(dim=1, keepdim=True))

# Create main and target neural networks.
main_nn = DuelingDQN(num_actions).to(device)
target_nn = DuelingDQN(num_actions).to(device)

# Loss function and optimizer.
optimizer = torch.optim.Adam(main_nn.parameters(), lr=1e-5)
loss_fn = nn.SmoothL1Loss() # Huber loss

# --- Epsilon-greedy ---
def select_epsilon_greedy_action(state, epsilon):
    """Take random action with probability epsilon, else take best action."""
    if np.random.rand() < epsilon:
        return env.action_space.sample() # Random action.
    with torch.no_grad():
        qs = main_nn(state).cpu().numpy()
    return int(np.argmax(qs))

# --- Replay Buffer ---
class UniformBuffer:
    """Experience replay buffer that samples uniformly."""
    def __init__(self, size, device):
        self._size = size
        self.buffer = []
        self.device = device
        self._next_idx = 0

    def add(self, state, action, reward, next_state, done):
        if state.dtype != np.uint8:
            state = (state * 255).astype(np.uint8) if state.max() <= 1.0 else state.astype(np.uint8)
        if next_state.dtype != np.uint8:
            next_state = (next_state * 255).astype(np.uint8) if next_state.max() <= 1.0 else next_state.astype(np.uint8)

        # Make sure we're not storing unnecessary copies
        state_compact = np.ascontiguousarray(state)
        next_state_compact = np.ascontiguousarray(next_state)

        if self._next_idx >= len(self.buffer):
            self.buffer.append((state_compact, action, reward, next_state_compact, done))
        else:
            self.buffer[self._next_idx] = (state_compact, action, reward, next_state_compact, done)
        self._next_idx = (self._next_idx + 1) % self._size

    def __len__(self):
        return len(self.buffer)

    def sample(self, num_samples):
        indices = np.random.choice(len(self.buffer), num_samples, replace=False)

        # Pre-allocate arrays for better memory efficiency
        states = np.empty((num_samples, 4, 84, 84), dtype=np.float32)
        next_states = np.empty((num_samples, 4, 84, 84), dtype=np.float32)
        actions = np.empty(num_samples, dtype=np.int64)
        rewards = np.empty(num_samples, dtype=np.float32)
        dones = np.empty(num_samples, dtype=np.float32)

        for idx, i in enumerate(indices):
            s, a, r, ns, d = self.buffer[i]
            states[idx] = s.astype(np.float32) / 255.0
            next_states[idx] = ns.astype(np.float32) / 255.0
            actions[idx] = a
            rewards[idx] = r
            dones[idx] = d

        # Convert to tensors
        states_tensor = torch.from_numpy(states).to(self.device)
        next_states_tensor = torch.from_numpy(next_states).to(self.device)
        actions_tensor = torch.from_numpy(actions).to(self.device)
        rewards_tensor = torch.from_numpy(rewards).to(self.device)
        dones_tensor = torch.from_numpy(dones).to(self.device)

        return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor

# --- Training step with Double DQN ---
def train_step(states, actions, rewards, next_states, dones):
    next_qs_argmax = main_nn(next_states).argmax(dim=-1, keepdim=True)
    masked_next_qs = target_nn(next_states).gather(1, next_qs_argmax).squeeze()
    target = rewards + (1 - dones) * discount * masked_next_qs
    masked_qs = main_nn(states).gather(1, actions.unsqueeze(-1)).squeeze()
    loss = loss_fn(masked_qs, target.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

# Hyperparameters
num_episodes = 5000
epsilon = 1.0
batch_size = 32
discount = 0.99
buffer_size = 25000
save_interval = 50
log_interval = 10

buffer = UniformBuffer(buffer_size, device)

# Training state
training_state = {'episode':0, 'epsilon':epsilon, 'cur_frame':0, 'last_100_ep_rewards':[], 'training_logs':{'episodes':[], 'rewards':[], 'avg_rewards':[], 'losses':[], 'epsilon_values':[], 'frames':[]}}

# Load from checkpoint if available
def load_latest_checkpoint():
    checkpoint_files = list(checkpoint_dir.glob("checkpoint_ep*.pth"))
    if not checkpoint_files:
        print("No checkpoints found. Starting from scratch.")
        return False

    # Find the latest checkpoint by episode number
    latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.stem.split('_')[1][2:]))
    print(f"Loading checkpoint: {latest_checkpoint}")

    try:
        with torch.serialization.safe_globals([np.core.multiarray.scalar]):
            checkpoint = torch.load(latest_checkpoint, map_location=device, weights_only=False)

        # Load model states
        main_nn.load_state_dict(checkpoint['main_nn_state_dict'])
        target_nn.load_state_dict(checkpoint['target_nn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Load training state
        global training_state
        training_state = checkpoint['training_state']
        checkpoint_filename = latest_checkpoint.name
        buffer_filename = f"buffer_{checkpoint_filename.replace('.pth', '.pkl')}"
        buffer_filepath = checkpoint_dir / buffer_filename

        if buffer_filepath.exists():
            print("Loading replay buffer...")
            with open(buffer_filepath, 'rb') as f:
                buffer.buffer = pickle.load(f)
                buffer._next_idx = len(buffer.buffer) % buffer._size
            print(f"Loaded buffer with {len(buffer.buffer)} experiences")
        else:
            print(f"No buffer file found at {buffer_filepath}, starting with empty buffer")

        print(f"Resumed from episode {checkpoint['episode']}, frame {training_state['cur_frame']}")
        return True

    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Starting from scratch.")
        return False

# Load checkpoint if available
checkpoint_loaded = load_latest_checkpoint()

# --- Checkpoints (with buffer size limit) ---
def save_checkpoint(episode, main_nn, target_nn, optimizer, buffer, training_state):
    filename = f"checkpoint_ep{episode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
    filepath = checkpoint_dir/filename
    checkpoint = {
        'episode': episode,
        'main_nn_state_dict': main_nn.state_dict(),
        'target_nn_state_dict': target_nn.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'training_state': training_state
    }
    # Limit buffer saving
    buffer_filepath = checkpoint_dir/f"buffer_{filename.replace('.pth','.pkl')}"
    with open(buffer_filepath,'wb') as f:
        pickle.dump(buffer.buffer[:len(buffer)], f)  # Only save used portion

    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")
    return filepath

# --- Memory monitoring function ---
def print_memory_usage():
    print(f"RAM Usage: {psutil.virtual_memory().percent:.1f}%")
    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")

last_100_ep_rewards = training_state['last_100_ep_rewards']
cur_frame = training_state['cur_frame']
epsilon = training_state['epsilon']
start_episode = training_state['episode']
loss_val = torch.tensor(0.0)

print(f"Starting training from episode {start_episode} with {len(buffer)} experiences in buffer...")
print_memory_usage()

# Start training. Play game once and then train with a batch.
for episode in range(start_episode + 1, num_episodes + 1):  # Start from next episode
    state, _ = env.reset()
    ep_reward, done = 0, False

    # Force garbage collection at episode start
    if episode % 25 == 0:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    while not done:
        # Convert state to tensor for action selection (minimize memory usage)
        state_tensor = torch.from_numpy(state.astype(np.float32) / 255.0).unsqueeze(0).to(device)
        action = select_epsilon_greedy_action(state_tensor, epsilon)

        # Clear the temporary tensor immediately
        del state_tensor

        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        ep_reward += reward
        reward = np.sign(reward)

        # Store in buffer
        buffer.add(state, action, reward, next_state, done)
        state = next_state
        cur_frame += 1
        if epsilon > 0.01:
            epsilon -= 1.1e-6

        if len(buffer) >= batch_size:
            states, actions, rewards, next_states, dones = buffer.sample(batch_size)
            loss_val = train_step(states, actions, rewards, next_states, dones)

            # Clear tensors after training step
            del states, actions, rewards, next_states, dones

        # Copy main_nn weights to target_nn.
        if cur_frame % 10000 == 0:
            target_nn.load_state_dict(main_nn.state_dict())

    if len(last_100_ep_rewards) == 100:
        last_100_ep_rewards.pop(0)
    last_100_ep_rewards.append(ep_reward)

    training_state.update({'episode':episode,'epsilon':epsilon,'cur_frame':cur_frame,'last_100_ep_rewards':last_100_ep_rewards})

    if episode % log_interval == 0:
        avg_reward = np.mean(last_100_ep_rewards) if last_100_ep_rewards else ep_reward
        logs = training_state['training_logs']
        logs['episodes'].append(episode)
        logs['rewards'].append(ep_reward)
        logs['avg_rewards'].append(avg_reward)
        logs['losses'].append(float(loss_val.item()) if isinstance(loss_val, torch.Tensor) else 0.0)
        logs['epsilon_values'].append(epsilon)
        logs['frames'].append(cur_frame)
        print(f"Ep {episode}/{num_episodes}, ε={epsilon:.3f}, Loss={float(loss_val.item()) if isinstance(loss_val, torch.Tensor) else 0.0:.4f}, AvgR={avg_reward:.2f}")
        print_memory_usage()

        with open(logs_dir/"latest_logs.json",'w') as f:
            json.dump(logs,f,indent=2)

    if episode % save_interval == 0:
        save_checkpoint(episode, main_nn, target_nn, optimizer, buffer, training_state)

    # Memory check
    if psutil.virtual_memory().percent > 85:
        print("WARNING: High memory usage detected. Forcing garbage collection...")
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print("Training completed!")
# Final save
save_checkpoint(training_state['episode'], main_nn, target_nn, optimizer, buffer, training_state)

0% [Working]            Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [                                                                               Get:2 https://cli.github.com/packages stable InRelease [3,917 B]
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [                                                                               Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [                                                                               Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
0% [Waiting for headers] [Waiting for headers] [Waiting for headers] [4 InRelea0% [Waiting for headers] [Waiting for headers] [Waiting for headers] [Connected                                