revisiting the recursive processing of `analysis_schema` models

In [80]:
data_dict = {
  "$schema": "../analysis_schema/yt_analysis_schema.json",
  "Data": [{"FileName": "not/a/real/file", "DatasetName": "blah"}],  
  "Plot": [
    {
      "ProjectionPlot": {
        "Dataset": [
          {
            "FileName": "../../Data/IsolatedGalaxy/galaxy0030/galaxy0030",
            "DatasetName": "IG"
          },
          {
            "FileName": "../../Data/enzo_tiny_cosmology/DD0000/DD0000",
            "DatasetName": "Enzo"
          }
        ],
        "Axis":"y",
        "FieldNames": {
          "field": "density",
          "field_type": "gas"
        },
        "WeightFieldName": {
          "field": "temperature",
          "field_type": "gas"
        }
      }
    }
  ]
}


In [81]:
from analysis_schema import ytModel

In [82]:
yt_model = ytModel.parse_obj(data_dict)

In [83]:
yt_model

ytModel(Data=[Dataset(DatasetName='blah', fn=PosixPath('not/a/real/file'), comments=None)], Plot=[Visualizations(SlicePlot=None, ProjectionPlot=ProjectionPlot(ds=[Dataset(DatasetName='IG', fn=PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'), comments=None), Dataset(DatasetName='Enzo', fn=PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'), comments=None)], fields=FieldNames(field='density', field_type='gas', comments=None), normal='y', center=None, width=None, axes_unit=None, weight_field=FieldNames(field='temperature', field_type='gas', comments=None), max_level=None, origin=None, right_handed=None, fontsize=None, field_parameters=None, method=None, data_source=None, Comments=None), PhasePlot=None)])

First, let's do a parsing step that walks the model, adding dataset objects to the available data store. Here's the `DatasetFixture` from `analysis_schema._data_store`. Defining it here in case I want to modify it:

In [84]:
import yt 
import warnings

class DatasetFixture:
    """
    A class to hold all references and instantiated datasets.
    Also has a method to instantiate the data if it isn't already.
    There is a dictionary for dataset references and
    instantiated datasets.
    """

    def __init__(self):
        self.all_data = {}
        self._instantiated_datasets = {}

    def add_to_alldata(self, fn: str, dataset_name: str):
        """
        A function to track all dataset.
        Stores dataset name, or if no name is provided,
        adds a number as the name.
        """
        self.fn = fn
        if dataset_name is None:
            dataset_name = fn
            
        if dataset_name in self.all_data:
            if self.all_data[dataset_name] != fn:
                warnings.warn("duplicate dataset name")
            
        self.all_data[dataset_name] = fn

    def _instantiate_data(
        self,
        dataset_name: str,
    ):
        """
        Instantiates a dataset and stores it in a separate dictionary.
        Returns an instantiated (loaded into memory) dataset.
        """
        ds = yt.load(self.all_data[dataset_name])
        self._instantiated_datasets[dataset_name] = ds
        return ds




In [150]:
print(yt_model.json(indent=4))

