## MNIST Classification / Generation using AUNN

### Setup & Definition

In [None]:
import torch
import random
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#deallocate all cuda memory
torch.cuda.empty_cache()

#print cuda memory
print(torch.cuda.memory_allocated())

In [None]:
# def positional_encoding(x: torch.Tensor, dim: int = 96):
#     """Apply fourier positional encoding to the input.

#     Args:
#       x (torch.Tensor): a 1 dimension tensor of indices
#       dim (optional, int): dimension of positional encoding. max index representable is 2^(dim//2+1). Default: 64.
#     Returns:
#       (torch.Tensor): Positional encoding of the input tensor. dimension: [x.size(0), dim]
#     """
#     position = x.unsqueeze(1)
#     device = x.device  # Get the device of x

#     # Create div_term on the same device as x
#     base = 10_000_000.0
#     div_term = torch.exp(
#         torch.arange(0, dim, 2, dtype=torch.float32, device=device) *\
#             (-np.log(base) / dim)
#     )

#     # Create pe on the same device as x
#     pe = torch.zeros(x.size(0), dim, device=device)

#     # Perform computations
#     pe[:, 0::2] = torch.sin(position * div_term)
#     pe[:, 1::2] = torch.cos(position * div_term)
#     return pe


def positional_encoding(x, dim: int = 48):
    """
    Binary positional encoding, where each dimension is a bit in the binary representation of the index.
    
    Args:
        x: Input tensor of positions with shape [N]
        dim (int): Number of bits in the binary encoding (output dimension). Default is 48.

    Returns:
        torch.Tensor: A binary encoding tensor with shape [N, dim] where each bit represents a binary position.
    """

    # Each row corresponds to an element in x; columns are the binary bits
    encoding = ((x.unsqueeze(1) >> torch.arange(dim, device=x.device)) & 1).to(torch.float32)

    return encoding


# def positional_encoding(x: torch.Tensor, dim: int = 48):
#     """
#     Positional encoding using sine and cosine functions with frequencies as powers of two,
#     starting from 2^2 (i.e., skipping 2^0 and 2^1).

#     Args:
#         x (torch.Tensor): Input tensor of positions with shape [N]
#         dim (int): Total dimension of the encoding. Must be even.

#     Returns:
#         torch.Tensor: A tensor with shape [N, dim] containing the positional encodings.
#     """
#     assert dim % 2 == 0, "Encoding dimension (dim) must be even."

#     # Frequencies corresponding to powers of two, starting from 2^1
#     frequency_powers = torch.arange(0, dim, dtype=torch.float32, device=x.device)
#     frequencies = 2 ** frequency_powers  # Shape: [num_frequencies]

#     # Compute the angles: [N, num_frequencies]
#     x = x.unsqueeze(1)  # Shape: [N, 1]
#     angles = (2 * torch.pi * x) / frequencies  # Broadcasting over x and frequencies

#     # Compute the positional encodings
#     encoding = torch.sin(angles)

#     return encoding

In [None]:
class AUNNModel(nn.Module):
    def __init__(self, embedding_dim:int, output_dim:int, num_layers:int, hidden_dim:int):

        super(AUNNModel, self).__init__()
        
        assert num_layers % 2 == 0 and num_layers >= 2, "Number of layers must be even and at least 2."

        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Input Layer
        self.input_layer = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_dim),
            nn.SiLU(),
            # nn.RMSNorm(self.hidden_dim)
        )

        # Hidden Layers
        self.layers = nn.ModuleList()
        for _ in range(self.num_layers - 2):  # Exclude input and output layers
            self.layers.append(nn.Sequential(
                nn.Linear(self.hidden_dim, self.hidden_dim),
                nn.SiLU(),
                # nn.RMSNorm(self.hidden_dim)
            ))

        # Output Layer
        self.output_layer = nn.Linear(self.hidden_dim, self.output_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        
        output = self.input_layer(x)
        residual = output  # Initialize residual for skip connections

        for idx, layer in enumerate(self.layers):
            output = layer(output)

            # Apply skip connection every two layers
            if (idx + 1) % 2 == 0:
                output = output + residual  # Skip connection
                residual = output  # Update residual

        output = self.output_layer(output)
        return output

In [None]:
# Define a function to save the model checkpoint
def save_checkpoint(model, params, optimizer, losses, filename="checkpoint.pth"):
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
    }

    keys = ['embedding_dim', 'output_dim', 'num_layers', 'hidden_dim']
    assert all(k in params for k in keys)
    for k in keys:
        checkpoint[k] = params[k]

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved with loss {losses[-1]:.4f}")

