In [1]:
import analysis_schema
import yt
import inspect


In [2]:

class YTrunner:
    def __init__(self):
        self.registry = {}
        
        

In [4]:

analysis_schema.base_model.ytBaseModel.schema()

{'title': 'ytBaseModel',
 'description': 'A class to connect attributes and their values to yt operations and their\nkeyword arguments.\n\nArgs:\n    BaseModel ([type]): A pydantic basemodel in the form of a json schema\n\nRaises:\n    AttributeError: [description]\n\nReturns:\n    [list]: A list of yt classes to be run and then displayed',
 'type': 'object',
 'properties': {}}

In [6]:
analysis_schema.data_classes.Slice.schema()

{'title': 'Slice',
 'description': 'An axis-aligned 2-d slice data selection object',
 'type': 'object',
 'properties': {'axis': {'title': 'Axis',
   'anyOf': [{'type': 'integer'}, {'type': 'string'}]},
  'coord': {'title': 'Coord', 'type': 'number'}},
 'required': ['axis', 'coord']}

In [2]:
data_dict = {
  "$schema": "../analysis_schema/yt_analysis_schema.json",
  "Data": [{"FileName": "not/a/real/file", "DatasetName": "blah"}],  
  "Plot": [
    {
      "ProjectionPlot": {
        "Dataset": [
          {
            "FileName": "../../Data/IsolatedGalaxy/galaxy0030/galaxy0030",
            "DatasetName": "IG"
          },
          {
            "FileName": "../../Data/enzo_tiny_cosmology/DD0000/DD0000",
            "DatasetName": "Enzo"
          }
        ],
        "Axis":"y",
        "FieldNames": {
          "field": "density",
          "field_type": "gas"
        },
        "WeightFieldName": {
          "field": "temperature",
          "field_type": "gas"
        }
      }
    }
  ]
}


In [3]:
yt_model = analysis_schema.ytModel.parse_obj(data_dict)


In [4]:
yt_model

ytModel(Data=[Dataset(DatasetName='blah', fn=PosixPath('not/a/real/file'), comments=None)], Plot=[Visualizations(SlicePlot=None, ProjectionPlot=ProjectionPlot(ds=[Dataset(DatasetName='IG', fn=PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'), comments=None), Dataset(DatasetName='Enzo', fn=PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'), comments=None)], fields=FieldNames(field='density', field_type='gas', comments=None), normal='y', center=None, width=None, axes_unit=None, weight_field=FieldNames(field='temperature', field_type='gas', comments=None), max_level=None, origin=None, right_handed=None, fontsize=None, field_parameters=None, method=None, data_source=None, Comments=None), PhasePlot=None)])

In [37]:
p  = yt_model.Plot[0]

In [38]:
p.ProjectionPlot

ProjectionPlot(ds=[Dataset(DatasetName='IG', fn=PosixPath('../../Data/IsolatedGalaxy/galaxy0030/galaxy0030'), comments=None), Dataset(DatasetName='Enzo', fn=PosixPath('../../Data/enzo_tiny_cosmology/DD0000/DD0000'), comments=None)], fields=FieldNames(field='density', field_type='gas', comments=None), normal='y', center=None, width=None, axes_unit=None, weight_field=FieldNames(field='temperature', field_type='gas', comments=None), max_level=None, origin=None, right_handed=None, fontsize=None, field_parameters=None, method=None, data_source=None, Comments=None)

In [56]:
import abc

class YTRunner(abc.ABC):
    def __init__(self, pydantic_class):
        self.pydantic_class = pydantic_class        
        
    @abc.abstractmethod
    def process_pydantic(self, pydantic_instance):
        # retrieve the arguments required for run_function
        pass

    def run(self, pydantic_instance=None):    
        if not isinstance(pydantic_instance, self.pydantic_class):
            raise TypeError("provided pydantic model instance does not match")             
        return self.process_pydantic(pydantic_instance)
                             
            
