In [1]:
from lnn import Model, Predicate, Variable, And, Or, Predicates, Fact, Loss, Propositions, Implies, World, Variables
import numpy as np	
import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import math

Observation = collections.namedtuple("Observation", ("Position", "Velocity", "Angle", "AngVelocity"))
Transition = collections.namedtuple("Transition", ("state", "action", "next_state", "reward", "done"))

class LNNCartpole():
    def __init__(self, num_nodes, n_pos, n_vel, n_ang, n_angvel, left):
        def create_predicates(n_nodes, name, var):
            predicate_list = []
            for i in range(n_nodes):
                predicate_list.append(Predicate(name + str(i+1))(var))
                predicate_list.append(Predicate(name + str(-(i+1)))(var))
            return predicate_list

        def create_n_ary_and(num_nodes, preds):
            and_list = []
            for _ in range(num_nodes):
                and_list.append(And(*preds["Position"], *preds["Velocity"], *preds["Angle"], *preds["AngVelocity"]))
            return and_list

        def create_n_ary_or(and_list):
            return Or(*and_list)

        self.model = Model()
        x = Variable('x')
        self.preds = {
            "Position": create_predicates(n_pos, "pos", x),
            "Velocity": create_predicates(n_vel, "vel", x),
            "Angle": create_predicates(n_ang, "ang", x),
            "AngVelocity": create_predicates(n_angvel, "angvel", x)
        }
        self.and_nodes = create_n_ary_and(num_nodes, self.preds)
        self.or_node = create_n_ary_or(self.and_nodes)

        self.model.add_knowledge(*self.and_nodes, self.or_node)
        
        self.left = left

    def generate_state_dictionary(self, processed_fol_arr):
        d = []
        for key in self.preds:
            value_array = []
            for i, fol in enumerate(processed_fol_arr):
                positive, value = fol[key]
                if self.left:
                    positive = not(positive)
                index = 2*(value-1) if positive else 2*value-1
                for j in range(len(self.preds[key])):
                    if len(value_array) <= j:
                        value_array.append({})
                    
                    if j == index:
                        value_array[j][str(i)] = Fact.TRUE
                    else:
                        value_array[j][str(i)] = Fact.FALSE
                
                predicate_array = np.array(self.preds[key], dtype = object)[:, 0]
                
                d.append(dict(zip(predicate_array, value_array)))
        res = {**d[0], **d[1], **d[2], **d[3]}
        return res

    def generate_label_dictionary(self, qval_arr, err=0.1):
        '''
            params:
                qval_arr: array of qvals for training
                err: float error radius on truth bounds

            returns:
                label_dict: dictionary {str(i): (qval-err, qval+err)} for each qval in qval_arr
        '''
        label_dict = {self.or_node: {str(i): (max(qval-err, 0.), min(qval+err, 1.)) for i, qval in enumerate(qval_arr)}}
        return label_dict

    def forward(self, processed_fol_arr):
        '''
            params:
                processed_fol_arr: array of fol observations used to generate state dict

            returns:
                output: bsz x 2 tensor of lower/upper bounds for each batch example
        '''
        self.model.flush()
        state_dict = self.generate_state_dictionary(processed_fol_arr)
        self.model.add_data(state_dict)
        self.model.infer()
        return self.or_node.get_data()

    def train_step(self, obs, labels, steps=1):
        '''
            params:
                obs: array of dictionaries corresponding to first order logic of input nodes
                labels: array of floats corresponding to the labels of observations

            returns:
                loss: loss over training	
        '''
        assert len(obs) == labels.shape(0)

        self.model.flush()

        state_dict = self.generate_state_dictionary(obs)
        self.model.add_data(state_dict)
        label_dict = self.generate_label_dictionary(labels)
        self.model.add_labels(label_dict)
        epochs, loss = self.model.train(losses=[Loss.SUPERVISED], epochs=steps)
        return loss

