In [104]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from qiskit import BasicAer
from qiskit import IBMQ
from qiskit.tools.monitor import job_monitor
from qiskit.circuit.library import ZZFeatureMap
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit.aqua.algorithms import QSVM, SklearnSVM
from qiskit.aqua.components.multiclass_extensions import AllPairs, OneAgainstRest, ErrorCorrectingCode
from qiskit.aqua.utils import split_dataset_to_data_and_labels, map_label_to_class_name

import sklearn
from sklearn.feature_selection import SelectFromModel
from sklearn.ensemble import RandomForestClassifier

import time

seed = 10599
aqua_globals.random_seed = seed

In [105]:
### Import data
df = pd.read_csv("data/stem_processed_all.csv")

### Subset
df = df.iloc[0:50]

### Select label column
LABEL_COL = 'TECH3'

x = df.drop(columns=[LABEL_COL])
y = df[LABEL_COL]

### Train-test split, use RF to select only the <max_features> top features
MAX_FEATURES = 5
xtrain_raw, xtest_raw, ytrain, ytest = sklearn.model_selection.train_test_split(x, y, test_size=0.2)

sel = SelectFromModel(sklearn.ensemble.RandomForestClassifier(n_estimators = 100), 
                      max_features=MAX_FEATURES)
sel.fit(xtrain_raw, ytrain)
support = sel.get_support()
newdf_columns = []
for column, sup in zip(x.columns, support):
    if sup == True:
        newdf_columns.append(column)
xtrain = xtrain_raw[newdf_columns].copy()
xtest = xtest_raw[newdf_columns].copy()

feature_dim = len(xtrain.columns)

### Reformat the data into dictionaries as is preferred by qiskit functions
train_inp = {}
test_inp = {}

for i in range(3):
    train_inp[i] = xtrain[ytrain == i].values.astype(float)
    test_inp[i] = xtest[ytest == i].values.astype(float)

temp = [test_inp[k] for k in test_inp]
total_array = np.concatenate(temp)

In [106]:
### Classical SVM
#################

accs = []
times = []
for i in range(10):
    start = time.time()
    result = SklearnSVM(train_inp, test_inp, total_array, multiclass_extension=AllPairs()).run()
    accs.append(result['testing_accuracy'])
    times.append(time.time()-start)

print(np.mean(accs))
print(np.mean(times))

0.5555555555555556
0.0031497001647949217


In [111]:
from matplotlib import pyplot as plt
from sklearn import svm

def f_importances(coef, names):
    imp = coef
    imp,names = zip(*sorted(zip(imp,names)))
    plt.barh(range(len(names)), imp, align='center')
    plt.yticks(range(len(names)), names)
    plt.show()

features_names = xtrain.columns
svm = svm.SVC(kernel='linear')
svm.fit(xtrain.to_numpy(), ytrain.to_numpy())
print(features_names)
svm.coef_

Index(['TECH6', 'REASON2d-3', 'REASON2e-2', 'ETHN5-1', 'RELATE2-3'], dtype='object')


array([[-1.125     ,  0.875     , -0.375     ,  0.75      ,  0.375     ],
       [-1.33333333,  0.66666667,  0.        ,  0.        ,  0.66666667],
       [-0.5       ,  0.        ,  1.5       , -2.        ,  1.        ]])

In [55]:
### View available backends

IBMQ.load_account()

provider = IBMQ.get_provider(group='open')
provider.backends(filters=lambda x: x.configuration().n_qubits > 5 and not x.configuration().simulator)

In [107]:
### Quantum SVM
###############

### UFUNC ERROR with qasm_simulator ###

class_labels = [0,1,2]

    # Alternate multiclass extension: OneAgainstRest(), ErrorCorrectingCode(code_size=5), AllPairs()
feature_map = ZZFeatureMap(feature_dimension=feature_dim, reps=2, entanglement='linear')
qsvm = QSVM(feature_map, train_inp, test_inp, total_array, 
            multiclass_extension=ErrorCorrectingCode(code_size=5))

backend = BasicAer.get_backend('qasm_simulator')
# backend = provider.get_backend("ibmq_16_melbourne")
quantum_instance = QuantumInstance(backend, shots=256, seed_simulator=seed, 
                                   seed_transpiler=seed)

result = qsvm.run(quantum_instance)

for k,v in result.items():
    print(f'{k} : {v}')

ValueError: Complex data not supported
[[ 1.        +0.j -0.99885606+0.j -0.99828688+0.j  0.99943082+0.j
   1.        +0.j  1.        +0.j  0.99885606+0.j  0.99943082+0.j
  -0.99943082+0.j  0.99885606+0.j -0.99828688+0.j  1.        +0.j
   1.        +0.j  0.99943082+0.j  0.99943082+0.j]
 [ 1.        +0.j -0.33321053+0.j -0.33284766+0.j  0.99963713+0.j
   1.        +0.j  1.        +0.j  0.33321053+0.j  0.99963713+0.j
  -0.99963713+0.j  0.33321053+0.j -0.33284766+0.j  1.        +0.j
   1.        +0.j  0.99963713+0.j  0.99963713+0.j]
 [ 1.        +0.j -0.99946532+0.j  0.99984383+0.j -0.99930915+0.j
   1.        +0.j  1.        +0.j  0.99946532+0.j -0.99930915+0.j
   0.99930915+0.j  0.99946532+0.j  0.99984383+0.j  1.        +0.j
   1.        +0.j -0.99930915+0.j -0.99930915+0.j]
 [ 1.        +0.j  0.32916662+0.j  0.4690441 +0.j  0.86012251+0.j
   1.        +0.j  1.        +0.j -0.32916662+0.j  0.86012251+0.j
  -0.86012251+0.j -0.32916662+0.j  0.4690441 +0.j  1.        +0.j
   1.        +0.j  0.86012251+0.j  0.86012251+0.j]
 [ 1.        +0.j  0.02624995+0.j  0.63422717+0.j  0.39202279+0.j
   1.        +0.j  1.        +0.j -0.02624995+0.j  0.39202279+0.j
  -0.39202279+0.j -0.02624995+0.j  0.63422717+0.j  1.        +0.j
   1.        +0.j  0.39202279+0.j  0.39202279+0.j]
 [ 1.        +0.j  0.99915712+0.j  0.99956283+0.j  0.99959429+0.j
   1.        +0.j  1.        +0.j -0.99915712+0.j  0.99959429+0.j
  -0.99959429+0.j -0.99915712+0.j  0.99956283+0.j  1.        +0.j
   1.        +0.j  0.99959429+0.j  0.99959429+0.j]
 [ 1.        +0.j  0.99958481+0.j  0.99984366+0.j  0.99974115+0.j
   1.        +0.j  1.        +0.j -0.99958481+0.j  0.99974115+0.j
  -0.99974115+0.j -0.99958481+0.j  0.99984366+0.j  1.        +0.j
   1.        +0.j  0.99974115+0.j  0.99974115+0.j]
 [ 1.        +0.j -0.99946532+0.j  0.99984383+0.j -0.99930915+0.j
   1.        +0.j  1.        +0.j  0.99946532+0.j -0.99930915+0.j
   0.99930915+0.j  0.99946532+0.j  0.99984383+0.j  1.        +0.j
   1.        +0.j -0.99930915+0.j -0.99930915+0.j]
 [ 1.        +0.j  0.03735194+0.j  0.14608187+0.j  0.89127008+0.j
   1.        +0.j  1.        +0.j -0.03735194+0.j  0.89127008+0.j
  -0.89127008+0.j -0.03735194+0.j  0.14608187+0.j  1.        +0.j
   1.        +0.j  0.89127008+0.j  0.89127008+0.j]]
