-
Notifications
You must be signed in to change notification settings - Fork 2
/
__init__.py
120 lines (109 loc) · 4.16 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) 2017, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
"""
List of all external dependancies for this package. Imported as
optional includes
"""
from distutils.version import StrictVersion as _StrictVersion
import logging as _logging
import re as _re
def __get_version(version):
# matching 1.6.1, and 1.6.1rc, 1.6.1.dev
version_regex = '^\d+\.\d+\.\d+'
version = _re.search(version_regex, str(version)).group(0)
return _StrictVersion(version)
# ---------------------------------------------------------------------------------------
HAS_SKLEARN = True
SKLEARN_MIN_VERSION = '0.15'
def __get_sklearn_version(version):
# matching 0.15b, 0.16bf, etc
version_regex = '^\d+\.\d+'
version = _re.search(version_regex, str(version)).group(0)
return _StrictVersion(version)
try:
import sklearn
if __get_sklearn_version(sklearn.__version__) < _StrictVersion(SKLEARN_MIN_VERSION):
HAS_SKLEARN = False
_logging.warn(('scikit-learn version %s is not supported. Minimum required version: %s. '
'Disabling scikit-learn conversion API.')
% (sklearn.__version__, SKLEARN_MIN_VERSION) )
except:
HAS_SKLEARN = False
# ---------------------------------------------------------------------------------------
HAS_LIBSVM = True
try:
import svm
except:
HAS_LIBSVM = False
# ---------------------------------------------------------------------------------------
HAS_XGBOOST = True
try:
import xgboost
except:
HAS_XGBOOST = False
# ---------------------------------------------------------------------------------------
HAS_KERAS_TF = True
HAS_KERAS2_TF = True
KERAS_MIN_VERSION = '1.2.2'
KERAS_MAX_VERSION = '2.0.4'
TF_MIN_VERSION = '1.0.0'
TF_MAX_VERSION = '1.1.1'
try:
# Prevent keras from printing things that are not errors to standard error.
import sys
import StringIO
stderr = sys.stderr
try:
temp = StringIO.StringIO()
sys.stderr = temp
import keras
except:
# Print out any actual error message and re-raise.
sys.stderr = stderr
sys.stderr.write(temp.getvalue())
raise
finally:
sys.stderr = stderr
import tensorflow
tf_ver = __get_version(tensorflow.__version__)
k_ver = __get_version(keras.__version__)
# keras 1 version too old
if k_ver < _StrictVersion(KERAS_MIN_VERSION):
HAS_KERAS_TF = False
HAS_KERAS2_TF = False
_logging.warn(('Keras version %s is not supported. Minimum required version: %s .'
'Keras conversion will be disabled.')
% (keras.__version__, KERAS_MIN_VERSION))
# keras version too new
if k_ver > _StrictVersion(KERAS_MAX_VERSION):
HAS_KERAS_TF = False
_logging.warn(('Keras version %s detected. Last version known to be fully compatible of Keras is %s .')
% (keras.__version__, KERAS_MAX_VERSION))
# Using Keras 2 rather than 1
if k_ver >= _StrictVersion('2.0.0'):
HAS_KERAS_TF = False
HAS_KERAS2_TF = True
# Using Keras 1 rather than 2
else:
HAS_KERAS_TF = True
HAS_KERAS2_TF = False
# TensorFlow too old
if tf_ver < _StrictVersion(TF_MIN_VERSION):
HAS_KERAS_TF = False
HAS_KERAS2_TF = False
_logging.warn(('TensorFlow version %s is not supported. Minimum required version: %s .'
'Keras conversion will be disabled.')
% (tensorflow.__version__, TF_MIN_VERSION))
if tf_ver > _StrictVersion(TF_MAX_VERSION):
_logging.warn(('TensorFlow version %s detected. Last version known to be fully compatible is %s .')
% (tensorflow.__version__, TF_MAX_VERSION))
if keras.backend.backend() != 'tensorflow':
HAS_KERAS_TF = False
HAS_KERAS2_TF = False
_logging.warn(('Unsupported Keras backend (only Tensorflow is currently supported). '
'Keras conversion will be disabled.'))
except:
HAS_KERAS_TF = False
HAS_KERAS2_TF = False