In [1]:
# Testing imports
import os
import pytest
import numpy.testing as npt
import inspect

# Main module imports
import librosa
import os
import numpy as np
import random
from functools import wraps
import math
from scipy.signal import butter, lfilter

# Audio class

In [2]:
audio_manipulations = []

class Audio():
    '''
    A class to hold audio
    
    This class is a mutable container for
    information about an approx. 5s chunk
    of audio. It has the following attributes:
    
        self.samples (np.ndarray), audio samples
        self.sample_rate (float or int), samples/sec
            of self.samples
        self.manipulations (list of tuples), a list of 
            manipulations that have been performed on 
            this audio. Each tuple has the name of the
            function used to manipulate the 
        self.sources (list of tuples), a list of tuples
            where each tuple contains a path to the source,
            the start time of the samples, and the 
            duration of the samples
        self.labels (list of strings), a list of class 
            labels for all the contributing audio sources
    
    It may also have these attributes:
        self.original_path (string): path to original object.
            This is optional for convenience in creating/testing
            Audio objects from samples.
    '''
    
    global audio_manipulations
        
    def __repr__(self):
        return f'Audio({self.sources}, {self.labels})'
    
    def __init__(self, label, path = None, samples = None, sample_rate = None):

        self.original_path = None
        self.samples = None
        self.sample_rate = None
        self.possible_manipulations = []
        self.manipulations = []
        self.sources = []
        self.labels = set()
        
        # Obtain samples and sample rate
        if path:
            if samples:
                raise ValueError('Only one of `path` or `samples` can be provided')
            else: 
                self.original_path = path
                
                # Load all samples
                samples, sample_rate = self._load_audio(path)
                self.set_samples(samples)
                self.sample_rate = sample_rate   
        elif samples.any():
            if not sample_rate:
                raise ValueError('If `samples` are used, must provide `sample_rate`')
            else:
                self.samples = samples
                self.sample_rate = sample_rate
        else:
            raise ValueError('A non-empty `path` or `samples` must be provided')
        
        # Add label or labels
        self.add_label(label)
        
        # Set possible manipulations to global variable by default
        self.possible_manipulations = audio_manipulations
        
    def _load_audio(self, path):  
        # Load file
        if not os.path.exists(path):
            raise FileNotFoundError(f"File {path} does not exist")
        return librosa.load(path) #load samples from path

    def set_samples(self, samples):
        if type(samples) != np.ndarray:
            raise ValueError(f'Samples must be numpy.ndarray')
        if samples.ndim != 1:
            raise ValueError(f'Multi-channel sample array provided. (Dimensions: {samples.shape})')
        
        self.samples = samples
    
    def set_possible_manipulations(self, manip_list):
        '''
        Set self.possible_manipulations
        
        This determines what manipulations will be 
        considered valid in the add_manipulation step
        
        Inputs:
            manip_list: list of strings of names of
                possible manipulations
        '''
        
        self.possible_manipulations = manip_list
        
    
    def add_manipulation(self, manip, manip_kwargs):
        '''
        Add a manipulation to the list of manipulations
        
        Appends a tuple, (manip, manip_kwargs), to the
        list of manipulations for this audio, 
        self.manipulations.
        
        Inputs: 
            manip: function name of the manipulation
            manip_kwargs: keyword arguments used to
                call the manipulation function
        '''
        
        try:
            assert manip in self.possible_manipulations
        except:
            raise ValueError(f'Invalid input to Audio.add_manipulation(): {manip} not in list of valid manipulations')
            
        try:
            assert type(manip_kwargs) == dict
        except:
            raise ValueError(f'Invalid inputs to Audio.add_manipulation(): {manip_kwargs} must be dict')
        
        self.manipulations.append((manip, manip_kwargs))
        return True
    
    
    def add_label(self, label):
        '''
        Add label or set of labels to self.labels
        
        Inputs:
            label (string, list of strings, or
                set of strings): labels to add
        '''
        if type(label) == str:
            self.labels.add(label)
        elif type(label) == list or type(label) == set:
            for l in label:
                self.labels.add(label)
    
    
    def add_source(self, path = None, start_time = None, duration = None, source = None):
        '''
        Add a source to the list of sources
        
        Appends a tuple, (path, dur_tuple),
        to the list of sources for this audio,
        self.sources. Can be called either with each 
        individual component (path, start_time, duration)
        or with a `source` tuple straight from another 
        Audio object.
        
        Inputs:
            path (string): path to the source file
            start_time (float): start time in seconds
                within the source file
            duration (float): duration in seconds
                within the source file
            source (tuple): a tuple containing
                the above information:
                (path, (start_time, duration))
        '''
        
        # Validate all inputs were given

        # If using source
        if source:
            if type(source) != tuple or len(source) != 2:
                raise ValueError(f'Audio.set_sources() input source must be a tuple of len 2. Got: {source}')
            path = source[0]
            dur_tuple = source[1]
            if type(dur_tuple) != tuple or len(dur_tuple) != 2:
                raise ValueError(f'Second element of input source must be a tuple of len 2. Got: {dur_tuple}')
            start_time = dur_tuple[0]
            duration = dur_tuple[1]
        # If using path, start_time, and duration:
        else:
            if (not path) or (not start_time) or (not duration):
                raise ValueError(f'If not calling Audio.set_sources() with a source, must provide all three of: path, start_time, duration')
            source = (path, (start_time, duration))
        
        # Check inputs for correctness
        if not os.path.exists(path):
            raise FileNotFoundError(f'Source file does not exist: {path}')
        try:
            start_time = float(start_time)
            duration = float(duration)
        except ValueError:
            raise ValueError(f'start time and duration must be floats. Given type(start_time) == {type(start_time)}, type(duration) == {type(duration)}')
        
        # Append to list of sources
        self.sources.append(source)
        
        return True

### Tests of Audio()

In [3]:
def test_Audio_path_loading_error_checking():
    # Test handling of uncouth files
    with pytest.raises(FileNotFoundError):
        Audio(path = "SirNotAppearingOnThisFilesystem.wav", label = 'test_label')
test_Audio_path_loading_error_checking()

In [4]:
def test_Audio_set_samples_error_checking():
    # Test handling of bad sample setting
    with pytest.raises(ValueError):
        chunk = Audio(path = '../tests/veryshort.wav', label = 'test')
        samples = chunk.samples
        chunk.set_samples('1')
    assert(chunk.samples is samples)
test_Audio_set_samples_error_checking()

In [5]:
def test_Audio_add_source_input_checking_not_all_args():
    source_path = '../tests/veryshort.wav'
    chunk = Audio(path = source_path, label = 'test')

    # Not all arguments provided
    with pytest.raises(ValueError):
        chunk.add_source(path = 'me')
    with pytest.raises(ValueError):
        chunk.add_source(path = 'me', start_time = 1)    
    with pytest.raises(ValueError): 
        chunk.add_source(source = (source_path, ('a')) )
test_Audio_add_source_input_checking_not_all_args()

def test_Audio_add_source_input_checking_bad_path():
    source_path = '../tests/veryshort.wav'
    chunk = Audio(path = source_path, label = 'test')

    # Bad source path
    with pytest.raises(FileNotFoundError):
        chunk.add_source(source = ('me', (1, 2)))
    with pytest.raises(FileNotFoundError):
        chunk.add_source(path = 'me', start_time = 1, duration = 2)
test_Audio_add_source_input_checking_bad_path()


def test_Audio_add_source_input_checking_bad_timing():
    source_path = '../tests/veryshort.wav'
    chunk = Audio(path = source_path, label = 'test')

    # Bad start time or duration
    with pytest.raises(ValueError): 
        chunk.add_source(source = (source_path, ('a', 1)))
        chunk.add_source(source = (source_path, (1, 'a')))
        chunk.add_source(path = source_path, start_time = 'a', duration = 1) 
        chunk.add_source(path = source_path, start_time = 1, duration = 'a')
test_Audio_add_source_input_checking_bad_timing()

def test_Audio_add_source_actually_works():
    # Make sure a good one actually works!
    source_path = '../tests/veryshort.wav'
    chunk = Audio(path = source_path, label = 'test')
    chunk.add_source(path = source_path, start_time = 1, duration = 1)
    assert(chunk.sources) == [('../tests/veryshort.wav', (1, 1))]

    source_path = '../tests/veryshort.wav'
    chunk = Audio(path = source_path, label = 'test')
    chunk.add_source(source = (source_path, (1, 1)))
    assert(chunk.sources) == [('../tests/veryshort.wav', (1, 1))]
test_Audio_add_source_actually_works()

# Manipulations wrapper

In [6]:
def audio_manipulation(func):
    '''
    Functionality for audio manipulation functions
    
    Wrapper for audio manipulation that ensures the
    input to the function is of class Audio and 
    adds a record of the manipulation to the 
    Audio object's `manipulations` attribute.
    
    When a manipulation function is first defined
    with this wrapper, this function appends the
    manipulation to a list of valid manipulations,
    which will be tested to ensure the manipulation
    function has the attributes required (described
    below).
    
    Inputs:
      - func (function): a function with all arguments
        provided as kwargs, with one argument called `audio`.
        Function must return an Audio object and a list of
        keyword arguments that does not include the  
        reference to the manipulation Audio object
    
    Returns:
      - wrapped version of the function that returns only 
        the Audio object.
    '''
    
    global audio_manipulations
    audio_manipulations.append(func.__name__)
    
    @wraps(func) #Allows us to call help(func)
    def validate_audio(*args, **kwargs):
        try:
            audio_arg = kwargs['audio']
        except KeyError:
            try:
                audio_arg = args[0]
            except IndexError:
                raise ValueError("an Audio object must be provided as first argument or keyword argument 'audio'")
            
        if type(audio_arg) is not Audio:
            raise ValueError("an Audio object must be provided as first argument or keyword argument 'audio'")
            
        # Run manipulation
        manipulated_audio, arguments = func(*args, **kwargs)
        
        # Add manipulation to list of manipulations
        manipulated_audio.add_manipulation(func.__name__, arguments)
        
        return manipulated_audio
    
    return validate_audio
    

