In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset


max_len = 7
state_dim = 10

class SharedStatesDataset(Dataset):
    def __init__(self, dataset, max_len=7):
        self.samples = []
        for sample in dataset:
            # 仅用eagle_1_forward ~ eagle_7_forward
            state_seq = [sample[f'eagle_{i}_forward'] for i in range(1, max_len+1)]  # [7, state_dim]
            stops = [sample[f'action_{i}']['stop'] for i in range(max_len+1)]
            # rewards = [reward_fn(i, stops[i]) for i in range(max_len+1)]
            self.samples.append({
                "states": state_seq,     # [7, state_dim]
                # "all_rewards": rewards,  # [8]
                "stop_probs": stops,         # [8]
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    states = torch.tensor([x["states"] for x in batch], dtype=torch.float32) # [B, 7, state_dim]
    # all_rewards = torch.tensor([x["all_rewards"] for x in batch], dtype=torch.float32) # [B, 8]
    stop_probs = [x["stop_probs"] for x in batch]  # [B, 8]
    return states, stop_probs

import datasets

ReplayDataset = datasets.load_from_disk("../data/scores_rb/shareGPT-llama3-d7-topk10-t1")
train_set = ReplayDataset["train"]
train_dataset = SharedStatesDataset(train_set, max_len=max_len)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class LSTMPolicyNet(nn.Module):
    def __init__(self, state_dim=10, lstm_hidden=128, mlp_hidden=128, num_layers=1, dropout=0.1):
        super().__init__()
        self.state_dim = state_dim
        self.lstm_hidden = lstm_hidden
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size=state_dim,
            hidden_size=lstm_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.mlp = nn.Sequential(
            nn.Linear(lstm_hidden + state_dim, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, 2)  # 输出2个动作的logits
        )

    def forward(self, state_seq, hidden=None):
        """
        state_seq: [B, T, state_dim]  # batch, seq_len, state_dim
        hidden: (h0, c0) tuple for LSTM initial state (optional)
        Returns:
            action_logits: [B, T, 2]  # 每个时刻的动作分数
        """
        lstm_out, hidden = self.lstm(state_seq, hidden)  # lstm_out: [B, T, lstm_hidden]
        logits = self.mlp(torch.cat((lstm_out, state_seq), dim=-1))  # [B, T, lstm_hidden + state_dim] -> [B, T, 2]
        return logits, hidden

    def act(self, state_seq, hidden=None, deterministic=False):
        """
        用于采样动作
        state_seq: [B, T, state_dim]
        Returns:
            actions: [B, T]
        """
        logits, hidden = self.forward(state_seq, hidden)
        probs = torch.softmax(logits, dim=-1)
        if deterministic:
            actions = torch.argmax(probs, dim=-1)
        else:
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()
        return actions, probs, hidden

    def reset_hidden(self, batch_size):
        """
        重置LSTM的隐藏状态
        batch_size: int
        Returns:
            hidden: (h0, c0) tuple for LSTM initial state
        """
        h0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        c0 = torch.zeros(self.num_layers, batch_size, self.lstm_hidden).to(next(self.parameters()).device)
        return (h0, c0)

def sample_trajectory(logits, device):
    """
    logits: [L, 2]
    返回：actions, log_probs, traj_len
    """
    actions, log_probs = [], []
    for t in range(logits.size(0)):
        prob = torch.softmax(logits[t], dim=-1)
        m = torch.distributions.Categorical(prob)
        action = m.sample()
        actions.append(action.item())
        log_probs.append(m.log_prob(action))
        if action.item() == 0:  # stop
            return actions, log_probs, t+1
    return actions, log_probs, logits.size(0)

# 假定每条轨迹reward只有终点有（即reward = traj_len - 1）
def get_stepwise_returns(traj_len, L):
    # reward is 0 for all steps except final
    returns = [0.0 for _ in range(traj_len-1)] + [traj_len-1]
    return returns

model = LSTMPolicyNet(state_dim=state_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)  

# 训练循环
for states, stop_probs in train_loader:
    states = states.to(device)
    B = states.size(0)
    logits = model(states)  # [B, 7, 2]
    batch_loss = 0
    for b in range(B):
        acts, log_probs, traj_len = sample_trajectory(logits[b], device)
        returns = get_stepwise_returns(traj_len, 7)
        log_probs = torch.stack(log_probs) # [traj_len]
        returns = torch.tensor(returns, dtype=torch.float32, device=device) # [traj_len]
        loss = - (log_probs * returns).sum()
        batch_loss += loss
    batch_loss = batch_loss / B
    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()