Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
dmeoli committed Mar 6, 2021
1 parent 3c831d2 commit a900f21
Show file tree
Hide file tree
Showing 4 changed files with 551 additions and 858 deletions.
66 changes: 50 additions & 16 deletions notebooks/ml/CM_SVC_report_experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@
"outputs": [],
"source": [
"from optiml.ml.svm import PrimalSVC\n",
"from optiml.ml.svm.losses import hinge, squared_hinge\n",
"from optiml.ml.utils import generate_linearly_separable_overlap_data, generate_nonlinearly_separable_data, plot_svm_hyperplane\n",
"from optiml.ml.svm.losses import hinge\n",
"from optiml.ml.utils import generate_linearly_separable_overlap_data, plot_svm_hyperplane\n",
"\n",
"from optiml.opti.unconstrained.stochastic import AdaGrad, StochasticGradientDescent\n",
"from optiml.opti.unconstrained.stochastic import AdaGrad\n",
"from optiml.opti.utils import plot_trajectory_optimization\n",
"\n",
"from sklearn.svm import LinearSVC as SkLinearSVC\n",
Expand Down Expand Up @@ -177,7 +177,7 @@
" scoring=primal_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)\n",
"\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
" scoring=primal_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -498,7 +498,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)\n",
"\n",
Expand Down Expand Up @@ -611,7 +611,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -827,7 +827,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)\n",
"\n",
Expand Down Expand Up @@ -960,7 +960,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)\n",
"\n",
Expand Down Expand Up @@ -1235,7 +1235,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -1576,7 +1576,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -1660,7 +1660,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -1837,7 +1837,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -1956,7 +1956,7 @@
" scoring=dual_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down Expand Up @@ -2214,6 +2214,40 @@
"### Primal formulation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidePrompt": true
},
"outputs": [],
"source": [
"from optiml.ml.svm import PrimalSVC\n",
"from optiml.ml.svm.losses import squared_hinge\n",
"from optiml.ml.utils import generate_linearly_separable_overlap_data, plot_svm_hyperplane\n",
"\n",
"from optiml.opti.unconstrained.stochastic import StochasticGradientDescent\n",
"from optiml.opti.utils import plot_trajectory_optimization\n",
"\n",
"from sklearn.svm import LinearSVC as SkLinearSVC\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"hidePrompt": true
},
"outputs": [],
"source": [
"def primal_accuracy_scorer(svc, X, y):\n",
" return {'accuracy': svc.score(X, y),\n",
" 'nr_support_vectors': len(np.argwhere(np.abs(svc.decision_function(X)) <= 1).ravel())}"
]
},
{
"cell_type": "code",
"execution_count": 47,
Expand Down Expand Up @@ -2249,7 +2283,7 @@
" scoring=primal_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)\n",
"\n",
Expand Down Expand Up @@ -2337,7 +2371,7 @@
" scoring=primal_accuracy_scorer,\n",
" cv=3, # 3 fold cross validation\n",
" n_jobs=-1, # use all processors\n",
" refit='accuracy', # refit the best model (wrt best accuracy) on the full development set\n",
" refit='accuracy', # refit the best model (wrt accuracy) on the full development set\n",
" return_train_score=True,\n",
" verbose=True).fit(X, y)"
]
Expand Down
1,341 changes: 500 additions & 841 deletions notebooks/ml/CM_SVR_report_experiments.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion optiml/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_nonlinearly_separable_data(size=100, random_state=None):

def generate_nonlinearly_regression_data(size=100, random_state=None):
rs = np.random.RandomState(random_state)
X = np.sort(4 * np.pi * rs.uniform(size=size))
X = np.sort(2 * np.pi * rs.uniform(size=size))
y = np.sin(X)
y += 0.25 * (0.5 - rs.uniform(size=size)) # noise
return X.reshape(-1, 1), y
Expand Down

0 comments on commit a900f21

Please sign in to comment.