## 策略模型定义

In [None]:
import torch
import torch.nn as nn

class LSTMPolicyNet(nn.Module):
    def __init__(self, state_dim=10, lstm_hidden=128, mlp_hidden=64, num_layers=1, dropout=0.1):
        super().__init__()
        self.state_dim = state_dim
        self.lstm_hidden = lstm_hidden
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size=state_dim,
            hidden_size=lstm_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.mlp = nn.Sequential(
            nn.Linear(lstm_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 2)  # 输出2个动作的logits
        )

    def forward(self, state_seq, hidden=None):
        """
        state_seq: [B, T, state_dim]  # batch, seq_len, state_dim
        hidden: (h0, c0) tuple for LSTM initial state (optional)
        Returns:
            action_logits: [B, T, 2]  # 每个时刻的动作分数
        """
        lstm_out, hidden = self.lstm(state_seq, hidden)  # lstm_out: [B, T, lstm_hidden]
        logits = self.mlp(lstm_out)                      # logits: [B, T, 2]
        return logits, hidden

    def act(self, state_seq, hidden=None, deterministic=False):
        """
        用于采样动作
        state_seq: [B, T, state_dim]
        Returns:
            actions: [B, T]
        """
        logits, hidden = self.forward(state_seq, hidden)
        probs = torch.softmax(logits, dim=-1)
        if deterministic:
            actions = torch.argmax(probs, dim=-1)
        else:
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()
        return actions, probs, hidden

    def reset_hidden(self, batch_size):
        """
        重置LSTM的隐藏状态
        batch_size: int
        Returns:
            hidden: (h0, c0) tuple for LSTM initial state
        """
        h0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        c0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        return (h0, c0)

## 加载ReplayDataset

In [1]:
import datasets

ReplayDataset = datasets.load_from_disk("../data/scores_rb/shareGPT-llama3-d7-topk10-t1")

In [2]:
ReplayDataset.set_format("torch")
type(ReplayDataset["train"][0]["action_7"]["stop"])

torch.Tensor

In [None]:
ReplayDataset["train"][0]["action_7"]["stop"]

DatasetDict({
    train: Dataset({
        features: ['eagle_1_forward', 'eagle_2_forward', 'eagle_3_forward', 'eagle_4_forward', 'eagle_5_forward', 'eagle_6_forward', 'eagle_7_forward', 'eagle_8_forward', 'action_0', 'action_1', 'action_2', 'action_3', 'action_4', 'action_5', 'action_6', 'action_7'],
        num_rows: 150726
    })
    test: Dataset({
        features: ['eagle_1_forward', 'eagle_2_forward', 'eagle_3_forward', 'eagle_4_forward', 'eagle_5_forward', 'eagle_6_forward', 'eagle_7_forward', 'eagle_8_forward', 'action_0', 'action_1', 'action_2', 'action_3', 'action_4', 'action_5', 'action_6', 'action_7'],
        num_rows: 37682
    })
})

In [None]:
def decode_action_seq(action_idx, max_len=7):
    # action_0 对应 [0], action_1 对应 [1, 0], ..., action_7 对应 [1,1,1,1,1,1,1]
    if action_idx < max_len:
        return [1] * action_idx + [0]
    else:
        return [1] * max_len

def build_trajectories(dataset, reward_fun, max_len=7):
    """
    将ReplayDataset的单条样本展开为多条轨迹，返回[{states, actions, reward}]
    """
    trajectories = []
    for sample in dataset:
        for i in range(max_len + 1):  # action_0 ~ action_7
            # 当前轨迹动作序列
            action_seq = decode_action_seq(i, max_len)
            seq_len = len(action_seq)
            # 状态序列
            states = []
            for t in range(seq_len):
                state = sample[f"eagle_{t+1}_forward"]
                states.append(state)
            # 奖励
            stop_probs = sample[f"action_{i}"]["stop"]  # 长度为seq_len+1的概率分布
            reward = reward_fun(i, stop_probs)
            trajectories.append({
                'states': states,
                'actions': action_seq,
                'reward': reward
            })
    return trajectories

## 奖励函数

In [None]:
def reward_fun():
    pass

In [None]:
train_set = ReplayDataset["train"]
train_trajectories = build_trajectories(train_set, reward_fun)

## 训练代码

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class TrajectoryDataset(torch.utils.data.Dataset):
    def __init__(self, trajectories):
        self.trajectories = trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        # 转为float32方便后续处理
        return {
            "states": torch.tensor(traj["states"], dtype=torch.float32),
            "actions": torch.tensor(traj["actions"], dtype=torch.long),
            "reward": torch.tensor(traj["reward"], dtype=torch.float32),
        }

def collate_fn(batch):
    # batch: list of dict
    states = [x["states"] for x in batch]
    actions = [x["actions"] for x in batch]
    rewards = torch.stack([x["reward"] for x in batch])
    max_len = max([s.shape[0] for s in states])
    state_dim = states[0].shape[1]
    batch_size = len(states)
    padded_states = torch.zeros((batch_size, max_len, state_dim), dtype=torch.float32)
    padded_actions = torch.zeros((batch_size, max_len), dtype=torch.long)
    for i in range(batch_size):
        T = states[i].shape[0]
        padded_states[i, :T, :] = states[i]
        padded_actions[i, :T] = actions[i]
    return padded_states, padded_actions, rewards



train_dataset = TrajectoryDataset(train_trajectories)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

# LSTM策略网络
state_dim = train_trajectories[0]["states"][0].__len__()
action_dim = 2
model = LSTMPolicyNet(state_dim=state_dim, action_dim=action_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
    total_loss, total_reward = 0, 0
    for states, actions, rewards in train_loader:
        logits, _ = model(states)
        log_probs = nn.functional.log_softmax(logits, dim=-1)
        action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
        # 这里只使用每条轨迹的总reward进行REINFORCE
        loss = -(action_log_probs.sum(dim=1) * rewards).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_reward += rewards.sum().item()
    print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, AvgReward={total_reward/len(train_dataset):.4f}")

# 保存模型
torch.save(model.state_dict(), "lstm_policy.pt")