Skip to content

Commit

Permalink
Merge pull request #11 from lucabaldini/line_fitting
Browse files Browse the repository at this point in the history
Line fitting
  • Loading branch information
lucabaldini committed Oct 11, 2023
2 parents 517906d + 6715c21 commit 729c0c1
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 146 deletions.
8 changes: 7 additions & 1 deletion docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ Release notes
=============


* Merging https://github.com/lucabaldini/hexsample/pull/10
* uncertainties added as a requirement.
* PlotCard class completely refactored.
* Updating the hxview script.


*hexsample (0.1.0) - Tue, 10 Oct 2023 10:31:12 +0200*

* Merging https://github.com/lucabaldini/hexsample/pull/10
* Initial setup of the repository.
* Simple versioning system in plac
* Simple versioning system in plac
7 changes: 7 additions & 0 deletions hexsample/bin/hxview.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import numpy as np

from hexsample.app import ArgumentParser
from hexsample.fitting import fit_histogram
from hexsample.hist import Histogram1d
from hexsample.io import ReconInputFile
from hexsample.modeling import Gaussian
from hexsample.plot import plt


Expand All @@ -50,6 +52,11 @@ def hxview(**kwargs):
binning = np.linspace(rec_energy.min(), rec_energy.max(), 100)
h_rec = Histogram1d(binning).fill(rec_energy)
h_rec.plot()
model = Gaussian() + Gaussian()
#model = fit_histogram(GaussianLineForestCuK(), h_rec)
fit_histogram(model, h_rec, p0=(1., 8000., 150., 1., 8900., 150.))
model.plot()
model.stat_box()
#h_mc = Histogram1d(binning).fill(mc_energy)
#h_mc.plot()
plt.figure('Cluster size')
Expand Down
113 changes: 10 additions & 103 deletions hexsample/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,98 +19,11 @@

import numpy as np

from hexsample.plot import plt
from hexsample.plot import plt, PlotCard

# pylint: disable=invalid-name


class StatBox:

"""Class describing a text box, to be used for the fit stat boxes.
Parameters
----------
position : str of tuple
It can either be a two-element tuple (in which case the argument is
interpreted as a position in absolute coordinates, with the reference
corner determined by the alignment flags), or a string in the
set ['upper left', 'upper right', 'lower left', 'lower rigth'].
If position is a string, the alignment flags are ignored.
halign : str
The horizontal alignment ('left' | 'center' | 'right')
valign : str
The vertical alignment ('top' | 'center' | 'bottom')
"""

HORIZONTAL_PADDING = 0.025
VERTICAL_PADDING = 0.035
_left, _right = HORIZONTAL_PADDING, 1 - HORIZONTAL_PADDING
_bottom, _top = VERTICAL_PADDING, 1 - VERTICAL_PADDING
POSITION_DICT = {
'upper left': (_left, _top, 'left', 'top'),
'upper right': (_right, _top, 'right', 'top'),
'lower left': (_left, _bottom, 'left', 'bottom'),
'lower right': (_right, _bottom, 'right', 'bottom')
}
DEFAULT_BBOX = dict(boxstyle='round', facecolor='white', alpha=0.75)

def __init__(self, position : str = 'upper left', halign : str = 'left',
valign : str = 'top') -> None:
"""Constructor.
"""
self.set_position(position, halign, valign)
self.text = ''

def set_position(self, position, halign='left', valign='top'):
"""Set the position of the bounding box.
"""
if isinstance(position, str):
self.x0, self.y0, self.halign,\
self.valign = self.POSITION_DICT[position]
else:
self.x0, self.y0 = position
self.halign, self.valign = halign, valign

def add_entry(self, label, value=None, error=None):
"""Add an entry to the stat box.
"""
if value is None and error is None:
self.text += '%s\n' % label
elif value is not None and error is None:
try:
self.text += '%s: %g\n' % (label, value)
except TypeError:
self.text += '%s: %s\n' % (label, value)
elif value is not None and error is not None:
if error > 0:
self.text += '%s: %g $\\pm$ %g\n' % (label, value, error)
else:
self.text += '%s: %g (frozen)\n' % (label, value)

def plot(self, **kwargs):
"""Plot the stat box.
Parameters
----------
**kwargs : dict
The options to be passed to `plt.text()`
"""
def set_kwargs_default(key, value):
"""
"""
if key not in kwargs:
kwargs[key] = value

