# QML model that learns to distinguish 5 different handwritten digits with 7 input dimensions

This notebook uses Data ReUploading, to learn to classify 5 different handwirtten digits.

In [1]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import numpy as np


from scipy.sparse.linalg import expm
from scipy.sparse import coo_matrix, csc_matrix, diags, identity


import jax.numpy as jnp
from jax import grad, vmap, jit
from jax import random

import jax

import scipy

from tqdm import tqdm

from sklearn.utils import shuffle

import qutip as q

from tqdm import tqdm


import itertools


from sklearn.utils import shuffle
from sklearn import datasets, svm, metrics
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

We again first define the spin matrices and the functions that act on a state. I will use a spin length of 4 which yields 9 states. I will use every second qudit state to label the digits.

In [2]:
l = 2 # spin length

# let's put together spin matrices
dim_qudit = 2*l+1
qudit_range = np.arange(l, -(l+1),-1)

Id  =  scipy.sparse.csc_matrix(identity(dim_qudit))
Lx  =  scipy.sparse.csc_matrix(1/2*diags([np.sqrt([(l-m+1)*(l+m) for m in qudit_range[:-1]]), np.sqrt([(l+m+1)*(l-m) for m in qudit_range[1:]]) ], [-1, 1]))
Lz  =  scipy.sparse.csc_matrix(diags([qudit_range], [0]))
Lz2 = Lz.multiply(Lz)
Lx2 = Lx.multiply(Lx)

Lx = Lx.toarray()
Lz = Lz.toarray()
Lz2 = Lz2.toarray()


Lx = jnp.array(Lx)
Lz = jnp.array(Lz)
Lz2 = jnp.array(Lz2)



In [3]:
def Rx(psi, theta, conj=False):
    Ux = jax.scipy.linalg.expm(-1j*theta*Lx)
    psi = jnp.dot(Ux,psi)
    return psi

def Rz(psi, theta, conj=False):
    Uz = jax.scipy.linalg.expm(-1j*theta*Lz)
    psi = jnp.dot(Uz,psi)
    return psi

def Rz2(psi, theta, conj=False):
    Uz2 = jax.scipy.linalg.expm(-1j*theta*Lz2)
    psi = jnp.dot(Uz2,psi)
    return psi

def Initialization(psi :np.array, x1: float, x2: float, x3: float, x4: float, x5: float, x6, x7)-> np.array:
    psi = Rz(psi, x1)
    psi = Rx(psi, x2)
    psi = Rz(psi, x3)
    psi = Rx(psi, x4)
    psi = Rz(psi, x5)
    psi = Rz(psi, x6)
    psi = Rx(psi, x7)
    return psi


def varaince_z(psi):
    a = psi.T.conj()@Lz@psi
    b = psi.T.conj()@Lz**2@psi
    
    return jnp.abs((b-a**2).real)

def measure(psi):
    prob = psi**2
    measurement = np.random.choice(np.arange(len(state)), p=prob)
    return measurement

def expect(psi, oper):
    return (psi.T.conj()@oper@psi).real



In [4]:
state_labels = [[i] for i in range(2*l+1)]
used_labels = [[0], [1], [2], [3], [4]]

def cost_circ(params, x, y, beta, return_ind, state_labels=used_labels):

    loss_overlap = 0.0
    loss_variance = 0.0
    dm_labels = state_labels
    
    overlap, psi = circ(params, x, y, var_return=True)
    loss_overlap +=(1 - overlap**2)

    loss_variance += (varaince_z(psi)**2)*beta
    loss = loss_overlap + loss_variance
    if return_ind:
        return loss, loss_variance, loss_overlap
    return loss 


def circ(params, x, y, var_return=False):

    psi = 1j*jnp.zeros(int(l*2+1))
    #This is necessary since jnp arrays are immutable
    psi = jax.ops.index_add(psi, 0, 1+1j*0)
    
    label = 1j*jnp.zeros(int(l*2+1))
    #This is necessary since jnp arrays are immutable
    label = jax.ops.index_add(label, y, 1+1j*0)

    for i in range(0, len(params)-1):
    
        psi = Initialization(psi, x[0], x[1], x[2], x[3], x[4], x[5], x[6])
    
        psi = Rx(psi, params[i][0]) 
        
        psi = Rz(psi, params[i][1])
        
        psi = Rx(psi, params[i][2]) 
        
        


        psi = Rz2(psi, params[i][3])

    
    psi = Rx(psi, params[i+1][0]) 

    psi = Rz(psi, params[i+1][1])

    psi = Rx(psi, params[i+1][2]) 
        
    
    if var_return:
        return  jnp.abs(jnp.dot(psi, label)), psi
    return jnp.abs(jnp.dot(psi, label))


def test(params, x, state_labels=used_labels):
    
    fidelities = jnp.array([circ(params, x, dm[0]) for dm in state_labels])
    
    best_fidel = jnp.argmax(fidelities)


    return best_fidel

The cost function is the following:

\begin{align}
\chi_f(\theta)= \sum_{i=1}^{m} (1- \langle\tilde{\psi}_s|\psi(\vec{x}, \vec{\theta})\rangle^2)
\end{align}


In [5]:
def accuracy_score(y_true, y_pred):

    score = y_true == y_pred
    return score.sum() / len(y_true)



def iterate_minibatches(inputs, targets, batch_size):

    for start_idx in range(0, inputs.shape[0] - batch_size + 1, batch_size):
        idxs = slice(start_idx, start_idx + batch_size)
        yield inputs[idxs], targets[idxs]

In [6]:
def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    

    accuracy = np.trace(cm) / np.sum(cm).astype('float')
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))

