In [1]:
import os
import certifi

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# BACKEND
os.environ["SSL_CERT_FILE"] = certifi.where()
torch.set_default_dtype(torch.float64)

# HYPERPARAMETERS

# Data
dataset_name = 'mnist'
batch_size = 64
image_size = 15
input_size = image_size ** 2

In [2]:
# Dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(image_size, antialias=True),
                               ])

train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transform,
                               download=True)
test_dataset = datasets.MNIST(root='data/',
                              train=False,
                              transform=transform,
                              download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)

In [None]:
class MPS(torch.nn.Module):
    def __init__(
        self, 
        n_cores: int,
        physical_dim: int,
        bond_dim: int,
    ):
        super().__init__()
        self.n_cores = n_cores
        self.physical_dim = physical_dim
        self.bond_dim = bond_dim
        self.merged = dict()

        self.g = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn(1, bond_dim, physical_dim, device=device))]
            + [torch.nn.Parameter(torch.randn(bond_dim, bond_dim, physical_dim, device=device)) for _ in range(n_cores - 2)]
            + [torch.nn.Parameter(torch.randn(bond_dim, 1, physical_dim, device=device))]
        )

    def merge_block(self, block_position: int):
        # position range: 0: (0, 1), 1: (1, 2), ..., n_cores - 2: (n_cores - 2, n_cores - 1)
        assert 0 <= block_position < self.n_cores - 1, f'Block position {block_position} is out of range [0, {self.n_cores - 1})'
        g_tilde = torch.einsum(
            'pdq, qvr -> pdvr', self.g[block_position], self.g[block_position + 1]
        )
        g_tilde.detach().requires_grad_()
        self.merged[block_position] = g_tilde

    def unmerge_block(self, block_position: int, cum_percentage: float, side: str = 'right'):
        """Unmerge a block at a given position.

        Args:
            block_position (int): Position of the block to unmerge.
            cum_percentage (float): Cumulative percentage of the singular values to keep.
            side (str, optional): 
                Indicates the side to which the diagonal matrix should be contracted.
                If "left", the first resultant node's tensor will be U @ S and the other node's tensor will be V^T.
                If "right", their tensors will be U and S @ V^T, respectively.
        """
        # position range: 0: (0, 1), 1: (1, 2), ..., n_cores - 2: (n_cores - 2, n_cores - 1)
        assert 0 <= block_position < self.n_cores - 1, f'Block position {block_position} is out of range [0, {self.n_cores - 1})'
        merged = self.merged.pop(block_position)
        p, d, v, r = merged.shape
        u, s, vt = torch.linalg.svd(merged.reshape(p*d, v*r))

        # Truncate singular values
        s_cumsum = torch.cumsum(s)

        if side == 'left':
            gl = u @ torch.diag(s)
            gr = vt.reshape(v, r, r)
        else:
            gl = u
            gr = torch.diag(s) @ vt

        # Truncate singular values
            

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute loss w.r.t active merged core.

        Args:
            x (torch.Tensor): Input tensor of shape (B, N, D)

        Returns:
            loss (torch.Tensor): Loss of shape (1,)
        """
        pass
    

In [63]:
# Hyperparameters
learning_rate = 1e-3
weight_decay = 1e-8
num_epochs = 100
move_block_epochs = 100

# Loss and optimizer
criterion = nn.CrossEntropyLoss()

# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)

            scores = model(embedding(x))
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        accuracy = float(num_correct) / float(num_samples) * 100
    model.train()
    return accuracy

In [None]:
from tqdm import tqdm

# Train network
block_position = 0
direction = 1
mps.merge_block(block_position, block_length)
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
optimizer = optim.Adam(mps.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)

for epoch in range(num_epochs):
    pbar = tqdm(train_loader, total=len(train_loader), desc=f'Epoch {epoch + 1}')
    for batch_idx, (data, targets) in enumerate(pbar):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = mps(embedding(data))
        loss = criterion(scores, targets)

        # # Forward
        # p = mps(embedding(data))
        # log_z = mps.norm(log_scale=True)
        # # loss = (log_z - p.log()).sum()
        # loss  =

        # Backward
        optimizer.zero_grad()
        loss.backward()

        pbar.set_postfix(loss=loss.item())

        # Gradient descent
        optimizer.step()

        if (batch_idx + 1) % move_block_epochs == 0:
            if block_position + direction + block_length > mps.n_features:
                direction *= -1
            if block_position + direction < 0:
                direction *= -1
            if block_length == mps.n_features:
                direction = 0

            if direction >= 0:
                mps.unmerge_block(side='left',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)
            else:
                mps.unmerge_block(side='right',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)

            block_position += direction
            mps.merge_block(block_position, block_length)
            mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
            optimizer = optim.Adam(mps.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

    train_acc = check_accuracy(train_loader, mps)
    test_acc = check_accuracy(test_loader, mps)

    print(f'* Epoch {epoch + 1:<3} ({block_position=}, {direction=})=>'
          f' Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

# Reset before saving the model
mps.reset()
torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')

  return tensor[index]
Epoch 1: 100%|██████████| 938/938 [03:17<00:00,  4.75it/s, loss=2.29]


* Epoch 1   (block_position=9, direction=1)=> Train. Acc.: 11.83, Test Acc.: 11.31


Epoch 2:  20%|█▉        | 185/938 [00:37<02:37,  4.78it/s, loss=2.29]

In [48]:
(data, targets)  = next((iter(train_loader)))

In [49]:
data = data.to(device)
data = data.reshape(data.shape[0], -1)
p = mps(embedding(data))
torch.isnan(p.log()).sum()

tensor(24)

IndexError: Node "mats_env_node_(112)" has no axis with name "input"