Skip to content

Commit

Permalink
Merge pull request #137 from jiangyi15/plot_interf
Browse files Browse the repository at this point in the history
feat: config.plot_partial_wave_interf(res1, res2)
  • Loading branch information
jiangyi15 committed Feb 5, 2024
2 parents 47285b5 + fea1d3f commit a2be967
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 11 deletions.
70 changes: 59 additions & 11 deletions tf_pwa/amp/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def get_bin_index(self, m):

@register_particle("linear_npy")
class InterpLinearNpy(InterpolationParticle):
"""
Linear interpolation model from a `.npy` file with array of [mi, re(ai), im(ai)].
Required `file: path_of_file.npy`, for the path of `.npy` file.
The example is `exp(5 I m)`.
.. plot::
>>> import tempfile
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".npy")
>>> m = np.linspace(0.2, 0.9)
>>> mi = m[::5]
>>> np.save(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_npy", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.load(self.input_file)
Expand All @@ -103,29 +124,56 @@ def init_params(self):

def get_point_values(self):
v_r = np.concatenate([[0.0], self.data[:, 1], [0.0]])
v_i = np.concatenate([[0.0], self.data[:, 1], [0.0]])
v_i = np.concatenate([[0.0], self.data[:, 2], [0.0]])
return self.data[:, 0], v_r, v_i

def interp(self, m):
x, p_r, p_i = self.get_point_values()
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x)
bin_idx = (bin_idx + len(self.bound)) % len(self.bound)
ret_r_l = tf.gather(p_r[1:], bin_idx)
ret_i_l = tf.gather(p_r[1:], bin_idx)
ret_r_r = tf.gather(p_r[:-1], bin_idx)
ret_i_r = tf.gather(p_r[:-1], bin_idx)
delta = np.concatenate(
[[1.0], self.data[1:, 1] - self.data[:-1, 1], [1.0]]
)
bin_idx = (bin_idx) % (len(self.bound) + 1)
ret_r_r = tf.gather(p_r[1:], bin_idx)
ret_i_r = tf.gather(p_i[1:], bin_idx)
ret_r_l = tf.gather(p_r[:-1], bin_idx)
ret_i_l = tf.gather(p_i[:-1], bin_idx)
delta = np.concatenate([[1e20], x[1:] - x[:-1], [1e20]])
x_left = np.concatenate([[x[0] - 1], x])
delta = tf.gather(delta, bin_idx)
x_left = tf.gather(x_left, bin_idx)
step = (m - x_left) / delta
a = step * (ret_r_l - ret_r_r)
b = step * (ret_i_l - ret_i_r)
a = step * (ret_r_r - ret_r_l) + ret_r_l
b = step * (ret_i_r - ret_i_l) + ret_i_l
return tf.complex(a, b)


@register_particle("linear_txt")
class InterpLinearTxt(InterpLinearNpy):
"""Linear interpolation model from a `.txt` file with array of [mi, re(ai), im(ai)].
Required `file: path_of_file.txt`, for the path of `.txt` file.
The example is `exp(5 I m)`.
.. plot::
>>> import tempfile
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".txt")
>>> m = np.linspace(0.2, 0.9)
>>> mi = m[::5]
>>> np.savetxt(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_txt", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.loadtxt(self.input_file)
points = self.data[:, 0]
kwargs["points"] = points
super(InterpLinearNpy, self).__init__(*args, **kwargs)


@register_particle("interp")
class Interp(InterpolationParticle):
"""linear interpolation for real number"""
Expand Down
42 changes: 42 additions & 0 deletions tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,48 @@ def plot_partial_wave(
)


@ConfigLoader.register_function()
def plot_partial_wave_interf(self, res1, res2, **kwargs):

labels = ["data"]
if self.config["data"].get("model", "auto") == "cfit":
labels.append("background")
elif self.config["data"].get("bg", None) is not None:
labels.append("background")

if kwargs.get("ref_amp", None) is not None:
labels.append("reference fit")
labels.append("total fit")

if kwargs.get("force_legend_labels", None) is not None:
labels = kwargs["force_legend_labels"]
del kwargs["force_legend_labels"]

labels += [str(res1), str(res2), "sum", "interference"]

if not isinstance(res1, list):
res1 = [res1]
if not isinstance(res2, list):
res2 = [res2]

amp = self.get_amplitude()

def weights_function(data, **kwargs):
with amp.temp_used_res(res1):
a = amp(data)
with amp.temp_used_res(res2):
b = amp(data)
with amp.temp_used_res(res1 + res2):
ab = amp(data)
return [a, b, ab, ab - a - b]

self.plot_partial_wave(
partial_waves_function=weights_function,
force_legend_labels=labels,
**kwargs,
)


@ConfigLoader.register_function()
def _get_plot_partial_wave_input(
self,
Expand Down
11 changes: 11 additions & 0 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ def test_cfit(gen_toy):
linestyle_file="toy_data/a.yml",
chains_id_method="res",
)
config.plot_partial_wave_interf(
"R_BC",
"R_BD",
prefix="toy_data/figure/interf_",
)
config.plot_partial_wave_interf(
"R_BC",
"R_BD",
prefix="toy_data/figure/interf2_",
ref_amp=amp,
)
config.get_plotter().save_all_frame(prefix="toy_data/figure/s3", idx=0)
plotter = config.get_plotter("toy_data/a.yml", use_weighted=True)
plotter.smooth = True
Expand Down

0 comments on commit a2be967

Please sign in to comment.