### Tests for wrapper

In [7]:
def test_audio_wrapper_arg_checking():
    # Raise error if kwarg is not "audio" type
    @audio_manipulation
    def function_with_good_kwarg(audio = None):
        return None
    with pytest.raises(ValueError):
        function_with_good_kwarg(audio = 'hah')

    @audio_manipulation 
    def function_with_bad_kwarg(notaudio = None):
        return None
    with pytest.raises(ValueError):
        function_with_bad_kwarg(notaudio = 'not')
        
test_audio_wrapper_arg_checking()

# Tests for all manipulations

In [8]:
audio_manipulations = []

# This line and the parameterizations don't do anything in the notebook
# They're just here for when these are moved into a test file
# Hence why audio_manipulations is "cleared" above before running the following lines
functional_audio_manipulations = [eval(func_string) for func_string in audio_manipulations]
@pytest.mark.parametrize('function', functional_audio_manipulations)
def test_audio_manipulation_audio_is_arg(function):
    # Throws a KeyError if 'audio' is not an argument
    inspect.signature(function).parameters['audio']
    
@pytest.mark.parametrize('function', functional_audio_manipulations)
def test_audio_manipulation_returns_Audio(function):
    
    my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
    returned_audio = function(audio = my_audio)
    
    # Ensure function gave us the correct return
    assert type(returned_audio) == Audio

@pytest.mark.parametrize('function', functional_audio_manipulations)
def test_audio_manipulation_adds_manipulation(function):    
    my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
    returned_audio = function(audio = my_audio)
    manipulation_1 = returned_audio.manipulations[0]
    
    # Create a desired dictionary of default values
    default = inspect.signature(function)
    sig_dict = dict(default.parameters)
    for key in sig_dict:
        sig_dict[key] = sig_dict[key].default
    sig_dict.pop('audio')
    
    # Ensure function added the correct entry to the manipulation list
    assert manipulation_1 == (function.__name__, sig_dict)

### Meta-tests: tests for function-testing tests

In [9]:
def test_audio_manipulation_test_catches_no_audio_arg():
    # Audio is not a kwarg
    def function_without_audio_arg(not_audio):
        return None
    with pytest.raises(KeyError):
        test_audio_manipulation_audio_is_arg(function_without_audio_arg)
test_audio_manipulation_test_catches_no_audio_arg()

In [10]:
def test_audio_manipulation_test_catches_wrong_return_format():
    # Manipulation does not return correct type
    def function_returning_wrong_type(audio = None):
        return True
    with pytest.raises(AssertionError):
        test_audio_manipulation_returns_Audio(function_returning_wrong_type)
    
    def function_with_two_returns(audio = None):
        return 'a', 'b'
    with pytest.raises(AssertionError):
        my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
        test_audio_manipulation_returns_Audio(function_with_two_returns) 
test_audio_manipulation_test_catches_wrong_return_format()

In [11]:
def test_audio_manipulation_test_catches_audio_object_not_removed_from_manips():
    # Manipulation does not delete 'audio' object from arguments
    def function_that_does_not_remove_audio_from_options(audio = None):
        arguments = locals()
        audio.set_possible_manipulations(['function_that_does_not_remove_audio_from_options'])
        audio.add_manipulation('function_that_does_not_remove_audio_from_options', arguments)
        return audio
    with pytest.raises(AssertionError):
        my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
        test_audio_manipulation_adds_manipulation(function_that_does_not_remove_audio_from_options)   
test_audio_manipulation_test_catches_audio_object_not_removed_from_manips()

In [12]:
# Manipulation works exactly as it's supposed to
def test_audio_manipulation_test_passes_good_manipulation_addition():
    possible_manipulations = ['function_that_works']
    def function_that_works(audio = None, another = 'default'):
        arguments = locals()
        audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
        audio.set_possible_manipulations(['function_that_works'])
        
        del arguments['audio']
        audio.add_manipulation('function_that_works', arguments)
        return audio
    test_audio_manipulation_adds_manipulation(function_that_works)
test_audio_manipulation_test_passes_good_manipulation_addition()

# Audio augmentations

## Chunk extraction

### Helper function: `wraparound_extract()`

In [13]:
def wraparound_extract(original, begin, length):
    '''
    Extracts elements from numpy.array in a "wraparound" fashion
    
    Extracts a certain number of elements from 
    a numpy.array starting at a certain position.
    If the chosen position and length go
    past the end of the array, the extraction
    "wraps around" to the beginning of the numpy.array
    as many times as necessary. For instance:
    
    wraparound_extract(
        original = [0, 5, 10],
        begin = 1, 
        length = 7) -> [5, 10, 0, 5, 10, 0, 5]
    
    Args:
        original (np.array): the original array 
        begin (int): beginning position to extract
        length (int): number of elements to extract
    '''

    # Get `head`: the array after the beginning position
    assert(type(original) == np.ndarray)
    len_original = original.shape[0]
    begin = begin % len_original
    head = original[begin:]
    len_head = head.shape[0]

    # Number of elements we require for full wrap-around
    wrap_needed = length - len_head

    # Generate the desired list, wrapped if necessary
    if wrap_needed > 0:
        repeats = np.tile(original, int(wrap_needed/len_original))
        tail = np.array(original[ : (wrap_needed % len_original)])
        desired_list = np.concatenate((head, repeats, tail))
    else:
        desired_list = original[begin:begin+length]
    
    #print(desired_list)
    return desired_list

### Test helper function

In [14]:
def test_wraparound_extract():
    # test zero beginning, not getting to end of original array
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 0, length = 1), np.array([0]))

    # test zero beginning, not getting to end of original array
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 0, length = 2), np.array([0, 1]))

    # test zero beginning, not wrapping
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 0, length = 2), np.array([0, 1]))

    # test zero beginning, wrapping around
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 0, length = 3), np.array([0, 1, 0]))

    # test nonzero beginning, not wrapping
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 1, length = 1), np.array([1]))

    # test nonzero beginning, wrapping around
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 1, length = 3), np.array([1, 0, 1]))

    # test multiwrap
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 1, length = 10), np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]))

    # test wrapping around beginning
    npt.assert_array_equal(wraparound_extract(original = np.array([0, 1]), begin = 5, length = 3), np.array([1, 0, 1]))

test_wraparound_extract()

### Main function: `get_chunk()`

In [15]:
@audio_manipulation
def get_chunk(
    audio,
    start_position = None, # randomize start position
    duration = 5, # 5 seconds
    duration_jitter = 0.5, #jitter duration +- 0.5s
    chance_random_skip = 0.3 #randomly skip 30% of the time
):
    '''
    Extracts chunk of audio with some augmentation
    
    Extracts samples of audio from a master list
    of samples. 
    
    Available data augmentation options include:
        - selecting a position to start extracting from
          or allowing function to randomly choose start
        - selecting duration of chunk and allowing
          for random jitter of duration
        - randomly skipping some number of samples from
          0 to the length of the chunk
    
    If the chunk to be extracted reaches the end of the
    samples, the chunk will "wrap around" and start
    reading from the beginning of the samples.
    
    Args:
        audio (instance of class Audio): Audio object to remove chunk from
        start_position (int): position in the file to start
            extracting samples from. If None, the start position 
            is chosen randomly
        duration (float): desired duration, in seconds, 
            of chunk to extract
        duration_jitter (float): if this value is not 0,
            the duration of the chunk extracted will 
            be randomly selected from the range 
            (duration - duration_jitter, duration + duration_jitter)
        chance_random_skip (float between 0 and 1):
            percent chance of random skipping. In a random skip,
            a position within the chunk will be randomly
            selected, and from that position in the 
            audio file, a random number of samples will 
            be skipped. The number of samples skipped is between
            0 and the number of samples in the entire chunk
    
    Returns to wrapper:
        audio (Audio): manipulated audio object
        options (dict): options the function was called with
    
    Returns when wrapped:
        audio (Audio): manipulated audio object
        
    '''
    # Get the given arguments
    options = locals()
    del options['audio']
    
    # Get a random start position
    num_samples = len(audio.samples)
    if not start_position:
        start_position = random.randint(0, num_samples)

    # Convert seconds to samples
    seconds_to_extract = duration + random.uniform(-duration_jitter, duration_jitter)
    samples_to_extract = int(seconds_to_extract * audio.sample_rate)
    
    # Get chunks with skip in the middle with probability = chance_random_skip
    if random.random() < chance_random_skip:
        position_to_skip = random.randint(0, samples_to_extract)
        amount_to_skip = random.randint(0, samples_to_extract)

        chunk_1_start = start_position
        chunk_1_end = chunk_1_start + position_to_skip
        chunk_2_start = chunk_1_end + amount_to_skip
        chunk_2_end = chunk_1_start + (samples_to_extract - position_to_skip)
        
        chunk_1 = wraparound_extract(audio.samples, chunk_1_start, chunk_1_end)
        chunk_2 = wraparound_extract(audio.samples, chunk_2_start, chunk_2_end)
        chunk = np.concatenate((chunk_1, chunk_2))
    
    # Otherwise get contiguous chunk
    else:
        chunk = wraparound_extract(audio.samples, start_position, samples_to_extract) 
        
    start_position_seconds = start_position / audio.sample_rate
    start_and_len = (start_position_seconds, seconds_to_extract)
    
    
    # Update attributes of audio function
    audio.set_samples(samples = chunk)
    audio.add_source(
        path = audio.original_path,
        start_time = start_position_seconds,
        duration = seconds_to_extract)
    
    return audio, options
    
