# Test xarray and h5netcdf

In [23]:
from pydantic import BaseModel
import numpy as np
from typing import List, Tuple
import xarray as xr

# Set up basic parameters and use pydantic to enforce fields

In [24]:
import numpy
import pandas as pd
import pint

# set up pint unit registry
ureg = pint.UnitRegistry()


class Parameter:
    
    def __init__(self, *, 
                 name: str,
                 value: pint.quantity.Quantity, 
                 default: pint.quantity.Quantity, 
                 range: List[pint.quantity.Quantity],
                 dim: Tuple[str]
                ):
        
        # convert default, and range to the value units
        # can add checks for incompatible units here
        default = default.to(value.units)
        range[0] = range[0].to(value.units)
        range[1] = range[1].to(value.units)
        self.range = range
        self.name = name
        self.value = value
        self.units = value.units
        self.default = default
        self.dim = dim
        
    
    @classmethod
    def from_variable(cls, variable):
        units = ureg.parse_expression(variable.attrs["units"])
        name = variable.name,
        value = variable.values * units
        dim = variable.dims
        min = variable.attrs["min"] * units
        max = variable.attrs["max"] * units
        default = variable.attrs["default"] * units
        
        return cls(value = value, name= name, default=default, range=[min, max], dim=dim)
        


class ScalarParameter(Parameter):
    dtype = "numpy.float64"
    
    def to_variable(self):
        value = np.array([self.value.magnitude])
        
        variable = xr.Variable(self.dim,
                               value,
                               attrs={"min" : self.range[0].magnitude,
                                      "max" : self.range[1].magnitude, 
                                      "default": self.default.magnitude,
                                      "units": self.units.__str__()
                                     }
                            )
        return variable


        
    
        
class ImageParameter(Parameter):
    dtype = "numpy.ndarray"
    
    def to_variable(self):
        value = parameter.value
        variable = xr.Variable(self.dim,
                               value,
                               attrs={"min" : self.range[0].magnitude,
                                      "max" : self.range[1].magnitude, 
                                      "default": self.default.magnitude,
                                      "units": self.units.__str__()
                                     }
                            )
        
        return variable
        

# Build the parameters

In [25]:
laser_units = ureg.parse_expression("mm")
laser_radius = ScalarParameter(
    name="laser_radius", 
    range=[1.000000e-01 * laser_units, 5.000000e-01 * laser_units], 
    default=3.47986980e-01 * laser_units,
    value=3.47986980e-01 * laser_units,
    dim = ("length",)
)


# set up maxb(2)
b_units = ureg.parse_expression("T")
max_b = ScalarParameter(
    name="maxb(2)", 
    range = np.array([0.000000e+00, 1.000000e-01]) * b_units,
    default = 4.02751972e-02 * b_units,
    value = 4.02751972e-02 * b_units,
    dim = ("field_strength",)
)


# set up phi(1)
phi_units = ureg.parse_expression("degree")
phi_1 = ScalarParameter(
    name="phi(1)", 
    range = np.array([-1.000000e+01, 1.000000e+01]) * phi_units, 
    default = -7.99101687e00 * phi_units,
    value = -7.99101687e00 * phi_units,
    dim = ("phi",)
)

# set up total charge 
charge_units = ureg.parse_expression("pC")
total_charge = ScalarParameter(
    name="total_charge:value", 
    range=np.array([0.000000e+00, 3.000000e+02])*charge_units, 
    default=1.41576322e02 * charge_units,
    value=1.41576322e02 * charge_units,
    dim = ("charge",)
)

# set up xmin
x_units = ureg.parse_expression("m")
x_min = ScalarParameter(
    name="xmin", 
    range=np.array([-4.216000e-04, 3.977000e-04])*x_units, 
    default = -3.53964583e-04 * x_units,
    value = -3.53964583e-04 * x_units,
    dim = ("length",)
)

# set up x max
x_max = ScalarParameter(
    name="xmax", 
    range=np.array([-4.216000e-04, 3.977000e-04])*x_units, 
    default = 3.44330666e-04*x_units,
    value = 3.44330666e-04*x_units,
    dim = ("length",)
)

# set up y min
y_units = ureg.parse_expression("m")
y_min = ScalarParameter(
    name="ymin", 
    range = np.array([-1.117627e-01, 1.120053e-01]) * y_units,
    default = -3.47874295e-04 * y_units,
    value = -3.47874295e-04 * y_units,
    dim = ("length",)
)

# set up y max
y_max = ScalarParameter(
    name = "ymax", 
    range = np.array([-1.117627e-01, 1.120053e-01]) * y_units, 
    default = 3.45778376e-04 * y_units,
    value = 3.45778376e-04 * y_units,
    dim = ("length",)
)


# image input
default_image_array = np.load("online_model/files/example_input_image.npy")

# reshape because isn't in the correct format
default_image_array = default_image_array.reshape((50,50))
image = ImageParameter(
    name="image", 
    range= np.array([0.0, 9.0]) * x_units/y_units,
    default = default_image_array * x_units/y_units,
    value = default_image_array * x_units/y_units,
    dim =  ("x", "y", )
)

parameter_list = [laser_radius, max_b, phi_1, total_charge, x_min, x_max, y_min, y_max, image]

# Create xarray variables from the Parameters

In [26]:
variables = {}

for parameter in parameter_list:
    variables[parameter.name] = parameter.to_variable()

# Create dataset from variables

In [27]:
dset = xr.Dataset(variables)
dset.laser_radius.attrs["min"]

0.1

In [28]:
dset.laser_radius.attrs["units"]

'millimeter'

In [29]:
reloaded.close()
dset.to_netcdf('save_test.h5', engine='h5netcdf')

  return array(a, dtype, copy=False, order=order)


In [30]:
reloaded = xr.open_dataset('save_test.h5', engine='h5netcdf', decode_cf=True)


In [31]:
parameter = ScalarParameter.from_variable(reloaded.image)

In [33]:
reloaded["image"]