In [197]:
import copy
import torch
import torch.nn.functional as F
import pandas as pd
import ast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from torch import nn
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO

$\mathcal{L}_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(s, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma\left( \beta \left( \log \frac{\pi_\theta(y_w \mid s)}{\pi_{\text{ref}}(y_w \mid s)} - \log \frac{\pi_\theta(y_l \mid s)}{\pi_{\text{ref}}(y_l \mid s)} \right) \right) \right]
$

$reward = \beta \cdot \log\left( \frac{\pi_{\text{ref}}(a \mid x)}{\pi_\theta(a \mid x)} \right)
$

In [4]:
def dpo_loss(pi_logps_w, pi_logps_l, ref_logps_w, ref_logps_l, beta):
    """
    pi_logps_w: log πθ(y_w | x), shape [B]
    pi_logps_l: log πθ(y_l | x), shape [B]
    ref_logps_w: log πref(y_w | x), shape [B]
    ref_logps_l: log πref(y_l | x), shape [B]
    beta: temperature scaling the KL divergence (float)
    
    Returns:
    - losses: Tensor of shape [B]
    - rewards: Diagnostic reward signal, shape [B]
    """
    # KL-adjusted log-ratio difference
    pi_logratios = pi_logps_w - pi_logps_l
    ref_logratios = ref_logps_w - ref_logps_l

    logits = beta * (pi_logratios - ref_logratios)
    losses = -F.logsigmoid(logits)

    # Diagnostic reward (not used for gradient)
    rewards = beta * (pi_logps_w - ref_logps_w).detach()

    return losses.mean(), rewards

$\log \pi_\theta(\tau |s) = \sum_{t=1}^{T} \log \pi_\theta(a_t \mid s_t)$




In [None]:

#wrapper for policy in DPO 
#pi: a policy stochastic passed with shape n_states times n_actions
class PolicyWrapper_tabular:
    def __init__(self, pi):
        """
        pi: Tensor of shape [num_states, num_actions] with action probabilities
            Make sure it's a valid probability distribution (rows sum to 1)
        """
        self.pi = pi.clamp(min=1e-8)  # prevent log(0)

    def log_prob_trajectory(self, state, trajectories):
        """
        trajectories: List of length B (batch), each a list of (s_t, a_t) pairs
        Returns: Tensor of shape [B] with summed log-probabilities per trajectory
        """
        batch_logps = list()
        for traj in trajectories:
            logps = [torch.log(self.pi[s, a]) for (s, a) in traj]
            traj_logp = torch.stack(logps).sum()
            batch_logps.append(traj_logp)
        return torch.stack(batch_logps)  # shape [B]
        
        

In [194]:
#now suppose policy is NN : 

class Policy(nn.Module): # definie the policy network
    def __init__(self, state_size=4, action_size=2, hidden_size=32, device='cpu'):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, action_size)
        self.device = device
        self.to(device)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = self.fc2(x)
        return F.softmax(x, dim=1) # we just consider 1 dimensional probability of action

    def predict(self, state, deterministic=True):
        """
        SB3-compatible predict method.
        Inputs:
            state: np.ndarray or torch.Tensor of shape [state_dim] or [1, state_dim]
            deterministic: if True, select action with highest probability
        Returns:
            action (int), state (None)
        """
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        elif isinstance(state, torch.Tensor) and state.dim() == 1:
            state = state.unsqueeze(0).float().to(self.device)
        else:
            state = state.float().to(self.device)

        with torch.no_grad():
            probs = self.forward(state).squeeze(0)
            if deterministic:
                action = torch.argmax(probs).item()
            else:
                action = Categorical(probs).sample().item()
        return action, None
    
    def log_prob_trajectory(self, state, trajectories):
        logps_batch = list()
        for traj in trajectories:
            logps = list()
            for s, a in traj:
                s_tensor = torch.tensor(s, dtype=torch.float32).unsqueeze(0).to(self.device)
                probs = self.forward(s_tensor).squeeze(0)
                log_prob = torch.log(probs[a])
                logps.append(log_prob)
            logps_batch.append(torch.stack(logps).sum())
        return torch.stack(logps_batch)
    

In [None]:

