In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_from_disk
import numpy as np
import matplotlib.pyplot as plt
import random

# 固定参数
L = max_len = 7  # 最大步数

def reward_fn(path_idx, stop_probs, lam1=1.0, lam2=1.0):
    n = path_idx
    probs = np.array(stop_probs[1:], dtype=np.float32)
    probs[0] += stop_probs[0]
    l_arr = np.arange(len(probs))
    reward_arr = lam1 * l_arr - lam2 * np.abs(n - l_arr)
    expected_reward = np.sum(probs * reward_arr)
    return expected_reward

class SharedStatesDataset(Dataset):
    def __init__(self, dataset, max_len=7):
        self.samples = []
        for sample in dataset:
            state_seq = [sample[f'eagle_{i}_forward'] for i in range(1, max_len+1)]
            stops = [sample[f'action_{i}']['stop'] for i in range(max_len+1)]
            rewards = [reward_fn(i, stops[i]) for i in range(max_len+1)]
            self.samples.append({
                "states": state_seq,
                "all_rewards": rewards,
                "stop_probs": stops,
            })
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    states = torch.tensor([x["states"] for x in batch], dtype=torch.float32) # [B, 7, state_dim]
    all_rewards = torch.tensor([x["all_rewards"] for x in batch], dtype=torch.float32) # [B, 8]
    stop_probs = [x["stop_probs"] for x in batch]  # list of [8, ?]
    return states, all_rewards, stop_probs

# Actor-Critic 网络定义
class ActorCriticNet(nn.Module):
    def __init__(self, state_dim=10, hidden_dim=128):
        super().__init__()
        self.lstm = nn.LSTM(state_dim, hidden_dim, batch_first=True)
        self.actor_head = nn.Linear(hidden_dim, 2)    # action: stop/go
        self.critic_head = nn.Linear(hidden_dim, 1)   # state value

    def forward(self, x):
        # x: [B, 7, state_dim]
        out, _ = self.lstm(x)  # [B, 7, hidden_dim]
        logits = self.actor_head(out)      # [B, 7, 2]
        values = self.critic_head(out).squeeze(-1)  # [B, 7]
        return logits, values

def sample_trajectory(logits, values, device, greedy=False):
    """
    给定一个样本的logits/values，采样一条轨迹
    返回：
      actions: [traj_len] 0/1, traj_len<=L
      log_probs: [traj_len]
      state_values: [traj_len]
      traj_len: int
    """
    traj_actions = []
    traj_log_probs = []
    traj_values = []
    traj_len = 0
    for t in range(logits.size(0)):
        prob = torch.softmax(logits[t], dim=-1)  # [2]
        m = torch.distributions.Categorical(prob)
        if greedy:
            action = torch.argmax(prob).item()
        else:
            action = m.sample().item()
        log_prob = m.log_prob(torch.tensor(action, device=device))
        traj_actions.append(action)
        traj_log_probs.append(log_prob)
        traj_values.append(values[t])
        traj_len += 1
        if action == 0:  # stop
            break
    return traj_actions, torch.stack(traj_log_probs), torch.stack(traj_values), traj_len

def evaluate_policy(model, data_loader, device, greedy=True):
    model.eval()
    all_lengths = []
    with torch.no_grad():
        for states, _, _ in data_loader:
            states = states.to(device)
            logits, _ = model(states)
            B = states.size(0)
            for b in range(B):
                acts, _, _, traj_len = sample_trajectory(logits[b], torch.zeros(L, device=device), device, greedy=greedy)
                all_lengths.append(traj_len)
    return all_lengths

if __name__ == "__main__":
    dataset_dict = load_from_disk("your_replay_dataset_path")
    train_set = dataset_dict["train"]
    test_set = dataset_dict["test"]
    train_dataset = SharedStatesDataset(train_set, max_len=L)
    test_dataset = SharedStatesDataset(test_set, max_len=L)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

    state_dim = np.array(train_dataset[0]["states"]).shape[1]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ActorCriticNet(state_dim=state_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 10
    train_losses = []
    test_losses = []
    best_test_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_batches = 0
        for states, all_rewards, stop_probs in train_loader:
            states = states.to(device)
            B = states.size(0)
            logits, values = model(states)  # [B, 7, 2], [B, 7]

            batch_loss = 0
            for b in range(B):
                acts, log_probs, value_traj, traj_len = sample_trajectory(logits[b], values[b], device)
                # 对应路径长度的reward
                reward = all_rewards[b, traj_len-1].to(device)
                # 价值估计使用最后一个有效state
                value = value_traj[traj_len-1]
                advantage = reward - value

                policy_loss = -log_probs.sum() * advantage.detach()
                value_loss = advantage.pow(2)
                loss = policy_loss + 0.5 * value_loss
                batch_loss += loss

            batch_loss = batch_loss / B
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()
            total_batches += 1

        avg_train_loss = total_loss / total_batches if total_batches > 0 else 0

        # 测试集loss：采样轨迹，同上
        model.eval()
        total_loss = 0
        total_batches = 0
        with torch.no_grad():
            for states, all_rewards, stop_probs in test_loader:
                states = states.to(device)
                B = states.size(0)
                logits, values = model(states)
                batch_loss = 0
                for b in range(B):
                    acts, log_probs, value_traj, traj_len = sample_trajectory(logits[b], values[b], device)
                    reward = all_rewards[b, traj_len-1].to(device)
                    value = value_traj[traj_len-1]
                    advantage = reward - value
                    policy_loss = -log_probs.sum() * advantage.detach()
                    value_loss = advantage.pow(2)
                    loss = policy_loss + 0.5 * value_loss
                    batch_loss += loss
                batch_loss = batch_loss / B
                total_loss += batch_loss.item()
                total_batches += 1
        avg_test_loss = total_loss / total_batches if total_batches > 0 else 0

        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}")

        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "actor_critic_sampled_traj_best.pt")
            print(f"Best test loss updated: {best_test_loss:.4f}, model checkpoint saved.")

    # 可视化loss
    plt.figure()
    plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
    plt.plot(range(1, num_epochs+1), test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Actor-Critic Sampled Trajectory Loss Curve")
    plt.savefig("actor_critic_sampled_traj_loss_curve.png")
    plt.show()