# Optimization of Fiducial State via Gradient Descent

Plans:

- Individual state optimization: For 3 player quantum prisoner's dilemma: find the best fiducial state
  - Start from some random $\ket{\psi(A)}$
  - Get an NE state on this particular orbit, $\ket{\psi(A_1, A_2, A_3)}$. Note that we need to keep track of individual tensors, while previously we only tracked the computation basis components of $\ket{\psi}$.
  - Get gradients directly on entries of the $A$ matrices, $A_i \mapsto A_i + \delta A_i$.
  - Take updated matrices:
    - The gradient should be already orthogonal to the "unitary directions" because by definition of NE the states are stable against these deviations.
    - After updating, we need to turn this to canonical form.
- Batching using the pytorch library for more efficient numerics?

## Setup

In [103]:
from misc_torch import *
import numpy as np
import torch as t
import einops
from jaxtyping import Complex, Float

# Set default device (use GPU if available)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
default_dtype = t.float64
print(f"Using device: {device}, default dtype: {default_dtype}")

Using device: cpu, default dtype: torch.float64


### Define the Hamiltonian

In [106]:
H = [t.diag(t.tensor([6., 3., 3., 0., 10., 6., 6., 2.], dtype=default_dtype, device=device)),
     t.diag(t.tensor([6., 3., 10., 6., 3., 0., 6., 2.], dtype=default_dtype, device=device)),
     t.diag(t.tensor([6., 10., 3., 6., 3., 6., 0., 2.], dtype=default_dtype, device=device))]
H = [h.reshape(2, 2, 2, 2, 2, 2) for h in H]
H_all_in_one = t.stack(H, dim=0)

print(f"H_all_in_one.shape: {H_all_in_one.shape}")


H_all_in_one.shape: torch.Size([3, 2, 2, 2, 2, 2, 2])


## MPS representation of states

In [107]:
def get_state_from_tensors(A_list: list[t.Tensor], bc: str = 'PBC'):
    Psi = mps_2form(A_list)
    L = len(Psi)
    if bc == "PBC":
        psi = Psi[0]
        for A in Psi[1:]:
            psi = einops.einsum(psi, A, "... chi_l bond, phys bond chi_r -> ... phys chi_l chi_r")
        
        psi = t.diagonal(psi, dim1=-2, dim2=-1).sum(dim=-1)
        MPS_norm = t.sqrt(t.tensordot(psi, psi.conj(), dims=(list(range(L)), list(range(L)))))
    elif bc == "OBC":
        raise NotImplementedError("OBC not implemented yet")
    return psi / MPS_norm

def normalize_mps_tensor(A: t.Tensor):
    """
    Normalize MPS tensor by the largest eigenvalue of the transfer matrix
    A: tensor of shape (phys, chi_l, chi_r)
    """
    T = compute_transfer_matrix(A)
    eigvals = t.linalg.eigvals(T)
    max_eigval = t.abs(eigvals).max()
    A = A / t.sqrt(max_eigval)
    return A

def compute_transfer_matrix(A: t.Tensor):
    T = einops.einsum(
        A, A.conj(), "phys chi_l chi_r, phys chi_lc chi_rc -> chi_l chi_r chi_lc chi_rc"
    )
    T = einops.rearrange(T, "chi_l chi_r chi_lc chi_rc -> (chi_l chi_lc) (chi_r chi_rc)")
    return T

How to compute the normalization factor for individual tensors in an MPS? Canonical form does not fix the norm...

In [108]:
### Testing `normalize_mps_tensor` and 'mps_2form'

A_list = [t.randn(2, 3, 3, dtype=default_dtype, device=device) for _ in range(3)]
for i, A in enumerate(A_list):
    A_list[i] = normalize_mps_tensor(A)
Psi = mps_2form(A_list)
psi = get_state_from_tensors(A_list, bc="PBC")
print(f"norm of psi for random tensors, with max_eigval normalization: {t.linalg.norm(psi).item()}")

