In [246]:
# 2026/2/6
# zhangzhong
# REINFORCE

In [247]:
from torch import nn, Tensor
from torch.distributions import Categorical
import torch
from torch.optim import Adam
from tqdm import tqdm

actions = ["LEFT", "RIGHT"]


class Environment:

    # TODO: discrete state mapping to embedding
    def __init__(self, goal: int = 4):
        self.goal: int = goal
        self.state: int = 0
        self.actions = ["LEFT", "RIGHT"]

    def step(self, action: str) -> tuple[float, int]:
        if action == "LEFT":
            self.state = max(0, self.state-1)
        elif action == "RIGHT":
            self.state = min(self.goal, self.state+1)
        return (10 if self.is_done() else -1, self.state)
    
    
    def curr_state(self) -> Tensor:
        return torch.tensor(data=[self.state])

    def is_done(self) -> bool:
        return self.state == self.goal
    
    def reset(self):
        self.state = 0

In [248]:
class Agent(nn.Module):
    def __init__(self, env: Environment) -> None:
        super().__init__()

        # BUG: out of index
        # self.state_embedding = nn.Embedding(num_embeddings=env.goal, embedding_dim=8)
        self.state_embedding = nn.Embedding(num_embeddings=env.goal + 1, embedding_dim=8)
        self.env = env

        self.policy_net = nn.Sequential(
            nn.Linear(in_features=8, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2)
        )

    # state
    # https://docs.pytorch.org/docs/stable/distributions.html#categorical
    def sample_action(self, state: Tensor) -> tuple[str, Tensor]:

        # 我们在sample的时候就能计算出log_probs来，要不也顺便返回一下吧

        embeddings: Tensor = self.state_embedding(state)
        logits = self.policy_net(embeddings)
        probs = Categorical(logits=logits)
        actions = probs.sample()
        # print(actions)
        return self.env.actions[actions.item()], probs.log_prob(value=actions)
        # log_probs = probs.log_prob(actions)


In [249]:
env = Environment()
agent = Agent(env=env)
agent.sample_action(state=torch.tensor(data=[0]))

('RIGHT', tensor([-0.5402], grad_fn=<SqueezeBackward1>))

In [250]:
# sample episode

batch_size = 128

def sample_episode(env: Environment, agent: Agent) -> list:
    env.reset()

    # episode 要收集什么？ (st: tensor, at: str, rt: float)
    # ()
    episode = []
    curr_state: Tensor = env.curr_state()

    for t in range(100):
        # initial state = ?


        if env.is_done():
            break
        # 
        action, log_prob = agent.sample_action(state=curr_state)
        # 
        reward, next_state = env.step(action=action)
        
        episode.append((curr_state, action, reward, log_prob))
        curr_state = torch.tensor(data=[next_state])

        # BUG: 如果env已经done了，上面还会采样状态，就是不对的
        # 应该在最开始就判断env是不是已经结束
        # if env.is_done():
        #     break


    return episode



In [251]:
# collect on episode

episode = sample_episode(env, agent)
print(episode)

