# MPS and Canonical Form - MNIST

In [1]:
import time
from functools import partial

import torch
import torch.nn as nn
from torchvision import transforms, datasets

import tensorkrowch as tk

In [2]:
# Miscellaneous initialization
torch.manual_seed(0)

# Training parameters
num_train = 60000
num_test = 10000
batch_size = 500
image_size = (28, 28)
num_epochs = 10
num_epochs_canonical = 3
learn_rate = 1e-4
l2_reg = 0.0
d_phys = 3
d_bond = 10

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mps = tk.MPSLayer(n_sites=image_size[0] * image_size[1] + 1,
                  d_phys=d_phys,
                  n_labels=10,
                  d_bond=d_bond)
mps = mps.to(device)

In [4]:
# Before starting training, set memory modes to True, and trace
mps.auto_stack = True
mps.auto_unbind = False
mps.trace(torch.zeros(image_size[0] * image_size[1], 1, d_phys).to(device))

In [5]:
# Set our loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(),
                             lr=learn_rate,
                             weight_decay=l2_reg)

In [6]:
def embedding(image: torch.Tensor) -> torch.Tensor:
    return torch.stack([torch.ones_like(image), image, 1 - image], dim=1)

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Lambda(embedding)])  # partial(tk.add_ones, dim=1)

train_set = datasets.MNIST('./data', download=True, transform=transform)
test_set = datasets.MNIST('./data', download=True, transform=transform,
                          train=False)

In [7]:
# Put MNIST data into dataloaders
samplers = {
    "train": torch.utils.data.SubsetRandomSampler(range(num_train)),
    "test": torch.utils.data.SubsetRandomSampler(range(num_test)),
}
loaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True
    )
    for (name, dataset) in [("train", train_set), ("test", test_set)]
}
num_batches = {
    name: total_num // batch_size
    for (name, total_num) in [("train", num_train), ("test", num_test)]
}

print(
    f"Training on {num_train} MNIST images \n"
    f"(testing on {num_test}) for {num_epochs} epochs"
)
print(f"Using Adam w/ learning rate = {learn_rate:.1e}")
if l2_reg > 0:
    print(f" * L2 regularization = {l2_reg:.2e}")
print()

Training on 60000 MNIST images 
(testing on 10000) for 10 epochs
Using Adam w/ learning rate = 1.0e-04



In [8]:
for epoch_num in range(1, num_epochs + 1):
    running_train_loss = 0.0
    running_train_acc = 0.0
    
    for inputs, labels in loaders["train"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_train_loss += loss
            running_train_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        running_test_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs = inputs.view([
                batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
            labels = labels.data
            inputs, labels = inputs.to(device), labels.to(device)

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_test_acc += torch.sum(preds == labels).item() / batch_size
    
    print(f'* Epoch {epoch_num}: '
          f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, '
          f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, '
          f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}')

* Epoch 1: Train. Loss: 0.9456, Train. Acc.: 0.6753, Test Acc.: 0.8924
* Epoch 2: Train. Loss: 0.2921, Train. Acc.: 0.9121, Test Acc.: 0.9360
* Epoch 3: Train. Loss: 0.2066, Train. Acc.: 0.9378, Test Acc.: 0.9443
* Epoch 4: Train. Loss: 0.1642, Train. Acc.: 0.9502, Test Acc.: 0.9595
* Epoch 5: Train. Loss: 0.1317, Train. Acc.: 0.9601, Test Acc.: 0.9632
* Epoch 6: Train. Loss: 0.1135, Train. Acc.: 0.9654, Test Acc.: 0.9655
* Epoch 7: Train. Loss: 0.1046, Train. Acc.: 0.9687, Test Acc.: 0.9668
* Epoch 8: Train. Loss: 0.0904, Train. Acc.: 0.9720, Test Acc.: 0.9723
* Epoch 9: Train. Loss: 0.0836, Train. Acc.: 0.9740, Test Acc.: 0.9725
* Epoch 10: Train. Loss: 0.0751, Train. Acc.: 0.9764, Test Acc.: 0.9748


In [9]:
# Original number of parametrs
n_params = 0
memory = 0
for p in mps.parameters():
    n_params += p.nelement()
    memory += p.nelement() * p.element_size()  # Bytes
print(f'Nº params:     {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB')  # MegaBytes

Nº params:     235660
Memory module: 0.8990 MB


In [10]:
# Canonicalize SVD
# ----------------
mps.canonicalize(cum_percentage=0.98)
mps.trace(torch.zeros(image_size[0] * image_size[1], 1, d_phys).to(device))

# New test accuracy
with torch.no_grad():
    running_acc = 0.0

    for inputs, labels in loaders["test"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        # Call our MPS to get logit scores and predictions
        scores = mps(inputs)
        _, preds = torch.max(scores, 1)
        running_acc += torch.sum(preds == labels).item() / batch_size

print(f"Test Acc.: {running_acc / num_batches['test']:.4f}\n")

# Number of parametrs
n_params = 0
memory = 0
for p in mps.parameters():
    n_params += p.nelement()
    memory += p.nelement() * p.element_size()  # Bytes
print(f'Nº params:     {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB\n')  # MegaBytes

Test Acc.: 0.9196

Nº params:     150710
Memory module: 0.5749 MB



In [11]:
# Continue training and obtaining canonical form after each epoch
optimizer = torch.optim.Adam(mps.parameters(),
                             lr=learn_rate,
                             weight_decay=l2_reg)

for epoch_num in range(1, num_epochs_canonical + 1):
    running_train_loss = 0.0
    running_train_acc = 0.0
    
    for inputs, labels in loaders["train"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_train_loss += loss
            running_train_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        running_test_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs = inputs.view([
                batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
            labels = labels.data
            inputs, labels = inputs.to(device), labels.to(device)

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_test_acc += torch.sum(preds == labels).item() / batch_size
    
    print(f'* Epoch {epoch_num}: '
          f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, '
          f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, '
          f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}')

* Epoch 1: Train. Loss: 0.1018, Train. Acc.: 0.9684, Test Acc.: 0.9693
* Epoch 2: Train. Loss: 0.0815, Train. Acc.: 0.9746, Test Acc.: 0.9698
* Epoch 3: Train. Loss: 0.0716, Train. Acc.: 0.9778, Test Acc.: 0.9721
