Skip to content

Commit

Permalink
[feat]: Adding handler to extract band projections
Browse files Browse the repository at this point in the history
  • Loading branch information
hentt30 committed Aug 3, 2021
1 parent a2c2631 commit 6b0b204
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 50 deletions.
127 changes: 127 additions & 0 deletions minushalf/handlers/get_projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Get band projections
"""
import collections
import loguru
import pandas as pd
from minushalf.softwares import Vasp
from minushalf.data import Softwares
from minushalf.handlers import BaseHandler
from minushalf.interfaces import SoftwaresAbstractFactory, BandProjectionFile, MinushalfYaml
from minushalf.utils import (BandProjectionFile, band_structure, projection_to_df)


class GetBandProjections(BaseHandler):
"""
Uses the software module to extract the character of the bands (VBM and CBM)
"""
def _get_band_structure(
self, software_module: SoftwaresAbstractFactory) -> BandProjectionFile:
"""
Return band structure class
"""
eigenvalues = software_module.get_eigenvalues()
fermi_energy = software_module.get_fermi_energy()
atoms_map = software_module.get_atoms_map()
num_bands = software_module.get_number_of_bands()
band_projection_file = software_module.get_band_projection_class()

return BandProjectionFile(eigenvalues, fermi_energy, atoms_map, num_bands,
band_projection_file)

def _get_vbm_projection(self,
band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Returns vbm projection
"""
vbm_projection = band_structure.vbm_projection()
normalized_df = projection_to_df(vbm_projection)
return normalized_df

