Skip to content

Commit

Permalink
Merge branch '9-fat-bands' into 'master'
Browse files Browse the repository at this point in the history
Resolve "Fat bands"

Closes orest-d#9

See merge request schlipf/py4vasp!8
  • Loading branch information
martin-schlipf committed Jan 16, 2020
2 parents 66a58c6 + 506f721 commit f44184e
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 18 deletions.
134 changes: 121 additions & 13 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import re
import functools
import itertools
import numpy as np
import plotly.graph_objects as go
from collections import namedtuple


class Band:
_Index = namedtuple("_Index", "spin, atom, orbital")
_Atom = namedtuple("_Atom", "indices, label")
_Orbital = namedtuple("_Orbital", "indices, label")
_Spin = namedtuple("_Spin", "indices, label")

def __init__(self, vaspout):
self._fermi_energy = vaspout["results/dos/efermi"][()]
self._kpoints = vaspout["results/eigenvalues/kpoint_coords"]
Expand All @@ -18,33 +26,91 @@ def __init__(self, vaspout):
self._indices = vaspout[indices_key] if indices_key in vaspout else []
labels_key = "input/kpoints/labels_kpoints"
self._labels = vaspout[labels_key] if labels_key in vaspout else []
self._has_projectors = "results/projectors" in vaspout
if self._has_projectors:
self._init_projectors(vaspout)

def _init_projectors(self, vaspout):
self._projections = vaspout["results/projectors/par"]
ion_types = vaspout["results/positions/ion_types"]
ion_types = [type.decode().strip() for type in ion_types]
self._init_atom_dict(ion_types, vaspout["results/positions/number_ion_types"])
orbitals = vaspout["results/projectors/lchar"]
orbitals = [orb.decode().strip() for orb in orbitals]
self._init_orbital_dict(orbitals)
self._init_spin_dict()

def _init_atom_dict(self, ion_types, number_ion_types):
num_atoms = self._projections.shape[1]
all_atoms = self._Atom(indices=range(num_atoms), label=None)
self._atom_dict = {"*": all_atoms}
start = 0
for type, number in zip(ion_types, number_ion_types):
_range = range(start, start + number)
self._atom_dict[type] = self._Atom(indices=_range, label=type)
for i in _range:
# create labels like Si_1, Si_2, Si_3 (starting at 1)
label = type + "_" + str(_range.index(i) + 1)
self._atom_dict[str(i + 1)] = self._Atom(indices=[i], label=label)
start += number
# atoms may be preceeded by :
for key in self._atom_dict.copy():
self._atom_dict[key + ":"] = self._atom_dict[key]

def _init_orbital_dict(self, orbitals):
num_orbitals = self._projections.shape[2]
all_orbitals = self._Orbital(indices=range(num_orbitals), label=None)
self._orbital_dict = {"*": all_orbitals}
for i, orbital in enumerate(orbitals):
self._orbital_dict[orbital] = self._Orbital(indices=[i], label=orbital)
if "px" in self._orbital_dict:
self._orbital_dict["p"] = self._Orbital(indices=range(1, 4), label="p")
self._orbital_dict["d"] = self._Orbital(indices=range(4, 9), label="d")
self._orbital_dict["f"] = self._Orbital(indices=range(9, 16), label="f")

def _init_spin_dict(self):
labels = ["up", "down"] if self._spin_polarized else [None]
self._spin_dict = {
key: self._Spin(indices=[i], label=key) for i, key in enumerate(labels)
}

def read(self):
def read(self, selection=None):
kpoints = self._kpoints[:]
return {
"kpoints": kpoints,
"kpoint_distances": self._kpoint_distances(kpoints),
"kpoint_labels": self._kpoint_labels(),
"fermi_energy": self._fermi_energy,
**self._shift_bands_by_fermi_energy(),
"projections": self._read_projections(selection),
}

def plot(self):
band = self.read()
num_bands = band["bands"].shape[-1]
kdists = band["kpoint_distances"]
# insert NaN to split separate bands
kdist = np.tile([*kdists, np.NaN], num_bands)
bands = np.append(
band["bands"], [np.repeat(np.NaN, num_bands)], axis=0
).flatten(order="F")
def plot(self, selection=None, width=0.5):
kdists = self._kpoint_distances(self._kpoints[:])
fatband_kdists = np.concatenate((kdists, np.flip(kdists)))
bands = self._shift_bands_by_fermi_energy()
projections = self._read_projections(selection)
ticks = [*kdists[:: self._line_length], kdists[-1]]
labels = self._ticklabels()
data = []
for key, lines in bands.items():
if len(projections) == 0:
data.append(self._scatter(key, kdists, lines))
for name, proj in projections.items():
if self._spin_polarized and not key in name:
continue
upper = lines + width * proj
lower = lines - width * proj
fatband_lines = np.concatenate((lower, np.flip(upper, axis=0)), axis=0)
plot = self._scatter(name, fatband_kdists, fatband_lines)
plot.fill = "toself"
plot.mode = "none"
data.append(plot)
default = {
"xaxis": {"tickmode": "array", "tickvals": ticks, "ticktext": labels},
"yaxis": {"title": {"text": "Energy (eV)"}},
}
return go.Figure(data=go.Scatter(x=kdist, y=bands), layout=default)
return go.Figure(data=data, layout=default)

