In [None]:
from htm_rl.agent.agent import Agent, AgentRunner
from htm_rl.agent.memory import Memory, TemporalMemory
from htm_rl.agent.planner import Planner
from htm_rl.common.s_sdr_encoder import StateSDREncoder
from htm_rl.common.sa_sdr_encoder import SaSdrEncoder, format_sa_superposition
from htm_rl.common.base_sa import SaRelatedComposition, Sa, SaSuperposition
from htm_rl.common.int_sdr_encoder import IntSdrEncoder, IntRangeEncoder
from htm_rl.envs.gridworld_pomdp import GridWorld

from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from random import randint

In [None]:
def init_memory(pars, sa_encoder, start_indicator=None, output_file=None):
    tm = TemporalMemory(**pars)
    memory = Memory(tm, sa_encoder, sa_encoder.format, format_sa_superposition,
                    start_indicator=start_indicator, output_file=output_file)
    return memory

In [None]:
world_description = [[2,0,0],
                     [1,1,0],
                     [0,0,0]]

In [None]:
gw = GridWorld(world_description, (3, 3), agent_initial_position={'row': 2, 'column': 0},
               observable_vars=['window'], window_coords={'top_left': (1, -1),
                                                          'bottom_right': (0, 1)})

In [None]:
gw.render()

In [None]:
gw.observable_state, gw.filtered_observation

In [None]:
action_bits = 5

state_encoder = StateSDREncoder('state', 4, (2, 3), 1)

In [None]:
action_encoder = IntSdrEncoder('action', gw.n_actions,
                              value_bits=action_bits, activation_threshold=action_bits)

In [None]:
sa_encoder = SaSdrEncoder(state_encoder, action_encoder)

In [None]:
sa_encoder.total_bits, sa_encoder.value_bits, sa_encoder.activation_threshold

In [None]:
pars = dict(n_columns=sa_encoder.total_bits,
                                cells_per_column=1,
                                activation_threshold=sa_encoder.value_bits,
                                learning_threshold=sa_encoder.value_bits,
                                initial_permanence=0.6,
                                connected_permanence=0.5,
                                maxNewSynapseCount=sa_encoder.value_bits,
                                maxSynapsesPerSegment=sa_encoder.value_bits,
                                permanenceIncrement=0.1,
                                permanenceDecrement=0.025,
                                predictedSegmentDecrement=0.005)

In [None]:
memory = init_memory(pars, sa_encoder)

In [None]:
state_encoder.activation_threshold

In [None]:
state_encoder.value_bits

In [None]:
planner = Planner(memory, 14, 1, 0.2, state_encoder)

In [None]:
agent = Agent(memory, planner, gw.n_actions)

In [None]:
max_steps = 25
run = AgentRunner(agent, gw, 300, max_steps, 100, 0)

In [None]:
run.run()

In [None]:
run.agent.planner.episode_goal_memory.goals[0].union

In [None]:
run.agent.set_planning_horizon(14)
run.verbosity = 3
run.n_episodes = 1
run.pretrain = 0
run.run()

In [None]:
run.train_stats.rewards[-1]

In [None]:
fig = plt.figure(figsize=(10, 7))
steps = np.array(run.train_stats.steps)
plt.plot(np.arange(steps.size), steps, '.')