my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
my_audio = get_chunk(audio = my_audio)

### Test get_chunk

In [16]:
test_audio_manipulation_audio_is_arg(get_chunk)
test_audio_manipulation_returns_Audio(get_chunk)
test_audio_manipulation_adds_manipulation(get_chunk)

## Cyclic shift

### Helper function: `shift_array()`

In [17]:
def shift_array(array, split_point = None):
    '''
    Shift array cyclicly by a random amount
    
    Shift array cyclicly by a random amount. Equivalent to
    splitting array into two parts at a random element, then
    switching the order of the parts.
    
    Args: 
        array (np.array): 1D-array to be split
        split_point (float): percentage from (0, 1) describing
            where in array to split -- for testing purposes.
            For stochastic splitting, leave as None.
    
    Returns:
        shifted_array: shifted array
    '''
    
    assert(type(array) == np.ndarray)
    length = array.shape[0]
    
    # Stochastic split point, or split point by floor of split_point * length of array
    if not split_point: split_point = random.randint(0, length)
    else: split_point = int(split_point * length)
    
    return np.concatenate((array[split_point:], array[:split_point]))

### Tests for helper function

In [18]:
def test_array_shifting():
    # Test random splitting
    random.seed(100)
    npt.assert_array_equal(shift_array(np.array((0, 1, 2, 3, 4, 5, 6, 7))), np.array([2, 3, 4, 5, 6, 7, 0, 1]))

    # Test deterministic splitting
    npt.assert_array_equal(shift_array(np.array([0, 1, 2]), split_point=0.5), np.array([1, 2, 0]))

    # Test deterministic splitting
    npt.assert_array_equal(shift_array(np.array([0, 1, 2, 3]), split_point=0.5), np.array([2, 3, 0, 1]))
test_array_shifting()

### Main function: `cyclic_shift()`

In [19]:
@audio_manipulation
def cyclic_shift(audio, split_point = None):
    '''
    Shift audio samples by a random amount
    
    Inputs: 
        audio (Audio object)
        split_point: where to split the things
    '''
    
    # Get the given arguments
    options = locals()
    del options['audio']
    
    new_samples = shift_array(audio.samples, split_point = split_point)
    
    audio.set_samples(new_samples)
    
    return audio, options

### Test cyclic_shift

In [20]:
test_audio_manipulation_audio_is_arg(cyclic_shift)
test_audio_manipulation_returns_Audio(cyclic_shift)
test_audio_manipulation_adds_manipulation(cyclic_shift)

## Divided-samples augmentations: time & freq

### Helper function to divide samples randomly: `divide_samples()`

In [21]:

def divide_samples(
    samples,
    sample_rate,
    low_duration = 0.5,
    high_duration = 5
):
    '''
    Divide audio samples into random-sized segments
    
    Divide audio samples into random-sized segments
    between the desired durations. The number
    of segments is not deterministic.
    
    Args:
        samples (np.ndarray): 1d array of samples
        sample_rate (int or float): sample rate of samples
        low_duration (float): minimum duration
            in seconds of any segment
        high_duration (float): maximum duration
            in seconds of any segment
    
    Returns:
        segments, list of sample lists
    '''

    min_chunk = int(low_duration * sample_rate)
    max_chunk = int(high_duration * sample_rate)
    
    samples_to_take = samples.copy()
    
    segments = []
    
    while samples_to_take.shape[0]:
        seg_size = random.randint(min_chunk, max_chunk)
        segment, samples_to_take = np.split(samples_to_take, [seg_size])
        segments.append(segment)
    
    return segments
    

### Test helper function

In [22]:
def test_divide_samples_at_set_amount():
    # Test chunk division at set amount
    array0 = np.array([0, 0, 0])
    array1 = np.array([1, 1, 1])
    array2 = np.array([2])
    all_arrays = (array0, array1, array2)
    cat_arrays = np.concatenate(all_arrays)
    divisions = divide_samples(samples=cat_arrays, sample_rate=1, low_duration=3, high_duration=3)
    
    for idx, division in enumerate(divisions):
        npt.assert_array_equal(division, all_arrays[idx])
        
test_divide_samples_at_set_amount()

In [23]:
def test_divide_samples_at_random_position():
    # Test random chunk division
    random.seed(333)
    
    # Predetermined results with random.seed(333)
    predetermined = [np.array([0, 1, 2, 3, 4, 5, 6, 7]), np.array([8, 9])]

    range_10 = np.array(range(10))
    divisions = divide_samples(samples=range_10, sample_rate=1, low_duration=0, high_duration=10)
    
    for idx, result in enumerate(divisions):
        npt.assert_array_equal(result, predetermined[idx])
        
test_divide_samples_at_random_position()

### Helper function to concatenate divisions: `combine_samples()`

In [24]:
def combine_samples(divided):
    '''
    Recombine divided sample arrays
    
    Combine divided sample arrays back into a 
    single array, perhaps after each division
    has been modified by pitch shifting, time stretching, etc.
    
    Args:
        divided (list of np.ndarrays): list of sample arrays
            divided by divide_samples()
    
    Returns:
        sample arrays concatenated back into a single array
    '''
    
    return np.concatenate(divided)

### Test helper function

In [25]:
def test_combine_samples():
    # Test that divided samples can be recombined successfully
    samples, sr = librosa.load('../tests/silence_10s.mp3')
    divided = divide_samples(samples, sample_rate=sr, low_duration=0.5, high_duration=4)
    npt.assert_array_equal(combine_samples(divided), samples)
test_combine_samples()

### Time stretch audio: `time_stretch_divisions()`

In [26]:
@audio_manipulation
def time_stretch_divisions(
    audio,
    low_division_duration = 0.5,
    high_division_duration = 4,
    chance_per_division = 0.50,
    mean_stretch = 1,
    sd_stretch = 0.05
):
    '''
    Time stretch divisions
    
    Given an Audio object, divide its samples and
    time stretch each division with some probability. 
    
    Args:
        audio (Audio object): audio object to
            be divided and time-stretched
        low_division_duration (float): minimum duration
            in seconds of any segment
        high_division_duration (float): maximum duration
            in seconds of any segment
        chance_per_division (float between 0 and 1): for
            each division, the chance it will be time-stretched
        mean_stretch (float): the mean stretch multiplier.
            == 1 is no stretch; > 1 is sped up, < 1 is slowed down
        sd_stretch (float > 0): the sd of the stretch 
            distribution. 
    
    Returns:
        stretched_divisions, time-stretched divisions
    '''
    options = locals()
    del options['audio']
    
    samples = audio.samples
    sample_rate = audio.sample_rate
    divisions = divide_samples(
        samples,
        sample_rate = sample_rate, 
        low_duration = low_division_duration,
        high_duration = high_division_duration)
    
    stretched_divisions = []

    for d in divisions:
        stretched_d = d
        # Stretch with some chance
        if random.random() < chance_per_division:
            stretch_factor = np.random.normal(
                loc = mean_stretch,
                scale = sd_stretch)
            if len(stretched_d) > 1:
                stretched_d = librosa.effects.time_stretch(y = stretched_d, rate = stretch_factor)
        stretched_divisions.append(stretched_d)
    
    recombined = combine_samples(stretched_divisions)
    audio.set_samples(recombined)
    
    return audio, options

### Test time stretching

In [27]:
test_audio_manipulation_audio_is_arg(time_stretch_divisions)
test_audio_manipulation_returns_Audio(time_stretch_divisions)
test_audio_manipulation_adds_manipulation(time_stretch_divisions)

In [28]:
def test_random_time_stretching():
    
    audio = Audio(samples = np.linspace(0, 1, 10), sample_rate=1, label = 'test')
    random.seed(33)
    np.random.seed(99)
    results = time_stretch_divisions(audio)

    # predetermined results for random.seed == 3 and np.random.seed == 111
    # np.random.seed must be set because randomness in time_stretch_divisions
    # comes from np.random.normal
    predetermined = np.array([0., 0.11111111, 0.22222222, 0.33333333, 0.44444444,
                  0.55555556, 0.53444054, 0.62240575, 0.88888889, 1.])
    
    npt.assert_array_almost_equal(results.samples, predetermined)
test_random_time_stretching()

### Frequency shift the divisions: `pitch_shift_divisions()`