def _shift_bands_by_fermi_energy(self):
if self._spin_polarized:
Expand All @@ -55,6 +121,19 @@ def _shift_bands_by_fermi_energy(self):
else:
return {"bands": self._bands[0] - self._fermi_energy}

def _read_projections(self, selection):
if selection is None:
return {}
parts = self._parse_selection(selection)
return self._read_elements(parts)

def _scatter(self, name, kdists, lines):
# insert NaN to split separate lines
num_bands = lines.shape[-1]
kdists = np.tile([*kdists, np.NaN], num_bands)
lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0)
return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name)

def _kpoint_distances(self, kpoints):
cartesian_kpoints = np.linalg.solve(self._cell, kpoints.T).T
kpoint_lines = np.split(cartesian_kpoints, self._num_lines)
Expand All @@ -64,6 +143,35 @@ def _kpoint_distances(self, kpoints):
)
return functools.reduce(concatenate_distances, kpoint_norms)

def _parse_selection(self, selection):
atom = self._atom_dict["*"]
selection = re.sub("\s*:\s*", ": ", selection)
for part in re.split("[ ,]+", selection):
if part in self._orbital_dict:
orbital = self._orbital_dict[part]
else:
atom = self._atom_dict[part]
orbital = self._orbital_dict["*"]
if ":" not in part: # exclude ":" because it starts a new atom
for spin in self._spin_dict.values():
yield atom, orbital, spin

def _read_elements(self, parts):
res = {}
for atom, orbital, spin in parts:
label = self._merge_labels([atom.label, orbital.label, spin.label])
index = self._Index(spin.indices, atom.indices, orbital.indices)
res[label] = self._read_element(index)
return res

def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _read_element(self, index):
sum_weight = lambda weight, i: weight + self._projections[i]
zero_weight = np.zeros(self._bands.shape[1:])
return functools.reduce(sum_weight, itertools.product(*index), zero_weight)

def _kpoint_labels(self):
if len(self._labels) == 0:
return None
Expand All @@ -75,9 +183,9 @@ def _kpoint_labels(self):
return [l.decode().strip() for l in labels]

def _ticklabels(self):
labels = [""] * (self._num_lines + 1)
labels = [" "] * (self._num_lines + 1)
for index, label in zip(self._indices, self._labels):
i = index // 2 # line has 2 ends
label = label.decode().strip()
labels[i] = (labels[i] + "|" + label) if labels[i] else label
labels[i] = (labels[i] + "|" + label) if labels[i].strip() else label
return labels
90 changes: 85 additions & 5 deletions tests/data/test_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def test_parabolic_band_plot(two_parabolic_bands):
fig = Band(h5f).plot()
assert fig.layout.yaxis.title.text == "Energy (eV)"
assert len(fig.data) == 1
assert fig.data[0].fill is None
assert fig.data[0].mode is None
assert len(fig.data[0].x) == len(fig.data[0].y)
num_NaN_x = np.count_nonzero(np.isnan(fig.data[0].x))
num_NaN_y = np.count_nonzero(np.isnan(fig.data[0].y))
Expand All @@ -74,7 +76,7 @@ def kpoint_path():
"kpoints": np.concatenate((first_path, second_path)),
"kdists": np.concatenate((first_dists, second_dists)),
"klabels": ([""] * (N - 1) + ["X", "Y"] + [""] * (N - 2) + ["G"]),
"ticklabels": ("", "X|Y", "G"),
"ticklabels": (" ", "X|Y", "G"),
}
h5f["input/kpoints/number_kpoints"] = N
h5f["input/kpoints/labels_kpoints"] = np.array(["X", "Y", "G"], dtype="S")
Expand Down Expand Up @@ -109,15 +111,23 @@ def test_kpoint_path_plot(kpoint_path):
@pytest.fixture
def spin_band_structure():
h5f = h5py.File(TemporaryFile(), "a")
num_bands = 5
kpoints = np.linspace(np.zeros(3), np.ones(3))
shape = (len(kpoints), 5)
size = np.prod(shape)
ref = {
"up": np.random.random((len(kpoints), num_bands)),
"down": np.random.random((len(kpoints), num_bands)),
"up": np.arange(size).reshape(shape),
"down": np.arange(size, 2 * size).reshape(shape),
"proj_up": np.random.uniform(low=0.2, size=shape),
"proj_down": np.random.uniform(low=0.2, size=shape),
"width": 0.05,
}
h5f["input/kpoints/number_kpoints"] = len(kpoints)
h5f["results/eigenvalues/kpoint_coords"] = kpoints
h5f["results/eigenvalues/eigenvalues"] = np.array([ref["up"], ref["down"]])
h5f["results/projectors/par"] = np.array([[[ref["proj_up"]]], [[ref["proj_down"]]]])
h5f["results/projectors/lchar"] = np.array(["s"], dtype="S")
h5f["results/positions/ion_types"] = np.array(["Si"], dtype="S")
h5f["results/positions/number_ion_types"] = [1]
h5f["results/dos/efermi"] = 0.0
h5f["results/positions/scale"] = 1.0
h5f["results/positions/lattice_vectors"] = np.eye(3)
Expand All @@ -126,6 +136,76 @@ def spin_band_structure():