# example of a pydantic model not explicitly in the yt api            
class YTFieldRunner(YTRunner): 
    
    def __init__(self):        
        super().__init__(analysis_schema.data_classes.FieldNames)    

    def process_pydantic(self, pydantic_instance):
        return (pydantic_instance.field, pydantic_instance.field_type)
    
    
    
yt_registry = {}
yt_registry[analysis_schema.data_classes.FieldNames] = YTFieldRunner()
    
# a generic, recursive runner
class GenericYtFunction(YTRunner):    
    
    def __init__(self, pydantic_class, yt_func_handle = None, _known_kwargs = None): 
        super().__init__(pydantic_class)
        
        if yt_func_handle is None: 
            yt_module = self.yt_handle_module()
            if yt_module:            
                yt_func_handle = getattr(yt_module, pydantic_class.__name__)
        self.yt_func_handle = yt_func_handle
        self.yt_func_spec = getfullargspec(self.yt_func_handle)
        
        if _known_kwargs is None:
            _known_kwargs = ()            
        self._known_kwargs = _known_kwargs
        
    def yt_handle_module(self, pydantic_class):
        if hasattr(yt, pydantic_class.__name__):
            return yt
        return None
    
    @staticmethod
    def expand_pydantic(arg_value):        
        if type(arg_value) in yt_registry:
            arg_value = yt_registry[type(arg_value)].process_pydantic(arg_value)
        else:                    
            result = GenericYtFunction(type(arg_value)).process_pydantic(arg_value)
            # if it worked, add it to the registry for next time
            yt_registry[type(arg_value)] = GenericYtFunction(type(arg_value))
            arg_value = result
        return arg_value
    
    @staticmethod
    def process_func_spec_args(self, pydantic_instance) -> tuple:
        
        # the argument position number at which we have default values (a little
        # hacky, should be a better way to do this, and not sure how to scale it to
        # include *args and **kwargs)
        n_args = len(func_spec.args)  # number of arguments
        if func_spec.defaults is None:
            # no default args, make sure we never get there...
            named_kw_start_at = n_args + 1
        else:
            # the position at which named keyword args start
            named_kw_start_at = n_args - len(func_spec.defaults)
            
        the_args = []
        for arg_i, arg in enumerate(self.yt_func_spec.args):           
            if arg in ["self", "cls"]:
                continue

            # get the value for this argument. If it's not there, attempt to set default
            # values for arguments needed for yt but not exposed in our pydantic class
            try:
                arg_value = getattr(pydantic_instance, arg)
                if arg_value is None:
                    default_index = arg_i - named_kw_start_at
                    arg_value = func_spec.defaults[default_index]
            except AttributeError:
                if arg_i >= named_kw_start_at:
                    # we are in the named keyword arguments, grab the default
                    # the func_spec.defaults tuple 0 index is the first named
                    # argument, so need to offset the arg_i counter
                    default_index = arg_i - named_kw_start_at
                    arg_value = func_spec.defaults[default_index]
                else:
                    raise AttributeError(f"could not find {arg}")

            if isinstance(arg_value, pydantic.BaseModel):
                arg_value = self.expand_pydantic(arg_value)                

            the_args.append(arg_value)
        return tuple(the_args)
        
    def process_known_kwargs(self, pydantic_instance):
        kwarg_dict = {}
        if hasattr(self, "_known_kwargs"):
            for kw in self._known_kwargs:
                arg_value = getattr(pydantic_model, kw, None)
                if isinstance(arg_value, pydantic.BaseModel):
                       arg_value = self.expand_pydantic(arg_value)
                kwarg_dict[kw] = arg_value
        return kwarg_dict
                
    def retrieve_args_and_kwargs(self, pydantic_instance):
        the_args = self.process_func_spec_args(pydantic_instance)
        kwargd_dict = self.process_known_kwargs(pydantic_instance)                        
        return the_args, kwarg_dict

    def process_pydantic(self, pydantic_instance=None):        
        args, kwargs = self.retrieve_args_and_kwargs(pydantic_instance)
        return self.yt_func_handle(*args, **kwargs)
                                                        

In [51]:
from analysis_schema.data_classes import ProjectionPlot, FieldNames, Sphere

In [53]:
Sphere.schema()

