From 1040c5913d00fb7ad854173c190d94c9fffa3205 Mon Sep 17 00:00:00 2001 From: jrudz Date: Fri, 15 Dec 2023 15:34:58 +0100 Subject: [PATCH] another round of review fixes --- atomisticparsers/h5md/metainfo/h5md.py | 2 +- atomisticparsers/h5md/nomad_plugin.yaml | 5 +-- atomisticparsers/h5md/parser.py | 48 ++++++++++++++++--------- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/atomisticparsers/h5md/metainfo/h5md.py b/atomisticparsers/h5md/metainfo/h5md.py index 72cdbd8f..361ddfd8 100644 --- a/atomisticparsers/h5md/metainfo/h5md.py +++ b/atomisticparsers/h5md/metainfo/h5md.py @@ -187,4 +187,4 @@ class Run(simulation.run.Run): x_h5md_author = SubSection(sub_section=Author.m_def) - x_h5md_creator = SubSection(sub_section=simulation.run.Program.m_def) \ No newline at end of file + x_h5md_creator = SubSection(sub_section=simulation.run.Program.m_def) diff --git a/atomisticparsers/h5md/nomad_plugin.yaml b/atomisticparsers/h5md/nomad_plugin.yaml index 62cd8ca1..cb97a11a 100644 --- a/atomisticparsers/h5md/nomad_plugin.yaml +++ b/atomisticparsers/h5md/nomad_plugin.yaml @@ -1,12 +1,13 @@ code_category: Atomistic code -code_homepage: http://www.??.org/ +# code_homepage: http://www.??.org/ +# TODO is this the correct category code_name: H5MD metadata: codeCategory: Atomistic code codeLabel: H5MD codeLabelStyle: All in capitals codeName: h5md - codeUrl: http://www.??.org/ + # codeUrl: http://www.??.org/ parserDirName: dependencies/parsers/atomistic/atomisticparsers/h5md/ parserGitUrl: https://github.com/nomad-coe/atomistic-parsers.git parserSpecific: '' diff --git a/atomisticparsers/h5md/parser.py b/atomisticparsers/h5md/parser.py index a4559390..d7ebbb79 100644 --- a/atomisticparsers/h5md/parser.py +++ b/atomisticparsers/h5md/parser.py @@ -21,6 +21,9 @@ import logging import h5py +from typing import List, Dict, Tuple, Any, Union, Iterable, cast, Callable, TYPE_CHECKING +from h5py import Group + from nomad.datamodel import EntryArchive from nomad.metainfo.util import MEnum from nomad.parsing.file_parser import FileParser @@ -124,6 +127,18 @@ def get_value(self, group, path: str, default=None): def parse(self, quantity_key: str = None, **kwargs): pass + # def parse(self, key, **kwargs): + # source = kwargs.get('source', self.filehdf5) + # path = kwargs.get('path') + # attribute = kwargs.get('attribute') + # if attribute is not None: + # value = self.hdf5_attr_getter(source, path, attribute) + # else: + # value = self.hdf5_getter(source, path) + # # i do not know how to construct the full path here relative to root + # full_path = f'...{path}.{attribute}' + # self._results[full_path] = value + class H5MDParser(MDParser): def __init__(self): @@ -326,25 +341,26 @@ def observable_info(self): if observables_group is None: return self._observable_info - def get_observable_paths(observable_group, current_path, paths): + def get_observable_paths(observable_group: Dict, current_path: str) -> List: + paths = [] for obs_key in observable_group.keys(): path = obs_key + '.' observable = self._data_parser.get_value(observable_group, obs_key) observable_type = self._data_parser.get_value(observable_group, obs_key).attrs.get('type') if not observable_type: - paths = get_observable_paths(observable, current_path + path, paths) + paths.extend(get_observable_paths(observable, f'{current_path}{path}')) else: paths.append(current_path + path[:-1]) return paths - observable_paths = get_observable_paths(observables_group, current_path='', paths=[]) + observable_paths = get_observable_paths(observables_group, current_path='') for path in observable_paths: observable = self._data_parser.get_value(observables_group, path) observable_type = self._data_parser.get_value(observables_group, path).attrs.get('type') observable_name = path.split('.')[0] - observable_label = '-'.join(path.split('.')[1:]) if len(path.split('.')) > 1 else '' + observable_label = '-'.join(path.split('.')[1:]) if observable_name not in self._observable_info[observable_type].keys(): self._observable_info[observable_type][observable_name] = {} self._observable_info[observable_type][observable_name][observable_label] = {} @@ -356,10 +372,11 @@ def get_observable_paths(observable_group, current_path, paths): self._observable_info[observable_type][observable_name][observable_label][key] = observable_attribute return self._observable_info - def get_atomsgroup_fromh5md(self, nomad_sec, h5md_sec_particlesgroup): + def parse_atomsgroup(self, nomad_sec, h5md_sec_particlesgroup: Group): for i_key, key in enumerate(h5md_sec_particlesgroup.keys()): particles_group = {group_key: self._data_parser.get_value(h5md_sec_particlesgroup[key], group_key) for group_key in h5md_sec_particlesgroup[key].keys()} - sec_atomsgroup = nomad_sec.m_create(AtomsGroup) + sec_atomsgroup = AtomsGroup() + nomad_sec.atoms_group.append(sec_atomsgroup) sec_atomsgroup.type = particles_group.pop('type', None) sec_atomsgroup.index = i_key sec_atomsgroup.atom_indices = particles_group.pop('indices', None) @@ -376,9 +393,9 @@ def get_atomsgroup_fromh5md(self, nomad_sec, h5md_sec_particlesgroup): sec_atomsgroup.x_h5md_parameters.append(ParamEntry(kind=particles_group_key, value=val, unit=units)) # get the next atomsgroup if particles_subgroup: - self.get_atomsgroup_fromh5md(sec_atomsgroup, particles_subgroup) + self.parse_atomsgroup(sec_atomsgroup, particles_subgroup) - def is_valid_key_val(self, metainfo_class, key, val): + def is_valid_key_val(self, metainfo_class, key: str, val) -> bool: if hasattr(metainfo_class, key): quant_type = getattr(metainfo_class, key).get('type') is_MEnum = isinstance(quant_type, MEnum) if quant_type else False @@ -397,7 +414,7 @@ def parameter_info(self): if parameters_group is None: return self._parameter_info - def get_parameters(parameter_group): + def get_parameters(parameter_group: Group) -> Dict: param_dict = {} for key, val in parameter_group.items(): if isinstance(val, h5py.Group): @@ -500,7 +517,7 @@ def format_times(times): if key == 'forces': data[key] = dict(total=dict(value=val[system_index])) else: - if key in BaseCalculation.__dict__.keys(): + if hasattr(BaseCalculation, key): data[key] = val[system_index] else: unit = None @@ -516,7 +533,7 @@ def format_times(times): if obs_index: val = observable.get('value', [None] * (obs_index + 1))[obs_index] if 'energ' in observable_type: # TODO check for energies or energy when matching name - if key in Energy.__dict__.keys(): + if hasattr(Energy, key): data['energy'][key] = dict(value=val) else: data_h5md['x_h5md_energy_contributions'].append(EnergyEntry(kind=map_key, value=val)) @@ -526,7 +543,7 @@ def format_times(times): else: key = map_key - if key in BaseCalculation.__dict__.keys(): + if hasattr(BaseCalculation, key): data[key] = val else: unit = None @@ -559,6 +576,7 @@ def parse_system(self): n_frames = len(system_info.get('times', [])) self._system_time_map = {} self._system_step_map = {} + for frame in range(n_frames): # if (n % self.frame_rate) > 0: # continue @@ -588,7 +606,7 @@ def parse_system(self): if frame == 0: # TODO extend to time-dependent topologies topology = self._data_parser.get_value(connectivity, 'particles_group', None) if topology: - self.get_atomsgroup_fromh5md(sec_run.system[frame], topology) + self.parse_atomsgroup(sec_run.system[frame], topology) def parse_method(self): @@ -639,7 +657,6 @@ def parse_method(self): for key, val in force_calculation_parameters.items(): if not isinstance(val, dict): - # key = self.resolve_key(ForceCalculations, key, val) if self.is_valid_key_val(ForceCalculations, key, val): sec_force_calculations.m_set(sec_force_calculations.m_get_quantity_definition(key), val) else: @@ -649,7 +666,6 @@ def parse_method(self): else: if key == 'neighbor_searching': for neigh_key, neigh_val in val.items(): - # neigh_key = self.resolve_key(NeighborSearching, neigh_key, neigh_val) if self.is_valid_key_val(NeighborSearching, neigh_key, neigh_val): sec_neighbor_searching.m_set(sec_neighbor_searching.m_get_quantity_definition(neigh_key), neigh_val) else: @@ -660,7 +676,7 @@ def parse_method(self): self.logger.warning('Unknown parameters in force calculations section. These will not be stored.') def get_workflow_properties_dict( - self, observables: dict, property_type_key=None, property_type_value_key=None, + self, observables: Dict, property_type_key=None, property_type_value_key=None, properties_known={}, property_keys_list=[], property_value_keys_list=[]): def populate_property_dict(property_dict, val_name, val, flag_known_property=False):