In [1]:
# 2026/2/6
# zhangzhong
# REINFORCE with baseline, monet caro actor aritic

In [2]:
from torch import nn, Tensor
from torch.distributions import Categorical
import torch
from torch.optim import Adam
from tqdm import tqdm

actions = ["LEFT", "RIGHT"]


class Environment:

    # TODO: discrete state mapping to embedding
    def __init__(self, goal: int = 4):
        self.goal: int = goal
        self.state: int = 0
        self.actions = ["LEFT", "RIGHT"]

    def step(self, action: str) -> tuple[float, int]:
        if action == "LEFT":
            self.state = max(0, self.state-1)
        elif action == "RIGHT":
            self.state = min(self.goal, self.state+1)
        # return (10 if self.is_done() else -1, self.state)
        return (-1 if action == "LEFT" else 1, self.state)
    
    def is_done(self) -> bool:
        return self.state == self.goal
    
    def reset(self):
        self.state = 0

In [3]:
class Agent(nn.Module):
    def __init__(self, env: Environment) -> None:
        super().__init__()

        # BUG: out of index
        # self.state_embedding = nn.Embedding(num_embeddings=env.goal, embedding_dim=8)
        self.state_embedding = nn.Embedding(num_embeddings=env.goal + 1, embedding_dim=8)
        self.env = env

        self.policy_net = nn.Sequential(
            nn.Linear(in_features=8, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2)
        )

    def curr_state(self) -> Tensor:
        return torch.tensor(data=[self.env.state])

    def get_action(self, state: Tensor) -> tuple[str, Tensor]:
        with torch.no_grad():
            embeddings: Tensor = self.state_embedding(state)
            logits: Tensor = self.policy_net(embeddings)
            index = logits.argmax()
            print(index)
            
            probs = Categorical(logits=logits)
            print(probs)
            actions = probs.sample()
            # print(actions)
            return self.env.actions[actions.int()], probs.log_prob(value=actions)
            # log_probs = probs.log_prob(actions)

    # state
    # https://docs.pytorch.org/docs/stable/distributions.html#categorical
    def sample_action(self, state: Tensor) -> tuple[str, Tensor]:

        # 我们在sample的时候就能计算出log_probs来，要不也顺便返回一下吧

        embeddings: Tensor = self.state_embedding(state)
        logits = self.policy_net(embeddings)    
        probs = Categorical(logits=logits)   
        actions = probs.sample()
        # print(actions)
        return self.env.actions[actions.int()], probs.log_prob(value=actions)
        # log_probs = probs.log_prob(actions)


In [4]:
env = Environment()
agent = Agent(env=env)
agent.sample_action(state=torch.tensor(data=[0]))

('LEFT', tensor([-0.4201], grad_fn=<SqueezeBackward1>))

In [5]:
class ValueNet(nn.Module):
    def __init__(self, env: Environment) -> None:
        super().__init__()

        self.env = env

        self.state_embedding = nn.Embedding(num_embeddings=env.goal + 1, embedding_dim=8)

        self.value_net = nn.Sequential(
            nn.Linear(in_features=8, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=1)
        )

    def forward(self, state: Tensor) -> Tensor:
        embeddings = self.state_embedding(state)
        return self.value_net(embeddings)

In [6]:
# sample episode

batch_size = 128

def sample_episode(env: Environment, agent: Agent) -> list[tuple[Tensor, str, float, Tensor]]:
    env.reset()

    # episode 要收集什么？ (st: tensor, at: str, rt: float)
    # ()
    episode = []
    curr_state: Tensor = agent.curr_state()

    for t in range(100):
        # initial state = ?


        if env.is_done():
            break
        # 
        action, log_prob = agent.sample_action(state=curr_state)
        # 
        reward, next_state = env.step(action=action)
        
        episode.append((curr_state, action, reward, log_prob))
        curr_state = torch.tensor(data=[next_state])

        # BUG: 如果env已经done了，上面还会采样状态，就是不对的
        # 应该在最开始就判断env是不是已经结束
        # if env.is_done():
        #     break


    return episode

In [7]:
# collect on episode

episode = sample_episode(env, agent)
print(episode)

