## 策略模型定义

In [None]:
import torch
import torch.nn as nn

class LSTMPolicyNet(nn.Module):
    def __init__(self, state_dim=10, lstm_hidden=128, mlp_hidden=64, num_layers=1, dropout=0.1):
        super().__init__()
        self.state_dim = state_dim
        self.lstm_hidden = lstm_hidden
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size=state_dim,
            hidden_size=lstm_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.mlp = nn.Sequential(
            nn.Linear(lstm_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 2)  # 输出2个动作的logits
        )

    def forward(self, state_seq, hidden=None):
        """
        state_seq: [B, T, state_dim]  # batch, seq_len, state_dim
        hidden: (h0, c0) tuple for LSTM initial state (optional)
        Returns:
            action_logits: [B, T, 2]  # 每个时刻的动作分数
        """
        lstm_out, hidden = self.lstm(state_seq, hidden)  # lstm_out: [B, T, lstm_hidden]
        logits = self.mlp(lstm_out)                      # logits: [B, T, 2]
        return logits, hidden

    def act(self, state_seq, hidden=None, deterministic=False):
        """
        用于采样动作
        state_seq: [B, T, state_dim]
        Returns:
            actions: [B, T]
        """
        logits, hidden = self.forward(state_seq, hidden)
        probs = torch.softmax(logits, dim=-1)
        if deterministic:
            actions = torch.argmax(probs, dim=-1)
        else:
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()
        return actions, probs, hidden

    def reset_hidden(self, batch_size):
        """
        重置LSTM的隐藏状态
        batch_size: int
        Returns:
            hidden: (h0, c0) tuple for LSTM initial state
        """
        h0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        c0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        return (h0, c0)

## 加载ReplayDataset

In [1]:
import datasets

ReplayDataset = datasets.load_from_disk("../data/scores_rb/shareGPT-llama3-d7-topk10-t1")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ReplayDataset.set_format("torch")
type(ReplayDataset["train"][0]["action_7"]["stop"])

torch.Tensor

In [2]:
ReplayDataset["train"][0]["action_7"]["stop"]

[0.0061492919921875,
 0.08392333984375,
 0.03704833984375,
 0.1397705078125,
 0.07598876953125,
 0.007415771484375,
 0.021484375,
 0.2437744140625,
 0.3843262195587158]

## 奖励函数

In [None]:
import numpy as np
# 1. 奖励函数（期望reward）
def reward_fn(action_idx, stop_probs, lam1=1.0, lam2=1.0):
    n = action_idx
    probs = stop_probs[1:]
    probs[0] += stop_probs[0] #0和1的stop概率合并
    l_arr = np.arange(len(probs))
    reward_arr = lam1 * l_arr - lam2 * np.abs(n - l_arr)
    expected_reward = np.sum(probs * reward_arr)
    return expected_reward

In [None]:
from torch.utils.data import Dataset
# 2. 轨迹数据集预处理
def decode_action_seq(action_idx, max_len=7):
    # action_0: [0], action_1: [1,0], ..., action_7: [1,1,1,1,1,1,1]
    if action_idx < max_len:
        return [1] * action_idx + [0]
    else:
        return [1] * max_len

def build_trajectories(dataset, max_len=7):
    trajectories = []
    for sample in dataset:
        for i in range(max_len + 1):  # action_0~action_7
            action_seq = decode_action_seq(i, max_len)
            seq_len = len(action_seq)
            # eagle_i_forward为状态，假设为一维/多维float
            states = [sample[f"eagle_{t+1}_forward"] for t in range(seq_len)]
            stop_probs = sample[f"action_{i}"]["stop"]  # shape: [seq_len+1]
            reward = reward_fn(i, stop_probs)
            trajectories.append({
                "states": states,
                "actions": action_seq,
                "reward": reward
            })
    return trajectories

# 3. PyTorch Dataset和collate函数
class TrajectoryDataset(Dataset):
    def __init__(self, trajectories):
        self.trajectories = trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        return {
            "states": torch.tensor(traj["states"], dtype=torch.float32),
            "actions": torch.tensor(traj["actions"], dtype=torch.long),
            "reward": torch.tensor(traj["reward"], dtype=torch.float32),
        }

def collate_fn(batch):
    states = [x["states"] for x in batch]
    actions = [x["actions"] for x in batch]
    rewards = torch.stack([x["reward"] for x in batch])
    max_len = max(s.shape[0] for s in states)
    state_dim = states[0].shape[1] if states[0].dim() > 1 else 1
    batch_size = len(states)
    padded_states = torch.zeros((batch_size, max_len, state_dim), dtype=torch.float32)
    padded_actions = torch.zeros((batch_size, max_len), dtype=torch.long)
    mask = torch.zeros((batch_size, max_len), dtype=torch.bool)
    for i in range(batch_size):
        T = states[i].shape[0]
        if state_dim > 1:
            padded_states[i, :T, :] = states[i]
        else:
            padded_states[i, :T, 0] = states[i].squeeze(-1)
        padded_actions[i, :T] = actions[i]
        mask[i, :T] = 1
    return padded_states, padded_actions, rewards, mask


## 训练代码

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


train_set = ReplayDataset["train"]
train_trajectories = build_trajectories(train_set)
train_dataset = TrajectoryDataset(train_trajectories)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

state_dim = train_trajectories[0]["states"][0].__len__()
model = LSTMPolicyNet(state_dim=state_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    total_loss = 0
    for states, actions, rewards, mask in train_loader:
        logits = model(states)
        log_probs = nn.functional.log_softmax(logits, dim=-1)
        action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)  # [batch, seq]
        masked_action_log_probs = action_log_probs * mask  # 只累计真实步数部分
        # sum over trajectory, then REINFORCE loss
        loss = -(masked_action_log_probs.sum(dim=1) * rewards).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

torch.save(model.state_dict(), "lstm_policy_expected_reward.pt")