# Imports

In [1]:
import math
import pandas as pd
import pennylane as qml
import time

from keras.datasets import mnist
from matplotlib import pyplot as plt
from pennylane import numpy as np
from pennylane.templates import AmplitudeEmbedding, AngleEmbedding
from pennylane.templates.subroutines import ArbitraryUnitary
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Model Params

In [2]:
np.random.seed(131)
initial_params = np.random.random([45])

INITIALIZATION_METHOD = 'Angle'
BATCH_SIZE = 20
EPOCHS = 400

STEP_SIZE = 0.01
BETA_1 = 0.9
BETA_2 = 0.99
EPSILON = 0.00000001

TRAINING_SIZE = 0.78
VALIDATION_SIZE = 0.07
TEST_SIZE = 1-TRAINING_SIZE-VALIDATION_SIZE

initial_time = time.time()

# Import dataset

In [3]:
(train_X, train_y), (test_X, test_y) = mnist.load_data()
examples = np.append(train_X, test_X, axis=0)
examples = examples.reshape(70000, 28*28)
classes = np.append(train_y, test_y)

In [4]:
x = []
y = []
for (example, label) in zip(examples, classes):
    if label in [0, 1, 2, 3]:
        x.append(example)
        y.append(-1)
    else:
        x.append(example)
        y.append(1)

In [5]:
x = np.array(x)
y = np.array(y)

# Normalize pixels values
x = x / 255

X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=TEST_SIZE, shuffle=True)

In [6]:
validation_indexes = np.random.random_integers(len(X_train), size=(math.floor(len(X_train)*VALIDATION_SIZE),))
X_validation = [X_train[n] for n in validation_indexes]
y_validation = [y_train[n] for n in validation_indexes]

In [7]:
pca = PCA(n_components=8)
pca.fit(X_train)
X_train = pca.transform(X_train)
X_validation = pca.transform(X_validation)
X_test = pca.transform(X_test)

preprocessing_time = time.time()

# Circuit creation

In [8]:
device = qml.device("default.qubit", wires=8)

