In [None]:
from htm_rl.agent.agent import Agent, AgentRunner
from htm_rl.agent.memory import Memory, TemporalMemory
from htm_rl.agent.planner import Planner
from htm_rl.envs.mdp import Mdp
from htm_rl.common.sa_sdr_encoder import SaSdrEncoder, SaSuperposition
from htm_rl.common.int_sdr_encoder import IntSdrEncoder


import numpy as np
import matplotlib.pyplot as plt


def format_sa_superposition(sa_superposition: SaSuperposition) -> str:
    """
    Formats SA superposition
    """
    states = sa_superposition.state
    actions = sa_superposition.action

    format_ = 'SA superposition: \n ------------------ \n'

    format_ += 'Possible states:\n'
    for state in states:
        format_ += str(state) + ' '

    format_ += '\nPossible actions:\n'
    for action in actions:
        format_ += str(action) + ' '

    format_ += '\n'

    return format_

Basic example for standard MDP environment

In [None]:
mdp = Mdp(transitions = {
            0: {0: 4, 1: 1},
            1: {0: 1, 1: 2},
            2: {0: 2, 1: 3},
            3: {0: 3, 1: 0},
            4: None
        })

In [None]:
mdp.n_states

In [None]:
state_encoder = IntSdrEncoder('state',
                              n_values=mdp.n_states,
                              value_bits=10,
                              activation_threshold=7)


action_encoder = IntSdrEncoder('action',
                               n_values=mdp.n_actions,
                               value_bits=10,
                               activation_threshold=7)



In [None]:
sa_encoder = SaSdrEncoder(state_encoder, action_encoder)

In [None]:
sa_encoder.total_bits
sa_encoder.activation_threshold

In [None]:
tm = TemporalMemory(n_columns=sa_encoder.total_bits,
                    cells_per_column=1,
                    activation_threshold=sa_encoder.activation_threshold,
                    learning_threshold=action_encoder.activation_threshold,
                    initial_permanence=0.49,
                    connected_permanence=0.5,
                    maxNewSynapseCount=sa_encoder.value_bits,
                    maxSynapsesPerSegment=sa_encoder.value_bits)

In [None]:
memory = Memory(tm, sa_encoder, sa_encoder.format, format_sa_superposition)

In [None]:
planner = Planner(memory,
                  planning_horizon=10,
                  goal_memory_size=2)

In [None]:
agent = Agent(memory, planner, mdp.n_actions)

In [None]:
run = AgentRunner(agent, mdp,
                  n_episodes=100,
                  max_steps=10,
                  pretrain=50,
                  verbosity=1)

In [None]:
run.run()

In [None]:
fig = plt.figure(figsize=(10, 7))
steps = np.array(run.train_stats.steps)
plt.plot(np.arange(steps.size), steps, '.')