Skip to content

Commit

Permalink
more type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Jul 21, 2017
1 parent 94c45c5 commit 8ec4f13
Show file tree
Hide file tree
Showing 18 changed files with 199 additions and 58 deletions.
1 change: 1 addition & 0 deletions eli5/_decision_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_decision_path_explanation(estimator, doc, vec, vectorized,
target_names, targets, top_targets,
is_regression, is_multiclass, proba,
get_score_weights):
# type: (...) -> Explanation

display_names = get_target_display_names(
original_display_names, target_names, targets, top_targets, proba)
Expand Down
2 changes: 2 additions & 0 deletions eli5/_feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def get_feature_importances_filtered(coef, feature_names, flt_indices, top,
coef_std=None):
# type: (...) -> FeatureImportances
if flt_indices is not None:
coef = coef[flt_indices]
if coef_std is not None:
Expand All @@ -27,6 +28,7 @@ def get_feature_importance_explanation(estimator, vec, coef, feature_names,
estimator_feature_names=None,
num_features=None,
coef_std=None):
# type: (...) -> Explanation
feature_names, flt_indices = get_feature_names_filtered(
estimator, vec,
feature_names=feature_names,
Expand Down
37 changes: 30 additions & 7 deletions eli5/_feature_names.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import re
import six
from typing import Any, Iterable, Tuple, Sized, List
from typing import (
Any, Iterable, Iterator, Tuple, Sized, List, Optional, Dict,
Union, Callable, Pattern
)

import numpy as np # type: ignore
import scipy.sparse as sp # type: ignore


class FeatureNames(Sized):
class FeatureNames(Sized, Iterable):
"""
A list-like object with feature names. It allows
feature names for unknown features to be generated using
a provided template, and to avoid making copies of large objects
in get_feature_names.
"""
def __init__(self, feature_names=None, bias_name=None,
unkn_template=None, n_features=None):
def __init__(self,
feature_names=None,
bias_name=None, # type: str
unkn_template=None, # type: str
n_features=None, # type: int
):
# type: (...) -> None
if not (feature_names is not None or
(unkn_template is not None and n_features)):
raise ValueError(
Expand All @@ -31,16 +39,22 @@ def __init__(self, feature_names=None, bias_name=None,
'unkn_template should be set for sparse features')
self.feature_names = feature_names
self.unkn_template = unkn_template
self.n_features = n_features or len(feature_names)
self.n_features = n_features or len(feature_names) # type: int
self.bias_name = bias_name

def __repr__(self):
# type: () -> str
return '<FeatureNames: {} features {} bias>'.format(
self.n_features, 'with' if self.has_bias else 'without')

def __len__(self):
# type: () -> int
return self.n_features + int(self.has_bias)

def __iter__(self):
# type: () -> Iterator[str]
return (self[i] for i in range(len(self)))

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)
Expand Down Expand Up @@ -71,15 +85,18 @@ def _slice(self, aslice):

@property
def has_bias(self):
# type: () -> bool
return self.bias_name is not None

@property
def bias_idx(self):
# type: () -> Optional[int]
if self.has_bias:
return self.n_features
return None

def filtered(self, feature_filter, x=None):
# type: (Any, Any) -> Tuple[FeatureNames, List[int]]
# type: (Callable, Any) -> Tuple[FeatureNames, List[int]]
""" Return feature names filtered by a regular expression
``feature_re``, and indices of filtered elements.
"""
Expand Down Expand Up @@ -120,7 +137,12 @@ def filtered(self, feature_filter, x=None):
),
indices)

def handle_filter(self, feature_filter, feature_re, x=None):
def handle_filter(self,
feature_filter,
feature_re, # type: Pattern[str]
x=None, # type: Any
):
# type: (...) -> Tuple[FeatureNames, Union[List[int], None]]
if feature_re is not None and feature_filter:
raise ValueError('pass either feature_filter or feature_re')
if feature_re is not None:
Expand Down Expand Up @@ -156,6 +178,7 @@ def add_feature(self, feature):


def _all_feature_names(name):
# type: (Union[str, bytes, List[Dict]]) -> List[str]
""" All feature names for a feature: usually just the feature itself,
but can be several features for unhashed features with collisions.
"""
Expand Down
4 changes: 3 additions & 1 deletion eli5/_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


def is_supported():
# type: () -> bool
try:
graphviz.Graph().pipe('svg')
return True
Expand All @@ -11,8 +12,9 @@ def is_supported():


def dot2svg(dot):
# type: (str) -> str
""" Render Graphviz data to SVG """
svg = graphviz.Source(dot).pipe(format='svg').decode('utf8')
svg = graphviz.Source(dot).pipe(format='svg').decode('utf8') # type: str
# strip doctype and xml declaration
svg = svg[svg.index('<svg'):]
return svg
4 changes: 2 additions & 2 deletions eli5/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Any, List, Tuple, Union
from typing import Any, List, Tuple, Union, Optional

from .base_utils import attrs
from .formatters.features import FormattedFeatureName
Expand All @@ -23,7 +23,7 @@ def __init__(self,
targets=None, # type: List[TargetExplanation]
feature_importances=None, # type: FeatureImportances
decision_tree=None, # type: TreeInfo
highlight_spaces=None,
highlight_spaces=None, # type: Optional[bool]
transition_features=None, # type: TransitionFeatureWeights
):
# type: (...) -> None
Expand Down
63 changes: 49 additions & 14 deletions eli5/formatters/html.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from itertools import groupby
from typing import List
from typing import List, Optional, Tuple

import numpy as np # type: ignore
from jinja2 import Environment, PackageLoader # type: ignore

