## MNIST Classification / Generation using AUNN

## Setup & Definition

In [None]:
import torch
import random
import time
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#deallocate all cuda memory
torch.cuda.empty_cache()

#print cuda memory
print(torch.cuda.memory_allocated())

In [2]:
class AUNNModel(nn.Module):
    def __init__(
        self, 
        embedding_dim:int, 
        output_dim:int, 
        num_layers:int, 
        hidden_dim:int):        

        assert num_layers >= 2, "Number of layers must be at least 2"

        super(AUNNModel, self).__init__() 
    
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # # Initialize fourier embedding parameters
        # thresh = 0.1
        # context_len = 4
        # alpha = -np.log(thresh) / (context_len ** 2)
        # M = embedding_dim // 2
        # w_init = np.random.normal(0, np.sqrt(2 * alpha), size=M).astype(np.float32)
        # b_init = np.random.uniform(0, 2 * np.pi, size=M).astype(np.float32)
        # self.register_buffer("w", torch.tensor(w_init, device=device))
        # self.register_buffer("b", torch.tensor(b_init, device=device))

        # Input Layer
        self.input_layer =  nn.Linear(self.embedding_dim, self.hidden_dim)

        # Hidden Layers
        self.layers = nn.ModuleList()
        for _ in range(self.num_layers - 2):  # Exclude input and output layers
            self.layers.append(nn.Sequential(
                nn.Linear(self.hidden_dim, self.hidden_dim),
                nn.SiLU()
            ))

        # Output Layer
        self.output_layer = nn.Linear(self.hidden_dim, self.output_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    #transformer style sinuisoidal
    def encode(self, x: torch.Tensor):
        """Apply fourier positional encoding to the input.

        Args:
        x (torch.Tensor): a 1 dimension tensor of indices
        dim (optional, int): dimension of positional encoding. max index representable is 2^(dim//2+1). Default: 64.
        Returns:
        (torch.Tensor): Positional encoding of the input tensor. dimension: [x.size(0), dim]
        """
        dim = self.embedding_dim
        position = x.unsqueeze(1)
        device = x.device  # Get the device of x

        # Create div_term on the same device as x
        div_term = torch.exp(
            torch.arange(0, dim, 2, dtype=torch.float32, device=device) *\
                (-np.log(10000.0) / dim)
        )

        # Create pe on the same device as x
        pe = torch.zeros(x.size(0), dim, device=device)

        # Perform computations
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    # #sinusoidal binary
    # def encode(self, x: torch.Tensor):

    #     dim = self.embedding_dim
    #     assert dim % 2 == 0, "Encoding dimension (dim) must be even."

    #     # Determine the number of frequencies
    #     num_frequencies = dim // 2

    #     # Frequencies corresponding to powers of two, starting from 2^2
    #     frequency_powers = torch.arange(2, 2 + num_frequencies, dtype=torch.float32, device=x.device)
    #     frequencies = 2 ** frequency_powers  # Shape: [num_frequencies]

    #     # Compute the angles: [N, num_frequencies]
    #     x = x.unsqueeze(1)  # Shape: [N, 1]
    #     angles = (2 * torch.pi * x) / frequencies  # Broadcasting over x and frequencies

    #     # Compute the positional encodings
    #     encoding = torch.zeros(x.size(0), dim, device=x.device)
    #     encoding[:, 0::2] = torch.sin(angles)  # Even indices: sin
    #     encoding[:, 1::2] = torch.cos(angles)  # Odd indices: cos

    #     return encoding

    # #binary
    # def encode(self, x: torch.Tensor):
    #     dim = self.embedding_dim
    #     encoding = ((x.unsqueeze(1) >> torch.arange(dim, device=x.device)) & 1).to(torch.float32)
    #     return encoding
    
    # #fourier
    # def encode(self, x: torch.Tensor):
    #     inps = torch.outer(x, self.w) + self.b
    #     scale = np.sqrt(2.0 / self.embedding_dim)
    #     embed = scale * torch.cat([torch.cos(inps), torch.sin(inps)], dim=-1)
    #     return embed

    def forward(self, indices):
        
        x = self.encode(indices)
        x = self.input_layer(x)
        x = x + nn.SiLU()(x)

        for layer in self.layers:
            x = x + layer(x)  # MLP output with skip connection

        x = self.output_layer(x)
        return x

In [3]:
# Define a function to save the model checkpoint
def save_checkpoint(model, params, optimizer, losses, filename="checkpoint.pth"):
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
        'params':{}
    }

    keys = ['embedding_dim', 'output_dim', 'num_layers', 'hidden_dim']
    assert all(k in params for k in keys)
    for k in keys:
        checkpoint['params'][k] = params[k]

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved with loss {losses[-1]:.4f}")