[(tensor([0]), 'RIGHT', 1, tensor([-1.0700], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.4928], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', 1, tensor([-1.0700], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.4928], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'LEFT', -1, tensor([-0.4201], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'LEFT', -1, tensor([-0.4201], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', 1, tensor([-1.0700], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.4928], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', 1, tensor([-1.0700], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.4928], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'LEFT', -1, tensor([-0.4201], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'LEFT', -1, tensor([-0.4201], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', 1, tensor([-1.0700], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'RIGHT', 1, tensor([-0.

In [8]:
## calculate rewards-to-go

from typing import Iterator


def calculate_rewards_to_go(rewards: list[float], gamma: float = 1.0) -> list[float]:
    reversed_rewards: Iterator[float] = reversed(rewards)
    rewards_to_go: list[float] = []
    curr_reward = 0
    for reward in reversed_rewards:
        curr_reward = reward + gamma*curr_reward
        rewards_to_go.append(curr_reward)
    return list(reversed(rewards_to_go))


In [9]:
## test rewards to go
rewards = [1.0, 1.0, 1.0]
print(calculate_rewards_to_go(rewards, gamma=1.0))
print(calculate_rewards_to_go(rewards, gamma=0.9))

[3.0, 2.0, 1.0]
[2.71, 1.9, 1.0]


In [10]:
# deal with rewards sampled from a episode

def make_rewards_to_go(episode: list, gamma: float) -> list[float]:
    rewards: list[float] = [reward for _, _, reward, _ in episode]
    rewards_to_go = calculate_rewards_to_go(rewards=rewards, gamma=gamma)
    return rewards_to_go

In [11]:
# collect one batch of data, how?
# 这里有两种实现的方法
# 一种是严格的按照公式来计算，每个episode都计算完成之后，然后再计算N个episode
# 还有一种就是每个step看成是独立的，就收集batch size个step，然后就计算就行

# 从整体实现上来将，肯定是第二种简单，就按照这种来吧，从数学上是一样的
# 咱们先不写reward to go了

gamma = 0.99

# 还有一个问题，episode的每个step对应的reward应该是rewards to go
# 可以写一个函数来计算一下

def collect_training_data() -> tuple[list[Tensor], list[str], list[Tensor], list[float]]:
    batch_log_probs = []
    batch_rewards_to_go = []
    batch_states = []
    batch_actions = []
    while True:
        episode = sample_episode(env=env, agent=agent)
        rewards_to_go = make_rewards_to_go(episode=episode, gamma=gamma)
        for (curr_state, curr_action, _, log_prob), reward in zip(episode, rewards_to_go):
            # print(curr_state, curr_action, log_prob, reward)
            
            batch_states.append(curr_state)
            batch_actions.append(curr_action)
            # TODO(policy-gradient / PPO-ready):
            # 当前实现中在采样阶段直接保存 log_prob（带计算图），
            # 并在同一轮中立刻用于 policy_loss.backward()。
            #
            # 这种方式在「一次采样 -> 一次更新」的 Monte Carlo REINFORCE / Actor-Critic 中是数学正确的，
            # 但有以下限制：
            # - 不能在同一批数据上做多次 policy update（如 PPO 的多 epoch 更新）
            # - 不能在参数更新后复用旧 trajectory
            #
            # 如果未来实现 PPO / 多次更新：
            # - 应改为存 (state, action)
            # - 在 update 阶段重新计算 log_prob
            # 这是 Spinning Up / PPO 的标准做法。
            batch_log_probs.append(log_prob)
            batch_rewards_to_go.append(reward)
            if (len(batch_log_probs) >= batch_size):
                break

        if (len(batch_log_probs) >= batch_size):
            break

    return batch_states, batch_actions, batch_log_probs, batch_rewards_to_go

# 我们的一个batch的数据需要什么？
# optimizer.zero_grad()
# compute_loss -> loss
# loss.backward()
# optimizer.step()
# 

In [12]:
batch_states, batch_actions, batch_log_probs, batch_rewards_to_go = collect_training_data()
print(batch_log_probs, batch_rewards_to_go)

[tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-1.0700], grad_fn=<SqueezeBackward1>), tensor([-0.4928], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-1.0700], grad_fn=<SqueezeBackward1>), tensor([-0.4928], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-1.0700], grad_fn=<SqueezeBackward1>), tensor([-0.9440], grad_fn=<SqueezeBackward1>), tensor([-0.8148], grad_fn=<SqueezeBackward1>), tensor([-0.8638], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-1.0700], grad_fn=<SqueezeBackward1>), tensor([-0.4928], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-0.4201], grad_fn=<SqueezeBackward1>), tensor([-1.0700], grad_fn=<SqueezeBackward1>), tensor([-0.9

In [13]:
def compute_value_loss(values: Tensor, batch_rewards_to_go: list[float]) -> Tensor:
    rewards = torch.tensor(batch_rewards_to_go)
    # assert values.squeeze(-1).shape == rewards.shape

    loss = torch.mean((values.squeeze(dim=-1)-rewards).pow(2))
    return loss

def compute_loss(batch_log_probs: list[Tensor], batch_rewards_to_go: list[float], batch_vt: list[float] = [0]) -> Tensor:
    grads = [log_probs*(rewards - vt) for log_probs, rewards, vt in zip(batch_log_probs, batch_rewards_to_go, batch_vt)]

    # BUG!!! 这一步是构建一个新的tensor，然后把旧的值给复制过来，所以不会保留计算图，requireds_grad=False
    # grads_tensor = torch.tensor(grads)
    grads_tensor = torch.stack(grads)
    assert grads_tensor.shape == (batch_size, 1)
    loss = -grads_tensor.mean()
    return loss




In [14]:
def train_one_epoch(env: Environment, agent: Agent, value_net: ValueNet, optimizer4agent, optimizer4value):
    # 1. collect training data of this epoch
    batch_states, batch_actions, batch_log_probs, batch_rewards_to_go = collect_training_data()

    # 2. zero policy net grads
    optimizer4agent.zero_grad()
    # 应该在这里吧
    optimizer4value.zero_grad()

    # 3. compute vt by using value net
    values = value_net.forward(state=torch.stack(batch_states).squeeze(dim=-1))
    assert values.shape == (batch_size, 1)
    # this must use detach, to isolate loss of policy net and value net
    # 这确保 policy loss 不会把梯度传进 value_net
    # TODO(actor-critic):
    # 当前 actor 和 critic 使用的是完全独立的网络参数，因此：
    # - policy_loss.backward() 只会更新 actor
    # - value_loss.backward() 只会更新 critic
    # 这是在数学和工程上都成立的。
    #
    # ⚠️ 注意：如果未来引入 shared embedding / shared trunk（actor-critic 共享部分参数），
    # 必须严格控制梯度流向：
    # - policy loss 中的 advantage 必须对 value.detach()
    # - actor 与 critic 的 backward / optimizer.step 需要明确隔离
    # 否则 critic 会被 policy loss 的梯度“错误地拖动”，导致训练不稳定。
    batch_vt: list[float] = values.detach().squeeze(dim=-1).tolist()

    # 4. policy compute loss
    # 大问题！我这样计算出来的loss是没有梯度的！为什么？
    policy_loss = compute_loss(batch_log_probs, batch_rewards_to_go, batch_vt=batch_vt)

    # 5. policy net backward propogation
    policy_loss.backward()
    # 1) actor must have grads
    assert any(p.grad is not None for p in agent.parameters())
    # 2) critic must NOT get grads from policy loss
    assert all(p.grad is None for p in value_net.parameters())

    # 6. optimize policy net
    optimizer4agent.step()

    ## 7. compute vt loss and optimize
    # BUG！不对，这个optimizer 不应该在这里
    # 经过测试，这个optimizer在哪里都不影响
    # optimizer4value.zero_grad()
    value_loss = compute_value_loss(values, batch_rewards_to_go)
    value_loss.backward()
    assert any(p.grad is not None for p in value_net.parameters())
    optimizer4value.step()

In [15]:
max_epochs = 200 # 经过测试，200个epoch可以较好的拟合value net
env = Environment()
agent = Agent(env=env)
value_net = ValueNet(env=env)
optimizer4agent = Adam(agent.parameters(), lr=1e-3)
optimizer4value = Adam(value_net.parameters(), lr=1e-3)

# 输出一下policy net, 看看动作选的对不对
for i in range(env.goal):
    action, logp= agent.get_action(state=torch.tensor(data=[i]))
    print(action, logp)

# 输出一下value net的数值, 看看训练的对不对
for i in range(env.goal):
    output = value_net.forward(state=torch.tensor(data=[i]))
    print(output.float())

tensor(0)
Categorical(logits: torch.Size([1, 2]))
LEFT tensor([-0.6802])
tensor(0)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.9141])
tensor(1)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.6536])
tensor(1)
Categorical(logits: torch.Size([1, 2]))
LEFT tensor([-0.7967])
tensor([[-0.0881]], grad_fn=<AddmmBackward0>)
tensor([[0.0705]], grad_fn=<AddmmBackward0>)
tensor([[-0.0484]], grad_fn=<AddmmBackward0>)
tensor([[0.0648]], grad_fn=<AddmmBackward0>)


In [16]:
for epoch in tqdm(range(max_epochs)):
    train_one_epoch(env=env, agent=agent, value_net=value_net, optimizer4agent=optimizer4agent, optimizer4value=optimizer4value)

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:04<00:00, 47.58it/s]


In [17]:
# 输出一下policy net, 看看动作选的对不对
for i in range(env.goal):
    action, logp= agent.get_action(state=torch.tensor(data=[i]))
    print(action, logp)

# 输出一下value net的数值, 看看训练的对不对
for i in range(env.goal):
    output = value_net.forward(state=torch.tensor(data=[i]))
    print(output.float())
# 直接根据Vt的公式进行计算，也就是sum of discouted rewards, 应该是 4 3 2 1

tensor(1)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.0375])
tensor(1)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.0366])
tensor(1)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.1620])
tensor(1)
Categorical(logits: torch.Size([1, 2]))
RIGHT tensor([-0.1496])
tensor([[3.8382]], grad_fn=<AddmmBackward0>)
tensor([[2.9628]], grad_fn=<AddmmBackward0>)
tensor([[1.9706]], grad_fn=<AddmmBackward0>)
tensor([[1.0019]], grad_fn=<AddmmBackward0>)
