-
-
Notifications
You must be signed in to change notification settings - Fork 389
/
_metric.py
171 lines (130 loc) · 4.78 KB
/
_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from typing import Any, Dict
from abc import ABC, abstractmethod
class IMetric(ABC):
"""Interface for all Metrics.
Args:
compute_on_call: Computes and returns metric value during metric call.
Used for per-batch logging.
default: ``True``
"""
def __init__(self, compute_on_call: bool = True):
"""Interface for all Metrics."""
self.compute_on_call = compute_on_call
@abstractmethod
def reset(self) -> None:
"""Resets the metric to it's initial state.
By default, this is called at the start of each loader
(`on_loader_start` event).
"""
pass
@abstractmethod
def update(self, *args, **kwargs) -> Any:
"""Updates the metrics state using the passed data.
By default, this is called at the end of each batch
(`on_batch_end` event).
Args:
*args: some args :)
**kwargs: some kwargs ;)
"""
pass
@abstractmethod
def compute(self) -> Any:
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each loader
(`on_loader_end` event).
Returns:
Any: computed value, # noqa: DAR202
it's better to return key-value
"""
pass
def __call__(self, *args, **kwargs) -> Any:
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each batch
(`on_batch_end` event).
Returns computed value if `compute_on_call=True`.
Args:
*args: Arguments passed to update method.
**kwargs: Keyword-arguments passed to update method.
Returns:
Any: computed value, it's better to return key-value.
"""
value = self.update(*args, **kwargs)
return self.compute() if self.compute_on_call else value
class ICallbackBatchMetric(IMetric):
"""Interface for all batch-based Metrics."""
def __init__(
self, compute_on_call: bool = True, prefix: str = None, suffix: str = None
):
"""Init"""
super().__init__(compute_on_call=compute_on_call)
self.prefix = prefix or ""
self.suffix = suffix or ""
@abstractmethod
def update_key_value(self, *args, **kwargs) -> Dict[str, float]:
"""Updates the metric based with new input.
By default, this is called at the end of each loader
(`on_loader_end` event).
Args:
*args: some args
**kwargs: some kwargs
Returns:
Dict: computed value in key-value format. # noqa: DAR202
"""
pass
@abstractmethod
def compute_key_value(self) -> Dict[str, float]:
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each loader
(`on_loader_end` event).
Returns:
Dict: computed value in key-value format. # noqa: DAR202
"""
pass
class ICallbackLoaderMetric(IMetric):
"""Interface for all loader-based Metrics.
Args:
compute_on_call: Computes and returns metric value during metric call.
Used for per-batch logging.
default: ``True``
prefix: metrics prefix
suffix: metrics suffix
"""
def __init__(
self, compute_on_call: bool = True, prefix: str = None, suffix: str = None
):
"""Init."""
super().__init__(compute_on_call=compute_on_call)
self.prefix = prefix or ""
self.suffix = suffix or ""
@abstractmethod
def reset(self, num_batches: int, num_samples: int) -> None:
"""Resets the metric to it's initial state.
By default, this is called at the start of each loader
(`on_loader_start` event).
Args:
num_batches: number of expected batches.
num_samples: number of expected samples.
"""
pass
@abstractmethod
def update(self, *args, **kwargs) -> None:
"""Updates the metrics state using the passed data.
By default, this is called at the end of each batch
(`on_batch_end` event).
Args:
*args: some args :)
**kwargs: some kwargs ;)
"""
pass
@abstractmethod
def compute_key_value(self) -> Dict[str, float]:
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each loader
(`on_loader_end` event).
Returns:
Dict: computed value in key-value format. # noqa: DAR202
"""
# @TODO: could be refactored - we need custom exception here
# we need this method only for callback metric logging
pass
__all__ = ["IMetric", "ICallbackBatchMetric", "ICallbackLoaderMetric"]