In [29]:
@audio_manipulation
def pitch_shift_divisions(
    audio,
    low_division_duration = 0.5,
    high_division_duration = 4,
    chance_per_division = 0.40,
    mean_shift = 0,
    sd_shift = 0.25
):
    '''
    Time stretch divisions
    
    Given an Audio object, divide its samples and
    pitch-shift each division with some probability. 
    The mean_shift and sd_shift should be given in "fractional
    half-steps," e.g. 0.25 = 1/4th of a half-step = 25 cents.
    
    Args:
        audio (Audio object): audio object to
            be divided and time-stretched
        low_division_duration (float): minimum duration
            in seconds of any segment
        high_division_duration (float): maximum duration
            in seconds of any segment
        chance_per_division (float between 0 and 1): for
            each division, the chance it will be time-stretched
        mean_shift (float): the mean pitch shift in (fractional) half-steps
            == 0 is no shift; > 0 is shift up; < 1 is shift down
        sd_shift (float > 0): the sd of the shift 
            distribution in cents
    
    Returns:
        shifted_divisions, pitch-shifted divisions
    '''
    
    
    options = locals()
    del options['audio']
    
    samples = audio.samples
    sample_rate = audio.sample_rate
    divisions = divide_samples(
        samples,
        sample_rate = sample_rate, 
        low_duration = low_division_duration,
        high_duration = high_division_duration)
    
    shifted_divisions = []
    
    for d in divisions:
        shifted_d = d
        if random.random() < chance_per_division:
            shift_factor = np.random.normal(
                loc = mean_shift,
                scale = sd_shift)
            shifted_d = librosa.effects.pitch_shift(
                y = shifted_d,
                sr = sample_rate,
                n_steps = shift_factor)
            
        shifted_divisions.append(shifted_d)
        
    recombined = combine_samples(shifted_divisions)
    audio.set_samples(recombined)
    
    return audio, options

### Test pitch shifting

In [30]:
test_audio_manipulation_audio_is_arg(pitch_shift_divisions)
test_audio_manipulation_returns_Audio(pitch_shift_divisions)
test_audio_manipulation_adds_manipulation(pitch_shift_divisions)

## Random audio filtering: `random_filter()`

In [31]:
@audio_manipulation
def random_filter(
    audio,
    percent_chance = 0.20,
    filter_type = None,
    filter_order = None,
    filter_low = None,
    filter_high = None,
    error_check = True
):
    '''
    Randomly filter audio samples
    
    With some probability, apply a filter to `samples`. 
    Some or all of the filter's characteristics can be 
    provided by the user; otherwise, they are
    are randomly selected from the following options:
    
    Type: lowpass, highpass, bandpass, bandstop
    Order: 1-5
    Low cutoff frequency: from 1Hz to (sample_rate/2) - 1 Hz
    High cutoff frequency (bandpass 
        and bandstop filters): from low_freq+1 
        to (sample_rate/2) - 1 Hz
        
    If filter output contains values not between -1.0 and 1.0,
    the original signal is returned to avoid glitchy filters.
    '''
    
    options = locals()
    del options['audio']
    
    samples = audio.samples
    sample_rate = audio.sample_rate
    
    if random.random() < percent_chance:
        
        # Nyquist frequency
        nyq = 0.5 * sample_rate
        
        # Select random filter choices
        if not filter_type: filter_type = random.choice(
            ['lowpass', 'highpass', 'bandpass', 'bandstop'])
        if not filter_order: filter_order = random.randint(1, 5)
        if not filter_low: filter_low = random.randint(1, (nyq - 1))
        if not filter_high:
            if filter_type in ['bandpass', 'bandstop']:
                filter_high = random.randint(filter_low, nyq - 1)
            else:
                filter_high = nyq - 1
        

        # Filter the audio
        low = filter_low / nyq
        high = filter_high / nyq
        b, a = butter(filter_order, [low, high], btype='band')
        filtered = lfilter(b, a, samples)

         # Set samples to filtered if not error checking, or if passes error check
        if not error_check:
            audio.set_samples(filtered)
        elif error_check:
            if ~(np.less(filtered, -1, where=~np.isnan(filtered)).any()) and \
               ~(np.greater(filtered, 1, where=~np.isnan(filtered)).any()):
                audio.set_samples(filtered)
    
    return audio, options
    

In [32]:
def test_filter_err_checking():
    # This audio contains values above 1 naturally,
    # and will cause errors in the filters:
    audio = Audio(path='../tests/1min.wav', label='tests')
    original_samples = audio.samples
    assert(~(audio.samples > 1).any())

    # This filter will produce an invalid output 
    # i.e., the array will contain values above 1
    filtered_not_checked = random_filter(
        audio,
        percent_chance = 1,
        filter_type = 'highpass',
        filter_order = 5,
        filter_low = 20,
        filter_high = 30,
        error_check = False
    )
    assert(filtered_not_checked.samples is not original_samples)

    
    audio = Audio(path='../tests/1min.wav', label='tests')
    original_samples = audio.samples
    assert(~(audio.samples > 1).any())
    # The same filter as above, but with error checking: 
    # the error check should flag the invalid content
    # in the filtered result and return the original array
    filtered_checked = random_filter(
        audio,
        percent_chance = 1,
        filter_type = 'highpass',
        filter_order = 5,
        filter_low = 20,
        filter_high = 30,
        #error_check = True # Error checking by default
    )
    assert(filtered_checked.samples is original_samples)
    
test_filter_err_checking()

## Adding audio chunks

### Helper function, fade audio in or out: `fade()`

In [33]:
def fade(array, fade_len, start_amp=1):
    '''
    Fade audio in or out
    
    Args:
        array (np.array): 1d audio array to fade
            in or out
        fade_len (int): the number of samples over which
            the fade should occur; must be smaller than 
            array.shape[0]
        start_amp (int, 1 or 0): whether to start at full 
            volume and fade out (1) or start at
            0 volume and fade in (0)
        
    '''
    
    if not ((start_amp is 0) or (start_amp is 1)):
        raise ValueError(f'start_amp must be either 0 or 1. Got {start_amp}')
    
    pad_len = int(array.shape[0] - fade_len)
    if pad_len < 0:
        raise IndexError(f'Given value of fade_len ({fade_len}) is longer than the number of samples in array ({array.shape[0]})')
    
    # Construct fade filter
    #fade_filter = np.linspace(start_amp, int(not start_amp), fade_len)
    # If fade_len is 1 and start_amp is 1, the above code results in 
    # a fade_filter = np.array([1.]), i.e. no fading. The below code
    # ensures that the end amplitude is included
    fade_filter = np.flip(np.linspace(int(not start_amp), start_amp, fade_len))
    
    # Pad filter for array length
    if start_amp == 0: # fade in at start
        fade_filter_padded = np.pad(
            fade_filter,
            (0, pad_len), # pad right side
            constant_values = 1, # with 1s
            mode = 'constant'
        )
    else: # start_amp == 1, fade out at end
        fade_filter_padded = np.pad(
            fade_filter,
            (pad_len, 0), # pad left side
            constant_values = 1, # with 1s
            mode = 'constant'
        )
    return np.multiply(array, fade_filter_padded)

In [34]:
# Assert that can only provide 0 or 1 as start_amp
def test_only_binary_start_amp():
    with pytest.raises(ValueError):
        fade(array = np.array((1, 1, 1)), fade_len=3, start_amp=1.0)
    with pytest.raises(ValueError):
        fade(array = np.array((1, 1, 1)), fade_len=3, start_amp=True)

# Assert that fading out doesn't work if fade_len is too long
def test_fade_too_long():
    with pytest.raises(IndexError):
        fade(array = np.array((1, 1, 1, 1, 1)), fade_len=6, start_amp=1)
        
# Fade in on array exactly the same length as fade_len
def test_fade_on_exact_length_array():
    fade_in = fade(array = np.array((1, 1, 1, 1, 1)), fade_len=5, start_amp=0)
    npt.assert_array_equal(fade_in, np.array([0., 0.25, 0.5, 0.75, 1.]))

# Fade out array longer than fade_len
def test_fade_on_long_array():
    fade_out = fade(array = np.array((1, 1, 1, 1, 1, 1, 1)), fade_len=5, start_amp=1)
    npt.assert_array_equal(fade_out, np.array([1., 1., 1., 0.75, 0.5, 0.25, 0.]))
    

test_only_binary_start_amp()
test_fade_too_long()
test_fade_on_exact_length_array()
test_fade_on_long_array()

### Helper function, pairwise sample summer: `sum_samples()`

