# Imports

In [1]:
import math
import pandas as pd
import pennylane as qml
import time

from matplotlib import pyplot as plt
from pennylane import numpy as np
from pennylane.templates import AmplitudeEmbedding, AngleEmbedding
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Model Params

In [2]:
np.random.seed(42)
initial_params = np.random.random([45])

INITIALIZATION_METHOD = 'Angle'
BATCH_SIZE = 20
EPOCHS = 400

STEP_SIZE = 0.01
BETA_1 = 0.9
BETA_2 = 0.99
EPSILON = 0.00000001

TRAINING_SIZE = 0.78
VALIDATION_SIZE = 0.07
TEST_SIZE = 1-TRAINING_SIZE-VALIDATION_SIZE

initial_time = time.time()

# Fetch Dataset

In [3]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')

# Import dataset

In [4]:
examples = mnist.data
classes = mnist.target

x = []
y = []
for (example, label) in zip(examples, classes):
    if label in ["0", "2", "4", "6", "8"]:
        x.append(example)
        y.append(-1)
    else:
        x.append(example)
        y.append(1)

In [5]:
x = np.array(x)
y = np.array(y)

# Normalize pixels values
x = x / 255

X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=TEST_SIZE, shuffle=True)

In [6]:
validation_indexes = np.random.random_integers(len(X_train), size=(math.floor(len(X_train)*VALIDATION_SIZE),))
X_validation = [X_train[n] for n in validation_indexes]
y_validation = [y_train[n] for n in validation_indexes]

In [7]:
pca = PCA(n_components=8)
pca.fit(X_train)
X_train = pca.transform(X_train)
X_validation = pca.transform(X_validation)
X_test = pca.transform(X_test)

preprocessing_time = time.time()

# Circuit creation

In [8]:
device = qml.device("default.qubit", wires=8)

