In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import collections
import math
import tqdm

from CartpoleAgent import FOLCartpoleAgent

Observation = collections.namedtuple("Observation", ("Position", "Velocity", "Angle", "AngVelocity"))
Transition = collections.namedtuple("Transition", ("state", "action", "next_state", "reward", "done"))

EPSILON_START = 1.0
EPSILON_END = 0.1
WARMUP_EPISODES = 25
NUM_EPISODES = 50

n_bin_args = {
    "n_pos": 2,
    "n_vel": 2,
    "n_ang": 2,
    "n_angvel": 2
}

limits = {
    "Position": [-1.2, 1.2],
    "Velocity": [-2, 2],
    "Angle": [-0.2094395, 0.2094395],
    "AngVelocity": [-3, 3]
}

agent = FOLCartpoleAgent(n_bin_args, n_nodes = 5, limits = limits, t="double")
env = gym.make("CartPole-v1")

def train():
    epsilon = EPSILON_START
    episode_num = 0
    episode_runs = []
    episode_losses = []
    for episode in tqdm.tqdm(range(NUM_EPISODES)):
        total_loss, step = 0,0
        state, info = env.reset()
        while True:
            if np.random.random() < epsilon:
                action = agent.sample_random_action()
            else:
                action = agent.get_action(state)
            next_state, reward, terminal, truncated, info = env.step(action)
            reward = reward/10 if not terminal else 0
            agent.remember(Transition(state, action, next_state, reward, terminal))
            loss = agent.optimize()
            state = next_state
            
            if loss is not None:
                total_loss += loss
                step += 1
            
            if terminal or truncated:
                if step > 0:
                    print("Run: " + str(episode) + ", score: " + str(step) + ", episode_loss: " + str(total_loss/step))
                    episode_runs.append(step)
                    episode_losses.append(total_loss/step)
                    epsilon -= (EPSILON_START - EPSILON_END)/WARMUP_EPISODES
                    
                    epsilon = min(epsilon, EPSILON_END)
                break

In [2]:
train()

  return F.mse_loss(input, target, reduction=self.reduction)
 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 47/50 [00:21<00:01,  2.22it/s]

Run: 46, score: 13, episode_loss: 24.0


 96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▍    | 48/50 [00:38<00:01,  1.05it/s]

Run: 47, score: 8, episode_loss: 24.06250023841858


 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋  | 49/50 [01:15<00:02,  2.39s/it]

Run: 48, score: 21, episode_loss: 36.49827176048642


 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋  | 49/50 [01:42<00:02,  2.09s/it]


KeyboardInterrupt: 

In [11]:
import itertools

In [15]:
itertools.chain.from_iterable([n.parameters() for n in agent.lnn.model.nodes.values()])

<itertools.chain at 0x22bff913eb0>

In [17]:
for val in agent.lnn.model.nodes.values():
    print(val.parameters())

<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 0x0000022BFFA7DE70>
<generator object Formula.parameters at 

In [3]:
agent.lnn.model.nodes

