Skip to content

Commit

Permalink
hyperparams object?
Browse files Browse the repository at this point in the history
  • Loading branch information
gcattan committed Nov 19, 2021
1 parent 467358c commit 2ad611f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
15 changes: 10 additions & 5 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Module for classification function."""
import numpy as np

from .utils.hyperparams import QuanticHyperParams

from sklearn.base import BaseEstimator, ClassifierMixin

from qiskit import BasicAer, IBMQ
from qiskit.circuit.library import ZZFeatureMap, TwoLocal
from qiskit.circuit.library import TwoLocal
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit.aqua.quantum_instance import logger
from qiskit.aqua.algorithms import QSVM, SklearnSVM, VQC
Expand Down Expand Up @@ -45,6 +47,8 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
the classification task will be running on a IBM quantum backend
verbose : bool (default:True)
If true will output all intermediate results and logs
hyper_params : QuanticHyperParams (default:QuanticHyperParams())
The hyper parameters for quantum classifiers
Notes
-----
Expand All @@ -59,6 +63,7 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
--------
QuanticSVM
QuanticVQC
QuanticHyperParams
References
----------
Expand All @@ -68,11 +73,12 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
"""

def __init__(self, quantum=True, q_account_token=None, verbose=True):
def __init__(self, quantum=True, q_account_token=None, verbose=True, hyper_params=QuanticHyperParams()):
self.verbose = verbose
self._log("Initializing Quantum Classifier")
self.q_account_token = q_account_token
self.quantum = quantum
self.hyper_params = hyper_params
# protected field for child classes
self._training_input = {}

Expand Down Expand Up @@ -148,8 +154,7 @@ def fit(self, X, y):

feature_dim = get_feature_dimension(self._training_input)
self._log("Feature dimension = ", feature_dim)
self._feature_map = ZZFeatureMap(feature_dimension=feature_dim, reps=2,
entanglement='linear')
self._feature_map = self.hyper_params.feature_map(feature_dim)
if self.quantum:
if not hasattr(self, "_backend"):
def filters(device):
Expand Down Expand Up @@ -247,7 +252,7 @@ def _init_algo(self, feature_dim):
if self.quantum:
classifier = QSVM(self._feature_map, self._training_input)
else:
classifier = SklearnSVM(self._training_input)
classifier = SklearnSVM(self._training_input, gamma=self.hyper_params.gamma)
return classifier

def predict_proba(self, X):
Expand Down
5 changes: 5 additions & 0 deletions pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import hyperparams

__all__ = [
'hyperparams',
]
45 changes: 45 additions & 0 deletions pyriemann_qiskit/utils/hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from qiskit.circuit.library import ZZFeatureMap

default = {
"gamma": None, #quantum=False
"lambda2":0.001, #QSVC
"feature_map":
lambda dim:
ZZFeatureMap(feature_dimension=dim, reps=2, entanglement='linear'),
"nshots":None,
"optimizer":None, #VQC
"var_form":None, #VQC
"enforce_spd":None,
"output_norm":None, #quantum=Fale
"l2reg":None
}


class QuanticHyperParams():
"""
This class is a wrapper around all hyper parameters for quantum classifier.
Parameters
----------
gamma: TODO
lambda2 : L2 norm regularization factor (QSVM)
feature_map : (feature_dim)->(Union[QuantumCircuit, FeatureMap])
Feature map module, used to transform data (QSVM and VQC)
Notes
-----
.. versionadded:: 0.0.1
See Also
--------
QuanticClassifierBase
QuanticSVM
QuanticVQC
"""
def __init__(self, gamma=None, lambda2=default["lambda2"], feature_map=default["feature_map"]):
self.gamma = gamma
self.lambda2 = lambda2
self.feature_map = feature_map

0 comments on commit 2ad611f

Please sign in to comment.