In [6]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm
import numpy as np
import numba as nb

In [3]:
# Define LSTM network
class LSTMNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTMNet, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

In [4]:
# Define ZOrderTransform class
class ZOrderTransform:
    def __init__(self, n):
        self.n = n  # n is the order of the curve

    def _deinterleave(self, x):
        x = (x | (x >> 1)) & 0x33333333
        x = (x | (x >> 2)) & 0x0F0F0F0F
        x = (x | (x >> 4)) & 0x00FF00FF
        x = (x | (x >> 8)) & 0x0000FFFF
        return x

    def _zorder_to_coordinates(self, d):
        return self._deinterleave(d), self._deinterleave(d >> 1)

    def __call__(self, image):
        image_array = np.array(image)
        zorder_image = []
        for i in range(2 ** (2 * self.n)):
            point = self._zorder_to_coordinates(i)
            if point[0] < image_array.shape[0] and point[1] < image_array.shape[1]:
                zorder_image.append(image_array[point[1]][point[0]])
        
        return torch.Tensor(zorder_image)


In [7]:
 # Parameters
input_size = 1  # MNIST images are greyscale, so there is only 1 input channel
hidden_size = 128
num_layers = 2
num_classes = 10
num_epochs = 2
batch_size = 64
learning_rate = 0.001

# Load MNIST dataset
transform = transforms.Compose([ZOrderTransform(n=7),  # MNIST images are 28x28, so 2**7 is close enough
                                # transforms.ToTensor()])
                               ])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Create the LSTM network
model = LSTMNet(input_size, hidden_size, num_layers, num_classes)

# Loss and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        # images = images.reshape(-1, 28*28, 1)
        images = images.reshape(-1, 352, 1)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

0it [00:00, ?it/s]

KeyboardInterrupt: 