In [1]:
from typing import Callable
import random
from collections import namedtuple

import jax
import numpy as np
import optax

import pcax as px
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu

import torchvision
from torchvision import datasets, transforms
import torch
import matplotlib.pyplot as plt

2024-05-27 00:06:17.121400: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
Dataset = namedtuple("Dataset", ["train_loader", "val_loader", "test_loader"])

# This is a simple collate function that stacks numpy arrays used to interface
# the PyTorch dataloader with JAX. In the future we hope to provide custom dataloaders
# that are independent of PyTorch.
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

# The dataloader assumes cuda is being used, as such it sets 'pin_memory = True' and
# 'prefetch_factor = 2'. Note that the batch size should be constant during training, so
# we set 'drop_last = True' to avoid having to deal with variable batch sizes.
class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=16,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )

# Function to add noise to the labels in the dataset
def add_label_noise(dataset, noise_level=0.2):
    targets = np.array(dataset.targets)
    num_classes = len(np.unique(targets))
    num_noisy = int(noise_level * len(targets))
    noisy_indices = np.random.choice(len(targets), num_noisy, replace=False)

    for idx in noisy_indices:
        original_label = targets[idx]
        new_label = original_label
        while new_label == original_label:
            new_label = np.random.randint(0, num_classes)
        targets[idx] = new_label

    dataset.targets = torch.tensor(targets)
    return dataset

# Function to get the dataloaders
def get_dataloaders(dataset_name, train_subset_size, batch_size, noise_level=0.2):
    if dataset_name.lower() == "mnist":
        ds = datasets.MNIST
    else:
        raise NotImplementedError(f"Dataset {dataset_name} isn't available")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: x.view(-1).numpy())  # Flatten the image to a vector
    ])

    train_set = ds(root='./data', download=True, train=True, transform=transform)
    train_set = add_label_noise(train_set, noise_level=noise_level)

    val_subset_size = int(0.2 * train_subset_size)
    random_train_indices = np.random.choice(len(train_set), size=train_subset_size, replace=False)
    remaining_indices = np.setdiff1d(np.arange(len(train_set)), random_train_indices)
    random_val_indices = np.random.choice(remaining_indices, size=val_subset_size, replace=False)

    train_loader = TorchDataloader(
        train_set, batch_size=batch_size, num_workers=16,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(random_train_indices))
    val_loader = TorchDataloader(
        train_set, batch_size=batch_size, num_workers=16,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(random_val_indices))

    test_set = ds(root='./data', download=True, train=False, transform=transform)
    test_loader = TorchDataloader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=16)

    return Dataset(train_loader=train_loader, val_loader=val_loader, test_loader=test_loader)

In [3]:
class Model(px.Module):
    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int, act_fn: Callable[[jax.Array], jax.Array]
    ) -> None:
        super().__init__()

        self.act_fn = px.static(act_fn)

        self.layers = [
            pxnn.Linear(input_dim, hidden_dim),
            pxnn.Linear(hidden_dim, output_dim)
        ]

        # create a glorot uniform initializer:
        # see: https://pytorch.org/docs/2.0/nn.init.html?highlight=xavier#torch.nn.init.xavier_uniform_
        # see: https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.variance_scaling.html#jax.nn.initializers.variance_scaling
        #initializer = jax.nn.initializers.glorot_uniform() # this is wrong
        # relu adjust JAX scale value
        scale_ = 6.0
        initializer_ = jax.nn.initializers.variance_scaling(scale=scale_, mode='fan_avg', distribution='uniform')
        # more here: https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.variance_scaling.html#jax.nn.initializers.variance_scaling
        # now apply glorot uniform initialization to the weights only
        for l in self.layers:
            l.nn.weight.set(initializer_(px.RKG(), l.nn.weight.shape))

    @staticmethod
    def name():
        return "two_layer_nn"

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.act_fn(layer(x))

        x = self.layers[-1](x)

        return x

In [4]:
def ce_loss(output, one_hot_label):
    return -one_hot_label * jax.nn.log_softmax(output)


@pxf.vmap({"model": None}, in_axes=0, out_axes=0)
def forward(x, *, model: Model):
    return model(x)


@pxf.vmap({"model": None}, in_axes=(0, 0), out_axes=(None, 0), axis_name="batch")
def loss(x, y, *, model: Model):
    y_ = model(x)
    return jax.lax.pmean(ce_loss(y_, y).sum(), "batch"), y_


@pxf.jit()
def train_on_batch(x: jax.Array, y: jax.Array, *, model: Model, optim_w: pxu.Optim):
    model.train()

    with pxu.step(model):
        (e, y_), g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(loss)(x, y, model=model)
    optim_w.step(model, g["model"])


def train(dl, *, model: Model, optim_w: pxu.Optim):
    for x, y in dl:
        train_on_batch(x, jax.nn.one_hot(y, 10), model=model, optim_w=optim_w)


@pxf.jit()
def eval_on_batch(x: jax.Array, y: jax.Array, *, model: Model):
    model.eval()

    with pxu.step(model):
        e, y_ = loss(x, jax.nn.one_hot(y, 10), model=model)
        y_ = y_.argmax(axis=-1)

    return (y_ == y).mean(), y_, e