In [4]:
def load_checkpoint(filename="checkpoint.pth"):

    checkpoint = torch.load(filename, weights_only=True)
    
    params = checkpoint['params']
    model = AUNNModel(**params)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    optimizer = optim.AdamW(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    losses = checkpoint['losses']

    print(f"Checkpoint loaded: loss {losses[-1]:.4f}")

    return model, optimizer, losses

## Prepare MNIST

In [5]:
import struct
from array import array

def load_mnist(images_path, labels_path, shuffle=False, seed=42):

    labels = []
    with open(labels_path, 'rb') as file:
        magic, size = struct.unpack(">II", file.read(8)) 
        if magic != 2049:
            raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
        labels = array("B", file.read())

    images = []
    rows, cols = None, None
    with open(images_path, 'rb') as file:
        magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
        if magic != 2051:
            raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
        data = array("B", file.read())
        for i in range(size):
            img = np.array(data[i * rows * cols:(i + 1) * rows * cols], dtype=np.uint8)
            img = np.where(img > 0, 1, 0) #binarize
            img.resize((rows, cols))
            images.append(img)

    assert len(images) == len(labels)

    if shuffle:
        random.seed(seed)
        indices = list(range(len(images)))
        random.shuffle(indices)
        images = [images[i] for i in indices]
        labels = [labels[i] for i in indices]

    return images, labels

def make_bitstring(images, labels):
    
    targets = []
    for img, label in zip(images, labels):

        #add image to inputs
        img_data = img.flatten()
        targets.append(img_data)

        #create binary label (4 bits), not one-hot encoded
        label = np.array([label])
        label = label >> np.arange(4) & 1
        num_repeats = len(img_data) // 4
        assert len(img_data) % 4 == 0, "Image data length must be divisible by 4"
        label = np.concatenate([label] * num_repeats)
        targets.append(label)

    targets = np.concatenate(targets, axis=0)
    return targets

In [None]:
from pathlib import Path

cur_dir = Path().resolve()
input_path = cur_dir / 'mnist'
training_images_filepath = input_path / 'train-images.idx3-ubyte'
training_labels_filepath = input_path /'train-labels.idx1-ubyte'
test_images_filepath = input_path / 't10k-images.idx3-ubyte'
test_labels_filepath = input_path / 't10k-labels.idx1-ubyte'

train_images, train_labels = load_mnist(training_images_filepath, training_labels_filepath, shuffle=True)
train_bitstring = make_bitstring(train_images, train_labels)
print(f"{len(train_bitstring):,} training samples")

In [None]:
img = train_bitstring[0:784]
img = img.reshape(28,28)
print(img)

lbl = train_bitstring[784:1568][0:4]
print(lbl)

img = train_bitstring[1568:1568+784]
img = img.reshape(28,28)
print(img)

lbl = train_bitstring[1568+784:1568+1568][0:4]
print(lbl)

## Train

In [17]:
# Hyperparameters

embedd_dim = 128
num_layers = 12    # Must be even and at least 2 (bc of skip connections)
hidden_dim = 512  # Size of hidden layers
output_dim = 2
batch_size = 4096
num_epochs = 50000

In [None]:
# Initialize the model, loss function, and optimizer

model = AUNNModel(embedd_dim, output_dim, num_layers, hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())
print(f"Model has {model.count_params():,} parameters")

In [None]:
ex_len = 784 * 2
num_ex = 10
# train_bitstring = np.concat([train_bitstring[0:ex_len*num_ex]] * 1000)
# train_bitstring = torch.tensor(train_bitstring, dtype=torch.uint8).to(device)
# train_bitstring = torch.tensor(train_bitstring[0:ex_len*num_ex], dtype=torch.uint8).to(device)
# train_bitstring = torch.tensor(train_bitstring, dtype=torch.uint8).to(device)
print(len(train_bitstring))

In [20]:
losses = []
accuracies = []

In [None]:
# Training loop

accumulation_steps = 1

num_batches = len(train_bitstring) // batch_size
if len(train_bitstring) % batch_size != 0:
    num_batches += 1

# num_batches = 10

for epoch in tqdm(list(range(num_epochs))):
    for batch_num in tqdm(list(range(num_batches)), leave=False, disable=False):

        start = batch_num * batch_size
        end = start + batch_size
        end = min(end, len(train_bitstring))
        targets = train_bitstring[start:end]
        inputs = torch.arange(start, end, device=device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Scale loss for gradient accumulation
        loss = loss / accumulation_steps
        loss.backward()

        # Perform optimization every accumulation_steps batches
        if (batch_num + 1) % accumulation_steps == 0 or batch_num == num_batches - 1:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Clip gradients
            optimizer.step()
            optimizer.zero_grad(set_to_none=True) # Flush gradients

        if loss != loss:
            print("ERR loss is NaN")

        cur_loss = loss.item()
        losses.append(cur_loss * accumulation_steps)

        # Calculate accuracy
        preds = outputs.argmax(dim=1)
        accuracy = (preds == targets).float().mean().item()
        accuracies.append(accuracy)

        # Batch logging
        if batch_num % 1000 == 0 and batch_num != 0:
            avg_loss = np.mean(losses[-1000:])
            avg_accuracy = np.mean(accuracies[-1000:])
            print(f"Batch [{batch_num}/{num_batches}], Loss: {avg_loss:.8f}, Accuracy: {avg_accuracy:.8f}", end="\r")

        # if batch_num == 10: break

    # Epoch logging
    avg_loss = np.mean(losses[-num_batches:])
    avg_accuracy = np.mean(accuracies[-num_batches:])
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.8f}, Accuracy: {avg_accuracy:.8f}")

In [None]:
# Plot the loss curves
temp = [x if x < 1 else 1 for x in losses] #clip losses to 1
plt.figure(figsize=(20, 6))
plt.plot(temp, label="Training Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

In [None]:
save_checkpoint(model, {
    'embedding_dim': embedd_dim,
    'output_dim': output_dim,
    'num_layers': num_layers,
    'hidden_dim': hidden_dim
}, optimizer, losses, filename="mnist/checkpoint.pth")

In [None]:
entry_num = 0
ex_len = 784 * 2
bitstring = train_bitstring[entry_num*ex_len:(entry_num+1)*ex_len].cpu().numpy()

img = bitstring[0:784]
lbl = bitstring[784:784*2]
print(img.reshape(28, 28))
print(lbl.reshape(-1, 4))

start = entry_num * ex_len
end = start + ex_len
indices = torch.arange(start, end, device=device)
with torch.no_grad():
    outputs = model(indices)
    outputs = torch.argmax(outputs, dim=1)

outputs = outputs.cpu().numpy()
img = outputs[0:784]
lbl = outputs[784:784*2]
print(img.reshape(28, 28))
print(lbl.reshape(-1, 4))

In [None]:
entry_num = 1000
start = entry_num * ex_len
end = start + ex_len
indices = torch.arange(start, end, device=device)
with torch.no_grad():
    outputs = model(indices)
    outputs = torch.argmax(outputs, dim=1)

outputs = outputs.cpu().numpy()
img = outputs[0:784]
lbl = outputs[784:784*2]
print(img.reshape(28, 28))
print(lbl.reshape(-1, 4))

In [None]:
#attempt conditioning
bos = np.ones(10, dtype=np.uint8)
lbl = np.zeros(10, dtype=np.uint8)
lbl[5] = 1

target = np.concatenate([bos, lbl], axis=0)
target = torch.tensor(target, dtype=torch.uint8, device=device)
indices = torch.arange(100000000, 100000000+len(target), device=device)

#train the model
model.train()
num_steps = 10
for _ in range(num_steps):
    outputs = model(indices)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)
    print(loss.item())

#test the model
model.eval()
entry_len = 814
indices = torch.arange(100000000, 100000000+entry_len, device=device)
outputs = model(indices)
outputs = torch.argmax(outputs, dim=1)
outputs = outputs.cpu().numpy()

bos = outputs[:10]
lbl = outputs[10:20]
img = outputs[20:20+784]
lbl2 = outputs[20+784:20+784+10]

print('-'*20)
print(bos)
print(lbl)
print(img.reshape(28, 28))
print(lbl2)