In [1]:
require 'nn';
require 'Attention';
require 'Association';
require 'image';
require 'gnuplot';

# Data Preparation 

In [2]:
-- Three Dimensions
-- 1=Big/0=Small; 1=Rectangle/0=Triangle; 1=Black/0=White
-- Differentiation between exemplars and category types would be useful
exemplar = torch.Tensor(8,3)
exemplar[1] = torch.Tensor({1,1,1})
exemplar[2] = torch.Tensor({0,1,1})
exemplar[3] = torch.Tensor({1,0,1})
exemplar[4] = torch.Tensor({1,1,0})
exemplar[5] = torch.Tensor({0,0,1})
exemplar[6] = torch.Tensor({1,0,0})
exemplar[7] = torch.Tensor({0,1,0})
exemplar[8] = torch.Tensor({0,0,0})

-- define white exemplars for every category type which can hopefully be used as logical indexes
-- See Kruschke 1992, Fig. 4 for this
category = torch.Tensor(6, 8)
category[1] = torch.Tensor({1, 0, 1, 1, 0, 1, 0, 0})
category[2] = torch.Tensor({0, 1, 1, 0, 0, 1, 1, 0})
category[3] = torch.Tensor({1, 1, 0, 1, 0, 1, 0, 0})
category[4] = torch.Tensor({1, 1, 1, 1, 0, 0, 0, 0})
category[5] = torch.Tensor({0, 1, 1, 1, 0, 1, 0, 0})
category[6] = torch.Tensor({1, 0, 0, 0, 1, 1, 1, 0})



# Training

In [3]:
target = torch.Tensor(2)
prob_corr = torch.Tensor(6, 50, 8)
-- For each category type
for l = 1,6 do
    -- initializeNetwork()
    alcove = nn.Sequential()
    attention = nn.Attention(3, 8, exemplar, 6.5, 1,1)
    association = nn.Association(2,8)
    alcove:add(attention)
    alcove:add(association)

    -- Over 50 epochs
    for i = 1,50 do
        -- For each training exempla
        for j = 1,8 do
            a_out = alcove:forward(exemplar[(i+j)%8 + 1]) -- forward activation          
            if category[l][(i+j)%8 + 1] == 1  then 
                target[1] = math.max(1, a_out[1])
                target[2] = math.min(-1,a_out[2])
                prob_corr[l][i][j] = math.exp(a_out[1]*2.0)/torch.sum(torch.mul(a_out,2.0):exp())
            else
                target[1] = math.min(-1, a_out[1])
                target[2] = math.max(1, a_out[2])
                prob_corr[l][i][j] = math.exp(a_out[2]*2.0)/torch.sum(torch.mul(a_out,2.0):exp())
            end
            alcove:backward(exemplar[(i+j)%8 + 1],target-a_out)
        end
        attention:updateParameters(0.0033) -- Why negative?
        association:updateParameters(0.03)
        association:zeroGradParameters()
        attention:zeroGradParameters()
    end
end

In [4]:
x = torch.linspace(1,50,50)
prob_corr = torch.mean(prob_corr,3)
gnuplot.plot({'1',x, prob_corr[{1,{},1}]},{'2',x, prob_corr[{2,{},1}]},{'3',x, prob_corr[{3,{},1}]},{'4',x, prob_corr[{4,{},1}]},{'5',x, prob_corr[{5,{},1}]},{'6',x, prob_corr[{6,{},1}]})