def eval(dl, *, model: Model):
    acc = []
    es = []
    ys_ = []
    for x, y in dl:
        a, y_, e = eval_on_batch(x, y, model=model)
        acc.append(a)
        es.append(e)
        ys_.append(y_)

    return np.mean(acc), np.concatenate(ys_), np.mean(es)

In [5]:
batch_size = 128
nm_epochs = 4000
model = Model(input_dim=784, hidden_dim=128, output_dim=10, act_fn=jax.nn.relu)

# Assuming dataset is the namedtuple with the dataloaders
dataset = get_dataloaders("mnist", train_subset_size=4000, batch_size=batch_size, noise_level=0.0)
# Check the sizes of the datasets
print(f"Training set: {len(dataset.train_loader.sampler)} samples")
print(f"Validation set: {len(dataset.val_loader.sampler)} samples")
print(f"Test set: {len(dataset.test_loader.dataset)} samples")

with pxu.step(model):
    optim_w = pxu.Optim(optax.sgd(1e-2, momentum=0.95), pxu.Mask(pxnn.LayerParam)(model))

Training set: 4000 samples
Validation set: 800 samples
Test set: 10000 samples


In [10]:
# show number of classes in the dataset by accessing countingn the unique labels
print(len(np.unique(dataset.train_loader.dataset.targets)))
# compute input dimension of the model
input_dim = dataset.train_loader.dataset[0][0].shape[0]
print(input_dim)
# show the name of the dataset
print(type(dataset.train_loader.dataset).__name__)
# print name of the model
print(model.name())
# now compute the total number of parameters in the model
print(model)

10
784
MNIST
two_layer_nn
(Model):
  .layers[0].nn.weight: LayerParam([128,784], float32)
  .layers[0].nn.bias: LayerParam([128], float32)
  .layers[1].nn.weight: LayerParam([10,128], float32)
  .layers[1].nn.bias: LayerParam([10], float32)


In [17]:
model.layers[0].nn.bias.size

128

In [20]:
import torch
import torch.nn as nn
nn.init.calculate_gain("relu")

1.4142135623730951

In [None]:
# Initialize lists to keep track of losses and accuracies
train_losses = []
val_losses = []
test_losses = []
train_accuracies = []
val_accuracies = []
test_accuracies = []

# keep track of the train and test losses and accuracies for plotting purposes
for e in range(nm_epochs):
    # train the model
    train(dataset.train_loader, model=model, optim_w=optim_w)
    
    # evaluate the model and get accuracies and losses
    a_train, ys_train, e_train = eval(dataset.train_loader, model=model)
    a_val, ys_val, e_val = eval(dataset.val_loader, model=model)
    a_test, ys_test, e_test = eval(dataset.test_loader, model=model)

    # append losses and accuracies to lists
    train_losses.append(e_train)
    val_losses.append(e_val)
    test_losses.append(e_test)
    train_accuracies.append(a_train)
    val_accuracies.append(a_val)
    test_accuracies.append(a_test)

    # print accuracies
    print(f"Epoch {e + 1}/{nm_epochs} - Training Accuracy: {a_train * 100:.2f}% - Validation Accuracy: {a_val * 100:.2f}% - Test Accuracy: {a_test * 100:.2f}%")
    # print losses
    print(f"Epoch {e + 1}/{nm_epochs} - Training Loss: {e_train:.4f} - Validation Loss: {e_val:.4f} - Test Loss: {e_test:.4f}")

In [None]:
# store/save the lists to a file for later use
np.savez("mnist_bp_results.npz", train_losses=train_losses, val_losses=val_losses, test_losses=test_losses, train_accuracies=train_accuracies, val_accuracies=val_accuracies, test_accuracies=test_accuracies)

In [None]:
# now reload the file and plot the results
results = np.load("mnist_bp_results.npz")
train_losses = results["train_losses"]
val_losses = results["val_losses"]
test_losses = results["test_losses"]
train_accuracies = results["train_accuracies"]
val_accuracies = results["val_accuracies"]
test_accuracies = results["test_accuracies"]

In [None]:
# Plotting the losses
nm_epochs = len(train_losses)
plt.figure(figsize=(24, 12))
plt.plot(range(1, nm_epochs + 1), train_losses, label='Training Loss')
#plt.plot(range(1, nm_epochs + 1), val_losses, label='Validation Loss')
plt.plot(range(1, nm_epochs + 1), test_losses, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('BP: Loss vs. Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Plotting the accuracies
plt.figure(figsize=(24, 12))
plt.plot(range(1, nm_epochs + 1), train_accuracies, label='Training Accuracy')
#plt.plot(range(1, nm_epochs + 1), val_accuracies, label='Validation Accuracy')
plt.plot(range(1, nm_epochs + 1), test_accuracies, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('BP: Accuracy vs. Epochs')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(model)

In [None]:
# create a random matplotlib figure of a sine wave
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)
plt.plot(x, y)
plt.show()