# Optimizing Fiducial State via Stochastic Gradient Descent

An updated plan: The update should be a perturbation at the orthogonality center instead of directly on the matrix product state, as in the latter case a 'small' change in the matrix elements does not mean a small change in the wavefunction


## Setup

In [35]:
import os
# Enable MPS fallback for unsupported operations (e.g., linalg_qr)
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import torch as t
import numpy as np
import pandas as pd
import einops
import torch.nn as nn
import torch.optim as optim
from jaxtyping import Complex, Float

from misc_torch import mps_2form, mps_overlap

if t.backends.mps.is_available():
    device = t.device('mps')
    print("Using Apple Silicon GPU (MPS) with CPU fallback for unsupported ops")
elif t.cuda.is_available():
    device = t.device('cuda')
    print("Using NVIDIA GPU (CUDA)")
else:
    device = t.device('cpu')
    print("Using CPU")

default_dtype = t.float32
print(f"Device: {device}, dtype: {default_dtype}")

Using Apple Silicon GPU (MPS) with CPU fallback for unsupported ops
Device: mps, dtype: torch.float32


## Utilities

### Initialization

In [49]:
def rand_init_finite_mps(d_phys, chi, n_sites, device, dtype):
    Psi = []
    if n_sites == 1:
        print("Why bother using MPS for a single site?")
        return None
    
    Psi.append(t.randn(d_phys, 1, chi, dtype=dtype, device=device))
    for i in list[int](range(n_sites))[1:-1]:
        A = t.randn(d_phys, chi, chi, dtype=dtype, device=device)
        Psi.append(A)

    Psi.append(t.randn(d_phys, chi, 1, dtype=dtype, device=device))
    return Psi

def to_comp_basis(Psi: list[t.Tensor]) -> list[t.Tensor]:
    psi = Psi[0]
    for A in Psi[1:]:
        psi = einops.einsum(psi, A, "... chi_l bond, phys bond chi_r -> ... phys chi_l chi_r")
    
    psi = psi.squeeze().reshape(2**len(Psi))
    return psi


### Converting to left/right normalized form

In [None]:
def to_canonical_form(Psi: list[t.Tensor], form: str = 'left', orth_center) -> tuple[list[t.Tensor], t.Tensor]:
    if form == 'left' or 'A':
        psi = Psi[0]
        d_phys = psi.shape[0]
        for j in range(len(Psi)):
            psi_grouped = einops.rearrange(psi, 'd_phys chi_l chi_r -> (d_phys chi_l) chi_r')
            U, S, Vh = t.linalg.svd(psi_grouped)
            Psi[j] = einops.rearrange(U, '(d_phys chi_l) chi_r -> d_phys chi_l chi_r', d_phys=d_phys)

            


def to_isometric_form(Psi: list[t.Tensor], form: str = 'left') -> tuple[list[t.Tensor], t.Tensor]:
    if form == 'left' or 'A':
        psi = Psi[0] # shape: (d_phys, 1, chi)
        d_phys = psi.shape[0]
        for j in range(len(Psi)):
            psi_grouped = einops.rearrange(psi, 'd_phys chi_l chi_r -> (d_phys chi_l) chi_r')
            left_iso, orth_center = t.linalg.qr(psi_grouped)
            print(left_iso.shape)
            print(orth_center.shape)
            Psi[j] = einops.rearrange(left_iso, '(d_phys chi_l) chi_r -> d_phys chi_l chi_r', d_phys=d_phys)
            if j < len(Psi) - 1:
                psi = einops.einsum(orth_center, Psi[j+1], 'chi_l bond, d_phys bond chi_r -> d_phys chi_l chi_r')
            else:
                # we are at the last site
                right_orth = orth_center.squeeze(-1)
    elif form == 'right' or 'B':
        raise NotImplementedError("Right isometric form not implemented")

    return Psi, right_orth




## Perturbation on MPS

In [52]:
chi = 10
d_phys = 2
n_sites = 3
boundary_vec = t.tensor([1.] + [0.] * (chi-1), dtype=default_dtype, device=device)

Psi = rand_init_finite_mps(d_phys, chi, n_sites, device, default_dtype)
print(Psi[1].shape)

Psi_simplified, right_orth = left_normalize(Psi)
print([Psi_simplified[i].shape for i in range(len(Psi_simplified))])


torch.Size([2, 10, 10])
torch.Size([2, 2])
torch.Size([2, 10])
torch.Size([4, 4])
torch.Size([4, 10])
torch.Size([8, 1])
torch.Size([1, 1])
[torch.Size([2, 1, 2]), torch.Size([2, 2, 4]), torch.Size([2, 4, 1])]


In [17]:
# check left orthogonality condition

Psi[1].shape


torch.Size([2, 2, 4])