# Data management for model inputs and outputs

This notebook explores the use of xarray and pydantic classes for the management of input and output surrogate model variables. For the purpose of this demo, the data for process variables has been organized into the file: `online_model/var_config.py`; however, the intention for these methods is that they will be called directly from the model build in order to organize the process variables for subsequent server use. 

Three approaches will be used in organizing the model input/outputs and their metadata:
1. Use of xarray variables
2. Pydantic class implementation with validation

Once the data has been constructed, it will be served using both the Channel Access and PVAccess servers. Here, these are converted into the appropriate dictionaries for building the process variable database items. The PVAccess server could be refactored to directly use the new data structure instead of constructing the pvdb; however, the Channel Access server uses the pvdb in its constuction so the pvdb conversion step will be necessary.

## Current data requirements:
In its current iteration (as reflected by the pva_refactor branch), the server and client use the process variable database constructed from the model info keys provided in the hdf5 file associated with the model. This serves as a single source of truth composed at runtime for both client and server, which should be avoided going forward. The process variable database should instead be derived from an input/output collection that is saved along with the model build and may be replicated across model builds. Current data used client and server-side is outlined below:


| Attribute   | Used server-side | Used client-side | Required for pva | Required for CA |
|-------------|------------------|------------------|------------------|-----------------|
| name        | ✓                | ✓                | ✓                | ✓               |
| pv_type     | ✓                | ✓                | ✓                | ✓               |
| value *     | ✓ (inputs)       |                  | ✓ (inputs)       | ✓ (inputs)      |
| default **  |                  | ✓                |                  |                 |
| units       |                  | ✓                |                  |                 |
| value_range | ✓ (CA)           | ✓ (sliders)      |                  | ✓               |
| is_input    | ✓                | ✓                | ✓                | ✓               |
| type        | ✓ (CA)           |                  |                  | ✓               |
| precision   | ✓ (CA)           |                  |                  | ✓               |

Additionally, this collection might also be extended to encompass build instructions. This will be explored in section 3.



\* Value for outputs is not required because it is computed directly from the model before serving.
<br>
\** Defaults are currently used as placeholders. There is also a concept of a missing value default for IOCs that could be acocunted for in a similar manner.

## 1. xarray for metadata 
The process variables are each constructed using the `xarray.Variable` class. All variables are then used to construct a Dataset. 

In [1]:
import xarray as xr
import numpy as np
from typing import List, Tuple, Union

# xarray does not propogate attributes by default
xr.set_options(keep_attrs = True)

def build_variable(*, 
                   name: str, 
                   pv_type: str, 
                   value: Union[np.ndarray, float], 
                   value_range: List[float], 
                   default: Union[np.ndarray, float], 
                   dim: Tuple[str], 
                   units: str, 
                   is_input: int,
                   precision: int,
                   color_mode: int, # None if image pv
                   shape: Tuple[int] # None if image pv
                  ) -> xr.Variable: 
    
    # assign default for none value
    if value is None:
        value = default
    
    # need to convert scalar to array to work with xarray
    if isinstance(value, (float,)):
        value = np.array([value])
        
    attributes = {
              "pv_type": pv_type, 
              "range" : value_range,
              "default": default,
              "units": units,
              "name": name,
              "is_input": is_input,
              "precision": precision,
             }
    
    if pv_type == "image":
        attributes["color_mode"] = color_mode
        attributes["shape"] = shape
        
    variable = xr.Variable(dim,
                           value,
                           attrs=attributes
                        )
    
    return variable


# Example variable using misc data
example_variable = build_variable(
    name = "example",
    pv_type = "scalar",
    value_range=[1.000000e-01, 5.000000e-01], 
    default=3.47986980e-01,
    value= np.array([3.47986980e-01]), # MUST BE AN ARRAY!
    dim = ("length",),
    units = "mm",
    is_input = 1,
    precision = 8,
    color_mode = None,
    shape = None
)


### Now, use the build_variable function to create xarray Dataset from our model info:

In [3]:
from online_model.var_config import VARIABLES


variables = {}

for variable, configs in VARIABLES.items():
    variables[variable] = build_variable(
        name = variable,
        pv_type = configs.get("pv_type"),
        value_range=configs.get("range"), 
        default=configs.get("default"),
        value= configs.get("value"),
        dim = configs.get("xarray_dim"),
        units = configs.get("units"),
        is_input = configs.get("is_input"),
        precision = 8, 
        color_mode = configs.get("color_mode"),
        shape = configs.get("shape")
    )
    
    
dset = xr.Dataset(variables)
# show dset explore view
dset

### Prepare to serve
Note: Enforcing required attributes for our entry would require massive manual extension of the build_variable. This has the potential to result in missing required fields or incorrectly typed attributes.

