WIP

## MNIST Classification / Generation using AUNN

## Setup & Definition

In [13]:
import torch
import random
import time
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#deallocate all cuda memory
torch.cuda.empty_cache()

#print cuda memory
print(torch.cuda.memory_allocated())

cuda
0


In [14]:
class AUNNModel(nn.Module):
    def __init__(
        self, 
        embedding_dim:int,
        output_dim:int, 
        num_layers:int, 
        hidden_dim:int):        

        assert num_layers >= 2, "Number of layers must be at least 2"

        super(AUNNModel, self).__init__() 
    
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Input Layer
        self.input_layer =  nn.Linear(self.embedding_dim, self.hidden_dim)

        # Hidden Layers
        self.layers = nn.ModuleList()
        for _ in range(self.num_layers - 2):  # Exclude input and output layers
            self.layers.append(nn.Sequential(
                nn.Linear(self.hidden_dim, self.hidden_dim),
                nn.SiLU(),
                nn.RMSNorm(self.hidden_dim)
            ))

        # Output Layer
        self.output_layer = nn.Linear(self.hidden_dim, self.output_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    # binary
    def encode(self, x: torch.Tensor):
        dim = self.embedding_dim
        encoding = ((x.unsqueeze(1) >> torch.arange(dim, device=x.device)) & 1).to(torch.float32)
        return encoding

    def forward(self, indices):
        
        x = self.encode(indices)
        x = self.input_layer(x)
        x = x + nn.SiLU()(x)

        for layer in self.layers:
            x = x + layer(x)  # MLP output with skip connection

        x = self.output_layer(x)
        return x

In [15]:
# Define a function to save the model checkpoint
def save_checkpoint(model, params, optimizer, losses, filename="checkpoint.pth"):
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
        'params':{}
    }

    keys = ['embedding_dim', 'output_dim', 'num_layers', 'hidden_dim']
    assert all(k in params for k in keys)
    for k in keys:
        checkpoint['params'][k] = params[k]

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved with loss {losses[-1]:.4f}")