def test_isometric(psi_tensor: Complex[t.Tensor, "phys chi_l chi_r"], atol: float = 1e-10):
    left_contracted = einops.einsum(psi_tensor, psi_tensor.conj(), "phys chi_l chi_r, phys chi_l chi_rc -> chi_r chi_rc")
    right_contracted = einops.einsum(psi_tensor, psi_tensor.conj(), "phys chi_l chi_r, phys chi_lc chi_r -> chi_l chi_lc")
    
    left_canonical_err = t.linalg.norm(left_contracted - t.eye(psi_tensor.shape[2], dtype=default_dtype, device=device)).item()
    right_canonical_err = t.linalg.norm(right_contracted - t.eye(psi_tensor.shape[1], dtype=default_dtype, device=device)).item()

    if left_canonical_err < atol:
        print(f"Left canonical form: {left_canonical_err}")
        
    if right_canonical_err < atol:
        print(f"Right canonical form: {right_canonical_err}")
        
    if left_canonical_err > atol and right_canonical_err > atol:
        print(f"Warning: left_canonical_err = {left_canonical_err}, right_canonical_err = {right_canonical_err}, the tensor is not in canonical form")
        return None
    elif left_canonical_err < atol and right_canonical_err > atol:
        return 'A'
    elif left_canonical_err > atol and right_canonical_err < atol:
        return 'B'
    else:
        return 'AB'

form = test_isometric(Psi[0])
    

norm of psi for random tensors, with max_eigval normalization: 1.0
Left canonical form: 2.482534153247273e-16


## Find Nash equilibria

The hyperparameter $\alpha$ should be set to larger for the algorithm to work, otherwise the convergence check would pass trivially...

In [159]:
def apply_unitary(unitary, A):
    """Apply unitary gate to the tensor A on the physical leg"""
    A = einops.einsum(
        A, unitary, "phys chi_l chi_r, new_phys phys -> new_phys chi_l chi_r"
    )
    return A

def find_nash_eq(Psi, H, max_iter=10000, alpha=10, convergence_threshold=1e-6, symmetric=False, trace_history=False):
    """
    Find Nash equilibrium using best response dynamics
    
    Args:
        Psi: List[t.Tensor] - MPS tensors of shape (phys, chi_l, chi_r)
        H: List[t.Tensor] - Hamiltonian tensors for each player
        max_iter: Maximum number of iterations
        alpha: Learning rate
        convergence_threshold: Threshold for convergence
        symmetric: Whether to use same random unitary for all players
        trace_history: Whether to record energy history
    
    Returns:
        dict with keys: converged, energy, state, num_iters, exploitability, energy_history (optional)
    """
    L = len(Psi)
    _iter = 0
    converged = False
    
    # Initialize with random unitaries
    U0 = t.tensor(np.linalg.qr(np.random.randn(2, 2))[0], dtype=default_dtype, device=device)
    for i in range(L):
        U = t.tensor(np.linalg.qr(np.random.randn(2, 2))[0], dtype=default_dtype, device=device)
        unitary = U if not symmetric else U0
        Psi[i] = apply_unitary(unitary, Psi[i])

    if trace_history:
        E_history = []
    

    while _iter <= max_iter and not converged:
        unitaries = []
        psi = get_state_from_tensors(Psi)
        E_old = []
        E_new = []

        # Compute energy BEFORE update and generate unitaries
        for i in range(L):
            j = (i - 1) % L
            k = (i + 1) % L
            dE = t.tensordot(H[i], psi, dims=([4, 5, 3], [1, 2, 0]))
            dE = t.tensordot(psi.conj(), dE, dims=([j, k], [j, k]))
            E_old.append(t.trace(dE).real.item())
            
            # Compute unitary from SVD
            dE = t.eye(2, dtype=default_dtype, device=device) - alpha * dE / t.linalg.norm(dE)
            U, S, Vh = t.linalg.svd(dE)
            unitaries.append((U @ Vh).T.conj())

        if trace_history:
            E_history.append(E_old)

        # Apply unitaries
        for i in range(L):
            Psi[i] = apply_unitary(unitaries[i], Psi[i])

        # Compute energy AFTER update for convergence check
        psi = get_state_from_tensors(Psi)
        for i in range(L):
            j = (i - 1) % L
            k = (i + 1) % L
            dE = t.tensordot(H[i], psi, dims=([4, 5, 3], [1, 2, 0]))
            dE = t.tensordot(psi.conj(), dE, dims=([j, k], [j, k]))
            E_new.append(t.trace(dE).real.item())

        local_max_epl = sum([max(E_new[i] - E_old[i], 0) for i in range(L)])
        converged = local_max_epl < convergence_threshold
        _iter = _iter + 1

    if not converged:
        print(f"Warning: the differential BR dynamics did not converge up to threshold {convergence_threshold}")

    result = {
        'converged': converged,
        'energy': E_new,
        'state': Psi,
        'num_iters': _iter,
        'exploitability': local_max_epl,
    }
    if trace_history:
        result['energy_history'] = E_history

    return result

