The following code is our implementation of Conservative Q-Learning (CQL) for offline RL over the encoded MathDial dataset (we discuss the state-action space more in our paper).

The implementation of CQL (particularly the replay buffer design and loss fuction) draws reference from:
- Aviral Kumar et al. "Conservative Q-Learning for Offline Reinforcement Learning" (https://arxiv.org/abs/2006.04779)
- Reference implementation: https://github.com/aviralkumar2907/CQL (official implementation by paper authors)
- Reference implementation: https://github.com/BY571/CQL (for discrete action space)


In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import json
from collections import defaultdict, Counter
from tqdm import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
"""
Helper to create a train test split since sklearn train_test_split won't work for our dataset's object type. We split the dataframe at episode boundaries using 'done' flags.
"""
def split_dataframe_by_episodes(df, test_size=0.2, val_size=0.1, random_state=42):
    episode_boundaries = []
    start = 0
    for i, row in df.iterrows():
        if row['done'] == 1:
            episode_boundaries.append((start, i + 1))
            start = i + 1
    if start < len(df):
        episode_boundaries.append((start, len(df)))

    train_val_eps, test_eps = train_test_split(
        episode_boundaries,
        test_size=test_size,
        random_state=random_state
    )
    rel_val_size = val_size / (1 - test_size)
    train_eps, val_eps = train_test_split(
        train_val_eps,
        test_size=rel_val_size,
        random_state=random_state
    )

    def collect(indices):
        idx = []
        for s, e in indices:
            idx.extend(range(s, e))
        return df.iloc[idx].reset_index(drop=True)

    train_df = collect(train_eps)
    val_df   = collect(val_eps)
    test_df  = collect(test_eps)
    return train_df, val_df, test_df

def split_dataframe_by_episodes_no_val(df, test_size=0.2, random_state=42):
    episode_boundaries = []
    current_episode_start = 0

    for i, row in df.iterrows():
        if row['done'] == 1:  # end of an episode
            episode_boundaries.append((current_episode_start, i + 1))
            current_episode_start = i + 1

    if current_episode_start < len(df):
        episode_boundaries.append((current_episode_start, len(df)))

    # split episode boundaries
    train_episodes, test_episodes = train_test_split(
        episode_boundaries,
        test_size=test_size,
        random_state=random_state
    )

    train_indices = []
    test_indices = []

    for start, end in train_episodes:
        train_indices.extend(range(start, end))

    for start, end in test_episodes:
        test_indices.extend(range(start, end))

    train_df = df.iloc[train_indices].reset_index(drop=True)
    test_df = df.iloc[test_indices].reset_index(drop=True)

    return train_df, test_df

In [None]:
# Helper class to convert numpy arrays to Pytorch tensors for custom dataset
class TutorDataset(Dataset):
    def __init__(self, states, actions):
        self.states = torch.tensor(states, dtype=torch.float32)
        self.actions = torch.tensor(actions, dtype=torch.long)

    def __len__(self):
        return len(self.actions)

    def __getitem__(self, i):
        return self.states[i], self.actions[i]

# Implements standard BC using FFNN w/ 2 hidden layers, 256 units per layer, ReLU activation, no dropout
class BCModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

def prepare_data_for_bc(df, encoder=None):
    state_features = ['misconception_type','convo_turn','previous_action_id',
                   'listen_to_feedback','problem_progress','progress_delta',
                   'correct_solution','next_action_hint_strength']

    categorical_features = ['misconception_type','previous_action_id',
                 'listen_to_feedback','correct_solution','next_action_hint_strength']
    numerical_features = [f for f in state_features if f not in categorical_features]

    action_column = 'next_action_id'
    df = df.dropna(subset=[action_column])
    df[action_column] = df[action_column].astype(int)
    for feature in categorical_features:
        df[feature] = df[feature].fillna(-1)
    for feature in numerical_features:
        df[feature] = df[feature].fillna(df[feature].mean() if not df[feature].isna().all() else 0)

    if encoder is None:
        encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        encoded_features = encoder.fit_transform(df[categorical_features])
    else:
        encoded_features = encoder.transform(df[categorical_features])

    numerical_data = df[numerical_features].values
    states = np.hstack((numerical_data, encoded_features))
    actions = df[action_column].values
    unique_actions = np.unique(actions)
    action_map = {old_id: new_id for new_id, old_id in enumerate(sorted(unique_actions))}
    remapped_actions = np.array([action_map[a] for a in actions])
    train_states, val_states, train_actions, val_actions = train_test_split(
        states, remapped_actions, test_size=0.2, random_state=42
    )
    return train_states, val_states, train_actions, val_actions, len(unique_actions)

def train_bc_model(train_states, val_states, train_actions, val_actions, num_actions, epochs=50, batch_size=128, lr=1e-3):
    model = BCModel(input_dim=train_states.shape[1], output_dim=num_actions)

    train_ds = TutorDataset(train_states, train_actions)
    val_ds = TutorDataset(val_states, val_actions)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)

    criterion = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    val_losses = []
    best_val_acc = 0.0
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        for S, A in train_loader:
            logits = model(S)
            loss = criterion(logits, A)
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            train_total += A.size(0)
            train_correct += (predicted == A).sum().item()
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        val_loss = 0
        val_total = 0
        val_correct = 0
        with torch.no_grad():
            for S, A in val_loader:
                logits = model(S)
                loss = criterion(logits, A)
                val_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                val_total += A.size(0)
                val_correct += (predicted == A).sum().item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        val_acc = 100 * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_bc_model.pt')

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    print(f'Best val accuracy: {best_val_acc:.2f}%')
    return model