In [None]:
from online_model.server import ca, pva
from online_model.model.MySurrogateModel import MySurrogateModel
PREFIX = "smvm"

MODEL_FILE = "online_model/files/CNN_051620_SurrogateModel.h5"
STOCK_LASER_IMAGE = "online_model/files/example_input_image.npy"

MODEL_KWARGS = {
    "model_file": MODEL_FILE,
    "stock_image_input": np.load(STOCK_LASER_IMAGE),
}

ARRAY_PVS = ["x:y"]



def pvdb_from_xarray(dset, protocol):
    input_pvdb = {}
    output_pvdb = {}

    for variable in dset.keys():

        entry = {
            "prec": dset[variable].attrs["precision"],
            "units": dset[variable].attrs["units"],
            "range": dset[variable].attrs["range"],
            "type": "float",  # For channel access
        }

        # set up area detector pvs
        if protocol == "ca" and dset[variable].attrs["pv_type"] == "image":
            image_pvs = build_image_pvs(
                variable,
                dset[variable].attrs["shape"],
                dset[variable].attrs["units"],
                dset[variable].attrs["precision"],
                dset[variable].attrs["color_mode"],
            )

            if dset[variable].attrs["is_input"] == 1:
                input_pvdb.update(image_pvs)

            elif dset[variable].attrs["is_input"] == 0:
                output_pvdb.update(image_pvs)

        else:

            if dset[variable].attrs["is_input"]:

                # set values for the inputs
                if dset[variable].attrs["pv_type"] == "scalar":
                    entry["value"] = dset[variable].values[
                        0
                    ]  # Have to extract our scalar value from the xarray

                else:
                    entry["value"] = dset[variable]

                input_pvdb[variable] = entry

            else:
                output_pvdb[variable] = entry

    return input_pvdb, output_pvdb

In [None]:
# save dataset to file
import pickle

with open("online_model/files/xarray_dset.pickle",  "wb") as f:
    pickle.dump(dset, f)

### Run Channel Access Server
Note: The `build_image_pvs` function used below is a utility function that adds the appropriate image process variables for the AreaDetector naming scheme. This will ultimately be included in the initial pvdb construction.

In [None]:
from online_model.util import build_image_pvs
from online_model.server import ca

input_pvdb, output_pvdb = pvdb_from_xarray(dset, "ca")

server = ca.CAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX, ARRAY_PVS)
server.start_server()

### Run PVAccess server
(works with Server.ipynb)

In [None]:
from online_model.server import pva

input_pvdb, output_pvdb = pvdb_from_xarray(dset, "pva")

server = pva.PVAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX, ARRAY_PVS)
server.start_server()

## 2. Build variables using pydantic classes
Pydantic is a library which enforces type hinting at runtime. This means that errors in variable configurations can be caught during variable construction. Pydantic also gives control over dictionary conversions and json schema.

In [4]:
import numpy as np
import xarray as xr
from enum import Enum
from typing import Any, List, Union, Optional
from pydantic import BaseModel, Field


# custom validator for ndarrays
class NumpyNDArray(np.ndarray):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v: Any) -> np.ndarray:
        # validate data...
        if not isinstance(v, np.ndarray):
            raise TypeError("Numpy array required")
        return v
    

class IOEnum(str, Enum):
    pv_in = "input"
    pv_out = "output"


class ProcessVariable(BaseModel):
    name: str
    io_type: IOEnum  # requires selection of input or output for creation
    # defaults for pvdb
    value_type: str = Field("float", alias = "type")
    precision: int = 8

    class Config:
        use_enum_values = True


class ScalarProcessVariable(ProcessVariable):
    value: Optional[float]
    default: Optional[float]
        
    # alias allows us to define dict representation
    value_range: list = Field(alias="range") 
    units: Optional[str]


class NDProcessVariable(ProcessVariable):
    value: Optional[NumpyNDArray]
    default: Optional[NumpyNDArray]
        
    # alias allows us to define dict representation
    value_range: list = Field(alias="range") 
    units: str


class ImageProcessVariable(NDProcessVariable):
    color_mode: int = 0
    shape: tuple  # need for channel access AreaDetector


## Example of an incorrectly typed prcess variable:

In [5]:
incorrect_var = ScalarProcessVariable(
                    name="test", 
                    io_type = "input", 
                    units="test", 
                    value=5.0, 
                    default= np.ndarray([5,  2],), # for a scalar parameter, a float should be passed
                    range=[5, 4]
            )

ValidationError: 1 validation error for ScalarProcessVariable
default
  value is not a valid float (type=type_error.float)

## Example of a correctly typed variable:

