# Data management for model inputs and outputs

This notebook explores the use of xarray and pydantic classes for the management of input and output surrogate model variables. For the purpose of this demo, the data for process variables has been organized into the file: `online_model/var_config.py`; however, the intention for these methods is that they will be called directly from the model build in order to organize the process variables for subsequent server use. 

Three approaches will be used in organizing the model input/outputs and their metadata:
1. Use of xarray variables
2. Pydantic class implementation with validation
3. Directly from dictionary

Once the data has been constructed, it will be served using both the Channel Access and PVAccess servers. Here, these are converted into the appropriate dictionaries for building the process variable database items. The PVAccess server could be refactored to directly use the new data structure instead of constructing the pvdb; however, the Channel Access server uses the pvdb in its constuction so the pvdb conversion step will be necessary.

## Current data requirements:
In its current iteration, the server and client use the process variable database constructed from the model info keys provided in the hdf5 file associated with the model. This serves as a single source of truth composed at runtime for both client and server, which should be avoided going forward. The process variable database should instead be derived from an input/output collection that is saved along with the model build and may be replicated across model builds. Current data used client and server-side is outlined below:


| Attribute   | Used server-side | Used client-side | Required for pva | Required for CA |
|-------------|------------------|------------------|------------------|-----------------|
| name        | ✓                | ✓                | ✓                | ✓               |
| pv_type     | ✓                | ✓                | ✓                | ✓               |
| value *     | ✓ (inputs)       |                  | ✓ (inputs)       | ✓ (inputs)      |
| default **  |                  | ✓                |                  |                 |
| units       |                  | ✓                |                  |                 |
| value_range | ✓ (CA)           | ✓ (sliders)      |                  | ✓               |
| is_input    | ✓                | ✓                | ✓                | ✓               |
| type        | ✓ (CA)           |                  |                  | ✓               |
| precision   | ✓ (CA)           |                  |                  | ✓               |

Additionally, this collection might also be extended to encompass build instructions. This will be explored in section 3.



\* Value for outputs is not required because it is computed directly from the model before serving.
<br>
\** Defaults are currently used as placeholders. There is also a concept of a missing value default for IOCs that could be acocunted for in a similar manner.

## 1. xarray for metadata 
The process variables are each constructed using the xarray.Variable class. All variables are then used to construct a Dataset. 

In [1]:
import xarray as xr
import numpy as np
from typing import List, Tuple, Union

# This is a remnant that this work would ultimately remove
# The protocol is currently used in the assembly of the pvdb because of 
# extensions necessary for Channel Access AreaDetector variables
import  os
os.environ["PROTOCOL"] = "pva"

from online_model import DEFAULT_PRECISION

# xarray does not propogate attributes by default
xr.set_options(keep_attrs = True)

def build_variable(*, 
                   name: str, 
                   pv_type: str, 
                   value: Union[np.ndarray, float], 
                   value_range: List[float], 
                   default: Union[np.ndarray, float], 
                   dim: Tuple[str], 
                   units: str, 
                   is_input: int,
                   precision: int=DEFAULT_PRECISION) -> xr.Variable:
    
    # assign default for none value
    if value is None:
        value = default
    
    # need to convert scalar to array to work with xarray
    if isinstance(value, (float,)):
        value = np.array([value])
        
    attributes = {
              "pv_type": pv_type, 
              "range" : value_range,
              "default": default,
              "units": units,
              "name": name,
              "is_input": is_input,
              "precision": precision
             }
        
    # drop None keys, cannot write None values to hdf5 file
    # note: this could be problematic with required attributes
   # attributes = {k: v for k, v in attributes.items() if v is not None}

    variable = xr.Variable(dim,
                           value,
                           attrs=attributes
                        )
    
    return variable


# Example variable using misc data
example_variable = build_variable(
    name = "example",
    pv_type = "scalar",
    value_range=[1.000000e-01, 5.000000e-01], 
    default=3.47986980e-01,
    value= np.array([3.47986980e-01]), # MUST BE AN ARRAY!
    dim = ("length",),
    units = "mm",
    is_input = 1,
)


### Now, use the build_variable function to create xarray Dataset from our model info:

