In [None]:
!pip install torch_geometric

In [4]:
import numpy as np
import torch
import torch.nn as nn
import gym


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()


SEED = 1234

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    gym.utils.seeding.np_random(seed)


class EchoStateNetwork(nn.Module):
    def __init__(self, input_dim, reservoir_size, spectral_radius=0.9, sparsity=0.5, leaky_rate=0.2):
        super(EchoStateNetwork, self).__init__()
        self.reservoir_size = reservoir_size
        self.spectral_radius = spectral_radius
        self.leaky_rate = leaky_rate

        self.W_in = (torch.rand(reservoir_size, input_dim) - 0.5) * 2 / input_dim

        W = torch.rand(reservoir_size, reservoir_size) - 0.5
        mask = torch.rand(reservoir_size, reservoir_size) > sparsity
        W[mask] = 0

        eigenvector = torch.rand(reservoir_size, 1)
        for _ in range(50):
            eigenvector = W @ eigenvector
            eigenvector = eigenvector / eigenvector.norm()
        max_eigenvalue = eigenvector.norm()
        self.W = W * (spectral_radius / max_eigenvalue)

        self.register_buffer("state", torch.zeros(reservoir_size))

    def forward(self, x):
        device = x.device
        self.state = self.state.to(device)
        self.W_in = self.W_in.to(device)
        self.W = self.W.to(device)

        self.state = (1 - self.leaky_rate) * self.state + self.leaky_rate * torch.tanh(self.W_in @ x + self.W @ self.state)
        self.state = self.state / (self.state.norm(dim=0, keepdim=True).clamp(min=1e-6))
        return self.state


class MultiScaleESN(nn.Module):
    def __init__(self, input_dim, reservoir_dims, output_dim, spectral_radii, sparsities, leaky_rates, device="cpu"):
        super(MultiScaleESN, self).__init__()
        self.device = device
        self.reservoirs = nn.ModuleList([
            EchoStateNetwork(input_dim, res_dim, sr, sp, lr)
            for res_dim, sr, sp, lr in zip(reservoir_dims, spectral_radii, sparsities, leaky_rates)
        ])
        self.readout = nn.Linear(sum(reservoir_dims), output_dim).to(device)

    def forward(self, x):
        x = x.to(self.device)
        states = [reservoir(x) for reservoir in self.reservoirs]
        combined_state = torch.cat(states, dim=-1)
        output = self.readout(combined_state)
        return output


class PolicyNetwork(nn.Module):
    def __init__(self, multiscale_esn, action_space):
        super(PolicyNetwork, self).__init__()
        self.multiscale_esn = multiscale_esn
        self.action_space = action_space
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        logits = self.multiscale_esn(x)
        probabilities = self.softmax(logits)
        return probabilities


def train():
    set_seed(SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    env = gym.make("CartPole-v1")
    env.seed(SEED)

    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # Define multi-scale reservoir parameters
    reservoir_dims = [300, 400, 500]  # Different sizes for reservoirs
    spectral_radii = [0.9, 0.9, 0.9]  # Different spectral radii
    sparsities = [0.5, 0.5, 0.5]     # Different sparsities
    leaky_rates = [0.2, 0.3, 0.4]    # Different leaky rates

    multiscale_esn = MultiScaleESN(input_dim, reservoir_dims, action_dim, spectral_radii, sparsities, leaky_rates, device).to(device)
    policy = PolicyNetwork(multiscale_esn, action_dim).to(device)

    optimizer = torch.optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99


    for episode in range(500):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device)
        rewards = []
        log_probs = []

        done = False
        while not done:
            action_probs = policy(state)


            assert action_probs.shape[0] == env.action_space.n, \
                f"Action probabilities dimension {action_probs.shape[0]} does not match action space {env.action_space.n}"


            dist = torch.distributions.Categorical(probs=action_probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            log_probs.append(log_prob)

            next_state, reward, done, _ = env.step(action.item())
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device)
            rewards.append(reward)

            state = next_state


        discounted_rewards = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            discounted_rewards.insert(0, R)
        discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32, device=device)


        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)


        loss = 0
        for log_prob, R in zip(log_probs, discounted_rewards):
            loss += -log_prob * R


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        if episode % 10 == 0:
            print(f"Episode {episode}, Total Reward: {sum(rewards)}")

    env.close()


train()


Using device: cuda
Episode 0, Total Reward: 11.0
Episode 10, Total Reward: 188.0
Episode 20, Total Reward: 175.0
Episode 30, Total Reward: 500.0
Episode 40, Total Reward: 248.0
Episode 50, Total Reward: 172.0
Episode 60, Total Reward: 252.0
Episode 70, Total Reward: 139.0
Episode 80, Total Reward: 404.0
Episode 90, Total Reward: 429.0
Episode 100, Total Reward: 312.0
Episode 110, Total Reward: 500.0
Episode 120, Total Reward: 173.0
Episode 130, Total Reward: 144.0
Episode 140, Total Reward: 97.0
Episode 150, Total Reward: 138.0
Episode 160, Total Reward: 110.0
Episode 170, Total Reward: 82.0
Episode 180, Total Reward: 71.0
Episode 190, Total Reward: 94.0
Episode 200, Total Reward: 71.0
Episode 210, Total Reward: 77.0
Episode 220, Total Reward: 138.0
Episode 230, Total Reward: 77.0
Episode 240, Total Reward: 79.0
Episode 250, Total Reward: 105.0
Episode 260, Total Reward: 124.0
Episode 270, Total Reward: 127.0
Episode 280, Total Reward: 105.0
Episode 290, Total Reward: 161.0
Episode 300