[(tensor([0]), 'RIGHT', -1, tensor([-0.5402], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'RIGHT', -1, tensor([-0.6645], grad_fn=<SqueezeBackward1>)), (tensor([2]), 'RIGHT', -1, tensor([-0.7219], grad_fn=<SqueezeBackward1>)), (tensor([3]), 'LEFT', -1, tensor([-0.7341], grad_fn=<SqueezeBackward1>)), (tensor([2]), 'LEFT', -1, tensor([-0.6652], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.7227], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'LEFT', -1, tensor([-0.8738], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', -1, tensor([-0.5402], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.7227], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', -1, tensor([-0.5402], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'LEFT', -1, tensor([-0.7227], grad_fn=<SqueezeBackward1>)), (tensor([0]), 'RIGHT', -1, tensor([-0.5402], grad_fn=<SqueezeBackward1>)), (tensor([1]), 'RIGHT', -1, tensor([-0.6645], grad_fn=<SqueezeBackward1>)), (tensor([2]), 'RIGHT', -1, ten

In [252]:
## calculate rewards-to-go

def calculate_rewards_to_go(rewards: list[float], gamma: float = 1.0):
    reversed_rewards = reversed(rewards)
    rewards_to_go = []
    curr_reward = 0
    for reward in reversed_rewards:
        curr_reward = reward + gamma*curr_reward
        rewards_to_go.append(curr_reward)
    return list(reversed(rewards_to_go))


In [253]:
## test rewards to go
rewards = [1.0, 1.0, 1.0]
print(calculate_rewards_to_go(rewards, gamma=1.0))
print(calculate_rewards_to_go(rewards, gamma=0.9))

[3.0, 2.0, 1.0]
[2.71, 1.9, 1.0]


In [254]:
# deal with rewards sampled from a episode

def make_rewards_to_go(episode: list, gamma: float) -> list:
    rewards: list[float] = [reward for _, _, reward, _ in episode]
    rewards_to_go = calculate_rewards_to_go(rewards=rewards, gamma=gamma)
    return rewards_to_go

In [255]:
# collect one batch of data, how?
# 这里有两种实现的方法
# 一种是严格的按照公式来计算，每个episode都计算完成之后，然后再计算N个episode
# 还有一种就是每个step看成是独立的，就收集batch size个step，然后就计算就行

# 从整体实现上来将，肯定是第二种简单，就按照这种来吧，从数学上是一样的
# 咱们先不写reward to go了

gamma = 0.99

# 还有一个问题，episode的每个step对应的reward应该是rewards to go
# 可以写一个函数来计算一下

def collect_training_data():
    batch_log_probs = []
    batch_rewards_to_go = []
    while True:
        episode = sample_episode(env=env, agent=agent)
        rewards_to_go = make_rewards_to_go(episode=episode, gamma=gamma)
        for (curr_state, curr_action, _, log_prob), reward in zip(episode, rewards_to_go):
            # print(curr_state, curr_action, log_prob, reward)
            

            batch_log_probs.append(log_prob)
            batch_rewards_to_go.append(reward)
            if (len(batch_log_probs) >= batch_size):
                break

        if (len(batch_log_probs) >= batch_size):
            break

    return batch_log_probs, batch_rewards_to_go

# 我们的一个batch的数据需要什么？
# optimizer.zero_grad()
# compute_loss -> loss
# loss.backward()
# optimizer.step()
# 

In [256]:
batch_log_probs, batch_rewards_to_go = collect_training_data()
print(batch_log_probs, batch_rewards_to_go)

[tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.6645], grad_fn=<SqueezeBackward1>), tensor([-0.6652], grad_fn=<SqueezeBackward1>), tensor([-0.6645], grad_fn=<SqueezeBackward1>), tensor([-0.6652], grad_fn=<SqueezeBackward1>), tensor([-0.7227], grad_fn=<SqueezeBackward1>), tensor([-0.8738], grad_fn=<SqueezeBackward1>), tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.7227], grad_fn=<SqueezeBackward1>), tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.7227], grad_fn=<SqueezeBackward1>), tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.6645], grad_fn=<SqueezeBackward1>), tensor([-0.6652], grad_fn=<SqueezeBackward1>), tensor([-0.7227], grad_fn=<SqueezeBackward1>), tensor([-0.8738], grad_fn=<SqueezeBackward1>), tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.7227], grad_fn=<SqueezeBackward1>), tensor([-0.5402], grad_fn=<SqueezeBackward1>), tensor([-0.6645], grad_fn=<SqueezeBackward1>), tensor([-0.7219], grad_fn=<SqueezeBackward1>), tensor([-0.7

In [257]:
def compute_loss(batch_log_probs: list[Tensor], batch_rewards_to_go: list[Tensor]) -> Tensor:
    grads = [log_probs*rewards for log_probs, rewards in zip(batch_log_probs, batch_rewards_to_go)]

    # BUG!!! 这一步是构建一个新的tensor，然后把旧的值给复制过来，所以不会保留计算图，requireds_grad=False
    # grads_tensor = torch.tensor(grads)
    grads_tensor = torch.stack(grads)
    assert grads_tensor.shape == (batch_size, 1)
    loss = -grads_tensor.mean()
    return loss

In [258]:
def train_one_epoch(env: Environment, agent: Agent, optimizer):
    # 1. collect training data of this epoch
    batch_log_probs, batch_rewards_to_go = collect_training_data()

    # 2. zero policy net grads
    optimizer.zero_grad()

    # 3. compute loss
    # 大问题！我这样计算出来的loss是没有梯度的！为什么？
    loss = compute_loss(batch_log_probs, batch_rewards_to_go)

    # 4. backward propogation
    loss.backward()

    # 5. optimize policy net
    optimizer.step()

In [259]:
max_epochs = 100
env = Environment()
agent = Agent(env=env)
optimizer = Adam(agent.parameters(), lr=1e-3)

for epoch in tqdm(range(max_epochs)):
    train_one_epoch(env=env, agent=agent, optimizer=optimizer)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:02<00:00, 38.63it/s]


In [263]:
# 输出一下policy net, 看看动作选的对不对
for i in range(env.goal):
    action, logp= agent.sample_action(state=torch.tensor(data=[i]))
    print(action, logp)

RIGHT tensor([-0.1109], grad_fn=<SqueezeBackward1>)
RIGHT tensor([-0.2101], grad_fn=<SqueezeBackward1>)
RIGHT tensor([-0.0176], grad_fn=<SqueezeBackward1>)
RIGHT tensor([-0.1070], grad_fn=<SqueezeBackward1>)
