Skip to content

Commit

Permalink
shifted get_workflow_results() to class method
Browse files Browse the repository at this point in the history
  • Loading branch information
jrudz committed Dec 15, 2023
1 parent da9a36a commit c616845
Showing 1 changed file with 100 additions and 93 deletions.
193 changes: 100 additions & 93 deletions atomisticparsers/h5md/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,111 +609,114 @@ def parse_method(self):

# Get the interactions
connectivity = self.h5md_groups.get('connectivity')
if not connectivity:
return
if connectivity:
atom_labels = self.atom_parameters.get('label')
interaction_keys = ['bonds', 'angles', 'dihedrals', 'impropers']
interactions_by_type = []
for interaction_key in interaction_keys:
interaction_list = self._data_parser.get_value(connectivity, interaction_key)
if interaction_list is None:
continue
elif isinstance(interaction_list, h5py.Group):
self.logger.warning('Time-dependent interactions currently not supported. These values will not be stored')
continue

atom_labels = self.atom_parameters.get('label')
interaction_keys = ['bonds', 'angles', 'dihedrals', 'impropers']
interactions_by_type = []
for interaction_key in interaction_keys:
interaction_list = self._data_parser.get_value(connectivity, interaction_key)
if interaction_list is None:
continue
elif isinstance(interaction_list, h5py.Group):
self.logger.warning('Time-dependent interactions currently not supported. These values will not be stored')
continue

interaction_type_dict = {
'type': interaction_key,
'n_interactions': len(interaction_list),
'n_atoms': len(interaction_list[0]),
'atom_indices': interaction_list,
'atom_labels': np.array(atom_labels)[interaction_list] if atom_labels is not None else None
}
interactions_by_type.append(interaction_type_dict)
self.parse_interactions_by_type(interactions_by_type, sec_model)
interaction_type_dict = {
'type': interaction_key,
'n_interactions': len(interaction_list),
'n_atoms': len(interaction_list[0]),
'atom_indices': interaction_list,
'atom_labels': np.array(atom_labels)[interaction_list] if atom_labels is not None else None
}
interactions_by_type.append(interaction_type_dict)
self.parse_interactions_by_type(interactions_by_type, sec_model)

# Get the force calculation parameters
force_calculation_parameters = self.parameter_info.get('force_calculations')
if force_calculation_parameters is None:
return
if force_calculation_parameters:
sec_force_calculations = sec_force_field.m_create(ForceCalculations)
sec_neighbor_searching = sec_force_calculations.m_create(NeighborSearching)

for key, val in force_calculation_parameters.items():
if not isinstance(val, dict):
# key = self.resolve_key(ForceCalculations, key, val)
if self.is_valid_key_val(ForceCalculations, key, val):
sec_force_calculations.m_set(sec_force_calculations.m_get_quantity_definition(key), val)
else:
units = val.units if hasattr(val, 'units') else None
val = val.magnitude if units is not None else val
sec_force_calculations.x_h5md_parameters.append(ParamEntry(kind=key, value=val, unit=units))
else:
if key == 'neighbor_searching':
for neigh_key, neigh_val in val.items():
# neigh_key = self.resolve_key(NeighborSearching, neigh_key, neigh_val)
if self.is_valid_key_val(NeighborSearching, neigh_key, neigh_val):
sec_neighbor_searching.m_set(sec_neighbor_searching.m_get_quantity_definition(neigh_key), neigh_val)
else:
units = val.units if hasattr(val, 'units') else None
val = val.magnitude if units is not None else val
sec_neighbor_searching.x_h5md_parameters.append(ParamEntry(kind=key, value=val, unit=units))
else:
self.logger.warning('Unknown parameters in force calculations section. These will not be stored.')

sec_force_calculations = sec_force_field.m_create(ForceCalculations)
sec_neighbor_searching = sec_force_calculations.m_create(NeighborSearching)
def get_workflow_properties_dict(
self, observables: dict, property_type_key=None, property_type_value_key=None,
properties_known={}, property_keys_list=[], property_value_keys_list=[]):

for key, val in force_calculation_parameters.items():
if not isinstance(val, dict):
# key = self.resolve_key(ForceCalculations, key, val)
if self.is_valid_key_val(ForceCalculations, key, val):
sec_force_calculations.m_set(sec_force_calculations.m_get_quantity_definition(key), val)
else:
units = val.units if hasattr(val, 'units') else None
val = val.magnitude if units is not None else val
sec_force_calculations.x_h5md_parameters.append(ParamEntry(kind=key, value=val, unit=units))
def populate_property_dict(property_dict, val_name, val, flag_known_property=False):
if val is None:
return
value_unit = val.units if hasattr(val, 'units') else None
value_magnitude = val.magnitude if hasattr(val, 'units') else val
if flag_known_property:
property_dict[val_name] = value_magnitude * value_unit if value_unit else value_magnitude
else:
if key == 'neighbor_searching':
for neigh_key, neigh_val in val.items():
# neigh_key = self.resolve_key(NeighborSearching, neigh_key, neigh_val)
if self.is_valid_key_val(NeighborSearching, neigh_key, neigh_val):
sec_neighbor_searching.m_set(sec_neighbor_searching.m_get_quantity_definition(neigh_key), neigh_val)
else:
units = val.units if hasattr(val, 'units') else None
val = val.magnitude if units is not None else val
sec_neighbor_searching.x_h5md_parameters.append(ParamEntry(kind=key, value=val, unit=units))
else:
self.logger.warning('Unknown parameters in force calculations section. These will not be stored.')
property_dict[f'{val_name}_unit'] = str(value_unit) if value_unit else None
property_dict[f'{val_name}_magnitude'] = value_magnitude

