# Imports

In [1]:
import math
import pandas as pd
import pennylane as qml
import time

from matplotlib import pyplot as plt
from pennylane import numpy as np
from pennylane.templates import AmplitudeEmbedding, AngleEmbedding
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Model Params

In [2]:
np.random.seed(42)
initial_params = np.random.random([15])

INITIALIZATION_METHOD = 'Angle'
BATCH_SIZE = 20
EPOCHS = 400

STEP_SIZE = 0.01
BETA_1 = 0.9
BETA_2 = 0.99
EPSILON = 0.00000001

TRAINING_SIZE = 0.78
VALIDATION_SIZE = 0.07
TEST_SIZE = 1-TRAINING_SIZE-VALIDATION_SIZE

initial_time = time.time()

# Fetch Dataset

In [3]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')

# Import dataset

In [4]:
examples = mnist.data
classes = mnist.target

x = []
y = []
for (example, label) in zip(examples, classes):
    if label in ["0", "1", "2", "3"]:
        x.append(example)
        y.append(-1)
    else:
        x.append(example)
        y.append(1)

In [5]:
x = np.array(x)
y = np.array(y)

# Normalize pixels values
x = x / 255

X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=TEST_SIZE, shuffle=True)

In [6]:
validation_indexes = np.random.random_integers(len(X_train), size=(math.floor(len(X_train)*VALIDATION_SIZE),))
X_validation = [X_train[n] for n in validation_indexes]
y_validation = [y_train[n] for n in validation_indexes]

In [7]:
pca = PCA(n_components=8)
pca.fit(X_train)
X_train = pca.transform(X_train)
X_validation = pca.transform(X_validation)
X_test = pca.transform(X_test)

preprocessing_time = time.time()

# Circuit creation

In [8]:
device = qml.device("default.qubit", wires=8)

