-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: khalid-davis <huangqinkai1@huawei.com>
- Loading branch information
1 parent
8daa0c5
commit e5d5990
Showing
20 changed files
with
726 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import logging | ||
|
||
from . import joint_inference | ||
from .context import context | ||
|
||
|
||
def log_configure(): | ||
logging.basicConfig( | ||
format='[%(asctime)s][%(name)s][%(levelname)s][%(lineno)s]: ' | ||
'%(message)s', | ||
level=logging.INFO) | ||
|
||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
log_configure() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
|
||
|
||
class BaseConfig: | ||
"""The base config, the value can not be changed.""" | ||
# dataset | ||
train_dataset_url = os.getenv("TRAIN_DATASET_URL") | ||
test_dataset_url = os.getenv("TEST_DATASET_URL") | ||
# k8s crd info | ||
namespace = os.getenv("NAMESPACE", "") | ||
worker_name = os.getenv("WORKER_NAME", "") | ||
service_name = os.getenv("SERVICE_NAME", "") | ||
|
||
model_url = os.getenv("MODEL_URL") | ||
|
||
# user parameter | ||
parameters = os.getenv("PARAMETERS") | ||
# Hard Example Mining Algorithm | ||
hem_name = os.getenv("HEM_NAME") | ||
hem_parameters = os.getenv("HEM_PARAMETERS") | ||
|
||
def __init__(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import logging | ||
from enum import Enum | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
class Framework(Enum): | ||
Tensorflow = "tensorflow" | ||
Pytorch = "pytorch" | ||
Mindspore = "mindspore" | ||
|
||
|
||
class K8sResourceKind(Enum): | ||
JOINT_INFERENCE_SERVICE = "jointinferenceservice" | ||
|
||
|
||
class K8sResourceKindStatus(Enum): | ||
COMPLETED = "completed" | ||
FAILED = "failed" | ||
RUNNING = "running" | ||
|
||
|
||
FRAMEWORK = Framework.Tensorflow # TODO: should read from env. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import json | ||
import logging | ||
|
||
from neptune.common.config import BaseConfig | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
def parse_parameters(parameters): | ||
""" | ||
:param parameters: | ||
[{"key":"batch_size","value":"32"}, | ||
{"key":"learning_rate","value":"0.001"}, | ||
{"key":"min_node_number","value":"3"}] | ||
----> | ||
:return: | ||
{'batch_size':32, 'learning_rate':0.001, 'min_node_number'=3} | ||
""" | ||
p = {} | ||
if parameters is None or len(parameters) == 0: | ||
LOG.info(f"PARAMETERS={parameters}, return empty dict.") | ||
return p | ||
j = json.loads(parameters) | ||
for d in j: | ||
p[d.get('key')] = d.get('value') | ||
return p | ||
|
||
|
||
class Context: | ||
"""The Context provides the capability of obtaining the context of the | ||
`PARAMETERS` and `HEM_PARAMETERS` field""" | ||
|
||
def __init__(self): | ||
self.parameters = parse_parameters(BaseConfig.parameters) | ||
self.hem_parameters = parse_parameters(BaseConfig.hem_parameters) | ||
|
||
def get_context(self): | ||
return self.parameters | ||
|
||
def get_parameters(self, param, default=None): | ||
"""get the value of the key `param` in `PARAMETERS`, | ||
if not exist, the default value is returned""" | ||
value = self.parameters.get(param) | ||
return value if value else default | ||
|
||
def get_hem_parameters(self, param, default=None): | ||
"""get the value of the key `param` in `HEM_PARAMETERS`, | ||
if not exist, the default value is returned""" | ||
value = self.hem_parameters.get(param) | ||
return value if value else default | ||
|
||
|
||
context = Context() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import BaseFilter, ThresholdFilter | ||
from .image_classification.hard_mine_filters import CrossEntropyFilter | ||
from .object_detection.scores_filters import IBTFilter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
class BaseFilter: | ||
"""The base class to define unified interface.""" | ||
|
||
def hard_judge(self, infer_result=None): | ||
"""predict function, and it must be implemented by | ||
different methods class. | ||
:param infer_result: prediction result | ||
:return: `True` means hard sample, `False` means not a hard sample. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class ThresholdFilter(BaseFilter): | ||
def __init__(self, threshold=0.5): | ||
self.threshold = threshold | ||
|
||
def hard_judge(self, infer_result=None): | ||
""" | ||
:param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | ||
:return: `True` means hard sample, `False` means not a hard sample. | ||
""" | ||
if not infer_result: | ||
return True | ||
|
||
image_score = 0 | ||
for bbox in infer_result: | ||
image_score += bbox[4] | ||
|
||
average_score = image_score / (len(infer_result) or 1) | ||
return average_score < self.threshold |
1 change: 1 addition & 0 deletions
1
lib/neptune/hard_example_mining/hard_example_helpers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .data_check_utils import data_check |
8 changes: 8 additions & 0 deletions
8
lib/neptune/hard_example_mining/hard_example_helpers/data_check_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def data_check(data): | ||
"""Check the data in [0,1].""" | ||
return 0 <= float(data) <= 1 |
Empty file.
53 changes: 53 additions & 0 deletions
53
lib/neptune/hard_example_mining/image_classification/hard_mine_filters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import logging | ||
import math | ||
|
||
from neptune.hard_example_mining import BaseFilter | ||
from neptune.hard_example_mining.hard_example_helpers import data_check | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CrossEntropyFilter(BaseFilter): | ||
""" Implement the hard samples discovery methods named IBT | ||
(image-box-thresholds). | ||
:param threshold_cross_entropy: threshold_cross_entropy to filter img, | ||
whose hard coefficient is less than | ||
threshold_cross_entropy. And its default value is | ||
threshold_cross_entropy=0.5 | ||
""" | ||
|
||
def __init__(self, threshold_cross_entropy=0.5): | ||
self.threshold_cross_entropy = threshold_cross_entropy | ||
|
||
def hard_judge(self, infer_result=None): | ||
"""judge the img is hard sample or not. | ||
:param infer_result: | ||
prediction classes list, | ||
such as [class1-score, class2-score, class2-score,....], | ||
where class-score is the score corresponding to the class, | ||
class-score value is in [0,1], who will be ignored if its value | ||
not in [0,1]. | ||
:return: `True` means a hard sample, `False` means not a hard sample. | ||
""" | ||
if infer_result is None: | ||
logger.warning(f'infer result is invalid, value: {infer_result}!') | ||
return False | ||
elif len(infer_result) == 0: | ||
return False | ||
else: | ||
log_sum = 0.0 | ||
data_check_list = [class_probability for class_probability | ||
in infer_result | ||
if data_check(class_probability)] | ||
if len(data_check_list) == len(infer_result): | ||
for class_data in data_check_list: | ||
log_sum += class_data * math.log(class_data) | ||
confidence_score = 1 + 1.0 * log_sum / math.log( | ||
len(infer_result)) | ||
return confidence_score >= self.threshold_cross_entropy | ||
else: | ||
logger.warning("every value of infer_result should be in " | ||
f"[0,1], your data is {infer_result}") | ||
return False |
Empty file.
56 changes: 56 additions & 0 deletions
56
lib/neptune/hard_example_mining/object_detection/scores_filters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import logging | ||
|
||
from neptune.hard_example_mining import BaseFilter | ||
from neptune.hard_example_mining.hard_example_helpers import data_check | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class IBTFilter(BaseFilter): | ||
"""Implement the hard samples discovery methods named IBT | ||
(image-box-thresholds). | ||
:param threshold_img: threshold_img to filter img, whose hard coefficient | ||
is less than threshold_img. | ||
:param threshold_box: threshold_box to calculate hard coefficient, formula | ||
is hard coefficient = number(prediction_boxes less than | ||
threshold_box)/number(prediction_boxes) | ||
""" | ||
|
||
def __init__(self, threshold_img=0.5, threshold_box=0.5): | ||
self.threshold_box = threshold_box | ||
self.threshold_img = threshold_img | ||
|
||
def hard_judge(self, infer_result=None): | ||
"""Judge the img is hard sample or not. | ||
:param infer_result: | ||
prediction boxes list, | ||
such as [bbox1, bbox2, bbox3,....], | ||
where bbox = [xmin, ymin, xmax, ymax, score, label] | ||
score should be in [0,1], who will be ignored if its value not | ||
in [0,1]. | ||
:return: `True` means a hard sample, `False` means not a hard sample. | ||
""" | ||
if infer_result is None: | ||
logger.warning(f'infer result is invalid, value: {infer_result}!') | ||
return False | ||
elif len(infer_result) == 0: | ||
return False | ||
else: | ||
data_check_list = [bbox[4] for bbox in infer_result | ||
if data_check(bbox[4])] | ||
if len(data_check_list) == len(infer_result): | ||
confidence_score_list = [ | ||
float(box_score) for box_score in data_check_list | ||
if float(box_score) <= self.threshold_box] | ||
if (len(confidence_score_list) / len(infer_result)) \ | ||
>= (1 - self.threshold_img): | ||
return True | ||
else: | ||
return False | ||
else: | ||
logger.warning( | ||
"every value of infer_result should be in [0,1], " | ||
f"your data is {infer_result}") | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .joint_inference import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import json | ||
|
||
|
||
class ServiceInfo: | ||
def __init__(self): | ||
self.startTime = '' | ||
self.updateTime = '' | ||
self.inferenceNumber = 0 | ||
self.hardExampleNumber = 0 | ||
self.uploadCloudRatio = 0 | ||
|
||
@staticmethod | ||
def from_json(json_str): | ||
info = ServiceInfo() | ||
info.__dict__ = json.loads(json_str) | ||
return info | ||
|
||
def to_json(self): | ||
info = json.dumps(self.__dict__) | ||
return info |
Oops, something went wrong.