In [16]:
def load_checkpoint(filename="checkpoint.pth"):

    checkpoint = torch.load(filename, weights_only=True)
    
    params = checkpoint['params']
    model = AUNNModel(**params)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    optimizer = optim.AdamW(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    losses = checkpoint['losses']

    print(f"Checkpoint loaded: loss {losses[-1]:.4f}")

    return model, optimizer, losses

## Prepare MNIST

In [31]:
import struct
from array import array

def load_mnist(images_path, labels_path, shuffle:bool=False, binarize:bool=True, seed=42):

    labels = []
    with open(labels_path, 'rb') as file:
        magic, size = struct.unpack(">II", file.read(8)) 
        if magic != 2049:
            raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
        labels = array("B", file.read())

    images = []
    rows, cols = None, None
    with open(images_path, 'rb') as file:
        magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
        if magic != 2051:
            raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
        data = array("B", file.read())
        for i in range(size):
            img = np.array(data[i * rows * cols:(i + 1) * rows * cols], dtype=np.uint8)
            if binarize:
                img = np.where(img > 0, 1, 0) 
            img.resize((rows, cols))
            images.append(img)

    assert len(images) == len(labels)

    if shuffle:
        random.seed(seed)
        indices = list(range(len(images)))
        random.shuffle(indices)
        images = [images[i] for i in indices]
        labels = [labels[i] for i in indices]

    return images, labels


def encode_patches(img, N):
    """
    Encodes a (28x28) binary image into patches of size NxN.
    Each NxN patch (N^2 bits) is converted into a single integer 
    in [0..(2^(N^2)-1)].
    
    Returns a flattened 1D array of length (28//N)*(28//N).
    """
    # 1. Reshape: (dim//N, N, dim//N, N)
    # 2. Transpose to group NxN patches -> (28//N, 28//N, N, N)
    #    so each (i, j) index references one NxN patch
    h = w = img.shape[0]
    assert img.shape[0] == img.shape[1], "Image must be square."
    assert h % N == 0 and w % N == 0, "Image dimensions must be divisible by N."
    arr_NxN = img.reshape(h//N, N, w//N, N).transpose(0, 2, 1, 3)
    # shape: (h//N, w//N, N, N)

    # 3. Flatten NxN -> N^2 bits: shape becomes (h//N, w//N, N^2)
    arr_flat = arr_NxN.reshape(h//N, w//N, N*N)

    # 4. Multiply each bit by powers of 2 to combine into single integers
    #    For example, if N=2, bitweights = [1, 2, 4, 8].
    bitweights = np.array([1 << i for i in range(N*N)], dtype=arr_flat.dtype)
    encoded = (arr_flat * bitweights).sum(axis=-1)  # shape (h//N, w//N)

    # 5. Flatten to 1D array
    return encoded.ravel()


def decode_patches(encoded, N, dim=28):
    """
    Decodes a flattened array of integers (each representing N^2 bits)
    back to a (28x28) binary image with NxN patches.
    """
    h = w = dim
    num_patches = (h // N) * (w // N)
    
    # 1. Reshape from (num_patches,) -> (h//N, w//N)
    arr_2d = encoded.reshape(h//N, w//N)

    # 2. Extract bits for each integer using bitwise AND with [1, 2, 4, ..., 2^(N^2-1)]
    bitweights = np.array([1 << i for i in range(N*N)], dtype=np.uint16)
    # shape -> (h//N, w//N, N^2)
    bits = ((arr_2d[..., None] & bitweights) > 0).astype(np.uint16)

    # 3. Reshape bits back to NxN patches: (h//N, w//N, N, N)
    arr_patches = bits.reshape(h//N, w//N, N, N)

    # 4. Transpose to (h//N, N, w//N, N) then reshape to (h, w)
    arr_patches = arr_patches.transpose(0, 2, 1, 3)  # shape -> (h//N, N, w//N, N)
    decoded = arr_patches.reshape(h, w)

    return decoded


def make_string(images, labels):
    
    targets = []
    for img, label in zip(images, labels):

        #prep img data, conver 0-255 to 0-1 float
        img = img.flatten().astype(np.float32) / 255.0
        img = img.reshape(-1, 1) #784x1
        print(img.shape)
        #add 10 columns of just zeros, new shape = 784x11
        img = np.concatenate([img, np.zeros((len(img), 10), dtype=np.float32)], axis=1)
        display(img)
        display(img[:5]) #display first 5 rows
        print(img.shape)

        #prep label data - create one-hot encoding
        label_data = np.zeros(10+1, dtype=np.float32)
        label_data[label+1] = 1.0
        num_repeats = len(img) // 8
        label_data = np.tile(label_data, num_repeats).reshape(num_repeats, -1)
        display(label_data)
        display(label_data[:5]) #display first 5 rows
        break

        #add to target
        targets.append(label)
        targets.append(img)
        targets.append(label)

        #define lengths
        img_len = len(img)
        lbl_len = len(label)
        ex_len = img_len + lbl_len * 2

    targets = np.concatenate(targets, axis=0)
    return (ex_len, img_len, lbl_len), targets


In [12]:
x = np.zeros(10)
x[1] = 1
print(x)
x = np.tile(x, 3).reshape(3, -1)
print(x)

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [32]:
from pathlib import Path

cur_dir = Path().resolve()
input_path = cur_dir / 'mnist'
training_images_filepath = input_path / 'train-images.idx3-ubyte'
training_labels_filepath = input_path /'train-labels.idx1-ubyte'
test_images_filepath = input_path / 't10k-images.idx3-ubyte'
test_labels_filepath = input_path / 't10k-labels.idx1-ubyte'

train_images, train_labels = load_mnist(training_images_filepath, training_labels_filepath, shuffle=True, binarize=False)
lengths, train_data  = make_string(train_images, train_labels)
print(f"{len(train_data):,} training samples")

(784, 1)


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

(784, 11)


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

array([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=float32)

ValueError: need at least one array to concatenate

In [None]:
img = train_images[999]
img = np.where(img > 0.1, 1, 0)  #binarize
plt.imshow(img, cmap='gray')
plt.show()

enc = encode_patches(img, 1)

img = decode_patches(enc, 1)
plt.imshow(img, cmap='gray')
plt.show()

In [None]:
ex_len, img_len, lbl_len = lengths
print(f"Example length: {ex_len}, Image length: {img_len}, Label length: {lbl_len}")

In [None]:
ex_num = 100
offset = ex_num * ex_len

#get label bytes
lbl_1 = train_data[offset:offset+lbl_len]
offset += lbl_len
print(lbl_1)

#get image bytes
img = train_data[offset:offset+img_len]
print(img.reshape(28,28))
offset += img_len
img = decode_patches(img, 1)
plt.imshow(img, cmap='gray')
plt.show()

#get label bytes
lbl_2 = train_data[offset:offset+lbl_len]
print(lbl_2)

## Train

In [None]:
train_data = torch.tensor(train_data, dtype=torch.long).to(device)
print(f"train_data: {train_data.shape}")

In [None]:
# Hyperparameters

embedd_dim = 32
num_layers = 12    # Must be even and at least 2 (bc of skip connections)
hidden_dim = 768   # Size of hidden layers
output_dim = 2**4  # bc N=2 patches

# Initialize the model

model = AUNNModel(
    embedding_dim=embedd_dim,
    output_dim=output_dim,
    num_layers=num_layers, 
    hidden_dim=hidden_dim).to(device)
print(f"Model has {model.count_params():,} parameters")

In [12]:
# Initialize, loss function, and optimizer

criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [13]:
losses = []
accuracies = []

In [None]:
from tqdm.auto import tqdm

model.train()
context_len = 4096
history = []
base_len = 1024

for i in tqdm(range(len(train_data) - context_len), total=len(train_data) - context_len):
        
    start = i
    end = i + context_len + 1
    targets = train_data[start:end]
    abs_indices = torch.arange(start, end)
    data_indices = abs_indices % ex_len
    assert base_len >= ex_len
    ex_indices = abs_indices // ex_len
    ex_indices = ex_indices * base_len
    indices = ex_indices + data_indices
    indices = indices.to(device)
        
    j = 0
    while True:

        outputs = model(indices)
        loss = criterion(outputs, targets)
        loss_val = loss.item()

        predicted = outputs.argmax(dim=1)
        a = targets[-1].item()
        b = predicted[-1].item()

        if j == 0:
            cur_ind = indices[-1].item()
            rel_ind = cur_ind % base_len

            if rel_ind == 0:
                display(f'----- EXAMPLE {cur_ind // base_len} START -----')
                display('LABEL START')

            elif rel_ind == lbl_len:
                display(f"LABEL END")
                display('IMAGE START')
                # see what the model would-have predicted for img at this point
                img_inds = torch.arange(cur_ind, cur_ind + img_len).to(device)
                img_outputs = model(img_inds)
                img_predicted = img_outputs.argmax(dim=1)
                img_predicted = img_predicted.cpu().numpy()
                img_data = decode_patches(img_predicted, 1)
                plt.imshow(img_data, cmap='gray')
                plt.show()

            elif rel_ind == lbl_len + img_len:
                display('IMAGE END')
                if len(history) >= img_len:
                    img_data = history[-img_len:]
                    img_data = np.array(img_data, dtype=np.uint16)
                    img_data = decode_patches(img_data, 1)
                    plt.imshow(img_data, cmap='gray')
                    plt.show()
                display('LABEL START')

            elif rel_ind == lbl_len + img_len + lbl_len:
                display('LABEL END')
            losses.append(loss_val)
            accuracies.append(a == b)
            history.append(b)
            print(a,b,'T' if a == b else 'F', f"{loss_val:f} @{cur_ind}" )
        j += 1

        loss_thresh = 0.0001
        if a == b and loss_val < loss_thresh:
            break

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


In [None]:
save_checkpoint(model, {
    'embedding_dim': embedd_dim,
    'output_dim': output_dim,
    'num_layers': num_layers,
    'hidden_dim': hidden_dim
}, optimizer, losses, filename="mnist/checkpoint.pth")

In [None]:
# Plot the loss curve
downsample = 1
window = 100

temp = []
for epoch in losses[num_batches*2::downsample]:
    avg = np.mean(epoch)
    temp.append(avg)
temp = np.convolve(temp, np.ones(window)/window, mode='valid')

plt.figure(figsize=(20, 6))
plt.plot(temp, label="Training Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

# Plot the accuracy curve
temp = []
for epoch in accuracies[num_batches*2::downsample]:
    avg = np.mean(epoch)
    temp.append(avg)
temp = np.convolve(temp, np.ones(window)/window, mode='valid')

plt.figure(figsize=(20, 6))
plt.plot(temp, label="Training Accuracy")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.title("Training Accuracy")
plt.legend()
plt.show()

In [None]:
entry_num = num_ex - 100
offset = entry_num * ex_len
data = train_data[offset:offset+ex_len].cpu().numpy()

lbl_1 = data[0:lbl_len]
img = data[lbl_len:lbl_len+img_len]
lbl_2 = data[lbl_len+img_len:]

print(lbl_1)
img = decode_patches(img, 2)
plt.imshow(img, cmap='gray')
plt.show()
print(lbl_2)

In [None]:
entry_num = num_ex - 100
offset = entry_num * ex_len
abs_indices = torch.arange(offset, offset+ex_len, device=device)
data_indices = abs_indices % ex_len
ex_indices = abs_indices // ex_len
ex_indices = ex_indices * 256
indices = ex_indices + data_indices
with torch.no_grad():
    outputs = model(indices)
    outputs = torch.argmax(outputs, dim=1)

outputs = outputs.cpu().numpy()
lbl_1 = outputs[0:lbl_len]
img = outputs[lbl_len:lbl_len+img_len]
lbl_2 = outputs[lbl_len+img_len:]

print(lbl_1)
img = decode_patches(img, 2)
plt.imshow(img, cmap='gray')
plt.show()
print(lbl_2)

In [None]:
for i in range(100):
    entry_num = num_ex+100+i #unseen data
    offset = entry_num * ex_len
    abs_indices = torch.arange(offset, offset+ex_len, device=device)
    data_indices = abs_indices % ex_len
    ex_indices = abs_indices // ex_len
    ex_indices = ex_indices * 256
    indices = ex_indices + data_indices
    with torch.no_grad():
        outputs = model(indices)
        outputs = torch.argmax(outputs, dim=1)

    outputs = outputs.cpu().numpy()
    lbl_1 = outputs[0:lbl_len]
    img = outputs[lbl_len:lbl_len+img_len]
    lbl_2 = outputs[lbl_len+img_len:]

    print(lbl_1)
    img = decode_patches(img, 2)
    plt.imshow(img, cmap='gray')
    plt.show()
    print(lbl_2)
    print('-'*100)