def test_spin_band_structure_read(spin_band_structure):
h5f, ref = spin_band_structure
band = Band(h5f).read()
band = Band(h5f).read("s")
assert_allclose(band["up"], ref["up"])
assert_allclose(band["down"], ref["down"])
assert_allclose(band["projections"]["s_up"], ref["proj_up"])
assert_allclose(band["projections"]["s_down"], ref["proj_down"])


def test_spin_band_structure_plot(spin_band_structure):
h5f, ref = spin_band_structure
fig = Band(h5f).plot("Si", width=ref["width"])
assert len(fig.data) == 2
spins = ["up", "down"]
for spin, data in zip(spins, fig.data):
assert data.name == "Si_" + spin
for band, weight in zip(np.nditer(ref[spin]), np.nditer(ref["proj_" + spin])):
upper = band + ref["width"] * weight
lower = band - ref["width"] * weight
pos_upper = data.x[np.where(np.isclose(data.y, upper))]
pos_lower = data.x[np.where(np.isclose(data.y, lower))]
assert len(pos_upper) == len(pos_lower) == 1
assert_allclose(pos_upper, pos_lower)


@pytest.fixture
def projected_band_structure():
h5f = h5py.File(TemporaryFile(), "a")
kpoints = np.linspace(np.zeros(3), np.ones(3))
shape = (len(kpoints), 2)
ref = {
"bands": np.arange(np.prod(shape)).reshape(shape),
"projections": np.random.uniform(low=0.2, size=shape),
# set lower bound to avoid accidentally triggering np.isclose
"width": 0.5,
}
h5f["input/kpoints/number_kpoints"] = len(kpoints)
h5f["results/eigenvalues/kpoint_coords"] = kpoints
h5f["results/eigenvalues/eigenvalues"] = np.array([ref["bands"]])
h5f["results/projectors/par"] = np.array([[[ref["projections"]]]])
h5f["results/projectors/lchar"] = np.array(["s"], dtype="S")
h5f["results/positions/ion_types"] = np.array(["Si"], dtype="S")
h5f["results/positions/number_ion_types"] = [1]
h5f["results/dos/efermi"] = 0.0
h5f["results/positions/scale"] = 1.0
h5f["results/positions/lattice_vectors"] = np.eye(3)
return h5f, ref


def test_projected_band_structure_read(projected_band_structure):
h5f, ref = projected_band_structure
band = Band(h5f).read("Si:s")
assert_allclose(band["projections"]["Si_s"], ref["projections"])


def test_projected_band_structure_plot(projected_band_structure):
h5f, ref = projected_band_structure
fig = Band(h5f).plot("s, 1")
assert len(fig.data) == 2
assert fig.data[0].name == "s"
assert fig.data[1].name == "Si_1"
for data in fig.data:
assert len(data.x) == len(data.y)
assert data.fill == "toself"
assert data.mode == "none"
num_NaN_x = np.count_nonzero(np.isnan(data.x))
num_NaN_y = np.count_nonzero(np.isnan(data.y))
assert num_NaN_x == num_NaN_y > 0
for band, weight in zip(np.nditer(ref["bands"]), np.nditer(ref["projections"])):
upper = band + ref["width"] * weight
lower = band - ref["width"] * weight
pos_upper = data.x[np.where(np.isclose(data.y, upper))]
pos_lower = data.x[np.where(np.isclose(data.y, lower))]
assert len(pos_upper) == len(pos_lower) == 1
assert_allclose(pos_upper, pos_lower)

0 comments on commit f44184e

Please sign in to comment.