# Conservative Q-Learning (CQL) for Offline RL

This notebook demonstrates Conservative Q-Learning (CQL), a value-based offline RL method that learns conservative Q-values to avoid distribution shift when training from fixed datasets.

## Why CQL?
- Works well with mixed-quality offline datasets
- No transformers or sequence modeling required
- Simple value-based learning with a conservative penalty

## Workflow
1. Load offline dataset of (s, a, r, s', done)
2. Train Q-network with CQL loss
3. Evaluate learned Q-function on held-out data
4. (Optional) Derive a policy from the Q-function and evaluate in OpenScope

## Prerequisites
- Offline dataset collected from OpenScope (random/heuristic policies)
- GPU recommended for faster training



## Section 1: Setup and Imports

Let's import the necessary modules and set up the environment.


In [None]:
import sys
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import random_split, DataLoader

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from data.offline_dataset import OfflineDatasetCollector, Episode
from training.cql_trainer import CQLTrainer, CQLConfig
from environment.utils import get_device

print("✅ Imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {get_device()}")


## Section 2: Load or Collect Offline Dataset

You can either:
- Load a pre-collected dataset from disk, or
- Collect a small dataset now (slower; requires OpenScope server)

We'll show both patterns.


In [None]:
# Option A: Load pre-collected dataset
from pathlib import Path
from data.offline_dataset import OfflineDatasetCollector

data_path = Path("../data/offline_data.pkl")

if data_path.exists():
    print(f"📦 Loading episodes from {data_path}")
    episodes = OfflineDatasetCollector.load_episodes(str(data_path))
else:
    print("⚠️ No dataset found at ../data/offline_data.pkl. You can collect a small one now.")
    episodes = []

print(f"Episodes loaded: {len(episodes)}")


In [None]:
# Option B: Collect a small dataset now (requires OpenScope server)
from environment import PlaywrightEnv, create_default_config

if len(episodes) == 0:
    try:
        print("🎬 Collecting a small dataset (random policy)...")
        env = PlaywrightEnv(headless=True, timewarp=5, max_aircraft=5, episode_length=300)
        collector = OfflineDatasetCollector(env)
        episodes = collector.collect_random_episodes(num_episodes=25, max_steps=100, verbose=True)
        print(f"Collected {len(episodes)} episodes")
    except Exception as e:
        print(f"Failed to collect data: {e}")



## Section 3: Train CQL

Let's configure and train the Conservative Q-Learning trainer. 

In [None]:
if len(episodes) == 0:
    raise RuntimeError("No offline episodes available. Please load or collect data first.")

# Configure CQL
config = CQLConfig(
    max_aircraft=5,
    num_epochs=10,       # keep small for demo
    batch_size=256,
    learning_rate=3e-4,
    cql_alpha=5.0,
)

trainer = CQLTrainer(config)
history = trainer.train(episodes)

print("\nTraining complete!")
print({k: v[-1] for k, v in history.items() if len(v) > 0})


## Section 4: Visualize Training Metrics

Plot CQL losses over epochs to verify training behavior.


In [None]:
import matplotlib.pyplot as plt

# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history.get("train_losses", []))
axes[0].set_title("Total Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].grid(True)

axes[1].plot(history.get("td_losses", []))
axes[1].set_title("TD Loss")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].grid(True)

axes[2].plot(history.get("cql_losses", []))
axes[2].set_title("CQL Penalty")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Loss")
axes[2].grid(True)

plt.tight_layout()
plt.show()


## Section 5: Greedy Policy Extraction (Optional)

Extract a simple policy by sampling candidate actions and choosing the one with highest Q(s, a). This is approximate and for demonstration only.


In [None]:
import torch

@torch.no_grad()
def sample_greedy_action(q_network, obs, num_candidates: int = 64):
    """Sample num_candidates actions and return the highest-Q one."""
    device = next(q_network.parameters()).device
    # Convert single obs dict to tensors with batch dim 1
    obs_t = {
        'aircraft': torch.from_numpy(obs['aircraft']).float().unsqueeze(0).to(device),
        'aircraft_mask': torch.from_numpy(obs['aircraft_mask']).bool().unsqueeze(0).to(device),
        'global_state': torch.from_numpy(obs['global_state']).float().unsqueeze(0).to(device),
    }

    best_q = None
    best_action = None

    for _ in range(num_candidates):
        # Sample random action
        act = {
            'aircraft_id': torch.randint(0, 6, (1,), device=device),
            'command_type': torch.randint(0, 5, (1,), device=device),
            'altitude': torch.randint(0, 18, (1,), device=device),
            'heading': torch.randint(0, 13, (1,), device=device),
            'speed': torch.randint(0, 8, (1,), device=device),
        }
        q = q_network(obs_t, act).squeeze()
        if best_q is None or q.item() > best_q:
            best_q = q.item()
            best_action = {k: v.item() for k, v in act.items()}

    return best_action, best_q

# Example usage (requires a fresh observation 'obs'):
# obs, _ = env.reset()
# action, q_value = sample_greedy_action(trainer.q_network, obs)
# print(action, q_value)


## Summary and Next Steps

- Trained a Conservative Q-Learning (CQL) Q-network from offline OpenScope trajectories.
- Visualized TD and conservative losses to verify stable training.
- Provided a simple greedy action extractor for demo purposes.

Next:
- Evaluate greedy policy in OpenScope (short rollout).
- Compare CQL vs Decision Transformer on sample efficiency and stability.


# Conservative Q-Learning (CQL) for Offline RL

This notebook demonstrates **Conservative Q-Learning (CQL)**, a value-based offline RL approach that learns Q-functions conservatively to prevent distribution shift.

## Key Concepts

**Traditional Q-Learning:**
- Learns Q(s, a) via TD learning
- Can overestimate Q-values for out-of-distribution actions
- Distribution shift when deploying offline-trained policies

**Conservative Q-Learning (CQL):**
- Adds conservative penalty to Q-learning
- Minimizes Q-values for actions NOT in dataset
- Prevents overestimation of OOD actions
- More robust offline RL than standard Q-learning

## Workflow

1. **Collect offline data** - Random and heuristic policies
2. **Train CQL Q-network** - Learn conservative Q-function
3. **Extract policy** - Greedy policy from learned Q-function
4. **Evaluate** - Test policy on environment
5. **Compare** - CQL vs Decision Transformer vs PPO

## Prerequisites

- OpenScope server running at http://localhost:3003 (for data collection)
- GPU recommended for faster training
- Estimated time: 40-50 minutes (data collection + training)


## 📚 Learning Objectives

By the end of this notebook, you will understand:

1. **Value-Based Offline RL** - Learning Q-functions from fixed datasets
2. **Conservative Regularization** - Preventing distribution shift via conservative updates
3. **CQL Algorithm** - How conservative penalties work in Q-learning
4. **Policy Extraction** - Getting policies from Q-functions (greedy action selection)
5. **CQL vs DT** - When to use value-based vs sequence modeling approaches

**Estimated Time**: 40-50 minutes (includes data collection and training)  
**Prerequisites**: Understanding of Q-learning, basic RL concepts  
**Hardware**: GPU recommended for faster training


## Section 1: Setup and Imports


In [None]:
# Apply nest_asyncio for Jupyter compatibility
import nest_asyncio
nest_asyncio.apply()

import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from environment import PlaywrightEnv, create_default_config
from data.offline_dataset import OfflineDatasetCollector, Episode
from models.q_network import QNetwork
from models.config import create_default_network_config
from training.cql_trainer import CQLTrainer, CQLConfig
from environment.utils import get_device

print("✅ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {get_device()}")


## Section 2: Collect Offline Data

First, we collect offline episodes using random and heuristic policies. This creates a diverse dataset with varying quality.


In [None]:
# Create OpenScope environment
env = PlaywrightEnv(
    airport="KLAS",
    max_aircraft=5,  # Small for faster data collection
    headless=True,
    timewarp=5,
    episode_length=600,  # 10 minutes
)

print("✅ Environment created")

# Initialize collector
collector = OfflineDatasetCollector(env)

# Collect random episodes
print("\n📊 Collecting offline data...")
print("   Random policy episodes...")
random_episodes = collector.collect_random_episodes(
    num_episodes=100,  # Increase for real training
    max_steps=100,
    verbose=True
)

# Collect heuristic episodes
print("\n   Heuristic policy episodes...")
heuristic_episodes = collector.collect_heuristic_episodes(
    num_episodes=100,
    max_steps=100,
    verbose=True
)

# Combine all episodes
all_episodes = random_episodes + heuristic_episodes

print(f"\n✅ Collected {len(all_episodes)} episodes")
print(f"   Total timesteps: {sum(ep.length for ep in all_episodes)}")


## Section 3: Train CQL Q-Network

Now we train the CQL algorithm to learn a conservative Q-function.


In [None]:
# Configure CQL training
config = CQLConfig(
    max_aircraft=5,
    num_epochs=50,  # Reduce for demo
    batch_size=256,
    learning_rate=3e-4,
    cql_alpha=5.0,  # Conservative penalty weight
    gamma=0.99,
    checkpoint_dir="../checkpoints/cql",
    save_every=10,
    device=get_device(),
)

print("📋 CQL Training Configuration:")
print(f"   Max aircraft: {config.max_aircraft}")
print(f"   Epochs: {config.num_epochs}")
print(f"   Batch size: {config.batch_size}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   CQL alpha (conservative penalty): {config.cql_alpha}")
print(f"   Gamma (discount): {config.gamma}")
print(f"   Device: {config.device}")

# Create trainer
trainer = CQLTrainer(config)

print("\n✅ CQL trainer created")

# Train CQL
print("\n🚀 Starting CQL training...")
print("   This learns a conservative Q-function from the offline dataset")
history = trainer.train(all_episodes)

print("\n✅ CQL training complete!")
