In [1]:
# Imports as always.
import os
import re
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import dataclasses
from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import train_test_split

import torchvision
from torchvision import transforms

from PIL import Image

from hflayers import Hopfield, HopfieldLayer, HopfieldPooling

from datetime import datetime

from tqdm.notebook import tqdm

from data_handling import ISICDataset

# Ignore warnings.
import warnings
warnings.filterwarnings('ignore')

# Beautification.
sns.set_context('paper')
sns.set_style('darkgrid')

print(f'CUDA is available for use with PyTorch: {torch.cuda.is_available()}')

print(f'Installed Python version:  {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')
print(f'Installed PyTorch version: {torch.__version__}')

# Helper function to send a tensor/model/etc. to the CPU/GPU accordingly.
def to_device(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x.cpu()
    
# Helper function for closing figures.
def close_figures():
    while len(plt.get_fignums()) > 0:
        plt.close()
        
# Get the current data and time as a string.
date_string = datetime.now().strftime('%Y-%m-%d-(%H-%M-%S)')

CUDA is available for use with PyTorch: True
Installed Python version:  3.8.18
Installed PyTorch version: 2.1.2+cu121


# Hopfield as Pooling

This notebook will look at `HopfieldPooling` as a direct substitute for `MaxPooling`. We'll be doing this for MNIST classification rather than ISIC segmentation; firstly to avoid any up-scaling complications (e.g. with unpooling), and secondly because fuck you.

### From the continuous Hopfield paper

The `HopfieldPooling` layer is designed for fixed pattern search, pooling operations, and memories like LSTMs or GRUs. The state (i.e. query) pattern is static, and may be learned during training.

If only one static state pattern (i.e. query) exists, then this is de facto a pooling over the sequence. This static state pattern is considered a "prototype pattern" and consequently learned in the Hopfield pooling layer. Note that the pooling always operates over the *token* dimension  (i.e. the sequence length), not the embedding dimension.

![Hopfield pooling diagram](./hopfield_pooling_diagram.png)

```
hopfield_pooling = HopfieldPooling(
    input_size=4,       # Y
    hidden_size=3,      # Q
    scaling=beta,
    quantity=2)         # No. state patterns

# Stored_pattern and pattern_projection
hopfield_pooling(Y)
```

## Data Handling

In [2]:
# Pre-define a couple of transform functions to and from tensors and images.
tensor_to_image = transforms.ToPILImage()
image_to_tensor = transforms.ToTensor()

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)])

# Define the train dataset.
train_dataset = torchvision.datasets.MNIST(
    root='./data/MNIST',
    train=True,
    download=True,
    transform=transform
)

# Define the test dataset.
test_dataset = torchvision.datasets.MNIST(
    root='./data/MNIST',
    train=False,
    download=True,
    transform=transform
)

# Train-val split.
train_idx, val_idx = train_test_split(list(range(len(train_dataset))), test_size=.2)
train_subset = Subset(train_dataset, train_idx)
val_subset = Subset(train_dataset, val_idx)

# Package into data loaders.
batch_size = 16
train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## MNIST Classification Training Loop

In [4]:
def train(model, device, train_loader, optimizer):
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

## Simple CNN Classifier

In [5]:
class StandardCNN(nn.Module):
    def __init__(self):
        super(StandardCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
# Model, optimiser, and scheduler.
standard_cnn_model = StandardCNN().to('cuda')
optimiser = torch.optim.Adam(standard_cnn_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimiser, step_size=1, gamma=.7)

# Training.
epochs = 5
for epoch_idx in tqdm(range(1, epochs + 1), desc='Training'):
    train(standard_cnn_model, 'cuda', train_dataloader, optimiser) 
    test(standard_cnn_model, 'cuda', val_dataloader)
    scheduler.step()

Training:   0%|          | 0/5 [00:00<?, ?it/s]

Average loss: 0.0808, Accuracy: 11712/12000 (98%)
Average loss: 0.0502, Accuracy: 11820/12000 (98%)
Average loss: 0.0434, Accuracy: 11851/12000 (99%)
Average loss: 0.0390, Accuracy: 11862/12000 (99%)
Average loss: 0.0406, Accuracy: 11866/12000 (99%)


In [7]:
# Save model.
torch.save(standard_cnn_model.state_dict(), f'./models/cnn/saves/{date_string}')

## CNN with Learnable Hopfield Pooling

In [8]:
def get_sinusoidal_encoding(n_tokens, token_length):
    def get_position_angle_vector(i):
        return [i / np.power(10000, 2 * (j // 2) / token_length) for j in range(token_length)]

    table = np.array([get_position_angle_vector(i) for i in range(n_tokens)])
    table[:, 0::2] = np.sin(table[:, 0::2])
    table[:, 1::2] = np.cos(table[:, 1::2])

    return torch.FloatTensor(table).unsqueeze(0)

In [9]:
# Sequence-embedding network.
class Embedding(nn.Module):
    def __init__(self, image_size, patch_size, channels, embed_dim):
        super().__init__()
        self.image_size = int(image_size)
        self.patch_size = int(patch_size)
        self.channels = int(channels)
        self.embed_dim = int(embed_dim)
        
        # Trainable linear projection for mapping dimension of patches.
        self.W_E = nn.Parameter(torch.randn(self.patch_size * self.patch_size * self.channels, self.embed_dim))
        
        # Fixed sinusoidal positional embedding.
        self.n_patches = self.image_size ** 2 // self.patch_size ** 2
        self.PE = get_sinusoidal_encoding(n_tokens=self.n_patches, token_length=self.embed_dim)
        
    def forward(self, x):
        # Patching.
        patches = x.unfold(1, self.channels , self.channels).unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(patches.size(0), -1, self.channels  * self.patch_size * self.patch_size).float()
        
        # Patch embeddings.
        patch_embeddings = torch.matmul(patches, self.W_E)
        
        # Position embeddings.
        embeddings = patch_embeddings + self.PE
        
        # Transpose so that each column represents a patch embedding.
        #embeddings = torch.transpose(embeddings, 1, 2)
        
        return embeddings
    
# Shape check.
channels, image_size, patch_size, embed_dim = 3, 64, 16, 768
embedding_layer = Embedding(image_size, patch_size, channels, embed_dim)
x = torch.randn(batch_size, channels, image_size, image_size)
y = embedding_layer(x)
print(f'Embedding layer: input shape {x.shape} -> output shape {y.shape}')

Embedding layer: input shape torch.Size([16, 3, 64, 64]) -> output shape torch.Size([16, 16, 768])


In [None]:
# Define the Hopfield pooling substitute for Max pooling.
class HopfieldImagePooling(nn.Module):
    def __init__(self, image_size, patch_size, channels, embed_dim):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.channels = channels
        self.embed_dim = embed_dim
        
        # Embedding layer.
        
        
    def forward(self, x):
        # Convolution (batch_size, channels, image_size, image_size) -> Embedding (batch_size, n_patches, embed_dim).
        