-
Notifications
You must be signed in to change notification settings - Fork 15
/
histogram_collection.py
200 lines (163 loc) · 6.79 KB
/
histogram_collection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from __future__ import annotations
from typing import TYPE_CHECKING, Container, Mapping, cast
import numpy as np
from physt._construction import calculate_1d_bins
from physt.binnings import BinningBase, as_binning
from physt.histogram1d import Histogram1D, ObjectWithBinning
if TYPE_CHECKING:
from typing import Any, Dict, Optional, Tuple
import physt
from physt.binnings import BinningLike
from physt.typing_aliases import ArrayLike
class HistogramCollection(Container[Histogram1D], ObjectWithBinning):
"""Experimental collection of histograms.
It contains (potentially name-addressable) 1-D histograms
with a shared binning.
"""
def __init__(
self,
*histograms: Histogram1D,
binning: Optional[BinningLike] = None,
title: Optional[str] = None,
name: Optional[str] = None,
):
self.histograms = list(histograms)
if histograms:
if binning:
raise ValueError(
"When creating collection from histograms, binning is deduced from them."
)
self._binning = histograms[0].binning
if not all(h.binning == self._binning for h in histograms):
raise ValueError("All histograms should share the same binning.")
else:
if binning is None:
raise ValueError("Either binning or at least one histogram must be provided.")
self._binning = as_binning(binning)
self.name = name
self.title = title or self.name
def __contains__(self, item):
try:
_ = self[item]
return True
except KeyError:
return False
def __iter__(self):
return iter(self.histograms)
def __len__(self):
return len(self.histograms)
def copy(self) -> "HistogramCollection":
# TODO: The binnings are probably not consistent in the copies
binning_copy = self.binning.copy()
histograms = [h.copy() for h in self.histograms]
for histogram in histograms:
histogram._binning = binning_copy
return HistogramCollection(*histograms, title=self.title, name=self.name)
@property
def binning(self) -> BinningBase:
return self._binning
@property
def axis_name(self) -> str:
return self.histograms[0].axis_name if self.histograms else "axis0"
@property
def axis_names(self) -> Tuple[str]:
return (self.axis_name,)
def add(self, histogram: Histogram1D) -> None:
"""Add a histogram to the collection."""
if self.binning and not self.binning == histogram.binning:
raise ValueError("Cannot add histogram with different binning.")
self.histograms.append(histogram)
def create(
self, name: str, values, *, weights=None, dropna: bool = True, **kwargs
) -> Histogram1D:
# TODO: Rename!
init_kwargs: Dict[str, Any] = {"axis_name": self.axis_name}
init_kwargs.update(kwargs)
histogram = Histogram1D(binning=self.binning, name=name, **init_kwargs)
histogram.fill_n(values, weights=weights, dropna=dropna)
self.histograms.append(histogram)
return histogram
def __getitem__(self, item) -> Histogram1D:
if isinstance(item, str):
candidates = [h for h in self.histograms if h.name == item]
if not candidates:
raise KeyError(f"Collection does not contain histogram named '{item}'.")
return candidates[0]
return self.histograms[item]
def __eq__(self, other) -> bool:
return (
(type(other) == HistogramCollection)
and (len(other) == len(self))
and all((h1 == h2) for h1, h2 in zip(self.histograms, other.histograms))
)
def normalize_bins(self, inplace: bool = False) -> "HistogramCollection":
"""Normalize each bin in the collection so that the sum is 1.0 for each bin.
Note: If a bin is zero in all collections, the result will be inf.
"""
col = self if inplace else self.copy()
sums = self.sum().frequencies
for h in col.histograms:
h.set_dtype(float)
h._frequencies /= sums
h._errors2 /= sums**2 # TODO: Does this make sense?
return col
def normalize_all(self, inplace: bool = False) -> "HistogramCollection":
"""Normalize all histograms so that total content of each of them is equal to 1.0."""
col = self if inplace else self.copy()
for h in col.histograms:
h.normalize(inplace=True)
return col
def sum(self) -> Histogram1D:
"""Return the sum of all contained histograms."""
if not self.histograms:
return Histogram1D(
data=np.zeros((self.binning.bin_count)), dtype=np.int64, binning=self.binning
)
return cast(Histogram1D, sum(self.histograms))
@property
def plot(self) -> "physt.plotting.PlottingProxy":
"""Proxy to plotting.
This attribute is a special proxy to plotting. In the most
simple cases, it can be used as a method. For more sophisticated
use, see the documentation for physt.plotting package.
"""
from physt.plotting import PlottingProxy
return PlottingProxy(self)
@classmethod
def multi_h1(
cls, a_dict: Mapping[str, ArrayLike], bins=None, **kwargs
) -> "HistogramCollection":
"""Create a collection from multiple datasets."""
# TODO: Change into a function in facade
mega_values: np.ndarray = np.concatenate(list(a_dict.values())) # type: ignore
binning = calculate_1d_bins(mega_values, bins, **kwargs)
title = kwargs.pop("title", None)
name = kwargs.pop("name", None)
collection = HistogramCollection(binning=binning, title=title, name=name)
for key, value in a_dict.items():
collection.create(key, value)
return collection
@classmethod
def from_dict(cls, a_dict: Dict[str, Any]) -> "HistogramCollection":
from physt.io import create_from_dict
histograms = (
cast(Histogram1D, create_from_dict(item, "HistogramCollection", check_version=False))
for item in a_dict["histograms"]
)
return HistogramCollection(*histograms)
def to_dict(self) -> Dict[str, Any]:
return {
"histogram_type": "histogram_collection",
"histograms": [h.to_dict() for h in self.histograms],
}
def to_json(self, path: Optional[str] = None, **kwargs) -> str:
"""Convert to JSON representation.
Parameters
----------
path: Where to write the JSON.
Returns
-------
The JSON representation.
"""
from .io import save_json
return save_json(self, path, **kwargs)