# **Hopfield Network with SKImage Data**

「SKImageの画像をネットワークで記憶→SKImageの画像にノイズ付加して→ネットワークで復元」の一連のホップフィールドネットワークの動作を試すデモコードになっています。ランタイムはCPUでOKです。

In [None]:
import numpy as np
np.random.seed(1)
from matplotlib import pyplot as plt
import skimage.data
from skimage.color import rgb2gray
from skimage.filters import threshold_mean
from skimage.transform import resize
import matplotlib.cm as cm
from tqdm import tqdm

class HopfieldNetwork(object):
    def train_weights(self, train_data):
        print("Start to train weights...")
        num_data =  len(train_data)
        self.num_neuron = train_data[0].shape[0]

        # initialize weights
        W = np.zeros((self.num_neuron, self.num_neuron))
        rho = np.sum([np.sum(t) for t in train_data]) / (num_data*self.num_neuron)

        # Hebb rule
        for i in tqdm(range(num_data)):
            t = train_data[i] - rho
            W += np.outer(t, t)

        # Make diagonal element of W into 0
        diagW = np.diag(np.diag(W))
        W = W - diagW
        W /= num_data

        self.W = W

    def predict(self, data, num_iter=20, threshold=0, asyn=False):
        print("Start to predict...")
        self.num_iter = num_iter
        self.threshold = threshold
        self.asyn = asyn

        # Copy to avoid call by reference
        copied_data = np.copy(data)

        # Define predict list
        predicted = []
        for i in tqdm(range(len(data))):
            predicted.append(self._run(copied_data[i]))
        return predicted

    def _run(self, init_s):
        if self.asyn==False:
            """
            Synchronous update
            """
            # Compute initial state energy
            s = init_s

            e = self.energy(s)

            # Iteration
            for i in range(self.num_iter):
                # Update s
                s = np.sign(self.W @ s - self.threshold)
                # Compute new state energy
                e_new = self.energy(s)

                # s is converged
                if e == e_new:
                    return s
                # Update energy
                e = e_new
            return s
        else:
            """
            Asynchronous update
            """
            # Compute initial state energy
            s = init_s
            e = self.energy(s)

            # Iteration
            for i in range(self.num_iter):
                for j in range(100):
                    # Select random neuron
                    idx = np.random.randint(0, self.num_neuron)
                    # Update s
                    s[idx] = np.sign(self.W[idx].T @ s - self.threshold)

                # Compute new state energy
                e_new = self.energy(s)

                # s is converged
                if e == e_new:
                    return s
                # Update energy
                e = e_new
            return s


    def energy(self, s):
        return -0.5 * s @ self.W @ s + np.sum(s * self.threshold)

    def plot_weights(self):
        plt.figure(figsize=(6, 5))
        W_norm = (self.W - np.min(self.W)) / (np.max(self.W) - np.min(self.W))
        w_mat = plt.imshow(self.W, cmap=cm.coolwarm)
        #w_mat = plt.imshow(W_norm, cmap=cm.coolwarm)
        plt.colorbar(w_mat)
        plt.title("Network Weights")
        plt.tight_layout()
        print("Saving network weights matrix...")
        plt.savefig("weights.png")
        print("Weights plot saved as weights.png.")
        #plt.show()

def get_corrupted_input(input, corruption_level):
    corrupted = np.copy(input)
    inv = np.random.binomial(n=1, p=corruption_level, size=len(input))
    for i, v in enumerate(input):
        if inv[i]:
            corrupted[i] = -1 * v
    return corrupted

def reshape(data):
    dim = int(np.sqrt(len(data)))
    data = np.reshape(data, (dim, dim))
    return data

def plot(data, test, predicted, figsize=(5, 6)):
    data = [reshape(d) for d in data]
    test = [reshape(d) for d in test]
    predicted = [reshape(d) for d in predicted]

    fig, axarr = plt.subplots(len(data), 3, figsize=figsize)
    for i in range(len(data)):
        if i==0:
            axarr[i, 0].set_title('Train data')
            axarr[i, 1].set_title("Input data")
            axarr[i, 2].set_title('Output data')

        axarr[i, 0].imshow(data[i], cmap='gray')
        axarr[i, 0].axis('off')
        axarr[i, 1].imshow(test[i], cmap='gray')
        axarr[i, 1].axis('off')
        axarr[i, 2].imshow(predicted[i], cmap='gray')
        axarr[i, 2].axis('off')

    plt.tight_layout()

    print("Saving prediction results...")
    plt.savefig("result.png")
    print("Prediction results saved as result.png.")
    #plt.show()

def preprocessing(img, w=128, h=128):
    # Resize image
    img = resize(img, (w,h), mode='reflect')

    # Thresholding
    thresh = threshold_mean(img)
    binary = img > thresh
    shift = 2*(binary*1)-1 # Boolian to int

    # Reshape
    flatten = np.reshape(shift, (w*h))
    return flatten

def main():

    # Load data
    camera = skimage.data.camera()
    astronaut = rgb2gray(skimage.data.astronaut())
    horse = skimage.data.horse()
    #coffee = rgb2gray(skimage.data.coffee())

    # Marge data
    #data = [camera, astronaut, horse, coffee]
    data = [camera, astronaut, horse]

    # Preprocessing
    print("Start to data preprocessing...")
    data = [preprocessing(d) for d in data]

    # Create Hopfield Network Model
    model = HopfieldNetwork()
    model.train_weights(data)

    # Generate testset(noise_level,0 to 1)
    noise_level=0.4
    test = [get_corrupted_input(d, noise_level) for d in data]

    predicted = model.predict(test, threshold=0, asyn=False)

    #Save network weights matrix(It takes time to process)
    #If you have enough memory, execute the following line (delete #)
    #model.plot_weights()

    #Save prediction results
    plot(data, test, predicted)


main()


Thanks to

*   https://github.com/shutakahama/Hopfield_Network
*   https://github.com/ethan-iai/hopfield-torch
*   https://github.com/takyamamoto/Hopfield-Network
*   https://github.com/kencyke/hopfield-mnist
*   https://www.kaggle.com/code/mrumuly/hopfield-network-mnist
*   https://github.com/grinvolod13/mnist-hopfield