In [9]:
@qml.qnode(device)
def circuit(features, params):
    # Load state
    if INITIALIZATION_METHOD == 'Amplitude':
        AmplitudeEmbedding(features=features, wires=range(8), normalize=True, pad_with=0.)
    else:
        AngleEmbedding(features=features, wires=range(8), rotation='Y')

    # First layer
    qml.U3(params[0], params[1], params[2], wires=0)
    qml.U3(params[3], params[4], params[5], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.U3(params[6], params[7], params[8], wires=2)
    qml.U3(params[9], params[10], params[11], wires=3)
    qml.CNOT(wires=[2, 3])
    qml.U3(params[12], params[13], params[14], wires=4)
    qml.U3(params[15], params[16], params[17], wires=5)
    qml.CNOT(wires=[5, 4])
    qml.U3(params[18], params[19], params[20], wires=6)
    qml.U3(params[21], params[22], params[23], wires=7)
    qml.CNOT(wires=[7, 6])

    # Second layer
    qml.U3(params[24], params[25], params[26], wires=1)
    qml.U3(params[27], params[28], params[29], wires=2)
    qml.CNOT(wires=[1, 2])
    qml.U3(params[30], params[31], params[32], wires=5)
    qml.U3(params[33], params[34], params[35], wires=6)
    qml.CNOT(wires=[6, 5])

    # Third layer
    qml.U3(params[36], params[37], params[38], wires=2)
    qml.U3(params[39], params[40], params[41], wires=5)
    qml.CNOT(wires=[2, 5])

    # Fourth layer
    qml.U3(params[42], params[43], params[44], wires=5)

    # Measurement
    return qml.expval(qml.PauliZ(5))

## Circuit example

In [10]:
features = X_train[0]
print(f"Inital parameters: {initial_params}\n")
print(f"Example features: {features}\n")
print(f"Expectation value: {circuit(features, initial_params)}\n")
print(circuit.draw())

Inital parameters: [0.65015361 0.94810917 0.38802889 0.64129616 0.69051205 0.12660931
 0.23946678 0.25415707 0.42644165 0.83900255 0.74503365 0.38067928
 0.26169292 0.05333379 0.43689638 0.20897912 0.59441102 0.09890353
 0.22409353 0.5842624  0.95908107 0.20988382 0.66133746 0.50261295
 0.32029143 0.12506485 0.80688893 0.98696002 0.54304141 0.23132314
 0.60351254 0.17669598 0.88653747 0.58902228 0.72117264 0.27567029
 0.78811469 0.1326223  0.39971595 0.62982409 0.42404345 0.16187284
 0.52034418 0.6070413  0.5808057 ]

Example features: [ 1.64521311  1.2675723  -0.46002775  5.2635177  -0.75769887 -0.39391971
  1.99552138  0.41458452]

Expectation value: 0.0507315826660395

 0: ──RY(1.65)────Rot(0.388, 0.65, -0.388)─────Rϕ(0.388)───Rϕ(0.948)───╭C───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤     
 1: ──RY(1.27)────Rot(0.127, 0.641, -0.127)────Rϕ(0.127)───Rϕ(0.691)───╰X──Rot(0.807

# Accuracy test definition

In [11]:
def measure_accuracy(x, y, circuit_params):
    class_errors = 0

    for example, example_class in zip(x, y):
        predicted_value = circuit(example, circuit_params)

        if (example_class > 0 and predicted_value <= 0) or (example_class <= 0 and predicted_value > 0):
            class_errors += 1

    return 1 - (class_errors/len(y))

# Training

In [12]:
params = initial_params
opt = qml.AdamOptimizer(stepsize=STEP_SIZE, beta1=BETA_1, beta2=BETA_2, eps=EPSILON)
test_accuracies = []
best_validation_accuracy = 0.0
best_params = []

for i in range(len(X_train)):
    features = X_train[i]
    expected_value = y_train[i]

    def cost(circuit_params):
        value = circuit(features, circuit_params)
        return ((expected_value - value) ** 2)/len(X_train)

    params = opt.step(cost, params)

    if i % BATCH_SIZE == 0:
        print(f"epoch {i//BATCH_SIZE}")

    if i % (10*BATCH_SIZE) == 0:
        current_accuracy = measure_accuracy(X_validation, y_validation, params)
        test_accuracies.append(current_accuracy)
        print(f"accuracy: {current_accuracy}")

        if current_accuracy > best_validation_accuracy:
            print("best accuracy so far!")
            best_validation_accuracy = current_accuracy
            best_params = params

    if len(test_accuracies) == 30:
        print(f"test_accuracies: {test_accuracies}")

        if np.allclose(best_validation_accuracy, test_accuracies[0]):
            params = best_params
            break

        del test_accuracies[0]

epoch 0
accuracy: 0.46962785114045613
best accuracy so far!
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
accuracy: 0.5270108043217288
best accuracy so far!
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
accuracy: 0.5296518607442977
best accuracy so far!
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
accuracy: 0.5611044417767107
best accuracy so far!
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
accuracy: 0.5759903961584634
best accuracy so far!
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49
epoch 50
accuracy: 0.5771908763505402
best accuracy so far!
epoch 51
epoch 52
epoch 53
epoch 54
epoch 55
epoch 56
epoch 57
epoch 58
epoch 59
epoch 60
accuracy: 0.5903961584633853
best accuracy so far!
epoch 61
epoch 62
epoch 63
epoch 64
epoch 65
epoch 66
epoch 67
epoch 68
epoch 69
epoch 70
accuracy: 0.5

epoch 341
epoch 342
epoch 343
epoch 344
epoch 345
epoch 346
epoch 347
epoch 348
epoch 349
epoch 350
accuracy: 0.6232893157262905
best accuracy so far!
test_accuracies: [0.5903961584633853, 0.5697478991596638, 0.5721488595438176, 0.5891956782713086, 0.5923169267707082, 0.5920768307322929, 0.5815126050420167, 0.5992797118847539, 0.5867947178871549, 0.5726290516206483, 0.5891956782713086, 0.602641056422569, 0.6064825930372149, 0.6012004801920768, 0.5944777911164465, 0.6074429771908764, 0.600720288115246, 0.5937575030012006, 0.6048019207683073, 0.6012004801920768, 0.6165666266506602, 0.6110444177671068, 0.6074429771908764, 0.6024009603841536, 0.6067226890756303, 0.6064825930372149, 0.6122448979591837, 0.6180072028811525, 0.6180072028811525, 0.6232893157262905]
epoch 351
epoch 352
epoch 353
epoch 354
epoch 355
epoch 356
epoch 357
epoch 358
epoch 359
epoch 360
accuracy: 0.62953181272509
best accuracy so far!
test_accuracies: [0.5697478991596638, 0.5721488595438176, 0.5891956782713086, 0.5923

epoch 451
epoch 452
epoch 453
epoch 454
epoch 455
epoch 456
epoch 457
epoch 458
epoch 459
epoch 460
accuracy: 0.6273709483793517
test_accuracies: [0.602641056422569, 0.6064825930372149, 0.6012004801920768, 0.5944777911164465, 0.6074429771908764, 0.600720288115246, 0.5937575030012006, 0.6048019207683073, 0.6012004801920768, 0.6165666266506602, 0.6110444177671068, 0.6074429771908764, 0.6024009603841536, 0.6067226890756303, 0.6064825930372149, 0.6122448979591837, 0.6180072028811525, 0.6180072028811525, 0.6232893157262905, 0.62953181272509, 0.6304921968787516, 0.6290516206482593, 0.6103241296518607, 0.6237695078031212, 0.6391356542617046, 0.631452581032413, 0.6280912364945979, 0.624249699879952, 0.6254501800720288, 0.6273709483793517]
epoch 461
epoch 462
epoch 463
epoch 464
epoch 465
epoch 466
epoch 467
epoch 468
epoch 469
epoch 470
accuracy: 0.6372148859543818
test_accuracies: [0.6064825930372149, 0.6012004801920768, 0.5944777911164465, 0.6074429771908764, 0.600720288115246, 0.59375750300

epoch 565
epoch 566
epoch 567
epoch 568
epoch 569
epoch 570
accuracy: 0.6304921968787516
test_accuracies: [0.6074429771908764, 0.6024009603841536, 0.6067226890756303, 0.6064825930372149, 0.6122448979591837, 0.6180072028811525, 0.6180072028811525, 0.6232893157262905, 0.62953181272509, 0.6304921968787516, 0.6290516206482593, 0.6103241296518607, 0.6237695078031212, 0.6391356542617046, 0.631452581032413, 0.6280912364945979, 0.624249699879952, 0.6254501800720288, 0.6273709483793517, 0.6372148859543818, 0.6280912364945979, 0.624249699879952, 0.6316926770708283, 0.6328931572629052, 0.6309723889555823, 0.6302521008403361, 0.6333733493397359, 0.628811524609844, 0.6319327731092437, 0.6304921968787516]
epoch 571
epoch 572
epoch 573
epoch 574
epoch 575
epoch 576
epoch 577
epoch 578
epoch 579
epoch 580
accuracy: 0.6290516206482593
test_accuracies: [0.6024009603841536, 0.6067226890756303, 0.6064825930372149, 0.6122448979591837, 0.6180072028811525, 0.6180072028811525, 0.6232893157262905, 0.6295318127

epoch 679
epoch 680
accuracy: 0.6427370948379352
best accuracy so far!
test_accuracies: [0.6103241296518607, 0.6237695078031212, 0.6391356542617046, 0.631452581032413, 0.6280912364945979, 0.624249699879952, 0.6254501800720288, 0.6273709483793517, 0.6372148859543818, 0.6280912364945979, 0.624249699879952, 0.6316926770708283, 0.6328931572629052, 0.6309723889555823, 0.6302521008403361, 0.6333733493397359, 0.628811524609844, 0.6319327731092437, 0.6304921968787516, 0.6290516206482593, 0.6331332533013205, 0.6350540216086434, 0.6213685474189676, 0.6278511404561824, 0.6362545018007203, 0.6307322929171668, 0.6278511404561824, 0.6324129651860744, 0.6357743097238895, 0.6427370948379352]
epoch 681
epoch 682
epoch 683
epoch 684
epoch 685
epoch 686
epoch 687
epoch 688
epoch 689
epoch 690
accuracy: 0.6319327731092437
test_accuracies: [0.6237695078031212, 0.6391356542617046, 0.631452581032413, 0.6280912364945979, 0.624249699879952, 0.6254501800720288, 0.6273709483793517, 0.6372148859543818, 0.62809123

epoch 788
epoch 789
epoch 790
accuracy: 0.6271308523409364
test_accuracies: [0.6316926770708283, 0.6328931572629052, 0.6309723889555823, 0.6302521008403361, 0.6333733493397359, 0.628811524609844, 0.6319327731092437, 0.6304921968787516, 0.6290516206482593, 0.6331332533013205, 0.6350540216086434, 0.6213685474189676, 0.6278511404561824, 0.6362545018007203, 0.6307322929171668, 0.6278511404561824, 0.6324129651860744, 0.6357743097238895, 0.6427370948379352, 0.6319327731092437, 0.6278511404561824, 0.6206482593037215, 0.6312124849939976, 0.6309723889555823, 0.6321728691476591, 0.6309723889555823, 0.6297719087635054, 0.6345738295318127, 0.6247298919567827, 0.6271308523409364]
epoch 791
epoch 792
epoch 793
epoch 794
epoch 795
epoch 796
epoch 797
epoch 798
epoch 799
epoch 800
accuracy: 0.6276110444177672
test_accuracies: [0.6328931572629052, 0.6309723889555823, 0.6302521008403361, 0.6333733493397359, 0.628811524609844, 0.6319327731092437, 0.6304921968787516, 0.6290516206482593, 0.6331332533013205

epoch 898
epoch 899
epoch 900
accuracy: 0.6326530612244898
test_accuracies: [0.6213685474189676, 0.6278511404561824, 0.6362545018007203, 0.6307322929171668, 0.6278511404561824, 0.6324129651860744, 0.6357743097238895, 0.6427370948379352, 0.6319327731092437, 0.6278511404561824, 0.6206482593037215, 0.6312124849939976, 0.6309723889555823, 0.6321728691476591, 0.6309723889555823, 0.6297719087635054, 0.6345738295318127, 0.6247298919567827, 0.6271308523409364, 0.6276110444177672, 0.6211284513805522, 0.6230492196878752, 0.631452581032413, 0.626890756302521, 0.6357743097238895, 0.6338535414165667, 0.6379351740696279, 0.6393757503001201, 0.6328931572629052, 0.6326530612244898]
epoch 901
epoch 902
epoch 903
epoch 904
epoch 905
epoch 906
epoch 907
epoch 908
epoch 909
epoch 910
accuracy: 0.6340936374549819
test_accuracies: [0.6278511404561824, 0.6362545018007203, 0.6307322929171668, 0.6278511404561824, 0.6324129651860744, 0.6357743097238895, 0.6427370948379352, 0.6319327731092437, 0.6278511404561824

epoch 1007
epoch 1008
epoch 1009
epoch 1010
accuracy: 0.6436974789915966
test_accuracies: [0.6312124849939976, 0.6309723889555823, 0.6321728691476591, 0.6309723889555823, 0.6297719087635054, 0.6345738295318127, 0.6247298919567827, 0.6271308523409364, 0.6276110444177672, 0.6211284513805522, 0.6230492196878752, 0.631452581032413, 0.626890756302521, 0.6357743097238895, 0.6338535414165667, 0.6379351740696279, 0.6393757503001201, 0.6328931572629052, 0.6326530612244898, 0.6340936374549819, 0.6475390156062425, 0.6352941176470588, 0.636734693877551, 0.6211284513805522, 0.6422569027611045, 0.637454981992797, 0.646578631452581, 0.6429771908763505, 0.6376950780312125, 0.6436974789915966]
epoch 1011
epoch 1012
epoch 1013
epoch 1014
epoch 1015
epoch 1016
epoch 1017
epoch 1018
epoch 1019
epoch 1020
accuracy: 0.6372148859543818
test_accuracies: [0.6309723889555823, 0.6321728691476591, 0.6309723889555823, 0.6297719087635054, 0.6345738295318127, 0.6247298919567827, 0.6271308523409364, 0.627611044417767

epoch 1111
epoch 1112
epoch 1113
epoch 1114
epoch 1115
epoch 1116
epoch 1117
epoch 1118
epoch 1119
epoch 1120
accuracy: 0.64921968787515
test_accuracies: [0.631452581032413, 0.626890756302521, 0.6357743097238895, 0.6338535414165667, 0.6379351740696279, 0.6393757503001201, 0.6328931572629052, 0.6326530612244898, 0.6340936374549819, 0.6475390156062425, 0.6352941176470588, 0.636734693877551, 0.6211284513805522, 0.6422569027611045, 0.637454981992797, 0.646578631452581, 0.6429771908763505, 0.6376950780312125, 0.6436974789915966, 0.6372148859543818, 0.6340936374549819, 0.6276110444177672, 0.6300120048019208, 0.646578631452581, 0.6381752701080432, 0.6451380552220889, 0.6405762304921969, 0.646578631452581, 0.6523409363745498, 0.64921968787515]
epoch 1121
epoch 1122
epoch 1123
epoch 1124
epoch 1125
epoch 1126
epoch 1127
epoch 1128
epoch 1129
epoch 1130
accuracy: 0.6480192076830733
test_accuracies: [0.626890756302521, 0.6357743097238895, 0.6338535414165667, 0.6379351740696279, 0.6393757503001201

epoch 1221
epoch 1222
epoch 1223
epoch 1224
epoch 1225
epoch 1226
epoch 1227
epoch 1228
epoch 1229
epoch 1230
accuracy: 0.6448979591836734
test_accuracies: [0.636734693877551, 0.6211284513805522, 0.6422569027611045, 0.637454981992797, 0.646578631452581, 0.6429771908763505, 0.6376950780312125, 0.6436974789915966, 0.6372148859543818, 0.6340936374549819, 0.6276110444177672, 0.6300120048019208, 0.646578631452581, 0.6381752701080432, 0.6451380552220889, 0.6405762304921969, 0.646578631452581, 0.6523409363745498, 0.64921968787515, 0.6480192076830733, 0.653061224489796, 0.6545018007202881, 0.646578631452581, 0.6540216086434574, 0.6386554621848739, 0.6496998799519809, 0.6460984393757503, 0.6506602641056423, 0.6525810324129652, 0.6448979591836734]
epoch 1231
epoch 1232
epoch 1233
epoch 1234
epoch 1235
epoch 1236
epoch 1237
epoch 1238
epoch 1239
epoch 1240
accuracy: 0.6501800720288116
test_accuracies: [0.6211284513805522, 0.6422569027611045, 0.637454981992797, 0.646578631452581, 0.642977190876350

epoch 1331
epoch 1332
epoch 1333
epoch 1334
epoch 1335
epoch 1336
epoch 1337
epoch 1338
epoch 1339
epoch 1340
accuracy: 0.6391356542617046
test_accuracies: [0.6300120048019208, 0.646578631452581, 0.6381752701080432, 0.6451380552220889, 0.6405762304921969, 0.646578631452581, 0.6523409363745498, 0.64921968787515, 0.6480192076830733, 0.653061224489796, 0.6545018007202881, 0.646578631452581, 0.6540216086434574, 0.6386554621848739, 0.6496998799519809, 0.6460984393757503, 0.6506602641056423, 0.6525810324129652, 0.6448979591836734, 0.6501800720288116, 0.6504201680672269, 0.651140456182473, 0.6432172869147659, 0.6496998799519809, 0.648499399759904, 0.62953181272509, 0.6304921968787516, 0.6331332533013205, 0.6372148859543818, 0.6391356542617046]
epoch 1341
epoch 1342
epoch 1343
epoch 1344
epoch 1345
epoch 1346
epoch 1347
epoch 1348
epoch 1349
epoch 1350
accuracy: 0.6331332533013205
test_accuracies: [0.646578631452581, 0.6381752701080432, 0.6451380552220889, 0.6405762304921969, 0.646578631452581

In [13]:
print("Optimized rotation angles: {}".format(params))

training_time = time.time()

Optimized rotation angles: [ 0.54442105  0.94810917 -0.04531381  2.1977282   1.32078135  0.61475414
  0.56989393  0.30233714  0.59857158  2.07907191  3.53077682  3.82646467
 -0.36404656 -0.75502631  0.61696244  3.00324139  0.06137447  0.51320395
  0.57398783  0.32243283  0.07787437  0.91820119  0.66133746  0.01443033
 -0.57393487  0.12506485  3.19608375  0.66767786  2.15701089  0.27950322
  0.77658027 -0.07174518  0.35350092  0.76613862  0.72117264 -0.51184977
  0.62611711  0.1326223   0.38004169  1.00413993 -0.13533982  0.06913828
  0.61414005  0.6070413  -0.02900983]


# Testing

In [14]:
accuracy = measure_accuracy(X_test, y_test, params)
print(accuracy)

test_time = time.time()

0.6573333333333333


In [15]:
print(f"pre-processing time: {preprocessing_time-initial_time}")
print(f"training time: {training_time - preprocessing_time}")
print(f"test time: {test_time - training_time}")
print(f"total time: {test_time - initial_time}")

pre-processing time: 11.026961326599121
training time: 9180.524121046066
test time: 125.15459513664246
total time: 9316.705677509308
