-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Luigi pipeline update #78
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
17dd7bd
Adding config for nsynth
jorshi ec435fe
Copying speech commands -- getting download and extraction running
jorshi e0bbce0
Metadata creation for nsynth
jorshi e7799f7
Renaming nsynth to nsynth pitch
jorshi fa4f90b
Remove monowavtrim
jorshi 04c8604
Named import for luigi utils
jorshi e9b14e3
Starting to create config classes
jorshi 1feb45b
Adding name partition configs
jorshi 518ebfc
Dynamic creation of the download and extract tasks
jorshi 46c3d8a
Process metadata being built dynamically
jorshi 94ae944
Starting to genericize the audio pipeline
jorshi bd1a729
Moved all the audio processing out of speech commands
jorshi c7fc530
Adding some docstrings
jorshi 6d7fb68
Cleaning up config
jorshi c4c05e6
versioned task name
jorshi c6591ff
Merge branch 'nsynth' into luigi-pipeline-update
jorshi 56cda4c
Updating nsynth config
jorshi 998acd0
remove string config passing into dataset builder
jorshi b7a01b2
Formatting
jorshi fe76e96
sample rate can be passed in as a command line arg
jorshi b2b3f56
Cleanup
jorshi 8ac1c0e
Move dataset specific config into the same files as tasks
jorshi d419651
Remove config folder
jorshi 2bd7ffb
Adding dataset preprocessing usage to readme
jorshi 32a1933
A bit of cleanup
jorshi 01ba125
Updating subsample numbers
jorshi d8f5ef1
Update some docstrings in the dataset builder
jorshi 63b3d29
Removing command-line invocation from individual tasks -- must be fro…
jorshi eb65831
Adding click requirement
jorshi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
max-line-length = 88 | ||
extend-ignore = | ||
# See https://github.com/PyCQA/pycodestyle/issues/373 | ||
E203, | ||
E203, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Generic configuration used by all tasks | ||
""" | ||
|
||
from typing import Dict, List | ||
|
||
|
||
class DatasetConfig: | ||
""" | ||
A base class config class for HEAR datasets. | ||
|
||
Args: | ||
task_name: Unique name for this task | ||
version: version string for the dataset | ||
download_urls: A dictionary of URLs to download the dataset files from | ||
sample_duration: All samples with be padded / trimmed to this length | ||
""" | ||
|
||
def __init__( | ||
self, task_name: str, version: str, download_urls: Dict, sample_duration: float | ||
): | ||
self.task_name = task_name | ||
self.version = version | ||
self.download_urls = download_urls | ||
self.sample_duration = sample_duration | ||
|
||
@property | ||
def versioned_task_name(self): | ||
return f"{self.task_name}-{self.version}" | ||
|
||
|
||
class PartitionConfig: | ||
""" | ||
A configuration class for creating named partitions in a dataset | ||
|
||
Args: | ||
name: name of the partition | ||
max_files: an integer number of samples to cap this partition at, | ||
defaults to None for no maximum. | ||
""" | ||
|
||
def __init__(self, name: str, max_files: int = None): | ||
self.name = name | ||
self.max_files = max_files | ||
|
||
|
||
class PartitionedDatasetConfig(DatasetConfig): | ||
""" | ||
A base class config class for HEAR datasets. This config should be used when | ||
there are pre-defined data partitions. | ||
|
||
Args: | ||
task_name: Unique name for this task | ||
version: version string for the dataset | ||
download_urls: A dictionary of URLs to download the dataset files from | ||
sample_duration: All samples with be padded / trimmed to this length | ||
partitions: A list of PartitionConfig objects describing the partitions | ||
""" | ||
|
||
def __init__( | ||
self, | ||
task_name: str, | ||
version: str, | ||
download_urls: Dict, | ||
sample_duration: float, | ||
partitions: List[PartitionConfig], | ||
): | ||
super().__init__(task_name, version, download_urls, sample_duration) | ||
self.partitions = partitions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Pre-processing pipeline for NSynth pitch detection | ||
""" | ||
|
||
import os | ||
from pathlib import Path | ||
from functools import partial | ||
import logging | ||
from typing import List | ||
|
||
import luigi | ||
import pandas as pd | ||
from slugify import slugify | ||
|
||
from heareval.tasks.dataset_config import ( | ||
PartitionedDatasetConfig, | ||
PartitionConfig, | ||
) | ||
from heareval.tasks.util.dataset_builder import DatasetBuilder | ||
import heareval.tasks.util.luigi as luigi_util | ||
|
||
logger = logging.getLogger("luigi-interface") | ||
|
||
|
||
# Dataset configuration | ||
class NSynthPitchConfig(PartitionedDatasetConfig): | ||
def __init__(self): | ||
super().__init__( | ||
task_name="nsynth-pitch", | ||
version="v2.2.3", | ||
download_urls={ | ||
"train": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-train.jsonwav.tar.gz", # noqa: E501 | ||
"valid": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-valid.jsonwav.tar.gz", # noqa: E501 | ||
"test": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-test.jsonwav.tar.gz", # noqa: E501 | ||
}, | ||
# All samples will be trimmed / padded to this length | ||
sample_duration=4.0, | ||
# Pre-defined partitions in the dataset. Number of files in each split is | ||
# train: 85,111; valid: 10,102; test: 4890. These values will be a bit less | ||
# after filter the pitches to be only within the piano range. | ||
# To subsample a partition, set the max_files to an integer. | ||
# TODO: Should we subsample NSynth? | ||
partitions=[ | ||
PartitionConfig(name="train", max_files=10000), | ||
PartitionConfig(name="valid", max_files=1000), | ||
PartitionConfig(name="test", max_files=None), | ||
], | ||
) | ||
# We only include pitches that are on a standard 88-key MIDI piano | ||
self.pitch_range = (21, 108) | ||
|
||
|
||
config = NSynthPitchConfig() | ||
|
||
|
||
class ConfigureProcessMetaData(luigi_util.WorkTask): | ||
""" | ||
Custom metadata pre-processing for the NSynth task. Creates a metadata csv | ||
file that will be used by downstream luigi tasks to curate the final dataset. | ||
""" | ||
|
||
outfile = luigi.Parameter() | ||
|
||
def requires(self): | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def get_rel_path(root: Path, item: pd.DataFrame) -> str: | ||
# Creates the relative path to an audio file given the note_str | ||
audio_path = root.joinpath("audio") | ||
filename = f"{item}.wav" | ||
return audio_path.joinpath(filename) | ||
|
||
@staticmethod | ||
def slugify_file_name(filename: str) -> str: | ||
return f"{slugify(filename)}.wav" | ||
|
||
def get_split_metadata(self, split: str) -> pd.DataFrame: | ||
logger.info(f"Preparing metadata for {split}") | ||
|
||
# Loads and prepares the metadata for a specific split | ||
split_path = Path(self.requires()[split].workdir).joinpath(f"nsynth-{split}") | ||
|
||
metadata = pd.read_json(split_path.joinpath("examples.json"), orient="index") | ||
|
||
# Filter out pitches that are not within the range | ||
metadata = metadata[metadata["pitch"] >= config.pitch_range[0]] | ||
metadata = metadata[metadata["pitch"] <= config.pitch_range[1]] | ||
|
||
metadata = metadata.assign(label=lambda df: df["pitch"]) | ||
metadata = metadata.assign( | ||
relpath=lambda df: df["note_str"].apply( | ||
partial(self.get_rel_path, split_path) | ||
) | ||
) | ||
metadata = metadata.assign( | ||
slug=lambda df: df["note_str"].apply(self.slugify_file_name) | ||
) | ||
metadata = metadata.assign(partition=lambda df: split) | ||
metadata = metadata.assign( | ||
filename_hash=lambda df: df["slug"].apply(luigi_util.filename_to_int_hash) | ||
) | ||
|
||
return metadata[luigi_util.PROCESSMETADATACOLS] | ||
|
||
def run(self): | ||
|
||
# Get metadata for each of the data splits | ||
process_metadata = pd.concat( | ||
[self.get_split_metadata(split) for split in self.requires()] | ||
) | ||
|
||
process_metadata.to_csv( | ||
os.path.join(self.workdir, self.outfile), | ||
columns=luigi_util.PROCESSMETADATACOLS, | ||
header=False, | ||
index=False, | ||
) | ||
|
||
self.mark_complete() | ||
|
||
|
||
def main(num_workers: int, sample_rates: List[int]): | ||
|
||
builder = DatasetBuilder(config) | ||
|
||
# Build the dataset pipeline with the custom metadata configuration task | ||
download_tasks = builder.download_and_extract_tasks() | ||
configure_metadata = builder.build_task( | ||
ConfigureProcessMetaData, | ||
requirements=download_tasks, | ||
params={"outfile": "process_metadata.csv"}, | ||
) | ||
audio_tasks = builder.prepare_audio_from_metadata_task( | ||
configure_metadata, sample_rates | ||
) | ||
|
||
builder.run(audio_tasks, num_workers=num_workers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Runs a luigi pipeline to build a dataset | ||
""" | ||
|
||
import logging | ||
import multiprocessing | ||
from typing import Optional | ||
|
||
import click | ||
|
||
import heareval.tasks.speech_commands as speech_commands | ||
import heareval.tasks.nsynth_pitch as nsynth_pitch | ||
|
||
logger = logging.getLogger("luigi-interface") | ||
|
||
tasks = {"speech_commands": speech_commands, "nsynth_pitch": nsynth_pitch} | ||
|
||
|
||
@click.command() | ||
@click.argument("task") | ||
@click.option( | ||
"--num-workers", | ||
default=None, | ||
help="Number of CPU workers to use when running. " | ||
"If not provided all CPUs are used.", | ||
type=int, | ||
) | ||
@click.option( | ||
"--sample-rate", | ||
default=None, | ||
help="Perform resampling only to this sample rate. " | ||
"By default we resample to 16000, 22050, 44100, 48000.", | ||
type=int, | ||
) | ||
def run( | ||
task: str, num_workers: Optional[int] = None, sample_rate: Optional[int] = None | ||
): | ||
|
||
if num_workers is None: | ||
num_workers = multiprocessing.cpu_count() | ||
logger.info(f"Using {num_workers} workers") | ||
|
||
if sample_rate is None: | ||
sample_rates = [16000, 22050, 44100, 48000] | ||
else: | ||
sample_rates = [sample_rate] | ||
|
||
tasks[task].main(num_workers=num_workers, sample_rates=sample_rates) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Full nsynth is quite large. Should we subsample it?