In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m45.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [19]:
import gym
import torch
import numpy as np
from torch.distributions import Categorical
from torch_geometric.data import Data
from itertools import permutations
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch_geometric.nn import GCNConv, LayerNorm, global_add_pool
import random


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()


class Utils:
    @staticmethod
    def set_seed(seed):
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        if torch.backends.cudnn.enabled:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True


class GraphDataProcessor:
    @staticmethod
    def create_graph_data(data, index_pairs):
        edge_index = list(permutations(range(len(data)), 2))
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        node_features = torch.tensor([[data[i], data[j]] for i, j in index_pairs], dtype=torch.float)
        return Data(x=node_features, edge_index=edge_index)

    @staticmethod
    def construct_index_pairs(num_nodes):
        index_pairs = [(i, i + 1) for i in range(0, num_nodes - 1, 2)]
        index_pairs.extend((i, num_nodes - i - 1) for i in range(num_nodes // 2))
        return index_pairs



class EchoStateNetwork(nn.Module):
    def __init__(self, input_dim, reservoir_size, spectral_radius=0.9, sparsity=0.5, leaky_rate=0.2):
        super(EchoStateNetwork, self).__init__()
        self.reservoir_size = reservoir_size
        self.spectral_radius = spectral_radius
        self.leaky_rate = leaky_rate

        self.W_in = (torch.rand(reservoir_size, input_dim) - 0.5) * 2 / input_dim

        W = torch.rand(reservoir_size, reservoir_size) - 0.5
        mask = torch.rand(reservoir_size, reservoir_size) > sparsity
        W[mask] = 0

        eigenvector = torch.rand(reservoir_size, 1)
        for _ in range(50):  # Power iteration
            eigenvector = W @ eigenvector
            eigenvector = eigenvector / eigenvector.norm()
        max_eigenvalue = eigenvector.norm()
        self.W = W * (spectral_radius / max_eigenvalue)

        self.register_buffer("state", torch.zeros(reservoir_size))

    def forward(self, x):
        device = x.device
        self.state = self.state.to(device)
        self.W_in = self.W_in.to(device)
        self.W = self.W.to(device)

        self.state = (1 - self.leaky_rate) * self.state + self.leaky_rate * torch.tanh(self.W_in @ x + self.W @ self.state)
        self.state = self.state / (self.state.norm(dim=0, keepdim=True).clamp(min=1e-6))
        return self.state

class GraphReinforceAgent(nn.Module):
    def __init__(self, input_dimension, output_dimension, esn_reservoir_size=500, hidden_layer_dimension=128, learning_rate=0.0005):
        super(GraphReinforceAgent, self).__init__()
        self.esn = EchoStateNetwork(input_dim=input_dimension, reservoir_size=esn_reservoir_size)
        self.graph_convolution_layer = GCNConv(2, hidden_layer_dimension)
        self.hidden_linear_layer = nn.Linear(hidden_layer_dimension + esn_reservoir_size, hidden_layer_dimension)
        self.output_layer = nn.Linear(hidden_layer_dimension, output_dimension)
        self.normalization_layer = LayerNorm(hidden_layer_dimension)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.scheduler = StepLR(self.optimizer, step_size=100, gamma=0.1)
        self.experience_memory = []

    def store_transition(self, transition):
        self.experience_memory.append(transition)

    def forward(self, node_features, edge_index, esn_state):
        node_features = F.relu(self.graph_convolution_layer(node_features, edge_index))
        node_features = global_add_pool(self.normalization_layer(node_features), torch.LongTensor([0] * 4).to(node_features.device))

        node_features = torch.cat((node_features, esn_state), dim=1)
        node_features = F.relu(self.hidden_linear_layer(node_features))
        output = self.output_layer(node_features)
        return F.log_softmax(output, dim=1)

    def optimize(self, discount_factor):
        cumulative_reward, discounted_rewards, running_discounted_reward = 0, [], 0
        for reward, log_prob in reversed(self.experience_memory):
            running_discounted_reward = reward + discount_factor * running_discounted_reward
            discounted_rewards.append(running_discounted_reward)
        discounted_rewards = np.array(discounted_rewards)
        rewards_mean, rewards_std_dev = discounted_rewards.mean(), discounted_rewards.std()
        self.optimizer.zero_grad()

        for reward, log_prob in reversed(self.experience_memory):
            cumulative_reward = reward + discount_factor * cumulative_reward
            policy_loss = -log_prob * ((cumulative_reward - rewards_mean) / rewards_std_dev)
            policy_loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        self.experience_memory = []



if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seed = 1234
    Utils.set_seed(seed)
    env.seed(seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    learning_rate = 0.0005
    episodes = 500
    gamma = 0.99
    print_interval = 10

    agent = GraphReinforceAgent(state_dim, action_dim, esn_reservoir_size=500, hidden_layer_dimension=128, learning_rate=learning_rate).to(device)
    score_list = []
    index_pairs = GraphDataProcessor.construct_index_pairs(state_dim)

    for episode in range(episodes):
        state = env.reset()
        score = 0
        done = False

        while not done:
            esn_state = agent.esn(torch.tensor(state, dtype=torch.float).to(device))
            graph_data = GraphDataProcessor.create_graph_data(state, index_pairs)
            action_probs = agent(graph_data.x.to(device), graph_data.edge_index.to(device), esn_state.unsqueeze(0))


            action_distribution = Categorical(torch.exp(action_probs))
            action = action_distribution.sample()


            next_state, reward, done, _ = env.step(action.item())
            agent.store_transition((reward, action_probs[0][action]))
            state = next_state
            score += reward


        agent.optimize(gamma)
        score_list.append(score)

        if episode % print_interval == 0 and episode != 0:
            avg_score = sum(score_list[-print_interval:]) / print_interval
            print(f"Episode {episode}, Avg Score: {avg_score}")

    env.close()


Episode 10, Avg Score: 93.9
Episode 20, Avg Score: 476.5
Episode 30, Avg Score: 392.0
Episode 40, Avg Score: 319.3
Episode 50, Avg Score: 346.8
Episode 60, Avg Score: 407.7
Episode 70, Avg Score: 469.5
Episode 80, Avg Score: 454.9
Episode 90, Avg Score: 500.0
Episode 100, Avg Score: 500.0
Episode 110, Avg Score: 480.9
Episode 120, Avg Score: 500.0
Episode 130, Avg Score: 500.0
Episode 140, Avg Score: 500.0
Episode 150, Avg Score: 500.0
Episode 160, Avg Score: 488.1
Episode 170, Avg Score: 500.0
Episode 180, Avg Score: 500.0
Episode 190, Avg Score: 500.0
Episode 200, Avg Score: 500.0
Episode 210, Avg Score: 500.0
Episode 220, Avg Score: 500.0
Episode 230, Avg Score: 500.0
Episode 240, Avg Score: 500.0
Episode 250, Avg Score: 500.0
Episode 260, Avg Score: 500.0
Episode 270, Avg Score: 500.0
Episode 280, Avg Score: 500.0
Episode 290, Avg Score: 500.0
Episode 300, Avg Score: 500.0
Episode 310, Avg Score: 500.0
Episode 320, Avg Score: 497.4
Episode 330, Avg Score: 474.4
Episode 340, Avg Sco