In [60]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import certifi

os.environ["SSL_CERT_FILE"] = certifi.where()

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

import tensorkrowch as tk
device = torch.device('cpu')

# set to float64
torch.set_default_dtype(torch.float64)

# Hyperparameters

# Data
dataset_name = 'mnist'
batch_size = 64
image_size = 15
input_size = image_size ** 2


# # Model
# bond_dim = 50
# init_method = 'unit'
# block_length = 2
# in_dim = 2
# cum_percentage = 0.98

In [61]:
# Dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(image_size, antialias=True),
                               ])

train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transform,
                               download=True)
test_dataset = datasets.MNIST(root='data/',
                              train=False,
                              transform=transform,
                              download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)

In [62]:
class MPS_DMRG(tk.models.MPSLayer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.parameterize(set_param=False, override=True)

        self.out_node.get_axis('input').name = 'output'

        self.block_position = None
        self.block_length = None

    @property
    def block(self):
        if self.block_position is not None:
            return self.mats_env[self.block_position]
        return None

    def merge_block(self, block_position, block_length):
        if block_position + block_length > self.n_features:
            raise ValueError(
                f'Last position of the block ({block_position + block_length}) '
                f'exceeds the range of MPS sites ({self.n_features})')
        elif block_length < 1:
            raise ValueError(
                '`block_length` should be greater than or equal to 1')

        if self.block_position is not None:
            raise ValueError(
                'Cannot create block if there is already a merged block')

        block_nodes = self.mats_env[block_position:(block_position + block_length)]

        block = block_nodes[0]
        for node in block_nodes[1:]:
            block = tk.contract_between_(block, node)
        block = block.parameterize(True)
        block.name = 'block'

        self.block_position = block_position
        self.block_length = block_length
        self._mats_env = self._mats_env[:block_position] + [block] + \
            self._mats_env[(block_position + block_length):]

    def unmerge_block(self, side='right', rank=None, cum_percentage=None):
        block = self.block

        block_nodes = []
        for i in range(self.block_length - 1):
            node1_axes = block.axes[:2]
            node2_axes = block.axes[2:]

            node, block = tk.split_(block,
                                    node1_axes,
                                    node2_axes,
                                    side=side,
                                    rank=rank,
                                    cum_percentage=cum_percentage)
            block.get_axis('split').name = 'left'
            node.get_axis('split').name = 'right'
            node.name = f'mats_env_({self.block_position + i})'

            block_nodes.append(node)

        block.name = f'mats_env_({self.block_position + i + 1})'
        block_nodes.append(block)

        self._mats_env = self._mats_env[:self.block_position] + block_nodes + \
            self._mats_env[(self.block_position + 1):]

        self.block_position = None
        self.block_length = None

    def contract(self):
        result_mats = []
        for node in self.mats_env:
            while any(['input' in name for name in node.axes_names]):
                for axis in node.axes:
                    if 'input' in axis.name:
                        data_node = node.neighbours(axis)
                        node = node @ data_node
                        break
            result_mats.append(node)

        result_mats = [self.left_node] + result_mats + [self.right_node]

        result = result_mats[0]
        for node in result_mats[1:]:
            result @= node

        return result

# Model hyperparameters
embedding_dim = 2
output_dim = 1
bond_dim = 50
init_method = 'unit'
block_length = 2
cum_percentage = 0.98
num_classes = 10
# Initialize network
model_name = 'mps_dmrg'
mps = MPS_DMRG(n_features=input_size,
               in_dim=embedding_dim,
               out_dim=num_classes,
               bond_dim=bond_dim,
               boundary='obc',
               init_method=init_method,
               device=device)

# Important to set data nodes before merging nodes
mps.set_data_nodes()

def embedding(x):
    x = tk.embeddings.unit(x, dim=embedding_dim)
    return x

# def embedding(x):
#     x_bool = x >= 0.5
#     x_ohot = torch.nn.functional.one_hot(x_bool.long(), num_classes=2)
#     return x_ohot.to(torch.get_default_dtype())

# mps.norm()

In [63]:
# Hyperparameters
learning_rate = 1e-3
weight_decay = 1e-8
num_epochs = 100
move_block_epochs = 100

# Loss and optimizer
criterion = nn.CrossEntropyLoss()

# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)

            scores = model(embedding(x))
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        accuracy = float(num_correct) / float(num_samples) * 100
    model.train()
    return accuracy

In [None]:
from tqdm import tqdm

# Train network
block_position = 0
direction = 1
mps.merge_block(block_position, block_length)
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
optimizer = optim.Adam(mps.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)

for epoch in range(num_epochs):
    pbar = tqdm(train_loader, total=len(train_loader), desc=f'Epoch {epoch + 1}')
    for batch_idx, (data, targets) in enumerate(pbar):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = mps(embedding(data))
        loss = criterion(scores, targets)

        # # Forward
        # p = mps(embedding(data))
        # log_z = mps.norm(log_scale=True)
        # # loss = (log_z - p.log()).sum()
        # loss  =

        # Backward
        optimizer.zero_grad()
        loss.backward()

        pbar.set_postfix(loss=loss.item())

        # Gradient descent
        optimizer.step()

        if (batch_idx + 1) % move_block_epochs == 0:
            if block_position + direction + block_length > mps.n_features:
                direction *= -1
            if block_position + direction < 0:
                direction *= -1
            if block_length == mps.n_features:
                direction = 0

            if direction >= 0:
                mps.unmerge_block(side='left',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)
            else:
                mps.unmerge_block(side='right',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)

            block_position += direction
            mps.merge_block(block_position, block_length)
            mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
            optimizer = optim.Adam(mps.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

    train_acc = check_accuracy(train_loader, mps)
    test_acc = check_accuracy(test_loader, mps)

    print(f'* Epoch {epoch + 1:<3} ({block_position=}, {direction=})=>'
          f' Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

# Reset before saving the model
mps.reset()
torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')

  return tensor[index]
Epoch 1: 100%|██████████| 938/938 [03:17<00:00,  4.75it/s, loss=2.29]


* Epoch 1   (block_position=9, direction=1)=> Train. Acc.: 11.83, Test Acc.: 11.31


Epoch 2: 100%|██████████| 938/938 [03:27<00:00,  4.52it/s, loss=2.28]


* Epoch 2   (block_position=18, direction=1)=> Train. Acc.: 11.99, Test Acc.: 11.83


Epoch 3: 100%|██████████| 938/938 [03:15<00:00,  4.80it/s, loss=2.31]


* Epoch 3   (block_position=27, direction=1)=> Train. Acc.: 12.41, Test Acc.: 12.29


Epoch 4: 100%|██████████| 938/938 [03:06<00:00,  5.03it/s, loss=2.34]


* Epoch 4   (block_position=36, direction=1)=> Train. Acc.: 12.79, Test Acc.: 11.96


Epoch 5: 100%|██████████| 938/938 [03:13<00:00,  4.86it/s, loss=2.28]


* Epoch 5   (block_position=45, direction=1)=> Train. Acc.: 14.59, Test Acc.: 13.92


Epoch 6: 100%|██████████| 938/938 [03:01<00:00,  5.17it/s, loss=2.26]


* Epoch 6   (block_position=54, direction=1)=> Train. Acc.: 17.15, Test Acc.: 16.06


Epoch 7: 100%|██████████| 938/938 [02:46<00:00,  5.62it/s, loss=2.21]


* Epoch 7   (block_position=63, direction=1)=> Train. Acc.: 18.09, Test Acc.: 16.77


Epoch 8: 100%|██████████| 938/938 [02:53<00:00,  5.42it/s, loss=2.22]


* Epoch 8   (block_position=72, direction=1)=> Train. Acc.: 20.85, Test Acc.: 19.03


Epoch 9: 100%|██████████| 938/938 [03:01<00:00,  5.16it/s, loss=2.15]


* Epoch 9   (block_position=81, direction=1)=> Train. Acc.: 22.78, Test Acc.: 20.51


Epoch 10: 100%|██████████| 938/938 [02:48<00:00,  5.56it/s, loss=2.09]


* Epoch 10  (block_position=90, direction=1)=> Train. Acc.: 24.74, Test Acc.: 21.43


Epoch 11: 100%|██████████| 938/938 [02:49<00:00,  5.52it/s, loss=1.93]


* Epoch 11  (block_position=99, direction=1)=> Train. Acc.: 27.34, Test Acc.: 23.87


Epoch 12: 100%|██████████| 938/938 [02:46<00:00,  5.62it/s, loss=2.13]


* Epoch 12  (block_position=108, direction=1)=> Train. Acc.: 29.07, Test Acc.: 24.83


Epoch 13: 100%|██████████| 938/938 [02:42<00:00,  5.78it/s, loss=1.94]


* Epoch 13  (block_position=117, direction=1)=> Train. Acc.: 37.06, Test Acc.: 32.27


Epoch 14: 100%|██████████| 938/938 [02:39<00:00,  5.87it/s, loss=1.72]


* Epoch 14  (block_position=126, direction=1)=> Train. Acc.: 39.34, Test Acc.: 34.37


Epoch 15: 100%|██████████| 938/938 [02:31<00:00,  6.20it/s, loss=1.45]


* Epoch 15  (block_position=135, direction=1)=> Train. Acc.: 42.36, Test Acc.: 37.03


Epoch 16: 100%|██████████| 938/938 [02:30<00:00,  6.24it/s, loss=1.83]


* Epoch 16  (block_position=144, direction=1)=> Train. Acc.: 44.41, Test Acc.: 38.53


Epoch 17: 100%|██████████| 938/938 [02:31<00:00,  6.21it/s, loss=1.79]


* Epoch 17  (block_position=153, direction=1)=> Train. Acc.: 45.55, Test Acc.: 39.26


Epoch 18: 100%|██████████| 938/938 [02:26<00:00,  6.42it/s, loss=1.32]


* Epoch 18  (block_position=162, direction=1)=> Train. Acc.: 47.09, Test Acc.: 40.54


Epoch 19: 100%|██████████| 938/938 [02:21<00:00,  6.62it/s, loss=1.46]


* Epoch 19  (block_position=171, direction=1)=> Train. Acc.: 47.43, Test Acc.: 40.93


Epoch 20: 100%|██████████| 938/938 [02:17<00:00,  6.82it/s, loss=1.28]


* Epoch 20  (block_position=180, direction=1)=> Train. Acc.: 47.73, Test Acc.: 40.75


Epoch 21: 100%|██████████| 938/938 [02:16<00:00,  6.85it/s, loss=1.73]


* Epoch 21  (block_position=189, direction=1)=> Train. Acc.: 47.67, Test Acc.: 40.95


Epoch 22: 100%|██████████| 938/938 [02:15<00:00,  6.94it/s, loss=1.32]


* Epoch 22  (block_position=198, direction=1)=> Train. Acc.: 47.65, Test Acc.: 40.93


Epoch 23: 100%|██████████| 938/938 [02:13<00:00,  7.04it/s, loss=1.33]


* Epoch 23  (block_position=207, direction=1)=> Train. Acc.: 47.88, Test Acc.: 40.64


Epoch 24: 100%|██████████| 938/938 [02:06<00:00,  7.42it/s, loss=1.78]


* Epoch 24  (block_position=216, direction=1)=> Train. Acc.: 47.91, Test Acc.: 40.80


Epoch 25: 100%|██████████| 938/938 [01:53<00:00,  8.29it/s, loss=1.7] 


* Epoch 25  (block_position=221, direction=-1)=> Train. Acc.: 47.92, Test Acc.: 41.19


Epoch 26: 100%|██████████| 938/938 [01:48<00:00,  8.65it/s, loss=1.73]


* Epoch 26  (block_position=212, direction=-1)=> Train. Acc.: 47.65, Test Acc.: 40.86


Epoch 27: 100%|██████████| 938/938 [01:51<00:00,  8.43it/s, loss=1.38]


* Epoch 27  (block_position=203, direction=-1)=> Train. Acc.: 47.80, Test Acc.: 40.75


Epoch 28: 100%|██████████| 938/938 [01:53<00:00,  8.25it/s, loss=2.02] 


* Epoch 28  (block_position=194, direction=-1)=> Train. Acc.: 48.16, Test Acc.: 41.28


Epoch 29: 100%|██████████| 938/938 [02:07<00:00,  7.37it/s, loss=1.62]


* Epoch 29  (block_position=185, direction=-1)=> Train. Acc.: 47.98, Test Acc.: 40.90


Epoch 30: 100%|██████████| 938/938 [02:19<00:00,  6.74it/s, loss=1.68]


* Epoch 30  (block_position=176, direction=-1)=> Train. Acc.: 48.17, Test Acc.: 41.01


Epoch 31: 100%|██████████| 938/938 [02:07<00:00,  7.33it/s, loss=1.5] 


* Epoch 31  (block_position=167, direction=-1)=> Train. Acc.: 48.78, Test Acc.: 40.92


Epoch 32: 100%|██████████| 938/938 [02:23<00:00,  6.52it/s, loss=1.49]


* Epoch 32  (block_position=158, direction=-1)=> Train. Acc.: 49.01, Test Acc.: 40.88


Epoch 33: 100%|██████████| 938/938 [02:26<00:00,  6.39it/s, loss=1.47]


* Epoch 33  (block_position=149, direction=-1)=> Train. Acc.: 49.68, Test Acc.: 41.32


Epoch 34: 100%|██████████| 938/938 [02:31<00:00,  6.20it/s, loss=1.35]


* Epoch 34  (block_position=140, direction=-1)=> Train. Acc.: 50.09, Test Acc.: 41.11


Epoch 35: 100%|██████████| 938/938 [02:36<00:00,  5.98it/s, loss=1.45]


* Epoch 35  (block_position=131, direction=-1)=> Train. Acc.: 50.81, Test Acc.: 41.41


Epoch 36: 100%|██████████| 938/938 [02:38<00:00,  5.92it/s, loss=1.58]


* Epoch 36  (block_position=122, direction=-1)=> Train. Acc.: 52.07, Test Acc.: 41.96


Epoch 37: 100%|██████████| 938/938 [02:43<00:00,  5.75it/s, loss=1.7] 


* Epoch 37  (block_position=113, direction=-1)=> Train. Acc.: 53.47, Test Acc.: 43.17


Epoch 38: 100%|██████████| 938/938 [02:47<00:00,  5.61it/s, loss=1.3]  


* Epoch 38  (block_position=104, direction=-1)=> Train. Acc.: 58.95, Test Acc.: 48.38


Epoch 39: 100%|██████████| 938/938 [02:46<00:00,  5.64it/s, loss=1.03] 


* Epoch 39  (block_position=95, direction=-1)=> Train. Acc.: 61.00, Test Acc.: 50.63


Epoch 40: 100%|██████████| 938/938 [02:45<00:00,  5.67it/s, loss=0.893]


* Epoch 40  (block_position=86, direction=-1)=> Train. Acc.: 62.89, Test Acc.: 52.28


Epoch 41: 100%|██████████| 938/938 [02:51<00:00,  5.46it/s, loss=1.25] 


* Epoch 41  (block_position=77, direction=-1)=> Train. Acc.: 65.61, Test Acc.: 55.07


Epoch 42: 100%|██████████| 938/938 [02:55<00:00,  5.34it/s, loss=1.37] 


* Epoch 42  (block_position=68, direction=-1)=> Train. Acc.: 67.15, Test Acc.: 56.84


Epoch 43: 100%|██████████| 938/938 [03:01<00:00,  5.17it/s, loss=0.789]


* Epoch 43  (block_position=59, direction=-1)=> Train. Acc.: 68.50, Test Acc.: 58.19


Epoch 44: 100%|██████████| 938/938 [03:02<00:00,  5.15it/s, loss=1.08] 


* Epoch 44  (block_position=50, direction=-1)=> Train. Acc.: 68.68, Test Acc.: 58.54


Epoch 45: 100%|██████████| 938/938 [03:12<00:00,  4.89it/s, loss=0.911]


* Epoch 45  (block_position=41, direction=-1)=> Train. Acc.: 69.30, Test Acc.: 59.21


Epoch 46: 100%|██████████| 938/938 [03:11<00:00,  4.90it/s, loss=1.46] 


* Epoch 46  (block_position=32, direction=-1)=> Train. Acc.: 69.43, Test Acc.: 58.79


Epoch 47: 100%|██████████| 938/938 [03:17<00:00,  4.75it/s, loss=0.877]


* Epoch 47  (block_position=23, direction=-1)=> Train. Acc.: 69.29, Test Acc.: 58.96


Epoch 48: 100%|██████████| 938/938 [03:10<00:00,  4.92it/s, loss=1.03] 


* Epoch 48  (block_position=14, direction=-1)=> Train. Acc.: 69.34, Test Acc.: 58.73


Epoch 49:  95%|█████████▌| 895/938 [03:05<00:09,  4.70it/s, loss=0.769]

In [48]:
(data, targets)  = next((iter(train_loader)))

In [49]:
data = data.to(device)
data = data.reshape(data.shape[0], -1)
p = mps(embedding(data))
torch.isnan(p.log()).sum()

tensor(24)

IndexError: Node "mats_env_node_(112)" has no axis with name "input"