In [35]:
def sum_samples(
    samples_original,
    samples_new,
    sample_rate,
    wraparound_fill = False,
    fade_out = True
):
    '''
    Sums audio samples and updates labels
    
    Combines audio samples, samples_new, on top
    of samples_original, overlaying samples_new
    so it begins at the same time as samples_original.
    
    Args:
        samples_original (np.array): samples to 
            overlay new samples on
        samples_new (np.array): samples to be
            overlayed on original samples. If shorter
            than samples_original, can either be repeated/
            wrapped around to reach length of
            samples_original, or can be faded out
        sample_rate (int or float): mutual sample rate
            of both samples_original and samples_new
        wraparound_fill (bool): whether or not to 
            fill in short samples_new by wrapping around
        fade_out (bool): whether or not to fade out 
            short samples_new. If wraparound_fill == True,
            this option does not apply.
            
    Returns:
        summed samples
    '''
    
    original_len = samples_original.shape[0]
    new_len = samples_new.shape[0]
    discrepancy = original_len - new_len
    
    # Add new samples to original samples, possibly applying 
    # fade-out, filling, etc.
    if discrepancy > 0: # if new_len shorter than original_len
        # Make up length by repeating/"wrapping around"
        if wraparound_fill:
            samples_to_add = wraparound_extract(
                original = samples_new,
                begin = 0,
                length = original_len)
        
        # Make up length with zero-padding
        else:
            samples_to_add = samples_new.copy()
            if fade_out:
                # Number of samples used in fade should be about 0.5s
                fade_samples = math.ceil(0.1 * sample_rate)
                if fade_samples > new_len: fade_samples = new_len

                # Apply fade
                samples_to_add = fade(
                    array = samples_to_add,
                    fade_len = fade_samples,
                    start_amp = 1,
                )
            
            # Zero pad
            samples_to_add = np.pad(
                samples_to_add,
                (0, discrepancy),
                constant_values = 0,
                mode='constant'
            )
    else:
        samples_to_add = samples_new[:original_len]
        
    return np.add(samples_original, samples_to_add)

In [36]:
def test_wrap_fade_combos():
    # Test fade & wraparound on audio-like numpy arrays
    nowrap_nofade = sum_samples(
        samples_original = np.array((1., 1., 500.)),
        samples_new = np.array((10., 11.)),
        sample_rate = 1,
        wraparound_fill = False,
        fade_out = False
    )
    npt.assert_array_equal(nowrap_nofade, np.array([11., 12.,  500.]))

    nowrap_fade = sum_samples(
        samples_original = np.array((1., 1., 1., 500.)),
        samples_new = np.array((10., 10.)),
        sample_rate = 1,
        wraparound_fill = False,
        fade_out = True
    )
    npt.assert_array_equal(nowrap_fade, np.array([ 11.,   1.,   1., 500.]))

    wrap_nofade = sum_samples(
        samples_original = np.array((1., 1., 500.)),
        samples_new = np.array((10., 11.)),
        sample_rate = 1,
        wraparound_fill = True,
        fade_out = False
    )
    npt.assert_array_equal(wrap_nofade, np.array([11., 12.,  510.]))

    # Same behavior as wrap_nofade
    wrap_fade = sum_samples(
        samples_original = np.array((1., 1., 500.)),
        samples_new = np.array((10., 11.)),
        sample_rate = 1,
        wraparound_fill = True,
        fade_out = True
    )
    npt.assert_array_equal(wrap_nofade, np.array([11., 12.,  510.]))
    
test_wrap_fade_combos()

def test_fade_on_actual_audio():
    # Test on actual audio without fade or wraparound
    audio_original = Audio(path = '../tests/1min.wav', label = 'test')
    audio_new = cyclic_shift(audio_original)

    samples_original = audio_original.samples[:22050]
    samples_new = audio_new.samples[:11025]

    summed = sum_samples(
        samples_original = samples_original,
        samples_new = samples_new,
        sample_rate = audio_original.sample_rate,
        wraparound_fill = False,
        fade_out = False)

    true_summed = np.add(samples_original, np.pad(samples_new, (0, 11025), constant_values=0, mode='constant'))
    npt.assert_array_equal(summed, true_summed)
    
test_fade_on_actual_audio()

### Helper function, select audio chunks: `select_chunk()` (not implemented yet)

In [37]:
def select_chunk(
    chunk_source,
    label,
    start_position = None,
    duration = 6, # should almost always be longer than source chunk
    duration_jitter = 0,
    chance_random_skip = 0.3
):
    
    # Randomly choose and open a source audio file
    wavs = [f for f in os.listdir(chunk_source) if f[-4:].lower() == '.wav']
    mp3s = [f for f in os.listdir(chunk_source) if f[-4:].lower() == '.mp3']
    desired_path = os.path.join(chunk_source, random.choice(wavs+mp3s))
    source_audio = Audio(path = desired_path, label = label)
    
    return get_chunk(
        source_audio,
        start_position = start_position,
        duration = duration, 
        duration_jitter = duration_jitter, 
        chance_random_skip = chance_random_skip
    )

In [38]:
select_chunk('../tests/', label='none')

Audio([('../tests/1min.wav', (5.327936507936508, 6.0))], {'none'})

### Main function: `sum_chunks()`

In [39]:
label_dict = {'test':'../tests/'}
@audio_manipulation
def sum_chunks(
    audio,
    new_chunk_labels = ['random']*4,
    label_dict = label_dict,
    start_position = None,
    duration = 6, 
    duration_jitter = 0,
    chance_random_skip = 0.3
):
    '''
    Add a random chunk to audio
    
    Grab a random number of chunks, from 0 to 4, 
    randomize their signal amplitude (multiply
    by a random number from 0 to 1), and add 
    the chunks to the audio. 
    
    Args:
        audio (Audio instance): original chunk 
        label_dict (dict): dictionary associating
            labels (keys) with paths (values). Each
            path is the place on the filesystem where 
            files of the given label can be found.
        new_chunk_labels (list of strings): list of 
            labels for new chunks, in order of potential
            addition. 
            
            Labels should be strings. Options:
            
                'original': same as original
                'different': different from original
                'random': any
                any key in label_dict
            
            New chunks are added with the 
            following probabilities:
                first chunk: 50%
                second chunk: 
                    if first chunk added: 40%
                    else: 0%
                third chunk:
                    if second chunk added: 30%
                    else: 0%
                fourth chunk: 
                    if third chunk added: 20%
                    else: 0%
                    
       start_position (int): position in the file to start
            extracting samples from. If None, the start position 
            is chosen randomly
        duration (float): desired duration, in seconds, 
            of chunk to extract
        duration_jitter (float): if this value is not 0,
            the duration of the chunk extracted will 
            be randomly selected from the range 
            (duration - duration_jitter, duration + duration_jitter)
        chance_random_skip (float between 0 and 1):
            percent chance of random skipping. In a random skip,
            a position within the chunk will be randomly
            selected, and from that position in the 
            audio file, a random number of samples will 
            be skipped. The number of samples skipped is between
            0 and the number of samples in the entire chunk
    '''
    options = locals()
    del options['audio']
    
    if (type(new_chunk_labels) != list) or (len(new_chunk_labels) != 4):
        raise ValueError("`new_chunk_labels` must be a list of four labels")
    for label in new_chunk_labels:
        if label not in ['original', 'different', 'random'] + list(label_dict.keys()):
            raise ValueError("Labels must be in label_dict.keys() or"
                            " in ['original', 'different', 'random']")
    
    sample_rate = audio.sample_rate
    
    chunks_to_add = 0
    if random.random() < 0.5:
        chunks_to_add += 1
        if random.random() < 0.4:
            chunks_to_add += 1
            if random.random() < 0.3:
                chunks_to_add += 1
                if random.random() < 0.2:
                    chunks_to_add += 1
    
    # Iteratively combine chunks and labels
    for idx in range(chunks_to_add):
        
        # Select a new label if necessary
        label = new_chunk_labels[idx]
        if label == 'random':
            label = random.choice(list(label_dict.keys()))
            chunk_source = label_dict[label]
        elif label == 'different':
            possible = list(label_dict.keys())
            for l in audio.labels:
                possible.pop(l)
            label = random.choice(possible)
        elif label == 'same':
            label = random.choice(list(audio.labels))
            

        # Randomly grab chunk from source
        chunk_source == label_dict[label]
        new_chunk = select_chunk(
            chunk_source = chunk_source,
            label = label,
            start_position = start_position,
            duration = duration,
            duration_jitter = duration_jitter,
            chance_random_skip = chance_random_skip
        )
        
        # Randomly change amplitude of chunk
        amp_modifier = random.randrange(0, 1) # TODO: not sure if this has the intended effect
        new_chunk.set_samples(np.multiply(new_chunk.samples, amp_modifier))
        
        # Add chunks together
        summed_samples = sum_samples(
            samples_original = audio.samples,
            samples_new = new_chunk.samples,
            sample_rate = sample_rate,
            wraparound_fill = False,
            fade_out = False
        )
        
        audio.set_samples(summed_samples)
        for new_label in new_chunk.labels:
            audio.add_label(new_label)
        for new_source in new_chunk.sources:
            audio.add_source(source = new_source)
    
    return audio, options

In [40]:
test_audio_manipulation_audio_is_arg(sum_chunks)
test_audio_manipulation_returns_Audio(sum_chunks)
test_audio_manipulation_adds_manipulation(sum_chunks)

# Image augmentations

In [41]:
spectrogram_manipulations = []

