Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out distribution logic #59

Closed
beasteers opened this issue Apr 28, 2019 · 3 comments
Closed

Factor out distribution logic #59

beasteers opened this issue Apr 28, 2019 · 3 comments

Comments

@beasteers
Copy link
Contributor

beasteers commented Apr 28, 2019

So this isn't a high-priority issue at all and I'm not suggesting we implement it any time soon, it's just something I've had on my mind for a while that I wanted to put on paper. Basically, just factoring out all of the distribution and event parameter validation so it's a bit cleaner and easier to add new distributions. Here's a rough sketch of what I was thinking. Obviously there are some things left to figure out, but I think it could potentially simplify the Scaper core logic nicely.

Distributions

def _validate_value(spec, value):
    if spec.get('can_be_none'):
        return
    elif value is None:
        raise ScaperError('Value for parameter {} cannot be None.'.format(spec['name']))
    
    if 'min' in spec and value < spec['min']:
        raise ScaperError('Value {} for parameter {} exceeded minimum value {}.'.format(
            value, spec['name'], spec['min']))

    if 'max' in spec and value > spec['max']:
        raise ScaperError('Value {} for parameter {} exceeded maximum value {}.'.format(
            value, spec['name'], spec['max']))

    if 'is_file' in spec and os.path.isfile(value) == spec['is_file']:
        raise ScaperError('Value {} for parameter {} should be an existing file: {}'.format(
            value, spec['name'], spec['is_file'])) # not good phrasing but you get the idea.

    if 'allowed_choices' in spec is not None and value not in spec['allowed_choices']:
        raise ScaperError('Value {} for parameter {} not in available values: {}'.format(
            value, spec['name'], spec['allowed_choices']))

    ... # a whole suite of possible tests

class Distributions:
    '''Distribution Factory'''
    available = {}

    @classmethod
    def register(cls, distribution):
        cls.available[distribution.name] = distribution

    @classmethod
    def from_tuple(cls, dist_tuple):
        return cls.available[dist_tuple[0]](*dist_tuple[1:])


class Distribution:
    def __init__(self):
        raise NotImplementedError

    def validate(self):
        raise NotImplementedError

    def __call__(self):
        raise NotImplementedError

@Distributions.register
class Const(Distribution):
    name = 'const'
    
    def __init__(self, value):
        self.value = value

    def validate(self, spec):
        _validate_value(spec, self.value)

    def __call__(self):
        return self.value
    
@Distributions.register
class Choose(Distribution):
    name = 'choose'
    
    def __init__(self, choices):
        self.choices = choices
        super().__init__()

    def validate(self, spec):
        for choice in self.choices:
            _validate_value(spec, choice)
    
    def __call__(self):
        return random.choice(self.choices)

@Distributions.register
class Uniform(Distribution):
    name = 'uniform'
    
    def __init__(self, vmin, vmax):
        self.min = vmin
        self.max = vmax

    def validate(self, spec):
        _validate_value(spec, self.min)
        _validate_value(spec, self.max)

    def __call__(self):
        return random.uniform(self.min, self.max)

@Distributions.register
class Normal(Distribution):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def validate(self, spec):
        if spec.min or spec.max:
            warnings.warn(
                'A "normal" distribution tuple for {} can result in '
                'non-positives values, in which case the distribution will be '
                're-sampled until a positive value is returned: this can result '
                'in an infinite loop!'.format(spec.name),
                ScaperWarning)

    def __call__(self):
        return random.normal(self.mean, self.std)

@Distributions.register
class Truncnorm(Distribution):
    name = 'truncnorm'
    
    def __init__(self, mean, std, vmin, vmax):
        self.mean = mean
        self.std = std
        self.min = vmin
        self.max = vmax

    def validate(self, spec):
        _validate_value(spec, self.min)
        _validate_value(spec, self.max)

    def __call__(self):
        x = random.normal(self.mean, self.std)
        x = max(x, self.min) if self.min is not None else x
        x = min(x, self.max) if self.max is not None else x
        return x

Event Spec

# default_event_validation_spec = dict(
#     min=None, max=None,
#     is_real=None, file_exists=None,
#     allowed_distributions=None,
#     allowed_choices=None,
#     can_be_none=None
# )

event_validation_spec = {
    'label': dict(allowed_distributions={'const', 'choose'},
                  allowed_choices=()),
    'source_file': dict(is_file=True),
    'time': dict(min=0),
    'duration': dict(min=0, is_real=True),
    'snr': dict(is_real=True),
    'pitch_shift': dict(can_be_none=True, is_real=True),
    'time_stretch': dict(can_be_none=True, is_real=True, min=0),
}

# TODO: figure out how to pass allowed_choices

# add in name as a field (for error reporting)
for name, spec in event_validation_spec.items():
    spec['name'] = name


def sample_event_parameter(name, dist_tuple):
    # get the validation spec for the event parameter
    spec = dict(event_validation_spec[name], **kw)

    # make sure the distribution is valid for this parameter.
    if 'allowed_distributions' in spec and dist_tuple[0] not in spec['allowed_distributions']:
        raise ScaperError('Invalid distribution {} for parameter {}.'.format(
            dist_tuple[0], spec['name']))

    # create, validate, and sample from the distribution
    dist = Distributions.from_tuple(dist_tuple)
    dist.validate(spec)
    return dist()

def sample_event_spec(event_spec):
    return {
        sample_event_parameter(name, dist_tuple)
        for name, dist_tuple in zip(event_spec._fields, event_spec)
    }
@justinsalamon
Copy link
Owner

justinsalamon commented Feb 15, 2020

@beasteers I'm not entirely sure what this issue is about: I can see a proposed solution, but I'm not sure I understand the problem.

Some of the distribution logic has been moved around via #54, in case you want to have a look and let me know whether that addresses your concerns.

Otherwise, could you please provide a clear description of the problem, along with an example? No need to provide example code for the solution at this stage. Thanks!

@pseeth
Copy link
Collaborator

pseeth commented Feb 29, 2020

We changed how this works in #53. Does the current version of Scaper address this issue? @beasteers

@justinsalamon
Copy link
Owner

Closing this out in the absence of further comments. @beasteers feel free to re-open if you think there are any issues with the new setup, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants