# Test MPS solver

This file tests various aspects of the solver and make sure they behave as expected.

## Setup


In [1]:
import os
# Enable MPS fallback for unsupported operations (e.g., linalg_qr)
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

from tqdm import tqdm
import torch as t
from functools import reduce
import numpy as np
np.set_printoptions(precision=3, suppress=True)
import matplotlib.pyplot as plt
import math
from scipy.optimize import differential_evolution
import einops
from jaxtyping import Float
import pandas as pd
import wandb
import argparse

from opt_mps_fiducial_state import apply_unitary
from mps_utils import to_canonical_form, to_comp_basis, get_rand_mps, get_product_state, get_ghz_state, apply_random_unitaries, test_canonical_form
from IPython.display import HTML, display
from game import get_default_3players, get_default_2players, get_default_H

from solver import *

Using Apple Silicon GPU (MPS) with CPU fallback for unsupported ops


## Initial equilibrium finding

Given a two-qubit state, compute the Nash equilibrium

In [5]:
Psi = get_rand_mps(L=2, chi=3, d_phys=2, dtype=np.complex128)
for i in range(len(Psi)):
    print(f"Site {i} has shape {Psi[i].shape}")

# Gathers statistics of the state

Site 0 has shape (2, 1, 2)
Site 1 has shape (2, 2, 1)


In [28]:
psi = to_comp_basis(Psi)
Psi_ = to_canonical_form(Psi, form='B')
print(f"Left canonical form: {Psi}")
print(f"Right canonical form: {Psi_}")
psi_ = to_comp_basis(Psi_)

Left canonical form: [array([[[-0.606-0.12j , -0.518+0.592j]],

       [[-0.333+0.713j, -0.354-0.506j]]]), array([[[-0.821-0.08j ],
        [ 0.158+0.082j]],

       [[ 0.444-0.196j],
        [ 0.225+0.04j ]]])]
Right canonical form: [array([[[-0.56 -0.146j,  0.148-0.169j]],

       [[-0.299+0.703j,  0.101+0.144j]]]), array([[[-0.858-0.083j],
        [-0.412-0.296j]],

       [[ 0.464-0.205j],
        [-0.856-0.096j]]])]


In [34]:
# NOTE: batch perturb must start from B canonical form
Psi_batch, original_S, batch_perturbed_S = batch_perturb(Psi_, batch_size=1, lr=0.03, site=0)
# print(Psi_batch)
print(original_S)
print(batch_perturbed_S)
# U, S, Vh = np.linalg.svd(
#     einops.rearrange(
#         Psi[0], 
#         'd_phys chi_l chi_r -> (d_phys chi_l) chi_r'
# )
Psi_test = [psi[0] for psi in Psi_batch]
test_canonical_form(Psi_test, form='A')
Psi_test_ = to_canonical_form(Psi_test, form='A')

assert all(np.allclose(psi, psi_) for psi, psi_ in zip(Psi_test, Psi_test_))


[[0.959 0.285]]
[[0.957 0.291]]


In [38]:
def get_schmidt_eigval(Psi):
    Psi_ = to_canonical_form(Psi, form='B')
    U, S, Vh = np.linalg.svd(
        einops.rearrange(
            Psi_[0], 
            'd_phys chi_l chi_r -> (d_phys chi_l) chi_r'
        ),
        full_matrices=False,
    )
    return S[0] # The larger Schmidt value

get_schmidt_eigval(Psi)
    

0.9585012601927729