### Tests

Find Nash equilibrium starting from

1. GHZ state
2. $W$ state
3. random MPS state


In [113]:
# Initialize GHZ state in PyTorch
ghz_state = [t.tensor([[[1., 0.], [0., 0.]], [[0., 0.], [0., 1.]]], dtype=default_dtype, device=device) for _ in range(3)]

result = find_nash_eq(ghz_state, H, max_iter=10000, alpha=10, convergence_threshold=1e-6, symmetric=False, trace_history=False)
Psi = result['state']
print("Final state:")
print(get_state_from_tensors(result['state']))
print(f"\nConverged: {result['converged']}")
print(f"Number of iterations: {result['num_iters']}")
print(f"Final energies: {result['energy']}")
print(f"Local Exploitability: {result['exploitability']}")

Final state:
tensor([[[ 0.0009,  0.0013],
         [ 0.4171,  0.5710]],

        [[-0.5710,  0.4171],
         [ 0.0013, -0.0009]]], dtype=torch.float64)

Converged: True
Number of iterations: 49
Final energies: [4.826038186594807, 4.673957465565634, 4.500000779798526]
Local Exploitability: 8.914493614398111e-07


## Evolving the fiducial state

In [111]:
def compute_energy(Psi: list[t.Tensor], H: t.Tensor):
    """
    Compute the energy of the state Psi under the Hamiltonian H

    OK so, not thermodynamic limit, still 3 sites here, so given a state we should compute the energy at each site.
    Shape of H is (3, 2, 2, 2, 2, 2, 2)
    """
    psi = get_state_from_tensors(Psi)
    coord_str = 'a1 a2 a3'
    coord_str_conj = 'b1 b2 b3'
    contraction_specification = "".join([coord_str, ', batch ', coord_str, ' ', coord_str_conj, ', ', coord_str_conj, ' -> batch'])
    # print(f"Contraction specification: {contraction_specification}")
    E = einops.einsum(psi, H, psi.conj(), contraction_specification)
    return t.real(E)

compute_energy(Psi, H_all_in_one)

tensor([4.5000, 4.7735, 4.7265], dtype=torch.float64)

### Batch compute exploitability

In [179]:
def batch_compute_energy(Psi: list[t.Tensor], H: t.Tensor, Psi_batch: Float[t.Tensor, "batch phys chi_l chi_r"], active_site: int):
    """
    Compute the energy of the state Psi under the Hamiltonian H in batches
    """
    def next_site(site: int):
        return (site + 1) % len(Psi)
    
    H_active_site = H[active_site]
    active_site_inds = f"batch a{active_site} b{active_site} b{next_site(active_site)}"
    inactive_site_inds = [
        f"a{i} b{i} b{next_site(i)}" for i in range(len(Psi)) if i != active_site
    ]
    contraction_spec = "".join(
        [active_site_inds, ", "] + [inactive_site_ind + ", " for inactive_site_ind in inactive_site_inds[:-1]] + [inactive_site_inds[-1]] + 
        ["-> batch "] + [f"a{i} " for i in range(len(Psi))]
    )
    # print(f"Contraction specification: {contraction_spec}")
    psi_batch = einops.einsum(Psi_batch, *[Psi[i] for i in range(len(Psi)) if i != active_site], contraction_spec)
    
    # The following parts work for 3 sites only
    coord_str = 'batch a1 a2 a3'
    coord_str_conj = 'batch b1 b2 b3'
    H_inds = 'a1 a2 a3 b1 b2 b3'
    contraction_spec = f"{coord_str}, {H_inds}, {coord_str_conj} -> batch"
    norm_spec = f"{coord_str}, {coord_str} -> batch"
    E = einops.einsum(psi_batch, H_active_site, psi_batch.conj(), contraction_spec)
    norm = einops.einsum(psi_batch, psi_batch.conj(), norm_spec)
    return E / norm

