In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import certifi

os.environ["SSL_CERT_FILE"] = certifi.where()

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

import tensorkrowch as tk

In [None]:
mps = tk.TensorNetwork(name='mps')

# for i in range(100):
#     _ = tk.randn(shape=(2, 5, 2),
#                  axes_names=('left', 'input', 'right'),
#                  name=f'node_({i})',
#                  network=mps)

# for i in range(100):
#     mps[f'node_({i})']['right'] ^ mps[f'node_({(i + 1) % 100})']['left']

In [5]:
mps['node_(99)'].size()

torch.Size([2, 5, 2])

In [20]:
class MPS_DMRG(tk.models.MPSLayer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.parameterize(set_param=False, override=True)

        self.out_node.get_axis('input').name = 'output'

        self.block_position = None
        self.block_length = None

    @property
    def block(self):
        if self.block_position is not None:
            return self.mats_env[self.block_position]
        return None

    def merge_block(self, block_position, block_length):
        if block_position + block_length > self.n_features:
            raise ValueError(
                f'Last position of the block ({block_position + block_length}) '
                f'exceeds the range of MPS sites ({self.n_features})')
        elif block_length < 1:
            raise ValueError(
                '`block_length` should be greater than or equal to 1')

        if self.block_position is not None:
            raise ValueError(
                'Cannot create block if there is already a merged block')

        block_nodes = self.mats_env[block_position:(block_position + block_length)]

        block = block_nodes[0]
        for node in block_nodes[1:]:
            block = tk.contract_between_(block, node)
        block = block.parameterize(True)
        block.name = 'block'

        self.block_position = block_position
        self.block_length = block_length
        self._mats_env = self._mats_env[:block_position] + [block] + \
            self._mats_env[(block_position + block_length):]

    def unmerge_block(self, side='right', rank=None, cum_percentage=None):
        block = self.block

        block_nodes = []
        for i in range(self.block_length - 1):
            node1_axes = block.axes[:2]
            node2_axes = block.axes[2:]

            node, block = tk.split_(block,
                                    node1_axes,
                                    node2_axes,
                                    side=side,
                                    rank=rank,
                                    cum_percentage=cum_percentage)
            block.get_axis('split').name = 'left'
            node.get_axis('split').name = 'right'
            node.name = f'mats_env_({self.block_position + i})'

            block_nodes.append(node)

        block.name = f'mats_env_({self.block_position + i + 1})'
        block_nodes.append(block)

        self._mats_env = self._mats_env[:self.block_position] + block_nodes + \
            self._mats_env[(self.block_position + 1):]

        self.block_position = None
        self.block_length = None

    def contract(self):
        result_mats = []
        for node in self.mats_env:
            while any(['input' in name for name in node.axes_names]):
                for axis in node.axes:
                    if 'input' in axis.name:
                        data_node = node.neighbours(axis)
                        node = node @ data_node
                        break
            result_mats.append(node)

        result_mats = [self.left_node] + result_mats + [self.right_node]

        result = result_mats[0]
        for node in result_mats[1:]:
            result @= node

        return result

# Model hyperparameters
input_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_dim = 2
output_dim = 1
bond_dim = 50
init_method = 'unit'
block_length = 2
cum_percentage = 0.98
num_classes = 10
# Initialize network
model_name = 'mps_dmrg'
mps = MPS_DMRG(n_features=input_size,
               in_dim=embedding_dim,
               out_dim=num_classes,
               bond_dim=bond_dim,
               boundary='obc',
               init_method=init_method,
               device=device)

# Important to set data nodes before merging nodes
# mps.set_data_nodes()

mps['mats_env_node_(3)'].size(), mps


(torch.Size([50, 2, 50]),
 MPS_DMRG(
  	name: mpslayer
 	nodes: 
 		[mats_env_node_(0)
 		 mats_env_node_(1)
 		 mats_env_node_(2)
 		 mats_env_node_(3)
 		 left_node
 		 right_node]
 	edges:
 		[mats_env_node_(0)[input] <-> None
 		 mats_env_node_(1)[input] <-> None
 		 mats_env_node_(2)[output] <-> None
 		 mats_env_node_(3)[input] <-> None]))

In [25]:
# mps.set_data_nodes()
mps.unset_data_nodes()
mps['mats_env_node_(3)'].size(), mps

(torch.Size([50, 2, 50]),
 MPS_DMRG(
  	name: mpslayer
 	nodes: 
 		[mats_env_node_(0)
 		 mats_env_node_(1)
 		 mats_env_node_(2)
 		 mats_env_node_(3)
 		 left_node
 		 right_node]
 	edges:
 		[mats_env_node_(2)[output] <-> None
 		 mats_env_node_(0)[input] <-> None
 		 mats_env_node_(1)[input] <-> None
 		 mats_env_node_(3)[input] <-> None]))