In [2]:
from online_model.var_config import VARIABLES


variables = {}

for variable, configs in VARIABLES.items():
    variables[variable] = build_variable(
        name = variable,
        pv_type = configs["pv_type"],
        value_range=configs["value_range"], 
        default=configs["default"],
        value= configs["value"],
        dim = configs["xarray_dim"],
        units = configs["units"],
        is_input = configs["is_input"]
    )
    
    
dset = xr.Dataset(variables)
# show dset explore view
dset

### Prepare to serve
Note: Enforcing required attributes for our entry would require massive manual extension of the build_variable. This has the potential to result in missing required fields or incorrectly typed attributes.

In [3]:
from online_model.server import ca, pva
from online_model.model.MySurrogateModel import MySurrogateModel


def pvdb_from_xarray(dset):
    input_pvdb = {}
    output_pvdb = {}
    
    for variable in dset.keys():
        
        entry = {
            "prec": dset[variable].attrs["precision"],
            "units": dset[variable].attrs["units"],
            "range": dset[variable].attrs["range"],
            "type": "float" # For channel access
        }
        
        
        if dset[variable].attrs["is_input"]:
            
            # set values for the inputs
            if dset[variable].attrs["pv_type"] == "scalar":
                entry["value"] = dset[variable].values[0] # Have to extract our scalar value from the xarray        
                
            else:
                entry["value"] = dset[variable]
            
            input_pvdb[variable] = entry
            
        else:
            output_pvdb[variable]  = entry
    
    return input_pvdb, output_pvdb

### Run Channel Access Server
Note: The `build_image_pvs` function used below is a utility function that adds the appropriate image process variables for the AreaDetector naming scheme. This will ultimately be included in the initial pvdb construction.

In [4]:
from online_model.util import build_image_pvs
from online_model import DEFAULT_PRECISION, DEFAULT_COLOR_MODE, IMAGE_UNITS, IMAGE_SHAPE, PREFIX, MODEL_KWARGS
from online_model.server import ca

input_pvdb, output_pvdb = pvdb_from_xarray(dset)


image_pvs = build_image_pvs(
        "x:y", 
        IMAGE_SHAPE,
        IMAGE_UNITS,
        DEFAULT_PRECISION,
        DEFAULT_COLOR_MODE,
    )

output_pvdb.update(image_pvs)


server = ca.CAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX)
server.start_server()