In [9]:
@qml.qnode(device)
def circuit(features, params):
    # Load state
    if INITIALIZATION_METHOD == 'Amplitude':
        AmplitudeEmbedding(features=features, wires=range(8), normalize=True, pad_with=0.)
    else:
        AngleEmbedding(features=features, wires=range(8), rotation='Y')

    # First layer
    qml.U3(params[0], params[1], params[2], wires=0)
    qml.U3(params[3], params[4], params[5], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.U3(params[6], params[7], params[8], wires=2)
    qml.U3(params[9], params[10], params[11], wires=3)
    qml.CNOT(wires=[2, 3])
    qml.U3(params[12], params[13], params[14], wires=4)
    qml.U3(params[15], params[16], params[17], wires=5)
    qml.CNOT(wires=[5, 4])
    qml.U3(params[18], params[19], params[20], wires=6)
    qml.U3(params[21], params[22], params[23], wires=7)
    qml.CNOT(wires=[7, 6])

    # Second layer
    qml.U3(params[24], params[25], params[26], wires=1)
    qml.U3(params[27], params[28], params[29], wires=2)
    qml.CNOT(wires=[1, 2])
    qml.U3(params[30], params[31], params[32], wires=5)
    qml.U3(params[33], params[34], params[35], wires=6)
    qml.CNOT(wires=[6, 5])

    # Third layer
    qml.U3(params[36], params[37], params[38], wires=2)
    qml.U3(params[39], params[40], params[41], wires=5)
    qml.CNOT(wires=[2, 5])

    # Fourth layer
    qml.U3(params[42], params[43], params[44], wires=5)

    # Measurement
    return qml.expval(qml.PauliZ(5))

## Circuit example

In [10]:
features = X_train[0]
print(f"Inital parameters: {initial_params}\n")
print(f"Example features: {features}\n")
print(f"Expectation value: {circuit(features, initial_params)}\n")
print(circuit.draw())

Inital parameters: [0.37454012 0.95071431 0.73199394 0.59865848 0.15601864 0.15599452
 0.05808361 0.86617615 0.60111501 0.70807258 0.02058449 0.96990985
 0.83244264 0.21233911 0.18182497 0.18340451 0.30424224 0.52475643
 0.43194502 0.29122914 0.61185289 0.13949386 0.29214465 0.36636184
 0.45606998 0.78517596 0.19967378 0.51423444 0.59241457 0.04645041
 0.60754485 0.17052412 0.06505159 0.94888554 0.96563203 0.80839735
 0.30461377 0.09767211 0.68423303 0.44015249 0.12203823 0.49517691
 0.03438852 0.9093204  0.25877998]

Example features: [ 1.93579439 -2.6662852  -2.77117209 -0.50435562 -0.26535735 -0.75975054
  0.03295553 -0.9811965 ]

Expectation value: 0.002696441955207407

 0: ──RY(1.94)────Rot(0.732, 0.375, -0.732)───Rϕ(0.732)──Rϕ(0.951)───╭C──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤     
 1: ──RY(-2.67)───Rot(0.156, 0.599, -0.156)───Rϕ(0.156)──Rϕ(0.156)───╰X──Rot(

# Accuracy test definition

In [11]:
def measure_accuracy(x, y, circuit_params):
    class_errors = 0

    for example, example_class in zip(x, y):
        predicted_value = circuit(example, circuit_params)

        if (example_class > 0 and predicted_value <= 0) or (example_class <= 0 and predicted_value > 0):
            class_errors += 1

    return 1 - (class_errors/len(y))

# Training

In [12]:
params = initial_params
opt = qml.AdamOptimizer(stepsize=STEP_SIZE, beta1=BETA_1, beta2=BETA_2, eps=EPSILON)
test_accuracies = []
best_validation_accuracy = 0.0
best_params = []

for i in range(len(X_train)):
    features = X_train[i]
    expected_value = y_train[i]

    def cost(circuit_params):
        value = circuit(features, circuit_params)
        return ((expected_value - value) ** 2)/len(X_train)

    params = opt.step(cost, params)

    if i % BATCH_SIZE == 0:
        print(f"epoch {i//BATCH_SIZE}")

    if i % (10*BATCH_SIZE) == 0:
        current_accuracy = measure_accuracy(X_validation, y_validation, params)
        test_accuracies.append(current_accuracy)
        print(f"accuracy: {current_accuracy}")

        if current_accuracy > best_validation_accuracy:
            print("best accuracy so far!")
            best_validation_accuracy = current_accuracy
            best_params = params

    if len(test_accuracies) == 30:
        print(f"test_accuracies: {test_accuracies}")

        if np.allclose(best_validation_accuracy, test_accuracies[0]):
            params = best_params
            break

        del test_accuracies[0]

epoch 0
accuracy: 0.4938775510204082
best accuracy so far!
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
accuracy: 0.5282112845138055
best accuracy so far!
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
accuracy: 0.5495798319327732
best accuracy so far!
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
accuracy: 0.5567827130852341
best accuracy so far!
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
accuracy: 0.5515006002400961
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49
epoch 50
accuracy: 0.5620648259303722
best accuracy so far!
epoch 51
epoch 52
epoch 53
epoch 54
epoch 55
epoch 56
epoch 57
epoch 58
epoch 59
epoch 60
accuracy: 0.5615846338535414
epoch 61
epoch 62
epoch 63
epoch 64
epoch 65
epoch 66
epoch 67
epoch 68
epoch 69
epoch 70
accuracy: 0.563265306122449
best accuracy so far!
epoch 71

epoch 341
epoch 342
epoch 343
epoch 344
epoch 345
epoch 346
epoch 347
epoch 348
epoch 349
epoch 350
accuracy: 0.5903961584633853
test_accuracies: [0.5615846338535414, 0.563265306122449, 0.5719087635054021, 0.5721488595438176, 0.575750300120048, 0.5822328931572629, 0.5927971188475389, 0.6019207683073229, 0.5815126050420167, 0.5687875150060024, 0.5932773109243697, 0.5863145258103242, 0.5841536614645859, 0.5887154861944778, 0.5747899159663865, 0.5899159663865546, 0.5841536614645859, 0.5959183673469388, 0.5923169267707082, 0.5884753901560624, 0.6016806722689076, 0.6055222088835535, 0.5937575030012006, 0.6033613445378152, 0.6105642256902761, 0.6069627851140456, 0.5841536614645859, 0.5903961584633853, 0.5887154861944778, 0.5903961584633853]
epoch 351
epoch 352
epoch 353
epoch 354
epoch 355
epoch 356
epoch 357
epoch 358
epoch 359
epoch 360
accuracy: 0.5834333733493398
test_accuracies: [0.563265306122449, 0.5719087635054021, 0.5721488595438176, 0.575750300120048, 0.5822328931572629, 0.59279711

epoch 451
epoch 452
epoch 453
epoch 454
epoch 455
epoch 456
epoch 457
epoch 458
epoch 459
epoch 460
accuracy: 0.5949579831932773
test_accuracies: [0.5863145258103242, 0.5841536614645859, 0.5887154861944778, 0.5747899159663865, 0.5899159663865546, 0.5841536614645859, 0.5959183673469388, 0.5923169267707082, 0.5884753901560624, 0.6016806722689076, 0.6055222088835535, 0.5937575030012006, 0.6033613445378152, 0.6105642256902761, 0.6069627851140456, 0.5841536614645859, 0.5903961584633853, 0.5887154861944778, 0.5903961584633853, 0.5834333733493398, 0.5939975990396158, 0.5966386554621849, 0.605282112845138, 0.5954381752701081, 0.5942376950780313, 0.5836734693877551, 0.5961584633853542, 0.5968787515006002, 0.5915966386554622, 0.5949579831932773]
epoch 461
epoch 462
epoch 463
epoch 464
epoch 465
epoch 466
epoch 467
epoch 468
epoch 469
epoch 470
accuracy: 0.5930372148859544
test_accuracies: [0.5841536614645859, 0.5887154861944778, 0.5747899159663865, 0.5899159663865546, 0.5841536614645859, 0.59591

epoch 561
epoch 562
epoch 563
epoch 564
epoch 565
epoch 566
epoch 567
epoch 568
epoch 569
epoch 570
accuracy: 0.604561824729892
test_accuracies: [0.5937575030012006, 0.6033613445378152, 0.6105642256902761, 0.6069627851140456, 0.5841536614645859, 0.5903961584633853, 0.5887154861944778, 0.5903961584633853, 0.5834333733493398, 0.5939975990396158, 0.5966386554621849, 0.605282112845138, 0.5954381752701081, 0.5942376950780313, 0.5836734693877551, 0.5961584633853542, 0.5968787515006002, 0.5915966386554622, 0.5949579831932773, 0.5930372148859544, 0.5894357743097238, 0.5774309723889556, 0.5903961584633853, 0.5872749099639856, 0.5920768307322929, 0.5987995198079232, 0.5793517406962785, 0.5978391356542617, 0.6112845138055223, 0.604561824729892]
epoch 571
epoch 572
epoch 573
epoch 574
epoch 575
epoch 576
epoch 577
epoch 578
epoch 579
epoch 580
accuracy: 0.6062424969987995
test_accuracies: [0.6033613445378152, 0.6105642256902761, 0.6069627851140456, 0.5841536614645859, 0.5903961584633853, 0.5887154

epoch 671
epoch 672
epoch 673
epoch 674
epoch 675
epoch 676
epoch 677
epoch 678
epoch 679
epoch 680
accuracy: 0.6158463385354142
best accuracy so far!
test_accuracies: [0.605282112845138, 0.5954381752701081, 0.5942376950780313, 0.5836734693877551, 0.5961584633853542, 0.5968787515006002, 0.5915966386554622, 0.5949579831932773, 0.5930372148859544, 0.5894357743097238, 0.5774309723889556, 0.5903961584633853, 0.5872749099639856, 0.5920768307322929, 0.5987995198079232, 0.5793517406962785, 0.5978391356542617, 0.6112845138055223, 0.604561824729892, 0.6062424969987995, 0.5983193277310924, 0.6040816326530613, 0.611764705882353, 0.5935174069627851, 0.6038415366146459, 0.6062424969987995, 0.5899159663865546, 0.6048019207683073, 0.5915966386554622, 0.6158463385354142]
epoch 681
epoch 682
epoch 683
epoch 684
epoch 685
epoch 686
epoch 687
epoch 688
epoch 689
epoch 690
accuracy: 0.6189675870348139
best accuracy so far!
test_accuracies: [0.5954381752701081, 0.5942376950780313, 0.5836734693877551, 0.596

epoch 781
epoch 782
epoch 783
epoch 784
epoch 785
epoch 786
epoch 787
epoch 788
epoch 789
epoch 790
accuracy: 0.6153661464585835
test_accuracies: [0.5903961584633853, 0.5872749099639856, 0.5920768307322929, 0.5987995198079232, 0.5793517406962785, 0.5978391356542617, 0.6112845138055223, 0.604561824729892, 0.6062424969987995, 0.5983193277310924, 0.6040816326530613, 0.611764705882353, 0.5935174069627851, 0.6038415366146459, 0.6062424969987995, 0.5899159663865546, 0.6048019207683073, 0.5915966386554622, 0.6158463385354142, 0.6189675870348139, 0.6158463385354142, 0.6153661464585835, 0.6036014405762304, 0.5975990396158464, 0.6100840336134454, 0.5973589435774309, 0.60984393757503, 0.6112845138055223, 0.6160864345738295, 0.6153661464585835]
epoch 791
epoch 792
epoch 793
epoch 794
epoch 795
epoch 796
epoch 797
epoch 798
epoch 799
epoch 800
accuracy: 0.6132052821128451
test_accuracies: [0.5872749099639856, 0.5920768307322929, 0.5987995198079232, 0.5793517406962785, 0.5978391356542617, 0.61128451

epoch 895
epoch 896
epoch 897
epoch 898
epoch 899
epoch 900
accuracy: 0.6156062424969988
test_accuracies: [0.611764705882353, 0.5935174069627851, 0.6038415366146459, 0.6062424969987995, 0.5899159663865546, 0.6048019207683073, 0.5915966386554622, 0.6158463385354142, 0.6189675870348139, 0.6158463385354142, 0.6153661464585835, 0.6036014405762304, 0.5975990396158464, 0.6100840336134454, 0.5973589435774309, 0.60984393757503, 0.6112845138055223, 0.6160864345738295, 0.6153661464585835, 0.6132052821128451, 0.6064825930372149, 0.6024009603841536, 0.6067226890756303, 0.60984393757503, 0.6148859543817526, 0.6028811524609844, 0.6038415366146459, 0.6062424969987995, 0.594717887154862, 0.6156062424969988]
epoch 901
epoch 902
epoch 903
epoch 904
epoch 905
epoch 906
epoch 907
epoch 908
epoch 909
epoch 910
accuracy: 0.6062424969987995
test_accuracies: [0.5935174069627851, 0.6038415366146459, 0.6062424969987995, 0.5899159663865546, 0.6048019207683073, 0.5915966386554622, 0.6158463385354142, 0.6189675870

epoch 1011
epoch 1012
epoch 1013
epoch 1014
epoch 1015
epoch 1016
epoch 1017
epoch 1018
epoch 1019
epoch 1020
accuracy: 0.6057623049219688
test_accuracies: [0.5975990396158464, 0.6100840336134454, 0.5973589435774309, 0.60984393757503, 0.6112845138055223, 0.6160864345738295, 0.6153661464585835, 0.6132052821128451, 0.6064825930372149, 0.6024009603841536, 0.6067226890756303, 0.60984393757503, 0.6148859543817526, 0.6028811524609844, 0.6038415366146459, 0.6062424969987995, 0.594717887154862, 0.6156062424969988, 0.6062424969987995, 0.6081632653061224, 0.594717887154862, 0.6208883553421369, 0.6182472989195678, 0.6084033613445379, 0.6, 0.6062424969987995, 0.6079231692677071, 0.6091236494597839, 0.6076830732292917, 0.6057623049219688]
epoch 1021
epoch 1022
epoch 1023
epoch 1024
epoch 1025
epoch 1026
epoch 1027
epoch 1028
epoch 1029
epoch 1030
accuracy: 0.6081632653061224
test_accuracies: [0.6100840336134454, 0.5973589435774309, 0.60984393757503, 0.6112845138055223, 0.6160864345738295, 0.6153661

accuracy: 0.6088835534213686
test_accuracies: [0.60984393757503, 0.6148859543817526, 0.6028811524609844, 0.6038415366146459, 0.6062424969987995, 0.594717887154862, 0.6156062424969988, 0.6062424969987995, 0.6081632653061224, 0.594717887154862, 0.6208883553421369, 0.6182472989195678, 0.6084033613445379, 0.6, 0.6062424969987995, 0.6079231692677071, 0.6091236494597839, 0.6076830732292917, 0.6057623049219688, 0.6081632653061224, 0.607202881152461, 0.605282112845138, 0.6079231692677071, 0.6105642256902761, 0.6148859543817526, 0.6158463385354142, 0.6158463385354142, 0.6168067226890757, 0.6031212484993997, 0.6088835534213686]
epoch 1131
epoch 1132
epoch 1133
epoch 1134
epoch 1135
epoch 1136
epoch 1137
epoch 1138
epoch 1139
epoch 1140
accuracy: 0.6144057623049219
test_accuracies: [0.6148859543817526, 0.6028811524609844, 0.6038415366146459, 0.6062424969987995, 0.594717887154862, 0.6156062424969988, 0.6062424969987995, 0.6081632653061224, 0.594717887154862, 0.6208883553421369, 0.6182472989195678,

epoch 1241
epoch 1242
epoch 1243
epoch 1244
epoch 1245
epoch 1246
epoch 1247
epoch 1248
epoch 1249
epoch 1250
accuracy: 0.6088835534213686
test_accuracies: [0.6084033613445379, 0.6, 0.6062424969987995, 0.6079231692677071, 0.6091236494597839, 0.6076830732292917, 0.6057623049219688, 0.6081632653061224, 0.607202881152461, 0.605282112845138, 0.6079231692677071, 0.6105642256902761, 0.6148859543817526, 0.6158463385354142, 0.6158463385354142, 0.6168067226890757, 0.6031212484993997, 0.6088835534213686, 0.6144057623049219, 0.6093637454981993, 0.6110444177671068, 0.6009603841536615, 0.6031212484993997, 0.6074429771908764, 0.6225690276110445, 0.6256902761104441, 0.6055222088835535, 0.6079231692677071, 0.6165666266506602, 0.6088835534213686]
epoch 1251
epoch 1252
epoch 1253
epoch 1254
epoch 1255
epoch 1256
epoch 1257
epoch 1258
epoch 1259
epoch 1260
accuracy: 0.5932773109243697
test_accuracies: [0.6, 0.6062424969987995, 0.6079231692677071, 0.6091236494597839, 0.6076830732292917, 0.6057623049219688

epoch 1351
epoch 1352
epoch 1353
epoch 1354
epoch 1355
epoch 1356
epoch 1357
epoch 1358
epoch 1359
epoch 1360
accuracy: 0.6079231692677071
test_accuracies: [0.6105642256902761, 0.6148859543817526, 0.6158463385354142, 0.6158463385354142, 0.6168067226890757, 0.6031212484993997, 0.6088835534213686, 0.6144057623049219, 0.6093637454981993, 0.6110444177671068, 0.6009603841536615, 0.6031212484993997, 0.6074429771908764, 0.6225690276110445, 0.6256902761104441, 0.6055222088835535, 0.6079231692677071, 0.6165666266506602, 0.6088835534213686, 0.5932773109243697, 0.5968787515006002, 0.5978391356542617, 0.5963985594237695, 0.6160864345738295, 0.6132052821128451, 0.6206482593037215, 0.6156062424969988, 0.6031212484993997, 0.605282112845138, 0.6079231692677071]
epoch 1361
epoch 1362
epoch 1363
epoch 1364
epoch 1365
epoch 1366
epoch 1367
epoch 1368
epoch 1369
epoch 1370
accuracy: 0.6105642256902761
test_accuracies: [0.6148859543817526, 0.6158463385354142, 0.6158463385354142, 0.6168067226890757, 0.60312

epoch 1461
epoch 1462
epoch 1463
epoch 1464
epoch 1465
epoch 1466
epoch 1467
epoch 1468
epoch 1469
epoch 1470
accuracy: 0.6093637454981993
test_accuracies: [0.6031212484993997, 0.6074429771908764, 0.6225690276110445, 0.6256902761104441, 0.6055222088835535, 0.6079231692677071, 0.6165666266506602, 0.6088835534213686, 0.5932773109243697, 0.5968787515006002, 0.5978391356542617, 0.5963985594237695, 0.6160864345738295, 0.6132052821128451, 0.6206482593037215, 0.6156062424969988, 0.6031212484993997, 0.605282112845138, 0.6079231692677071, 0.6105642256902761, 0.6050420168067228, 0.60984393757503, 0.6105642256902761, 0.6088835534213686, 0.6132052821128451, 0.6112845138055223, 0.6076830732292917, 0.6043217286914766, 0.6112845138055223, 0.6093637454981993]
epoch 1471
epoch 1472
epoch 1473
epoch 1474
epoch 1475
epoch 1476
epoch 1477
epoch 1478
epoch 1479
epoch 1480
accuracy: 0.611764705882353
test_accuracies: [0.6074429771908764, 0.6225690276110445, 0.6256902761104441, 0.6055222088835535, 0.60792316

In [13]:
print("Optimized rotation angles: {}".format(params))

training_time = time.time()

Optimized rotation angles: [ 8.16680788e-01  9.50714306e-01  2.54205756e+00  2.26974006e+00
 -9.38713189e-01 -1.19546656e+00 -9.89243082e-01  9.87876116e-01
  5.74850983e-01  6.47741476e-01  6.66363553e-01  2.77556984e+00
  8.10930445e-01  2.39541726e+00  9.68276568e-01  1.12534577e+00
 -3.46551887e-01  4.13082897e-01  8.13601935e-01 -5.57710908e-01
 -4.45044580e-01  1.07926360e+00  2.92144649e-01 -9.10173907e-08
  1.98420843e+00  7.85175961e-01 -8.57772360e-01 -8.56064182e-01
  4.96245505e-01  1.68150383e-01  1.21750056e+00  1.30180499e-01
 -5.85742537e-01  2.32182792e+00  9.65632033e-01  3.73810393e+00
  8.76709993e-01  9.76721140e-02  6.91246867e-02  1.16688905e+00
  7.90299808e-02  1.29380304e-01 -8.11356206e-01  9.09320402e-01
  5.43731427e-02]


# Testing

In [14]:
accuracy = measure_accuracy(X_test, y_test, params)
print(accuracy)

test_time = time.time()

0.6249523809523809


In [15]:
print(f"pre-processing time: {preprocessing_time-initial_time}")
print(f"training time: {training_time - preprocessing_time}")
print(f"test time: {test_time - training_time}")
print(f"total time: {test_time - initial_time}")

pre-processing time: 75.44581437110901
training time: 28726.242733716965
test time: 518.111909866333
total time: 29319.800457954407
