-
Notifications
You must be signed in to change notification settings - Fork 26
/
lightning.py
37 lines (28 loc) · 1.03 KB
/
lightning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from benchopt import BaseSolver
from benchopt import safe_import_context
with safe_import_context() as import_ctx:
from lightning.regression import CDRegressor
# TODO: lightning always fit an intercept
# it is thus not optimizing the same cost function
class Solver(BaseSolver):
name = 'Lightning'
install_cmd = 'conda'
requirements = [
'pip:git+https://github.com/scikit-learn-contrib/lightning.git'
]
references = [
'M. Blondel, K. Seki and K. Uehara, '
'"Block coordinate descent algorithms for large-scale sparse '
'multiclass classification" '
'Mach. Learn., vol. 93, no. 1, pp. 31-52 (2013)'
]
def set_objective(self, X, y, lmbd):
self.X, self.y, self.lmbd = X, y, lmbd
self.clf = CDRegressor(
loss='squared', penalty='l1', C=1, alpha=self.lmbd,
tol=1e-15)
def run(self, n_iter):
self.clf.max_iter = n_iter
self.clf.fit(self.X, self.y)
def get_result(self):
return self.clf.coef_.flatten()