In [6]:
test_param = ScalarProcessVariable(
                    name="test", 
                    io_type = "input", 
                    units="test_units", 
                    value=5.0, 
                    default=0.0, # for a scalar parameter, a float should be passed
                    range=[5, 4]
        )

## Dictionary representation can be used directly to build the pvdb

In [7]:
test_param.dict(exclude_unset=True, exclude={"io_type"}, by_alias=True)

{'name': 'test',
 'value': 5.0,
 'default': 0.0,
 'range': [5, 4],
 'units': 'test_units'}

## Similar to the xarray config, we can use classes to set up process variables on a case by case basis

In [None]:
variables = []

for variable, configs in VARIABLES.items():
    
    if configs["is_input"] == 1:
        io_type = "input"
        
    else:
        io_type = "output"
    
    
    if configs["pv_type"] == "scalar":
        var = ScalarProcessVariable(
            name = variable,
            io_type = io_type,
            **configs
        )
    
    elif configs["pv_type"] == "image":
        var =  NDProcessVariable(
            name = variable,
            io_type = io_type,
            **configs
        )
    
    variables.append(var)

## Create pvdb from the class instances:

Besides the channel access image variables, the pvdb entries are created simply by calling the dict method. 

In [None]:
def pvdb_from_classes(variables, protocol):
    input_pvdb = {}
    output_pvdb = {}

    for variable in variables:
        # no manual formatting needed and have control over what is included/excluded
        # by_alias kwarg allows us to dump the dict using reserved types
        entry = variable.dict(exclude_unset=True, exclude={"io_type"}, by_alias=True)

        if protocol == "ca" and isinstance(variable, (ImageProcessVariable,)):
            image_pvs = build_image_pvs(
                variable.name,
                variable.shape,
                variable.units,
                variable.precision,
                variable.color_mode,
            )

            if variable.io_type == "input":
                input_pvdb.update(image_pvs)

            elif variable.io_type == "output":
                output_pvdb.update(image_pvs)

        else:
            if variable.io_type == "input":
                input_pvdb[variable.name] = entry

            elif variable.io_type == "output":
                output_pvdb[variable.name] = entry

            else:
                # pydantic enum validation will prohibit any other assignment
                pass

    return input_pvdb, output_pvdb

### Run Channel Access Server
Note: The `build_image_pvs` function used below is a utility function that adds the appropriate image process variables for the AreaDetector naming scheme. This will ultimately be included in the initial pvdb construction.

In [None]:
from online_model.server import ca

input_pvdb, output_pvdb = pvdb_from_classes(variables, "ca")

server = ca.CAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX, ARRAY_PVS)
server.start_server()

### Run PVAccess Server 
(works with Server.ipynb)

In [None]:
from online_model.server import pva

input_pvdb, output_pvdb = pvdb_from_classes(variables, "pva")

server = pva.PVAServer(MySurrogateModel, MODEL_KWARGS, input_pvdb, output_pvdb, PREFIX, ARRAY_PVS)
server.start_server()


# Conclusions

There are a couple of strange things about the xarray implementation.
- The array dimension field is unused in our server/client famework; however, xarray variables require dimension definition. This forced assignment contradicts could contradict quantity standards (like openPMD, though this is admittedly pedantic). 
- A value field must be passed to construct the array. This isn't consistent with how output variables are used currently as they don't require values for initial construction.
- Scalar process variables must be passed as an array and then sampled to extract value
- In order to access the indexing features of xarray, a strict pythonic naming scheme has to be enforced for the input/output names (For example, `data_set.laser_radius.value`, vs, `data_set\["phi(1)"\].value`). Use of pythonic aliases for the variables within the xarray structure could be augmented by an attribute with the proper name; though, this is redundant and probably would lead to confusion.
- Addition of new attribute = more function arguments, which could expand dramatically depending on the amount of client-building automation desired. 


## Beyond server config
Other configurations will be needed in order to expand this framework beyond the current model implementation. On the client side, instructions for rendering must be passed on the basis of process variable type. In order to do this, the type of the variable must be known. With xarray this requires the introspection of attributes for each variable. With the class implementation, this could be enforced with a type validation on the process variable class. 

Client-side configurations will also require the abstraction of indicators that are currently hard-coded. For example, slider inclusion/exclusion will need to be abstracted before accomodating other models. This could be accomplished by an exposure attribute or by the optional inclusion of SliderConfig settings, that could enforce additional range constraints, set the step size, etc.. Other controls may be required beyond sliders, each with their own, specific and sometimes optional settings. Pydantic models are well suited to this application and may be favorable over attribute expansion in xarray.

# Other considerations

- How will these inputs and outputs be saved? 
These input and output variables need to be decopled from the current hdf5 in order for the client to operate without a distribution of the the original model

![title](online_model/files/program_structure.png)