Skip to content

Commit

Permalink
[test]: Adding tests to get projection handler
Browse files Browse the repository at this point in the history
  • Loading branch information
hentt30 committed Aug 8, 2021
1 parent 6b0b204 commit 17b110c
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 29 deletions.
1 change: 1 addition & 0 deletions minushalf/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .read_minushalf_yaml import ReadMinushalf
from .create_software_module import CreateSoftwareModule
from .run_calculations import RunCalculations
from .get_band_projections import GetBandProjections
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,35 @@
import collections
import loguru
import pandas as pd
from minushalf.softwares import Vasp
from minushalf.data import Softwares
from minushalf.handlers import BaseHandler
from minushalf.interfaces import SoftwaresAbstractFactory, BandProjectionFile, MinushalfYaml
from minushalf.utils import (BandProjectionFile, band_structure, projection_to_df)
from minushalf.interfaces import BandProjectionFile, MinushalfYaml
from minushalf.utils import (BandStructure, projection_to_df)


class GetBandProjections(BaseHandler):
"""
Uses the software module to extract the character of the bands (VBM and CBM)
"""
def _get_band_structure(
self, software_module: SoftwaresAbstractFactory) -> BandProjectionFile:
"""
Return band structure class
"""
eigenvalues = software_module.get_eigenvalues()
fermi_energy = software_module.get_fermi_energy()
atoms_map = software_module.get_atoms_map()
num_bands = software_module.get_number_of_bands()
band_projection_file = software_module.get_band_projection_class()

return BandProjectionFile(eigenvalues, fermi_energy, atoms_map, num_bands,
band_projection_file)

def _get_vbm_projection(self,
band_structure: BandProjectionFile) -> pd.DataFrame:
def _get_vbm_projection(
self, band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Returns vbm projection
"""
vbm_projection = band_structure.vbm_projection()
normalized_df = projection_to_df(vbm_projection)
return normalized_df

def _get_cbm_projection(self,
band_structure: BandProjectionFile) -> pd.DataFrame:
def _get_cbm_projection(
self, band_structure: BandProjectionFile) -> pd.DataFrame:
"""
Returns cbm projection
"""
cbm_projection = band_structure.cbm_projection()
normalized_df = projection_to_df(cbm_projection)
return normalized_df

def _get_band_projection(self, band_structure: BandProjectionFile, kpoint: int,
band: int) -> pd.DataFrame:
def _get_band_projection(self, band_structure: BandProjectionFile,
kpoint: int, band: int) -> pd.DataFrame:
"""
Returns band projection
"""
Expand All @@ -66,7 +50,7 @@ def _select_vbm(self, minushalf_yaml: MinushalfYaml,
if overwrite_vbm:
return self._get_band_projection(band_structure, *overwrite_vbm)

return self._get_vbm_projection()
return self._get_vbm_projection(band_structure)

def _select_cbm(self, minushalf_yaml: MinushalfYaml,
band_structure: BandProjectionFile) -> pd.DataFrame:
Expand All @@ -78,7 +62,7 @@ def _select_cbm(self, minushalf_yaml: MinushalfYaml,
if overwrite_cbm:
return self._get_band_projection(band_structure, *overwrite_cbm)

return self._get_cbm_projection()
return self._get_cbm_projection(band_structure)

def _is_valence_correction(self, correction_code: str) -> bool:
"""
Expand Down Expand Up @@ -118,8 +102,7 @@ def action(self, request: any) -> any:
loguru.logger.info("Extracting band projections")
software_module, minushalf_yaml = request["software_module"], request[
"minushalf_yaml"]
band_structure = self._get_band_structure(software_module)

band_structure = BandStructure.create(software_module)
## Add projections to object
request["projections"] = self._get_projections(minushalf_yaml,
band_structure)
Expand Down
25 changes: 24 additions & 1 deletion minushalf/utils/band_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from math import inf
from collections import defaultdict
import numpy as np
from minushalf.interfaces import BandProjectionFile
from minushalf.interfaces import BandProjectionFile, SoftwaresAbstractFactory


class BandStructure():
Expand Down Expand Up @@ -186,3 +186,26 @@ def band_gap(self) -> dict:
cbm_eigenval - vbm_eigenval
}
return gap_report

@staticmethod
def create(software_module: SoftwaresAbstractFactory,
base_path: str = '.'):
"""
Create band structure class from ab inition results
Args:
software_module (SoftwaresAbstractFactory): Holds the results of first principles
output calculations
base_path (str): Path to first principles output files
Returns:
band_strucure (BandStructure): Class with band structure informations
"""
eigenvalues = software_module.get_eigenvalues(base_path=base_path)
fermi_energy = software_module.get_fermi_energy(base_path=base_path)
atoms_map = software_module.get_atoms_map(base_path=base_path)
num_bands = software_module.get_number_of_bands(base_path=base_path)
band_projection_file = software_module.get_band_projection_class(
base_path=base_path)

