In [None]:
class SnowflakesLayer(meta.Augmenter):
    def __init__(self, density, density_uniformity, flake_size,
                 flake_size_uniformity, angle, speed, blur_sigma_fraction,
                 blur_sigma_limits=(0.5, 3.75),
                 seed=None, name=None,
                 random_state="deprecated", deterministic="deprecated"):
        super(SnowflakesLayer, self).__init__(
            seed=seed, name=name,
            random_state=random_state, deterministic=deterministic)
        self.density = density
        self.density_uniformity = iap.handle_continuous_param(
            density_uniformity, "density_uniformity", value_range=(0.0, 1.0))
        self.flake_size = iap.handle_continuous_param(
            flake_size, "flake_size", value_range=(0.0+1e-4, 1.0))
        self.flake_size_uniformity = iap.handle_continuous_param(
            flake_size_uniformity, "flake_size_uniformity",
            value_range=(0.0, 1.0))
        self.angle = iap.handle_continuous_param(angle, "angle")
        self.speed = iap.handle_continuous_param(
            speed, "speed", value_range=(0.0, 1.0))
        self.blur_sigma_fraction = iap.handle_continuous_param(
            blur_sigma_fraction, "blur_sigma_fraction", value_range=(0.0, 1.0))

        # (min, max), same for all images
        self.blur_sigma_limits = blur_sigma_limits

        # (height, width), same for all images
        self.gate_noise_size = (8, 8)

    # Added in 0.4.0.
    def _augment_batch_(self, batch, random_state, parents, hooks):
        if batch.images is None:
            return batch

        images = batch.images

        rss = random_state.duplicate(len(images))
        for i, (image, rs) in enumerate(zip(images, rss)):
            batch.images[i] = self.draw_on_image(image, rs)
        return batch

    def get_parameters(self):
        """See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`."""
        return [self.density,
                self.density_uniformity,
                self.flake_size,
                self.flake_size_uniformity,
                self.angle,
                self.speed,
                self.blur_sigma_fraction,
                self.blur_sigma_limits,
                self.gate_noise_size]

    def draw_on_image(self, image, random_state):
        assert image.ndim == 3, (
            "Expected input image to be three-dimensional, "
            "got %d dimensions." % (image.ndim,))
        assert image.shape[2] in [1, 3], (
            "Expected to get image with a channel axis of size 1 or 3, "
            "got %d (shape: %s)" % (image.shape[2], image.shape))

        rss = random_state.duplicate(2)

        flake_size_sample = self.flake_size.draw_sample(random_state)
        flake_size_uniformity_sample = self.flake_size_uniformity.draw_sample(
            random_state)
        angle_sample = self.angle.draw_sample(random_state)
        speed_sample = self.speed.draw_sample(random_state)
        blur_sigma_fraction_sample = self.blur_sigma_fraction.draw_sample(
            random_state)

        height, width, nb_channels = image.shape
        downscale_factor = np.clip(1.0 - flake_size_sample, 0.001, 1.0)
        height_down = max(1, int(height*downscale_factor))
        width_down = max(1, int(width*downscale_factor))
        noise = self._generate_noise(
            height_down,
            width_down,
            self.density,
            rss[0]
        )

        # gate the sampled noise via noise in range [0.0, 1.0]
        # this leads to less flakes in some areas of the image and more in
        # other areas
        gate_noise = iap.Beta(1.0, 1.0 - self.density_uniformity)
        noise = self._gate(noise, gate_noise, self.gate_noise_size, rss[1])
        noise = ia.imresize_single_image(noise, (height, width),
                                         interpolation="cubic")

        # apply a bit of gaussian blur and then motion blur according to
        # angle and speed
        sigma = max(height, width) * blur_sigma_fraction_sample
        sigma = np.clip(sigma,
                        self.blur_sigma_limits[0], self.blur_sigma_limits[1])
        noise_small_blur = self._blur(noise, sigma)
        noise_small_blur = self._motion_blur(noise_small_blur,
                                             angle=angle_sample,
                                             speed=speed_sample,
                                             random_state=random_state)

        noise_small_blur_rgb = self._postprocess_noise(
            noise_small_blur, flake_size_uniformity_sample, nb_channels)

        return self._blend(image, speed_sample, noise_small_blur_rgb)

    @classmethod
    def _generate_noise(cls, height, width, density, random_state):
        noise = arithmetic.Salt(p=density, random_state=random_state)
        return noise.augment_image(np.zeros((height, width), dtype=np.uint8))

    @classmethod
    def _gate(cls, noise, gate_noise, gate_size, random_state):
        # the beta distribution here has most of its weight around 1.0 and
        # will only rarely sample values around 0.0 the average of the
        # sampled values seems to be at around 0.6-0.75
        gate_noise = gate_noise.draw_samples(gate_size, random_state)
        gate_noise_up = ia.imresize_single_image(gate_noise, noise.shape[0:2],
                                                 interpolation="cubic")
        gate_noise_up = np.clip(gate_noise_up, 0.0, 1.0)
        return np.clip(
            noise.astype(np.float32) * gate_noise_up, 0, 255
        ).astype(np.uint8)

    @classmethod
    def _blur(cls, noise, sigma):
        return blur.blur_gaussian_(noise, sigma=sigma)

    @classmethod
    def _motion_blur(cls, noise, angle, speed, random_state):
        size = max(noise.shape[0:2])
        k = int(speed * size)
        if k <= 1:
            return noise

        # we use max(k, 3) here because MotionBlur errors for anything less
        # than 3
        blurer = blur.MotionBlur(
            k=max(k, 3), angle=angle, direction=1.0, random_state=random_state)
        return blurer.augment_image(noise)

    # Added in 0.4.0.
    @classmethod
    def _postprocess_noise(cls, noise_small_blur,
                           flake_size_uniformity_sample, nb_channels):
        # use contrast adjustment of noise to make the flake size a bit less
        # uniform then readjust the noise values to make them more visible
        # again
        gain = 1.0 + 2*(1 - flake_size_uniformity_sample)
        gain_adj = 1.0 + 5*(1 - flake_size_uniformity_sample)
        noise_small_blur = contrast.GammaContrast(gain).augment_image(
            noise_small_blur)
        noise_small_blur = noise_small_blur.astype(np.float32) * gain_adj
        noise_small_blur_rgb = np.tile(
            noise_small_blur[..., np.newaxis], (1, 1, nb_channels))
        return noise_small_blur_rgb

    # Added in 0.4.0.
    @classmethod
    def _blend(cls, image, speed_sample, noise_small_blur_rgb):
        # blend:
        # sum for a bit of glowy, hardly visible flakes
        # max for the main flakes
        image_f32 = image.astype(np.float32)
        image_f32 = cls._blend_by_sum(
            image_f32, (0.1 + 20*speed_sample) * noise_small_blur_rgb)
        image_f32 = cls._blend_by_max(
            image_f32, (1.0 + 20*speed_sample) * noise_small_blur_rgb)
        return image_f32

    # TODO replace this by a function from module blend.py
    @classmethod
    def _blend_by_sum(cls, image_f32, noise_small_blur_rgb):
        image_f32 = image_f32 + noise_small_blur_rgb
        return np.clip(image_f32, 0, 255).astype(np.uint8)

    # TODO replace this by a function from module blend.py
    @classmethod
    def _blend_by_max(cls, image_f32, noise_small_blur_rgb):
        image_f32 = np.maximum(image_f32, noise_small_blur_rgb)
        return np.clip(image_f32, 0, 255).astype(np.uint8)