In [9]:
@qml.qnode(device)
def circuit(features, params):
    # Load state
    if INITIALIZATION_METHOD == 'Amplitude':
        AmplitudeEmbedding(features=features, wires=range(8), normalize=True, pad_with=0.)
    else:
        AngleEmbedding(features=features, wires=range(8), rotation='Y')

    # First layer
    qml.RY(params[0], wires=0)
    qml.RY(params[1], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RY(params[2], wires=2)
    qml.RY(params[3], wires=3)
    qml.CNOT(wires=[2, 3])
    qml.RY(params[4], wires=4)
    qml.RY(params[5], wires=5)
    qml.CNOT(wires=[5, 4])
    qml.RY(params[6], wires=6)
    qml.RY(params[7], wires=7)
    qml.CNOT(wires=[7, 6])

    # Second layer
    qml.RY(params[8], wires=1)
    qml.RY(params[9], wires=2)
    qml.CNOT(wires=[1, 2])
    qml.RY(params[10], wires=5)
    qml.RY(params[11], wires=6)
    qml.CNOT(wires=[6, 5])

    # Third layer
    qml.RY(params[12], wires=2)
    qml.RY(params[13], wires=5)
    qml.CNOT(wires=[2, 5])

    # Fourth layer
    qml.RY(params[14], wires=5)

    # Measurement
    return qml.expval(qml.PauliZ(5))

## Circuit example

In [10]:
features = X_train[0]
print(f"Inital parameters: {initial_params}\n")
print(f"Example features: {features}\n")
print(f"Expectation value: {circuit(features, initial_params)}\n")
print(circuit.draw())

Inital parameters: [0.37454012 0.95071431 0.73199394 0.59865848 0.15601864 0.15599452
 0.05808361 0.86617615 0.60111501 0.70807258 0.02058449 0.96990985
 0.83244264 0.21233911 0.18182497]

Example features: [ 3.38764876 -1.94605776  0.49431788  0.86600787  0.48087724 -0.11930587
 -0.66084286  2.41654873]

Expectation value: -0.005924716389768325

 0: ──RY(3.39)────RY(0.375)───╭C────────────────────────────────────────────┤     
 1: ──RY(-1.95)───RY(0.951)───╰X──RY(0.601)───╭C────────────────────────────┤     
 2: ──RY(0.494)───RY(0.732)───╭C──RY(0.708)───╰X──RY(0.832)──╭C─────────────┤     
 3: ──RY(0.866)───RY(0.599)───╰X─────────────────────────────│──────────────┤     
 4: ──RY(0.481)───RY(0.156)───╭X─────────────────────────────│──────────────┤     
 5: ──RY(-0.119)──RY(0.156)───╰C──RY(0.0206)──╭X──RY(0.212)──╰X──RY(0.182)──┤ ⟨Z⟩ 
 6: ──RY(-0.661)──RY(0.0581)──╭X──RY(0.97)────╰C────────────────────────────┤     
 7: ──RY(2.42)────RY(0.866)───╰C──────────────────────────────────────

# Accuracy test definition

In [11]:
def measure_accuracy(x, y, circuit_params):
    class_errors = 0

    for example, example_class in zip(x, y):
        predicted_value = circuit(example, circuit_params)

        if (example_class > 0 and predicted_value <= 0) or (example_class <= 0 and predicted_value > 0):
            class_errors += 1

    return 1 - (class_errors/len(y))

# Training

In [12]:
params = initial_params
opt = qml.AdamOptimizer(stepsize=STEP_SIZE, beta1=BETA_1, beta2=BETA_2, eps=EPSILON)
test_accuracies = []

for i in range(len(X_train)):
    features = X_train[i]
    expected_value = y_train[i]

    def cost(circuit_params):
        value = circuit(features, circuit_params)
        return ((expected_value - value) ** 2)/len(X_train)

    params = opt.step(cost, params)

    if i % BATCH_SIZE == 0:
        print(f"epoch {i//BATCH_SIZE}")
    
    if i % (10*BATCH_SIZE) == 0:
        current_accuracy = measure_accuracy(X_validation, y_validation, params)
        test_accuracies.append(current_accuracy)
        print(f"accuracy: {current_accuracy}")

    if len(test_accuracies) == 30:
        print(f"test_accuracies: {test_accuracies}")

        if np.allclose(test_accuracies, 30*[test_accuracies[-1]]):
            break

        del test_accuracies[0]

epoch 0
accuracy: 0.5277310924369748
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
accuracy: 0.5822328931572629
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
accuracy: 0.5932773109243697
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
accuracy: 0.5980792316926771
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
accuracy: 0.6172869147659064
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49
epoch 50
accuracy: 0.6093637454981993
epoch 51
epoch 52
epoch 53
epoch 54
epoch 55
epoch 56
epoch 57
epoch 58
epoch 59
epoch 60
accuracy: 0.6069627851140456
epoch 61
epoch 62
epoch 63
epoch 64
epoch 65
epoch 66
epoch 67
epoch 68
epoch 69
epoch 70
accuracy: 0.6230492196878752
epoch 71
epoch 72
epoch 73
epoch 74
epoch 75
epoch 76
epoch 77
epoch 78
epoch 79
epoch 80
accuracy: 0.6292917166866747
epoch 81
epoch 82
ep

epoch 351
epoch 352
epoch 353
epoch 354
epoch 355
epoch 356
epoch 357
epoch 358
epoch 359
epoch 360
accuracy: 0.6477791116446578
test_accuracies: [0.6230492196878752, 0.6292917166866747, 0.6333733493397359, 0.6309723889555823, 0.6350540216086434, 0.6410564225690276, 0.6386554621848739, 0.62953181272509, 0.6355342136854742, 0.6432172869147659, 0.6460984393757503, 0.6506602641056423, 0.6518607442977191, 0.6429771908763505, 0.6436974789915966, 0.6477791116446578, 0.6388955582232894, 0.6432172869147659, 0.6482593037214885, 0.6458583433373349, 0.6441776710684274, 0.6475390156062425, 0.6460984393757503, 0.6564225690276111, 0.6518607442977191, 0.6501800720288116, 0.6470588235294117, 0.6581032412965186, 0.6552220888355342, 0.6477791116446578]
epoch 361
epoch 362
epoch 363
epoch 364
epoch 365
epoch 366
epoch 367
epoch 368
epoch 369
epoch 370
accuracy: 0.6585834333733493
test_accuracies: [0.6292917166866747, 0.6333733493397359, 0.6309723889555823, 0.6350540216086434, 0.6410564225690276, 0.638655

epoch 463
epoch 464
epoch 465
epoch 466
epoch 467
epoch 468
epoch 469
epoch 470
accuracy: 0.6516206482593037
test_accuracies: [0.6506602641056423, 0.6518607442977191, 0.6429771908763505, 0.6436974789915966, 0.6477791116446578, 0.6388955582232894, 0.6432172869147659, 0.6482593037214885, 0.6458583433373349, 0.6441776710684274, 0.6475390156062425, 0.6460984393757503, 0.6564225690276111, 0.6518607442977191, 0.6501800720288116, 0.6470588235294117, 0.6581032412965186, 0.6552220888355342, 0.6477791116446578, 0.6585834333733493, 0.648499399759904, 0.6542617046818727, 0.653781512605042, 0.6581032412965186, 0.6564225690276111, 0.651140456182473, 0.6509003601440576, 0.6472989195678271, 0.6482593037214885, 0.6516206482593037]
epoch 471
epoch 472
epoch 473
epoch 474
epoch 475
epoch 476
epoch 477
epoch 478
epoch 479
epoch 480
accuracy: 0.6516206482593037
test_accuracies: [0.6518607442977191, 0.6429771908763505, 0.6436974789915966, 0.6477791116446578, 0.6388955582232894, 0.6432172869147659, 0.6482593

epoch 575
epoch 576
epoch 577
epoch 578
epoch 579
epoch 580
accuracy: 0.653061224489796
test_accuracies: [0.6460984393757503, 0.6564225690276111, 0.6518607442977191, 0.6501800720288116, 0.6470588235294117, 0.6581032412965186, 0.6552220888355342, 0.6477791116446578, 0.6585834333733493, 0.648499399759904, 0.6542617046818727, 0.653781512605042, 0.6581032412965186, 0.6564225690276111, 0.651140456182473, 0.6509003601440576, 0.6472989195678271, 0.6482593037214885, 0.6516206482593037, 0.6516206482593037, 0.6501800720288116, 0.6369747899159663, 0.6496998799519809, 0.6472989195678271, 0.6521008403361345, 0.646578631452581, 0.6578631452581032, 0.6518607442977191, 0.6453781512605041, 0.653061224489796]
epoch 581
epoch 582
epoch 583
epoch 584
epoch 585
epoch 586
epoch 587
epoch 588
epoch 589
epoch 590
accuracy: 0.643937575030012
test_accuracies: [0.6564225690276111, 0.6518607442977191, 0.6501800720288116, 0.6470588235294117, 0.6581032412965186, 0.6552220888355342, 0.6477791116446578, 0.65858343337

epoch 690
accuracy: 0.6463385354141657
test_accuracies: [0.653781512605042, 0.6581032412965186, 0.6564225690276111, 0.651140456182473, 0.6509003601440576, 0.6472989195678271, 0.6482593037214885, 0.6516206482593037, 0.6516206482593037, 0.6501800720288116, 0.6369747899159663, 0.6496998799519809, 0.6472989195678271, 0.6521008403361345, 0.646578631452581, 0.6578631452581032, 0.6518607442977191, 0.6453781512605041, 0.653061224489796, 0.643937575030012, 0.6564225690276111, 0.6516206482593037, 0.6521008403361345, 0.6561824729891956, 0.6475390156062425, 0.6535414165666267, 0.6516206482593037, 0.6540216086434574, 0.6542617046818727, 0.6463385354141657]
epoch 691
epoch 692
epoch 693
epoch 694
epoch 695
epoch 696
epoch 697
epoch 698
epoch 699
epoch 700
accuracy: 0.6521008403361345
test_accuracies: [0.6581032412965186, 0.6564225690276111, 0.651140456182473, 0.6509003601440576, 0.6472989195678271, 0.6482593037214885, 0.6516206482593037, 0.6516206482593037, 0.6501800720288116, 0.6369747899159663, 0.

epoch 801
epoch 802
epoch 803
epoch 804
epoch 805
epoch 806
epoch 807
epoch 808
epoch 809
epoch 810
accuracy: 0.6619447779111645
test_accuracies: [0.6472989195678271, 0.6521008403361345, 0.646578631452581, 0.6578631452581032, 0.6518607442977191, 0.6453781512605041, 0.653061224489796, 0.643937575030012, 0.6564225690276111, 0.6516206482593037, 0.6521008403361345, 0.6561824729891956, 0.6475390156062425, 0.6535414165666267, 0.6516206482593037, 0.6540216086434574, 0.6542617046818727, 0.6463385354141657, 0.6521008403361345, 0.6593037214885955, 0.6593037214885955, 0.6648259303721489, 0.6595438175270107, 0.6540216086434574, 0.6569027611044418, 0.6408163265306123, 0.6542617046818727, 0.6518607442977191, 0.6518607442977191, 0.6619447779111645]
epoch 811
epoch 812
epoch 813
epoch 814
epoch 815
epoch 816
epoch 817
epoch 818
epoch 819
epoch 820
accuracy: 0.658343337334934
test_accuracies: [0.6521008403361345, 0.646578631452581, 0.6578631452581032, 0.6518607442977191, 0.6453781512605041, 0.653061224

epoch 913
epoch 914
epoch 915
epoch 916
epoch 917
epoch 918
epoch 919
epoch 920
accuracy: 0.6501800720288116
test_accuracies: [0.6561824729891956, 0.6475390156062425, 0.6535414165666267, 0.6516206482593037, 0.6540216086434574, 0.6542617046818727, 0.6463385354141657, 0.6521008403361345, 0.6593037214885955, 0.6593037214885955, 0.6648259303721489, 0.6595438175270107, 0.6540216086434574, 0.6569027611044418, 0.6408163265306123, 0.6542617046818727, 0.6518607442977191, 0.6518607442977191, 0.6619447779111645, 0.658343337334934, 0.648499399759904, 0.6535414165666267, 0.6564225690276111, 0.6581032412965186, 0.6561824729891956, 0.6578631452581032, 0.6585834333733493, 0.6427370948379352, 0.6581032412965186, 0.6501800720288116]
epoch 921
epoch 922
epoch 923
epoch 924
epoch 925
epoch 926
epoch 927
epoch 928
epoch 929
epoch 930
accuracy: 0.6453781512605041
test_accuracies: [0.6475390156062425, 0.6535414165666267, 0.6516206482593037, 0.6540216086434574, 0.6542617046818727, 0.6463385354141657, 0.652100

epoch 1022
epoch 1023
epoch 1024
epoch 1025
epoch 1026
epoch 1027
epoch 1028
epoch 1029
epoch 1030
accuracy: 0.6585834333733493
test_accuracies: [0.6595438175270107, 0.6540216086434574, 0.6569027611044418, 0.6408163265306123, 0.6542617046818727, 0.6518607442977191, 0.6518607442977191, 0.6619447779111645, 0.658343337334934, 0.648499399759904, 0.6535414165666267, 0.6564225690276111, 0.6581032412965186, 0.6561824729891956, 0.6578631452581032, 0.6585834333733493, 0.6427370948379352, 0.6581032412965186, 0.6501800720288116, 0.6453781512605041, 0.6518607442977191, 0.6472989195678271, 0.6496998799519809, 0.6458583433373349, 0.651140456182473, 0.6554621848739496, 0.6554621848739496, 0.6468187274909964, 0.6557022809123649, 0.6585834333733493]
epoch 1031
epoch 1032
epoch 1033
epoch 1034
epoch 1035
epoch 1036
epoch 1037
epoch 1038
epoch 1039
epoch 1040
accuracy: 0.6597839135654262
test_accuracies: [0.6540216086434574, 0.6569027611044418, 0.6408163265306123, 0.6542617046818727, 0.6518607442977191, 

epoch 1131
epoch 1132
epoch 1133
epoch 1134
epoch 1135
epoch 1136
epoch 1137
epoch 1138
epoch 1139
epoch 1140
accuracy: 0.6566626650660263
test_accuracies: [0.6564225690276111, 0.6581032412965186, 0.6561824729891956, 0.6578631452581032, 0.6585834333733493, 0.6427370948379352, 0.6581032412965186, 0.6501800720288116, 0.6453781512605041, 0.6518607442977191, 0.6472989195678271, 0.6496998799519809, 0.6458583433373349, 0.651140456182473, 0.6554621848739496, 0.6554621848739496, 0.6468187274909964, 0.6557022809123649, 0.6585834333733493, 0.6597839135654262, 0.6542617046818727, 0.6566626650660263, 0.6545018007202881, 0.6489795918367347, 0.6468187274909964, 0.653781512605042, 0.642016806722689, 0.6494597839135654, 0.6518607442977191, 0.6566626650660263]
epoch 1141
epoch 1142
epoch 1143
epoch 1144
epoch 1145
epoch 1146
epoch 1147
epoch 1148
epoch 1149
epoch 1150
accuracy: 0.6561824729891956
test_accuracies: [0.6581032412965186, 0.6561824729891956, 0.6578631452581032, 0.6585834333733493, 0.6427370

epoch 1241
epoch 1242
epoch 1243
epoch 1244
epoch 1245
epoch 1246
epoch 1247
epoch 1248
epoch 1249
epoch 1250
accuracy: 0.6499399759903961
test_accuracies: [0.6496998799519809, 0.6458583433373349, 0.651140456182473, 0.6554621848739496, 0.6554621848739496, 0.6468187274909964, 0.6557022809123649, 0.6585834333733493, 0.6597839135654262, 0.6542617046818727, 0.6566626650660263, 0.6545018007202881, 0.6489795918367347, 0.6468187274909964, 0.653781512605042, 0.642016806722689, 0.6494597839135654, 0.6518607442977191, 0.6566626650660263, 0.6561824729891956, 0.6523409363745498, 0.6547418967587035, 0.6496998799519809, 0.641296518607443, 0.6509003601440576, 0.6523409363745498, 0.6588235294117647, 0.6557022809123649, 0.6573829531812725, 0.6499399759903961]
epoch 1251
epoch 1252
epoch 1253
epoch 1254
epoch 1255
epoch 1256
epoch 1257
epoch 1258
epoch 1259
epoch 1260
accuracy: 0.6456182472989196
test_accuracies: [0.6458583433373349, 0.651140456182473, 0.6554621848739496, 0.6554621848739496, 0.646818727

epoch 1351
epoch 1352
epoch 1353
epoch 1354
epoch 1355
epoch 1356
epoch 1357
epoch 1358
epoch 1359
epoch 1360
accuracy: 0.6617046818727491
test_accuracies: [0.6545018007202881, 0.6489795918367347, 0.6468187274909964, 0.653781512605042, 0.642016806722689, 0.6494597839135654, 0.6518607442977191, 0.6566626650660263, 0.6561824729891956, 0.6523409363745498, 0.6547418967587035, 0.6496998799519809, 0.641296518607443, 0.6509003601440576, 0.6523409363745498, 0.6588235294117647, 0.6557022809123649, 0.6573829531812725, 0.6499399759903961, 0.6456182472989196, 0.6403361344537815, 0.6540216086434574, 0.6557022809123649, 0.6571428571428571, 0.6600240096038416, 0.6549819927971188, 0.6597839135654262, 0.6516206482593037, 0.6417767106842738, 0.6617046818727491]
epoch 1361
epoch 1362
epoch 1363
epoch 1364
epoch 1365
epoch 1366
epoch 1367
epoch 1368
epoch 1369
epoch 1370
accuracy: 0.6607442977190876
test_accuracies: [0.6489795918367347, 0.6468187274909964, 0.653781512605042, 0.642016806722689, 0.649459783

epoch 1461
epoch 1462
epoch 1463
epoch 1464
epoch 1465
epoch 1466
epoch 1467
epoch 1468
epoch 1469
epoch 1470
accuracy: 0.6593037214885955
test_accuracies: [0.6496998799519809, 0.641296518607443, 0.6509003601440576, 0.6523409363745498, 0.6588235294117647, 0.6557022809123649, 0.6573829531812725, 0.6499399759903961, 0.6456182472989196, 0.6403361344537815, 0.6540216086434574, 0.6557022809123649, 0.6571428571428571, 0.6600240096038416, 0.6549819927971188, 0.6597839135654262, 0.6516206482593037, 0.6417767106842738, 0.6617046818727491, 0.6607442977190876, 0.6499399759903961, 0.658343337334934, 0.65906362545018, 0.6470588235294117, 0.6617046818727491, 0.6552220888355342, 0.6561824729891956, 0.6581032412965186, 0.6547418967587035, 0.6593037214885955]
epoch 1471
epoch 1472
epoch 1473
epoch 1474
epoch 1475
epoch 1476
epoch 1477
epoch 1478
epoch 1479
epoch 1480
accuracy: 0.6468187274909964
test_accuracies: [0.641296518607443, 0.6509003601440576, 0.6523409363745498, 0.6588235294117647, 0.655702280

epoch 1571
epoch 1572
epoch 1573
epoch 1574
epoch 1575
epoch 1576
epoch 1577
epoch 1578
epoch 1579
epoch 1580
accuracy: 0.6379351740696279
test_accuracies: [0.6557022809123649, 0.6571428571428571, 0.6600240096038416, 0.6549819927971188, 0.6597839135654262, 0.6516206482593037, 0.6417767106842738, 0.6617046818727491, 0.6607442977190876, 0.6499399759903961, 0.658343337334934, 0.65906362545018, 0.6470588235294117, 0.6617046818727491, 0.6552220888355342, 0.6561824729891956, 0.6581032412965186, 0.6547418967587035, 0.6593037214885955, 0.6468187274909964, 0.6569027611044418, 0.6573829531812725, 0.6648259303721489, 0.6453781512605041, 0.6521008403361345, 0.648499399759904, 0.6470588235294117, 0.6525810324129652, 0.6535414165666267, 0.6379351740696279]
epoch 1581
epoch 1582
epoch 1583
epoch 1584
epoch 1585
epoch 1586
epoch 1587
epoch 1588
epoch 1589
epoch 1590
accuracy: 0.6482593037214885
test_accuracies: [0.6571428571428571, 0.6600240096038416, 0.6549819927971188, 0.6597839135654262, 0.65162064

epoch 1681
epoch 1682
epoch 1683
epoch 1684
epoch 1685
epoch 1686
epoch 1687
epoch 1688
epoch 1689
epoch 1690
accuracy: 0.6523409363745498
test_accuracies: [0.65906362545018, 0.6470588235294117, 0.6617046818727491, 0.6552220888355342, 0.6561824729891956, 0.6581032412965186, 0.6547418967587035, 0.6593037214885955, 0.6468187274909964, 0.6569027611044418, 0.6573829531812725, 0.6648259303721489, 0.6453781512605041, 0.6521008403361345, 0.648499399759904, 0.6470588235294117, 0.6525810324129652, 0.6535414165666267, 0.6379351740696279, 0.6482593037214885, 0.6617046818727491, 0.6561824729891956, 0.6542617046818727, 0.6429771908763505, 0.6545018007202881, 0.657623049219688, 0.6494597839135654, 0.6463385354141657, 0.6561824729891956, 0.6523409363745498]
epoch 1691
epoch 1692
epoch 1693
epoch 1694
epoch 1695
epoch 1696
epoch 1697
epoch 1698
epoch 1699
epoch 1700
accuracy: 0.6513805522208884
test_accuracies: [0.6470588235294117, 0.6617046818727491, 0.6552220888355342, 0.6561824729891956, 0.65810324

epoch 1791
epoch 1792
epoch 1793
epoch 1794
epoch 1795
epoch 1796
epoch 1797
epoch 1798
epoch 1799
epoch 1800
accuracy: 0.6501800720288116
test_accuracies: [0.6648259303721489, 0.6453781512605041, 0.6521008403361345, 0.648499399759904, 0.6470588235294117, 0.6525810324129652, 0.6535414165666267, 0.6379351740696279, 0.6482593037214885, 0.6617046818727491, 0.6561824729891956, 0.6542617046818727, 0.6429771908763505, 0.6545018007202881, 0.657623049219688, 0.6494597839135654, 0.6463385354141657, 0.6561824729891956, 0.6523409363745498, 0.6513805522208884, 0.6501800720288116, 0.6549819927971188, 0.6525810324129652, 0.6501800720288116, 0.6470588235294117, 0.6518607442977191, 0.657623049219688, 0.6573829531812725, 0.651140456182473, 0.6501800720288116]
epoch 1801
epoch 1802
epoch 1803
epoch 1804
epoch 1805
epoch 1806
epoch 1807
epoch 1808
epoch 1809
epoch 1810
accuracy: 0.6444177671068427
test_accuracies: [0.6453781512605041, 0.6521008403361345, 0.648499399759904, 0.6470588235294117, 0.652581032

epoch 1901
epoch 1902
epoch 1903
epoch 1904
epoch 1905
epoch 1906
epoch 1907
epoch 1908
epoch 1909
epoch 1910
accuracy: 0.657623049219688
test_accuracies: [0.6542617046818727, 0.6429771908763505, 0.6545018007202881, 0.657623049219688, 0.6494597839135654, 0.6463385354141657, 0.6561824729891956, 0.6523409363745498, 0.6513805522208884, 0.6501800720288116, 0.6549819927971188, 0.6525810324129652, 0.6501800720288116, 0.6470588235294117, 0.6518607442977191, 0.657623049219688, 0.6573829531812725, 0.651140456182473, 0.6501800720288116, 0.6444177671068427, 0.666266506602641, 0.6581032412965186, 0.6573829531812725, 0.663625450180072, 0.6518607442977191, 0.6602641056422569, 0.6554621848739496, 0.6578631452581032, 0.6566626650660263, 0.657623049219688]
epoch 1911
epoch 1912
epoch 1913
epoch 1914
epoch 1915
epoch 1916
epoch 1917
epoch 1918
epoch 1919
epoch 1920
accuracy: 0.6595438175270107
test_accuracies: [0.6429771908763505, 0.6545018007202881, 0.657623049219688, 0.6494597839135654, 0.646338535414

epoch 2011
epoch 2012
epoch 2013
epoch 2014
epoch 2015
epoch 2016
epoch 2017
epoch 2018
epoch 2019
epoch 2020
accuracy: 0.6477791116446578
test_accuracies: [0.6525810324129652, 0.6501800720288116, 0.6470588235294117, 0.6518607442977191, 0.657623049219688, 0.6573829531812725, 0.651140456182473, 0.6501800720288116, 0.6444177671068427, 0.666266506602641, 0.6581032412965186, 0.6573829531812725, 0.663625450180072, 0.6518607442977191, 0.6602641056422569, 0.6554621848739496, 0.6578631452581032, 0.6566626650660263, 0.657623049219688, 0.6595438175270107, 0.6557022809123649, 0.6595438175270107, 0.6578631452581032, 0.6451380552220889, 0.6521008403361345, 0.6573829531812725, 0.6549819927971188, 0.653781512605042, 0.6472989195678271, 0.6477791116446578]
epoch 2021
epoch 2022
epoch 2023
epoch 2024
epoch 2025
epoch 2026
epoch 2027
epoch 2028
epoch 2029
epoch 2030
accuracy: 0.6525810324129652
test_accuracies: [0.6501800720288116, 0.6470588235294117, 0.6518607442977191, 0.657623049219688, 0.65738295318

epoch 2121
epoch 2122
epoch 2123
epoch 2124
epoch 2125
epoch 2126
epoch 2127
epoch 2128
epoch 2129
epoch 2130
accuracy: 0.6458583433373349
test_accuracies: [0.6573829531812725, 0.663625450180072, 0.6518607442977191, 0.6602641056422569, 0.6554621848739496, 0.6578631452581032, 0.6566626650660263, 0.657623049219688, 0.6595438175270107, 0.6557022809123649, 0.6595438175270107, 0.6578631452581032, 0.6451380552220889, 0.6521008403361345, 0.6573829531812725, 0.6549819927971188, 0.653781512605042, 0.6472989195678271, 0.6477791116446578, 0.6525810324129652, 0.6602641056422569, 0.6554621848739496, 0.6523409363745498, 0.6340936374549819, 0.653781512605042, 0.6535414165666267, 0.651140456182473, 0.6494597839135654, 0.636014405762305, 0.6458583433373349]
epoch 2131
epoch 2132
epoch 2133
epoch 2134
epoch 2135
epoch 2136
epoch 2137
epoch 2138
epoch 2139
epoch 2140
accuracy: 0.6472989195678271
test_accuracies: [0.663625450180072, 0.6518607442977191, 0.6602641056422569, 0.6554621848739496, 0.65786314525

epoch 2231
epoch 2232
epoch 2233
epoch 2234
epoch 2235
epoch 2236
epoch 2237
epoch 2238
epoch 2239
epoch 2240
accuracy: 0.646578631452581
test_accuracies: [0.6578631452581032, 0.6451380552220889, 0.6521008403361345, 0.6573829531812725, 0.6549819927971188, 0.653781512605042, 0.6472989195678271, 0.6477791116446578, 0.6525810324129652, 0.6602641056422569, 0.6554621848739496, 0.6523409363745498, 0.6340936374549819, 0.653781512605042, 0.6535414165666267, 0.651140456182473, 0.6494597839135654, 0.636014405762305, 0.6458583433373349, 0.6472989195678271, 0.6470588235294117, 0.6475390156062425, 0.6410564225690276, 0.6521008403361345, 0.6506602641056423, 0.6525810324129652, 0.6554621848739496, 0.653061224489796, 0.6564225690276111, 0.646578631452581]
epoch 2241
epoch 2242
epoch 2243
epoch 2244
epoch 2245
epoch 2246
epoch 2247
epoch 2248
epoch 2249
epoch 2250
accuracy: 0.6516206482593037
test_accuracies: [0.6451380552220889, 0.6521008403361345, 0.6573829531812725, 0.6549819927971188, 0.65378151260

epoch 2341
epoch 2342
epoch 2343
epoch 2344
epoch 2345
epoch 2346
epoch 2347
epoch 2348
epoch 2349
epoch 2350
accuracy: 0.6444177671068427
test_accuracies: [0.6523409363745498, 0.6340936374549819, 0.653781512605042, 0.6535414165666267, 0.651140456182473, 0.6494597839135654, 0.636014405762305, 0.6458583433373349, 0.6472989195678271, 0.6470588235294117, 0.6475390156062425, 0.6410564225690276, 0.6521008403361345, 0.6506602641056423, 0.6525810324129652, 0.6554621848739496, 0.653061224489796, 0.6564225690276111, 0.646578631452581, 0.6516206482593037, 0.6535414165666267, 0.6585834333733493, 0.658343337334934, 0.6566626650660263, 0.6669867947178871, 0.6559423769507804, 0.6549819927971188, 0.6559423769507804, 0.6410564225690276, 0.6444177671068427]
epoch 2351
epoch 2352
epoch 2353
epoch 2354
epoch 2355
epoch 2356
epoch 2357
epoch 2358
epoch 2359
epoch 2360
accuracy: 0.6535414165666267
test_accuracies: [0.6340936374549819, 0.653781512605042, 0.6535414165666267, 0.651140456182473, 0.649459783913

epoch 2451
epoch 2452
epoch 2453
epoch 2454
epoch 2455
epoch 2456
epoch 2457
epoch 2458
epoch 2459
epoch 2460
accuracy: 0.6600240096038416
test_accuracies: [0.6410564225690276, 0.6521008403361345, 0.6506602641056423, 0.6525810324129652, 0.6554621848739496, 0.653061224489796, 0.6564225690276111, 0.646578631452581, 0.6516206482593037, 0.6535414165666267, 0.6585834333733493, 0.658343337334934, 0.6566626650660263, 0.6669867947178871, 0.6559423769507804, 0.6549819927971188, 0.6559423769507804, 0.6410564225690276, 0.6444177671068427, 0.6535414165666267, 0.6549819927971188, 0.6557022809123649, 0.6482593037214885, 0.6540216086434574, 0.6554621848739496, 0.6487394957983194, 0.6533013205282112, 0.6581032412965186, 0.6561824729891956, 0.6600240096038416]
epoch 2461
epoch 2462
epoch 2463
epoch 2464
epoch 2465
epoch 2466
epoch 2467
epoch 2468
epoch 2469
epoch 2470
accuracy: 0.643937575030012
test_accuracies: [0.6521008403361345, 0.6506602641056423, 0.6525810324129652, 0.6554621848739496, 0.65306122

epoch 2561
epoch 2562
epoch 2563
epoch 2564
epoch 2565
epoch 2566
epoch 2567
epoch 2568
epoch 2569
epoch 2570
accuracy: 0.658343337334934
test_accuracies: [0.658343337334934, 0.6566626650660263, 0.6669867947178871, 0.6559423769507804, 0.6549819927971188, 0.6559423769507804, 0.6410564225690276, 0.6444177671068427, 0.6535414165666267, 0.6549819927971188, 0.6557022809123649, 0.6482593037214885, 0.6540216086434574, 0.6554621848739496, 0.6487394957983194, 0.6533013205282112, 0.6581032412965186, 0.6561824729891956, 0.6600240096038416, 0.643937575030012, 0.64921968787515, 0.6557022809123649, 0.6554621848739496, 0.6564225690276111, 0.6528211284513805, 0.6559423769507804, 0.6523409363745498, 0.6480192076830733, 0.6501800720288116, 0.658343337334934]
epoch 2571
epoch 2572
epoch 2573
epoch 2574
epoch 2575
epoch 2576
epoch 2577
epoch 2578
epoch 2579
epoch 2580
accuracy: 0.657623049219688
test_accuracies: [0.6566626650660263, 0.6669867947178871, 0.6559423769507804, 0.6549819927971188, 0.65594237695

epoch 2671
epoch 2672
epoch 2673
epoch 2674
epoch 2675
epoch 2676
epoch 2677
epoch 2678
epoch 2679
epoch 2680
accuracy: 0.6552220888355342
test_accuracies: [0.6482593037214885, 0.6540216086434574, 0.6554621848739496, 0.6487394957983194, 0.6533013205282112, 0.6581032412965186, 0.6561824729891956, 0.6600240096038416, 0.643937575030012, 0.64921968787515, 0.6557022809123649, 0.6554621848739496, 0.6564225690276111, 0.6528211284513805, 0.6559423769507804, 0.6523409363745498, 0.6480192076830733, 0.6501800720288116, 0.658343337334934, 0.657623049219688, 0.6585834333733493, 0.6588235294117647, 0.6602641056422569, 0.65906362545018, 0.6540216086434574, 0.6533013205282112, 0.6561824729891956, 0.6523409363745498, 0.6494597839135654, 0.6552220888355342]
epoch 2681
epoch 2682
epoch 2683
epoch 2684
epoch 2685
epoch 2686
epoch 2687
epoch 2688
epoch 2689
epoch 2690
accuracy: 0.6518607442977191
test_accuracies: [0.6540216086434574, 0.6554621848739496, 0.6487394957983194, 0.6533013205282112, 0.65810324129

epoch 2781
epoch 2782
epoch 2783
epoch 2784
epoch 2785
epoch 2786
epoch 2787
epoch 2788
epoch 2789
epoch 2790
accuracy: 0.6602641056422569
test_accuracies: [0.6554621848739496, 0.6564225690276111, 0.6528211284513805, 0.6559423769507804, 0.6523409363745498, 0.6480192076830733, 0.6501800720288116, 0.658343337334934, 0.657623049219688, 0.6585834333733493, 0.6588235294117647, 0.6602641056422569, 0.65906362545018, 0.6540216086434574, 0.6533013205282112, 0.6561824729891956, 0.6523409363745498, 0.6494597839135654, 0.6552220888355342, 0.6518607442977191, 0.6612244897959183, 0.6602641056422569, 0.64921968787515, 0.64921968787515, 0.6547418967587035, 0.6629051620648259, 0.6540216086434574, 0.6645858343337334, 0.6607442977190876, 0.6602641056422569]
epoch 2791
epoch 2792
epoch 2793
epoch 2794
epoch 2795
epoch 2796
epoch 2797
epoch 2798
epoch 2799
epoch 2800
accuracy: 0.6549819927971188
test_accuracies: [0.6564225690276111, 0.6528211284513805, 0.6559423769507804, 0.6523409363745498, 0.648019207683

epoch 2891
epoch 2892
epoch 2893
epoch 2894
epoch 2895
epoch 2896
epoch 2897
epoch 2898
epoch 2899
epoch 2900
accuracy: 0.6595438175270107
test_accuracies: [0.6602641056422569, 0.65906362545018, 0.6540216086434574, 0.6533013205282112, 0.6561824729891956, 0.6523409363745498, 0.6494597839135654, 0.6552220888355342, 0.6518607442977191, 0.6612244897959183, 0.6602641056422569, 0.64921968787515, 0.64921968787515, 0.6547418967587035, 0.6629051620648259, 0.6540216086434574, 0.6645858343337334, 0.6607442977190876, 0.6602641056422569, 0.6549819927971188, 0.6571428571428571, 0.6434573829531813, 0.6578631452581032, 0.6417767106842738, 0.6619447779111645, 0.6535414165666267, 0.6595438175270107, 0.6605042016806723, 0.6559423769507804, 0.6595438175270107]
epoch 2901
epoch 2902
epoch 2903
epoch 2904
epoch 2905
epoch 2906
epoch 2907
epoch 2908
epoch 2909
epoch 2910
accuracy: 0.6561824729891956
test_accuracies: [0.65906362545018, 0.6540216086434574, 0.6533013205282112, 0.6561824729891956, 0.652340936374

In [13]:
print("Optimized rotation angles: {}".format(params))

training_time = time.time()

Optimized rotation angles: [ 1.14277742  2.89032394  0.78507158  0.82246409 -0.16620923  0.1504343
 -2.08231968  0.70290122  0.52019361 -0.26239755 -0.60214949  0.41877036
  0.88214876  1.03109232  0.72760926]


# Testing

In [14]:
accuracy = measure_accuracy(X_test, y_test, params)
print(accuracy)

test_time = time.time()

0.6490476190476191


In [15]:
print(f"pre-processing time: {preprocessing_time-initial_time}")
print(f"training time: {training_time - preprocessing_time}")
print(f"test time: {test_time - training_time}")
print(f"total time: {test_time - initial_time}")

pre-processing time: 77.05595469474792
training time: 34120.02218103409
test time: 240.83839988708496
total time: 34437.91653561592
