In [1]:
"""
Post: https://towardsdatascience.com/a-succinct-guide-to-bidirectional-associative-memory-bam-d2f1ac9b868
Original code from: https://github.com/arthurratz/bam_associations_intro
"""

import numpy as np

In [2]:
def learn(x,y):
    return x.T.dot(y)

In [3]:
def learn_op(x,y):
    return np.sum([np.outer(x,y) for x,y in zip(x,y)],axis=0)

In [4]:
def bipolar_th(x):
    return 1 if x >= 0 else -1

In [39]:
def activate(x):
    return np.vectorize(bipolar_th)(x)
    #return np.sign(x)

In [49]:
# Recalls an association Y for the input pattern X, bidirectionally:
def recall_backward(w, _input):
    """
    Compute the BAM output until the existing 
    inputs are not equal to the new inputs
    """
    while True:
        _output = activate(w.dot(_input))
        new_input = activate(w.T.dot(_output))
        if np.all(np.equal(_input, new_input)):
            break
        _input = new_input
    
    return _output

def recall_forward(w, _input):
    """
    Compute the BAM output until the existing 
    inputs are not equal to the new inputs
    """
    while True:
        _output = activate(w.T.dot(_input))
        new_input = activate(w.dot(_output))
        if np.all(np.equal(_input, new_input)):
            break
        _input = new_input
    
    return _output

In [47]:
# The BAM model of 8*10^3 inputs, 5*10^3 memory cells, with memory capacity - 20 patterns

patterns = 5 # m
neurons = 20 # n
mm_cells = 20 # p

# Generate input (X) and output (Y) patterns maps of shapes 
# (patterns x neurons) and (patterns by mm_cells)
x_size_1d = patterns * neurons
bipolar_random_x = [1 if x > 0.5 else -1 for x in np.random.rand(x_size_1d)]
X = np.array(bipolar_random_x, dtype=np.int8)
# Reshape patterns into the input and output 2D-pattern maps X and Y
X = np.reshape(X, (patterns, neurons))

# Orthogonalize the input patterns (X)
# into the corresponding output patterns (Y) 
y_size_1d = patterns * mm_cells
# -X[:patterns*mm_cells]
bipolar_random_y = [1 if x > 0.5 else -1 for x in np.random.rand(y_size_1d)]
Y = np.array(bipolar_random_y, dtype=np.int8)
# Reshape patterns into the input and output 2D-pattern maps X and Y
Y = np.reshape(Y, (patterns, mm_cells))

# Learn the BAM model with the associations of the input and output patterns X and Y
W = learn_op(X,Y) # W - the correlation weights matrix (i.e., the BAM's memory storage space)
W

array([[-1, -1, -1,  1,  3,  3, -1,  1, -3,  1, -3,  1,  3, -1,  1, -1,
        -3, -1, -3,  3],
       [-1,  3, -5,  1, -1,  3,  3, -3,  1,  1,  1,  1,  3,  3,  1, -1,
         1,  3,  1,  3],
       [ 3, -1,  3, -3, -1, -1, -1,  1,  1,  1,  1,  1, -5, -1,  1,  3,
         1, -1,  1, -1],
       [ 1,  1, -3, -1,  1,  5,  1, -1, -1,  3, -1,  3,  1,  1,  3,  1,
        -1,  1, -1,  5],
       [ 3, -1,  3,  1,  3, -1, -1,  5, -3,  1, -3,  1, -1, -5,  1, -1,
        -3, -1, -3, -1],
       [-3, -3,  1, -1,  1, -3,  1, -1, -1, -1,  3, -1,  1,  1, -5, -3,
         3,  1,  3, -3],
       [-1, -1, -1, -3, -1,  3, -1, -3,  1,  1,  1,  1, -1,  3,  1,  3,
         1, -1,  1,  3],
       [-1, -1, -1, -3, -1,  3, -1, -3,  1,  1,  1,  1, -1,  3,  1,  3,
         1, -1,  1,  3],
       [ 1,  1,  1, -1, -3,  1, -3, -1,  3, -1, -1, -1, -3,  1,  3,  5,
        -1, -3, -1,  1],
       [-1,  3, -1,  1, -5, -1, -1, -3,  5, -3,  1, -3, -1,  3,  1,  3,
         1, -1,  1, -1],
       [ 3, -1, -1, -3,  3,  3

In [52]:
print("Recalling the associations (Y) for the input patterns (X)")
total = 0
correct = 0
for x,y in zip(X,Y):
    y_pred = recall_forward(W, x)
    if np.all(y_pred == y):
        correct += 1
    else:
        print(y_pred)
        print(y)
    total += 1
print(f'correct: {correct}')

correct = 0
print("Recalling the associations (X) for the output patterns (Y)")
for x,y in zip(X,Y):
    x_pred = recall_backward(W, y)
    if np.all(x_pred == x):
        correct += 1
    else:
        print(x_pred, x)
    total += 1
    
print(f'correct: {correct}')

Recalling the associations (Y) for the input patterns (X)
correct: 5
Recalling the associations (X) for the output patterns (Y)
correct: 5


In [53]:
print("Predicting a randomly distorted patterns")

# Distorts an input pattern map X
def poison(x,ratio=0.33,distort='yes'):
    p_fn = [ lambda x: 0 if np.random.rand() > 0.5 else x,
             lambda x: 1 if np.random.rand() > 0.5 else -1, ]

    x_shape = np.shape(x); x = np.reshape(x,-1)
    return np.reshape(np.vectorize(p_fn[distort == 'yes'])(x),x_shape)

n = 100
i = 0
correct = 0

while i < n:
    pattern_n = np.random.randint(0,np.size(X,axis=0))
    # Distort the input pattern with random 1's or -1's
    x_dist = poison(X[pattern_n],distort='yes')
    # Predict a correct association for the random pattern X
    y_pred = recall_forward(W, x_dist)
    #print(Y[pattern_n], y_pred)
    if np.any(Y[pattern_n] != y_pred):
        correct += 1
    i += 1

print(f'correct: {correct} / total: {n}')

Predicting a randomly distorted patterns
correct: 88 / total: 100
