In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import math
from ndlinear import NdLinear 

In [None]:
# Data loading & normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


In [None]:
# Define a CNN-NdLinear model. 
class NdCNN(nn.Module):
    def __init__(self, input_shape, hidden_size):
        super(NdCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.ndlinear = NdLinear(input_shape, hidden_size)
        final_dim = math.prod(hidden_size)
        self.fc_out = nn.Linear(final_dim, 100)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.ndlinear(x)
        x = x.view(x.shape[0], -1)
        x = self.fc_out(self.relu(x))
        return x

In [None]:
# Set the device. 
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
                "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
                "and/or you do not have an MPS-enabled device on this machine.")
    compute_device = torch.device("cpu")
else:
    compute_device = torch.device("mps")

In [None]:
# Instantiate model, loss, and optimizer
nd_cnn = NdCNN((64, 8, 8), (32, 8, 8)).to(compute_device)
# An example of incorrect usage. This will be equivalent to a naive nn.Linear layer. 
# nd_cnn = NdCNN((64,), (32,)).to(compute_device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(nd_cnn.parameters(), lr=0.001)

In [None]:
# Training loop. Display loss and accuracy for each epoch. 
epochs = 20
ndcnn_loss = [] 
ndcnn_acc = [] 
params_ndcnn = sum(p.numel() for p in nd_cnn.parameters() if p.requires_grad)
for epoch in range(epochs):
    nd_cnn.train() 
    running_loss = 0.0 
    correct_ndcnn, total = 0, 0 
    for images, labels in trainloader:
        images, labels = images.to(compute_device), labels.to(compute_device)
        optimizer.zero_grad()
        outputs_hyper = nd_cnn(images)
        loss_hyper = criterion(outputs_hyper, labels)
        loss_hyper.backward()
        optimizer.step()
        running_loss += loss_hyper.item()
        ndcnn_loss.append(running_loss / len(trainloader))

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(compute_device), labels.to(compute_device)
            outputs_hyper = nd_cnn(images)
            _, predicted_hyper = torch.max(outputs_hyper, 1)
            correct_ndcnn += (predicted_hyper == labels).sum().item()
            total += labels.size(0)
    ndcnn_acc.append(100 * correct_ndcnn / total)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {ndcnn_loss[-1]:.4f}, Acc: {ndcnn_acc[-1]:.2f}%")