-
Notifications
You must be signed in to change notification settings - Fork 0
/
hopfield.py
116 lines (96 loc) · 3.9 KB
/
hopfield.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import random
import os
import numpy as np
import perceptron as p
import utils
# Build pattern array from txt file with '*', ' '
def get_pattern_array(filename: str) -> np.ndarray:
pattern = []
with open(filename) as file:
for line in file:
for car in line.strip('\n'):
pattern.append(1 if car == '*' else -1)
return np.array(pattern)
# Build pattern matrix from directory, len(pattern)
def pattern_matrix(dirpath: str, N: int) -> np.ndarray:
file_list = os.listdir(pattern_dirpath)
patterns = np.zeros((N, len(file_list)), dtype=int)
letter_list = []
for i, pattern_filename in enumerate(file_list):
letter_list.append(pattern_filename)
pattern_filepath = pattern_dirpath + '/' + pattern_filename
patterns[:, i] = get_pattern_array(pattern_filepath)
return letter_list, patterns
# Print side x side pattern from side x side lengthed array
def print_pattern(pattern: np.ndarray, side: int):
for i in range(side * side):
car = '*' if pattern[i] > 0 else ' '
print(car, end='')
if i != 0 and (i + 1) % side == 0:
print('\n', end='')
# Generate pattern mutation from pattern using pm probability
def get_mutated_pattern(pattern: np.ndarray, pm: float) -> np.ndarray:
mut_pattern = np.copy(pattern)
for i in range(len(mut_pattern)):
if random.random() < pm:
mut_pattern[i] = mut_pattern[i] * -1
return mut_pattern
# Pattern length is fixed
SIDE = 5
N = SIDE * SIDE
RED = "#FF0000"
GREEN = "#00FF00"
# read config file
with open("config.json") as file:
config = json.load(file)
pattern_dirpath: str = config["hopfield"]["pattern_dir"]
pm: float = config["hopfield"]["mutation_prob"]
max_iterations: int = config["hopfield"]["max_iterations"]
plot_boolean: bool = config["plot"]
# Build pattern matrix, with [e1 e2 e3 ...], len(ei) = N
letter_list, patterns = pattern_matrix(pattern_dirpath, N)
# Calculate dot product between letters --> Closer to 0, more ortogonal
for i in range(len(patterns[0])):
for j in range(i + 1, len(patterns[0])):
print(f'Producto interno entre {letter_list[i]} y {letter_list[j]}: {np.dot(patterns[:, i], patterns[:, j])}')
# Get query pattern from available patterns
query_num = random.randint(0, patterns.shape[1] - 1)
query_pattern = get_mutated_pattern(patterns[:, query_num], pm)
# Initialize Hopfield perceptron
algo: p.HopfieldPerceptron = p.HopfieldPerceptron(patterns, query_pattern)
# Print initial query
print('------------------')
print_pattern(query_pattern, SIDE)
print('------------------')
# Iterate over hopfield
s: np.ndarray
count: int = 0
while not algo.is_over() and count < max_iterations:
s = algo.iterate()
print_pattern(s, SIDE)
print('------------------')
count += 1
# Print ending motive
if count >= max_iterations:
print(f'Se ha alcanzado el {utils.string_with_color("límite de iteraciones", RED)} (probablemente por loop). Saliendo...')
else:
spurious = True
for i in range(patterns.shape[1]):
if np.array_equal(s, patterns[:, i]):
(correct, color) = ("es correcto", GREEN) if letter_list[i] == letter_list[query_num] else ("es incorrecto", RED)
print(f'El estado final {utils.string_with_color(correct, color)}. Coincide con {letter_list[i]} ({count} iter) y era {letter_list[query_num]}.')
spurious = False
break
if spurious:
print(f'El estado final {utils.string_with_color("es espúreo", RED)}. Debería coincidir con {letter_list[query_num]}.')
# Print energy values
print(f'\nEnergy values: \n{algo.energy}')
# If plot_boolean is true, generate plots
if plot_boolean:
# Init plotter
utils.init_plotter()
# Plot energy = f(t)
utils.plot_values(range(len(algo.energy)), 'iteration', algo.energy, 'energy', sci_y=False, ticks=range(len(algo.energy)))
# Hold execution to show plots
utils.hold_execution()