/
HexagonDemo.m
95 lines (73 loc) · 2.64 KB
/
HexagonDemo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
% A simple geometrical demonstration of an overcomplete Infomax network
%
% Written by Aviv Dotan
% 11.11.2017
%
% Based on the paper:
% Shriki, O., Sompolinsky, H., & Lee, D. D. (2001). An information
% maximization approach to overcomplete and recurrent representations. In
% Advances in neural information processing systems (pp. 612-618).
% URL:
% https://papers.nips.cc/paper/1863-an-information-maximization-approach-to-overcomplete-and-recurrent-representations
clear;
close;
clc;
%% Uniformly sample 2D points from a hexagon
n_samples = 3000; % Number of training points
% Generate the data
X = rand(2, n_samples);
% Reshape the data distribution into a hexagon
n3 = floor(n_samples/3);
ind3 = 1:n3;
R = @(theta) [ cos(theta), sin(theta);
-sin(theta), cos(theta)];
D = diag([sqrt(3)/sqrt(2), sqrt(2)/2]);
mu = repmat([-sqrt(3)/2; 1/2], [1, n_samples]);
X = D*R(pi/4)*X + mu; % Reshape into a Rhombus
X(:, ind3) = R(2*pi/3)*X(:, ind3); % Rotate third of the data
X(:, ind3 + n3) = R(4*pi/3)*X(:, ind3 + n3); % Rotate third of the data
%% Create an overcomplete Infomax network
Net = Infomax(2, 3);
%% Train the network
n_train = n_samples; % Number of learning steps
batch_size = 1; % Number of samples per learning step
plot_freq = floor(n_train/100); % Plot frequency
figure();
for t = 1:n_train
%% Learn
% Choose a random batch
x = X(:, randperm(n_samples, batch_size));
% Train the network
Net.Learn(x);
%% Progress plot
% Set the plot frequency
if ~(rem(t, plot_freq) == 0 || t == n_train)
continue;
end
% Get the network's cost
cost = Net.GetCost(X);
% Get the network's axes (normalized)
Wpinv = pinv(Net.W);
Wpinv = Wpinv ./ sqrt((4/3)*max(diag(Wpinv'*Wpinv)));
% Plot the training data
scatter(X(1,:), X(2,:), 'k.');
hold on;
% Plot the network's axes
quiver(zeros(1,3), zeros(1,3), Wpinv(1,:), Wpinv(2,:), ...
'b', 'Linewidth', 2);
hold off;
% Plot formatting
title(['$$t=' num2str(t, '%-d') '$$ , ' ...
'$$\varepsilon=' num2str(cost, '%-g') '$$'], ...
'Interpreter', 'latex');
set(gca, 'XTick', []);
xlim([-1, 1]);
xlabel('$$x_1$$', 'Interpreter', 'latex');
ylim([-1, 1]);
set(gca, 'YTick', []);
ylabel('$$x_2$$', 'Interpreter', 'latex');
axis square;
% Draw the plot
drawnow;
pause(0.01);
end