In [None]:
# replay buffer: stores sequences of state-action-reward transitions (s,a,r,s',done)
class TransitionBuffer:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []

    def push(self, s, a, r, s2, done):
        self.states.append(s)
        self.actions.append(a)
        self.rewards.append(r)
        self.next_states.append(s2)
        self.dones.append(done)

    def __len__(self):
        return len(self.actions)

def build_transition_buffer(df, reward_fn, meta_map, orig_to_idx, encoder=None, categorical_features=None, episode_column=None):
    flat_idx_to_meta = {
        idx: {'category': cat, 'strategy': strat, 'level': lvl}
        for cat, strategies in meta_map.items()
        for strat, levels in strategies.items()
        for lvl, idx in levels.items()
    }

    state_features = [
        'misconception_type','convo_turn','previous_action_id',
        'listen_to_feedback','problem_progress','progress_delta',
        'correct_solution','next_action_hint_strength'
    ]
    if categorical_features is None:
        categorical_features = [
            'misconception_type','previous_action_id',
            'listen_to_feedback','correct_solution','next_action_hint_strength'
        ]
    num_feats = [f for f in state_features if f not in categorical_features]

    if encoder is None:
        encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        encoder.fit(df[categorical_features])

    cat_array = encoder.transform(df[categorical_features].values)

    if episode_column:
        episodes = df.sort_values([episode_column, 'convo_turn']).groupby(episode_column)
    else:
        df = df.sort_values('convo_turn')
        df['_tmp_ep'] = 1
        episodes = df.groupby('_tmp_ep')

    buffer = TransitionBuffer()

    # build transitions per episode
    for _, ep_df in episodes:
        prev_state = prev_idx = prev_meta = None

        for idx_row in ep_df.index:
            row = df.loc[idx_row]
            orig_id = row.get('next_action_id')
            if pd.isna(orig_id) or int(orig_id) not in orig_to_idx:
                continue
            action_idx = orig_to_idx[int(orig_id)]

            num = row[num_feats].to_numpy()
            cat = cat_array[idx_row]
            state = np.hstack((num, cat))
            meta = flat_idx_to_meta.get(action_idx)
            done_flag = bool(row.get('done', False))

            if prev_state is not None:
                r = reward_fn(prev_state, prev_idx, state, prev_meta)
                buffer.push(prev_state, prev_idx, r, state, done_flag)

            prev_state, prev_idx, prev_meta = state, action_idx, meta

            if done_flag:
                term_r = new_terminal(state)
                buffer.push(state, action_idx, term_r, state, True)
                prev_state = prev_idx = prev_meta = None

    buffer.states = np.array(buffer.states)
    buffer.actions = np.array(buffer.actions, dtype=int)
    buffer.rewards = np.array(buffer.rewards, dtype=float)
    buffer.next_states = np.array(buffer.next_states)
    buffer.dones = np.array(buffer.dones, dtype=bool)

    return buffer, orig_to_idx


def new_terminal(state):
    MAX_PROGRESS = 50.0 # to make sure the raw progress reward isn't greater than if the solution was just correct
    raw_progress = state[4]  # problem_progress in state
    normalized_progress = min(raw_progress, MAX_PROGRESS) / MAX_PROGRESS
    # want a higher reward if correct solution was achieved
    return 5.0 if state[6] > 0 else 2.0 * normalized_progress