class FOLCartpoleAgent():
    MAXLEN = 10_000
    MIN_REPLAY_SIZE = 1_000
    BATCH_SIZE = 64
    GAMMA = 0.9
    LR = 0.01

    def __init__(self, n_bin_args, n_nodes, limits):
        self.left_lnn = LNNCartpole(n_nodes, **n_bin_args, left = True)
        self.right_lnn = LNNCartpole(n_nodes, **n_bin_args, left = False)
        self.limits = limits
        self.bin_args = n_bin_args
        self.bin_sizes = {}
        for (key1, key2) in zip(self.limits, self.bin_args):
            self.bin_sizes[key1] = self.limits[key1][1]/self.bin_args[key2]

        self.replay_memory = collections.deque([], maxlen = self.MAXLEN)

    def envs2fol(self, obs_arr):
        ret = []
        for obs in obs_arr:
            ret.append(self.env2fol(obs))
        return ret

    def env2fol(self, obs):
        assert obs.shape == (4,)
        obs = Observation(*obs)
        ret = {}
        for (key1, key2) in zip(self.limits, self.bin_args):
            val = getattr(obs, key1)
            positive = (val >= 0)
            if positive:
                val_bin = math.ceil(val/self.bin_sizes[key1])
                if val/self.bin_sizes[key1] - int(val/self.bin_sizes[key1]) == 0:
                    val_bin += 1
                val_bin = min(val_bin, self.bin_args[key2])
            else:
                val_bin = math.floor(val/self.bin_sizes[key1])
                val_bin = max(val_bin, -self.bin_args[key2])

            ret[key1] = (positive, abs(val_bin))
        return ret

    def remember(self, obs):
        '''
            obs: namedtuple given by (state, action, reward, next_state, done)
        '''
        self.replay_memory.append(obs)

    def optimize(self):
        if len(self.replay_memory) < self.MIN_REPLAY_SIZE:
            return

        transitions = [self.replay_memory[idx] for idx in np.random.permutation(len(self.replay_memory))[:self.MINIBATCH_SIZE]]
        batch = Transition(*zip(*transitions))

        #action_batch = torch.tensor(batch.action, device = self.device, dtype = torch.int64)
        reward_batch = torch.tensor(batch.reward, device = self.device)

        final_mask = torch.tensor([val == False for val in batch.done], device = self.device)		

        next_state_batch = self.envs2fol(np.array(batch.next_state)[final_mask])

        right_next_values = self.right_lnn.forward(next_state_batch).mean(dim=1)
        left_next_values = self.left_lnn.forward(next_state_batch).mean(dim=1)

        next_state_values = torch.zeros(self.MINIBATCH_SIZE, device = self.device)
        next_state_values[final_mask] = torch.cat((right_next_values, left_next_values), dim=1).max(dim=1)

        expected_next_state_values = next_state_values * self.GAMMA + reward_batch

        left_mask = torch.tensor([val == 0 for val in batch.action], device = self.device) #True is left, False is Right
        right_mask = left_mask == False

        state_batch_right = self.envs2fol(np.array(batch.state)[right_mask])
        state_batch_left = self.envs2fol(np.array(batch.state)[left_mask])

        loss_left = self.left_lnn.train_step(state_batch_left, expected_next_state_values[left_mask])
        loss_right = self.right_lnn.train_step(state_batch_right, expected_next_state_values[right_mask])

        return loss_left + loss_right


    def sample_random_action(self):
        '''
            0: left
            1: right
        '''
        return np.random.randint(2)

    def get_action(self, state):
        state_fol = [self.env2fol(state)]
        left_q = self.left_lnn.forward(state_fol).mean(dim=1)
        right_q = self.right_lnn.forward(state_fol).mean(dim=1)
        return torch.argmax(torch.cat((left_q, right_q), dim = 0))





In [10]:
model = Model()

p1, p2 = Predicates('p1', 'p2')
x = Variable('x')

a1 = And(p1(x), p2(x))

model.add_knowledge(a1)

model.add_data({
    p1: {
        '0': Fact.TRUE,
        '1': Fact.FALSE
    },
    p2: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    }
})

model.add_labels({
    a1: {
        '0': Fact.TRUE,
        '1': Fact.TRUE
    }
})

In [13]:
model.train(losses = [Loss.SUPERVISED])
model.print(params = True)


***************************************************************************
                                LNN Model

OPEN And: (p1(0) ∧ p2(0)) 
params  α: 1.0,  β: 1.0,  w: [0. 1.]
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: p2 
params  α: 1.0
'0'                                                         TRUE (1.0, 1.0)
'1'                                                         TRUE (1.0, 1.0)

OPEN Predicate: p1 
params  α: 1.0
'0'                                                         TRUE (1.0, 1.0)
'1'                                                        FALSE (0.0, 0.0)

***************************************************************************


In [18]:
torch.optim.Adam(model.parameters())

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [166]:
import gym 
EPSILON_START = 1.0
EPSILON_END = 0.1
WARMUP_STEPS = 1_000
NUM_EPISODES = 10_000

n_bin_args = {
    "n_pos": 10,
    "n_vel": 5,
    "n_ang": 10,
    "n_angvel": 5
}

limits = {
    "Position": [-1.2, 1.2],
    "Velocity": [-2, 2],
    "Angle": [-0.2094395, 0.2094395],
    "AngVelocity": [-3, 3]
}


