Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 10, 2023
1 parent 8555fdc commit ff8191c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 45 deletions.
63 changes: 21 additions & 42 deletions pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2614,10 +2614,15 @@ def _get_colordata(bs, elements, bs_projection):
for idx, e in enumerate(elements):
c[idx] = math.sqrt(projs[e] / total) # min is to handle round errors

c = [c[1], c[2], c[0], c[3]] # prefer blue, then red, then green or magenta, then yellow, then cyan, then black
c = [
c[1],
c[2],
c[0],
c[3],
] # prefer blue, then red, then green or magenta, then yellow, then cyan, then black
if len(elements) == 4:
# convert cmyk to rgb
c = [(1-c[0])*(1-c[3]), ((1-c[1])*(1-c[3])), ((1-c[2])*(1-c[3]))]
c = [(1 - c[0]) * (1 - c[3]), ((1 - c[1]) * (1 - c[3])), ((1 - c[2]) * (1 - c[3]))]
else:
c = [c[0], c[1], c[2]]

Expand All @@ -2640,6 +2645,7 @@ def _cmyk_triangle(ax, c_label, m_label, y_label, k_label, loc):
loc = 2

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

inset_ax = inset_axes(ax, width=1.5, height=1.5, loc=loc)
mesh = 35
x = []
Expand All @@ -2648,63 +2654,36 @@ def _cmyk_triangle(ax, c_label, m_label, y_label, k_label, loc):
for c in range(0, mesh):
for ye in range(0, mesh):
for m in range(0, mesh):
if (
not (c == mesh-1 and ye == mesh-1 and m == mesh-1) and
not (c == 0 and ye == 0 and m == 0)
):
c1 = c / (c + ye + m )
ye1 = ye / (c + ye + m )
m1 = m / (c + ye + m )
x.append(0.33 * (2. * ye1 + c1) / (c1 + ye1 + m1))
if not (c == mesh - 1 and ye == mesh - 1 and m == mesh - 1) and not (c == 0 and ye == 0 and m == 0):
c1 = c / (c + ye + m)
ye1 = ye / (c + ye + m)
m1 = m / (c + ye + m)
x.append(0.33 * (2.0 * ye1 + c1) / (c1 + ye1 + m1))
y.append(0.33 * np.sqrt(3) * c1 / (c1 + ye1 + m1))
rc = (1 - c / (mesh - 1))
gc = (1 - m / (mesh - 1))
bc = (1 - ye / (mesh - 1))
rc = 1 - c / (mesh - 1)
gc = 1 - m / (mesh - 1)
bc = 1 - ye / (mesh - 1)
color.append([rc, gc, bc])

# x = [n + 0.25 for n in x] # nudge x coordinates
# y = [n + (max_y - 1) for n in y] # shift y coordinates to top
# plot the triangle
inset_ax.scatter(x, y, s=7, marker='.', edgecolor=color)
inset_ax.scatter(x, y, s=7, marker=".", edgecolor=color)
inset_ax.set_xlim([-0.35, 1.00])
inset_ax.set_ylim([-0.35, 1.00])

# add the labels
inset_ax.text(
0.70,
-0.2,
m_label,
fontsize=13,
family='Times New Roman',
color=(0, 0, 0),
horizontalalignment='left'
0.70, -0.2, m_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="left"
)
inset_ax.text(
0.325,
0.70,
c_label,
fontsize=13,
family='Times New Roman',
color=(0, 0, 0),
horizontalalignment='center'
0.325, 0.70, c_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="center"
)
inset_ax.text(
-0.05,
-0.2,
y_label,
fontsize=13,
family='Times New Roman',
color=(0, 0, 0),
horizontalalignment='right'
-0.05, -0.2, y_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="right"
)
inset_ax.text(
0.325,
0.22,
k_label,
fontsize=13,
family='Times New Roman',
color=(1, 1, 1),
horizontalalignment='center'
0.325, 0.22, k_label, fontsize=13, family="Times New Roman", color=(1, 1, 1), horizontalalignment="center"
)

inset_ax.get_xaxis().set_visible(False)
Expand Down
10 changes: 7 additions & 3 deletions pymatgen/electronic_structure/tests/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,18 @@ def test_methods(self):
v.complete_dos,
)
plt.close("all")

v = Vasprun(os.path.join(PymatgenTest.TEST_FILES_DIR, "vasprun_SrBa2Sn2O7_bands.xml"), parse_projected_eigen=True)

v = Vasprun(
os.path.join(PymatgenTest.TEST_FILES_DIR, "vasprun_SrBa2Sn2O7_bands.xml"), parse_projected_eigen=True
)
p = BSDOSPlotter()
plt = p.get_plot(
v.get_band_structure(kpoints_filename=os.path.join(PymatgenTest.TEST_FILES_DIR, "KPOINTS_SrBa2Sn2O7_bands"))
)
plt = p.get_plot(
v.get_band_structure(kpoints_filename=os.path.join(PymatgenTest.TEST_FILES_DIR, "KPOINTS_SrBa2Sn2O7_bands")),
v.get_band_structure(
kpoints_filename=os.path.join(PymatgenTest.TEST_FILES_DIR, "KPOINTS_SrBa2Sn2O7_bands")
),
v.complete_dos,
)

Expand Down

0 comments on commit ff8191c

Please sign in to comment.