"""
Our custom designed reward function that follows the principal of: scaffolding then telling. More about our reward function design is in the paper.
"""
def hybrid_reward(state, action_id, next_state=None, action_meta=None):
    progress_delta = state[5]
    turn = state[1]
    listen_to_feedback = state[3]

    # progress reward and step penalty
    progress_reward = 5.0 * progress_delta
    step_penalty = -0.1

    # how far into the conversation the current transition is
    turn_progress = min(1.0, turn / 8.0)

    # will be used for the pedagoical/action-specific reward bonus
    strategy_bonus = 0.0
    if action_meta is not None:
        cat = action_meta['category']

        # scaffolding actions (Focus, Probing)
        if cat in ['Focus', 'Probing']:
            strategy_bonus = 0.2 * (1.0 - turn_progress) # decreasing bonus for scaffolding actions as conversation progresses

        # telling actions
        elif cat == 'Telling':
            severity = 1.0
            if 'strategy' in action_meta:
                if action_meta['strategy'] == 'Full Reveal (Answer)':
                    severity = 1.5 # penalize more for strong hints
                elif action_meta['strategy'] == 'Conceptual Hint':
                    severity = 0.6

            # if early in convo: penalty for telling
            early_penalty = -0.3 * (1.0 - turn_progress) * severity
            # if later in convo: bonus for telling so long as effective (the student makes good progress)
            late_bonus = 0.0
            if progress_delta > 0:
                late_bonus = 0.2 * turn_progress

            strategy_bonus = early_penalty + late_bonus

    return progress_reward + step_penalty + 2.0*strategy_bonus

In [None]:
"""
Helper function that is essentially asking: if we started from this random state in the buffer and followed the agent's policy, what rewards would the agent have earned if the environment happened to transition exactly as recorded in the buffer?

Terminates if the conversation is done, or by max steps (should be 80), or if the buffer ends.
"""
def evaluate_policy(agent, buffer, reward_fn, num_episodes=20, max_steps=80, gamma=0.98, action_meta_map=None):
    if hasattr(agent, 'q_net') and hasattr(agent.q_net, 'eval'):
        agent.q_net.eval()
    elif hasattr(agent, 'model') and hasattr(agent.model, 'eval'):
        agent.model.eval()

    flat_idx_to_meta = {}
    if action_meta_map:
        for cat, strategies in action_meta_map.items():
            for strat, levels in strategies.items():
                for lvl, idx in levels.items():
                    flat_idx_to_meta[idx] = {
                        'category': cat,
                        'strategy': strat,
                        'level': lvl
                    }

    total_returns = []
    episode_lengths = []
    termination_reasons = {"max_steps": 0, "done": 0, "buffer_end": 0}


    for ep in range(num_episodes):
        # sample a random starting state
        start_idx = np.random.randint(0, len(buffer) - max_steps)
        if isinstance(buffer.states[start_idx], torch.Tensor):
            state = buffer.states[start_idx].cpu().numpy()
        else:
            state = buffer.states[start_idx].copy()

        episode_return = 0
        episode_length = 0
        discount = 1.0

        for step in range(max_steps):
            # select action using the policy
            action_idx = agent.select_action(state)

            """
            IMPORTANT: the chosen action is ignored (since we don't have a simulator to execute chosen action to transition to next state), so we just move sequentially through buffer.
            Rewards are calculated by the action the agent chose.
            """
            # get next state by sampling from buffer
            next_idx = min(start_idx + step + 1, len(buffer) - 1)

            if isinstance(buffer.states[next_idx], torch.Tensor):
                next_state = buffer.states[next_idx].cpu().numpy()
            else:
                next_state = buffer.states[next_idx].copy()

            # check for done flag
            done = False
            if hasattr(buffer, 'dones') and len(buffer.dones) > next_idx:
                done = buffer.dones[next_idx]

            action_meta = None
            if action_meta_map:
                action_meta = flat_idx_to_meta.get(action_idx)

            # compute reward
            reward = reward_fn(state, action_idx, next_state, action_meta)
            # update return
            episode_return += discount * reward
            discount *= gamma
            episode_length += 1
            # update state
            state = next_state

            # check for termination
            if done:
                termination_reasons["done"] += 1
                break

            # reached max number of steps (some episodes were too long, like > 180 turns)
            if step >= max_steps - 1:
                termination_reasons["max_steps"] += 1
                break

            # reached the end of the buffer
            if next_idx >= len(buffer) - 1:
                termination_reasons["buffer_end"] += 1
                # print(f"  Episode terminated due to buffer end at step {step+1}")
                break

        total_returns.append(episode_return)
        episode_lengths.append(episode_length)

    mean_return = np.mean(total_returns)
    std_return = np.std(total_returns)

    print(f"\nEvaluation over {num_episodes} episodes:")
    print(f"  Mean return: {mean_return:.4f} ± {std_return:.4f}")
    print(f"  Mean episode length: {np.mean(episode_lengths):.2f}")
    print(f"  Termination reasons: {termination_reasons}")

    return mean_return, total_returns

