forked from open-mmlab/mmrazor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
gaoyang07
committed
Sep 30, 2022
1 parent
8d603d9
commit ebb8db2
Showing
11 changed files
with
479 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .metric_predictor import MetricPredictor | ||
from .zero_shot_predictor import ZeroShotPredictor | ||
|
||
__all__ = ['MetricPredictor', 'ZeroShotPredictor'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import abstractmethod | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
|
||
|
||
class BasePredictor(): | ||
"""Base predictor.""" | ||
|
||
def __init__(self, handler_cfg: dict): | ||
"""init.""" | ||
self.handler_cfg = handler_cfg | ||
self.handler = TASK_UTILS.build(handler_cfg) | ||
|
||
@abstractmethod | ||
def predict(self, model, predict_args): | ||
"""predict result.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .carts_handler import CartsHandler | ||
from .gp_handler import GaussProcessHandler | ||
from .mlp_handler import MLPHandler | ||
from .rbf_handler import RBFHandler | ||
|
||
__all__ = ['CartsHandler', 'GaussProcessHandler', 'MLPHandler', 'RBFHandler'] |
25 changes: 25 additions & 0 deletions
25
mmrazor/models/task_modules/predictor/handler/base_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from joblib import dump, load | ||
|
||
|
||
class BaseHandler(): | ||
"""Base handler.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
def fit(self, train_data, train_label): | ||
pass | ||
|
||
def predict(self, test_data): | ||
pass | ||
|
||
def load(self, path): | ||
"""Load pretrained weights for the handler.""" | ||
self.model = load(path) | ||
|
||
def save(self, path): | ||
"""Save the handler and return saved path for diff suffix.""" | ||
path += f'_{self.__class__.__name__}.joblib'.lower() | ||
dump(self.model, path) | ||
return path |
69 changes: 69 additions & 0 deletions
69
mmrazor/models/task_modules/predictor/handler/carts_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
from sklearn.tree import DecisionTreeRegressor | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
from .base_handler import BaseHandler | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class CartsHandler(BaseHandler): | ||
"""Classification and Regression Tree. | ||
Args: | ||
n_tree (int): number of regression trees. | ||
""" | ||
|
||
def __init__(self, n_tree=1000): | ||
self.n_tree = n_tree | ||
self.model = None | ||
|
||
@staticmethod | ||
def _make_decision_trees(train_data, train_label, n_tree): | ||
"""Construct the decision trees.""" | ||
feature_record = [] | ||
tree_record = [] | ||
|
||
for i in range(n_tree): | ||
sample_idx = np.arange(train_data.shape[0]) | ||
np.random.shuffle(sample_idx) | ||
train_data = train_data[sample_idx, :] | ||
train_label = train_label[sample_idx] | ||
|
||
feature_idx = np.arange(train_data.shape[1]) | ||
np.random.shuffle(feature_idx) | ||
n_feature = np.random.randint(1, train_data.shape[1] + 1) | ||
selected_feature_ids = feature_idx[0:n_feature] | ||
feature_record.append(selected_feature_ids) | ||
|
||
dt = DecisionTreeRegressor() | ||
dt.fit(train_data[:, selected_feature_ids], train_label) | ||
tree_record.append(dt) | ||
|
||
return tree_record, feature_record | ||
|
||
def fit(self, train_data, train_label): | ||
"""Training predictor.""" | ||
self.model = self._make_decision_trees(train_data, train_label, | ||
self.n_tree) | ||
|
||
def predict(self, test_data): | ||
"""Predict the subnets' performance.""" | ||
trees, features = self.model[0], self.model[1] | ||
test_num, n_tree = len(test_data), len(trees) | ||
|
||
predict_labels = np.zeros((test_num, 1)) | ||
for i in range(test_num): | ||
this_test_data = test_data[i, :] | ||
predict_this_list = np.zeros(n_tree) | ||
|
||
for j, (tree, feature) in enumerate(zip(trees, features)): | ||
predict_this_list[j] = tree.predict([this_test_data[feature] | ||
])[0] | ||
|
||
predict_this_list = np.sort(predict_this_list) | ||
predict_this_list = predict_this_list[::-1] | ||
this_predict = np.mean(predict_this_list) | ||
predict_labels[i, 0] = this_predict | ||
|
||
return predict_labels |
112 changes: 112 additions & 0 deletions
112
mmrazor/models/task_modules/predictor/handler/gp_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
|
||
try: | ||
import pydacefit | ||
from pydacefit.dace import DACE | ||
except ImportError: | ||
pydacefit = None | ||
DACE = object | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
from .base_handler import BaseHandler | ||
|
||
|
||
def get_func(): | ||
if pydacefit is None: | ||
raise RuntimeError('Failed to import pydacefit. Please run ' | ||
'"pip install pydacefit". ') | ||
|
||
from pydacefit.corr import (corr_cubic, corr_exp, corr_expg, corr_gauss, | ||
corr_spherical, corr_spline) | ||
from pydacefit.dace import regr_linear, regr_quadratic | ||
from pydacefit.regr import regr_constant | ||
|
||
REGR = { | ||
'linear': regr_linear, | ||
'constant': regr_constant, | ||
'quadratic': regr_quadratic | ||
} | ||
|
||
CORR = { | ||
'gauss': corr_gauss, | ||
'cubic': corr_cubic, | ||
'exp': corr_exp, | ||
'expg': corr_expg, | ||
'spline': corr_spline, | ||
'spherical': corr_spherical | ||
} | ||
|
||
return REGR, CORR | ||
|
||
|
||
class DACE_with_smooth(DACE): | ||
"""GP model.""" | ||
|
||
def __init__(self, regr, corr, theta=1, thetaL=0, thetaU=100): | ||
super(DACE_with_smooth, self).__init__(regr, corr, theta, thetaL, | ||
thetaU) | ||
|
||
def fit(self, X, Y): | ||
|
||
if len(Y.shape) == 1: | ||
Y = Y[:, None] | ||
|
||
if X.shape[0] != Y.shape[0]: | ||
raise Exception('X and Y must have the same number of rows.') | ||
|
||
mX, sX = np.mean(X, axis=0), np.std(X, axis=0, ddof=1) + 1e-6 | ||
mY, sY = np.mean(Y, axis=0), np.std(Y, axis=0, ddof=1) + 1e-6 | ||
|
||
nX = (X - mX) / sX | ||
nY = (Y - mY) / sY | ||
|
||
if self.tl is not None and self.tu is not None: | ||
self.model = {'nX': nX, 'nY': nY} | ||
self.boxmin() | ||
self.model = self.itpar['best'] | ||
else: | ||
from pydacefit.fit import fit | ||
self.model = fit(nX, nY, self.regr, self.kernel, self.theta) | ||
|
||
self.model = { | ||
**self.model, 'mX': mX, | ||
'sX': sX, | ||
'mY': mY, | ||
'sY': sY, | ||
'nX': nX, | ||
'nY': nY | ||
} | ||
self.model['sigma2'] = np.square(sY) @ self.model['_sigma2'] | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class GaussProcessHandler(BaseHandler): | ||
"""Gaussian Process (Kriging) | ||
Args: | ||
regr (str): regression kernel for GP model. | ||
corr (str): correlation kernel for GP model. | ||
""" | ||
|
||
def __init__(self, regr='linear', corr='gauss'): | ||
REGR, CORR = get_func() | ||
assert regr in REGR and corr in CORR, \ | ||
NotImplementedError('Unknown GP regression or correlation !') | ||
self.regr = REGR[regr] | ||
self.corr = CORR[corr] | ||
|
||
self.model = DACE_with_smooth( | ||
regr=self.regr, | ||
corr=self.corr, | ||
theta=1.0, | ||
thetaL=0.00001, | ||
thetaU=100) | ||
|
||
def fit(self, train_data, train_label): | ||
"""Training predictor.""" | ||
self.model.fit(train_data, train_label) | ||
|
||
def predict(self, test_data): | ||
"""Predict the subnets' performance.""" | ||
return self.model.predict(test_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .base_handler import BaseHandler | ||
|
||
|
||
class MLPHandler(BaseHandler): | ||
|
||
def __init__(self) -> None: | ||
super().__init__() |
61 changes: 61 additions & 0 deletions
61
mmrazor/models/task_modules/predictor/handler/rbf_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
try: | ||
import pySOT | ||
except ImportError: | ||
pySOT = None | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
from .base_handler import BaseHandler | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class RBFHandler(BaseHandler): | ||
"""Radial Basis Function. | ||
Args: | ||
kernel (str): RBF kernel object. | ||
tail (str): RBF polynomial tail object. | ||
""" | ||
|
||
def __init__(self, kernel='tps', tail='linear'): | ||
if pySOT is None: | ||
raise RuntimeError('Failed to import pydacefit. Please run ' | ||
'"pip install pySOT==0.2.3". ') | ||
from pySOT.surrogate import (ConstantTail, CubicKernel, LinearTail, | ||
TPSKernel) | ||
self.kernel = kernel | ||
self.tail = tail | ||
self.model = None | ||
|
||
if kernel == 'cubic': | ||
self.kernel = CubicKernel | ||
elif self.kernel == 'tps': | ||
self.kernel = TPSKernel | ||
else: | ||
raise NotImplementedError('unknown RBF kernel') | ||
|
||
if tail == 'linear': | ||
self.tail = LinearTail | ||
elif self.tail == 'constant': | ||
self.tail = ConstantTail | ||
else: | ||
raise NotImplementedError('unknown RBF tail') | ||
|
||
def fit(self, train_data, train_label): | ||
"""Training predictor.""" | ||
if train_data.shape[0] <= train_data.shape[1]: | ||
raise ValueError('RBF only support ' | ||
f'# of samples{train_data.shape[0]}' | ||
f' > # of dimensions{train_data.shape[1]} !') | ||
from pySOT.surrogate import RBFInterpolant | ||
self.model = RBFInterpolant( | ||
dim=train_data.shape[1], | ||
kernel=self.kernel(), | ||
tail=self.tail(train_data.shape[1])) | ||
|
||
for i in range(len(train_data)): | ||
self.model.add_points(train_data[i, :], train_label[i]) | ||
|
||
def predict(self, test_data): | ||
"""Predict the subnets' performance.""" | ||
return self.model.predict(test_data) |
Oops, something went wrong.