def _get_cbm_projection(self,
band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Returns cbm projection
"""
cbm_projection = band_structure.cbm_projection()
normalized_df = projection_to_df(cbm_projection)
return normalized_df

def _get_band_projection(self, band_structure: BandProjectionFile, kpoint: int,
band: int) -> pd.DataFrame:
"""
Returns band projection
"""
band_projection = band_structure.band_projection(kpoint, band)
normalized_df = projection_to_df(band_projection)
return normalized_df

def _select_vbm(self, minushalf_yaml: MinushalfYaml,
band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Select and returns vbm character
"""
overwrite_vbm = minushalf_yaml.get_overwrite_vbm()

if overwrite_vbm:
return self._get_band_projection(band_structure, *overwrite_vbm)

return self._get_vbm_projection()

def _select_cbm(self, minushalf_yaml: MinushalfYaml,
band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Select and returns cbm character
"""
overwrite_cbm = minushalf_yaml.get_overwrite_cbm()

if overwrite_cbm:
return self._get_band_projection(band_structure, *overwrite_cbm)

return self._get_cbm_projection()

def _is_valence_correction(self, correction_code: str) -> bool:
"""
Verify if the correction is a valence correction
"""
return "v" in correction_code

def _is_conduction_correction(self, correction_code: str) -> bool:
"""
Verify if the correction is a conduction correction
"""
return "c" in correction_code

def _get_projections(
self, minushalf_yaml: MinushalfYaml,
band_structure: BandProjectionFile) -> collections.defaultdict:
"""
Returns an dictionary with the projections necessary to the correction
"""
correction_code = minushalf_yaml.get_correction_code()

projections = collections.defaultdict(dict)

## Select confuction and valence index
if self._is_valence_correction(correction_code):
projections["valence"] = self._select_vbm(minushalf_yaml,
band_structure)
if self._is_conduction_correction(correction_code):
projections["conduction"] = self._select_cbm(
minushalf_yaml, band_structure)
return projections

def action(self, request: any) -> any:
"""
Uses the software module to extract band projections
"""
loguru.logger.info("Extracting band projections")
software_module, minushalf_yaml = request["software_module"], request[
"minushalf_yaml"]
band_structure = self._get_band_structure(software_module)

## Add projections to object
request["projections"] = self._get_projections(minushalf_yaml,
band_structure)

return request
16 changes: 16 additions & 0 deletions minushalf/interfaces/minushalf_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@ def get_software_name(self) -> str:
"""
Returns the name of the software that runs first principles calculations
"""
@abstractmethod
def get_correction_code(self) -> str:
"""
Returns the code used to identify the correction
"""
@abstractmethod
def get_overwrite_vbm(self) -> str:
"""
Returns the parameter that overwrites vbm
"""

@abstractmethod
def get_overwrite_cbm(self) -> str:
"""
Returns the parameter that overwrites cbm
"""
5 changes: 3 additions & 2 deletions minushalf/io/correction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Class for correction input parameters in minushalf.yaml
"""
from shutil import Error
from minushalf.interfaces.minushalf_yaml_tags import MinushalfYamlTags
import loguru
from minushalf.data import CorrectionCode
Expand Down Expand Up @@ -116,7 +117,7 @@ def to_dict(self):
"""
Return dictionary with the class variables
"""
parameters_dict = self.__dict__
parameters_dict = self.__dict__.copy()

## removing private variables
parameters_dict["correction_code"] = parameters_dict.pop(
Expand All @@ -126,4 +127,4 @@ def to_dict(self):
parameters_dict["overwrite_cbm"] = parameters_dict.pop(
"_overwrite_cbm", None)

return self.__dict__
return parameters_dict
18 changes: 18 additions & 0 deletions minushalf/io/minushalf_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def get_command(self) -> list:
"""
return self.software_configurations.command

def get_correction_code(self) -> list:
"""
Returns the code used to identify the correction
"""
return self.correction.correction_code

def get_overwrite_vbm(self) -> list:
"""
Returns the parameter that overwrites vbm
"""
return self.correction.overwrite_vbm

def get_overwrite_cbm(self) -> list:
"""
Returns the parameter that overwrites cbm
"""
return self.correction.overwrite_cbm

@staticmethod
def _read_yaml(filename: str) -> collections.defaultdict:
"""
Expand Down
Empty file.
105 changes: 57 additions & 48 deletions tests/unit/io/test_minushalf_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,26 @@ def test_default_parameters():
**file.get_software_configurations_params()
}
default_params = {
'exchange_correlation_code': 'pb',
'calculation_code': 'ae',
'max_iterations': 100,
'potfiles_folder': 'minushalf_potfiles',
'amplitude': 1.0,
'valence_cut_guess': None,
'conduction_cut_guess': None,
'tolerance': 0.01,
'fractional_valence_treshold': 10,
'fractional_conduction_treshold': 9,
'inplace': False,
'correction_code': 'v',
'overwrite_vbm': [],
'overwrite_cbm': [],
'command': ['mpirun', 'vasp']
"exchange_correlation_code": "pb",
"calculation_code": "ae",
"max_iterations": 100,
"potfiles_folder": "minushalf_potfiles",
"amplitude": 1.0,
"valence_cut_guess": None,
"conduction_cut_guess": None,
"tolerance": 0.01,
"fractional_valence_treshold": 10,
"fractional_conduction_treshold": 9,
"inplace": False,
"correction_code": "v",
"overwrite_vbm": [],
"overwrite_cbm": [],
"command": ["mpirun", "vasp"]
}
assert file.get_command() == ['mpirun', 'vasp']
assert file.get_command() == ["mpirun", "vasp"]
assert file.get_correction_code() == "v"
assert file.get_overwrite_cbm() == []
assert file.get_overwrite_vbm() == []
assert params == default_params


Expand All @@ -54,24 +57,27 @@ def test_minushalf_without_filling_correction(file_path):
**file.get_software_configurations_params()
}
expected_params = {
'exchange_correlation_code': 'wi',
'calculation_code': 'ae',
'max_iterations': 200,
'potfiles_folder': 'minushalf_potfiles',
'amplitude': 1.0,
'valence_cut_guess': None,
'conduction_cut_guess': None,
'tolerance': 0.01,
'fractional_valence_treshold': 10,
'fractional_conduction_treshold': 9,
'inplace': False,
'correction_code': 'v',
'overwrite_vbm': [],
'overwrite_cbm': [],
'command': ['mpirun', '-np', '6', '../vasp']
"exchange_correlation_code": "wi",
"calculation_code": "ae",
"max_iterations": 200,
"potfiles_folder": "minushalf_potfiles",
"amplitude": 1.0,
"valence_cut_guess": None,
"conduction_cut_guess": None,
"tolerance": 0.01,
"fractional_valence_treshold": 10,
"fractional_conduction_treshold": 9,
"inplace": False,
"correction_code": "v",
"overwrite_vbm": [],
"overwrite_cbm": [],
"command": ["mpirun", "-np", "6", "../vasp"]
}
assert file.get_command() == ['mpirun', '-np', '6', '../vasp']
assert file.get_command() == ["mpirun", "-np", "6", "../vasp"]
assert file.get_software_name() == Softwares.vasp.value
assert file.get_correction_code() == "v"
assert file.get_overwrite_cbm() == []
assert file.get_overwrite_vbm() == []
assert params == expected_params


Expand All @@ -88,24 +94,27 @@ def test_minushalf_filled_out(file_path):
**file.get_software_configurations_params()
}
expected_params = {
'exchange_correlation_code': 'wi',
'calculation_code': 'ae',
'max_iterations': 200,
'potfiles_folder': '../potcar',
'amplitude': 3.0,
'valence_cut_guess': 2.0,
'conduction_cut_guess': 1.0,
'tolerance': 0.001,
'fractional_valence_treshold': 15,
'fractional_conduction_treshold': 23,
'inplace': True,
'correction_code': 'vf',
'overwrite_vbm': [1, 3],
'overwrite_cbm': [1, 4],
'command': ['mpirun', '-np', '6', '../vasp']
"exchange_correlation_code": "wi",
"calculation_code": "ae",
"max_iterations": 200,
"potfiles_folder": "../potcar",
"amplitude": 3.0,
"valence_cut_guess": 2.0,
"conduction_cut_guess": 1.0,
"tolerance": 0.001,
"fractional_valence_treshold": 15,
"fractional_conduction_treshold": 23,
"inplace": True,
"correction_code": "vf",
"overwrite_vbm": [1, 3],
"overwrite_cbm": [1, 4],
"command": ["mpirun", "-np", "6", "../vasp"]
}
assert file.get_command() == ['mpirun', '-np', '6', '../vasp']
assert file.get_command() == ["mpirun", "-np", "6", "../vasp"]
assert file.get_software_name() == Softwares.vasp.value
assert file.get_correction_code() == "vf"
assert file.get_overwrite_cbm() == [1,4]
assert file.get_overwrite_vbm() == [1,3]
assert params == expected_params


Expand Down

0 comments on commit 6b0b204

Please sign in to comment.