class Spectrogram():
    def __init__(self, audio = None):
        '''
        Set up a Spectrogram object but do not create spectrogram
        '''
        global spectrogram_manipulations

        # From audio object, filled by class method set_audio
        self.audio = None
        self.sample_rate = None
        self.samples = None
        self.labels = None
        
        # Filled by spect manipulation functions
        self.spect = None
        self.times = None
        self.freqs = None
        self.possible_manipulations = spectrogram_manipulations
        
        # Filled by class methods
        self.manipulations = []
        self.sources = []
    
        # Set self.audio, self.samples, self.sample_rate, self.labels
        self.set_audio(audio)
        
    def set_audio(self, audio):
         
        if type(audio) != Audio:
            raise ValueError('must pass an instance of Audio class')
        
        self.audio = audio
        self.samples = audio.samples
        self.sample_rate = audio.sample_rate
        self.labels = audio.labels
    
        
    def set_possible_manipulations(self, manip_list):
        '''
        Set self.possible_manipulations
        
        This determines what manipulations will be 
        considered valid in the add_manipulation step
        
        Inputs:
            manip_list: list of strings of names of
                possible manipulations
        '''
        
        self.possible_manipulations = manip_list
    
        
    def add_manipulation(self, manip, manip_kwargs):
        '''
        Add a manipulation to the list of manipulations
        
        Appends a tuple, (manip, manip_kwargs), to the
        list of manipulations for this audio, 
        self.manipulations.
        
        Inputs: 
            manip: function name of the manipulation
            manip_kwargs: keyword arguments used to
                call the manipulation function
        '''
        
        try:
            assert manip in self.possible_manipulations
        except:
            raise ValueError(f'Invalid input to Spectrogram.add_manipulation(): {manip} not in list of valid manipulations')
            
        try:
            assert type(manip_kwargs) == dict
        except:
            raise ValueError(f'Invalid inputs to Spectrogram.add_manipulation(): {manip_kwargs} must be dict')
        
        self.manipulations.append((manip, manip_kwargs))
        return True
    
    
    def add_source(self, path = None, start_time = None, duration = None, source = None):
        '''
        Add a source to the list of sources
        
        Appends a tuple, (path, dur_tuple),
        to the list of sources for this spectrogram,
        self.sources. Can be called either with each 
        individual component (path, start_time, duration)
        or with a `source` tuple straight from another 
        Audio object.
        
        Inputs:
            path (string): path to the source file
            start_time (float): start time in seconds
                within the source file
            duration (float): duration in seconds
                within the source file
            source (tuple): a tuple containing
                the above information:
                (path, (start_time, duration))
        '''
        
        # Validate all inputs were given

        # If using source
        if source:
            if type(source) != tuple or len(source) != 2:
                raise ValueError(f'Spectrogram.set_sources() input source must be a tuple of len 2. Got: {source}')
            path = source[0]
            dur_tuple = source[1]
            if type(dur_tuple) != tuple or len(dur_tuple) != 2:
                raise ValueError(f'Second element of input source must be a tuple of len 2. Got: {dur_tuple}')
            start_time = dur_tuple[0]
            duration = dur_tuple[1]
        # If using path, start_time, and duration:
        else:
            if (not path) or (not start_time) or (not duration):
                raise ValueError(f'If not calling Spectrogram.set_sources() with a source, must provide all three of: path, start_time, duration')
            source = (path, (start_time, duration))
        
        # Check inputs for correctness
        if not os.path.exists(path):
            raise FileNotFoundError(f'Source file does not exist: {path}')
        try:
            start_time = float(start_time)
            duration = float(duration)
        except ValueError:
            raise ValueError(f'start time and duration must be floats. Given type(start_time) == {type(start_time)}, type(duration) == {type(duration)}')
        
        # Append to list of sources
        self.sources.append(source)
        
        return True

### Tests for Spectrogram class

In [42]:
def test_Spectrogram_requires_Audio_object():
    # Must use an Audio object
    with pytest.raises(ValueError):
        spect = Spectrogram(audio = 'not an audio object')

    audio = Audio(path='../tests/silence_10s.mp3', label='test')
    spect = Spectrogram(audio)
test_Spectrogram_requires_Audio_object()

In [43]:
def test_Spectrogram_add_source_not_all_arguments_provided():
    # TODO: move this into a fixture
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)

    # Not all arguments provided
    with pytest.raises(ValueError):
        spect.add_source(path = 'me')
    with pytest.raises(ValueError):
        spect.add_source(path = 'me', start_time = 1)    
    with pytest.raises(ValueError): 
        spect.add_source(source = (source_path, ('a')) )

test_Spectrogram_add_source_not_all_arguments_provided()

In [44]:
def test_Spectrogram_add_source_not_all_arguments_provided():
    # TODO: move this into a fixture
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)

def test_Spectrogram_add_source_bad_path():
    # TODO: move this into a fixture
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)
    
    # Bad source path
    with pytest.raises(FileNotFoundError):
        spect.add_source(source = ('me', (1, 2)))
    with pytest.raises(FileNotFoundError):
        spect.add_source(path = 'me', start_time = 1, duration = 2)
test_Spectrogram_add_source_bad_path()
        
def test_Spectrogram_add_source_bad_duration():
    # TODO: move this into a fixture
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)
        
    # Bad start time or duration
    with pytest.raises(ValueError): 
        spect.add_source(source = (source_path, ('a', 1)))
        spect.add_source(source = (source_path, (1, 'a')))
        spect.add_source(path = source_path, start_time = 'a', duration = 1) 
        spect.add_source(path = source_path, start_time = 1, duration = 'a')
test_Spectrogram_add_source_bad_duration()

def test_Spectrogram_add_source_works_correctly():
    # TODO: move these into fixtures
    
    # Make sure a good one actually works!
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)
    spect.add_source(path = source_path, start_time = 1, duration = 1)
    assert(spect.sources) == [('../tests/veryshort.wav', (1, 1))]
   
    source_path = '../tests/veryshort.wav'
    audio = Audio(path = source_path, label = 'test')
    spect = Spectrogram(audio)
    spect.add_source(source = (source_path, (1, 1)))
    assert(spect.sources) == [('../tests/veryshort.wav', (1, 1))]
test_Spectrogram_add_source_works_correctly()

# Spectrogram manipulation wrapper

In [45]:
spectrogram_manipulations = []
def spectrogram_manipulation(func):
    '''
    Functionality for spectrogram manipulation functions
    
    Wrapper for spectrogram manipulation that ensures the
    input to the function is of class Spectrogram and 
    adds a record of the manipulation to the 
    Spectrogram object's `manipulations` attribute.
    
    When a manipulation function is first defined
    with this wrapper, this function appends the
    manipulation to a list of valid manipulations,
    which will be tested to ensure the manipulation
    function has the attributes required (described
    below).
    
    Inputs:
      - func (function): a function with all arguments
        provided as kwargs, except one argument called `spectrogram`.
        Function must return an Spectrogram object and a list of
        keyword arguments that does not include the  
        reference to the manipulated Spectrogram object
    
    Returns:
      - wrapped version of the function that returns only 
        the Spectrogram object.
    '''
    
    global spectrogram_manipulations
    spectrogram_manipulations.append(func.__name__)
    
    @wraps(func) #Allows us to call help(func)
    def validate_spectrogram(*args, **kwargs):
        try:
            spect_arg = kwargs['spectrogram']
        except KeyError:
            try:
                spect_arg = args[0]
            except IndexError:
                raise ValueError("a Spectrogram object must be provided as first argument or keyword argument 'spectrogram'")
            
        if type(spect_arg) is not Spectrogram:
            raise ValueError("a Spectrogram object must be provided as first argument or keyword argument 'spectrogram'")
            
        # Run manipulation
        manipulated_spect, arguments = func(*args, **kwargs)
        
        # Add manipulation to list of manipulations
        manipulated_spect.add_manipulation(func.__name__, arguments)
        
        return manipulated_spect
    
    return validate_spectrogram

### Tests of wrapper

In [46]:
def test_spectrogram_wrapper_spectrogram_arg_is_required():
    @spectrogram_manipulation
    def function_with_good_kwarg(spectrogram = None):
        return None
    with pytest.raises(ValueError):
        function_with_good_kwarg(spectrogram = 'hah')

    @spectrogram_manipulation
    def function_with_bad_kwarg(notspectrogram = None):
        return None
    with pytest.raises(ValueError):
        function_with_bad_kwarg(notspectrogram = 'not')
test_spectrogram_wrapper_spectrogram_arg_is_required()

# Tests for all manipulation functions

In [47]:
spectrogram_manipulations = []

# This line and the parameterizations don't do anything in the notebook
# They're just here for when these are moved into a test file
# Hence why spectrogram_manipulations is "cleared" above before running the following lines
functional_spectrogram_manipulations = [eval(func_string) for func_string in spectrogram_manipulations]
@pytest.mark.parametrize('function', functional_spectrogram_manipulations)
def test_spectrogram_manipulation_spectrogram_is_arg(function):
    # Throws a KeyError if 'spectrogram' is not an argument
    inspect.signature(function).parameters['spectrogram']
    
@pytest.mark.parametrize('function', functional_spectrogram_manipulations)
def test_spectrogram_manipulation_returns_Spectrogram(function):
    
    my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
    my_spectrogram = Spectrogram(my_audio)
    returned_spectrogram = function(spectrogram = my_spectrogram)
    
    # Ensure function gave us the correct return
    assert type(returned_spectrogram) == Spectrogram

@pytest.mark.parametrize('function', functional_spectrogram_manipulations)
def test_spectrogram_manipulation_adds_manipulation(function):    
    my_audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
    my_spectrogram = Spectrogram(my_audio)
    returned_spectrogram = function(spectrogram = my_spectrogram)
    manipulation_1 = returned_spectrogram.manipulations[0]
    
    # Create a desired dictionary of default values
    default = inspect.signature(function)
    sig_dict = dict(default.parameters)
    for key in sig_dict:
        sig_dict[key] = sig_dict[key].default
    sig_dict.pop('spectrogram')
    
    # Ensure function added the correct entry to the manipulation list
    assert manipulation_1 == (function.__name__, sig_dict)

