In [None]:
from estival.utils import sample
import pandas as pd
import numpy as np

In [None]:
esamp = sample.esamptools
SampleIterator = esamp.SampleIterator

In [None]:
df = pd.DataFrame({"x": np.random.normal(size=(16,)), "y": np.random.normal(size=(16,))})

In [None]:
si = esamp.validate_samplecontainer(df)

In [None]:
si.convert(esamp.SampleTypes.LIST_OF_DICTS)

In [None]:
class SampleFormatConverter:
    def __init__(self, format_name: str):
        self.format_name = format_name
        self._convert_to = {"SampleIterator": self.to_sampleiterator}
        self._convert_from = {"SampleIterator": self.from_sampleiterator}
    
    def to_sampleiterator(self, obj):
        raise NotImplementedError()
    
    def from_sampleiterator(self, si):
        raise NotImplementedError()
        
    def is_format(self, obj):
        raise NotImplementedError()
        

In [None]:
class IndexSampleListConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("ListOfIndexSample")
        self._convert_to["ListOfDicts"] = self.to_lod
        
    def is_format(self, obj):
        if isinstance(obj, list):
            ref_sample = obj[0]
            if isinstance(ref_sample, tuple):
                if len(ref_sample) == 2:
                    if isinstance(ref_sample[1], dict):
                        return True
        return False
    
    def to_sampleiterator(self, obj):
        lod = [v for k, v in obj]
        index = pd.Index([k for k, v in obj])
        si = esamp._lod_to_si(lod)
        si.set_index(index)
        return si
    
    def from_sampleiterator(self, obj):
        return [(k,v) for k,v in obj.iterrows()]
    
    def to_lod(self, obj):
        return [v for (k,v) in obj]

In [None]:
class ListOfDictsConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("ListOfDicts")
        
    def is_format(self, obj):
        if isinstance(obj, list):
            ref_sample = obj[0]
            if isinstance(ref_sample, dict):
                return True
        return False
    
    def to_sampleiterator(self, obj):
        return esamp._lod_to_si(obj)
    
    def from_sampleiterator(self, si):
        return [v for k,v in si.iterrows()]
        

In [None]:
def dataframe_to_sampleiterator(in_data: pd.DataFrame):
    components = {c: in_data[c].to_numpy() for c in in_data.columns}  # type: ignore
    return SampleIterator(components, index=in_data.index)

In [None]:
class DataFrameConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("DataFrame")
    
    def is_format(self, obj):
        return isinstance(obj, pd.DataFrame)
    
    def to_sampleiterator(self, obj):
        return dataframe_to_sampleiterator(obj)
    
    def from_sampleiterator(self, obj):
        return pd.DataFrame(obj.convert("list_of_dicts"), index=obj.index)

In [None]:
esamp.xarray_to_sampleiterator?

In [None]:
import xarray
import arviz

In [None]:
class InferenceDataConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("InferenceData")
    
    def is_format(self, obj):
        return isinstance(obj, arviz.InferenceData)
    
    def to_sampleiterator(self, obj):
        return esamp.idata_to_sampleiterator(obj)

In [None]:
class XArrayConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("XArrayDataset")
    
    def is_format(self, obj):
        return isinstance(obj, xarray.Dataset)
    
    def to_sampleiterator(self, obj):
        return esamp.xarray_to_sampleiterator(obj)

In [None]:
class SampleIteratorConverter(SampleFormatConverter):
    def __init__(self):
        super().__init__("SampleIterator")
        
    def is_format(self, obj):
        return isinstance(obj, SampleIterator)
    
    def to_sampleiterator(self, obj):
        return obj
    
    def from_sampleiterator(self, obj):
        return obj

In [None]:
import pymc as pm

with pm.Model() as model:
    x = pm.Uniform('x')
    idata = pm.sample()

In [None]:
fm.convert(idata)

In [None]:
class FormatManager:
    def __init__(self):
        self.converters = {}
        self.register(SampleIteratorConverter())
    
    def register(self, format_converter):
        name = format_converter.format_name
        if name in self.converters:
            raise KeyError(name, "format already exists")
        else:
            self.converters[name] = format_converter
            
    def get_format(self, obj):
        for k, v in self.converters.items():
            if v.is_format(obj):
                return k
        raise TypeError("Unknown format")
        
    def get_converter(self, obj):
        return self.converters[self.get_format(obj)]
        
    def convert(self, obj, dest_format="SampleIterator"):
        src_fmt = self.get_format(obj)
        
        if src_fmt == dest_format:
            return obj
            
        in_converter = self.get_converter(obj)
        out_converter = self.converters[dest_format]
        if convert_to := in_converter._convert_to.get(dest_format):
            return convert_to(obj)
        elif convert_from := out_converter._convert_from.get(src_fmt):
            return convert_from(obj)

        si = in_converter.to_sampleiterator(obj)
        return out_converter.from_sampleiterator(si)

In [None]:
fm = FormatManager()
fm.register(IndexSampleListConverter())
fm.register(ListOfDictsConverter())
fm.register(DataFrameConverter())
fm.register(XArrayConverter())
fm.register(InferenceDataConverter())

isl = [(5, {"x": 0.5, "y": 1.2}), (7, {"x": 0.7, "y": 0.3})]
lod = [{"x": 0.5, "y": 1.2}, {"x": 0.7, "y": 0.3}]
si = SampleIterator({"x": np.array((0.0,1.2)), "y": np.array((0.7,0.3))})
sims = SampleIterator({"x": np.array((0.0,1.2)), "y": np.array(((0.7,0.3,0.1),(0.9,0.2,0.4)))})

In [None]:
idata_si = fm.convert(idata)

In [None]:
subset = arviz.extract(idata, num_samples=20)

In [None]:
subset_si = fm.convert(subset)

In [None]:
subset_si

In [None]:
idata["posterior"]

In [None]:
subset

In [None]:
sims.set_index(pd.MultiIndex.from_product([pd.Index([0], name="chain"), pd.Index([0,1],name="draw")]))

In [None]:
df= fm.convert(sims, "DataFrame")

In [None]:
sims.loc[:,1]

In [None]:
def to_xarray(si):
    ds = xarray.Dataset()
    for k,v in si.components.items():
        extra_dims = [f"{k}_dim_{i}" for i in range(len(v.shape) - 1)]
        darr = xarray.DataArray(v, coords={"sample": si.index}, dims=["sample"] + extra_dims)
        ds[k] = darr
    return ds

In [None]:
to_xarray(sims)

In [None]:
xarray.DataArray(np.linspace(0.0,))

In [None]:
xarray.DataArray(np.random.normal(size=(256,4)))

In [None]:
xarray.DataArray(si.components["x"], coords={"sample": si.index})

In [None]:
from itertools import product

In [None]:
for (src, dst) in list(product(fm.converters, fm.converters)):
    print(src, dst)
    out = fm.convert(fm.convert(si, src), dst)
    print(fm.get_format(out))

In [None]:
fm.converters

In [None]:
fm.converters

In [None]:
fm.convert(lod, "sample_iterator")

In [None]:
fm.convert(fm.convert(fm.convert(lod, "sample_iterator"), "list_of_indexsample"), "pandas")

In [None]:
fm.converters

In [None]:
si = islc.to_sampleiterator(index_sample_list)

In [None]:
[t for t in si]

In [None]:
fm.get_format(lod)

In [None]:
islc.to_sampleiterator(index_sample_list)

In [None]:
islc.is_format(index_sample_list)

In [None]:
c.is_format([(2, 5)])