{
    "Data": [
        {
            "DatasetName": "blah",
            "fn": "not/a/real/file",
            "comments": null
        }
    ],
    "Plot": [
        {
            "SlicePlot": null,
            "ProjectionPlot": {
                "ds": [
                    {
                        "DatasetName": "IG",
                        "fn": "../../Data/IsolatedGalaxy/galaxy0030/galaxy0030",
                        "comments": null
                    },
                    {
                        "DatasetName": "Enzo",
                        "fn": "../../Data/enzo_tiny_cosmology/DD0000/DD0000",
                        "comments": null
                    }
                ],
                "fields": {
                    "field": "density",
                    "field_type": "gas",
                    "comments": null
                },
                "normal": "y",
                "center": null,
                "width": null,
                "axes_unit": null,
          

first check the model-wide `Data` list

In [91]:
import pydantic 

dataset_fixture = DatasetFixture()

def add_to_fixture(pydantic_ds: analysis_schema.data_classes.Dataset):
    fn = pydantic_ds.fn
    dname = pydantic_ds.DatasetName        
    dataset_fixture.add_to_alldata(fn, dname)
    
def add_model_data_to_fixture(model_attr):    
    if isinstance(model_attr, analysis_schema.data_classes.Dataset):
        # found a dataset, add it to the data fixture
        add_to_fixture(model_attr)        
    else:    
        # traverse all the model attributes
        for key, value in model_attr:
            if isinstance(value, list):
                # its a list, go deeper!
                for val in value:
                    process_pydantic_attr(val)        
            elif isinstance(value, pydantic.BaseModel):
                # it's a pydantic model, go deeper!
                process_pydantic_attr(value)                                


In [92]:
add_model_data_to_fixture(yt_model)

In [93]:
dataset_fixture.all_data

{'blah': PosixPath('not/a/real/file'),
 'IG': PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'),
 'Enzo': PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000')}

now let's walk it again and construct graphs for each ds

In [None]:
flatted_by_ds = {}

def monitor_ds_usage(model_attr, ds_name, ds_file):    
    if isinstance(model_attr, analysis_schema.data_classes.Dataset):
        # found a dataset, add it to the data fixture
        add_to_fixture(model_attr)        
    else:    
        # traverse all the model attributes
        for key, value in model_attr:
            if isinstance(value, list):
                # its a list, go deeper!
                for val in value:
                    process_pydantic_attr(val)        
            elif isinstance(value, pydantic.BaseModel):
                # it's a pydantic model, go deeper!
                process_pydantic_attr(value)                                

                
for ds_name, ds_file in dataset_fixture:    
    ds = yt.load(ds_file)
    
    # now do all the operations that use that ds
    
    



In [148]:
import yt 
from inspect import getfullargspec

# registry of operations to take a pydantic model and return a yt-ready value


def return_field(field_name: analysis_schema.data_classes.FieldNames):
    return (field_name.field_type, field_name.field)


def return_dataset(ds_model: analysis_schema.data_classes.Dataset):
    return yt.load(ds_model.fn)


class EvaluatorRegistry:
    
    def __init__(self):
        self.registry = {}
        
    def add(self, pydantic_class, evaluator):
        self.registry[pydantic_class] = evaluator
            
    def evaluate(self, pydantic_model_instance):
        ptype = type(pydantic_model_instance)
        if ptype in self.registry:
            return self.registry[ptype](pydantic_model_instance)
          
        
        
        
        
        
    
registry = EvaluatorRegistry()
registry.add(analysis_schema.data_classes.FieldNames, return_field)
registry.add(analysis_schema.data_classes.Dataset, return_dataset)


f = analysis_schema.data_classes.FieldNames.parse_obj({"field_type":"enzo", "field":"Density"})
registry.evaluate(f)


ds_model = analysis_schema.data_classes.Dataset.parse_obj({"FileName": "IsolatedGalaxy/galaxy0030/galaxy0030", "DatasetName":"isogal"})
ds = registry.evaluate(ds_model)






yt : [INFO     ] 2022-05-04 17:09:03,933 Parameters: current_time              = 0.0060000200028298
yt : [INFO     ] 2022-05-04 17:09:03,934 Parameters: domain_dimensions         = [32 32 32]
yt : [INFO     ] 2022-05-04 17:09:03,934 Parameters: domain_left_edge          = [0. 0. 0.]
yt : [INFO     ] 2022-05-04 17:09:03,935 Parameters: domain_right_edge         = [1. 1. 1.]
yt : [INFO     ] 2022-05-04 17:09:03,936 Parameters: cosmological_simulation   = 0


In [132]:
f = YtMethod("ProjectionPlot", yt_model.Plot[0].ProjectionPlot)
f.yt_callable

arg_value is a BaseModel
fields
field='density' field_type='gas' comments=None
<class 'analysis_schema.data_classes.FieldNames'>
arg_value is a BaseModel
weight_field
field='temperature' field_type='gas' comments=None
<class 'analysis_schema.data_classes.FieldNames'>


yt.visualization.plot_window.ProjectionPlot

In [122]:
f.yt_args

([[Dataset(DatasetName='IG', fn=PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'), comments=None),
   Dataset(DatasetName='Enzo', fn=PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'), comments=None)],
  'y',
  FieldNames(field='density', field_type='gas', comments=None),
  'c',
  None,
  None,
  FieldNames(field='temperature', field_type='gas', comments=None),
  None,
  'center-window',
  True,
  18,
  None,
  None,
  'integrate',
  None,
  8.0,
  (800, 800),
  None],
 {})

In [124]:
yt.ProjectionPlot?

[0;31mInit signature:[0m
[0myt[0m[0;34m.[0m[0mProjectionPlot[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mds[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxis[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfields[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcenter[0m[0;34m=[0m[0;34m'c'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwidth[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxes_unit[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mweight_field[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_level[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0morigin[0m[0;34m=[0m[0;34m'center-window'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mright_handed[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfontsize[0m[0;34m=[0m[0;36m18[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfield_parameters[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m

In [113]:
yt_model.Plot[0].ProjectionPlot.dict()

{'ds': [{'DatasetName': 'IG',
   'fn': PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'),
   'comments': None},
  {'DatasetName': 'Enzo',
   'fn': PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'),
   'comments': None}],
 'fields': {'field': 'density', 'field_type': 'gas', 'comments': None},
 'normal': 'y',
 'center': None,
 'width': None,
 'axes_unit': None,
 'weight_field': {'field': 'temperature',
  'field_type': 'gas',
  'comments': None},
 'max_level': None,
 'origin': None,
 'right_handed': None,
 'fontsize': None,
 'field_parameters': None,
 'method': None,
 'data_source': None,
 'Comments': None}

In [105]:
import inspect

In [110]:
yt_model.Plot[0].dict()

{'SlicePlot': None,
 'ProjectionPlot': {'ds': [{'DatasetName': 'IG',
    'fn': PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'),
    'comments': None},
   {'DatasetName': 'Enzo',
    'fn': PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'),
    'comments': None}],
  'fields': {'field': 'density', 'field_type': 'gas', 'comments': None},
  'normal': 'y',
  'center': None,
  'width': None,
  'axes_unit': None,
  'weight_field': {'field': 'temperature',
   'field_type': 'gas',
   'comments': None},
  'max_level': None,
  'origin': None,
  'right_handed': None,
  'fontsize': None,
  'field_parameters': None,
  'method': None,
  'data_source': None,
  'Comments': None},
 'PhasePlot': None}

In [None]:
    



# class YtMethod:
    
#     def __init__(self, method, pydantic_model):
        
#         self.method = method
#         self.callable = self._find_yt_callable()
#         self.args = self._set_args(pydantic_model)
        
#     def _sanitize_args(self):
#         clean_args = []
#         for arg in self.args:
#             if isinstance(arg, YtMethod):
#                 clean_args.append(arg.evaluate())
#             else:
#                 clean_args.append(arg)
#         return clean_args
            
        
#     def evaluate(self):
#         args = self._sanitize_args(self.args)        
#         return self.callable(*args)
        
#     def _find_yt_callable(self):
                
#         if hasattr(yt, self.method):
#             # check top level yt api
#             return getattr(yt, self.method)
#         elif self.method in yt_custom_registry:
#             return yt_custom_registry[self.method]
#         else:
#             raise RuntimeError(f"Could not find {self.method} in yt api")
            
#     def _set_args(self, pydantic_model):
        
#         # the list that we'll use to eventually call our function
#         the_args = []
#         # this method actually executes the yt code

#         func = self.callable

#         # now we get the arguments for the function:
#         # func_spec.args, which lists the named arguments and keyword arguments.
#         # ignoring vargs and kw-only args for now...
#         # see https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
#         func_spec = getfullargspec(func)

#         # the argument position number at which we have default values (a little
#         # hacky, should be a better way to do this, and not sure how to scale it to
#         # include *args and **kwargs)
#         n_args = len(func_spec.args)  # number of arguments
#         if func_spec.defaults is None:
#             # no default args, make sure we never get there...
#             named_kw_start_at = n_args + 1
#         else:
#             # the position at which named keyword args start
#             named_kw_start_at = n_args - len(func_spec.defaults)

#         # loop over the call signature arguments and pull out values from our pydantic
#         # class. this is recursive!
#         for arg_i, arg in enumerate(func_spec.args):
#             # check if we've remapped the yt internal argument name for the schema
#             if arg in ["self", "cls"]:
#                 continue

#             # get the value for this argument. If it's not there, attempt to set default
#             # values for arguments needed for yt but not exposed in our pydantic class
#             try:
#                 arg_value = getattr(pydantic_model, arg)
#                 if arg_value is None:
#                     default_index = arg_i - named_kw_start_at
#                     arg_value = func_spec.defaults[default_index]
#             except AttributeError:
#                 if arg_i >= named_kw_start_at:
#                     # we are in the named keyword arguments, grab the default
#                     # the func_spec.defaults tuple 0 index is the first named
#                     # argument, so need to offset the arg_i counter
#                     default_index = arg_i - named_kw_start_at
#                     arg_value = func_spec.defaults[default_index]
#                 else:
#                     raise AttributeError(f"could not find {arg}")

#             if isinstance(arg_value, pydantic.BaseModel):
#                 YtMethod(
#                 print("arg_value is a BaseModel")
#                 print(arg)
#                 print(arg_value)
#                 print(type(arg_value))
            
#             # if _check_run(arg_value):
#             #     arg_value = arg_value._run()
                
#             the_args.append(arg_value)

#         # if this class has a list of known kwargs that we know will not be
#         # picked up by argspec, add them here. Not using inspect here because
#         # some of the yt visualization classes pass along kwargs, so we need
#         # to do this semi-manually for some classes and functions.
#         kwarg_dict = {}
#         if hasattr(self, "_known_kwargs"):
#             for kw in self._known_kwargs:
#                 arg_value = getattr(pydantic_model, kw, None)
#                 if isinstance(arg_value, pydantic.BaseModel):
#                     print(kw)
#                     print("arg_value is a BaseModel")                
#                 kwarg_dict[kw] = arg_value
                
#         return the_args, kwarg_dict

# #
# def model_to_yt_api(model_name, model_attr):