In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import collections
import math
import tqdm

from CartpoleAgent import FOLCartpoleAgent

Observation = collections.namedtuple("Observation", ("Position", "Velocity", "Angle", "AngVelocity"))
Transition = collections.namedtuple("Transition", ("state", "action", "next_state", "reward", "done"))

EPSILON_START = 1.0
EPSILON_END = 0.1
WARMUP_EPISODES = 250
NUM_EPISODES = 500

n_bin_args = {
    "n_pos": 5,
    "n_vel": 5,
    "n_ang": 10,
    "n_angvel": 5
}

limits = {
    "Position": [-1.2, 1.2],
    "Velocity": [-2, 2],
    "Angle": [-0.2094395, 0.2094395],
    "AngVelocity": [-3, 3]
}

agent = FOLCartpoleAgent(n_bin_args, n_nodes = 5, limits = limits, t="double")
env = gym.make("CartPole-v1")

def train():
    epsilon = EPSILON_START
    episode_num = 0
    episode_runs = []
    episode_losses = []
    for episode in tqdm.tqdm(range(NUM_EPISODES)):
        total_loss, step = 0,0
        state, info = env.reset()
        while True:
            if np.random.random() < epsilon:
                action = agent.sample_random_action()
            else:
                action = agent.get_action(state)
            next_state, reward, terminal, truncated, info = env.step(action)
            reward = reward/10 if not terminal else 0
            agent.remember(Transition(state, action, next_state, reward, terminal))
            loss = agent.optimize()
            state = next_state
            
            if loss is not None:
                total_loss += loss
                step += 1
            
            if terminal or truncated:
                if step > 0:
                    print("Run: " + str(episode) + ", score: " + str(step) + ", episode_loss: " + str(total_loss/step))
                    episode_runs.append(step)
                    episode_losses.append(total_loss/step)
                    epsilon -= (EPSILON_START - EPSILON_END)/WARMUP_EPISODES
                    
                    epsilon = min(epsilon, EPSILON_END)
                break

In [None]:
train()

  8%|████▉                                                          | 39/500 [00:43<08:32,  1.11s/it]

Run: 38, score: 14, episode_loss: 24.0


  8%|█████                                                          | 40/500 [01:37<22:50,  2.98s/it]

Run: 39, score: 14, episode_loss: 30.540901592799596


  8%|█████▏                                                         | 41/500 [03:13<57:07,  7.47s/it]

Run: 40, score: 28, episode_loss: 38.50791159697941


  8%|█████                                                        | 42/500 [06:03<2:16:25, 17.87s/it]

Run: 41, score: 48, episode_loss: 37.240512082974114


  9%|█████▏                                                       | 43/500 [06:53<2:37:12, 20.64s/it]

Run: 42, score: 13, episode_loss: 37.913462565495415


  9%|█████▎                                                       | 44/500 [07:53<3:08:24, 24.79s/it]

Run: 43, score: 16, episode_loss: 36.32812616229057


  9%|█████▍                                                       | 45/500 [08:38<3:29:01, 27.56s/it]

Run: 44, score: 13, episode_loss: 38.062186864706185


  9%|█████▌                                                       | 46/500 [10:56<5:41:39, 45.15s/it]

Run: 45, score: 37, episode_loss: 40.21743051425831


  9%|█████▋                                                       | 47/500 [11:35<5:31:55, 43.96s/it]

Run: 46, score: 11, episode_loss: 34.80129367654974


 10%|█████▊                                                       | 48/500 [12:44<6:10:43, 49.21s/it]

Run: 47, score: 18, episode_loss: 38.443033615748085


 10%|█████▉                                                       | 49/500 [13:46<6:32:10, 52.17s/it]

Run: 48, score: 19, episode_loss: 35.50263249246698


 10%|██████                                                       | 50/500 [14:38<6:30:21, 52.05s/it]

Run: 49, score: 15, episode_loss: 45.862502098083496


 10%|██████                                                      | 51/500 [17:19<10:03:18, 80.62s/it]

Run: 50, score: 38, episode_loss: 37.58060119026586


 10%|██████▎                                                      | 52/500 [18:11<9:03:07, 72.74s/it]

Run: 51, score: 13, episode_loss: 40.59615245232215


 11%|██████▍                                                      | 53/500 [18:58<8:07:46, 65.47s/it]

Run: 52, score: 12, episode_loss: 35.336799462636314


 11%|██████▌                                                      | 54/500 [19:40<7:18:17, 58.96s/it]

Run: 53, score: 11, episode_loss: 37.5724965875799


 11%|██████▋                                                      | 55/500 [20:57<7:55:11, 64.07s/it]

Run: 54, score: 24, episode_loss: 36.98948089281718


 11%|██████▊                                                      | 56/500 [22:55<9:51:18, 79.91s/it]

Run: 55, score: 32, episode_loss: 34.04008932411671


 11%|██████▉                                                      | 57/500 [23:44<8:43:15, 70.87s/it]

Run: 56, score: 14, episode_loss: 37.04136858667646


 12%|███████                                                      | 58/500 [24:36<7:59:11, 65.05s/it]

