## Training test
### A notebook for prototyping mps training code


In [1]:
import numpy as np
import torch

from models import MPS, ComplexTensor
from utils import build_ghz_plus
from qtools import pauli_exp
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from utils import MeasurementDataset,do_local_sgd_training, evaluate, do_validation

import matplotlib.pyplot as plt
import datetime
from utils import make_linear_schedule, make_exp_schedule

Loaded libmkl_rt.so for dgesvd


Pick system size, dimensionality of local hilbert space, and initial bond dim

In [2]:
L=4
local_dim=2
bond_dim=2

In [3]:
# dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
dev = torch.device("cpu")

Initialize the MPS model

In [4]:
psi = MPS(L=L, local_dim=local_dim, bond_dim=bond_dim)

In [5]:
psi.to(device=dev)

The ground truth state (a ghz state with no phase)

In [6]:
ghz_plus = build_ghz_plus(L)
ghz_plus.to(device=dev)

Samples_tr: lists indices of the observed basis states at each measurement. 0 = spin up, 1 = spin down

settings: array of corresponding angles

In [7]:
# fname_settings = "test_datasets/settings_ghz_plus_random_basis_L=%d.npy"%L
# fname_samples = "test_datasets/samples_ghz_plus_random_basis_L=%d.npy"%L
fname_settings = "datasets/mps_sampled/ghz_plus_L=%d_angles.npy"%L
fname_samples = "datasets/mps_sampled/ghz_plus_L=%d_outcomes.npy"%L
Nsamp=20000
samples = np.load(fname_samples)[:Nsamp]
samples_tr = torch.tensor((1-samples)/2).to(dtype=torch.long,device=dev)
settings = np.load(fname_settings)[:Nsamp]

In [8]:
theta = torch.tensor(settings[...,0],dtype=torch.float32,device=dev)
phi = torch.tensor(settings[...,1],dtype=torch.float32,device=dev)

Holds the unitaries corresponding to each angle

In [9]:
U = pauli_exp(theta, phi)
rotations_real = U.real
rotations_imag = U.imag

In [None]:
estimate_fidelity(ghz_plus, samples_tr, U)

A dataset which yields outcomes and corresponding rotations

In [10]:
Ntr=int(.9 * Nsamp)

In [11]:
ds = MeasurementDataset(samples=samples_tr,rotations=U)
ds_tr, ds_val = random_split(ds, [Ntr, Nsamp-Ntr])

In [17]:
s,u=ds.unpack(0, 20)

Batch size and learning rate for training:

In [19]:
u

(ComplexTensor shape torch.Size([20, 4, 2, 2]),)

In [None]:
batch_size = 1028


Number of epochs of training

In [None]:
epochs = 50

Max number of singular values to keep, and cutoff below which to truncate singular values

In [None]:
max_sv = 10
cutoff=1e-4

In [None]:
Nsamp = 5

In [None]:
lr_scale = 10**np.random.uniform(-6, 0, Nsamp)
s2_scale = 10**np.random.uniform(-2, 1, Nsamp)
s2_timescale = np.random.uniform(.2, 1, Nsamp) * epochs
lr_timescale = np.random.uniform(.5, 1, Nsamp) * epochs

learning_rates = [make_exp_schedule(A, tau) for (A, tau) in zip(lr_scale, lr_timescale)]
s2_penalties = [make_exp_schedule(A, tau) for (A, tau) in zip(s2_scale, s2_timescale)]

In [None]:
params = [dict(learning_rate=lr, s2_penalty=s2) for lr, s2 in zip(learning_rates, s2_penalties)]

In [None]:
max_sv_to_keep = lambda ep: max_sv

In [None]:
scores = do_validation(ds_tr, ds_val, batch_size,epochs, params,
                                    cutoff=cutoff, max_sv_to_keep=max_sv_to_keep, use_cache=True, 
                                  early_stopping=True,
                                   verbose=True)

In [None]:
scores

In [None]:
lr_

fidelity = logdict['fidelity']
loss = logdict['loss']
max_bond_dim = logdict['max_bond_dim']
eigs = logdict['eigenvalues']
s2 = logdict['s2']
val_loss = logdict['val_loss']

t = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

%matplotlib inline
fig, ax = plt.subplots()
plt.plot(loss, label='training set')
plt.plot(val_loss, label='val set')
plt.legend()
plt.xlabel("training step")
plt.title("batch NLL loss %s" % t)
# fig.savefig("assets/nll_loss_example_{0}.png".format(t))