In [None]:
import sys
sys.path.append('../htm_rl/htm_rl/')

from agent.agent import Agent, AgentRunner
from agent.memory import Memory, TemporalMemory
from agent.planner import Planner
from common.sa_sdr_encoder import SaSdrEncoder, format_sa_superposition
from common.base_sa import SaRelatedComposition, Sa, SaSuperposition
from common.int_sdr_encoder import IntSdrEncoder, IntRangeEncoder
from common.int_sdr_encoder import SequenceSdrEncoder
from envs.mymdp import GridWorld


import numpy as np
import matplotlib.pyplot as plt

In [None]:
world_description = [[2,0,0],
                     [1,1,0],
                     [0,0,0]]

In [None]:
gw = GridWorld(world_description, (3, 3), agent_initial_position={'row': 2, 'column': 0})

In [None]:
gw.world_size

In [None]:
gw.render()

In [None]:
max_steps = 12

state_encoder = SequenceSdrEncoder('state',
                                   encoders=[IntSdrEncoder('distance',
                                                                   gw.world_size[0],
                                                                   5,
                                                                   4),
                                             IntSdrEncoder('surface', 3, 5, 4),
                                             IntRangeEncoder('row', (-(gw.world_size[0]-1),
                                                                     gw.world_size[1]-1), 5, 4),
                                             IntRangeEncoder('column', (-(gw.world_size[0]-1),
                                                                     gw.world_size[1]-1), 5, 4),
                                             IntSdrEncoder('direction', 4, 5, 4)],
                                   size=5)

In [None]:
print(state_encoder.value_bits)
state_encoder.total_bits

In [None]:
action_encoder = IntSdrEncoder('action', gw.n_actions,
                              value_bits=6, activation_threshold=4)

In [None]:
sa_encoder = SaSdrEncoder(state_encoder, action_encoder)

In [None]:
sa_encoder.total_bits, sa_encoder.value_bits, sa_encoder.activation_threshold

In [None]:
action_encoder.activation_threshold

In [None]:
sa_encoder.value_bits ** 2

In [None]:
tm = TemporalMemory(n_columns=sa_encoder.total_bits,
                    cells_per_column=8,
                    activation_threshold=sa_encoder.activation_threshold,
                    learning_threshold=sa_encoder.activation_threshold,
                    initial_permanence=0.5,
                    connected_permanence=0.5,
                    maxNewSynapseCount=sa_encoder.value_bits,
                    maxSynapsesPerSegment=sa_encoder.value_bits,
                    permanenceIncrement=0.1,
                    permanenceDecrement=0,
                    predictedSegmentDecrement=0
                    )

In [None]:
tm.activation_threshold, sa_encoder.value_bits

In [None]:
tm.getMaxSegmentsPerCell(), tm.getMaxNewSynapseCount(), tm.getMaxSynapsesPerSegment()


In [None]:
memory = Memory(tm, sa_encoder, sa_encoder.format, format_sa_superposition)

In [None]:
actions = [2, 2, 1, 2, 2, 1, 2, 2]

In [None]:
verbosity = 3
for i in range(3):
    if verbosity>1:
        print()
        print(f'*** cycle {i+1} ***')
        print()
    state, reward, done = gw.reset(), 0, False
    for action in actions:
        if verbosity > 1:
            gw.render()
            print(f'Action {action} State: {state}')
        memory.train(Sa(state, action), verbosity)
        state, _, _, info = gw.step(action)