In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class MiniColumn:
    def __init__(self, input_size):
        self.input_size = input_size
        


class Column:
    def __init__(self, input_size, n_mini_cols, alpha=0.1, thresh=0.5, gamma=0.1):
        self.input_size = input_size
        self.n_mini_cols = n_mini_cols
        self.alpha = alpha
        self.thresh = thresh
        self.gamma = gamma

        self.in_weight = nn.Parameter(torch.randn((n_mini_cols, input_size)))
        self.pred_weight = nn.Parameter(torch.randn((n_mini_cols, n_mini_cols)))
        self.lateral_inhibition = nn.Parameter(torch.ones((n_mini_cols, n_mini_cols)) - 2*torch.eye(n_mini_cols))

    def init_state(self, batch_size=1):
        return torch.zeros((batch_size, self.n_mini_cols))

    def predict(self, state):
        actv = torch.where(state > self.thresh, torch.ones_like(state), torch.zeros_like(state))
        return F.linear(actv, self.pred_weight)

    def inhibit(self, state):
        actv = torch.where(state > self.thresh, torch.ones_like(state), torch.zeros_like(state))
        return F.linear(actv, self.lateral_inhibition)

    def update(self, state, input):
        pred = self.predict(state)
        inhib = self.inhibit(state)
        state = state + self.alpha * (input + (self.gamma * pred) - inhib)

    def __call__(self, input):
        state = self.init_state(input.shape[0])
        input = F.linear(input, self.in_weight)
    


In [18]:
Sequences = [
    ['H', 'E', 'L', 'L', 'O', '.'],
    ['H', 'I', 'B', 'R', 'O', 'S'],
    ['H', 'E', 'L', 'P', 'M', 'E'],
    ['H', 'E', 'A', 'L', 'M', 'E'],
]

str2idx = {c: i for i, c in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ.')}

def seqs2tensor(seqs):
    return torch.tensor([[str2idx[c] for c in seq] for seq in seqs]).float()
data = seqs2tensor(Sequences)
data = F.one_hot(data.long()).float()
data.shape

torch.Size([4, 6, 27])

In [None]:
col = Column(27, 64)
state = col.init_state(data.shape[0])