set_kwargs_default('horizontalalignment', self.halign)
set_kwargs_default('verticalalignment', self.valign)
set_kwargs_default('bbox', self.DEFAULT_BBOX)
set_kwargs_default('transform', plt.gca().transAxes)
plt.text(self.x0, self.y0, self.text.strip('\n'), **kwargs)



class FitModelBase:

"""Base class for a fittable model.
Expand Down Expand Up @@ -366,26 +279,19 @@ def plot(self, *parameters, **kwargs):
"""
if len(parameters) == len(self):
self.parameters = parameters
display_stat_box = kwargs.pop('display_stat_box', False)
x = np.linspace(self.xmin, self.xmax, 1000)
y = self(x, *parameters)
plt.plot(x, y, **kwargs)
if display_stat_box:
self.stat_box(**kwargs)

def stat_box(self, position=None, plot=True, **kwargs):
def stat_box(self, **kwargs):
"""Plot a ROOT-style stat box for the model.
"""
if position is None:
position = self.DEFAULT_STAT_BOX_POSITION
box = StatBox(position)
box.add_entry('Fit model: %s' % self.name())
box.add_entry('Chisquare', '%.1f / %d' % (self.chisq, self.ndof))
box = PlotCard()
box.add_string('Fit model', self.name())
box.add_string('Chisquare', '%.1f / %d' % (self.chisq, self.ndof))
for name, value, error in self.parameter_status():
box.add_entry(name, value, error)
if plot:
box.plot(**kwargs)
return box
box.add_quantity(name, value, error)
box.plot(**kwargs)

def __len__(self):
"""Return the number of model parameters.
Expand All @@ -407,9 +313,10 @@ def __add__(self, other):

class _model(FitModelBase):

PARAMETER_NAMES = m1.PARAMETER_NAMES + m2.PARAMETER_NAMES
PARAMETER_NAMES = [f'{name}1' for name in m1.PARAMETER_NAMES] + \
[f'{name}2' for name in m2.PARAMETER_NAMES]
PARAMETER_DEFAULT_VALUES = m1.PARAMETER_DEFAULT_VALUES + \
m2.PARAMETER_DEFAULT_VALUES
m2.PARAMETER_DEFAULT_VALUES
DEFAULT_PLOTTING_RANGE = (xmin, xmax)
PARAMETER_DEFAULT_BOUNDS = (-np.inf, np.inf)

Expand Down
114 changes: 72 additions & 42 deletions hexsample/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from loguru import logger
import matplotlib
from matplotlib import pyplot as plt
import uncertainties

if sys.flags.interactive:
plt.ion()
Expand Down Expand Up @@ -74,53 +75,88 @@ class PlotCard(dict):

"""Small class reperesenting a text card.
This is essentially a dictionary that is capable of plotting itself on
a matplotlib figure in the form of a multi-line graphic card.
This is essentially a list of key-value pairs that is capable of plotting
itself on a matplotlib figure in the form of a multi-line graphic card.
Possible uses include a statistical box for the results of a fit.
Arguments
---------
data : dict
A dictionary holding the lines to be displayed in the card.
Note that the semantics of the object is intentionally simple---add stuff
and plot once; we do non support, e.g., updating the values after the fact.
"""

KEY_KWARGS = dict(color='gray', size='x-small', ha='left', va='top')
VALUE_KWARGS = dict(color='black', size='small', ha='left', va='top')
_LABEL_KWARGS = dict(color='gray', size='x-small', ha='right', va='top')
_CONTENT_KWARGS = dict(color='black', size='small', ha='right', va='top')

def __init__(self, data : dict = None) -> None:
def __init__(self) -> None:
"""Constructor.
"""
super().__init__()
if data is not None:
for key, value in data.items():
self.add_line(key, value)
self._item_list = []

def add_line(self, key : str, value : float, fmt : str = '%g', units : str = None) -> None:
"""Set the value for a given key.
def add_string(self, label : str, content : str) -> None:
"""Add a label-content pair to the card. This is the workhorse methods,
and specialized methods below use this internally.
Arguments
---------
key : str
The key, i.e., the explanatory text for a given value.
label : str
The label, i.e., the explanatory text for a given value.
value : float, optional
The actual value (if None, a blank line will be added).
content : str
The actual text content.
"""
self._item_list.append((label, content))

