In [1]:
from leap_torch.ops import mutate_guassian, uniform_crossover
from leap_torch.initializers import create_instance
from leap_torch.decoders import NumpyDecoder

from leap_ec import ops
from leap_ec.probe import FitnessPlotProbe
from leap_ec.representation import Representation
from leap_ec.algorithm import generational_ea
from leap_ec.executable_rep.problems import EnvironmentProblem
from leap_ec.executable_rep.executable import ArgmaxExecutable, WrapperDecoder

import gymnasium as gym
import matplotlib.pyplot as plt

%matplotlib qt5

In [2]:
device = "cpu"

POP_SIZE = 100
GENERATIONS = 250

RUNS_PER_FITNESS_EVAL = 5
SIMULATION_STEPS = 500

QSocketNotifier: Can only be used with threads started with QThread


In [3]:
import torch
import torch.nn as nn
import snntorch as snn


# Define Network
class SpikingNetwork(nn.Module):
   def __init__(self, num_inputs, num_hidden, num_outputs, start_beta=0.1, num_steps=10):
          super().__init__()

          self.num_steps = num_steps

          self.fc1 = nn.Linear(num_inputs, num_hidden)
          self.fc2 = nn.Linear(num_hidden, num_outputs)
          with torch.no_grad():
               for param in self.fc1.parameters():
                    nn.init.uniform_(param, -1, 1)
               for param in self.fc2.parameters():
                    nn.init.uniform_(param, -1, 1)
          
          self.lif1 = snn.Leaky(start_beta, learn_beta=True, learn_threshold=True)
          self.lif2 = snn.Leaky(start_beta, learn_beta=True, learn_threshold=True)

   def forward(self, x):
          spk1, syn1 = self.lif1.init_synaptic()
          spk2, syn2 = self.lif2.init_synaptic()

          spk2_sum = None

          for _ in range(self.num_steps):
               cur1 = self.fc1(x)
               spk1, syn1 = self.lif1(cur1, syn1)
               cur2 = self.fc2(spk1)
               spk2, syn2 = self.lif2(cur2, syn2)
               
               if spk2_sum is None:
                    spk2_sum = torch.zeros(spk2.size()).to(spk2)
               spk2_sum += spk2
               
          return spk2_sum

In [4]:
env = gym.make("CartPole-v1")
decoder = WrapperDecoder(
    wrapped_decoder=NumpyDecoder(device=device), decorator=ArgmaxExecutable
)

plot_probe = FitnessPlotProbe(
        ylim=(0, 1), xlim=(0, 1),
        modulo=1, ax=plt.gca()
    )

generational_ea(
        max_generations=GENERATIONS, pop_size=POP_SIZE,
        
        problem=EnvironmentProblem(
            RUNS_PER_FITNESS_EVAL, SIMULATION_STEPS,
            environment=env, fitness_type="reward", gui=False
        ),
        
        representation=Representation(
            initialize=create_instance(
                SpikingNetwork,
                env.observation_space.shape[0], 20, env.action_space.n
            ), decoder=decoder
        ),
        
        pipeline=[
            ops.tournament_selection,
            ops.clone,
            mutate_guassian(std=0.05, p_mutate=0.01),
            uniform_crossover(),
            ops.evaluate,
            ops.pool(size=POP_SIZE),
            plot_probe,
        ]
    )




qt.qpa.wayland: Failed to initialize EGL display 3001
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestActivate()
qt.qpa.wayland: Wayland does not support QWindow::requestAct

: 

: 