In [27]:
from lnn import Model, Predicate, Variable, And, Or, Predicates, Fact, Loss, Propositions, Implies, World, Variables
import torch.nn as nn
import gym
import numpy as np
import torch
import collections

In [33]:
Observation = collections.namedtuple("Observation", ("Position", "Velocity", "Angle", "AngVelocity"))

In [2]:
env = gym.make("CartPole-v1")

In [87]:
model = Model()

p1, p2, p3 = Propositions("p1", "p2", "p3")
a1 = And(p1, p2)
o1 = Or(p1, p2)
a2 = And(a1, o1)
i1 = Implies(a2, p3, world= World.AXIOM)

model.add_knowledge(a1, a2, o1, i1)
model.add_data({
    p1: Fact.TRUE,
    p2: Fact.TRUE
})


model.add_labels({
    p3: Fact.FALSE
})

model.train(losses = [Loss.SUPERVISED])
model.print(params = True)

In [174]:
##HOW TO TRAIN MODEL
model = Model()
p1, p2, p3 = Predicates("P1", "P2", "P3")
x = Variable("X")

a1 = And(p1(x), p2(x), p3(x))
model.add_knowledge(a1)

model.add_data({
    p1: {
        '0': Fact.FALSE,
        '1': Fact.FALSE
    },
    p2: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    },
    p3: {
        '0': Fact.TRUE,
        '1': Fact.FALSE
    }
})
model.add_labels({
    a1: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    }
})

model.train(losses = [Loss.SUPERVISED])
model.print(params=True)





***************************************************************************
                                LNN Model

OPEN And: (P1(0) ∧ P2(0) ∧ P3(0)) 
params  α: 1.0,  β: 1.0,  w: [0. 1. 0.]
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: P3 
params  α: 1.0
'0'                                                         TRUE (1.0, 1.0)
'1'                                                        FALSE (0.0, 0.0)

OPEN Predicate: P2 
params  α: 1.0
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: P1 
params  α: 1.0
'0'                                                        FALSE (0.0, 0.0)
'1'                                                        FALSE (0.0, 0.0)

***************************************************************************


In [137]:
model.print()


***************************************************************************
                                LNN Model

OPEN And: (P1(0) ∧ P2(0)) 

OPEN Predicate: P2 

OPEN Predicate: P1 

***************************************************************************


In [38]:
maxvel = 0
maxangvel = 0
for i in range(10000):
    env.reset()
    while True:
        next_state, reward, done, truncated, _ = env.step(0)
        if done:
            break
        if abs(next_state[1]) > maxvel:
            maxvel = abs(next_state[1])
        if abs(next_state[3]) > maxangvel:
            maxangvel = abs(next_state[3])
print(maxvel, maxangvel)

2.0012715 2.9858377


In [32]:
velocities = np.linspace(-2, 2, 41)
angvelocities = np.linspace(-3, 3, 61)
pos = np.linspace(-1.2, 1.2, 25)
ang = np.linspace(-0.2094395, 0.2094395, 25)

In [103]:
LIMIT_DICT = {
    "Position": [-1.2, 1.2],
    "Velocity": [-2, 2],
    "Angle": [-0.2094395, 0.2094395],
    "AngVelocity": [-3, 3]
}

BIN_ARGS = {
    "Position": 10,
    "Velocity": 10,
    "Angle": 10,
    "AngVelocity": 10
}
BIN_SIZES = {}
for key in BIN_ARGS:
    BIN_SIZES[key] = LIMIT_DICT[key][1]/BIN_ARGS[key]

In [112]:
state = env.reset()
env2fol(state[0])

[('Position', True, 1),
 ('Velocity', True, 1),
 ('Angle', False, 2),
 ('AngVelocity', True, 1)]

In [192]:
class LNNCartpole():
    def __init__(n_pos, n_vel, n_ang, n_angvel, num_nodes):
        predicate_list = []
        def create_predicates(n_nodes, name):
            for i in range(n_nodes):
                predicate_list.append(Predicate(name + str(i+1)))
                predicate_list.append(Predicate(name + str(-(i+1))))
        return predicate_list

        def create_n_ary_and(num_nodes, predicate_list):
            and_list = []
            for i in range(num_nodes):
                and_list.append(And(*predicate_list))
            return and_list

        def create_n_ary_or(and_list):
            return Or(*and_list)

        self.model = Model()
        self.preds = {
            "Position": create_predicates(n_pos, "pos"),
            "Velocity": create_predicates(n_vel, "vel"),
            "Angle": create_predicates(n_ang, "ang"),
            "AngVelocity": create_predicates(n_angvel, "angvel")
        }
        self.and_nodes = create_n_ary_and(num_nodes, preds)
        self.or_node = create_n_ary_or(and_nodes)

        self.model.add_knowledge(*and_nodes, or_node)
    def generate_initial_state_dictionary(self, raw_bin_dict):
        d = []
        for key in self.preds:
            arr = [{0: Fact.FALSE}]*(len(self.preds[key])*2)
            
            positive, value = raw_bin_dict[key]
            
            index = 2*(val-1) if positive else 2*val-1
            arr[index][0] = Fact.TRUE
            d.append(dict(zip(self.preds[key], arr)))
        res = {**d[0], **d[1], **d[2], **d[3]}
        return res
            
        

In [116]:
class FOLCartpoleAgent():
    MAXLEN = 10_000
    MIN_REPLAY_SIZE = 1_000
    BATCH_SIZE = 64
    GAMMA = 0.9
    LR = 0.01
    
    def __init__(self, n_bin_args, n_nodes, limits):
        self.left_lnn = create_lnn(*n_bin_args, n_nodes)
        self.right_lnn = create_lnn(*n_bin_args, n_nodes)
        self.limits = limits
        self.bin_args = n_bin_args
        self.bin_sizes = {}
        for key in self.bin_args:
            self.bin_sizes[key] = self.limits[key][1]/self.bin_args
            
        self.replay_memory = collections.deque([], maxlen = self.MAXLEN)
        
    def env2fol(self, obs):
        assert obs.shape == (4,)
        obs = Observation(*obs)
        ret = {}
        for key in self.limits:
            val = getattr(obs, key)
            positive = (val >= 0)
            if positive:
                val_bin = math.ceil(val/self.bin_sizes[key])
                if val/self.bin_sizes[key] - int(val/self.bin_sizes[key]) == 0:
                    val_bin += 1
                val_bin = min(val_bin, self.bin_args[key])
            else:
                val_bin = math.floor(val/self.bin_sizes[key])
                val_bin = max(val_bin, -self.bin_args[key])
            
            ret[key] = (positive, abs(val_bin))
        return ret
    
    def remember(obs):
        '''
            obs: namedtuple given by (state, action, reward, next_state, done)
        '''
        self.replay_memory.append(obs)
    
    def optimize():
        if len(self.replay_memory) < self.MIN_REPLAY_SIZE:
            return
        
        