def add_quantity(self, label : str, value : float, error : float = None,
fmt : str = None, units : str = None) -> None:
"""Add a numerical quantity to the card.
This can be either a numerical value, or the result of a measurement
(i.e., including its uncertainity).
Arguments
---------
label : str
The label, i.e., the explanatory text for a given value.
value : float
The numerical value of the quantity.
error : float, optional
The uncertainty of the measurement.
fmt : str
The string format to be used to render the value.
fmt : str, optional
Optional format string for the quantity---this is ignored if the error
is defined, as in that case the rultes for the significant digits take
precedence.
units : str
The measurement units for the value.
units : str, optional
Optional measurement errors.
"""
self[key] = (value, fmt, units)
if error is not None and error > 0.:
content = f'{uncertainties.ufloat(value, error):P}'
elif fmt is not None:
content = f'{value:{fmt}}'
else:
content = f'{value}'
if units is not None:
content = f'{content} {units}'
self.add_string(label, content)

def add_blank(self) -> None:
"""Add a blank line.
"""
self.add_string('', '')

def draw(self, axes = None, x : float = 0.05, y : float = 0.95, line_spacing : float = 0.075,
spacing_ratio : float = 0.75) -> None:
"""Draw the card.
def plot(self, x : float = 0.95, y : float = 0.95, line_spacing : float = 0.075,
spacing_ratio : float = 0.75, **kwargs) -> None:
"""Plot the card.
Arguments
---------
x0, y0 : float
The absolute coordinates of the top-left corner of the card.
x : float
The absolute x-coordinate of the top-left corner of the card.
y : float
The absolute x-coordinate of the top-left corner of the card.
line_spacing : float
The line spacing in units of the total height of the current axes.
Expand All @@ -129,22 +165,16 @@ def draw(self, axes = None, x : float = 0.05, y : float = 0.95, line_spacing : f
The fractional line spacing assigned to the key label.
"""
# pylint: disable=invalid-name
if axes is None:
axes = plt.gca()
key_norm = spacing_ratio / (1. + spacing_ratio)
value_norm = 1. - key_norm
for kwargs in (self.KEY_KWARGS, self.VALUE_KWARGS):
kwargs['transform'] = axes.transAxes
for key, (value, fmt, units) in self.items():
if value is None:
y -= 0.5 * line_spacing
continue
axes.text(x, y, key, **self.KEY_KWARGS)
label_kwargs = self._LABEL_KWARGS.copy()
label_kwargs.update(kwargs)
content_kwargs = self._CONTENT_KWARGS.copy()
content_kwargs.update(kwargs)
for label, content in self._item_list:
plt.gca().text(x, y, label, transform=plt.gca().transAxes, **label_kwargs)
y -= key_norm * line_spacing
value = fmt % value
if units is not None:
value = f'{value} {units}'
axes.text(x, y, value, **self.VALUE_KWARGS)
plt.gca().text(x, y, content, transform=plt.gca().transAxes, **content_kwargs)
y -= value_norm * line_spacing


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ dependencies = [
"scipy",
"tables",
"tqdm",
"uncertainties",
"xraydb"
]

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pydata-sphinx-theme
scipy
tables
tqdm
uncertainties
xraydb
41 changes: 41 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2022 luca.baldini@pi.infn.it
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Test suite for plot.py
"""

from hexsample.plot import plt, PlotCard


def test_card():
"""Test for the plot cards.
"""
card = PlotCard()
card.add_string('Label', 'Content')
card.add_blank()
card.add_quantity('Fixed float', 1.0)
card.add_quantity('Formatted fixed float', 1.0, fmt='.5f')
card.add_quantity('Fixed int', 1)
card.add_quantity('Parameter 1', 1.23456, 0.53627)
card.add_quantity('Fixed float', 1.0, units='cm')
card.add_quantity('Fixed int', 1, units='cm')
card.add_quantity('Parameter 1', 1.23456, 0.53627, units='cm')
card.plot()
card.plot(0.05, 0.95, ha='left')


if __name__ == '__main__':
test_card()
plt.show()

0 comments on commit 729c0c1

Please sign in to comment.