return BandStructure(eigenvalues, fermi_energy, atoms_map, num_bands,
band_projection_file)
89 changes: 89 additions & 0 deletions tests/unit/handlers/test_get_band_projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Test get projections handler
"""
from unittest.mock import patch
from minushalf.io import MinushalfYaml
from minushalf.softwares import Vasp
from minushalf.handlers import GetBandProjections
from minushalf.utils import BandStructure


def test_get_valence_projections(file_path):
"""
Get projections of GaN 3d
"""
## Variables
minushalf_yaml = MinushalfYaml.from_file()
minushalf_yaml.correction.correction_code = "vc"
software_module = Vasp()
request = {
"minushalf_yaml": minushalf_yaml,
"software_module": software_module,
}
band_structure = BandStructure.create(software_module,
file_path("/gan-3d/"))

## Mock method in class BandStructure
with patch('minushalf.utils.BandStructure.create') as mock:
mock.return_value = band_structure
handler = GetBandProjections()
response = handler.action(request)

## Assertions
assert response["projections"]["valence"]["d"]["Ga"] == 16
assert response["projections"]["conduction"]["s"]["N"] == 56
assert "projections" in response


def test_get_conduction_and_valence_projections(file_path):
"""
Get projections of GaN 3d
"""
## Variables
minushalf_yaml = MinushalfYaml.from_file()
software_module = Vasp()
request = {
"minushalf_yaml": minushalf_yaml,
"software_module": software_module,
}
band_structure = BandStructure.create(software_module,
file_path("/gan-3d/"))

## Mock method in class BandStructure
with patch('minushalf.utils.BandStructure.create') as mock:
mock.return_value = band_structure
handler = GetBandProjections()
response = handler.action(request)

## Assertions
assert response["projections"]["valence"]["d"]["Ga"] == 16
assert response["projections"]["valence"]["p"]["N"] == 78
assert "projections" in response


def test_get_projections_with_ovewrite(file_path):
"""
Get overwrited projections of GaN 3d
"""
## Variables
minushalf_yaml = MinushalfYaml.from_file(
file_path("/minushalf_yaml/minushalf_filled_out.yaml"))
software_module = Vasp()
request = {
"minushalf_yaml": minushalf_yaml,
"software_module": software_module,
}
band_structure = BandStructure.create(software_module,
file_path("/gan-3d/"))

## Mock method in class BandStructure
with patch('minushalf.utils.BandStructure.create') as mock:
mock.return_value = band_structure
handler = GetBandProjections()
response = handler.action(request)
print(response["projections"])
## Assertions
assert response["projections"]["valence"]["d"]["Ga"] == 98
assert "projections" in response
assert "valence" in response["projections"]
assert ("conduction" in response["projections"]) == False
Empty file.
31 changes: 31 additions & 0 deletions tests/unit/utils/test_band_structure_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from minushalf.utils import BandStructure
from minushalf.softwares.vasp import Procar, Vasprun, Eigenvalues
from minushalf.softwares import Vasp


def test_is_metal_gan_3d(file_path):
Expand All @@ -28,6 +29,19 @@ def test_is_metal_gan_3d(file_path):
assert band_structure.is_metal() is False


def test_create_bandstrutcture_gan_3d(file_path):
"""
Tests create method for GaN 3d
"""
software_module = Vasp()

band_structure = BandStructure.create(software_module,
file_path("/gan-3d/"))

assert isinstance(band_structure, BandStructure)
assert band_structure.is_metal() is False


def test_band_gap_gan_3d(file_path):
"""
Tests GaN 3d band gap .
Expand All @@ -49,6 +63,7 @@ def test_band_gap_gan_3d(file_path):
assert np.isclose(band_structure.band_gap()["gap"], 1.5380389999999995)



def test_vbm_index_gan_3d(file_path):
"""
Test if the index of the valence maximum band is correct
Expand Down Expand Up @@ -255,6 +270,22 @@ def test_vbm_index_bn_2d(file_path):
assert vbm_index[0] == kpoint_vbm
assert vbm_index[1] == band_vbm

def test_create_bandstrutcture_bn_2d(file_path):
"""
Tests create method for BN 2d
"""
softare_module = Vasp()
kpoint_vbm = 24
band_vbm = 4

band_structure = BandStructure.create(softare_module,
file_path("/bn-2d/"))

vbm_index = band_structure.vbm_index()
assert isinstance(band_structure, BandStructure)
assert vbm_index[0] == kpoint_vbm
assert vbm_index[1] == band_vbm


def test_cbm_index_bn_2d(file_path):
"""
Expand Down

0 comments on commit 17b110c

Please sign in to comment.