In [102]:
from lnn import Model, Predicate, Variable, And, Or, Predicates, Fact, Loss, Propositions, Implies, World, Variables
import torch.nn as nn
import gym
import numpy as np
import torch

In [2]:
env = gym.make("CartPole-v1")

In [87]:
model = Model()

p1, p2, p3 = Propositions("p1", "p2", "p3")
a1 = And(p1, p2)
o1 = Or(p1, p2)
a2 = And(a1, o1)
i1 = Implies(a2, p3, world= World.AXIOM)

In [89]:
model.parameters()

[tensor(1.),
 tensor(1.),
 tensor([1., 1.], requires_grad=True),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor([1., 1.], requires_grad=True),
 tensor(1.),
 tensor(1.),
 tensor([1., 1.], requires_grad=True),
 tensor(1.),
 tensor(1., requires_grad=True),
 tensor([1., 1.], requires_grad=True),
 tensor(1.)]

In [92]:
model.add_knowledge(a1, a2, o1, i1)
model.add_data({
    p1: Fact.TRUE,
    p2: Fact.TRUE
})


model.add_labels({
    p3: Fact.FALSE
})

model.train(losses = [Loss.SUPERVISED])
model.print(params = True)


***************************************************************************
                                LNN Model

AXIOM Implies: (((p1 ∧ p2) ∧ (p1 ∨ p2)) → p3)               TRUE (1.0, 1.0)
params  α: 1.0,  β: 1.0,  w: [1. 1.]
OPEN Proposition: p3                                        TRUE (1.0, 1.0)
params  α: 1.0
OPEN And: ((p1 ∧ p2) ∧ (p1 ∨ p2))                           TRUE (1.0, 1.0)
params  α: 1.0,  β: 1.0,  w: [1. 1.]
OPEN Or: (p1 ∨ p2)                                          TRUE (1.0, 1.0)
params  α: 1.0,  β: 1.0,  w: [1. 1.]
OPEN And: (p1 ∧ p2)                                         TRUE (1.0, 1.0)
params  α: 1.0,  β: 1.0,  w: [1. 1.]
OPEN Proposition: p2                                        TRUE (1.0, 1.0)
params  α: 1.0
OPEN Proposition: p1                                        TRUE (1.0, 1.0)
params  α: 1.0
***************************************************************************


In [145]:
# construct the model from formulae
model = Model()
p1, p2 = Predicates("P1", "P2")
x = Variable("X")

a1 = And(p1(x), p2(x))
model.add_knowledge(a1)

model.add_data({
    p1: {
        '0': Fact.FALSE,
        '1': Fact.FALSE
    },
    p2: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    }
})
model.add_labels({
    a1: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    }
})

model.train(losses = [Loss.SUPERVISED])
model.print(params=True)





***************************************************************************
                                LNN Model

OPEN And: (P1(0) ∧ P2(0)) 
params  α: 1.0,  β: 1.0,  w: [0. 1.]
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: P2 
params  α: 1.0
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: P1 
params  α: 1.0
'0'                                                        FALSE (0.0, 0.0)
'1'                                                        FALSE (0.0, 0.0)

***************************************************************************


In [137]:
model.print()


***************************************************************************
                                LNN Model

OPEN And: (P1(0) ∧ P2(0)) 

OPEN Predicate: P2 

OPEN Predicate: P1 

***************************************************************************


In [3]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [178]:
maxvel = 0
maxangvel = 0
for i in range(100):
    env.reset()
    while True:
        next_state, reward, done, _ = env.step(0)
        if done:
            break
        if abs(next_state[1]) > maxvel:
            maxvel = abs(next_state[1])
        if abs(next_state[3]) > maxangvel:
            maxangvel = abs(next_state[3])
print(maxvel, maxangvel)

1.9946576 2.9537146


In [6]:
velocities = np.linspace(-2, 2, 41)
angvelocities = np.linspace(-3, 3, 61)
pos = np.linspace(-1.2, 1.2, 25)
ang = np.linspace(-0.2094395, 0.2094395, 25)

In [7]:
model = Model()
velocity_variables = []
angvelocity_variables = []
pos_variables = []
ang_variables = []

In [8]:
for i in range(41):
  velocity_variables.append(Predicate("vel_"+str(i+1)))
for i in range(61):
  angvelocity_variables.append(Predicate("angvel_"+str(i+1)))
for i in range(25):
  pos_variables.append(Predicate("pos_"+str(i+1)))
for i in range(25):
  ang_variables.append(Predicate("ang_"+str(i+1)))

In [189]:
def create_predicates(num_bins):
    predicate_list = []
    for i in range(num_bins):
        predicate_list.append(Predicate("p_" +str(i+1)))
    return predicate_list

def create_n_ary_and(num_nodes, predicate_list):
    and_list = []
    for i in range(num_nodes):
        and_list.append(And(*predicate_list))
    return and_list

def create_n_ary_or(and_list):
    return Or(*and_list)

def create_lnn(num_bins, num_nodes):
    model = Model()
    preds = create_predicates(num_bins)
    and_nodes = create_n_ary_and(num_nodes, preds)
    or_node = create_n_ary_or(and_nodes)
    
    model.add_knowledge(*and_nodes, or_node)
    
    return model

In [197]:
state = env.reset()

In [202]:
torch.tensor(state).shape

torch.Size([4])

In [None]:
def env_2_fol(obs):
    assert obs.shape = 

In [190]:
left = create_lnn(10, 3)

In [192]:
left.print()


***************************************************************************
                                LNN Model

OPEN Or: ((p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9)) ∨ (p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9)) ∨ (p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9))) 

OPEN And: (p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9)) 

OPEN And: (p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9)) 

OPEN And: (p_1(0) ∧ p_2(1) ∧ p_3(2) ∧ p_4(3) ∧ p_5(4) ∧ p_6(5) ∧ p_7(6) ∧ p_8(7) ∧ p_9(8) ∧ p_10(9)) 

OPEN Predicate: p_10 

OPEN Predicate: p_9 

OPEN Predicate: p_8 

OPEN Predicate: p_7 

OPEN Predicate: p_6 

OPEN Predicate: p_5 

OPEN Predicate: p_4 

OPEN Predicate: p_3 

OPEN Predicate: p_2 

OPEN Predicate: p_1 

************************************************************

In [180]:
test = create_predicates(10)

In [186]:
And(*test)

<lnn.symbolic.logic.n_ary_neuron.And at 0x1836dd306d0>