In [None]:
import torch
from HopfieldNetworkPyTorch.ModernHopfieldNetwork import ModernHopfieldNetwork, InteractionFunction
from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

if not torch.cuda.is_available():
    print("WARNING: This script can take a substantial amount of time without a GPU!")

In [None]:
# Defines the number of classes for the task, in this case 10 digits
numClasses = 10

# Defines the image shape of this task
imageShape = (28, 28)

# Defines the threshold of the heaiviside function. 
# When normalized between 0 and 1, pixels below this intensity are mapped to -1, above are mapped to 1.
imageThreshold = 0.5

def displayStatesAsImage(statesTensor: torch.Tensor, numImages: int, fig_kw: dict = {}) -> tuple[Figure, list[Axes]]:
    """
    Given a set of tensors of shape (imageShape[0]*imageShape[1]*numClasses, N), take only the image neurons of the first numImages items and display them.
    """

    numSubplot = np.ceil(np.sqrt(numImages)).astype(int)
    fig, axes = plt.subplots(numSubplot, numSubplot, **fig_kw)
    for ax in np.ravel(axes):
        ax.axis("off")
    
    for itemIndex, ax in zip(range(numImages), np.ravel(axes)):
        targetMemory = statesTensor[:imageShape[0]*imageShape[1], itemIndex].to("cpu").detach().numpy()
        targetMemory = targetMemory.reshape(imageShape)
        ax.imshow(targetMemory)
        # plt.colorbar()
    
    return fig, axes

def getClassesFromStateTensor(stateTensor: torch.Tensor):
    """
    Compute the class label of the states in stateTensor.

    The class is computed by the argmax of the class neurons (the final numClasses neurons)
    """

    return torch.argmax(stateTensor[-numClasses:, :], dim=0)

# Load the MNIST dataset

In [None]:
trainingData = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

X_train = torch.flatten(trainingData.data, start_dim=1).T
X_train = X_train / 255.0
X_train[X_train < imageThreshold] = -1
X_train[X_train > imageThreshold] = 1
y_train = torch.full((trainingData.targets.shape[0], numClasses), -1)       # Creates a new vector of -1 for each item in the training data
y_train = y_train.scatter(1, trainingData.targets.view(-1,1), 1).T          # Puts a 1 at the correct class index for each vector, now one hot encoded
trainingStates = torch.cat((X_train, y_train))                              # Finally, put the image data and class data together into a single tensor
del X_train
del y_train
del trainingData

testingData = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

X_test = torch.flatten(testingData.data, start_dim=1).T
X_test = X_test / 255.0
X_test[X_test < imageThreshold] = -1
X_test[X_test > imageThreshold] = 1
y_test = torch.full((testingData.targets.shape[0], numClasses), -1)      # Creates a new vector of -1 for each item in the testing data
y_test = y_test.scatter(1, testingData.targets.view(-1,1), 1).T          # Puts a 1 at the correct class index for each vector, now one hot encoded
testingStates = torch.cat((X_test, y_test))                              # Finally, put the image data and class data together into a single tensor
del X_test
del y_test
del testingData

fig, _ = displayStatesAsImage(trainingStates, 16, fig_kw={"figsize": (12,12)})
fig.suptitle("Training States")
plt.show()

# Define Network Parameters

In [None]:
# Dimension of our network will be the flattened image shape (28*28) plus a number of neurons equal to the number of classes (10)
dimension = (imageShape[0] * imageShape[1])+numClasses

# Since we will not update the "image neurons" we mask only the "class neurons"
neuronMask = torch.arange(dimension-10, dimension)

# We must also set the various ModernHopfieldNetwork parameters

interactionVertex = 15
numMemories = 100
initialLearningRate = 1e-3
learningRateDecay = 0.999
momentum = 0.6
initialTemperature = 500
finalTemperature = 500
errorPower = 1

maximumTrainingEpochs = 1000
itemBatchSize = 1024
neuronBatchSize = 10


# Create and Train the Network

In [None]:
network = ModernHopfieldNetwork(
    dimension = dimension,
    nMemories = numMemories,
    interactionFunction = InteractionFunction.RectifiedPolynomialInteractionFunction(interactionVertex),
    torchDevice = device,
    itemBatchSize = itemBatchSize,
    neuronBatchSize = neuronBatchSize 
)

# Bias the network's class neurons to have negative values

newMemories = torch.normal(mean=torch.full_like(network.memories, 0.0))
newMemories[-numClasses:, :] = torch.normal(mean=torch.full_like(network.memories[-numClasses:, :], -0.3))
newMemories = newMemories.clamp_(-1,1)
network.setMemories(newMemories)

In [None]:
# Learn the memory vectors

trainingStates = trainingStates[:, :2**14]

history = network.learnMemories(
    trainingStates.to(device),
    maxEpochs = maximumTrainingEpochs, 
    initialLearningRate = initialLearningRate,
    learningRateDecay = learningRateDecay,
    momentum = momentum,
    initialTemperature = initialTemperature,
    finalTemperature = finalTemperature,
    errorPower = errorPower,
    neuronMask = neuronMask,
    precision=1e-30,
    verbose=2,
)

plt.plot(np.arange(len(history)), history)
plt.title("Loss History")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

# Relax and Predict the Testing Data

In [None]:
relaxedTestingStates = testingStates.clone().to(device)
relaxedTestingStates[-numClasses:, :] = -1
network.stepStates(relaxedTestingStates, neuronMask=neuronMask, activationFunction=lambda X: X)
relaxedTestingStates = relaxedTestingStates.to("cpu")

In [None]:
shuffleIndices = torch.randperm(testingStates.shape[1])
testingStates = testingStates[:, shuffleIndices]
relaxedTestingStates = relaxedTestingStates[:, shuffleIndices]


trueTestingClassLabel = getClassesFromStateTensor(testingStates)
predictedClassSoftmax = torch.nn.Softmax(dim=0)(relaxedTestingStates[-numClasses:, :])
predictedClassLabel = getClassesFromStateTensor(relaxedTestingStates)

numItems = 16
print(f"True Class Labels:\t\t{trueTestingClassLabel[:numItems]}")
print(f"Predicted Class Labels:\t\t{predictedClassLabel[:numItems]}")

fig, axes = displayStatesAsImage(testingStates, numItems, fig_kw={"figsize": (12,12)})
for trueLabel, predLabel, ax in zip(trueTestingClassLabel, predictedClassLabel, np.ravel(axes)):
    ax: Axes
    ax.set_title(f"True: {trueLabel}\nPred: {predLabel}")
plt.tight_layout()
plt.show()

# print(f"True Class Softmax:\t\t{predictedClassSoftmax[:numItems]}")

# Display the Network Memories

In [None]:
shuffleIndices = torch.randperm(numMemories)
memories = network.memories[:, shuffleIndices]
displayStatesAsImage(memories, numImages=16, fig_kw={"figsize": (12,12)})
plt.show()

displayStatesAsImage(memories, numImages=numMemories, fig_kw={"figsize": (12,12)})
plt.show()