workflow_properties_dict = {}
for observable_type, observable_dict in observables.items():
flag_known_property = False
if observable_type in properties_known:
property_type_key = observable_type
property_type_value_key = properties_known[observable_type]
flag_known_property = True
property_dict = {property_type_value_key: []}
property_dict['label'] = observable_type
for key, observable in observable_dict.items():
property_values_dict = {'label': key}
for quant_name, val in observable.items():
if quant_name == 'val':
continue
if quant_name == 'bins':
continue
if quant_name in property_keys_list:
property_dict[quant_name] = val
if quant_name in property_value_keys_list:
property_values_dict[quant_name] = val
# TODO Still need to add custom values here.

val = observable.get('value')
populate_property_dict(property_values_dict, 'value', val, flag_known_property=flag_known_property)
bins = observable.get('bins')
populate_property_dict(property_values_dict, 'bins', bins, flag_known_property=flag_known_property)
property_dict[property_type_value_key].append(property_values_dict)

if workflow_properties_dict.get(property_type_key):
workflow_properties_dict[property_type_key].append(property_dict)
else:
workflow_properties_dict[property_type_key] = [property_dict]

return workflow_properties_dict

def parse_workflow(self):

workflow_parameters = self.parameter_info.get('workflow').get('molecular_dynamics')
# TODO should store parameters that do not match the enum vals as x_h5MD params, not sure how with MDParser??
# TODO should store parameters that do not match the enum vals as x_h5MD params,
# not sure how with MDParser??
if workflow_parameters is None:
return

def get_workflow_results(property_type_dict, observables, workflow_results):

def populate_property_dict(property_dict, val_name, val, flag_known_property=False):
if val is None:
return
value_unit = val.units if hasattr(val, 'units') else None
value_magnitude = val.magnitude if hasattr(val, 'units') else val
if flag_known_property:
property_dict[val_name] = value_magnitude * value_unit if value_unit else value_magnitude
else:
property_dict[f'{val_name}_unit'] = str(value_unit) if value_unit else None
property_dict[f'{val_name}_magnitude'] = value_magnitude

property_key = property_type_dict['property_type_key']
property_value_key = property_type_dict['property_type_value_key']
for observable_type, observable_dict in observables.items():
flag_known_property = False
if observable_type in property_type_dict['properties_known']:
property_key = observable_type
property_value_key = property_type_dict['properties_known'][observable_type]
flag_known_property = True
workflow_results[property_key] = []
property_dict = {property_value_key: []}
property_dict['label'] = observable_type
for key, observable in observable_dict.items():
property_values_dict = {'label': key}
for quant_name, val in observable.items():
if quant_name == 'val':
continue
if quant_name == 'bins':
continue
if quant_name in property_type_dict['property_keys_list']:
property_dict[quant_name] = val
if quant_name in property_type_dict['property_value_keys_list']:
property_values_dict[quant_name] = val
# TODO Still need to add custom values here.

val = observable.get('value')
populate_property_dict(property_values_dict, 'value', val, flag_known_property=flag_known_property)
bins = observable.get('bins')
populate_property_dict(property_values_dict, 'bins', bins, flag_known_property=flag_known_property)
property_dict[property_value_key].append(property_values_dict)
workflow_results[property_key].append(property_dict)

workflow_results = {}
ensemble_average_observables = self.observable_info.get('ensemble_average')
ensemble_property_dict = {
Expand All @@ -723,7 +726,9 @@ def populate_property_dict(property_dict, val_name, val, flag_known_property=Fal
'property_keys_list': EnsembleProperty.m_def.all_quantities.keys(),
'property_value_keys_list': EnsemblePropertyValues.m_def.all_quantities.keys()
}
get_workflow_results(ensemble_property_dict, ensemble_average_observables, workflow_results)
workflow_results.update(
self.get_workflow_properties_dict(ensemble_average_observables, **ensemble_property_dict)
)
correlation_function_observables = self.observable_info.get('correlation_function')
correlation_function_dict = {
'property_type_key': 'correlation_functions',
Expand All @@ -732,7 +737,9 @@ def populate_property_dict(property_dict, val_name, val, flag_known_property=Fal
'property_keys_list': CorrelationFunction.m_def.all_quantities.keys(),
'property_value_keys_list': CorrelationFunctionValues.m_def.all_quantities.keys()
}
get_workflow_results(correlation_function_dict, correlation_function_observables, workflow_results)
workflow_results.update(
self.get_workflow_properties_dict(correlation_function_observables, **correlation_function_dict)
)
self.parse_md_workflow(dict(method=workflow_parameters, results=workflow_results))

def init_parser(self):
Expand Down

0 comments on commit c616845

Please sign in to comment.