Th following functions are the jitted, vectorized versions of the cost and test function.

In [7]:
@jit
def vmap_cost(params, X_batched, y_batched, beta=0.3):
    return vmap(cost_circ,  in_axes=(None, 0, 0, None, None))(params, X_batched, y_batched, beta, False).sum()/len(X_batched)

@jit
def vmap_cost_ind(params, X_batched, y_batched, beta=0.3):
    return vmap(cost_circ,  in_axes=(None, 0, 0, None, None))(params, X_batched, y_batched, beta, True)

@jit
def vmap_test(params, X_batched, state_labels):
    return vmap(test,  in_axes=(None, 0, None))(params, X_batched, state_labels)



We can now load the MNIST Handwritten dataset and pick 5 numbers out of this dataset. The numbers will correspond to the following labels:

* 0 labelstate = $|4\rangle$
* 1 labelstate = $|0\rangle$
* 7 labelstate = $|1\rangle$
* 2 labelstate = $|3\rangle$
* 4 labelstate = $|2\rangle$

Since the numbers consist of $8 \times 8$ pixels, and this is a little bit too much of input data we can use a Principle Component Analysis to reduce the dimensionality of the dataset to whatever size we want. I chose a dimensionality of $7$ in this notebook. 




In [8]:
digits = datasets.load_digits()

zeros = digits["images"][digits["target"]==0]
sevens = digits["images"][digits["target"]==7]
ones = digits["images"][digits["target"]==1]
twos = digits["images"][digits["target"]==2]
fours = digits["images"][digits["target"]==4]

zeros_y = digits["target"][digits["target"]==0]
sevens_y = digits["target"][digits["target"]==7]
ones_y = digits["target"][digits["target"]==1]
twos_y = digits["target"][digits["target"]==2]
fours_y = digits["target"][digits["target"]==4]

data = digits["images"]
targets = digits["target"]

x_data = np.vstack((zeros, sevens))
x_data = np.vstack((x_data, ones))
x_data = np.vstack((x_data, twos))
x_data = np.vstack((x_data, fours))

y_data = np.hstack((zeros_y, sevens_y))
y_data = np.hstack((y_data, ones_y))
y_data = np.hstack((y_data, twos_y))
y_data = np.hstack((y_data, fours_y))



pca = PCA(n_components=7)
print(data.shape)

x_data = pca.fit_transform(x_data.reshape(x_data.shape[0], 64))

x_data = x_data/np.amax(x_data)

y_data[y_data==2] = 3
y_data[y_data==4] = 2
y_data[y_data==0] = 4
y_data[y_data==1] = 0
y_data[y_data==7] = 1


print(np.unique(y_data))

#used_labels = [[0], [2], [4], [6], [8]]

np.unique(used_labels)

(1797, 8, 8)
[0 1 2 3 4]


array([0, 1, 2, 3, 4])

We then need the gradient of the cost function.

In [9]:
cost_circ_grad = grad(vmap_cost, argnums=0)

Now we have everything we need and can start learning. We first need to create the dataset and choose random parameters to start. 

In [10]:
batch_size = 30

x_data, y_data = shuffle(x_data, y_data)


train_x, test_x, train_y, test_y = train_test_split(x_data, y_data)
used_labels_arr = np.array(used_labels)[:, 0]

In [11]:
np.random.seed(2)
#layers = [2, 4, 6, 8, 10]
layers = [3, 4, 5, 6, 7, 8, 9]
iterations = 10
#learning_Rates = [0.05, 0.01, 0.005, 0.001, 0.0005]


beta = 0
def training(j):
    print("--------------------------------------")
    print(f"Starting with {j} layers")
    for m in tqdm(range(7, iterations)):
        
        params = np.random.uniform(size=(j, 4))*np.pi/2
        params[:, 3] = 0
        
        
        losses = []
        train_acc = []
        param_list = []

        epochs = 500

        learning_rate = 0.005
        opt = AdamOptimizer(learning_rate)
        
        train_x, test_x, train_y, test_y = train_test_split(x_data, y_data, shuffle=True, random_state=j*m+m)

        for i in range(epochs):
            #print("--------------------------------------")
            #print(f"Starting Epoch {i+1}, learning rate: {learning_rate:.2f}")

            for X_Batch, Y_Batch in iterate_minibatches(train_x, train_y, batch_size):
                #print(X_Batch.shape, Y_Batch.shape)
                params = opt.step(vmap_cost, params, grad_fn=cost_circ_grad, X_batched=X_Batch, y_batched=Y_Batch, beta=beta)
                param_list.append(params)
            pred_train = vmap_test(params, test_x, used_labels)
            pred_train = np.array(pred_train)
            for n, prediction in enumerate(pred_train):
                pred_train[n] = used_labels_arr[prediction]
            acc = accuracy_score(test_y, pred_train)
            train_acc.append(acc)
            loss = vmap_cost(params, train_x, train_y, beta)
            losses.append(loss)
            
        np.save(f"500e0-005lr_7Dim/params_{j}-layers-{m}.npy", param_list)
        np.save(f"500e0-005lr_7Dim/loss_{j}-layers-{m}.npy", losses)
        np.save(f"500e0-005lr_7Dim/acc_{j}-layers-{m}.npy", train_acc)
        np.save(f"500e0-005lr_7Dim/pred_{j}-layers-{m}.npy", pred_train)
    return True

    

In [12]:
import numpy as np
from multiprocessing import Pool

with Pool(1) as p:
        print(p.map(training, layers))


--------------------------------------
Starting with 7 layers


100%|██████████| 3/3 [58:10<00:00, 1163.35s/it]


[True]


[3, 4, 5, 6, 7, 8, 9]