#pi should be inherit from nn.Module
#optimizer should be initialize with pi parameters
def update_policy_with_DPO(pi, pref_dataset,nber_epoch, beta, optimizer):
    """
    pi_ref : reference policy
    pref_dataset : preference dataset, dataloader that return state , trajectoryw, trajectoryl
    nber_epoch : number of epoch
    beta the hyperparameter corresponding to the importance of KL div
    optimize for the SGD update
    return the updated policy with DPO
    """
    pi_ref = copy.deepcopy(pi) # at first pi is initialized to pi_ref
    for epoch in range(nber_epoch):
        pbar = tqdm(pref_dataset, desc=f"Epoch {epoch+1}/{nber_epoch}")
        for batch in pbar:
            state, t_w, t_l = batch  # Lists of batch_size elements

            optimizer.zero_grad()

            # t_w and t_l are List[List[Tuple[state, action]]]
            pi_logps_w = pi.log_prob_trajectory(state, t_w)
            pi_logps_l = pi.log_prob_trajectory(state, t_l)

            with torch.no_grad():
                pi_ref_logps_w = pi_ref.log_prob_trajectory(state, t_w)
                pi_ref_logps_l = pi_ref.log_prob_trajectory(state, t_l)

            loss, _ = dpo_loss(pi_logps_w, pi_logps_l, pi_ref_logps_w, pi_ref_logps_l, beta)
            loss.backward()
            optimizer.step()

            # Optionally show loss in tqdm
            pbar.set_postfix({'loss': loss.item()})
    

In [192]:
import pandas as pd
import ast
from torch.utils.data import Dataset

class PrefDataset(Dataset):
    def __init__(self, df):
        self.samples = []
        for _, row in df.iterrows():
            # Convert initial state to tensor
            state = torch.tensor(row['initial_state'], dtype=torch.float32)

            # Convert each step in trajectories to (state_tensor, action_tensor)
            traj_w = [(torch.tensor(step['state'], dtype=torch.float32),
                       torch.tensor(step['action'], dtype=torch.long)) 
                      for step in row['preferred']]

            traj_l = [(torch.tensor(step['state'], dtype=torch.float32),
                       torch.tensor(step['action'], dtype=torch.long)) 
                      for step in row['rejected']]

            self.samples.append((state, traj_w, traj_l))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]
    
def collate_fn(batch):
    # batch is a list of tuples: (initial_state_tensor, traj_w_tensor_list, traj_l_tensor_list)
    initial_states, traj_ws, traj_ls = zip(*batch)
    return list(initial_states), list(traj_ws), list(traj_ls)


In [203]:
# Load and parse CSV
df = pd.read_csv('trajectory_pairs.csv')
for col in ['initial_state', 'preferred', 'rejected']:
    df[col] = df[col].apply(ast.literal_eval)
# Instantiate dataset and dataloader
dataset = PrefDataset(df)
dataloader = DataLoader(
    dataset,
    batch_size=5,
    shuffle=True,
    collate_fn=collate_fn
)
# Initialize policy
policy = Policy()
# Initialize optimizer
optimizer = optim.Adam(policy.parameters(), lr=1e-4)
# training
update_policy_with_DPO(pi=policy, pref_dataset=dataloader,nber_epoch=20, beta=1, optimizer=optimizer)

  s_tensor = torch.tensor(s, dtype=torch.float32).unsqueeze(0).to(self.device)
Epoch 1/20: 100%|██████████| 20/20 [00:03<00:00,  5.38it/s, loss=0.608]
Epoch 2/20: 100%|██████████| 20/20 [00:04<00:00,  4.52it/s, loss=0.62] 
Epoch 3/20: 100%|██████████| 20/20 [00:04<00:00,  4.92it/s, loss=0.568]
Epoch 4/20: 100%|██████████| 20/20 [00:03<00:00,  5.23it/s, loss=0.565]
Epoch 5/20: 100%|██████████| 20/20 [00:03<00:00,  5.11it/s, loss=0.404]
Epoch 6/20: 100%|██████████| 20/20 [00:03<00:00,  5.43it/s, loss=0.533]
Epoch 7/20: 100%|██████████| 20/20 [00:03<00:00,  5.42it/s, loss=0.412]
Epoch 8/20: 100%|██████████| 20/20 [00:04<00:00,  4.87it/s, loss=0.185]
Epoch 9/20: 100%|██████████| 20/20 [00:03<00:00,  5.43it/s, loss=0.259]
Epoch 10/20: 100%|██████████| 20/20 [00:03<00:00,  5.07it/s, loss=0.601]
Epoch 11/20: 100%|██████████| 20/20 [00:03<00:00,  5.16it/s, loss=0.366]
Epoch 12/20: 100%|██████████| 20/20 [00:04<00:00,  4.80it/s, loss=0.589]
Epoch 13/20: 100%|██████████| 20/20 [00:03<00:00,  5.1

In [None]:
#example of evaluation
env = gym.make('CartPole-v0')
env.reset(seed=0)
evaluate_policy(policy, env, n_eval_episodes=10)