Run: 57, score: 12, episode_loss: 35.45833174387614


 12%|███████▏                                                     | 59/500 [25:28<7:31:01, 61.36s/it]

Run: 58, score: 11, episode_loss: 38.71136171167547


 12%|███████▎                                                     | 60/500 [26:25<7:19:47, 59.97s/it]

Run: 59, score: 13, episode_loss: 40.350306254166824


 12%|███████▍                                                     | 61/500 [28:25<9:30:35, 77.99s/it]

Run: 60, score: 31, episode_loss: 37.50201479081185


 12%|███████▌                                                     | 62/500 [29:27<8:54:40, 73.24s/it]

Run: 61, score: 18, episode_loss: 38.78876119189792


 13%|███████▋                                                     | 63/500 [30:36<8:44:02, 71.95s/it]

Run: 62, score: 17, episode_loss: 36.7014705433565


 13%|███████▊                                                     | 64/500 [31:43<8:32:15, 70.49s/it]

Run: 63, score: 14, episode_loss: 35.535711901528494


 13%|███████▉                                                     | 65/500 [33:31<9:53:02, 81.80s/it]

Run: 64, score: 25, episode_loss: 40.76333148956299


 13%|████████                                                     | 66/500 [34:31<9:03:56, 75.20s/it]

Run: 65, score: 13, episode_loss: 40.249997249016396


 13%|████████▏                                                    | 67/500 [35:34<8:34:35, 71.31s/it]

Run: 66, score: 13, episode_loss: 41.04486850591806


 14%|████████                                                   | 68/500 [39:15<13:56:51, 116.23s/it]

Run: 67, score: 52, episode_loss: 37.95855494645926


 14%|████████▏                                                  | 69/500 [43:52<19:41:36, 164.49s/it]

Run: 68, score: 62, episode_loss: 38.21853777670091


 14%|████████▎                                                  | 70/500 [46:27<19:18:13, 161.61s/it]

Run: 69, score: 29, episode_loss: 38.41665879611311


 14%|████████▍                                                  | 71/500 [49:14<19:27:09, 163.24s/it]

Run: 70, score: 42, episode_loss: 37.76663038844154


In [11]:
import itertools

In [15]:
itertools.chain.from_iterable([n.parameters() for n in agent.lnn.model.nodes.values()])

<itertools.chain at 0x22bff913eb0>

In [17]:
for val in agent.lnn.model.nodes.values():
    print(val.parameters())

<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 

In [3]:
agent.lnn.model.nodes

{0: <lnn.symbolic.logic.n_ary_neuron.And at 0x22bf882ab90>,
 1: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87baef0>,
 2: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bae00>,
 3: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87baa10>,
 4: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87ba230>,
 5: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bb370>,
 6: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bba60>,
 7: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4ca0>,
 8: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4e20>,
 9: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4fd0>,
 10: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5180>,
 11: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5330>,
 12: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f54e0>,
 13: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5690>,
 14: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5840>,
 15: <lnn.s

In [8]:
agent.lnn.model.print(params = True)


***************************************************************************
                                LNN Model

OPEN Or: ((pos1(0) ∧ notpos1(0) ∧ pos-1(0) ∧ notpos-1(0) ∧ pos2(0) ∧ notpos2(0) ∧ pos-2(0) ∧ notpos-2(0) ∧ pos3(0) ∧ notpos3(0) ∧ pos-3(0) ∧ notpos-3(0) ∧ pos4(0) ∧ notpos4(0) ∧ pos-4(0) ∧ notpos-4(0) ∧ pos5(0) ∧ notpos5(0) ∧ pos-5(0) ∧ notpos-5(0) ∧ vel1(0) ∧ notvel1(0) ∧ vel-1(0) ∧ notvel-1(0) ∧ vel2(0) ∧ notvel2(0) ∧ vel-2(0) ∧ notvel-2(0) ∧ vel3(0) ∧ notvel3(0) ∧ vel-3(0) ∧ notvel-3(0) ∧ vel4(0) ∧ notvel4(0) ∧ vel-4(0) ∧ notvel-4(0) ∧ vel5(0) ∧ notvel5(0) ∧ vel-5(0) ∧ notvel-5(0) ∧ ang1(0) ∧ notang1(0) ∧ ang-1(0) ∧ notang-1(0) ∧ ang2(0) ∧ notang2(0) ∧ ang-2(0) ∧ notang-2(0) ∧ ang3(0) ∧ notang3(0) ∧ ang-3(0) ∧ notang-3(0) ∧ ang4(0) ∧ notang4(0) ∧ ang-4(0) ∧ notang-4(0) ∧ ang5(0) ∧ notang5(0) ∧ ang-5(0) ∧ notang-5(0) ∧ ang6(0) ∧ notang6(0) ∧ ang-6(0) ∧ notang-6(0) ∧ ang7(0) ∧ notang7(0) ∧ ang-7(0) ∧ notang-7(0) ∧ ang8(0) ∧ notang8(0) ∧ ang-8(0) ∧ notang-8(0) ∧ ang9(

In [7]:
agent.lnn.model.parameters()

[tensor(1.),
 tensor(1.),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 ten

In [3]:
state, _ = env.reset()
action = agent.get_action(state)

0.0