def batch_compute_exploitability(Psi, H, num_samples: int = 1000):
    """
    Compute the exploitability of the state Psi under the Hamiltonian H in batches
    """
    params = t.linspace(0, t.pi, num_samples, dtype=default_dtype, device=device)
    batch_U = t.stack([t.cos(params), t.sin(params), -t.sin(params), t.cos(params)])
    batch_U = einops.rearrange(batch_U, "(d1 d2) n_sample -> n_sample d1 d2", d1=2, d2=2)
    batch_E = []
    for site in range(len(Psi)):
        Psi_batch = einops.einsum(
            batch_U, Psi[site], "n_sample d1 d2, d2 chi_l chi_r -> n_sample d1 chi_l chi_r"
        )
        batch_E.append(batch_compute_energy(Psi, H, Psi_batch, site))

    batch_E: Float[t.Tensor, "n_player n_sample"] = t.stack(batch_E)
    original_E = einops.repeat(compute_energy(Psi, H), "n_player -> n_player n_sample", n_sample=num_samples)
    expl = t.clamp(batch_E - original_E, min=0).max(dim=1).values
    return expl


### Optimizer loop

In [242]:
# Set up the optimizer and parameters.
# Some test codes for one-step gradient descent on the MPS space
import torch.nn as nn
import torch.optim as optim
Psi = [t.randn(2, 3, 3, dtype=default_dtype, device=device) for _ in range(3)]
params_list = [nn.Parameter(Psi[i]) for i in range(len(Psi))]

# print(f"state before optimization step: {Psi}")

optimizer_list = [optim.Adam([params_list[i]], lr=0.01) for i in range(len(Psi))]

# one round of update
E = compute_energy(params_list, H_all_in_one)
for i in range(E.shape[0]):
    optimizer_list[i].zero_grad()
    E[i].backward(retain_graph=True)
    optimizer_list[i].step()
  
# need to re-canonicalize the tensors
params_list = mps_2form(params_list)
E_intermediate = compute_energy(params_list, H_all_in_one)
result = find_nash_eq(params_list, H, max_iter=10000, alpha=10, convergence_threshold=1e-6, symmetric=False, trace_history=False)
Psi = result['state']
params_list = [nn.Parameter(Psi[i]) for i in range(len(Psi))]

# print(f"state after optimization step: {Psi}")
E_new = compute_energy(params_list, H_all_in_one)
print(sum(E))
print(sum(E_intermediate))
print(sum(E_new))

tensor(11.51, dtype=torch.float64, grad_fn=<AddBackward0>)
tensor(11.51, dtype=torch.float64, grad_fn=<AddBackward0>)
tensor(6.41, dtype=torch.float64, grad_fn=<AddBackward0>)


In [243]:
# Full optimization loop
from tqdm import tqdm
import pandas as pd

Psi = [t.randn(2, 3, 3, dtype=default_dtype, device=device) for _ in range(3)]
Psi_params = [nn.Parameter(Psi[i]) for i in range(len(Psi))]
optimizer_list = [optim.Adam([Psi_params[i]], lr=0.005) for i in range(len(Psi))]

df = []

