diff --git a/.github/workflows/link-check.yml b/.github/workflows/link-check.yml index ea566b07..7bb6dd9a 100644 --- a/.github/workflows/link-check.yml +++ b/.github/workflows/link-check.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run markdown link check uses: gaurav-nelson/github-action-markdown-link-check@v1 diff --git a/.github/workflows/svgo.yml b/.github/workflows/svgo.yml index bde9b886..9873706c 100644 --- a/.github/workflows/svgo.yml +++ b/.github/workflows/svgo.yml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.head_ref }} - name: Set up node - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 - name: Install SVGO run: npm install --global svgo diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8b24229..59039f34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,9 @@ repos: hooks: - id: ruff args: [--fix] + types_or: [python, jupyter] - id: ruff-format + types_or: [python, jupyter] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.7.1 diff --git a/dataset_exploration/boltztrap_mp/explore_boltztrap_mp.py b/dataset_exploration/boltztrap_mp/explore_boltztrap_mp.py index 0fda3cde..b20bbd08 100644 --- a/dataset_exploration/boltztrap_mp/explore_boltztrap_mp.py +++ b/dataset_exploration/boltztrap_mp/explore_boltztrap_mp.py @@ -28,10 +28,10 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from pymatviz import ptable_heatmap -from pymatviz.plot_defaults import plt # %% diff --git a/dataset_exploration/camd_2022/explore_camd_2022.py b/dataset_exploration/camd_2022/explore_camd_2022.py index 0973ad7a..f7074a3f 100644 --- a/dataset_exploration/camd_2022/explore_camd_2022.py +++ b/dataset_exploration/camd_2022/explore_camd_2022.py @@ -17,12 +17,12 @@ # %% import os +import matplotlib.pyplot as plt import pandas as pd import requests from pymatgen.symmetry.groups import SpaceGroup from pymatviz import annotate_bars, count_elements, ptable_heatmap, spacegroup_sunburst -from pymatviz.plot_defaults import plt # %% Download data (if needed) diff --git a/dataset_exploration/matbench/dielectric/explore_dielectric.py b/dataset_exploration/matbench/dielectric/explore_dielectric.py index e0a75c35..39080ad3 100644 --- a/dataset_exploration/matbench/dielectric/explore_dielectric.py +++ b/dataset_exploration/matbench/dielectric/explore_dielectric.py @@ -1,15 +1,17 @@ # %% +import matplotlib.pyplot as plt +import plotly.express as px from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib from matminer.datasets import load_dataset from tqdm import tqdm from pymatviz import ( + crystal_sys_order, ptable_heatmap, ptable_heatmap_plotly, spacegroup_hist, spacegroup_sunburst, ) -from pymatviz.plot_defaults import crystal_sys_order, plt, px from pymatviz.utils import crystal_sys_from_spg_num diff --git a/dataset_exploration/matbench/expt_gap/explore_expt_gap.py b/dataset_exploration/matbench/expt_gap/explore_expt_gap.py index b5decb31..389e7ded 100644 --- a/dataset_exploration/matbench/expt_gap/explore_expt_gap.py +++ b/dataset_exploration/matbench/expt_gap/explore_expt_gap.py @@ -1,11 +1,12 @@ # %% from __future__ import annotations +import matplotlib.pyplot as plt +import plotly.express as px from matminer.datasets import load_dataset from pymatgen.core import Composition from pymatviz import ptable_heatmap -from pymatviz.plot_defaults import plt, px """Stats for the matbench_expt_gap dataset. diff --git a/dataset_exploration/matbench/jdft2d/explore_jdft2d.py b/dataset_exploration/matbench/jdft2d/explore_jdft2d.py index 1df417e4..38b9ffa3 100644 --- a/dataset_exploration/matbench/jdft2d/explore_jdft2d.py +++ b/dataset_exploration/matbench/jdft2d/explore_jdft2d.py @@ -1,9 +1,9 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from tqdm import tqdm from pymatviz import ptable_heatmap, spacegroup_hist, spacegroup_sunburst -from pymatviz.plot_defaults import plt """Stats for the matbench_jdft2d dataset. diff --git a/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py b/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py index 39523017..3cde8f07 100644 --- a/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py +++ b/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py @@ -13,14 +13,20 @@ # %% from time import perf_counter +import matplotlib.pyplot as plt import numpy as np +import plotly.express as px from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib from matminer.datasets import load_dataset from pymatgen.core import Structure from tqdm import tqdm -from pymatviz import ptable_heatmap, spacegroup_hist, spacegroup_sunburst -from pymatviz.plot_defaults import crystal_sys_order, plt, px +from pymatviz import ( + crystal_sys_order, + ptable_heatmap, + spacegroup_hist, + spacegroup_sunburst, +) from pymatviz.utils import crystal_sys_from_spg_num diff --git a/dataset_exploration/matbench/mp_e_form/explore_mp_e_form.py b/dataset_exploration/matbench/mp_e_form/explore_mp_e_form.py index 4183685b..fe84597a 100644 --- a/dataset_exploration/matbench/mp_e_form/explore_mp_e_form.py +++ b/dataset_exploration/matbench/mp_e_form/explore_mp_e_form.py @@ -1,8 +1,8 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from pymatviz import ptable_heatmap -from pymatviz.plot_defaults import plt """Stats for the matbench_mp_e_form dataset. diff --git a/dataset_exploration/matbench/mp_gap/explore_mp_gap.py b/dataset_exploration/matbench/mp_gap/explore_mp_gap.py index 673f2333..c273ef87 100644 --- a/dataset_exploration/matbench/mp_gap/explore_mp_gap.py +++ b/dataset_exploration/matbench/mp_gap/explore_mp_gap.py @@ -1,8 +1,8 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from pymatviz import ptable_heatmap -from pymatviz.plot_defaults import plt """Stats for the matbench_mp_gap dataset. diff --git a/dataset_exploration/matbench/perovskites/explore_perovskites.py b/dataset_exploration/matbench/perovskites/explore_perovskites.py index 05713e01..146c0101 100644 --- a/dataset_exploration/matbench/perovskites/explore_perovskites.py +++ b/dataset_exploration/matbench/perovskites/explore_perovskites.py @@ -1,4 +1,5 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from tqdm import tqdm @@ -8,7 +9,6 @@ ptable_heatmap, spacegroup_sunburst, ) -from pymatviz.plot_defaults import plt from pymatviz.utils import crystal_sys_from_spg_num diff --git a/dataset_exploration/matbench/phonons/explore_phonons.py b/dataset_exploration/matbench/phonons/explore_phonons.py index 48bac29d..5ccbd700 100644 --- a/dataset_exploration/matbench/phonons/explore_phonons.py +++ b/dataset_exploration/matbench/phonons/explore_phonons.py @@ -17,11 +17,11 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from tqdm import tqdm from pymatviz import ptable_heatmap, spacegroup_hist -from pymatviz.plot_defaults import plt # %% diff --git a/dataset_exploration/matbench/steels/explore_steels.py b/dataset_exploration/matbench/steels/explore_steels.py index 12a9bed3..a7aa6e4a 100644 --- a/dataset_exploration/matbench/steels/explore_steels.py +++ b/dataset_exploration/matbench/steels/explore_steels.py @@ -9,10 +9,10 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from pymatviz import ptable_heatmap -from pymatviz.plot_defaults import plt # %% diff --git a/dataset_exploration/ricci_carrier_transport/explore_carrier_transport.py b/dataset_exploration/ricci_carrier_transport/explore_carrier_transport.py index f21b2eab..1af55109 100644 --- a/dataset_exploration/ricci_carrier_transport/explore_carrier_transport.py +++ b/dataset_exploration/ricci_carrier_transport/explore_carrier_transport.py @@ -19,11 +19,11 @@ # %% +import matplotlib.pyplot as plt from matminer.datasets import load_dataset from tqdm import tqdm from pymatviz import ptable_heatmap, spacegroup_hist -from pymatviz.plot_defaults import plt # %% diff --git a/dataset_exploration/wbm/explore_wbm.py b/dataset_exploration/wbm/explore_wbm.py index c73e8165..706acdf6 100644 --- a/dataset_exploration/wbm/explore_wbm.py +++ b/dataset_exploration/wbm/explore_wbm.py @@ -1,8 +1,8 @@ # %% import pandas as pd +import plotly.express as px -from pymatviz import ptable_heatmap_plotly, spacegroup_sunburst -from pymatviz.plot_defaults import crystal_sys_order, px +from pymatviz import crystal_sys_order, ptable_heatmap_plotly, spacegroup_sunburst from pymatviz.utils import crystal_sys_from_spg_num diff --git a/examples/matbench_dielectric_eda.ipynb b/examples/matbench_dielectric_eda.ipynb index ff0d7059..166268d9 100644 --- a/examples/matbench_dielectric_eda.ipynb +++ b/examples/matbench_dielectric_eda.ipynb @@ -7,7 +7,7 @@ "source": [ "# Matbench Dielectric Dataset\n", "\n", - "Exploratory Data Analysis (EDA). [MPContribs link](https://ml.materialsproject.org/projects/matbench_dielectric)" + "Exploratory Data Analysis (EDA). [MPContribs link](https://ml.materialsproject.org/projects/matbench_dielectric)\n" ] }, { @@ -26,11 +26,13 @@ "metadata": {}, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "import plotly.express as px\n", + "import plotly.io as pio\n", "from matminer.datasets import load_dataset\n", "from tqdm import tqdm\n", "\n", "from pymatviz import ptable_heatmap, spacegroup_hist, spacegroup_sunburst\n", - "from pymatviz.plot_defaults import pio, plt, px\n", "from pymatviz.utils import get_crystal_sys\n", "\n", "\n", diff --git a/examples/mp_bimodal_e_form.ipynb b/examples/mp_bimodal_e_form.ipynb index 1b1b4e6d..dd4c0106 100644 --- a/examples/mp_bimodal_e_form.ipynb +++ b/examples/mp_bimodal_e_form.ipynb @@ -36,7 +36,6 @@ "__author__ = \"Janosh Riebesell\"\n", "__date__ = \"2022-08-11\"\n", "\n", - "pio.templates.default = \"plotly_white\"\n", "pio.renderers.default = \"png\"" ] }, diff --git a/examples/mprester_ptable.ipynb b/examples/mprester_ptable.ipynb index e354cf50..69b508ea 100644 --- a/examples/mprester_ptable.ipynb +++ b/examples/mprester_ptable.ipynb @@ -15,7 +15,7 @@ "outputs": [], "source": [ "# dash needed for interactive plots\n", - "!pip install pymatviz dash\n" + "!pip install pymatviz dash" ] }, { @@ -38,12 +38,11 @@ "__date__ = \"2022-07-21\"\n", "\n", "\n", - "pio.templates.default = \"plotly_white\"\n", "# Interactive plotly figures don't show up on GitHub.\n", "# https://github.com/plotly/plotly.py/issues/931\n", "# change renderer from \"svg\" to \"notebook\" to get hover tooltips back\n", "# (but blank plots on GitHub!)\n", - "pio.renderers.default = \"png\"\n" + "pio.renderers.default = \"png\"" ] }, { @@ -60,7 +59,7 @@ } ], "source": [ - "print(\", \".join(MPRester().materials.summary.available_fields))\n" + "print(\", \".join(MPRester().materials.summary.available_fields))" ] }, { @@ -75,7 +74,7 @@ " mp_data = mpr.materials.summary.search(\n", " # nelements=[4, None], # 4 or less elements\n", " fields=[\"material_id\", \"formula\", \"nelements\"]\n", - " )\n" + " )" ] }, { @@ -85,7 +84,7 @@ "outputs": [], "source": [ "df_mp = pd.DataFrame(map(dict, mp_data)).set_index(\"material_id\")\n", - "df_mp.head()\n" + "df_mp.head()" ] }, { @@ -106,7 +105,7 @@ "# %store df_mp\n", "\n", "# uncomment line to load cached MP data from disk\n", - "%store -r df_mp\n" + "%store -r df_mp" ] }, { @@ -123,7 +122,7 @@ "compound_counts_by_arity = {\n", " key: (df_mp.nelements == idx).sum()\n", " for idx, key in enumerate(elem_counts_by_arity, 1)\n", - "}\n" + "}" ] }, { @@ -188,7 +187,7 @@ " f\"Element distribution of {n_compounds:,} {arity_label} compounds in \"\n", " \"Materials Project\",\n", " fontsize=16,\n", - " )\n" + " )" ] }, { @@ -212,7 +211,7 @@ " )\n", " fig.update_layout(title=dict(text=title, x=0.45, y=0.93))\n", " arity_figs[arity_label] = fig\n", - " # fig.show() # uncomment to show plotly figures\n" + " # fig.show() # uncomment to show plotly figures" ] }, { @@ -266,7 +265,7 @@ " return arity_figs[dropdown_value]\n", "\n", "\n", - "app.run(debug=True, mode=\"inline\")\n" + "app.run(debug=True, mode=\"inline\")" ] } ], diff --git a/pymatviz/__init__.py b/pymatviz/__init__.py index 6d23b8c4..63636fc7 100644 --- a/pymatviz/__init__.py +++ b/pymatviz/__init__.py @@ -4,7 +4,9 @@ from importlib.metadata import PackageNotFoundError, version +import matplotlib.pyplot as plt import plotly.express as px +import plotly.io as pio from pymatviz.correlation import marchenko_pastur, marchenko_pastur_pdf from pymatviz.cumulative import cumulative_error, cumulative_residual @@ -49,6 +51,11 @@ pass # package not installed +# define a sensible order for crystal systems across plots +crystal_sys_order = ( + "cubic hexagonal trigonal tetragonal orthorhombic monoclinic triclinic".split() +) + bandgap_col = "band_gap" charge_col = "total_charge" crystal_sys_col = "crystal_system" @@ -91,5 +98,53 @@ symmetry_col: "Symmetry", volume_col: f"Volume {cubic_angstrom}", volume_per_atom_col: f"Volume {angstrom_per_atom}", + "n_atoms": "Atom Count", + "n_elems": "Element Count", + "gap expt": "Experimental band gap (eV)", + "crystal_sys": "Crystal system", + "n": "Refractive index n", + "spg_num": "Space group", + "n_wyckoff": "Number of Wyckoff positions", + "n_sites": "Number of unit cell sites", + "energy_per_atom": "Energy (eV/atom)", } -px.defaults.template = "plotly_white" + +# uncomment to hide math loading MathJax message in bottom left corner of plotly PDFs +# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924 +# pio.kaleido.scope.mathjax = None + + +""" +Importing this module has side-effects that apply sensible (often, not always) global +defaults settings for plotly and matplotlib like increasing font size, prettier +axis labels (plotly only) and higher figure resolution (matplotlib only). + +To use it, simply import this module before generating any plots: + +import pymatviz +""" + +plt.rc("font", size=16) +plt.rc("savefig", bbox="tight", dpi=200) +plt.rc("figure", dpi=200, titlesize=18) +plt.rcParams["figure.constrained_layout.use"] = True + + +axis_template = dict( + mirror=True, + showline=True, + ticks="outside", + zeroline=True, + linewidth=1, +) +white_axis_template = axis_template | dict(linecolor="black", gridcolor="lightgray") +pio.templates["pymatviz_white"] = pio.templates["plotly_white"].update( + layout=dict(xaxis=axis_template, yaxis=axis_template) +) +black_axis_template = axis_template | dict(linecolor="white", gridcolor="darkgray") +pio.templates["pymatviz_black"] = pio.templates["plotly_dark"].update( + layout=dict(xaxis=axis_template, yaxis=axis_template) +) + +px.defaults.template = "pymatviz_white" +pio.templates.default = "pymatviz_white" diff --git a/pymatviz/plot_defaults.py b/pymatviz/plot_defaults.py deleted file mode 100644 index 6e223757..00000000 --- a/pymatviz/plot_defaults.py +++ /dev/null @@ -1,44 +0,0 @@ -import matplotlib.pyplot as plt -import plotly.express as px -import plotly.io as pio - - -""" -Importing this module has side-effects that apply sensible (often, not always) global -defaults settings for plotly and matplotlib like increasing font size, prettier -axis labels (plotly only) and higher figure resolution (matplotlib only). - -To use it, simply import this module before generating any plots: - -import pymatviz.plot_defaults -# or -from pymatviz.plot_defaults import plt, px, pio -""" - -# prettier -px.defaults.labels = { - "n_atoms": "Atom Count", - "n_elems": "Element Count", - "gap expt": "Experimental band gap (eV)", - "crystal_sys": "Crystal system", - "n": "Refractive index n", - "spg_num": "Space group", - "n_wyckoff": "Number of Wyckoff positions", - "n_sites": "Number of unit cell sites", - "energy_per_atom": "Energy (eV/atom)", -} - -pio.templates.default = "plotly_white" - -# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924 -# pio.kaleido.scope.mathjax = None - -crystal_sys_order = ( - "cubic hexagonal trigonal tetragonal orthorhombic monoclinic triclinic".split() -) - - -plt.rc("font", size=16) -plt.rc("savefig", bbox="tight", dpi=200) -plt.rc("figure", dpi=200, titlesize=18) -plt.rcParams["figure.constrained_layout.use"] = True diff --git a/pyproject.toml b/pyproject.toml index bbb76b23..93459c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,6 @@ no_implicit_optional = false [tool.ruff] target-version = "py39" -include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions