Try training on GHZ state using data in random bases

In [1]:
import numpy as np
import torch

In [2]:
from models import MPS, ComplexTensor
from utils import build_ghz_plus

Pick system size, dimensionality of local hilbert space, and initial bond dim

In [3]:
L=12
local_dim=2
bond_dim=2

In [4]:
# dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
dev = torch.device("cpu")

Initialize the MPS model

In [5]:
psi = MPS(L=L, local_dim=local_dim, bond_dim=bond_dim)

In [6]:
psi.to(device=dev)

The ground truth state (a ghz state with no phase)

In [7]:
ghz_plus = build_ghz_plus(L)
ghz_plus.to(device=dev)

Samples_tr: lists indices of the observed basis states at each measurement. 0 = spin up, 1 = spin down

settings: array of corresponding angles

In [8]:
# fname_settings = "test_datasets/settings_ghz_plus_random_basis_L=%d.npy"%L
# fname_samples = "test_datasets/samples_ghz_plus_random_basis_L=%d.npy"%L
fname_settings = "datasets/mps_sampled/ghz_plus_L=%d_angles.npy"%L
fname_samples = "datasets/mps_sampled/ghz_plus_L=%d_outcomes.npy"%L
Nsamp=20000
samples = np.load(fname_samples)[:Nsamp]
samples_tr = torch.tensor((1-samples)/2).to(dtype=torch.long,device=dev)
settings = np.load(fname_settings)[:Nsamp]

In [9]:
theta = torch.tensor(settings[...,0],dtype=torch.float32,device=dev)
phi = torch.tensor(settings[...,1],dtype=torch.float32,device=dev)

In [10]:
from qtools import pauli_exp

Holds the unitaries corresponding to each angle

In [11]:
U = pauli_exp(theta, phi)
rotations_real = U.real
rotations_imag = U.imag

In [12]:
from torch.utils.data import TensorDataset, DataLoader

In [13]:
from utils import MeasurementDataset,do_local_sgd_training

A dataset which yields outcomes and corresponding rotations

In [14]:
ds = MeasurementDataset(samples=samples_tr,rotations=U)

Batch size and learning rate for training:

In [15]:
batch_size = 1028
lr = 1e-2

In [16]:
dl = DataLoader(ds, batch_size=batch_size, shuffle=True)

Number of epochs of training

In [17]:
epochs = 10

How many grad-descent steps to perform at each bond

In [18]:
nstep=1

Max number of singular values to keep, and cutoff below which to truncate singular values

In [19]:
max_sv = 20
cutoff=1e-3

Regularization term: penalty for Renyi-2 entropy, here set to zero

In [20]:
Nstep = epochs * len(dl)

In [21]:
s2_schedule = lambda ep: np.exp(-ep)
max_sv_to_keep = lambda ep: 2 if ep<1 else max_sv
# s2_schedule = np.zeros(Nstep)

In [22]:
logdict = do_local_sgd_training(psi,dl,epochs=epochs,learning_rate=lr,
                         s2_schedule=s2_schedule,nstep=nstep,cutoff=cutoff,max_sv_to_keep=max_sv_to_keep,
                         ground_truth_mps=ghz_plus, use_cache=True, verbose=True)

Finished epoch 0 in 3.939 sec
Model shape:  [(1, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 1)]
Finished epoch 1 in 13.560 sec
Model shape:  [(1, 2), (2, 4), (4, 8), (8, 16), (16, 20), (20, 20), (20, 20), (20, 16), (16, 8), (8, 4), (4, 2), (2, 1)]
Finished epoch 2 in 14.248 sec
Model shape:  [(1, 2), (2, 4), (4, 8), (8, 16), (16, 20), (20, 20), (20, 20), (20, 16), (16, 8), (8, 4), (4, 2), (2, 1)]


KeyboardInterrupt: 

In [None]:
fidelity = logdict['fidelity']
losses = logdict['loss']

In [None]:
import matplotlib.pyplot as plt

In [None]:
import datetime

In [None]:
t = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

In [None]:
%matplotlib inline
fig, ax = plt.subplots()
plt.plot(losses, label='model')
plt.legend()
plt.xlabel("training step")
plt.title("batch NLL loss %s" % t)
# fig.savefig("assets/nll_loss_example_{0}.png".format(t))

In [None]:
plt.plot(fidelity)
plt.xlabel("Training step")
plt.title("fidelity")

from tools import generate_binary_space

basis = torch.tensor(generate_binary_space(L),dtype=torch.long)

with torch.no_grad():
    for i in range(len(basis)):
        print("{0} has probability {1:.4f}".format(basis[i], psi.prob_normalized(basis[i]).item()))

with torch.no_grad():
    for i in range(len(basis)):
        a = psi.amplitude_normalized(basis[i])
        atrue = ghz_plus.amplitude_normalized(basis[i])
        print("{0} has amplitude {1:.4f} (target: {2:.4f})".format(basis[i], a.numpy().item(),
                                                                         atrue.numpy().item()))