# MPS and Canonical Form - FashionMNIST

In [1]:
import time
from functools import partial

import torch
import torch.nn as nn
from torchvision import transforms, datasets

import tensorkrowch as tk

In [2]:
# Miscellaneous initialization
torch.manual_seed(0)

# Training parameters
num_train = 60000
num_test = 10000
batch_size = 500
image_size = (28, 28)
num_epochs = 30
num_epochs_canonical = 3
learn_rate = 1e-4
l2_reg = 0.0
d_phys = 2
d_bond = 10

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mps = tk.MPSLayer(n_sites=image_size[0] * image_size[1] + 1,
                  d_phys=d_phys,
                  n_labels=10,
                  d_bond=d_bond)
mps = mps.to(device)

In [4]:
# Before starting training, set memory modes to True, and trace
mps.auto_stack = True
mps.auto_unbind = False
mps.trace(torch.zeros(image_size[0] * image_size[1], 1, d_phys).to(device))

In [5]:
# Set our loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(),
                             lr=learn_rate,
                             weight_decay=l2_reg)

In [6]:
def embedding(image: torch.Tensor) -> torch.Tensor:
    return torch.stack([torch.ones_like(image), image, 1 - image], dim=1)

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Lambda(partial(tk.add_ones, dim=1))])  # partial(tk.add_ones, dim=1)

train_set = datasets.FashionMNIST('./data', download=True, transform=transform)
test_set = datasets.FashionMNIST('./data', download=True, transform=transform,
                                 train=False)

In [7]:
# Put MNIST data into dataloaders
samplers = {
    "train": torch.utils.data.SubsetRandomSampler(range(num_train)),
    "test": torch.utils.data.SubsetRandomSampler(range(num_test)),
}
loaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True
    )
    for (name, dataset) in [("train", train_set), ("test", test_set)]
}
num_batches = {
    name: total_num // batch_size
    for (name, total_num) in [("train", num_train), ("test", num_test)]
}

print(
    f"Training on {num_train} MNIST images \n"
    f"(testing on {num_test}) for {num_epochs} epochs"
)
print(f"Using Adam w/ learning rate = {learn_rate:.1e}")
if l2_reg > 0:
    print(f" * L2 regularization = {l2_reg:.2e}")
print()

Training on 60000 MNIST images 
(testing on 10000) for 30 epochs
Using Adam w/ learning rate = 1.0e-04



In [8]:
for epoch_num in range(1, num_epochs + 1):
    running_train_loss = 0.0
    running_train_acc = 0.0
    
    for inputs, labels in loaders["train"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_train_loss += loss
            running_train_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        running_test_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs = inputs.view([
                batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
            labels = labels.data
            inputs, labels = inputs.to(device), labels.to(device)

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_test_acc += torch.sum(preds == labels).item() / batch_size
    
    print(f'* Epoch {epoch_num}: '
          f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, '
          f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, '
          f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}')

* Epoch 1: Train. Loss: 1.3908, Train. Acc.: 0.4461, Test Acc.: 0.7251
* Epoch 2: Train. Loss: 0.6236, Train. Acc.: 0.7636, Test Acc.: 0.7842
* Epoch 3: Train. Loss: 0.5318, Train. Acc.: 0.8020, Test Acc.: 0.7989
* Epoch 4: Train. Loss: 0.4872, Train. Acc.: 0.8198, Test Acc.: 0.8069
* Epoch 5: Train. Loss: 0.4533, Train. Acc.: 0.8324, Test Acc.: 0.8296
* Epoch 6: Train. Loss: 0.4351, Train. Acc.: 0.8378, Test Acc.: 0.8253
* Epoch 7: Train. Loss: 0.4111, Train. Acc.: 0.8473, Test Acc.: 0.8333
* Epoch 8: Train. Loss: 0.3967, Train. Acc.: 0.8521, Test Acc.: 0.8431
* Epoch 9: Train. Loss: 0.3875, Train. Acc.: 0.8570, Test Acc.: 0.8324
* Epoch 10: Train. Loss: 0.3825, Train. Acc.: 0.8572, Test Acc.: 0.8461
* Epoch 11: Train. Loss: 0.3754, Train. Acc.: 0.8591, Test Acc.: 0.8472
* Epoch 12: Train. Loss: 0.3680, Train. Acc.: 0.8620, Test Acc.: 0.8452
* Epoch 13: Train. Loss: 0.3640, Train. Acc.: 0.8635, Test Acc.: 0.8521
* Epoch 14: Train. Loss: 0.3559, Train. Acc.: 0.8668, Test Acc.: 0.8561
*

In [9]:
# Original number of parametrs
n_params = 0
memory = 0
for p in mps.parameters():
    n_params += p.nelement()
    memory += p.nelement() * p.element_size()  # Bytes
print(f'Nº params:     {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB')  # MegaBytes

Nº params:     157440
Memory module: 0.6006 MB


In [10]:
# Canonicalize SVD
# ----------------
mps.canonicalize(cum_percentage=0.98)
mps.trace(torch.zeros(image_size[0] * image_size[1], 1, d_phys).to(device))

# New test accuracy
with torch.no_grad():
    running_acc = 0.0

    for inputs, labels in loaders["test"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        # Call our MPS to get logit scores and predictions
        scores = mps(inputs)
        _, preds = torch.max(scores, 1)
        running_acc += torch.sum(preds == labels).item() / batch_size

print(f"Test Acc.: {running_acc / num_batches['test']:.4f}\n")

# Number of parametrs
n_params = 0
memory = 0
for p in mps.parameters():
    n_params += p.nelement()
    memory += p.nelement() * p.element_size()  # Bytes
print(f'Nº params:     {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB\n')  # MegaBytes

Test Acc.: 0.8321

Nº params:     100664
Memory module: 0.3840 MB



In [11]:
# Continue training and obtaining canonical form after each epoch
optimizer = torch.optim.Adam(mps.parameters(),
                             lr=learn_rate,
                             weight_decay=l2_reg)

for epoch_num in range(1, num_epochs_canonical + 1):
    running_train_loss = 0.0
    running_train_acc = 0.0
    
    for inputs, labels in loaders["train"]:
        inputs = inputs.view(
            [batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
        labels = labels.data
        inputs, labels = inputs.to(device), labels.to(device)

        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_train_loss += loss
            running_train_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        running_test_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs = inputs.view([
                batch_size, d_phys, image_size[0] * image_size[1]]).permute(2, 0, 1)
            labels = labels.data
            inputs, labels = inputs.to(device), labels.to(device)

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_test_acc += torch.sum(preds == labels).item() / batch_size
    
    print(f'* Epoch {epoch_num}: '
          f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, '
          f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, '
          f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}')

* Epoch 1: Train. Loss: 0.3074, Train. Acc.: 0.8843, Test Acc.: 0.8691
* Epoch 2: Train. Loss: 0.2964, Train. Acc.: 0.8894, Test Acc.: 0.8686
* Epoch 3: Train. Loss: 0.2900, Train. Acc.: 0.8916, Test Acc.: 0.8725
