In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import collections
import math
import tqdm

from CartpoleAgent import FOLCartpoleAgent

Observation = collections.namedtuple("Observation", ("Position", "Velocity", "Angle", "AngVelocity"))
Transition = collections.namedtuple("Transition", ("state", "action", "next_state", "reward", "done"))

EPSILON_START = 1.0
EPSILON_END = 0.1
WARMUP_EPISODES = 30
NUM_EPISODES = 50

n_bin_args = {
    "n_pos": 2,
    "n_vel": 2,
    "n_ang": 2,
    "n_angvel": 2
}

limits = {
    "Position": [-1.2, 1.2],
    "Velocity": [-2, 2],
    "Angle": [-0.2094395, 0.2094395],
    "AngVelocity": [-3, 3]
}

agent = FOLCartpoleAgent(n_bin_args, n_nodes = 5, limits = limits, t="double")
env = gym.make("CartPole-v1")

def train():
    epsilon = EPSILON_START
    episode_num = 0
    episode_runs = []
    episode_losses = []
    for episode in tqdm.tqdm(range(NUM_EPISODES)):
        total_loss, step = 0,0
        state, info = env.reset()
        while True:
            if np.random.random() < epsilon:
                action = agent.sample_random_action()
            else:
                action = agent.get_action(state)
            next_state, reward, terminal, truncated, info = env.step(action)
            reward = reward/10 if not terminal else 0
            agent.remember(Transition(state, action, next_state, reward, terminal))
            loss = agent.optimize()
            state = next_state
            
            if loss is not None:
                total_loss += loss
                step += 1
            
            if terminal or truncated:
                if step > 0:
                    print("Run: " + str(episode) + ", score: " + str(step) + ", episode_loss: " + str(total_loss/step))
                    episode_runs.append(step)
                    episode_losses.append(total_loss/step)
                    epsilon -= (EPSILON_START - EPSILON_END)/WARMUP_EPISODES
                    
                    epsilon = min(epsilon, EPSILON_END)
                break

In [None]:
train()

In [None]:
import itertools

In [None]:
itertools.chain.from_iterable([n.parameters() for n in agent.lnn.model.nodes.values()])

In [None]:
for val in agent.lnn.model.nodes.values():
    print(val.parameters())

In [None]:
# import random
def distill_policy(self):
    lnamed_params = self.left_lnn.model.named_parameters()
    lpolicy = []
    lor_weights = []
    for key in lnamed_params:
        if '∨' in key and '.weights' in key: #OR
            lor_weights = lnamed_params[key]
        elif '∧' in key and '.weights' in key:
            weights = lnamed_params[key]
            used_inputs = weights > (1./len(weights))
            input_names = key[1:-(len(').weights'))].split(' ∧ ')
            used_inputs = [inp[:-3] for i, inp in enumerate(input_names) if used_inputs[i]]
            rule = ''
            if used_inputs:
                rule = ' ∧ '.join(used_inputs)
            lpolicy.append(rule)

    rnamed_params = self.right_lnn.model.named_parameters()
    rpolicy = []
    ror_weights = []
    for key in rnamed_params:
        if '∨' in key and '.weights' in key: #OR
            ror_weights = rnamed_params[key]
        elif '∧' in key and '.weights' in key:
            weights = rnamed_params[key]
            used_inputs = weights > (1./len(weights))
            input_names = key[1:-(len(').weights'))].split(' ∧ ')
            used_inputs = [inp[:-3] for i, inp in enumerate(input_names) if used_inputs[i]]
            rule = ''
            if used_inputs:
                rule = ' ∧ '.join(used_inputs)
            rpolicy.append(rule)
    return lpolicy, lor_weights, rpolicy, ror_weights

In [None]:
lpolicy, lor_weights, rpolicy, ror_weights = distill_policy(agent)
for rule in lpolicy:
    print(rule)
print(lor_weights)
for rule in rpolicy:
    print(rule)
print(ror_weights)

In [None]:
parameters = agent.left_lnn.model.parameters_grouped_by_neuron()
named_params = agent.left_lnn.model.named_parameters()
used = []
used_ands = torch.Tensor()
for x in parameters:
    if x['neuron_type'] == 'And':
        weights = x['params'][2]
        used.append(weights > (1./len(weights)))
    elif x['neuron_type'] == 'Or':
        or_weights = parameters[-1]['params'][2]
        used_ands = or_weights > (1./len(or_weights))
for u in used:
    print(u)
print(used_ands)

In [None]:
agent.left_lnn.model.print(params = True)

In [None]:
agent.lnn.model.parameters()

In [None]:
state, _ = env.reset()
action = agent.get_action(state)

In [None]:
def distill2(self):
    '''
    returns:
    '''
    

In [None]:
def distill3(self):
    

In [None]:
policy = distill(agent)
print(policy)

In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
import collections
import tqdm

from CartPole import DQNSolver

Observation = collections.namedtuple('Observation', ('state', 'action', 'reward', 'next_state', 'done'))



env = gym.make("CartPole-v1")
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)

epsilon_start = 1
epsilon_end = 0.1
epsilon_decay = 0.97
epsilon = epsilon_start

max_episodes = 150
episode = 0

episode_runs = []
episode_losses = []
while episode < max_episodes:
    episode += 1
    state = env.reset()
    state = torch.tensor(np.reshape(state[0], [1, observation_space]))
    step = 0
    episode_loss = 0
    while True:
        action = dqn_solver.choose_action(state,epsilon)
        state_next, reward, terminal, truncated, info = env.step(action)
        reward = reward if not terminal else 0
        state_next = torch.tensor(np.reshape(state_next, [1, observation_space]))
        dqn_solver.remember(Observation(state, action, reward, state_next, terminal))
        loss = dqn_solver.replay()

        if loss is not None:
            episode_loss += loss
            state = state_next
            step += 1

        if terminal or truncated:
            if step > 0:
                print("Run: " + str(episode) + ", score: " + str(step) + ", episode_loss: " + str(episode_loss/step))
                episode_runs.append(step)
                episode_losses.append(episode_loss/step)
                epsilon = epsilon * epsilon_decay
                epsilon = min(epsilon, epsilon_end)
            break

In [None]:
ret = dqn_solver.distill(n_bins=[6,6,6,6])

In [None]:
print(ret[5])

In [None]:
def distill2(self, n_bins=[4,4,4,4], lims=[[-1.2, 1.2], [-2, 2], [-0.2094395, 0.2094395], [-3, 3]]):
    self.
    return ret