Skip to content

Commit

Permalink
Samplingtools new features
Browse files Browse the repository at this point in the history
- All classes Initialization with optional parameters that are passed to ``set_param``
- Samplingplanner method ``product`` to create cartesian product (grid) of input variables to create test cases
  • Loading branch information
Felix-Mac committed Sep 14, 2022
1 parent 22c9644 commit 47029a2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
11 changes: 8 additions & 3 deletions do_mpc/sampling/datahandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class DataHandler:
The list of all samples originates from :py:class:`do_mpc.sampling.samplingplanner.SamplingPlanner` and is used to
initiate this class (``sampling_plan``).
The class can be created with optional keyword arguments which are passed to :py:meth:`set_param`.
**Configuration and retrieving processed data:**
1. Initiate the object with the ``sampling_plan`` originating from :py:class:`do_mpc.sampling.samplingplanner.SamplingPlanner`.
Expand Down Expand Up @@ -67,7 +69,7 @@ def sample_function(alpha, beta):
dh[:]
"""
def __init__(self, sampling_plan):
def __init__(self, sampling_plan, **kwargs):
self.flags = {
'set_post_processing' : False,
}
Expand All @@ -89,6 +91,9 @@ def __init__(self, sampling_plan):

self.pre_loaded_data = {'id':[], 'data':[]}

if kwargs:
self.set_param(**kwargs)


@property
def data_dir(self):
Expand Down Expand Up @@ -193,13 +198,13 @@ def filter(self, input_filter=None, output_filter=None):
:param filter_fun: Function to filter the data.
:type filter: Function or BuiltinFunction_or_Method
:type filter_fun: Function or BuiltinFunction_or_Method
:raises assertion: No post processing function is set
:raises assertion: filter_fun must be either Function of BuiltinFunction_or_Method
:return: Returns the post processed samples that satisfy the filter
:rtype: dict
:rtype: list
"""
assert isinstance(input_filter, (types.FunctionType, types.BuiltinFunctionType, type(None))), 'input_filter must be either Function or BuiltinFunction_or_Method, you have {}'.format(type(input_filter))
assert isinstance(output_filter, (types.FunctionType, types.BuiltinFunctionType, type(None))), 'output_filter must be either Function or BuiltinFunction_or_Method, you have {}'.format(type(output_filter))
Expand Down
9 changes: 8 additions & 1 deletion do_mpc/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class Sampler:
Initiate the class by passing a :py:class:`do_mpc.sampling.samplingplanner.SamplingPlanner` (``sampling_plan``) object.
The class can be configured to create samples based on the defined cases in the ``sampling_plan``.
The class can be created with optional keyword arguments which are passed to :py:meth:`set_param`.
**Configuration and sampling:**
1. (Optional) use :py:meth:`set_param` to configure the class. Use :py:attr:`data_dir` to choose the save location for the samples.
Expand Down Expand Up @@ -54,7 +56,7 @@ def sample_function(alpha, beta):
sampler.sample_data()
"""
def __init__(self, sampling_plan):
def __init__(self, sampling_plan, **kwargs):
assert isinstance(sampling_plan, list), 'sampling_plan must be a list'
assert np.all([isinstance(plan_i, dict) for plan_i in sampling_plan]), 'All elements of sampling plan must be a dictionary.'

Expand Down Expand Up @@ -82,6 +84,11 @@ def __init__(self, sampling_plan):
self.print_progress = True
self.n_processes = 1

if kwargs:
self.set_param(**kwargs)



@property
def data_dir(self):
"""Set the save directory for the results.
Expand Down
45 changes: 44 additions & 1 deletion do_mpc/sampling/samplingplanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import pdb
import scipy.io as sio
import copy
import itertools
from do_mpc.tools import load_pickle, save_pickle


class SamplingPlanner:
"""A class for generating sampling plans.
These sampling plans will be executed by :py:class:`do_mpc.sampling.sampler.Sampler` to generate data.
The class can be created with optional keyword arguments which are passed to :py:meth:`set_param`.
**Configuration and sampling plan generation:**
1. Set variables which should be sampled with :py:func:`set_sampling_var`.
Expand All @@ -26,7 +29,7 @@ class SamplingPlanner:
5. Export the plan with all sampling cases with :py:meth:`export`
"""
def __init__(self):
def __init__(self, **kwargs):
self.sampling_vars = []
self.sampling_var_names = []
self.sampling_plan = []
Expand All @@ -41,6 +44,9 @@ def __init__(self):
self.overwrite = False
self.id_precision = 3

if kwargs:
self.set_param(**kwargs)

@property
def data_dir(self):
"""Set the save directory for the ``samplingplan``.
Expand Down Expand Up @@ -213,6 +219,43 @@ def gen_sampling_plan(self, n_samples):

return self.sampling_plan

def product(self, **kwargs):
"""Cartesian product of input variables.
This method is inspired by `itertools.product <https://docs.python.org/3/library/itertools.html#itertools.product>`_.
Must pass a list for each ``sampling_var`` that should be considered. Not all ``sampling_vars`` must be referenced.
Sampling vars that are excluded, will generate a value according to their assigned ``fun_var_pdf`` (see :py:meth:`set_sampling_var`).
:param kwargs: Keyword arguments of the form ``var_name=var_values``.
:type kwargs: dict
:return: None
:rtype: NoneType
"""
# Check if all key word values are lists:
check = np.alltrue([isinstance(v, list) for v in kwargs.values()])
if not check:
raise ValueError('keyword values must be lists')

# Check if all key words are existing sampling variables:
keys = kwargs.keys()
check = np.alltrue([v in self.sampling_var_names for v in keys])
if not check:
raise ValueError('keyword names must be existing sampling variables')

# Create cartesian product of all values passen in kwargs:
values = list(itertools.product(*list(kwargs.values())))

# Create new sampling cases:
for value in values:
# Zip together the value(s) of the current case with the respective keys:
case = dict(zip(keys, value))
# Add sampling case
self.add_sampling_case(**case)


return self.sampling_plan



def export(self, sampling_plan_name):
"""Export SamplingPlan in pickle format.
Expand Down

0 comments on commit 47029a2

Please sign in to comment.