# Supervised Learning with Kernel Methods

Here we use a variational classifier to learn a nonlinear boundary using kernel methods.

The variational circuit rchitecture is specified by [Farhi and Neven (2018)](https://arxiv.org/abs/1802.06002). 
The kernel map is specified by [Havlicek et al (2018)](https://arxiv.org/abs/1804.11326)

In [1]:
import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import NesterovMomentumOptimizer

from scipy.stats import unitary_group

In [2]:
dev = qml.device('default.qubit', wires=2)

In [3]:
def U_phi(x):
    #print(x, np.shape(x))
    # x3 := (pi - x1)(pi - x2)
    x_0, x_1, x_2 = x[0], x[1], x[2]
    #print(x_0, x_1, x_2)
        
    qml.RZ( x_0 , wires=0)
    qml.RZ( x_1 , wires=1)
    
    qml.CNOT(wires=[0,1])
    qml.RZ(x_2,wires=1)
    qml.CNOT(wires=[0,1])

In [4]:
def layer(W): # 6 weights are specified at each layer
    
    # euler angles
    qml.Rot(W[0, 0], W[0, 1], W[0, 2], wires=0)
    qml.Rot(W[1, 0], W[1, 1], W[1, 2], wires=1)

    qml.CNOT(wires=[0, 1])

In [5]:
def featuremap(x):
    for i in range(2):
        qml.Hadamard(wires=0)
        qml.Hadamard(wires=1)
        U_phi(x)

In [6]:
@qml.qnode(dev)
def circuit1(weights, x):

    featuremap(x)

    for W in weights:
        layer(W)

    return qml.expval.PauliZ(wires=0)

@qml.qnode(dev)
def circuit2(weights, x):

    featuremap(x)

    for W in weights:
        layer(W)

    return qml.expval.PauliZ(wires=1)

In [7]:
def variational_classifier(var, x): # x is a keyword argument -> fixed (not trained)
    weights = var[0]
    bias = var[1]

    return circuit1(weights, x) * circuit2(weights, x) + bias

In [8]:
def square_loss(labels, predictions):

    loss = 0
    for l, p in zip(labels, predictions):
        loss = loss + (l - p) ** 2
    loss = loss / len(labels)

    return loss

In [18]:
def accuracy(labels, predictions):
    print(labels, predictions)
    loss = 0
    for l, p in zip(labels, predictions):
        if abs(l - p) < 1e-5:
            loss = loss + 1
    loss = loss / len(labels)

    return loss

In [10]:
def cost(var, X, Y):

    predictions = [variational_classifier(var, x) for x in X]
    if (len(Y) == num_data):
        print("[(pred, label), ...]: ", list(zip(predictions, Y)))
    return square_loss(Y, predictions) 

In [11]:
random_U = unitary_group.rvs(4)
random_U = random_U / (np.linalg.det(random_U) ** (1/4)) # so that det = 1


print("random unitary: ", random_U)
print("det", np.linalg.det(random_U))

random unitary:  [[-0.18162439-0.0564827 j -0.46921616+0.35182063j -0.18429372+0.71136095j
   0.15362355+0.23723818j]
 [ 0.16715375+0.52652749j  0.45677116+0.07710775j -0.26365744+0.07976064j
  -0.05430152+0.63357499j]
 [ 0.56985238-0.46167089j -0.06118629+0.56898051j -0.03623803-0.27996401j
  -0.07567228+0.22187009j]
 [-0.27638715-0.21068722j -0.22530634-0.25145063j  0.51402037-0.20053536j
  -0.01917148+0.67855121j]]
det (1-1.2836953722228375e-16j)


In [12]:
@qml.qnode(dev)
def data_label_1(x):
    #print(u)
    #print("label the following:", x)
    featuremap(x)
    qml.QubitUnitary(random_U, wires=[0,1])
    
    return qml.expval.PauliZ(wires=0)

@qml.qnode(dev)
def data_label_2(x):
    #print("label the following:", x)
    featuremap(x)
    qml.QubitUnitary(random_U, wires=[0,1])
    
    return qml.expval.PauliZ(wires=1)

In [13]:
thresh = 0.3

X = np.array([])
Y = np.array([])
ctr = 0 # num valid data pts
maxval = 0.0
minval = 0.0

np.random.seed(0)

while ctr < 40:
    x = np.random.rand(2) * 2 * np.pi
    x = np.append(x, (np.pi - x[0]) * (np.pi - x[1]))
    y_1 = data_label_1(x)
    y_2 = data_label_2(x)
    #print(y_1, y_2, y_1 * y_2)
    if (np.abs(y_1 * y_2) > maxval):
        maxval = y_1 * y_2
        print("new max separation: ", maxval)
    elif (y_1 * y_2 < minval):
        minval = y_1 * y_2
        print("new min separation: ", minval)
        
    if y_1 * y_2 > thresh:
        Y = np.append(Y, +1)
        X = np.append(X, x)
        ctr += 1
        print("+1")
    elif y_1 * y_2 < -1 * thresh:
        Y = np.append(Y, -1)
        X = np.append(X, x)
        ctr += 1
        print("-1")

new max separation:  -0.022712201818728222
new max separation:  0.05424512049119921
new max separation:  -0.11149464029858032
new max separation:  0.21272145481532936
new min separation:  -0.02488334739968215
new min separation:  -0.15068202989514268
new max separation:  -0.3994676946123235
-1
new max separation:  -0.006211618700467554
new max separation:  0.1422915430133009
new max separation:  -0.1995703597559318
new max separation:  -0.1220707091424859
new max separation:  0.10902943610614052
new max separation:  0.17724482877543116
new max separation:  -0.26515825998189013
new max separation:  -0.22340171561511674
new max separation:  0.10590256058931216
new max separation:  0.1218657623659504
new max separation:  -0.6053511176490141
-1
new max separation:  0.11505356882003209
new max separation:  0.8540874911941555
+1
new min separation:  -0.2510149117117334
new min separation:  -0.6461257785348852
-1
+1
+1
+1
-1
+1
+1
-1
-1
-1
+1
+1
-1
-1
+1
-1
+1
-1
new min separation:  -0.68408

In [14]:
X = X.reshape(-1, 3)
print(X)
print(Y)

[[ 4.88930306  5.46644755  4.0631731 ]
 [ 1.23516341  2.31676857  1.57246875]
 [ 6.13717092  3.80035648  1.97337861]
 [ 4.61936028  6.04560893  4.29146123]
 [ 5.1132425   2.49131905 -1.28211186]
 [ 5.53613466  3.65224517  1.22277891]
 [ 6.00725065  4.04630976  2.5926098 ]
 [ 4.48771044  6.27594084  4.21920183]
 [ 3.62683419  1.49472468 -0.79912874]
 [ 5.86983967  3.85766187  1.95361371]
 [ 0.15506102  0.4225419   8.12053109]
 [ 6.22248312  1.36280395 -5.48025318]
 [ 1.30357293  2.6683775   0.86977879]
 [ 3.55264688  1.15158117 -0.81800263]
 [ 4.55904385  0.07180084 -4.35128007]
 [ 1.78140145  2.38715147  1.02618426]
 [ 1.63000004  2.34873722  1.19847442]
 [ 5.98080373  3.61755124  1.35134689]
 [ 5.15703191  5.7104335   5.17734269]
 [ 6.21664498  0.41031843 -8.39881113]
 [ 1.4616385   2.18981178  1.59894823]
 [ 5.97501147  1.46662272 -4.74589133]
 [ 1.7117715   2.38168472  1.08653244]
 [ 5.27278076  1.49377595 -3.51180737]
 [ 0.95727285  2.62314426  1.1324571 ]
 [ 6.08083326  3.43617918

In [15]:
num_data = len(Y)
num_train = int(0.5 * num_data)

print(num_data, num_train)

index = np.random.permutation(range(num_data))
X_train = X[index[:num_train]]
Y_train = Y[index[:num_train]]

X_test = X[index[num_train:]]
Y_test = Y[index[num_train:]]

40 20


In [16]:
num_qubits = 2
num_layers = 6
var_init = (0.01 * np.random.randn(num_layers, num_qubits, 3), 0.0)

In [17]:
opt = NesterovMomentumOptimizer(0.01)
batch_size = 5

# train the variational classifier
var = var_init
for it in range(200):

    # Update the weights by one optimizer step
    batch_index = np.random.randint(0, num_train, (batch_size, ))
    X_train_batch = X_train[batch_index]
    Y_train_batch = Y_train[batch_index]
    var = opt.step(lambda v: cost(v, X_train_batch, Y_train_batch), var)

    # Compute predictions on train and validation set
    predictions_train = [np.sign(variational_classifier(var, x)) for f in X_train]
    predictions_test = [np.sign(variational_classifier(var, x)) for f in X_test]
    
    # Compute accuracy on train and validation set
    acc_train = accuracy(Y_train, predictions_train)
    acc_test = accuracy(Y_test, predictions_test)

    print("Iter: {:5d} | Cost: {:0.7f} | Acc train: {:0.7f} | Acc validation: {:0.7f} "
          "".format(it+1, cost(var, X, Y), acc_train, acc_test))

[(pred, label), ...]:  [(0.06633853010549862, -1.0), (0.01373184079976447, -1.0), (-0.10471900942317844, 1.0), (0.00028403747678261773, -1.0), (-0.014165984076751657, 1.0), (-0.05906532345589827, 1.0), (-0.4145266265758054, 1.0), (-0.03660467852163929, -1.0), (-0.01951335843232637, 1.0), (-0.0818233942146746, 1.0), (0.07914809245664634, -1.0), (0.10074695329972971, -1.0), (-0.08051838464324262, -1.0), (-0.15920488967541088, 1.0), (-0.00368363134438833, 1.0), (0.062369524862901426, -1.0), (0.023368584134035204, -1.0), (-0.02700314879217639, 1.0), (0.078334234435845, -1.0), (0.2569602118805494, 1.0), (0.013613109428815528, -1.0), (0.01388880693652942, -1.0), (0.04206677894823284, -1.0), (0.054661388679618644, 1.0), (-0.07722697822167693, -1.0), (-0.3808721306418403, 1.0), (0.11132899080690092, -1.0), (0.001822148261042232, -1.0), (-0.11154177286507022, 1.0), (-0.020228281555113366, -1.0), (0.001588234616479254, 1.0), (0.00616342780524481, -1.0), (-0.01016337382395937, -1.0), (-0.06292352

[(pred, label), ...]:  [(0.11275483172281338, -1.0), (0.10219891737621262, -1.0), (0.09902751871984235, 1.0), (0.0905275509914577, -1.0), (0.10266834399733718, 1.0), (-0.058095845045733646, 1.0), (0.01296421876150354, 1.0), (0.07604450319264489, -1.0), (0.23075473322857054, 1.0), (0.10968732844989582, 1.0), (0.1689292486342269, -1.0), (0.24226620080768888, -1.0), (0.0027648403364185364, -1.0), (0.15414264622803023, 1.0), (0.07224849246610957, 1.0), (0.11465281954694949, -1.0), (0.08421660068870873, -1.0), (0.010634612671276453, 1.0), (0.1756615065841452, -1.0), (0.2681785324713483, 1.0), (0.09492035441191238, -1.0), (0.10722833282568811, -1.0), (0.09719211850842645, -1.0), (0.25768427244115094, 1.0), (0.02883621466788855, -1.0), (-0.3545789950073808, 1.0), (0.16766452941027107, -1.0), (0.10671702696950443, -1.0), (0.24749544059816211, 1.0), (0.04740514884427624, -1.0), (0.10253386592932377, 1.0), (0.0941643223767675, -1.0), (0.027140888792732007, -1.0), (0.09126034069770834, -1.0), (0.

[(pred, label), ...]:  [(0.0740001374959797, -1.0), (0.061821490795226273, -1.0), (0.1858563936038244, 1.0), (0.05000872439213566, -1.0), (0.08182201966331373, 1.0), (-0.1229053016172864, 1.0), (0.30525037848318926, 1.0), (0.041104880992689545, -1.0), (0.33271036341266097, 1.0), (0.17026735486974104, 1.0), (0.144538721770224, -1.0), (0.14804391293913957, -1.0), (-0.07547021917508505, -1.0), (0.3285159123256313, 1.0), (0.04189743404387615, 1.0), (0.08028444005907842, -1.0), (0.04085099886805014, -1.0), (-0.03856360314755619, 1.0), (0.09007784888604682, -1.0), (0.06768549004597685, 1.0), (0.05089369504852192, -1.0), (0.08425067749721392, -1.0), (0.05821769797061507, -1.0), (0.3731179325607689, 1.0), (-0.05642012690431884, -1.0), (-0.39993922361578693, 1.0), (0.14879103596256127, -1.0), (0.08435923122053625, -1.0), (0.43613689828947017, 1.0), (-0.019554887385143216, -1.0), (0.0882738904712093, 1.0), (0.03988584465630541, -1.0), (-0.051758252192692325, -1.0), (0.14038251082684083, -1.0), (

[(pred, label), ...]:  [(-0.19445649899432715, -1.0), (-0.14543541753440759, -1.0), (0.09833158261414122, 1.0), (-0.20662036056907807, -1.0), (-0.15761748041137855, 1.0), (-0.33869176533203776, 1.0), (0.2977833819202687, 1.0), (-0.20895568804444958, -1.0), (0.20152451375457586, 1.0), (0.05409050977814327, 1.0), (-0.037683946720350414, -1.0), (-0.21721136226171406, -1.0), (-0.39872460250191083, -1.0), (0.2503147588078872, 1.0), (-0.1198120188276757, 1.0), (-0.17866929541744495, -1.0), (-0.2043899948917725, -1.0), (-0.21951468910706745, 1.0), (-0.13627459167780862, -1.0), (-0.1655700756244039, 1.0), (-0.15772796585675147, -1.0), (-0.06069529685821286, -1.0), (-0.19782313829606496, -1.0), (0.3035688944280203, 1.0), (-0.33315823781849974, -1.0), (-0.5349843698951556, 1.0), (-0.09999889665641855, -1.0), (-0.09800817914372487, -1.0), (0.3293380011253641, 1.0), (-0.32061571746259754, -1.0), (-0.06501088785816199, 1.0), (-0.18881066507389593, -1.0), (-0.39297952591482493, -1.0), (0.01810244638

[(pred, label), ...]:  [(-0.4565760277256458, -1.0), (-0.25402009195781894, -1.0), (0.1236855475635193, 1.0), (-0.4679362354618974, -1.0), (-0.48487061827064365, 1.0), (-0.4486439514684682, 1.0), (0.3535362435077152, 1.0), (-0.4775491902458772, -1.0), (0.09551374211044644, 1.0), (0.036258496246887734, 1.0), (-0.1493772223970463, -1.0), (-0.637511810821365, -1.0), (-0.661376972195199, -1.0), (0.2237455453941341, 1.0), (-0.17072987970647213, 1.0), (-0.4379523589851587, -1.0), (-0.4258257394152285, -1.0), (-0.272148166011951, 1.0), (-0.2952869747261522, -1.0), (-0.12048852807811258, 1.0), (-0.2967462060763427, -1.0), (-0.12303025950810872, -1.0), (-0.4460029962531674, -1.0), (0.30918886110604915, 1.0), (-0.4721825414122064, -1.0), (-0.45667229739055226, 1.0), (-0.3556400521343623, -1.0), (-0.18701005557972064, -1.0), (0.22239637799304868, 1.0), (-0.5834340248393384, -1.0), (-0.08344173499881465, 1.0), (-0.3024564859217813, -1.0), (-0.685181401890104, -1.0), (-0.009795257532745183, -1.0), 

[(pred, label), ...]:  [(-0.6634053035570584, -1.0), (-0.45110954708180284, -1.0), (0.38140649407121374, 1.0), (-0.6677733837690577, -1.0), (-0.7480522830775356, 1.0), (-0.2467324190508265, 1.0), (0.5102609442254723, 1.0), (-0.6094004488976514, -1.0), (0.10711770450470436, 1.0), (0.19709766340347057, 1.0), (-0.23873128925552173, -1.0), (-0.8232926967520969, -1.0), (-0.7339853735582772, -1.0), (0.3508670338938376, 1.0), (-0.14894323188991923, 1.0), (-0.6734653051234541, -1.0), (-0.6932721221791185, -1.0), (0.010699848193842995, 1.0), (-0.5222634431505481, -1.0), (0.20282168393526756, 1.0), (-0.5594028042913991, -1.0), (-0.22601754510505084, -1.0), (-0.6971633840526328, -1.0), (0.2263624148403885, 1.0), (-0.44641620023854645, -1.0), (-0.08398839122201993, 1.0), (-0.5540367896520937, -1.0), (-0.41758660595014324, -1.0), (0.24961164458393986, 1.0), (-0.7419916271556742, -1.0), (0.08929016727966836, 1.0), (-0.41644656704863986, -1.0), (-0.7443272756383267, -1.0), (-0.2265996904710832, -1.0)

[(pred, label), ...]:  [(-0.4971759096377099, -1.0), (-0.5933182777208338, -1.0), (0.7805237345417887, 1.0), (-0.6021678277550988, -1.0), (-0.6047370629379714, 1.0), (0.08136878610330234, 1.0), (0.711354725028133, 1.0), (-0.4636741429022591, -1.0), (0.3715174202778332, 1.0), (0.5493988251671338, 1.0), (-0.12166917946768152, -1.0), (-0.5683191427932373, -1.0), (-0.56128829079143, -1.0), (0.6418289055608949, 1.0), (0.08087755781224865, 1.0), (-0.5360022379489888, -1.0), (-0.6677445722961667, -1.0), (0.35209127913050997, 1.0), (-0.5538621024543469, -1.0), (0.47401075625132966, 1.0), (-0.6664399342761045, -1.0), (-0.28399266233109893, -1.0), (-0.6043142565530442, -1.0), (0.2745810208190022, 1.0), (-0.35872909420565124, -1.0), (0.15585769791810236, 1.0), (-0.3466175892009393, -1.0), (-0.5852499322838648, -1.0), (0.5180604934113462, 1.0), (-0.6325184589664985, -1.0), (0.4723484581208845, 1.0), (-0.48835847605243204, -1.0), (-0.49180582718414173, -1.0), (-0.430586145966442, -1.0), (-0.5568651

[(pred, label), ...]:  [(-0.23853424183317518, -1.0), (-0.5700815538503993, -1.0), (0.9763883502217688, 1.0), (-0.497091076249883, -1.0), (-0.2539984086951187, 1.0), (0.3195722138189453, 1.0), (0.7608479860291398, 1.0), (-0.31203150525541434, -1.0), (0.6690946361664606, 1.0), (0.7887414590013575, 1.0), (0.14541899595285887, -1.0), (-0.28425648739797205, -1.0), (-0.4099539663729822, -1.0), (0.8470314550535525, 1.0), (0.5713221668462922, 1.0), (-0.29057495762937235, -1.0), (-0.47516590495805766, -1.0), (0.5686725612987295, 1.0), (-0.3113046480727777, -1.0), (0.6934032284831065, 1.0), (-0.5246048996392055, -1.0), (-0.2794379781070063, -1.0), (-0.3817460340943155, -1.0), (0.4573244379555078, 1.0), (-0.2787367013192841, -1.0), (0.3460989628222108, 1.0), (-0.058011595787281434, -1.0), (-0.5106407029302652, -1.0), (0.7801845431470916, 1.0), (-0.49165408719693837, -1.0), (0.792988666990004, 1.0), (-0.47656977457169436, -1.0), (-0.27688824575220783, -1.0), (-0.38617951539245887, -1.0), (-0.5375

[(pred, label), ...]:  [(-0.26149321788147, -1.0), (-0.5897692251022282, -1.0), (0.9187726353792809, 1.0), (-0.5582022401641771, -1.0), (-0.11301829828637239, 1.0), (0.33924024721723034, 1.0), (0.6423942031060998, 1.0), (-0.33878550181563777, -1.0), (0.6823135796187794, 1.0), (0.7451261487544902, 1.0), (0.2272242622196518, -1.0), (-0.23176234304995164, -1.0), (-0.4760735514285865, -1.0), (0.8086616226696903, 1.0), (0.7770660475511588, 1.0), (-0.31434125503181803, -1.0), (-0.4894547342146089, -1.0), (0.58485544796318, 1.0), (-0.19432950469243326, -1.0), (0.6870614650083744, 1.0), (-0.5004063561042535, -1.0), (-0.27890737883674827, -1.0), (-0.4052993739051218, -1.0), (0.47833211622886984, 1.0), (-0.3116945013321855, -1.0), (0.4012633363536213, 1.0), (-0.06296365279785388, -1.0), (-0.49368888119744453, -1.0), (0.7718882324348395, 1.0), (-0.5597584398252831, -1.0), (0.8999788055730529, 1.0), (-0.5163259291567077, -1.0), (-0.34824801339965855, -1.0), (-0.37600647927529424, -1.0), (-0.512151

[(pred, label), ...]:  [(-0.4303311793913025, -1.0), (-0.6586686132189149, -1.0), (0.8566565985046479, 1.0), (-0.6534946838488254, -1.0), (-0.16568058862911542, 1.0), (0.2391822182037991, 1.0), (0.5710872087845821, 1.0), (-0.4247264310137009, -1.0), (0.5914774403060892, 1.0), (0.6595517017189505, 1.0), (0.15974532161218175, -1.0), (-0.3691551515009317, -1.0), (-0.5996883257050782, -1.0), (0.7406709280877978, 1.0), (0.6647728155464987, 1.0), (-0.47730993164146907, -1.0), (-0.6271278276719455, -1.0), (0.5216387692835937, 1.0), (-0.2676822965982987, -1.0), (0.592929506677322, 1.0), (-0.6010869657995599, -1.0), (-0.2900120297715695, -1.0), (-0.5592486201593772, -1.0), (0.41919978789546414, 1.0), (-0.38939256258813504, -1.0), (0.3478499991016766, 1.0), (-0.22891021597194414, -1.0), (-0.5700995991621556, -1.0), (0.6947028896555792, 1.0), (-0.6897597115964412, -1.0), (0.8426771429523419, 1.0), (-0.58109122543555, -1.0), (-0.5202000870823261, -1.0), (-0.4390458926885741, -1.0), (-0.53942735933

[(pred, label), ...]:  [(-0.482002160670176, -1.0), (-0.6882274198211702, -1.0), (0.8895293967113023, 1.0), (-0.6488732306960348, -1.0), (-0.24539855798249843, 1.0), (0.2147016749240689, 1.0), (0.605641036964671, 1.0), (-0.4216894598088333, -1.0), (0.5646589724844902, 1.0), (0.668201263655757, 1.0), (0.11444510604323083, -1.0), (-0.4138816958760979, -1.0), (-0.5886954118444566, -1.0), (0.7543781924491475, 1.0), (0.561101835217048, 1.0), (-0.5291560832558561, -1.0), (-0.6741822438881657, -1.0), (0.5333124159427891, 1.0), (-0.37071500234282284, -1.0), (0.576297896296514, 1.0), (-0.6618000128894379, -1.0), (-0.3086117713751556, -1.0), (-0.6068504048237933, -1.0), (0.4272449664136111, 1.0), (-0.3946344125639658, -1.0), (0.34742371588826687, 1.0), (-0.2952094504092699, -1.0), (-0.6251578568007498, -1.0), (0.681499522213033, 1.0), (-0.6881464971782577, -1.0), (0.7795279367372953, 1.0), (-0.5932428972930268, -1.0), (-0.5281252175335345, -1.0), (-0.492727365936677, -1.0), (-0.6042307134976421,

[(pred, label), ...]:  [(-0.3974682030461014, -1.0), (-0.6375715016110844, -1.0), (0.9925151724775649, 1.0), (-0.6216904170098656, -1.0), (-0.22706700678562045, 1.0), (0.28906941283931925, 1.0), (0.7135157113318732, 1.0), (-0.42595625423370076, -1.0), (0.6207771706628687, 1.0), (0.7781493113634851, 1.0), (0.1314136420452655, -1.0), (-0.4289367119522731, -1.0), (-0.5525629956782427, -1.0), (0.8395605953586377, 1.0), (0.5439551744583478, 1.0), (-0.44679028868597914, -1.0), (-0.5975732764972225, -1.0), (0.6229871868410874, 1.0), (-0.3418104221225028, -1.0), (0.6124391055804366, 1.0), (-0.5917766078060082, -1.0), (-0.3289684132235454, -1.0), (-0.5262535930544993, -1.0), (0.5396704006733245, 1.0), (-0.4063024269678145, -1.0), (0.4100861223910716, 1.0), (-0.21486171285661534, -1.0), (-0.5567230019665717, -1.0), (0.7495230256956085, 1.0), (-0.63134132590656, -1.0), (0.7098116164441453, 1.0), (-0.5700178884337116, -1.0), (-0.4555947088247325, -1.0), (-0.4178130223378055, -1.0), (-0.58725630351

[(pred, label), ...]:  [(-0.32775023737484377, -1.0), (-0.6257055321486847, -1.0), (1.075808997735272, 1.0), (-0.6198302340047119, -1.0), (-0.10685194227127431, 1.0), (0.39763069096875814, 1.0), (0.764608719899145, 1.0), (-0.42321494606645693, -1.0), (0.6661358758415274, 1.0), (0.8702379081756648, 1.0), (0.20086674639823826, -1.0), (-0.3172854214805497, -1.0), (-0.561431217467235, -1.0), (0.8889530393964323, 1.0), (0.6010275494688793, 1.0), (-0.3776079837724712, -1.0), (-0.5381693152443223, -1.0), (0.7363601883714926, 1.0), (-0.22855745869995514, -1.0), (0.6714762846347018, 1.0), (-0.530276652088853, -1.0), (-0.29715110687850843, -1.0), (-0.4624182814714184, -1.0), (0.5900394392533621, 1.0), (-0.4346765506423492, -1.0), (0.5148534248103083, 1.0), (-0.13400883251035378, -1.0), (-0.512451437472019, -1.0), (0.7855424421145953, 1.0), (-0.6228122650798267, -1.0), (0.7610007778144553, 1.0), (-0.586764644445873, -1.0), (-0.4262203511506918, -1.0), (-0.3855656308022515, -1.0), (-0.538487665697

[(pred, label), ...]:  [(-0.23553704303387624, -1.0), (-0.5697740629990159, -1.0), (1.1440750379452505, 1.0), (-0.5932807836188221, -1.0), (0.0658643847463509, 1.0), (0.5153788458904213, 1.0), (0.7908647760407022, 1.0), (-0.39286077629578653, -1.0), (0.732911171609655, 1.0), (0.9508619225512571, 1.0), (0.29603228758573175, -1.0), (-0.18409830493292806, -1.0), (-0.5484816741199422, -1.0), (0.9396752879200477, 1.0), (0.7161693900871413, 1.0), (-0.28446032376804953, -1.0), (-0.44672242241897064, -1.0), (0.8510539643351103, 1.0), (-0.06265045007194148, -1.0), (0.7541495188209253, 1.0), (-0.42306453367009456, -1.0), (-0.2276092690830802, -1.0), (-0.37189153288765453, -1.0), (0.6379518211557288, 1.0), (-0.42769328555626374, -1.0), (0.6328798422603923, 1.0), (-0.03202755521080852, -1.0), (-0.4195232298814604, -1.0), (0.8304390830981757, 1.0), (-0.5928852581751806, -1.0), (0.874105173369062, 1.0), (-0.5641975402712757, -1.0), (-0.38143189081636975, -1.0), (-0.3064779278147294, -1.0), (-0.42972

[(pred, label), ...]:  [(-0.3070164069274566, -1.0), (-0.6208944303378928, -1.0), (1.1101009629156375, 1.0), (-0.6085761220098962, -1.0), (0.07450605352643175, 1.0), (0.5016436747517222, 1.0), (0.7200306778616212, 1.0), (-0.3859334930421229, -1.0), (0.6451953345752233, 1.0), (0.9208537853657488, 1.0), (0.2600307993951424, -1.0), (-0.22938493222607592, -1.0), (-0.5619222010608463, -1.0), (0.8496449933741008, 1.0), (0.6554152927581186, 1.0), (-0.35611023140580267, -1.0), (-0.5155368440526567, -1.0), (0.8605287706097481, 1.0), (-0.1260712168904634, -1.0), (0.7647048265783136, 1.0), (-0.4940949435050811, -1.0), (-0.24345395727187824, -1.0), (-0.4415885483169132, -1.0), (0.5841474869135945, 1.0), (-0.42731740594420364, -1.0), (0.6411570193125116, 1.0), (-0.10864830047488433, -1.0), (-0.48673154287469744, -1.0), (0.7265377053171176, 1.0), (-0.6254220371781773, -1.0), (0.8872191221304464, 1.0), (-0.5921729003275519, -1.0), (-0.4210242944836853, -1.0), (-0.3737734118268544, -1.0), (-0.47894519

[(pred, label), ...]:  [(-0.6242060237453313, -1.0), (-0.6631437567648701, -1.0), (0.8866354967687755, 1.0), (-0.5983153626581974, -1.0), (-0.0713255077724344, 1.0), (0.2614139929499544, 1.0), (0.48839669240596173, 1.0), (-0.3300215839067397, -1.0), (0.40836747420774366, 1.0), (0.6633821398583796, 1.0), (0.09372915278120121, -1.0), (-0.3835440828716414, -1.0), (-0.5665060852561017, -1.0), (0.5935661163119192, 1.0), (0.41372409338547256, 1.0), (-0.6640223015604998, -1.0), (-0.7652283582498189, -1.0), (0.7019551968221596, 1.0), (-0.3916452865910253, -1.0), (0.7087627981292093, 1.0), (-0.7137652790279956, -1.0), (-0.20256452676448095, -1.0), (-0.7215616100772659, -1.0), (0.3852469758642493, 1.0), (-0.3078048876739037, -1.0), (0.5512219720875695, 1.0), (-0.47037328382746696, -1.0), (-0.6404607971845728, -1.0), (0.4620364770916181, 1.0), (-0.6962875331526872, -1.0), (0.8968733868968524, 1.0), (-0.5286241684361737, -1.0), (-0.5916197907843801, -1.0), (-0.5045979818399653, -1.0), (-0.55671419

[(pred, label), ...]:  [(-0.6479372679721896, -1.0), (-0.7962558919242438, -1.0), (0.845378520340502, 1.0), (-0.729665705527393, -1.0), (-0.16896512503144098, 1.0), (0.21831127812233977, 1.0), (0.4450307390605196, 1.0), (-0.48519272536862135, -1.0), (0.342772177065512, 1.0), (0.6371419199181328, 1.0), (0.0014138461404297822, -1.0), (-0.5204663766283927, -1.0), (-0.6852499002069654, -1.0), (0.5407395075565116, 1.0), (0.3170710272291526, 1.0), (-0.6927441275793718, -1.0), (-0.8213779149568712, -1.0), (0.6443894768621364, 1.0), (-0.4847357294681579, -1.0), (0.6077877375262252, 1.0), (-0.7956307463580122, -1.0), (-0.3741217736920627, -1.0), (-0.7618249530165098, -1.0), (0.35309144747273213, 1.0), (-0.49016020586396347, -1.0), (0.4558504630747765, 1.0), (-0.48201352881053394, -1.0), (-0.743160851297396, -1.0), (0.4086608170591697, 1.0), (-0.7930991052609813, -1.0), (0.7089294141120895, 1.0), (-0.692456821580423, -1.0), (-0.6469422869836682, -1.0), (-0.6085629875091253, -1.0), (-0.7002394719

[(pred, label), ...]:  [(-0.47124295534503596, -1.0), (-0.7610134607447157, -1.0), (0.9354745563011561, 1.0), (-0.7068505544225756, -1.0), (-0.11231424154045645, 1.0), (0.3646828991506015, 1.0), (0.5214282134125058, 1.0), (-0.5029757560448518, -1.0), (0.4291598833773996, 1.0), (0.7688376997424734, 1.0), (0.038830754549267044, -1.0), (-0.5157966991166673, -1.0), (-0.6348586217013443, -1.0), (0.6285576726801333, 1.0), (0.3835862445425208, 1.0), (-0.5235302176877628, -1.0), (-0.6828621473058307, -1.0), (0.7444301042313733, 1.0), (-0.4133543142040484, -1.0), (0.6408043953079563, 1.0), (-0.6871201835953019, -1.0), (-0.45840875896806316, -1.0), (-0.6051508876284466, -1.0), (0.4692092783906004, 1.0), (-0.5363861872111461, -1.0), (0.5038469458624217, 1.0), (-0.2978042825691177, -1.0), (-0.6630008922765619, -1.0), (0.4936164659857981, 1.0), (-0.711964133731257, -1.0), (0.6358429013448746, 1.0), (-0.7046048595557294, -1.0), (-0.5128808399156081, -1.0), (-0.5321170420459339, -1.0), (-0.7068168542

[(pred, label), ...]:  [(-0.3527941069215691, -1.0), (-0.6288450725936137, -1.0), (1.0373165926959163, 1.0), (-0.5986301704507361, -1.0), (0.05381725044955726, 1.0), (0.5188719994742971, 1.0), (0.6060351010345217, 1.0), (-0.3920165440944029, -1.0), (0.5573600987880941, 1.0), (0.8822053015237528, 1.0), (0.16933526889793304, -1.0), (-0.3670637267661512, -1.0), (-0.5315579285032798, -1.0), (0.7374851244466816, 1.0), (0.5218603757741807, 1.0), (-0.4047192473061417, -1.0), (-0.5591095189241794, -1.0), (0.8955722071206509, 1.0), (-0.2546245157122784, -1.0), (0.7757297948448735, 1.0), (-0.5476272779949448, -1.0), (-0.3355708839084506, -1.0), (-0.4852737640725474, -1.0), (0.6050857748716478, 1.0), (-0.42151886251983683, -1.0), (0.6695007936879454, 1.0), (-0.1773571896769972, -1.0), (-0.5200950816193245, -1.0), (0.6016396034196458, 1.0), (-0.6052520882273564, -1.0), (0.7782953108053124, 1.0), (-0.5808725485542957, -1.0), (-0.40786470031431143, -1.0), (-0.38839647378608133, -1.0), (-0.5629550440

[(pred, label), ...]:  [(-0.3811336554956481, -1.0), (-0.595725710179724, -1.0), (1.0595646186414813, 1.0), (-0.5675172434350351, -1.0), (0.1221134187395945, 1.0), (0.5195412431473772, 1.0), (0.6293817833559636, 1.0), (-0.332566720168695, -1.0), (0.5888030172219418, 1.0), (0.884011943405542, 1.0), (0.2331241446672773, -1.0), (-0.2717691647409723, -1.0), (-0.5245964358634583, -1.0), (0.7648488435045459, 1.0), (0.5766464995289551, 1.0), (-0.42809569158796495, -1.0), (-0.5673709127027446, -1.0), (0.9211760805336231, 1.0), (-0.18835259181962427, -1.0), (0.82088819098347, 1.0), (-0.5323385400069689, -1.0), (-0.2166623021213503, -1.0), (-0.5036253048711989, -1.0), (0.6186277071057815, 1.0), (-0.3592355081047714, -1.0), (0.7277927032986132, 1.0), (-0.2010693708622363, -1.0), (-0.49548881472346956, -1.0), (0.6315498561196875, 1.0), (-0.6094851662685399, -1.0), (0.9021831120863713, 1.0), (-0.5340464651501216, -1.0), (-0.4382913757133801, -1.0), (-0.36504258612569623, -1.0), (-0.4826740533906111

[(pred, label), ...]:  [(-0.3035531678466066, -1.0), (-0.6199914082704383, -1.0), (1.1204976536747435, 1.0), (-0.5855548261275654, -1.0), (0.15117193709912025, 1.0), (0.5848077664966014, 1.0), (0.6946119467813837, 1.0), (-0.36572913839006504, -1.0), (0.6130599375566778, 1.0), (0.9553069054391359, 1.0), (0.26077809340560276, -1.0), (-0.25141109464396194, -1.0), (-0.5426009908603969, -1.0), (0.8102864808232512, 1.0), (0.5732788743017472, 1.0), (-0.3523029471569321, -1.0), (-0.5108327655578944, -1.0), (0.9550454099746308, 1.0), (-0.14416806359917578, -1.0), (0.8307797072459717, 1.0), (-0.49569014376963627, -1.0), (-0.23622625378246143, -1.0), (-0.43602217230348683, -1.0), (0.6343884527415695, 1.0), (-0.43527856908954055, -1.0), (0.7261234618572361, 1.0), (-0.11442274211196618, -1.0), (-0.4863027400325759, -1.0), (0.6669724739399965, 1.0), (-0.6044712093707438, -1.0), (0.8478599234901995, 1.0), (-0.593752153268694, -1.0), (-0.40128058374903763, -1.0), (-0.3668366698741332, -1.0), (-0.48908

[(pred, label), ...]:  [(-0.35429110770242533, -1.0), (-0.7517232609866779, -1.0), (1.0490619377376325, 1.0), (-0.6730124503766076, -1.0), (0.0674778478995906, 1.0), (0.5169796154088149, 1.0), (0.6237944732313754, 1.0), (-0.45196825945781705, -1.0), (0.508715304087154, 1.0), (0.8871432916607295, 1.0), (0.1709388580001116, -1.0), (-0.30215256297157145, -1.0), (-0.6257051955648116, -1.0), (0.7174873129523511, 1.0), (0.4788166427062188, 1.0), (-0.4047510936344908, -1.0), (-0.5771033271076713, -1.0), (0.8736529228648308, 1.0), (-0.23945371526211412, -1.0), (0.7230908799426425, 1.0), (-0.5906355145768878, -1.0), (-0.34294580700734545, -1.0), (-0.4929820204537242, -1.0), (0.5229918098826714, 1.0), (-0.5677539073201185, -1.0), (0.6335425898328314, 1.0), (-0.16579060900560988, -1.0), (-0.6086615239026543, -1.0), (0.5750350525981102, 1.0), (-0.6812355313072194, -1.0), (0.7195295052530826, 1.0), (-0.7368795974501902, -1.0), (-0.4580523926282025, -1.0), (-0.5079245740655477, -1.0), (-0.6218584970

[(pred, label), ...]:  [(-0.22874130346641725, -1.0), (-0.704928813302192, -1.0), (1.0629166047102108, 1.0), (-0.656400937526718, -1.0), (0.1388111079503473, 1.0), (0.5587772884188824, 1.0), (0.6382313209019257, 1.0), (-0.45717855916461225, -1.0), (0.586469822289643, 1.0), (0.9234975189660133, 1.0), (0.2018654948145736, -1.0), (-0.2718227404264029, -1.0), (-0.5880402985157824, -1.0), (0.7698137300009081, 1.0), (0.619436811012503, 1.0), (-0.28409356877832626, -1.0), (-0.4658223230797308, -1.0), (0.8866081926522716, 1.0), (-0.14283760881868707, -1.0), (0.7128261472008653, 1.0), (-0.4837573567780097, -1.0), (-0.39348148732966093, -1.0), (-0.37669217247958986, -1.0), (0.5784298651025717, 1.0), (-0.5904478405901366, -1.0), (0.6340090854891418, 1.0), (-0.03941658539373283, -1.0), (-0.5226266638778863, -1.0), (0.6529224344715446, 1.0), (-0.6226253155912856, -1.0), (0.7071626702770981, 1.0), (-0.7332673845142867, -1.0), (-0.36712737956402186, -1.0), (-0.4346810015937823, -1.0), (-0.59222916933

[(pred, label), ...]:  [(-0.2595693672300432, -1.0), (-0.6406485600650824, -1.0), (1.0822500928101149, 1.0), (-0.5939927295671786, -1.0), (0.19597174655645036, 1.0), (0.5677082615324724, 1.0), (0.6471396276194383, 1.0), (-0.37065410800160736, -1.0), (0.6245001263663555, 1.0), (0.9273177600728949, 1.0), (0.2630478417572021, -1.0), (-0.22794966216410417, -1.0), (-0.5390434455083798, -1.0), (0.7935987455199611, 1.0), (0.6751360984142009, 1.0), (-0.31157623267724444, -1.0), (-0.479152898759198, -1.0), (0.9344473343387203, 1.0), (-0.11003300897314028, -1.0), (0.7900890673322586, 1.0), (-0.47424332206434305, -1.0), (-0.28411930564464516, -1.0), (-0.3991479623378993, -1.0), (0.6430070237257621, 1.0), (-0.46946970890849893, -1.0), (0.7134249171948517, 1.0), (-0.06717228762696409, -1.0), (-0.48590068162776157, -1.0), (0.6793839492111768, 1.0), (-0.5964747407625512, -1.0), (0.8448548596335189, 1.0), (-0.6328979951787775, -1.0), (-0.3726416543302423, -1.0), (-0.3822341707380257, -1.0), (-0.512184

[(pred, label), ...]:  [(-0.32747979814796035, -1.0), (-0.6005076804200029, -1.0), (1.0759004386323094, 1.0), (-0.5476016317679359, -1.0), (0.19512945152590783, 1.0), (0.5362870003666168, 1.0), (0.6478083039876845, 1.0), (-0.3128533001900011, -1.0), (0.6214174436095756, 1.0), (0.9052219386283846, 1.0), (0.2672066532025025, -1.0), (-0.26074108559205345, -1.0), (-0.5048205999315167, -1.0), (0.7926336446284485, 1.0), (0.6310193258659805, 1.0), (-0.3754434474098679, -1.0), (-0.526176773261303, -1.0), (0.9311511512293746, 1.0), (-0.14985935447460014, -1.0), (0.8365656986405501, 1.0), (-0.5058191546617381, -1.0), (-0.2093941277262793, -1.0), (-0.4555095988468204, -1.0), (0.6564168363698529, 1.0), (-0.3751031306503311, -1.0), (0.7276005008066069, 1.0), (-0.14217684740483028, -1.0), (-0.4846681325036739, -1.0), (0.6698126533668184, 1.0), (-0.5841198867854727, -1.0), (0.8962024696948135, 1.0), (-0.5521793465537963, -1.0), (-0.3972161451936863, -1.0), (-0.3594076744335395, -1.0), (-0.47406241245

[(pred, label), ...]:  [(-0.30969383676154566, -1.0), (-0.6430840678163467, -1.0), (1.083385885327104, 1.0), (-0.5616036353651612, -1.0), (0.1839242685904223, 1.0), (0.5428005189652785, 1.0), (0.6539035671961346, 1.0), (-0.33417290923667536, -1.0), (0.6086019271332657, 1.0), (0.9165296562266956, 1.0), (0.24205666906591733, -1.0), (-0.2854673164269665, -1.0), (-0.5094668104386157, -1.0), (0.7916812645867706, 1.0), (0.6074538656126036, 1.0), (-0.36013855987970517, -1.0), (-0.5246791135131346, -1.0), (0.9220375761666502, 1.0), (-0.18128002010022004, -1.0), (0.8224456139607472, 1.0), (-0.527941035023866, -1.0), (-0.25829379672914105, -1.0), (-0.44485720149453983, -1.0), (0.6227164852287104, 1.0), (-0.41901627435187255, -1.0), (0.6973175898229863, 1.0), (-0.12552678529444616, -1.0), (-0.5220297517777399, -1.0), (0.6648434013794345, 1.0), (-0.5844296257672097, -1.0), (0.8374707890415303, 1.0), (-0.600613614035229, -1.0), (-0.37758579587600577, -1.0), (-0.4038761619997137, -1.0), (-0.52353248

[(pred, label), ...]:  [(-0.32507753359256486, -1.0), (-0.6512305681910401, -1.0), (1.0577507340534233, 1.0), (-0.6108920698369946, -1.0), (0.19196008541899137, 1.0), (0.5243726421747842, 1.0), (0.6342494264938564, 1.0), (-0.39429933485682456, -1.0), (0.6136292962747434, 1.0), (0.896379762085018, 1.0), (0.21543066100120717, -1.0), (-0.3367434290535788, -1.0), (-0.5600650379886554, -1.0), (0.794268793327154, 1.0), (0.6081056316732782, 1.0), (-0.37578281387131507, -1.0), (-0.5371476924696571, -1.0), (0.887101981432679, 1.0), (-0.16609437564503057, -1.0), (0.7916414435849103, 1.0), (-0.5231302138387212, -1.0), (-0.301475249050834, -1.0), (-0.4604637846968892, -1.0), (0.6035416901116828, 1.0), (-0.46712067734330665, -1.0), (0.6544357457651917, 1.0), (-0.13739908844791013, -1.0), (-0.5122600893593184, -1.0), (0.6692878724091941, 1.0), (-0.6254847413524122, -1.0), (0.7931801304306895, 1.0), (-0.6268885466626448, -1.0), (-0.4147700145280173, -1.0), (-0.38973671848024816, -1.0), (-0.5168085751

[(pred, label), ...]:  [(-0.2617469113731169, -1.0), (-0.6262950807948504, -1.0), (1.041357631632641, 1.0), (-0.6556714483274586, -1.0), (0.18940944301109236, 1.0), (0.5272030625007501, 1.0), (0.6323667996653237, 1.0), (-0.47607869903400135, -1.0), (0.6481484727281887, 1.0), (0.902576081136584, 1.0), (0.20032110805779005, -1.0), (-0.41107169472782257, -1.0), (-0.5956530625827717, -1.0), (0.8181496094559322, 1.0), (0.6298053858083924, 1.0), (-0.3148576611462466, -1.0), (-0.4791924900104493, -1.0), (0.8508548567637729, 1.0), (-0.12198724641773387, -1.0), (0.7373072300059058, 1.0), (-0.45371398331881707, -1.0), (-0.3793223767131395, -1.0), (-0.40218320875064795, -1.0), (0.6336783097718468, 1.0), (-0.5419786433557572, -1.0), (0.5885029995968234, 1.0), (-0.06992812276629537, -1.0), (-0.4477482364387083, -1.0), (0.7029323149371882, 1.0), (-0.6315212461027816, -1.0), (0.693237818882969, 1.0), (-0.6497591359741443, -1.0), (-0.40108451124836286, -1.0), (-0.3214375286847535, -1.0), (-0.502268629

[(pred, label), ...]:  [(-0.36508202394895656, -1.0), (-0.6774044380406627, -1.0), (0.9832842788897234, 1.0), (-0.6827687260238915, -1.0), (0.12708042139436496, 1.0), (0.43975233921009166, 1.0), (0.5706295934267416, 1.0), (-0.48477337949670024, -1.0), (0.5790112704197349, 1.0), (0.8281613724136543, 1.0), (0.15561325028984424, -1.0), (-0.4623065021259202, -1.0), (-0.6323846253775416, -1.0), (0.7516573016036039, 1.0), (0.5595201891410775, 1.0), (-0.4150004077379092, -1.0), (-0.5716472035706177, -1.0), (0.7915617622792546, 1.0), (-0.20173550399632417, -1.0), (0.7089715351946939, 1.0), (-0.5425427150634587, -1.0), (-0.3760490314523088, -1.0), (-0.4988050635321221, -1.0), (0.5796084804561876, 1.0), (-0.5408922608851341, -1.0), (0.5448588325824952, 1.0), (-0.17403264170693744, -1.0), (-0.5239119604993066, -1.0), (0.6311548335630625, 1.0), (-0.6861427041911072, -1.0), (0.6998772374891089, 1.0), (-0.671508053417714, -1.0), (-0.4788835415597335, -1.0), (-0.3899075020410368, -1.0), (-0.545858141