In [None]:
def load_checkpoint(filename="checkpoint.pth"):

    checkpoint = torch.load(filename, weights_only=True)
    
    keys = ['embedding_dim', 'output_dim', 'num_layers', 'hidden_dim']
    params = {k: checkpoint[k] for k in keys}
    
    model = AUNNModel(**params)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    optimizer = optim.AdamW(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    losses = checkpoint['losses']

    print(f"Checkpoint loaded: loss {losses[-1]:.4f}")

    return model, optimizer, losses

### Load MNIST

In [None]:
import struct
from array import array

def load_mnist(images_path, labels_path, shuffle=False, seed=42):

    labels = []
    with open(labels_path, 'rb') as file:
        magic, size = struct.unpack(">II", file.read(8)) 
        if magic != 2049:
            raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
        labels = array("B", file.read())

    images = []
    rows, cols = None, None
    with open(images_path, 'rb') as file:
        magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
        if magic != 2051:
            raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
        data = array("B", file.read())
        for i in range(size):
            img = np.array(data[i * rows * cols:(i + 1) * rows * cols], dtype=np.uint8)
            img = np.where(img > 0, 1, 0) #binarize
            img.resize((rows, cols))
            images.append(img)

    assert len(images) == len(labels)

    if shuffle:
        random.seed(seed)
        indices = list(range(len(images)))
        random.shuffle(indices)
        images = [images[i] for i in indices]
        labels = [labels[i] for i in indices]

    return images, labels

In [None]:
from pathlib import Path

cur_dir = Path().resolve()
input_path = cur_dir / 'mnist'
training_images_filepath = input_path / 'train-images.idx3-ubyte'
training_labels_filepath = input_path /'train-labels.idx1-ubyte'
test_images_filepath = input_path / 't10k-images.idx3-ubyte'
test_labels_filepath = input_path / 't10k-labels.idx1-ubyte'

images, labels = load_mnist(training_images_filepath, training_labels_filepath, shuffle=True, seed=0)

label2idx = {}
for idx, label in enumerate(labels):
    if label not in label2idx:
        label2idx[label] = []
    label2idx[label].append(idx)

img_size = len(images[0].flatten())

plt.imshow(images[0], cmap='gray')
plt.show()
print(labels[0])
print(img_size)

### Train

#### Define

In [None]:
# Hyperparameters

embedd_dim = 48
num_layers = 8     # Must be even and at least 2 (bc of skip connections)
hidden_dim = 512  # Size of hidden layers
output_dim = 2

eos_bos_len = 8
label_len = 10
entry_length = eos_bos_len * 2 + label_len * 2 + img_size
assert entry_length < 1024, "Entry length must be less than 1024, improves positional encoding consistency."
print(entry_length)

In [None]:
# Initialize the model, loss function, and optimizer
model = AUNNModel(embedd_dim, output_dim, num_layers, hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"trainable parameters: {total_params}")

#### One Example Random

In [None]:
train_labels = [0,1,2,3,4,5,6,7,8,9]
img_per_label = 1
example_imgs = {}
for label in train_labels:
    print(label)
    example_imgs[label] = []
    for idx in label2idx[label][:img_per_label]:
        img = images[idx]
        example_imgs[label].append(img)
        plt.figure(figsize=(2,2))
        plt.imshow(img, cmap='gray')
        plt.show()

In [None]:
def data_provider(num_epochs, num_entries, batch_size=64, seed=42, offset_entries=0):

    assert num_entries % batch_size == 0, "Number of entries must be divisible by batch size."
    
    random.seed(seed)
    inputs = []
    targets = []
    N = offset_entries

    #instantiate sequence

    labels = []
    for _ in range(num_entries):
        label = random.choice(train_labels)
        labels.append(label)
        
    variants = []
    for _ in range(num_entries):
        variant = random.choice(list(range(img_per_label)))
        variants.append(variant)

    print(f"Labels: {labels[:5]}")
        
    for label, variant in zip(labels, variants):

        #add a begin sequence indicator (8 ones in a row)
        begin_sequence = torch.ones(eos_bos_len, dtype=torch.uint8)
        targets.append(begin_sequence)

        #one hot encode the label
        label_data = torch.zeros(label_len, dtype=torch.uint8)
        label_data[label] = 1
        targets.append(label_data)

        #add image to inputs
        img_data = example_imgs[label][variant]
        img_data = img_data.flatten()
        img_data = torch.tensor(img_data, dtype=torch.uint8)
        targets.append(img_data)

        #add the label again
        label_data = torch.zeros(label_len, dtype=torch.uint8)
        label_data[label] = 1
        targets.append(label_data)

        #add an end sequence indicator (8 ones in a row)
        end_sequence = torch.ones(eos_bos_len, dtype=torch.uint8)
        targets.append(end_sequence)

        #make input embedding
        start = N * 1024
        index = torch.arange(start, start + entry_length)
        embed = positional_encoding(index, embedd_dim)
        inputs.append(embed)
        N += 1

    inputs = torch.cat(inputs, dim=0)
    inputs = inputs.to(device)

    targets = torch.cat(targets, dim=0)
    targets = targets.to(device)

    print(f"Total Indices: {len(inputs)}")

    num_batches = num_entries // batch_size
    batch_indices = list(range(num_batches))

    for epoch in tqdm(range(num_epochs)):

        random.shuffle(batch_indices)

        for batch_idx in batch_indices:

            start = batch_idx * batch_size * entry_length
            end = (batch_idx + 1) * batch_size * entry_length

            yield epoch+1, inputs[start:end], targets[start:end]

In [None]:
losses = []

In [None]:
num_epochs = 10_000
num_entries = 200
batch_size = 100
offset = 0
num_batches = num_entries // batch_size

In [None]:
# Training loop

print("Initializing training loop...")
print(f"Number of entries: {num_entries}, Batch size: {batch_size}, Number of batches: {num_batches}")

model.train()

last_epoch = 1
epoch_losses = []
batch_num = 0

for epoch, inputs, targets in data_provider(num_epochs=num_epochs, num_entries=num_entries, batch_size=batch_size, offset_entries=offset):
    
    if epoch != last_epoch: # epoch logging
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        losses.append(avg_loss)
        print(f"Epoch [{last_epoch}/{num_epochs}] completed, Loss: {avg_loss:.8f}")
        batch_num = 0
        epoch_losses = []
        last_epoch = epoch
        
    batch_num += 1
    
    # do optimization
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if loss != loss:
        print("ERR loss is NaN")

    cur_loss = loss.item()
    epoch_losses.append(cur_loss)

    if batch_num % 10 == 0 and batch_num != num_batches: # batch logging
        print(f"Epoch [{last_epoch}/{num_epochs}], Batch [{batch_num}/{num_batches}], Loss: {cur_loss:.8f}")

In [None]:
# Plot the loss curves
plt.figure(figsize=(20, 6))
plt.plot(losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

In [None]:
def eval_output(start_idx):
    model.eval()
    indices = torch.arange(start_idx, start_idx+entry_length).to(device)
    inputs = positional_encoding(indices, embedd_dim)
    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
    predicted = predicted.cpu().numpy()
    start = 0
    bos = predicted[start:start+eos_bos_len]
    start += eos_bos_len
    label = predicted[start:start+label_len]
    start += label_len
    img = predicted[start:start+img_size]
    img = img.reshape(28, 28)
    start += img_size
    label2 = predicted[start:start+label_len]
    start += label_len
    eos = predicted[start:start+eos_bos_len]
    print("BOS:", bos)
    print("Label:", label)
    print("Image:", img)
    print("Label2:", label2)
    print("EOS:", eos)
    
# start = offset
# start = num_entries + offset
start = 600
num_to_show = 100

for entry_idx in range(start, start+num_to_show):
    seq_start = entry_idx * 1024
    print(f"Sequence @ {entry_idx}, idx = {seq_start}")
    eval_output(seq_start)
    print('-'*50)

In [None]:
#save checkpoint
save_checkpoint(
    model, 
    {'embedding_dim': embedd_dim, 'output_dim': output_dim, 'num_layers': num_layers, 'hidden_dim': hidden_dim}, 
    optimizer, 
    losses, 
    filename="mnist/checkpoint_scaffold_10x1.pth")

#### Multi Example Random

In [None]:
model, optimizer, losses = load_checkpoint(filename="mnist/checkpoint_scaffold_10x1.pth")

In [None]:
train_labels = [0,1,2,3,4,5,6,7,8,9]
img_per_label = 100
example_imgs = {}
for label in train_labels:
    print(label)
    example_imgs[label] = []
    for idx in label2idx[label][:img_per_label]:
        img = images[idx]
        example_imgs[label].append(img)
        # plt.figure(figsize=(2,2))
        # plt.imshow(img, cmap='gray')
        # plt.show()

In [None]:
num_epochs = 10_000
num_entries = 5000
batch_size = 100
offset = 0
num_batches = num_entries // batch_size

In [None]:
# Training loop

print("Initializing training loop...")
print(f"Number of entries: {num_entries}, Batch size: {batch_size}, Number of batches: {num_batches}")

model.train()

last_epoch = 1
epoch_losses = []
batch_num = 0

for epoch, inputs, targets in data_provider(num_epochs=num_epochs, num_entries=num_entries, batch_size=batch_size, offset_entries=offset):
    
    if epoch != last_epoch: # epoch logging
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        losses.append(avg_loss)
        print(f"Epoch [{last_epoch}/{num_epochs}] completed, Loss: {avg_loss:.8f}")
        batch_num = 0
        epoch_losses = []
        last_epoch = epoch
        
    batch_num += 1
    
    # do optimization
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if loss != loss:
        print("ERR loss is NaN")

    cur_loss = loss.item()
    epoch_losses.append(cur_loss)

    if batch_num % 10 == 0 and batch_num != num_batches: # batch logging
        print(f"Epoch [{last_epoch}/{num_epochs}], Batch [{batch_num}/{num_batches}], Loss: {cur_loss:.8f}")

In [None]:
# Plot the loss curves
plt.figure(figsize=(20, 6))
plt.plot(losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

In [None]:
# start = offset
start = num_entries + offset
num_to_show = 100

for entry_idx in range(start, start+num_to_show):
    seq_start = entry_idx * 1024
    print(f"Sequence @ {entry_idx}, idx = {seq_start}")
    eval_output(seq_start)
    print('-'*50)

In [None]:
#save checkpoint
save_checkpoint(
    model, 
    {'embedding_dim': embedd_dim, 'output_dim': output_dim, 'num_layers': num_layers, 'hidden_dim': hidden_dim}, 
    optimizer, 
    losses, 
    filename="mnist/checkpoint_scaffold_10x100.pth")