Skip to content

Commit

Permalink
PhononDosPlotter.plot_dos() add support for existing plt.Axes (#3487
Browse files Browse the repository at this point in the history
)

* breaking: PhononBSPlotter.save_plot remove redundant img_format keyword

use filename extension to determine image format

* formatting

* Add support for existing plt axes in
PhononDosPlotter.plot_dos()

* pretty print DELTA as $\Delta$ in PhononBSPlotter.get_ticks
  • Loading branch information
janosh committed Nov 27, 2023
1 parent ec750ca commit 68cf6b4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,7 @@ def make_supergraph(graph, multiplicity, periodicity_vectors):
connecting_edges.append((n1, n2, key, new_data))
else:
if not np.all(np.array(data["delta"]) == 0):
print(
"delta not equal to periodicity nor 0 ... : ",
n1,
n2,
key,
data["delta"],
data,
)
print("delta not equal to periodicity nor 0 ... : ", n1, n2, key, data["delta"], data)
input("Are we ok with this ?")
other_edges.append((n1, n2, key, data))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1129,22 +1129,22 @@ def _get_map(self, isite):
target_cns = [cg.coordination_number for cg in target_cgs]
for ii in range(min([len(maps_and_surfaces), self.max_nabundant])):
my_map_and_surface = maps_and_surfaces[order[ii]]
mymap = my_map_and_surface["map"]
cn = mymap[0]
my_map = my_map_and_surface["map"]
cn = my_map[0]
if cn not in target_cns or cn > 12 or cn == 0:
continue
all_conditions = [params[2] for params in my_map_and_surface["parameters_indices"]]
if self._additional_condition not in all_conditions:
continue
cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][mymap[0]][
mymap[1]
cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][my_map[0]][
my_map[1]
].minimum_geometry(symmetry_measure_type=self._symmetry_measure_type)
if (
cg in self.target_environments
and cgdict["symmetry_measure"] <= self.max_csm
and cgdict["symmetry_measure"] < current_target_env_csm
):
current_map = mymap
current_map = my_map
current_target_env_csm = cgdict["symmetry_measure"]
if current_map is not None:
return current_map
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/pourbaix_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ def __init__(
for entry in ion_entries:
ion_elts = list(set(entry.elements) - ELEMENTS_HO)
# TODO: the logic here for ion concentration setting is in two
# places, in PourbaixEntry and here, should be consolidated
# places, in PourbaixEntry and here, should be consolidated
if len(ion_elts) == 1:
entry.concentration = conc_dict[ion_elts[0].symbol] * entry.normalization_factor
elif len(ion_elts) > 1 and not entry.concentration:
raise ValueError("Elemental concentration not compatible with multi-element ions")

self._unprocessed_entries = solid_entries + ion_entries

if not len(solid_entries + ion_entries) == len(entries):
if len(solid_entries + ion_entries) != len(entries):
raise ValueError('All supplied entries must have a phase type of either "Solid" or "Ion"')

if self.filter_solids:
Expand Down
27 changes: 13 additions & 14 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
)
self.stack = stack
self.sigma = sigma
self._doses: dict = {}
self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {}

def add_dos(self, label: str, dos: PhononDos) -> None:
"""Adds a dos for plotting.
Expand Down Expand Up @@ -138,6 +138,7 @@ def get_plot(
ylim: float | None = None,
units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz",
legend: dict | None = None,
ax: Axes | None = None,
) -> Axes:
"""Get a matplotlib plot showing the DOS.
Expand All @@ -149,6 +150,8 @@ def get_plot(
legend: dict with legend options. For example, {"loc": "upper right"}
will place the legend in the upper right corner. Defaults to
{"fontsize": 30}.
ax (Axes): An existing axes object onto which the plot will be
added. If None, a new figure will be created.
"""
legend = legend or {"fontsize": 30}
unit = freq_units(units)
Expand All @@ -161,7 +164,7 @@ def get_plot(
y = None
all_densities = []
all_frequencies = []
ax = pretty_plot(12, 8)
ax = pretty_plot(12, 8, ax=ax)

# Note that this complicated processing of frequencies is to allow for
# stacked plots in matplotlib.
Expand Down Expand Up @@ -516,30 +519,27 @@ def show(
"""Show the plot using matplotlib.
Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
ylim (float): Specifies the y-axis limits.
units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies.
"""
self.get_plot(ylim, units=units)
plt.show()

def save_plot(
self,
filename: str | PathLike,
img_format: str = "eps",
ylim: float | None = None,
units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz",
) -> None:
"""Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
filename (str | Path): Filename to write to.
ylim (float): Specifies the y-axis limits.
units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies.
"""
self.get_plot(ylim=ylim, units=units)
plt.savefig(filename, format=img_format)
plt.savefig(filename)
plt.close()

def show_proj(
Expand Down Expand Up @@ -598,9 +598,8 @@ def get_ticks(self) -> dict[str, list]:
elif point.label.startswith("\\") or point.label.find("_") != -1:
tick_labels.append(f"${point.label}$")
else:
label = point.label
if label == "GAMMA":
label = r"$\Gamma$"
# map atomate2 all-upper-case point.labels to pretty LaTeX
label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label)
tick_labels.append(label)
previous_label = point.label
previous_branch = this_branch
Expand Down

0 comments on commit 68cf6b4

Please sign in to comment.