for _ in tqdm(range(10)):
    E = compute_energy(Psi_params, H_all_in_one)
    for i in range(E.shape[0]):
        optimizer_list[i].zero_grad()
        E[i].backward(retain_graph=True)
    
    # First accumulate the gradients, then update the tensors simultaneously
    for i in range(len(Psi)):
        optimizer_list[i].step()
    
    with t.no_grad():
        Psi_canonical_tensors = mps_2form(Psi_params)
        result = find_nash_eq(Psi_canonical_tensors, H, max_iter=10000, alpha=10, convergence_threshold=1e-7, symmetric=False, trace_history=False)
        ne_state = result['state']
        for i in range(len(Psi)):
            Psi_params[i].data = ne_state[i]

        global_expl = batch_compute_exploitability(ne_state, H_all_in_one, num_samples=10000)

    
    df.append({
        'energy': result['energy'],
        'converged': result['converged'],
        'state': t.stack(result['state']).detach().cpu().numpy(),
        'num_iters': result['num_iters'],
        'local_expl': result['exploitability'],
        'global_expl': global_expl.detach().cpu().numpy(),
        'state_': get_state_from_tensors(result['state']).flatten().detach().cpu().numpy()
    })
    



  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:06<00:00,  1.59it/s]


### Logging and computation of multipartite entanglement

In [None]:
def compute_ent_params_from_state(state, option = 'I'):
    if state.ndim == 1:
        state = state.reshape(2, 2, 2)
    
    rho_1 = einops.einsum(state, state.conj(), 'x i j, y i j -> x y')
    rho_2 = einops.einsum(state, state.conj(), 'i x j, i y j -> x y')
    rho_3 = einops.einsum(state, state.conj(), 'i j x, i j y -> x y')
    rho_12 = einops.einsum(state, state.conj(), 'x1 x2 i, y1 y2 i -> x1 y1 x2 y2')
    rho_12 = einops.rearrange(rho_12, 'x1 y1 x2 y2 -> (x1 x2) (y1 y2)')
    
    I1 = np.trace(np.linalg.matrix_power(rho_1, 2))
    I2 = np.trace(np.linalg.matrix_power(rho_2, 2))
    I3 = np.trace(np.linalg.matrix_power(rho_3, 2))
    # I4 = 3 * np.trace(np.kron(rho_1, rho_2) @ rho_12) - np.trace(np.linalg.matrix_power(rho_1, 3)) - np.trace(np.linalg.matrix_power(rho_2, 3))
    I4 = np.trace(np.kron(rho_1, rho_2) @ rho_12)
    
    eps = np.array([[0, 1], [-1, 0]])
    det3 = 1/2 * einops.einsum(
        eps, eps, eps, eps, eps, eps, state, state, state, state,
        'i1 j1, i2 j2, k1 l1, k2 l2, i3 k3, j3 l3, i1 i2 i3, j1 j2 j3, k1 k2 k3, l1 l2 l3 ->'
    )
    # I5 = 4 * np.abs(det3) ** 2
    I5 = np.abs(det3) ** 2

    if option == 'I':
        return np.array([I1, I2, I3, I4, I5])
    elif option == 'J':
        J1 = 1/4 * (1 + I1 - I2 - I3 - 2 * np.sqrt(I5))
        J2 = 1/4 * (1 - I1 + I2 - I3 - 2 * np.sqrt(I5))
        J3 = 1/4 * (1 - I1 - I2 + I3 - 2 * np.sqrt(I5))
        J4 = np.sqrt(I5)
        J5 = 1/4 * (3 - 3 * I1 - 3 * I2 - I3 + 4 * I4 - 2 * np.sqrt(I5))
        return np.array([J1, J2, J3, J4, J5])
    else:
        raise ValueError("Invalid option")
    

In [244]:
# some post-processing
df = pd.DataFrame(df)
def post_process(df: pd.DataFrame | list[dict]):
    if isinstance(df, list):
        df = pd.DataFrame(df)
    # Assume each entry in 'energy' and 'global_expl' columns is a list or array
    df['welfare'] = df['energy'].apply(lambda x: sum(x))
    df['tot_expl'] = df['global_expl'].apply(lambda x: sum(x))

    # compute the entanglement parameters
    df['ent_params'] = df['state_'].apply(lambda x: compute_ent_params_from_state(x, option='I'))
    return df

df = post_process(df)
print(df['global_expl'], '\n', df['welfare'])