agent = FOLCartpoleAgent(n_bin_args, n_nodes = 3, limits = limits)
env = gym.make("CartPole-v1")

In [169]:
state, info = env.reset()
action = agent.get_action(state)
print(action)

tensor(0)


In [7]:
def train():
    epsilon = EPSILON_START
    episode_num = 0
    episode_runs = []
    episode_losses = []
    for episode in tqdm.tqdm(range(NUM_EPISODES)):
        total_loss, step = 0,0
        state, info = env.reset()
        while True:
            if np.random.random() > epsilon:
                action = agent.sample_random_action()
            else:
                action = agent.get_action(state)
            next_state, reward, terminal, truncated, info = env.step(action)
            reward = reward if not terminal else 0
            agent.remember(Transition(state, action, next_state, reward, terminal))
            #loss = agent.optimize()
            loss = None
            state = next_state

            if loss is not None:
                total_loss += loss
                step += 1
            
            if terminal or truncated:
                if step > 0:
                    print("Run: " + str(episode) + ", score: " + str(step) + ", episode_loss: " + str(total_loss/step))
                    episode_runs.append(step)
                    episode_losses.append(total_loss/step)
                    epsilon = epsilon * epsilon_decay
                    epsilon = min(epsilon, epsilon_end)
                break

In [170]:
k = agent.replay_memory

transitions = [k[idx] for idx in np.random.permutation(len(k))[:10]]
batch = Transition(*zip(*transitions))
for obs in batch.done:
    print(obs)

True
True
False
False
False
False
False
False
False
False


In [178]:
mask = torch.tensor([val == False for val in batch.done])

In [132]:
train()

 35%|██████████████████████████████████▎                                                               | 3495/10000 [00:02<00:05, 1219.76it/s]


KeyboardInterrupt: 

In [46]:
def train():
		epsilon = EPS_START
		episode_num = 0
		step = 0
		for ep in tqdm.tqdm(range(NUM_EPISODES)):
				total_loss, step = 0, 0
				for query in database.queries:
						env.reset(query)
						while True:
								if env.current_length == MAX_REGEX_LENGTH-1:
									action = vocab.word_dict["$"]
								else:
											if np.random.random() > epsilon:
													action = agent.get_action(query, env.current_state, env.current_length)
											else:
													action = env.sample_random_action()

								curr_state = copy.deepcopy(env.current_state)

								next_state, reward, length, done = env.step(action)
								if next_state is not None:
										next_state = torch.tensor(next_state)
										# if length > 1:
										#     temp_reward = env.hypothesis(curr_state, length)
										#     agent.update_replay_memory(Transition(
										#         torch.tensor(query), torch.tensor(curr_state), torch.tensor(length+1), torch.tensor(vocab.word_dict["$"]),
										#         torch.tensor([-1]*10), torch.tensor(temp_reward), torch.tensor(True)))
								else:
										next_state = torch.tensor([-1]*10)
								agent.update_replay_memory(Transition(
										torch.tensor(query), torch.tensor(curr_state), torch.tensor(length), torch.tensor(action),
										next_state, torch.tensor(reward), torch.tensor(done)))
								loss = agent.optimize()
								if loss is not None:
										total_loss += loss
										step += 1
								# if step > 0 and step % 100 == 0:
								# 		print("step {:d} | loss {:.3f} | lr {:.5f}".format(
								# 				step, total_loss/step, agent.optimizer.param_groups[0]['lr']))
								# 		losses.append(total_loss)
								# 		steps.append(step)
								loss = agent.optimize(only_success=True)
								if loss is not None:
										total_loss += loss
										step += 1
								if done:
										break

						episode_num += 1
						if len(agent.replay_memory) > agent.MIN_REPLAY_MEMORY_SIZE and episode_num % agent.update_target_every == 0:
								print("model updated")
								agent.target_model.load_state_dict(agent.policy_model.state_dict())
				if step > 0:
					print("ep {:d} | loss {:.3f}".format(ep, total_loss/step))
				# Decay epsilon
				if len(agent.replay_memory) > agent.MIN_REPLAY_MEMORY_SIZE and epsilon > EPS_END:
						epsilon *= EPS_DECAY
						epsilon = max(EPS_END, epsilon)

				if len(agent.replay_memory) > agent.MIN_REPLAY_MEMORY_SIZE:
						if ep % 10 == 0:
								prec, rec, error= test(verbose = True)
						else:
								prec, rec, error= test(verbose = False)
    #torch.save(agent.policy_model.state_dict(), "model")

<__main__.FOLCartpoleAgent at 0x2bbbd923970>