In [3]:
!pip install numpy==1.23.5

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m92.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
xarray 2025.3.1 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
blosc2 3.5.1 requires numpy>=1.26, but you have numpy 1.23.5 which is incompatible.
chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.
imbalanced-learn 0.13.0 require

In [6]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Actor-Critic модель ===
class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(num_inputs, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, num_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(num_inputs, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        value = self.critic(x)
        logits = self.actor(x)
        return logits, value

    def get_distribution(self, x):
        logits, _ = self.forward(x)
        return Categorical(logits=logits)

    def act(self, state):
        dist, value = self.forward(state)
        action = Categorical(logits=dist).sample()
        return action, value

# === Вспомогательные функции ===
def flat_params(model):
    params = []
    for param in model.parameters():
        params.append(param.view(-1))
    return torch.cat(params)

def assign_params(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size

def conjugate_gradient(fvp_fun, b_vec, nsteps, residual_tol=1e-10):
    x = torch.zeros_like(b_vec)
    r = b_vec.clone()
    p = r.clone()
    r_dot_r = torch.dot(r, r)
    for i in range(nsteps):
        Ap = fvp_fun(p)
        alpha = r_dot_r / torch.dot(p, Ap)
        x += alpha * p
        r -= alpha * Ap
        r_dot_r_new = torch.dot(r, r)
        if r_dot_r_new < residual_tol:
            break
        beta = r_dot_r_new / r_dot_r
        p = r + beta * p
        r_dot_r = r_dot_r_new
    return x

def fisher_vector_product(states, old_dist, model, damping=0.1):
    def FVP(v):
        kl = 0
        for state in states:
            new_dist = model.get_distribution(state)
            kl += torch.mean(torch.distributions.kl.kl_divergence(old_dist, new_dist))
        grads = torch.autograd.grad(kl, model.actor.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
        grad_kl_v = torch.dot(flat_grad_kl, v)
        grads2 = torch.autograd.grad(grad_kl_v, model.actor.parameters())
        flat_grads2 = torch.cat([grad.contiguous().view(-1) for grad in grads2])
        return flat_grads2 + damping * v
    return FVP

# === Сбор траекторий ===
def collect_trajectories(env, model, horizon=1000):
    states, actions, rewards, values, log_probs, masks = [], [], [], [], [], []

    state = env.reset()
    done = False
    total_reward = 0

    for step in range(horizon):
        state_tensor = torch.FloatTensor(state).to(device)
        with torch.no_grad():
            dist, value = model(state_tensor)
        action = Categorical(logits=dist).sample()
        log_prob = Categorical(logits=dist).log_prob(action)

        next_state, reward, terminated, truncated = env.step(action.item())
        done = terminated or truncated

        # Сохраняем данные
        states.append(state_tensor)
        actions.append(action)
        rewards.append(reward)
        values.append(value)
        log_probs.append(log_prob)
        masks.append(not done)

        state = next_state
        total_reward += reward

        if done:
            break

    # Последнее значение для GAE
    with torch.no_grad():
        next_value = model(torch.FloatTensor(next_state).to(device))[1]

    returns = discount_rewards(torch.tensor(rewards + [next_value.item()], device=device), gamma=0.99)[:-1]
    advantages = returns - torch.stack(values).squeeze()

    return (
        torch.stack(states),
        torch.stack(actions),
        torch.tensor(rewards, device=device),
        returns,
        torch.stack(log_probs),
        torch.tensor(masks, dtype=torch.float32, device=device),
        total_reward
    )

def discount_rewards(rewards, gamma):
    R = 0
    returns = []
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return torch.tensor(returns, device=device)

# === Обновление TRPO ===
def trpo_update(model, states, actions, old_log_probs, advantages, max_kl=0.01, cg_iters=10, damping=0.1):
    logits = model.actor(states)
    dist = Categorical(logits=logits)
    log_probs = dist.log_prob(actions)
    ratio = torch.exp(log_probs - old_log_probs)

    loss = -(ratio * advantages).mean()

    grads = torch.autograd.grad(loss, model.actor.parameters())
    flat_grad = torch.cat([grad.view(-1) for grad in grads])

    # Получаем FVP
    old_dist = Categorical(logits=logits.detach())
    fvp_fun = fisher_vector_product(states, old_dist, model, damping=damping)

    step_dir = conjugate_gradient(lambda v: fvp_fun(v), flat_grad, cg_iters)

    shs = 0.5 * torch.dot(step_dir, fvp_fun(step_dir))
    lm = torch.sqrt(max_kl / shs)
    full_step = step_dir * lm

    # Выполняем line search
    prev_params = flat_params(model.actor)
    success = False
    for j in range(10):
        new_params = prev_params - full_step * (0.5 ** j)
        assign_params(model.actor, new_params)

        logits_new = model.actor(states)
        if torch.isnan(logits_new).any():
          print("NaN detected in logits. Skipping update.")
          assign_params(model.actor, prev_params)
          return
        new_dist = Categorical(logits=logits_new)
        kl = torch.mean(torch.distributions.kl.kl_divergence(old_dist, new_dist))

        if kl <= max_kl:
            success = True
            break
        assign_params(model.actor, prev_params)

    # Обновляем critic
    values = model.critic(states).squeeze()
    critic_loss = F.mse_loss(values, advantages + values.detach())
    model.critic.zero_grad()
    critic_loss.backward()
    model.critic_optimizer.step()

# === Тестирование модели ===
def test_model(env, model, episodes=3):
    for _ in range(episodes):
        state = env.reset()
        done = False
        total_reward = 0
        while not done:
            with torch.no_grad():
                logits, _ = model(torch.FloatTensor(state).to(device))
                dist = Categorical(logits=logits)
                action = dist.sample()
            next_state, reward, terminated, truncated = env.step(action.item())
            done = terminated or truncated
            total_reward += reward
            state = next_state
        print(f"Test episode reward: {total_reward}")

# === Основной цикл обучения ===
def train_trpo():
    env_name = 'CartPole-v1'
    env = gym.make(env_name, render_mode="human")
    eval_env = gym.make(env_name, render_mode="human")
    num_inputs = env.observation_space.shape[0]
    num_actions = env.action_space.n

    model = ActorCritic(num_inputs, num_actions).to(device)
    model.critic_optimizer = optim.Adam(model.critic.parameters())

    for iteration in range(100):
        states, actions, rewards, returns, old_log_probs, masks, avg_reward = collect_trajectories(env, model)

        advantages = returns - model.critic(states).squeeze()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        trpo_update(model, states, actions, old_log_probs, advantages, max_kl=0.01)

        print(f"Iteration {iteration}, Avg Reward: {avg_reward:.2f}")

        if iteration % 10 == 0:
            test_model(eval_env, model)

if __name__ == "__main__":
    train_trpo()

Iteration 0, Avg Reward: 10.00
Test episode reward: 24.0
Test episode reward: 13.0
Test episode reward: 14.0
Iteration 1, Avg Reward: 16.00
Iteration 2, Avg Reward: 16.00
Iteration 3, Avg Reward: 23.00
Iteration 4, Avg Reward: 15.00
Iteration 5, Avg Reward: 16.00
Iteration 6, Avg Reward: 10.00
NaN detected in logits. Skipping update.
Iteration 7, Avg Reward: 24.00
Iteration 8, Avg Reward: 49.00
NaN detected in logits. Skipping update.
Iteration 9, Avg Reward: 14.00
Iteration 10, Avg Reward: 11.00
Test episode reward: 15.0
Test episode reward: 16.0
Test episode reward: 40.0
Iteration 11, Avg Reward: 17.00
Iteration 12, Avg Reward: 44.00
Iteration 13, Avg Reward: 14.00
NaN detected in logits. Skipping update.
Iteration 14, Avg Reward: 42.00
Iteration 15, Avg Reward: 27.00
Iteration 16, Avg Reward: 20.00
NaN detected in logits. Skipping update.
Iteration 17, Avg Reward: 19.00
Iteration 18, Avg Reward: 39.00
NaN detected in logits. Skipping update.
Iteration 19, Avg Reward: 17.00
Iteration