# TM backward prediction pass (aka backtracking)

На выходе forward prediction pass имеется цепочка из N последовательных переходов между суперпозициями состояний.
Обратный путь предсказания призван восстановить по этой цепочке переходов и некоторому подмножеству конечного супер-состояния (т.е. те биты конечной SDR, "в которые" хотим попасть через N шагов) цепочку конкретных переходов.

Цель:

- реализовать бэктрекинг для однозначной последовательностей
  - с одним хвостом
  - с двумя хвостами
- реализовать бэктрекинг для последовательностей, неоднозначных для first-order TM

In [None]:
import numpy as np
from htm.bindings.sdr import SDR
from htm.algorithms import TemporalMemory as TM, Connections

def format_sdr(sdr, layer_ind=0):
    result = ''
    
    layer = sdr.dense
    if len(layer.shape) == 2:
        layer = sdr.dense[:, layer_ind]

    size = min(input_size, col_size * 5)
    for i in range(size):
        if i > 0 and i % col_size == 0:
            result += ' '
        result += str(layer[i])
    return result

def print_formatted_active_cells(val, active_cells):
    print(f'{val:<4} | {format_sdr(active_cells)} Active')
    for layer_ind in range(1, cells_per_column):
        print(f'     | {format_sdr(active_cells, layer_ind)}')
    
def print_formatted_predictive_cells(anomaly, predictive_cells):
    print(f'{anomaly:.2f} | {format_sdr(predictive_cells)} Predicted')
    for layer_ind in range(1, cells_per_column):
        print(f'     | {format_sdr(predictive_cells, layer_ind)}')

def encode_sdr_val(val):
    data = np.zeros(input_size, dtype=np.int8)
    fr = val * col_size
    to = fr + col_size
    data[fr:to] = 1
    return data

def train_cycle(arr, tm, input_sdr, print_enabled=False, learn=True, reset_enabled=True):
    if reset_enabled:
        tm.reset()
    
    for val in arr:
        input_sdr.dense = encode_sdr_val(val)
        tm.compute(input_sdr, learn=learn)
        if print_enabled:
            print_formatted_active_cells(val, tm.getActiveCells())
    
        tm.activateDendrites(learn)
        if print_enabled:
            print_formatted_predictive_cells(tm.anomaly, tm.getPredictiveCells())

def get_columns_from_cells(cells_sdr, tm):
    return list(sorted(set(tm.columnForCell(i) for i in cells_sdr.sparse)))

def get_vals_from_cells(cols):
    return list(sorted(set(col // col_size for col in cols)))

def predict_cycle(start_val, n_steps, tm, input_sdr, print_enabled=False, learn=False, reset_enabled=True):
    if reset_enabled:
        tm.reset()
    val = start_val
    input_sdr.dense = encode_sdr_val(val)
    for i in range(n_steps):
        tm.compute(input_sdr, learn=learn)
        if print_enabled:
            print_formatted_active_cells(' ', tm.getActiveCells())
    
        tm.activateDendrites(learn)
        prediction = tm.getPredictiveCells()
        if print_enabled:
            print_formatted_predictive_cells(tm.anomaly, prediction)
        
        next_activation_sparse = get_columns_from_cells(prediction, tm)
        print(next_activation_sparse)
        input_sdr.sparse = next_activation_sparse
        print(get_vals_from_cells(get_columns_from_cells(prediction, tm)))

In [None]:
from collections import defaultdict

input_size = 200
col_size = 20
cells_per_column = 4
activationThreshold = 18
connectedPermanence = .5
input_sdr = SDR(input_size)

tm = TM(
    columnDimensions = (input_sdr.size,),
    cellsPerColumn=cells_per_column,
    minThreshold=12,
    activationThreshold=activationThreshold,
    initialPermanence=0.5,
    connectedPermanence=connectedPermanence,
)

seqs = [
#     [0, 1, 2, 3],
#     [0, 2, 1, 3, 2],
#     [0, 1, 4, 3, 2],
    
    # working
#     [0, 1, 3],
#     [0, 2, 4],
    
    # not working
#     [0, 1, 3, 3, 2],
    
    # not working
    [0, 1, 4, 4, 3],
    [0, 2, 2, 4, 3],
]

for cycle in range(len(seqs) * 40):
    seq = seqs[np.random.choice(len(seqs))]
    train_cycle(seq, tm, input_sdr, print_enabled=False, learn=True, reset_enabled=True)

def get_columns_from_cells(cells_sdr, tm):
    return list(sorted(set(tm.columnForCell(i) for i in cells_sdr.sparse)))

def get_vals_from_cells(cols):
    return list(sorted(set(col // col_size for col in cols)))

def plan_to_val(start_val, end_val, max_steps, tm, input_sdr, print_enabled=False, reset_enabled=True):
    if reset_enabled:
        tm.reset()
    val = start_val
    end_sdr = SDR(input_size)
    end_sdr.dense = encode_sdr_val(end_val)
    connection_graph = []
    
    input_sdr.dense = encode_sdr_val(val)
    for i in range(max_steps):
        tm.compute(input_sdr, learn=False)
        active_cells = tm.getActiveCells()
        if print_enabled:
            print_formatted_active_cells(' ', active_cells)
                   
        overlap_with_end_val = end_sdr.getOverlap(input_sdr)        
        print(overlap_with_end_val)
        if overlap_with_end_val >= activationThreshold:
            break
    
        tm.activateDendrites(False)
        prediction = tm.getPredictiveCells()
        if print_enabled:
            print_formatted_predictive_cells(tm.anomaly, prediction)

        # append layer to graph
        active_segments = tm.getActiveSegments()       
        print(f'segments: {active_segments}')
        cells_for_segments = [tm.connections.cellForSegment(segment) for segment in active_segments]
        print(f'cells_se: {cells_for_segments}')
        synapses = [
            (tm.connections.cellForSegment(segment), tm.connections.presynapticCellForSynapse(synapse), segment, synapse, tm.connections.permanenceForSynapse(synapse))
            for segment in active_segments
            for synapse in tm.connections.synapsesForSegment(segment)
            if tm.connections.permanenceForSynapse(synapse) >= connectedPermanence
        ]
        synapses_for_cells = defaultdict(list)
        for synapse in synapses:
            cell = synapse[0]
            synapses_for_cells[cell].append(synapse)
            
        connection_graph.append(synapses_for_cells)
        # -----------------
            
        next_activation_sparse = get_columns_from_cells(prediction, tm)
        input_sdr.sparse = next_activation_sparse
        print(get_vals_from_cells(get_columns_from_cells(prediction, tm)))

    if overlap_with_end_val < activationThreshold:
        return

    end_columns = set(end_sdr.sparse)
    active_cells = [
        cell
        for cell in active_cells.flatten().sparse
        if tm.columnForCell(cell) in end_columns
    ]
    print(connection_graph[-1].keys())
    backtracking_graph = [dict() for _ in connection_graph]

    for t in range(len(connection_graph) - 1, -1, -1):
        backtracking_graph[t] = {
            cell: connection_graph[t][cell]
            for cell in active_cells
        }
        
        print('===')
        
        presynaptic_cells = {
            connection[1]
            for cell in active_cells
            for connection in connection_graph[t][cell]
        }
        active_cells = list(presynaptic_cells)
        print('---------------------')
        print(active_cells)
        
    
        

plan_to_val(start_val=0, end_val=3, max_steps=5, tm=tm, input_sdr=input_sdr, print_enabled=True, reset_enabled=True)