In [None]:
"""
This is a class to wrap the BC model to make it more compatible with the evaluation framework when I do the BC-init (warm start).
1. Takes a trained BC model and an action mapping dictionary
2. Has a select_action method that takes a state and returns an action index
"""
class BCPolicyWrapper:
    def __init__(self, model, action_map, device="cpu"):
        self.model = model.to(device)
        self.q_net = self.model  # alias for compatibility
        self.inv_map = {v: k for k, v in action_map.items()}
        self.device = device

    def select_action(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        else:
            if state.dim() == 1:
                state = state.unsqueeze(0)
            state = state.to(self.device)

        # return the action with the highest logit value (highest prob)
        with torch.no_grad():
            logits = self.model(state)
            action = logits.argmax(dim=1).item()
        return action

def main():
    csv_file_path = "data.csv"

    # predefined action mappings
    action_map = {0: 0, 1: 1, 2: 2, 3: 3, 5: 4, 6: 5, 7: 6, 8: 7, 11: 8, 12: 9,
              13: 10, 16: 11, 17: 12, 18: 13, 20: 14, 21: 15, 22: 16, 23: 17,
              26: 18, 27: 19, 28: 20, 31: 21, 32: 22, 36: 23, 37: 24, 38: 25,
              41: 26, 42: 27, 43: 28, 45: 29, 46: 30, 47: 31, 48: 32, 54: 33,
              55: 34, 56: 35, 57: 36, 58: 37, 59: 38, 60: 39, 65: 40, 66: 41,
              67: 42, 70: 43, 71: 44, 72: 45, 73: 46, 75: 47, 76: 48, 77: 49}

    action_meta_map = {
        "Focus": {
            "Seek Next Step": {1: 0, 2: 1, 3: 2},
            "Confirm Calculation": {1: 5, 2: 6, 3: 7, 4: 8},
            "Re-direct to Sub-Problem": {2: 11, 3: 12, 4: 13},
            "Highlight Missing Info": {2: 16, 3: 17, 4: 18}
        },
        "Probing": {
            "Ask for Explanation": {1: 20, 2: 21, 3: 22, 4: 23},
            "Seek Self-Correction": {2: 26, 3: 27, 4: 28},
            "Hypothetical Variation": {2: 31, 3: 32},
            "Check Understanding/Concept": {2: 36, 3: 37, 4: 38},
            "Encourage Comparison": {2: 41, 3: 42, 4: 43}
        },
        "Telling": {
            "Partial Reveal (Strategy)": {1: 45, 2: 46, 3: 47, 4: 48},
            "Full Reveal (Answer)": {1: 54, 2: 55, 3: 56, 4: 57, 5: 58, 6: 59},
            "Corrective Explanation": {1: 60}
        },
        "Generic": {
            "Acknowledgment/Praise": {1: 65, 2: 66, 3: 67},
            "Summarize Progress": {1: 70, 2: 71, 3: 72, 4: 73},
            "General Inquiry/Filler": {1: 75, 2: 76, 3: 77}
        }
    }

    df = pd.read_csv(csv_file_path)
    # build transition buffer
    buffer, action_mapping = build_transition_buffer(
        df=df,
        reward_fn=hybrid_reward,
        meta_map=action_meta_map,
        orig_to_idx=action_map,
        episode_column=None
    )


if __name__ == "__main__":
    main()

In [None]:
"""
Initialize a CQL agent using weights from a pre-trained BC model
Motivation: to give CQL a good starting point, since BC has already learning meaningful state representations (warm start)
"""
def convert_bc_to_cql_model(bc_model, action_dim):
    input_dim = bc_model.net[0].in_features
    cql_agent = CQLAgent(state_dim=input_dim, action_dim=action_dim)
    cql_agent.q_net.net[0].weight.data.copy_(bc_model.net[0].weight.data)
    cql_agent.q_net.net[0].bias.data.copy_(bc_model.net[0].bias.data)
    cql_agent.q_net.net[2].weight.data.copy_(bc_model.net[2].weight.data)
    cql_agent.q_net.net[2].bias.data.copy_(bc_model.net[2].bias.data)
    cql_agent.target_q_net.load_state_dict(cql_agent.q_net.state_dict())
    return cql_agent


class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)
"""
Why CQL?
It's offline and good to ensure a conservatively learned policy - don't overestimate out of distribution action rewards.

Vanilla QL overestimates values for unseen (s,a) pairs, and so CQL adds a regularization term that penalizes Q values for actions not seen in the dataset to make the policy more conservative. This is important in educaitonal settings that can affect student learning outcomes. Also, we don't have a good simulator to interact with the environment.
"""
class CQLAgent:
    def __init__(
        self,
        state_dim,
        action_dim,
        hidden_dim=256,
        lr=3e-4,
        gamma=0.98,
        tau=0.005,
        cql_alpha=1.0,
        device="cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.action_dim = action_dim
        self.gamma = gamma
        self.tau = tau
        self.cql_alpha = cql_alpha
        self.device = device

        # need 2 Q-networks, target network used for stable learning
        self.q_net = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_q_net = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())

        # freeze target network (prevent moving target problem)
        for param in self.target_q_net.parameters():
            param.requires_grad = False

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

        self.training_stats = {
            'q_loss': [],
            'cql_loss': [],
            'total_loss': [],
            'avg_q_values': []
        }

    def select_action(self, state):
        # select the best action according to the Q-network
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.q_net(state)
            action = q_values.argmax(dim=1).item()
        return action

    # MOST IMPORTANT!
    """
    NOTE: The CQL loss implementation is adapted from:
    https://github.com/BY571/CQL/blob/main/CQL-DQN/agent.py
    with some modifications for our specific use case.
    """
    def update(self, batch):
        # update the Q network using CQL loss
        states, actions, rewards, next_states, dones = batch
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        batch_size = states.shape[0]

        # compute Q values and target Q values
        q_values = self.q_net(states)
        with torch.no_grad():
            next_q_values = self.target_q_net(next_states)
            next_actions = next_q_values.argmax(dim=1)
            next_q = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze()
            target_q = rewards + (1 - dones) * self.gamma * next_q

        # TD error for the sampled actions (like vanilla Q learning)
        # Standard Bellman error
        q_values_sampled = q_values.gather(1, actions.unsqueeze(1)).squeeze() # selects the Q values corresponding to the specific action indices that were chosen in the offline dataset (the action demonstrated by the expert at that state)
        td_loss = F.mse_loss(q_values_sampled, target_q)

        # CQL loss: Minimize Q values for actions not in the dataset
        logsumexp_q = torch.logsumexp(q_values, dim=1)
        """
        logSumExp(x₁, x₂, ..., xₙ) = log(exp(x₁) + exp(x₂) + ... + exp(xₙ))
        A smooth approximation of the maximum function across the action dimension. When one value is significantly larger than the others, it dominates the sum inside the log (and is differentiable)
        """
        # logsumexp_q is the 'soft' maximum Q value across all actions for each state
        # q_values_sampled is the Q value for the actions in the dataset (actions demonstrated by the expert)
        cql_loss = (logsumexp_q - q_values_sampled).mean()
        """
        When the difference between the max  over all action values for each state and the q val of demonstrated actions is LARGE, it means the model is overestimating Q values for actions not in the dataset.
        Th CQL objective is to MINIMIZE this difference by penalizing the model for overestimation.

        This pushes Q values down for all non-dataset actions while preserving the values for actions the expert tutors actually took.
        """
        # total loss
        loss = td_loss + self.cql_alpha * cql_loss # add cql regularization term

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network, periodic updates (low tau, less frequent updates) for stable learning
        for param, target_param in zip(self.q_net.parameters(), self.target_q_net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        self.training_stats['q_loss'].append(td_loss.item())
        self.training_stats['cql_loss'].append(cql_loss.item())
        self.training_stats['total_loss'].append(loss.item())
        self.training_stats['avg_q_values'].append(q_values.mean().item())

        return td_loss.item(), cql_loss.item(), loss.item()

    def save(self, path):
        torch.save({
            'q_net': self.q_net.state_dict(),
            'target_q_net': self.target_q_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'training_stats': self.training_stats
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.q_net.load_state_dict(checkpoint['q_net'])
        self.target_q_net.load_state_dict(checkpoint['target_q_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.training_stats = checkpoint['training_stats']

    def plot_training_curves(self):
        fig, axs = plt.subplots(2, 2, figsize=(15, 10))

        # Plot Q loss
        axs[0, 0].plot(self.training_stats['q_loss'])
        axs[0, 0].set_title('TD Loss')
        axs[0, 0].set_xlabel('Updates')
        axs[0, 0].set_ylabel('Loss')

        # Plot CQL loss
        axs[0, 1].plot(self.training_stats['cql_loss'])
        axs[0, 1].set_title('CQL Loss')
        axs[0, 1].set_xlabel('Updates')
        axs[0, 1].set_ylabel('Loss')

        # Plot total loss
        axs[1, 0].plot(self.training_stats['total_loss'])
        axs[1, 0].set_title('Total Loss')
        axs[1, 0].set_xlabel('Updates')
        axs[1, 0].set_ylabel('Loss')

        # Plot average Q values
        axs[1, 1].plot(self.training_stats['avg_q_values'])
        axs[1, 1].set_title('Average Q Values')
        axs[1, 1].set_xlabel('Updates')
        axs[1, 1].set_ylabel('Q Value')

        plt.tight_layout()
        plt.savefig('cql_training_curves.png')
        plt.close()

def prepare_buffer_for_training(buffer, batch_size=128, action_map=None):
    states = torch.FloatTensor(np.array(buffer.states))

    # map the original action IDs -> model indices
    if action_map:
        mapped_actions = [action_map.get(a, 0) for a in buffer.actions]
        actions = torch.LongTensor(mapped_actions)
    else:
        actions = torch.LongTensor(buffer.actions)

    rewards = torch.FloatTensor(np.array(buffer.rewards))
    next_states = torch.FloatTensor(np.array(buffer.next_states))
    dones = torch.FloatTensor(np.array(buffer.dones))

    dataset = TensorDataset(states, actions, rewards, next_states, dones)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


"""
Train a CQL agent on the given transition buffer
1. Prepare data from the transition buffer
2. Initialize the CQL agent
3. Train over multiple epochs
4. Periodically evaluate performance
4. Save the best model
"""
def train_cql(buffer, state_dim, action_dim, num_epochs=100, batch_size=128,
              hidden_dim=256, lr=3e-4, gamma=0.98, tau=0.005, cql_alpha=1.0, eval_buffer=None,
              eval_interval=5, model_save_path='cql_agent.pt', action_map=None, action_meta_map=None):
    dataloader = prepare_buffer_for_training(buffer, batch_size, action_map)

    agent = CQLAgent(
        state_dim=state_dim,
        action_dim=action_dim,
        hidden_dim=hidden_dim,
        lr=lr,
        gamma=gamma,
        tau=tau,
        cql_alpha=cql_alpha
    )

    print(f"Training CQL agent with {len(buffer)} transitions")
    print(f"State dim: {state_dim}, Action dim: {action_dim}")
    print(f"CQL alpha: {cql_alpha}, Learning rate: {lr}")

    best_eval_return = -float('inf')
    buf = eval_buffer if eval_buffer is not None else buffer
    for epoch in range(num_epochs):
        epoch_td_loss = 0
        epoch_cql_loss = 0
        epoch_total_loss = 0
        num_batches = 0

        agent.q_net.train()
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            td_loss, cql_loss, total_loss = agent.update(batch)
            epoch_td_loss += td_loss
            epoch_cql_loss += cql_loss
            epoch_total_loss += total_loss
            num_batches += 1

            progress_bar.set_postfix({
                'td_loss': td_loss,
                'cql_loss': cql_loss,
                'total_loss': total_loss
            })

        # calculate average losses
        epoch_td_loss /= num_batches
        epoch_cql_loss /= num_batches
        epoch_total_loss /= num_batches
        print(f"\nEpoch {epoch+1}/{num_epochs} - TD Loss: {epoch_td_loss:.4f}, "
              f"CQL Loss: {epoch_cql_loss:.4f}, Total Loss: {epoch_total_loss:.4f}")

        # evaluation
        if (epoch + 1) % eval_interval == 0:
            mean_return, _ = evaluate_policy(
                agent=agent,
                buffer=buf,
                reward_fn=hybrid_reward,
                num_episodes=20,
                max_steps=80,
                action_meta_map=action_meta_map
            )
            print(f"Evaluation - Mean Return: {mean_return:.4f}")

            # save best model
            if mean_return > best_eval_return:
                best_eval_return = mean_return
                agent.save(model_save_path)
                print(f"New best model saved with return: {best_eval_return:.4f}")

    agent.plot_training_curves()
    agent.load(model_save_path)
    final_mean_return, final_returns = evaluate_policy(
        agent=agent,
        buffer=buf,
        reward_fn=hybrid_reward,
        num_episodes=50,
        max_steps=60
    )
    print(f"\nFINAL: Mean Return: {final_mean_return:.4f}")

    return agent



In [None]:
def compute_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    sem = stats.sem(data)
    h = sem * stats.t.ppf((1 + confidence) / 2, n-1)
    return mean, mean - h, mean + h

class RandomPolicy:
    def __init__(self, action_dim):
        self.action_dim = action_dim
    def select_action(self, state):
        return np.random.randint(self.action_dim)

def evaluate_all_models_on_same_states(
    agents, labels, buffer, reward_fn,
    num_episodes=50, max_steps=80, action_meta_map=None
):

    flat_idx_to_meta = {
        idx: {'category': cat, 'strategy': strat, 'level': lvl}
        for cat, strategies in action_meta_map.items()
        for strat, levels in strategies.items()
        for lvl, idx in levels.items()
    }

    # sample once
    start_indices = np.random.randint(0, len(buffer) - max_steps, size=num_episodes)

    # prepare storage
    all_returns = {lbl: np.zeros(num_episodes) for lbl in labels}

    for ep_idx, start_idx in enumerate(start_indices):
        for agent, lbl in zip(agents, labels):
            state = buffer.states[start_idx].copy()
            episode_return = 0.0
            discount = 1.0

            for step in range(max_steps):
                action_idx = agent.select_action(state)
                next_idx = min(start_idx + step + 1, len(buffer) - 1)
                next_state = buffer.states[next_idx].copy()
                done = bool(buffer.dones[next_idx]) if hasattr(buffer, 'dones') else False
                action_meta = flat_idx_to_meta.get(action_idx)
                r = reward_fn(state, action_idx, next_state, action_meta)

                episode_return += discount * r
                discount *= 0.98
                state = next_state
                if done or next_idx >= len(buffer)-1:
                    break

            all_returns[lbl][ep_idx] = episode_return

    return all_returns

def collect_seed_stats(seed, agents, labels, buffer, reward_fn, action_meta_map, num_episodes=100, max_steps=60):
    np.random.seed(seed)
    results = evaluate_all_models_on_same_states(
        agents,
        labels,
        buffer=buffer,
        reward_fn=reward_fn,
        num_episodes=num_episodes,
        max_steps=max_steps,
        action_meta_map=action_meta_map
    )
    seed_dict = {}
    # compute mean and 95% CI for each method
    for lbl in labels:
        arr = results[lbl]
        mean, lo, hi = compute_confidence_interval(arr)
        seed_dict[lbl] = {'mean': mean, 'ci': [lo, hi]}
    # paired t-tests
    paired = {}
    for a, b in [('BC', 'Random'), ('CQL', 'Random'), ('BC_init', 'Random'), ('CQL', 'BC')]:
        t, p = stats.ttest_rel(results[a], results[b])
        paired[f'{a}_vs_{b}'] = {'t': t, 'p': p}
    seed_dict['paired_tests'] = paired
    return seed_dict

def main():
    # np.random.seed(22)
    # torch.manual_seed(22)

    # 1. load full dataset
    df = pd.read_csv("data.csv")

    # 2. create global encoder on full data
    categorical_features = ['misconception_type', 'previous_action_id',
                           'listen_to_feedback', 'correct_solution',
                           'next_action_hint_strength']

    global_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    global_encoder.fit(df[categorical_features])

    # 3. split into train/test
    train_df, val_df, test_df = split_dataframe_by_episodes(df, test_size=0.2, val_size=0.1, random_state=22)

    # 4. train BC on TRAINING DATA only
    train_states, val_states, train_actions, val_actions, num_actions = prepare_data_for_bc(train_df, encoder=global_encoder)
    bc_model = train_bc_model(train_states, val_states, train_actions, val_actions, num_actions)

    # define action mappings
    # original action IDs to model indices (0-49)
    action_map = {0: 0, 1: 1, 2: 2, 3: 3, 5: 4, 6: 5, 7: 6, 8: 7, 11: 8, 12: 9,
                 13: 10, 16: 11, 17: 12, 18: 13, 20: 14, 21: 15, 22: 16, 23: 17,
                 26: 18, 27: 19, 28: 20, 31: 21, 32: 22, 36: 23, 37: 24, 38: 25,
                 41: 26, 42: 27, 43: 28, 45: 29, 46: 30, 47: 31, 48: 32, 54: 33,
                 55: 34, 56: 35, 57: 36, 58: 37, 59: 38, 60: 39, 65: 40, 66: 41,
                 67: 42, 70: 43, 71: 44, 72: 45, 73: 46, 75: 47, 76: 48, 77: 49}

    action_meta_map = {
        "Focus": {
            "Seek Next Step": {1: 0, 2: 1, 3: 2},
            "Confirm Calculation": {1: 5, 2: 6, 3: 7, 4: 8},
            "Re-direct to Sub-Problem": {2: 11, 3: 12, 4: 13},
            "Highlight Missing Info": {2: 16, 3: 17, 4: 18}
        },
        "Probing": {
            "Ask for Explanation": {1: 20, 2: 21, 3: 22, 4: 23},
            "Seek Self-Correction": {2: 26, 3: 27, 4: 28},
            "Hypothetical Variation": {2: 31, 3: 32},
            "Check Understanding/Concept": {2: 36, 3: 37, 4: 38},
            "Encourage Comparison": {2: 41, 3: 42, 4: 43}
        },
        "Telling": {
            "Partial Reveal (Strategy)": {1: 45, 2: 46, 3: 47, 4: 48},
            "Full Reveal (Answer)": {1: 54, 2: 55, 3: 56, 4: 57, 5: 58, 6: 59},
            "Corrective Explanation": {1: 60}
        },
        "Generic": {
            "Acknowledgment/Praise": {1: 65, 2: 66, 3: 67},
            "Summarize Progress": {1: 70, 2: 71, 3: 72, 4: 73},
            "General Inquiry/Filler": {1: 75, 2: 76, 3: 77}
        }
    }

    # build the transition buffers (train/test) with reward function
    train_buffer, _ = build_transition_buffer(
        df=train_df,
        reward_fn=hybrid_reward,
        meta_map=action_meta_map,
        orig_to_idx=action_map,
        encoder=global_encoder,
        categorical_features=categorical_features
    )

    val_buffer, _ = build_transition_buffer(
        df=val_df,
        reward_fn=hybrid_reward,
        meta_map=action_meta_map,
        orig_to_idx=action_map,
        encoder=global_encoder,
        categorical_features=categorical_features
    )

    test_buffer, _ = build_transition_buffer(
        df=test_df,
        reward_fn=hybrid_reward,
        meta_map=action_meta_map,
        orig_to_idx=action_map,
        encoder=global_encoder,
        categorical_features=categorical_features
    )

    # print buffer statistics
    print(f"*Train Buffer Statistics*")
    print(f"\tTotal transitions: {len(train_buffer)}")
    print(f"\tState dimension: {train_buffer.states.shape[1]}")
    print(f"\tNumber of unique actions: {len(set(train_buffer.actions))}")

    # print buffer statistics
    print(f"*Test Buffer Statistics*")
    print(f"\tTotal transitions: {len(test_buffer)}")
    print(f"\tState dimension: {test_buffer.states.shape[1]}")
    print(f"\tNumber of unique actions: {len(set(test_buffer.actions))}")

    # 1: Training CQL from scratch
    print("\nTraining CQL from scratch!")
    cql_agent = train_cql(
        buffer=train_buffer,
        eval_buffer=val_buffer,
        state_dim=train_buffer.states.shape[1],
        action_dim=num_actions,
        num_epochs=100,
        batch_size=128,
        hidden_dim=256,
        lr=1e-4,
        gamma=0.98,
        tau=0.005,
        cql_alpha=2.0,
        model_save_path='best_cql_model.pt',
        action_map=action_map,
        action_meta_map=action_meta_map
    )

    # 2: Initializing CQL with BC weights
    print("\nTraining BC-init CQL!")
    bc_initialized_cql = convert_bc_to_cql_model(bc_model, num_actions)
    bc_initialized_cql.save('bc_initialized_cql.pt')

    bc_init_cql_agent = train_cql(
        buffer=train_buffer,
        eval_buffer=val_buffer,
        state_dim=train_buffer.states.shape[1],
        action_dim=num_actions,
        num_epochs=40,
        batch_size=128,
        hidden_dim=256,
        lr=3e-5,
        gamma=0.98,
        tau=0.005,
        cql_alpha=0.85,
        model_save_path='best_initialized_cql_model.pt',
        action_map=action_map,
        action_meta_map=action_meta_map
    )

    bc_agent = BCPolicyWrapper(bc_model, action_map, device=cql_agent.device)

    # make random agent
    random_agent = RandomPolicy(num_actions)

    agents = [bc_agent, cql_agent, bc_init_cql_agent, random_agent]
    labels = ['BC', 'CQL', 'BC_init', 'Random']
    # run over 5 seeds
    seeds = [0, 1, 2, 3, 4]
    all_seeds = [collect_seed_stats(s,
                                    agents,
                                    labels,
                                    buffer=test_buffer,
                                    reward_fn=hybrid_reward,
                                    action_meta_map=action_meta_map)
                 for s in seeds]

    np.save('all_seeds.npy', all_seeds)

if __name__ == "__main__":
    main()