{0: <lnn.symbolic.logic.n_ary_neuron.And at 0x22bf882ab90>,
 1: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87baef0>,
 2: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bae00>,
 3: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87baa10>,
 4: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87ba230>,
 5: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bb370>,
 6: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87bba60>,
 7: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4ca0>,
 8: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4e20>,
 9: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f4fd0>,
 10: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5180>,
 11: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5330>,
 12: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f54e0>,
 13: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5690>,
 14: <lnn.symbolic.logic.leaf_formula.Predicate at 0x22bf87f5840>,
 15: <lnn.s

In [8]:
agent.lnn.model.print(params = True)


***************************************************************************
                                LNN Model

OPEN Or: ((pos1(0) ∧ notpos1(0) ∧ pos-1(0) ∧ notpos-1(0) ∧ pos2(0) ∧ notpos2(0) ∧ pos-2(0) ∧ notpos-2(0) ∧ pos3(0) ∧ notpos3(0) ∧ pos-3(0) ∧ notpos-3(0) ∧ pos4(0) ∧ notpos4(0) ∧ pos-4(0) ∧ notpos-4(0) ∧ pos5(0) ∧ notpos5(0) ∧ pos-5(0) ∧ notpos-5(0) ∧ vel1(0) ∧ notvel1(0) ∧ vel-1(0) ∧ notvel-1(0) ∧ vel2(0) ∧ notvel2(0) ∧ vel-2(0) ∧ notvel-2(0) ∧ vel3(0) ∧ notvel3(0) ∧ vel-3(0) ∧ notvel-3(0) ∧ vel4(0) ∧ notvel4(0) ∧ vel-4(0) ∧ notvel-4(0) ∧ vel5(0) ∧ notvel5(0) ∧ vel-5(0) ∧ notvel-5(0) ∧ ang1(0) ∧ notang1(0) ∧ ang-1(0) ∧ notang-1(0) ∧ ang2(0) ∧ notang2(0) ∧ ang-2(0) ∧ notang-2(0) ∧ ang3(0) ∧ notang3(0) ∧ ang-3(0) ∧ notang-3(0) ∧ ang4(0) ∧ notang4(0) ∧ ang-4(0) ∧ notang-4(0) ∧ ang5(0) ∧ notang5(0) ∧ ang-5(0) ∧ notang-5(0) ∧ ang6(0) ∧ notang6(0) ∧ ang-6(0) ∧ notang-6(0) ∧ ang7(0) ∧ notang7(0) ∧ ang-7(0) ∧ notang-7(0) ∧ ang8(0) ∧ notang8(0) ∧ ang-8(0) ∧ notang-8(0) ∧ ang9(

In [7]:
agent.lnn.model.parameters()

[tensor(1.),
 tensor(1.),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 ten

In [3]:
state, _ = env.reset()
action = agent.get_action(state)

0.0


In [9]:
def distill(self, n_bins=[4,4,4,4], lims=[[-1.2, 1.2], [-2, 2], [-0.2094395, 0.2094395], [-3, 3]]):
    if isinstance(self, FOLCartpoleAgent): #this method can be directly copied for dqn agent, I think
        n_bins = [2*self.bin_args[key] for key in self.bin_args] #includes pos and neg
        lims = [self.limits[key] for key in self.limits]
    steps = [(l[1] - l[0]) / n for l, n in zip(lims, n_bins)]
    state_init = torch.Tensor([l[0] + 0.5*step for l, step in zip(lims, steps)]) #init to midpoints of lowest bins
    curr_state = state_init.clone()
    ret = torch.ones(n_bins)
    for i in tqdm.tqdm(range(n_bins[0])):
        for j in tqdm.tqdm(range(n_bins[1])):
            for k in range(n_bins[2]):
                for h in range(n_bins[3]):
                    if isinstance(self, FOLCartpoleAgent):
#                         fol = [self.env2fol(curr_state)]
#                         right = self.right_lnn.forward(fol)[0].mean().item()
#                         left = self.left_lnn.forward(fol)[0].mean().item()
#                         ret[i,j,k,h] = 1 if right > left else 0
                        ret[i,j,k,h] = self.get_action(curr_state)
                    else:
                        ret[i,j,k,h] = self.choose_action(curr_state)
                    curr_state[3] += steps[3] #increment by step
                curr_state[2] += steps[2]
                curr_state[3] = state_init[3] #reset state
            curr_state[1] += steps[1]
            curr_state[2] = state_init[2]
        curr_state[0] += steps[0]
        curr_state[1] = state_init[1]
    return ret

In [10]:
policy = distill(agent)
print(policy)

  0%|                                                                                                                           | 0/4 [00:00<?, ?it/s]
  0%|                                                                                                                           | 0/4 [00:00<?, ?it/s][A
 25%|████████████████████████████▊                                                                                      | 1/4 [00:09<00:27,  9.11s/it][A
 50%|█████████████████████████████████████████████████████████▌                                                         | 2/4 [00:17<00:17,  8.73s/it][A
 75%|██████████████████████████████████████████████████████████████████████████████████████▎                            | 3/4 [00:25<00:08,  8.58s/it][A
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:34<00:00,  8.64s/it][A
 25%|████████████████████████████▊                                             

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],