### TODO: tests of tests

In [48]:
def test_spect_manipulation_test_catches_no_spectrogram_arg():
    # spectrogram is not an arg or kwarg
    def function_without_spect_arg(not_spectrogram):
        return None
    with pytest.raises(KeyError):
        test_spectrogram_manipulation_spectrogram_is_arg(function_without_spect_arg)
test_spect_manipulation_test_catches_no_spectrogram_arg()

In [49]:
def test_spect_manipulation_test_catches_wrong_return_format():
    # Manipulation does not return correct type
    def function_returning_wrong_type(spectrogram = None):
        return True
    with pytest.raises(AssertionError):
        test_spectrogram_manipulation_returns_Spectrogram(function_returning_wrong_type)
    
    def function_with_two_returns(spectrogram = None):
        return 'a', 'b'
    with pytest.raises(AssertionError):
        test_spectrogram_manipulation_returns_Spectrogram(function_with_two_returns)
test_spect_manipulation_test_catches_wrong_return_format()

In [50]:
def test_spect_manipulation_test_catches_Spectrogram_object_not_removed_from_manips():
    # Manipulation does not delete 'spectrogram' object from arguments
    def function_that_does_not_remove_spectrogram_from_options(spectrogram = None):
        arguments = locals()
        spectrogram.set_possible_manipulations(['function_that_does_not_remove_spectrogram_from_options'])
        spectrogram.add_manipulation('function_that_does_not_remove_spectrogram_from_options', arguments)
        return spectrogram
    
    with pytest.raises(AssertionError):
        test_spectrogram_manipulation_adds_manipulation(function_that_does_not_remove_spectrogram_from_options)
test_spect_manipulation_test_catches_Spectrogram_object_not_removed_from_manips()

In [51]:
# Manipulation works exactly as it's supposed to
def test_spectrogram_test_passes_good_manipulation_addition():
    possible_manipulations = ['function_that_works']
    def function_that_works(spectrogram = None, another = 'default'):
        arguments = locals()
        audio = Audio(path = '../tests/silence_10s.mp3', label='silence')
        spectrogram = Spectrogram(audio)
        spectrogram.set_possible_manipulations(['function_that_works'])
        
        del arguments['spectrogram']
        spectrogram.add_manipulation('function_that_works', arguments)
        
        return spectrogram
    test_spectrogram_manipulation_adds_manipulation(function_that_works)
    
test_spectrogram_test_passes_good_manipulation_addition()

In [52]:
raise NotImplementedError('Spectrogram manipulation functions are not implemented yet')

NotImplementedError: Spectrogram manipulation functions are not implemented yet

# Manipulation functions

In [None]:
def make_mel_spectrogram(
    self,
    fmax = self.sample_rate/2,
    fmin = 0,

    n_mels = 128,
    S = None,
    n_fft = 2048,
    hop_length = 360,
    win_length = 1536,
    window = 'hann',
    center = True,
    power = 2.0,
):

    '''
    Make spectrogram in decibels

    Create a mel spectrogram in decibels between
    a certain band of frequencies (fmax and fmin).

    Args:
        fmax: maximum frequency to include in spectrogram
        fmin: minimum frequency to include in spectrogram
        other arguments: see the librosa documentation:
            https://librosa.github.io/librosa/generated/librosa.feature.melspectrogram.html
            https://librosa.github.io/librosa/generated/librosa.filters.mel.html

    Returns:
        decibel-formatted spectrogram in the form of an np.array
    '''
    y = self.samples
    sr = self.sample_rate

    # Store arguments used to call librosa.feature.melspectrogram
    options = locals()

    self.spect = librosa.feature.melspectrogram(
        y = y, 
        sr = sr, 
        fmax = fmax,
        fmin = fmin,
        S = S,
        n_fft = n_fft,
        hop_length = hop_length,
        win_length = win_length,
        window = window,
        center = center,
        power = power,
        n_mels = n_mels
    )

    self.times = ['mel']
    self.freqs = ['mel']

    self.add_manipulation(('make_mel_spectrogram', options))
    for source in self.audio.sources:
        self.add_source(source)

def make_normal_spectrogram(self, **kwargs):





### Helper function to display example spectrogram

In [None]:
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt

# From https://www.agiliq.com/notebook/yanny-or-laurel.html
def display_mel_spectrogram(y, sr):
    S = librosa.feature.melspectrogram(y, sr=sr, n_mels=128)

    # Convert to log scale (dB)
    log_S = librosa.power_to_db(S, ref=np.max)

    # Make a new figure
    plt.figure(figsize=(12,4))

    # Display the spectrogram on a mel scale
    # sample rate and hop length parameters are used to render the time axis
    librosa.display.specshow(log_S, sr=sr, x_axis='time', y_axis='mel')

    # draw a color bar
    plt.colorbar(format='%+02.0f dB')

    # Make the figure layout compact
    plt.tight_layout()
    
    return S

In [None]:
spect = display_mel_spectrogram(samples, sample_rate)

### Helper function to make spectrogram

In [None]:
def make_spectrogram(
    samples,
    sample_rate = 22050, 
    fmax = 10300,
    fmin = 160,
    
    n_mels = 128,
    S = None,
    n_fft = 2048,
    hop_length = 360,
    win_length = 1536,
    window = 'hann',
    center = True,
    power = 2.0,
):
    '''
    Make mel spectrogram in decibels
    
    Create a mel spectrogram in decibels between
    a certain band of frequencies (fmax and fmin).
    
    Args:
        samples (np.array): numpy array of audio samples
        sample_rate (int or float): sample rate of samples
        fmax: maximum frequency to include in spectrogram
        fmin: minimum frequency to include in spectrogram
        other arguments: see the librosa documentation:
            https://librosa.github.io/librosa/generated/librosa.feature.melspectrogram.html
            https://librosa.github.io/librosa/generated/librosa.filters.mel.html
    
    Returns:
        decibel-formatted spectrogram in the form of an np.array
    '''
    fmax = min(fmax, (sample_rate/2))
    
    spectrogram = librosa.feature.melspectrogram(
        samples, 
        sr = sample_rate, 
        fmax = fmax,
        fmin = fmin,
        S = S,
        n_fft = n_fft,
        hop_length = hop_length,
        win_length = win_length,
        window = window,
        center = center,
        power = power,
        n_mels = n_mels
        
    )
    
    return librosa.power_to_db(spectrogram, ref=np.max)

In [None]:
samples, sample_rate = librosa.load('veryshort.wav')
spect = make_spectrogram(
    samples,
    sample_rate,
    fmax = 1000,
    fmin = 0)

npt.assert_array_almost_equal(
    spect[0], 
    np.array(
       [ -7.26741 , -12.600594, -44.393555, -28.828209, -26.990265,
       -39.72174 , -33.033913, -29.21325 , -16.667488], dtype=np.float32)
    )

## Remove random high/low spectrogram rows: `remove_random_hi_lo_bands()`

In [None]:
def remove_random_hi_lo_bands(
    spectrogram,
    min_lo = 0,
    max_lo = 10,
    min_hi = 0,
    max_hi = 6,
):
    '''
    Remove random bands at top and bottom of spectrogram
    '''
    # Ensure sensible hi/lo bands
    values = [min_lo, max_lo, min_hi, max_hi]
    for value in values: 
        if value < 0: 
            raise ValueError('Number of bands to remove must be positive')
    if (min_lo > max_lo) or (min_hi > max_hi):
        raise ValueError('Minimum number of bands to remove must be less than or equal to maximum')
    if max_lo + max_hi > spectrogram.shape[0]:
        raise ValueError('Maximum number of bands to remove cannot be greater than number of bands in spectrogram')
    
    hi_remove = random.randint(min_hi, max_hi)
    lo_remove = random.randint(min_lo, max_lo)
    
    # The high-frequency bands are the last bands in the spectrogram
    return spectrogram[lo_remove:-hi_remove]

In [None]:
# Remove the first band and the last two
removed = remove_random_hi_lo_bands(
    np.array([
        [0, 0, 0, 0],
        [1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]]
    ),
    min_lo = 1,
    max_lo = 1,
    min_hi = 2,
    max_hi = 2
)
npt.assert_array_equal(removed, np.array([
    [1, 1, 1, 1],
    [2, 2, 2, 2]]))

# Remove everything!
base_array = np.array([
    [0, 0, 0, 0],
    [3, 3, 3, 3],
    [4, 4, 4, 4]])
empty_array = base_array[0:0]
removed = remove_random_hi_lo_bands(
    base_array,
    min_lo = 1,
    max_lo = 1,
    min_hi = 2,
    max_hi = 2
)
npt.assert_array_equal(removed, empty_array)


# Ensure value checking of min/max bands to remove
# Min should be less than max
with pytest.raises(ValueError):
    test = remove_random_hi_lo_bands(
        np.array([[0, 0, 0, 0]]),
        min_lo = 2,
        max_lo = 1,
    )

# Can't remove negative bands
with pytest.raises(ValueError):
    test = remove_random_hi_lo_bands(
        np.array([[0, 0, 0, 0]]),
        min_hi = -2,
        max_hi = 2
    )

# Can't remove more bands than exist in spectrogram
with pytest.raises(ValueError): 
    test = remove_random_hi_lo_bands(
        np.array([[0, 0, 0, 0]]),
        max_lo = 1,
        max_hi = 1
    )

