In [1]:
import numpy as np
import torch
import torch.nn as nn
import gym

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class BayesianReadout(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.2):
        super(BayesianReadout, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        return self.dropout(self.fc(x))


class EchoStateNetwork(nn.Module):
    def __init__(
        self, input_dim, reservoir_size, spectral_radius=0.9, sparsity=0.5, leaky_rate=0.2
    ):
        super(EchoStateNetwork, self).__init__()
        self.reservoir_size = reservoir_size
        self.spectral_radius = spectral_radius
        self.leaky_rate = leaky_rate

        self.W_in = (torch.rand(reservoir_size, input_dim) - 0.5) * 2 / input_dim

        W = torch.rand(reservoir_size, reservoir_size) - 0.5
        mask = torch.rand(reservoir_size, reservoir_size) > sparsity
        W[mask] = 0

        eigenvector = torch.rand(reservoir_size, 1)
        for _ in range(50):
            eigenvector = W @ eigenvector
            eigenvector = eigenvector / eigenvector.norm()
        max_eigenvalue = eigenvector.norm()
        self.W = W * (spectral_radius / max_eigenvalue)

        self.register_buffer("state", torch.zeros(reservoir_size))

        self.readout = BayesianReadout(reservoir_size, 2, dropout_rate=0.2)

    def forward(self, x):
        device = x.device
        self.state = self.state.to(device)
        self.W_in = self.W_in.to(device)
        self.W = self.W.to(device)

        self.state = (1 - self.leaky_rate) * self.state + self.leaky_rate * torch.tanh(
            self.W_in @ x + self.W @ self.state
        )
        self.state = self.state / (self.state.norm(dim=0, keepdim=True).clamp(min=1e-6))
        return self.readout(self.state)


class PolicyNetwork(nn.Module):
    def __init__(self, esn, action_dim):
        super(PolicyNetwork, self).__init__()
        self.esn = esn
        self.action_dim = action_dim
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        logits = self.esn(x)
        probabilities = self.softmax(logits)
        return probabilities


def train():
    SEED = 1234
    set_seed(SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env = gym.make("CartPole-v1")
    env.seed(SEED)

    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    reservoir_dim = 500

    esn = EchoStateNetwork(
        input_dim=input_dim,
        reservoir_size=reservoir_dim,
        spectral_radius=0.9,
        sparsity=0.5,
        leaky_rate=0.2
    ).to(device)

    policy = PolicyNetwork(esn, action_dim).to(device)

    optimizer = torch.optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99
    num_samples = 50

    for episode in range(500):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32).to(device)
        rewards = []
        log_probs = []

        done = False
        while not done:
            action_probs_samples = [policy(state) for _ in range(num_samples)]
            action_probs_mean = torch.stack(action_probs_samples).mean(0)

            dist = torch.distributions.Categorical(action_probs_mean)
            action = dist.sample()

            assert env.action_space.contains(action.item()), f"Invalid action: {action.item()}"

            log_prob = dist.log_prob(action)
            log_probs.append(log_prob)

            next_state, reward, done, _ = env.step(action.item())
            next_state = torch.tensor(next_state, dtype=torch.float32).to(device)
            rewards.append(reward)

            state = next_state

        discounted_rewards = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            discounted_rewards.insert(0, R)
        discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32).to(device)

        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

        loss = 0
        for log_prob, R in zip(log_probs, discounted_rewards):
            loss += -log_prob * R

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 10 == 0:
            print(f"Episode {episode}, Total Reward: {sum(rewards)}")

    env.close()


if __name__ == "__main__":
    train()


  deprecation(
  deprecation(
  deprecation(
  if not isinstance(terminated, (bool, np.bool8)):


Episode 0, Total Reward: 8.0
Episode 10, Total Reward: 100.0
Episode 20, Total Reward: 86.0
Episode 30, Total Reward: 121.0
Episode 40, Total Reward: 107.0
Episode 50, Total Reward: 94.0
Episode 60, Total Reward: 298.0
Episode 70, Total Reward: 176.0
Episode 80, Total Reward: 386.0
Episode 90, Total Reward: 94.0
Episode 100, Total Reward: 91.0
Episode 110, Total Reward: 255.0
Episode 120, Total Reward: 475.0
Episode 130, Total Reward: 288.0
Episode 140, Total Reward: 296.0
Episode 150, Total Reward: 229.0
Episode 160, Total Reward: 475.0
Episode 170, Total Reward: 500.0
Episode 180, Total Reward: 500.0
Episode 190, Total Reward: 483.0
Episode 200, Total Reward: 500.0
Episode 210, Total Reward: 500.0
Episode 220, Total Reward: 500.0
Episode 230, Total Reward: 372.0
Episode 240, Total Reward: 500.0
Episode 250, Total Reward: 500.0
Episode 260, Total Reward: 178.0
Episode 270, Total Reward: 352.0
Episode 280, Total Reward: 425.0
Episode 290, Total Reward: 500.0
Episode 300, Total Reward: 