{'title': 'Sphere',
 'description': 'A sphere of points defined by a *center* and a *radius*.',
 'type': 'object',
 'properties': {'Center': {'title': 'Center',
   'type': 'array',
   'items': {'type': 'number'}},
  'Radius': {'title': 'Radius',
   'anyOf': [{'type': 'number'},
    {'type': 'array',
     'minItems': 2,
     'maxItems': 2,
     'items': [{'type': 'number'}, {'type': 'string'}]}]},
  'DataSet': {'$ref': '#/definitions/Dataset'}},
 'required': ['Center', 'Radius'],
 'definitions': {'Dataset': {'title': 'Dataset',
   'description': 'The dataset to load. Filename (fn) must be a string.\n\nRequired fields: Filename',
   'type': 'object',
   'properties': {'DatasetName': {'title': 'Datasetname', 'type': 'string'},
    'FileName': {'title': 'Filename',
     'description': 'A string containing the (path to the file and the) file name',
     'type': 'string',
     'format': 'path'},
    'comments': {'title': 'Comments', 'type': 'string'}},
   'required': ['DatasetName', 'FileNam

In [42]:
field_runner = YTFieldRunner()
field_runner.run(p.ProjectionPlot.fields)

('density', 'gas')

In [49]:
type(p).__name__

'Visualizations'

In [13]:



import inspect
import yt
import typing

In [9]:
inspect.getfullargspec(yt.SlicePlot)

FullArgSpec(args=['ds', 'normal', 'fields', 'axis'], varargs='args', varkw='kwargs', defaults=(None, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={})

In [14]:
typing.get_type_hints(yt.SlicePlot)

{}

In [15]:
import pydantic

In [16]:
pydantic.create_model?

[0;31mSignature:[0m
[0mpydantic[0m[0;34m.[0m[0mcreate_model[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0m__model_name[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0m__config__[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mType[0m[0;34m[[0m[0mpydantic[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mBaseConfig[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0m__base__[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mNoneType[0m[0;34m,[0m [0mType[0m[0;34m[[0m[0mForwardRef[0m[0;34m([0m[0;34m'Model'[0m[0;34m)[0m[0;34m][0m[0;34m,[0m [0mTuple[0m[0;34m[[0m[0mType[0m[0;34m[[0m[0mForwardRef[0m[0;34m([0m[0;34m'Model'[0m[0;34m)[0m[0;34m][0m[0;34m,[0m [0;34m...[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0m__module__[0m[0;34m:[0m [0mstr[0m [0;34m=

In [50]:
fas = inspect.getfullargspec(yt.SlicePlot)
fas

FullArgSpec(args=['ds', 'normal', 'fields', 'axis'], varargs='args', varkw='kwargs', defaults=(None, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={})

In [20]:
fas.kwonlyargs

[]

In [30]:
model_fields = {}
for ky, val in zip(fas.args, (typing.Any, str, str, typing.Any)):
    model_fields[ky] = val
    
model_fields

{'ds': typing.Any, 'normal': str, 'fields': str, 'axis': typing.Any}

In [32]:
SlicePlotModel = pydantic.create_model(__model_name='SlicePlot', fields=str)

In [33]:
SlicePlotModel

pydantic.main.SlicePlot

In [24]:
yt.SlicePlot?

[0;31mSignature:[0m [0myt[0m[0;34m.[0m[0mSlicePlot[0m[0;34m([0m[0mds[0m[0;34m,[0m [0mnormal[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mfields[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0maxis[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
A factory function for
:class:`yt.visualization.plot_window.AxisAlignedSlicePlot`
and :class:`yt.visualization.plot_window.OffAxisSlicePlot` objects.  This
essentially allows for a single entry point to both types of slice plots,
the distinction being determined by the specified normal vector to the
slice.

The returned plot object can be updated using one of the many helper
functions defined in PlotWindow.

Parameters
----------

ds : :class:`yt.data_objects.static_output.Dataset`
    This is the dataset object corresponding to the
    simulation output to be plotted.
normal : int or one of 'x', 'y', 'z', or se