from eli5 import _graphviz
from eli5.base import TargetExplanation
from eli5.base import (Explanation, TargetExplanation, FeatureWeights,
FeatureWeight)
from eli5.utils import max_or_0
from .utils import (
format_signed, format_value, format_weight, has_any_values_for_weights,
Expand All @@ -33,10 +34,16 @@
))


def format_as_html(explanation, include_styles=True, force_weights=True,
show=fields.ALL, preserve_density=None,
highlight_spaces=None, horizontal_layout=True,
show_feature_values=False):
def format_as_html(explanation, # type: Explanation
include_styles=True, # type: bool
force_weights=True, # type: bool
show=fields.ALL,
preserve_density=None, # type: Optional[bool]
highlight_spaces=None, # type: Optional[bool]
horizontal_layout=True, # type: bool
show_feature_values=False # type: bool
):
# type: (...) -> str
""" Format explanation as html.
Most styles are inline, but some are included separately in <style> tag,
you can omit them by passing ``include_styles=False`` and call
Expand Down Expand Up @@ -125,14 +132,18 @@ def format_as_html(explanation, include_styles=True, force_weights=True,


def format_html_styles():
# type: () -> str
""" Format just the styles,
use with ``format_as_html(explanation, include_styles=False)``.
"""
return template_env.get_template('styles.html').render()


def render_targets_weighted_spans(targets, preserve_density):
# type: (List[TargetExplanation], bool) -> List[str]
def render_targets_weighted_spans(
targets, # type: List[TargetExplanation]
preserve_density, # type: Optional[bool]
):
# type: (...) -> List[str]
""" Return a list of rendered weighted spans for targets.
Function must accept a list in order to select consistent weight
ranges across all targets.
Expand Down Expand Up @@ -163,7 +174,11 @@ def render_weighted_spans(pws):
key=lambda x: x[1]))


def _colorize(token, weight, weight_range):
def _colorize(token, # type: str
weight, # type: float
weight_range, # type: float
):
# type: (...) -> str
""" Return token wrapped in a span with some styles
(calculated from weight and weight_range) applied.
"""
Expand Down Expand Up @@ -191,17 +206,22 @@ def _colorize(token, weight, weight_range):


def _weight_opacity(weight, weight_range):
# type: (float, float) -> str
""" Return opacity value for given weight as a string.
"""
min_opacity = 0.8
if np.isclose(weight, 0) and np.isclose(weight_range, 0):
rel_weight = 0
rel_weight = 0.0
else:
rel_weight = abs(weight) / weight_range
return '{:.2f}'.format(min_opacity + (1 - min_opacity) * rel_weight)


_HSL_COLOR = Tuple[float, float, float]


def weight_color_hsl(weight, weight_range, min_lightness=0.8):
# type: (float, float, float) -> _HSL_COLOR
""" Return HSL color components for given weight,
where the max absolute weight is given by weight_range.
"""
Expand All @@ -213,32 +233,41 @@ def weight_color_hsl(weight, weight_range, min_lightness=0.8):


def format_hsl(hsl_color):
# type: (_HSL_COLOR) -> str
""" Format hsl color as css color string.
"""
hue, saturation, lightness = hsl_color
return 'hsl({}, {:.2%}, {:.2%})'.format(hue, saturation, lightness)


def _hue(weight):
# type: (float) -> float
return 120 if weight > 0 else 0


def get_weight_range(weights):
# type: (FeatureWeights) -> float
""" Max absolute feature for pos and neg weights.
"""
return max_or_0(abs(fw.weight) for lst in [weights.pos, weights.neg]
return max_or_0(abs(fw.weight)
for lst in [weights.pos, weights.neg]
for fw in lst or [])


def remaining_weight_color_hsl(ws, weight_range, pos_neg):
def remaining_weight_color_hsl(
ws, # type: List[FeatureWeight]
weight_range, # type: float
pos_neg, # type: str
):
# type: (...) -> _HSL_COLOR
""" Color for "remaining" row.
Handles a number of edge cases: if there are no weights in ws or weight_range
is zero, assume the worst (most intensive positive or negative color).
"""
sign = {'pos': 1, 'neg': -1}[pos_neg]
sign = {'pos': 1.0, 'neg': -1.0}[pos_neg]
if not ws and not weight_range:
weight = sign
weight_range = 1
weight_range = 1.0
elif not ws:
weight = sign * weight_range
else:
Expand All @@ -247,6 +276,7 @@ def remaining_weight_color_hsl(ws, weight_range, pos_neg):


def _format_unhashed_feature(feature, weight, hl_spaces):
# type: (...) -> str
""" Format unhashed feature: show first (most probable) candidate,
display other candidates in title attribute.
"""
Expand All @@ -263,6 +293,7 @@ def _format_unhashed_feature(feature, weight, hl_spaces):


def _format_feature(feature, weight, hl_spaces):
# type: (...) -> str
""" Format any feature.
"""
if isinstance(feature, FormattedFeatureName):
Expand All @@ -275,11 +306,13 @@ def _format_feature(feature, weight, hl_spaces):


def _format_single_feature(feature, weight, hl_spaces):
# type: (str, float, bool) -> str
feature = html_escape(feature)
if not hl_spaces:
return feature

def replacer(n_spaces, side):
# type: (int, str) -> str
m = '0.1em'
margins = {'left': (m, 0), 'right': (0, m), 'center': (m, m)}[side]
style = '; '.join([
Expand All @@ -296,13 +329,15 @@ def replacer(n_spaces, side):


def _format_decision_tree(treedict):
# type: (...) -> str
if treedict.graphviz and _graphviz.is_supported():
return _graphviz.dot2svg(treedict.graphviz)
else:
return tree2text(treedict)


def html_escape(text):
# type: (str) -> str
try:
from html import escape
except ImportError:
Expand Down

0 comments on commit 8ec4f13

Please sign in to comment.