Loaded Attributes successfully
Loaded Architecture successfully
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Loaded Weights successfully
{'laser_radius': 0.34798698, 'maxb(2)': 0.0402751972, 'phi(1)': -7.99101687, 'total_charge:value': 141.576322, 'in_xmin': -0.000353964583, 'in_xmax': 0.000345778376, 'in_ymin': -0.000347874295, 'in_ymax': 0.000345778376, 'image': array([[[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]],

       ...,

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
      

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

### Run PVAccess server

In [None]:
from online_model import DEFAULT_PRECISION, PREFIX, MODEL_KWARGS
from online_model.server import pva

input_pvdb, output_pvdb = pvdb_from_xarray(dset)

server = pva.PVAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX)
server.start_server()

In [None]:
server.stop_server()

Saving

# Build variables using classes
advantages: enforce typing and required fields on construction
<br>
can enforce other formatting
<br>
View-based config requirements: for example - slider

In [None]:
from pydantic import BaseModel
from typing import Optional, List, Any, Union
from enum import IntEnum


# custom validator for ndarrays
class NumpyNDArray(np.ndarray):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate
        
    @classmethod
    def validate(cls, v: Any) -> np.ndarray:
        # validate data...
        if not isinstance(v, np.ndarray):
            raise TypeError('Numpy array required')
        return v

# custom validator for xarray DataArrays
class XarrayDataArray(xr.DataArray):
    __slots__ = [] # xarrray requires explicit definition of slots on subclasses
    
    @classmethod
    def __get_validators__(cls):
        yield cls.validate
        
    @classmethod
    def validate(cls, v: Any) -> xr.DataArray:
        # validate data...
        if not isinstance(v, xr.DataArrary):
            raise TypeError('Numpy array required')
        return v

class IOEnum(IntEnum):
    output = 0
    input = 1

class ProcessVariable(BaseModel):
    name: str
    io_type: IOEnum   #requires selection of input or output for creation
    units: Optional[str]
        
    # fixed type to be passed to pvdb
    class Meta:
        type: str = "float"
        precision: int = DEFAULT_PRECISION
        
        
class ScalarProcessVariable(ProcessVariable):
    value: Optional[float]
    default: Optional[float]
    range:  Optional[Union[NumpyNDArray, XarrayDataArray]] 
    

class NDProcessVariable(ProcessVariable):
    value: Optional[NumpyNDArray]
    default: Optional[NumpyNDArray]
    range:  Optional[Union[NumpyNDArray, XarrayDataArray]]     
    
    
test_param = NDProcessVariable(
                    name = "test", 
                    io_type = 1, 
                    value=np.array([5,  2],), 
                    default= np.array([5.0,  2.0],), 
                    range=np.array([5.0, 4.0])
            )


In [None]:
# Example of incorrectly typed parameter:
test_param = ScalarProcessVariable(
                    name="test", 
                    io_type = 1, 
                    units="test", 
                    value=5.0, 
                    default= np.ndarray([5,  2],), # for a scalar parameter, a float should be passed
                    range=np.ndarray([5, 4])
                )

# Use classes to set up process variables

In [None]:
variables = []

for variable, configs in VARIABLES.items():
    if configs["pv_type"] == "scalar":
        var = ScalarProcessVariable(
            name =  variable, 
            io_type = configs["is_input"], 
            units = configs.get("units"), # not necessarily defined 
            value = configs.get("value"), # not necessarily defined
            default = configs.get("default"), # not necessarily defined
            range = configs.get("range"), # not necessarily defined
        )
    
    elif configs["pv_type"] == "image":
        var =  NDProcessVariable(
            name =  variable, 
            io_type = configs["is_input"], 
            units = configs.get("units"), # not necessarily defined 
            value = configs.get("value"), # not necessarily defined
            default = configs.get("default"), # not necessarily defined
            range = configs.get("range"), # not necessarily defined
        )
    
    variables.append(var)

# Create pvdb from the classes

requirements to serve


PVA
- initial value for inputs


CA
-build pvdb




Extra step of processing for xarray -> pva rather than using a dictionary

In [None]:
def pvdb_from_classes(variables):
    input_pvdb = {}
    output_pvdb = {}
    
    for variable in variables:
        entry = {
            "prec": variable.Meta.precision,
            "units": variable.units,
            "range": variable.range,
            "value": variable.value,
            "type": variable.Meta.type,
        }
        
        if variable.io_type.name == "input":
            input_pvdb[variable.name] = entry
            
        elif variable.io_type.name == "output":
            output_pvdb[variable.name]  = entry
            
        else:
            # pydantic validation should prohibit any other assignment
            pass 
    
    return input_pvdb, output_pvdb

    
input_pvdb, output_pvdb = pvdb_from_classes(variables)
input_pvdb

In [None]:
input_pvdb, output_pvdb = pvdb_from_classes(variables)

server = pva.PVAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX)
server.start_server()


# CONS OF XARRAY
Dimension requirements that don't play super well with the concept of a single value- implementation is weird
 Xarray variable dimensions requires definition inconsistent with the openPMD standard
 requires a value to be passed to construct the array, problematic with outputs

gymnastics/manipulation to accomodate xarray rather than having it fit naturally into this use case

Why treat a float as an array?

To get the benefits of xarray indexing (data_set.laser_radius.value, etc.), parameter names must by pythonic. 
Otherwise, must use data_set\["phi(1)"\].value
Also, this means that the name of the variable is tracked redundantly within the variable,
in order for the variable to be a complete record


Why treat

Addition of new attribute = more function arguments

# PROS OF CLASSES
extensible configs

The flexibility of pydantic also allows us to validate the data needed before we even write the data
settings could be further broken down to input/ouptut variables
settings could be provided for plots, sliders, etc. 

![loading_diagram.png](attachment:loading_diagram.png) 