0    [0.018348212670681097, 0.026139472063660296, 0...
1    [1.576896414417675, 0.3382407577651474, 1.6375...
2    [0.03664470831248856, 0.0429613710765282, 0.19...
3    [1.246065849470967, 0.0829921160060394, 0.6527...
4    [0.0006846802399698149, 0.19539722827714412, 0...
5    [0.002882772548166823, 0.003542676323128102, 0...
6    [0.019589943484326078, 0.02284340155695963, 4....
7    [0.0005807935827775168, 1.255943963857992, 0.6...
8     [0.019986021315266722, 0.02518253218871802, 0.0]
9    [0.004297502725477909, 0.19559124587839793, 0....
Name: global_expl, dtype: object 
 0     7.96
1    13.15
2     9.15
3    11.34
4     8.68
5     7.70
6     7.87
7    11.32
8     7.86
9     8.30
Name: welfare, dtype: float64


## Analyzing results

In [247]:
import pickle
import pandas as pd
filename = '/mnt/users/clin/workspace/nonlocal/nash_data/qpd_opt_results.pkl'
with open(filename, 'rb') as f:
    df = pickle.load(f)

df





Unnamed: 0,energy,converged,state,num_iters,local_expl,global_expl,state_
0,"[2.416024223604791, 2.437621846440669, 2.50033...",True,"[[[[ 0.19 0.13 -0.92], [ 0.64 -0.38 0.1 ], [...",2,0.0,0.0008982423031183,"[0.1866866264877058, -0.1682821779283403, 0.15..."
1,"[2.3987757071661844, 2.414948591464952, 2.4533...",True,"[[[[ 0.18 -0.12 -0.92], [0.63 0.38 0.1 ], [0.1...",3,6.81e-07,0.0,"[-0.1825311892854199, 0.16335416582316772, -0...."
2,"[2.3854209320953528, 2.3967149732821342, 2.414...",True,"[[[[ 0.18 -0.12 -0.92], [0.63 0.39 0.11], [0.1...",1,0.0,9.70469583894129e-06,"[0.18101411700243095, -0.15133450069890644, 0...."
3,"[2.3676494323905746, 2.3701099154984453, 2.371...",True,"[[[[ 0.19 -0.12 -0.92], [0.62 0.39 0.11], [0.1...",1,0.0,1.917468006151779e-05,"[-0.17983724582823243, 0.13761776852771124, -0..."
4,"[2.3446026427099604, 2.342876015497899, 2.3270...",True,"[[[[ 0.19 -0.12 -0.92], [0.62 0.39 0.12], [0.1...",1,0.0,1.931344101446797e-05,"[0.1771574294847754, -0.12304683393871617, 0.1..."
5,"[2.3230076217376174, 2.314527114565141, 2.2865...",True,"[[[[ 0.19 -0.12 -0.92], [0.62 0.38 0.13], [0.1...",1,0.0,1.5852913985980877e-05,"[-0.17445527358394933, 0.10848059805507743, -0..."
6,"[2.3006257974155, 2.288471939941917, 2.2494881...",True,"[[[[ 0.19 -0.11 -0.92], [0.61 0.38 0.14], [0.1...",1,0.0,1.085507195908164e-05,"[0.17055820099874341, -0.09418411194571485, 0...."
7,"[2.2806716376905487, 2.263412111191749, 2.2175...",True,"[[[[ 0.19 -0.11 -0.92], [0.61 0.38 0.14], [0.1...",1,0.0,6.951854295333959e-06,"[-0.16654396887384482, 0.08055979194197913, -0..."
8,"[2.542845254553866, 2.5238168662333123, 2.4610...",True,"[[[[-0.27 0.08 0.94], [-0.49 -0.43 -0.1 ], [...",4,0.0,0.1630762162921986,"[-0.1128472158540858, 0.03865374812868179, -0...."
9,"[2.542790356101613, 2.5250355251847223, 2.4625...",True,"[[[[-0.27 0.09 0.94], [-0.49 -0.44 -0.1 ], [...",2,0.0,0.164227977381465,"[-0.11083128334973702, -0.04036468540993522, 0..."