## Resize random columns: `resize_random_cols()`

In [None]:
np.array([[1, 2], [1, 2], [1, 2]]).shape

In [None]:
import skimage.transform

def resize_random_bands(
    spectrogram,
    rows_or_cols = 'rows',
    chance_resize = 0.5,
    min_division_size = 10,
    max_division_size = 100,
    min_stretch_factor = 0.9,
    max_stretch_factor = 1.1
):
    '''
    Resize random row or column chunks of spectrogram
    
    With a certain percentage chance, divide 
    spectrogram into random chunks along one axis 
    (row-wise or column-wise chunks) and resize
    each chunk with a randomly chosen scaling factor.
    
    Args:
        spectrogram (np.array): the spectrogram image
        rows_or_cols (string, 'rows' or 'cols'): whether
            to divide spectrogram into chunks by rows 
            (horizontal chunks spanning all times of the
            spectrogram) or columns (vertical chunks
            spanning the whole frequency range of the 
            spectrogram)
        chance_resize (float between 0 and 1): percent
            chance of dividing up the spectrogram and
            performing resizing operations
        min_division_size (int > 0): minimum size in pixels
            for each spectrogram division
        max_division_size (int): maximum size in pixels 
            for each spectrogram division.
        min_stretch_factor (float): minimum scaling factor.
            values < 1 allow spectrogram to shrink
        max_stretch_factor (float): maximum scaling factor.
            values > 1 allow spectrogram to stretch
    
    Returns:
        either the rescaled spectrogram 
            (with probability = chance_resize) or the 
            original spectrogram (prob = 1 - chance_resize)
    '''
    
    
    # Check sensibility of division sizes and stretch factors
    factors = [
        chance_resize,
        min_division_size,
        max_division_size,
        min_stretch_factor,
        max_stretch_factor
    ]
    for factor in factors: 
        if factor < 0: 
            raise ValueError('Division and stretch sizes must be > 0')
    if (min_division_size > max_division_size) or \
        (min_stretch_factor > max_stretch_factor):
            raise ValueError('Minimum division and stretch sizes must be smaller than maximum')
    
    # Probabilistically don't resize
    if random.random() > chance_resize:
        return spectrogram
    
    if rows_or_cols == 'rows':
        axis = 0
    elif rows_or_cols == 'cols':
        axis = 1
    else:
        raise ValueError("Parameter rows_or_cols must be either 'rows' or 'cols'")
    
    len_spectrogram = spectrogram.shape[1]
    len_divisions = 0
    stretched_divisions = []

    # Incrementally divide spectrogram and stretch portions
    while len_divisions < len_spectrogram: 
        # Grab randomly sized portion of spectrogram
        size_division = random.randint(min_division_size, max_division_size)
        division = np.split(
            spectrogram,
            (len_divisions, len_divisions + size_division),
            axis = axis
        )[1]
        
        # Stretch portion of spectrogram
        stretch_factor = random.uniform(min_stretch_factor, max_stretch_factor)
        if rows_or_cols == 'cols': multiplier = (1, stretch_factor)
        else: multiplier = (stretch_factor, 1)
        stretched = skimage.transform.rescale(
            division,
            multiplier,
            preserve_range = True,
            multichannel = False
        )
        stretched_divisions.append(stretched)
       
        len_divisions += size_division
        
    
    # Concatenate all stretched divisions
    return np.concatenate(stretched_divisions, axis = axis)

In [None]:
# Input testing

# Test resizing random columns
test_boi = np.array([
    [0, 1],
    [1, 1],
])
true_2x_stretched = np.array([
    [0, 0, 1, 1],
    [1, 1, 1, 1]])

stretched_2x = np.rint(resize_random_bands(
    spectrogram=test_boi,
    rows_or_cols = 'cols',
    chance_resize = 1,
    min_division_size = 1,
    max_division_size = 1,
    min_stretch_factor = 2,
    max_stretch_factor = 2))

npt.assert_array_equal(stretched_2x, true_2x_stretched)


# Test resizing random rows
test_boi = np.array([
    [0, 0],
    [1, 1],
])
true_2x_stretched = np.array([
    [0, 0],
    [0, 0],
    [1, 1],
    [1, 1]])

stretched_2x = np.rint(resize_random_bands(
    spectrogram=test_boi,
    rows_or_cols = 'rows',
    chance_resize = 1,
    min_division_size = 1,
    max_division_size = 1,
    min_stretch_factor = 2,
    max_stretch_factor = 2))

npt.assert_array_equal(stretched_2x, true_2x_stretched)

## Resize spectrogram to network dimensions: `resize_spect_random_interpolation()`

In [None]:
import PIL
import PIL.ImageOps

def resize_spect_random_interpolation(
    spectrogram,
    width,
    height,
    chance_random_interpolation = 0.15
):
    '''
    Resize spectrogram with random interpolation
    
    Convert np.ndarray spectrogram into a PIL Image, and
    resize it using either a Lanczos filter or a randomly
    selected filter.
    
    Args:
        spectrogram (np.array): spectrogram np array
        width (int): width to resize to
        height (int): height to resize to
        chance_random_interpolation (float between 0 and 1):
            the percent chance that instead of a Lanczos
            filter, a different filter will be used to 
            resize. Filter choices are Box, Nearest, Bilinear,
            Hamming, and Bicubic, as implemented in PIL.
    
    Returns:
        a resized PIL image, in RGB
    '''
    
    if not ((type(width) == int) and (type(height) == int)):
        raise ValueError('Height and width must be given in integers')
    
    # Convert spectrogram to image:
    spectrogram = spectrogram[::-1, ...] # Flip array
    spect_image = PIL.Image.fromarray(spectrogram.astype(np.uint8))
    spect_image = PIL.ImageOps.invert(spect_image) # Invert colors

    # Randomly choose interpolation
    if random.random() > chance_random_interpolation:
        interpolation = PIL.Image.LANCZOS
    else:
        interpolation = random.choice([
            PIL.Image.BOX,
            PIL.Image.NEAREST,
            PIL.Image.BILINEAR,
            PIL.Image.HAMMING,
            PIL.Image.BICUBIC
        ])

    resized = spect_image.resize((width, height), interpolation) 
    rgb = resized.convert('RGB')
    
    return rgb

In [None]:
samples, sample_rate = librosa.load('1min.wav')
spectrogram = make_spectrogram(samples, sample_rate)
image = resize_spect_random_interpolation(
    spectrogram = spectrogram,
    width = 299,
    height = 299,
    chance_random_interpolation = 0.15)
image

### TODO: more tests

## Jitter brightness, contrast, saturation: `color_jitter()`

In [None]:
import PIL.ImageEnhance

def color_jitter(
    image,
    brightness = 0.3,
    contrast = 0.3,
    saturation = 0.3,
    hue = 0.01
):
    '''
    Randomly change colors of image
    
    Randomly change the brightness, contrast,
    saturation, and hue of an image. Random
    choices are chosen uniformly from a distribution
    that includes the contrast value. This function's
    behavior mimics the behavior of the ColorJitter
    object in pytorch: 
        https://github.com/pytorch/vision/blob/3483342733673c3182bd5f8a4de3723a74ce5845/torchvision/transforms/transforms.py
    by centering brightness, contrast, and saturation
    jitters around 1, and centering hue jitters around 0.
    
    Args:
        image (PIL image):
        brightness (float > 0): how much to jitter
            brightness. Jitter amount is chosen uniformly
            from [max(0, 1 - brightness), 1 + brightness].
        contrast (float > 0): how much to jitter
            contrast. Jitter amount is chosen uniformly
            from [max(0, 1 - contrast), 1 + contrast].
        saturation (float > 0): how much to jitter
            saturation. Jitter amount is chosen uniformly
            from [max(0, 1 - saturation), 1 + saturation].
        hue (float, 0 <= hue <= 0.5): how much to jitter hue.
            Jitter amount is chosen uniformly from [-hue, hue].
    '''
    if type(image) != PIL.Image.Image:
        raise ValueError(f"image should be PIL Image. Got {type(image)}")
    if image.mode != 'RGB':
        raise ValueError(f"image.mode should be 'RGB'. Got {image.mode}")
    
    if brightness is not None:
        brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
        enhancer = PIL.ImageEnhance.Brightness(image)
        image = enhancer.enhance(brightness_factor)

    if contrast is not None:
        contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
        enhancer = PIL.ImageEnhance.Contrast(image)
        image = enhancer.enhance(contrast_factor)

    if saturation is not None:
        saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
        enhancer = PIL.ImageEnhance.Color(image)
        image = enhancer.enhance(saturation_factor)
    
    if hue is not None: 
        if (hue > 0.5) or (hue < -0.5):
            raise ValueError(f"hue should be between -0.5 and 0.5. Got {hue}")            
        hue_factor = random.uniform(-hue, hue)

        h, s, v = image.convert('HSV').split()
        
        np_h = np.array(h, dtype=np.uint8)
        # uint8 addition take cares of rotation across boundaries
        with np.errstate(over='ignore'):
            np_h += np.uint8(hue_factor * 255)
        h = PIL.Image.fromarray(np_h, 'L')

        image = PIL.Image.merge('HSV', (h, s, v)).convert('RGB')
    
    return image

### TODO: add tests

In [None]:
random.seed(1)
print(random.random())
random.seed(1)
print(random.random())

In [None]:
color_jitter(image)