In [4]:
import numpy as np
import torch
import torch.nn as nn
import gym

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()

SEED = 1234


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    gym.utils.seeding.np_random(seed)



class EchoStateNetwork(nn.Module):
    def __init__(
        self, input_dim, reservoir_size, spectral_radius=0.9, sparsity=0.5, leaky_rate=0.2
    ):
        super(EchoStateNetwork, self).__init__()
        self.reservoir_size = reservoir_size
        self.spectral_radius = spectral_radius
        self.leaky_rate = leaky_rate


        self.W_in = (torch.rand(reservoir_size, input_dim) - 0.5) * 2 / input_dim


        W = torch.rand(reservoir_size, reservoir_size) - 0.5
        mask = torch.rand(reservoir_size, reservoir_size) > sparsity
        W[mask] = 0


        eigenvector = torch.rand(reservoir_size, 1)
        for _ in range(50):
            eigenvector = W @ eigenvector
            eigenvector = eigenvector / eigenvector.norm()
        max_eigenvalue = eigenvector.norm()
        self.W = W * (spectral_radius / max_eigenvalue)


        self.register_buffer("state", torch.zeros(reservoir_size))

    def forward(self, x):
        device = x.device
        self.state = self.state.to(device)
        self.W_in = self.W_in.to(device)
        self.W = self.W.to(device)

        self.state = (1 - self.leaky_rate) * self.state + self.leaky_rate * torch.tanh(
            self.W_in @ x + self.W @ self.state
        )
        self.state = self.state / (self.state.norm(dim=0, keepdim=True).clamp(min=1e-6))
        return self.state



class SymbolicReasoningModule:
    def __init__(self, device):
        self.device = device
        self.rules = {
            "pole_angle": lambda angle: [0.9, 0.1] if angle < -0.1 else ([0.1, 0.9] if angle > 0.1 else [0.5, 0.5]),
            "cart_position": lambda pos: [0.8, 0.2] if pos < -1 else ([0.2, 0.8] if pos > 1 else [0.5, 0.5])
        }

    def forward(self, state):
        pole_angle = state[2].item()
        cart_position = state[0].item()

        angle_output = self.rules["pole_angle"](pole_angle)
        position_output = self.rules["cart_position"](cart_position)

        symbolic_output = [(a + b) / 2 for a, b in zip(angle_output, position_output)]
        return torch.tensor(symbolic_output, dtype=torch.float32, device=self.device)

    def refine_rules(self, feedback):
        for key in self.rules:
            if feedback[key] < 0:
                self.rules[key] = lambda x: [0.6, 0.4] if x < -0.1 else ([0.4, 0.6] if x > 0.1 else [0.5, 0.5])



class NeuroSymbolicEchoStateNetwork(nn.Module):
    def __init__(self, input_dim, reservoir_dim, output_dim, device, symbolic_dim=2, spectral_radius=0.9, sparsity=0.1, leaky_rate=0.2):
        super(NeuroSymbolicEchoStateNetwork, self).__init__()
        self.device = device


        self.esn = EchoStateNetwork(
            input_dim=input_dim,
            reservoir_size=reservoir_dim,
            spectral_radius=spectral_radius,
            sparsity=sparsity,
            leaky_rate=leaky_rate,
        )


        self.symbolic_module = SymbolicReasoningModule(device=device)


        self.readout = nn.Linear(reservoir_dim + symbolic_dim, output_dim).to(device)

    def forward(self, x):

        esn_output = self.esn(x)

        symbolic_output = self.symbolic_module.forward(x)

        combined_input = torch.cat((esn_output, symbolic_output))

        output = self.readout(combined_input)
        return output

    def refine_symbolic_rules(self, feedback):
        self.symbolic_module.refine_rules(feedback)



class PolicyNetwork(nn.Module):
    def __init__(self, esn, action_space):
        super(PolicyNetwork, self).__init__()
        self.esn = esn
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        logits = self.esn(x)
        probabilities = self.softmax(logits)
        return probabilities



def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    set_seed(SEED)

    env = gym.make("CartPole-v1")
    env.seed(SEED)

    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    reservoir_dim = 500

    esn = NeuroSymbolicEchoStateNetwork(
        input_dim=input_dim,
        reservoir_dim=reservoir_dim,
        output_dim=action_dim,
        device=device,
    )
    policy = PolicyNetwork(esn, action_dim).to(device)

    optimizer = torch.optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99

    for episode in range(500):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device)
        rewards = []
        log_probs = []
        feedback = {"pole_angle": 0, "cart_position": 0}

        done = False
        while not done:
            action_probs = policy(state)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            log_probs.append(log_prob)

            next_state, reward, done, _ = env.step(action.item())
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device)
            rewards.append(reward)

            feedback["pole_angle"] += reward if next_state[2].abs() < 0.1 else -1
            feedback["cart_position"] += reward if next_state[0].abs() < 1 else -1

            state = next_state

        esn.refine_symbolic_rules(feedback)

        discounted_rewards = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            discounted_rewards.insert(0, R)
        discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32, device=device)

        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

        loss = 0
        for log_prob, R in zip(log_probs, discounted_rewards):
            loss += -log_prob * R

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 10 == 0:
            print(f"Episode {episode}, Total Reward: {sum(rewards)}")

    env.close()


train()


Using device: cuda
Episode 0, Total Reward: 11.0
Episode 10, Total Reward: 56.0
Episode 20, Total Reward: 82.0
Episode 30, Total Reward: 112.0
Episode 40, Total Reward: 163.0
Episode 50, Total Reward: 305.0
Episode 60, Total Reward: 108.0
Episode 70, Total Reward: 317.0
Episode 80, Total Reward: 500.0
Episode 90, Total Reward: 500.0
Episode 100, Total Reward: 134.0
Episode 110, Total Reward: 72.0
Episode 120, Total Reward: 68.0
Episode 130, Total Reward: 156.0
Episode 140, Total Reward: 100.0
Episode 150, Total Reward: 128.0
Episode 160, Total Reward: 201.0
Episode 170, Total Reward: 500.0
Episode 180, Total Reward: 500.0
Episode 190, Total Reward: 500.0
Episode 200, Total Reward: 500.0
Episode 210, Total Reward: 500.0
Episode 220, Total Reward: 500.0
Episode 230, Total Reward: 500.0
Episode 240, Total Reward: 390.0
Episode 250, Total Reward: 500.0
Episode 260, Total Reward: 500.0
Episode 270, Total Reward: 500.0
Episode 280, Total Reward: 500.0
Episode 290, Total Reward: 500.0
Episode