Skip to content

Commit

Permalink
Rename "get_embedding_model" to "load_embedding_model"
Browse files Browse the repository at this point in the history
  • Loading branch information
auroracramer committed Apr 12, 2019
1 parent 6b2653c commit 5829321
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 38 deletions.
4 changes: 2 additions & 2 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ By default, the corresponding model file is loaded every time this function is c
import openl3
import soundfile as sf
model = openl3.models.get_embedding_model(input_repr="mel256", content_type="music", embedding_size=6144)
model = openl3.models.load_embedding_model(input_repr="mel256", content_type="music", embedding_size=6144)
emb, ts = openl3.get_embedding(audio, sr, model=model)
To compute embeddings for an audio file and save them locally, you can use code like the following:
Expand Down Expand Up @@ -113,7 +113,7 @@ Like before, you can also load the model before processing the file so that load
import openl3
import numpy as np
model = openl3.models.get_embedding_model(input_repr="mel256", content_type="music", embedding_size=6144)
model = openl3.models.load_embedding_model(input_repr="mel256", content_type="music", embedding_size=6144)
audio_filepath = '/path/to/file.wav'
# Saves the file to '/path/to/file.npz'
Expand Down
4 changes: 2 additions & 2 deletions openl3/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from openl3.openl3_exceptions import OpenL3Error
from openl3 import process_file
from openl3.models import get_embedding_model
from openl3.models import load_embedding_model
from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError
from collections import Iterable
from six import string_types
Expand Down Expand Up @@ -85,7 +85,7 @@ def run(inputs, output_dir=None, suffix=None, input_repr="mel256", content_type=
sys.exit(-1)

# Load model
model = get_embedding_model(input_repr, content_type, embedding_size)
model = load_embedding_model(input_repr, content_type, embedding_size)

# Process all files in the arguments
for filepath in file_list:
Expand Down
4 changes: 2 additions & 2 deletions openl3/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numbers import Real
import warnings
import keras
from .models import get_embedding_model
from .models import load_embedding_model
from .openl3_exceptions import OpenL3Error
from .openl3_warnings import OpenL3Warning

Expand Down Expand Up @@ -114,7 +114,7 @@ def get_embedding(audio, sr, model=None, input_repr="mel256",

# Get embedding model
if model is None:
model = get_embedding_model(input_repr, content_type, embedding_size)
model = load_embedding_model(input_repr, content_type, embedding_size)

audio_len = audio.size
frame_len = TARGET_SR
Expand Down
6 changes: 3 additions & 3 deletions openl3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
}


def get_embedding_model(input_repr, content_type, embedding_size):
def load_embedding_model(input_repr, content_type, embedding_size):
"""
Returns a model with the given characteristics. Loads the model
if the model has not been loaded yet.
Expand All @@ -54,7 +54,7 @@ def get_embedding_model(input_repr, content_type, embedding_size):
warnings.simplefilter("ignore")
m = MODELS[input_repr]()

m.load_weights(get_embedding_model_path(input_repr, content_type))
m.load_weights(load_embedding_model_path(input_repr, content_type))

# Pooling for final output embedding size
pool_size = POOLINGS[input_repr][embedding_size]
Expand All @@ -64,7 +64,7 @@ def get_embedding_model(input_repr, content_type, embedding_size):
return m


def get_embedding_model_path(input_repr, content_type):
def load_embedding_model_path(input_repr, content_type):
"""
Returns the local path to the model weights file for the model
with the given characteristics
Expand Down
30 changes: 23 additions & 7 deletions tests/notebooks/generate_openl3_regression_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
"outputs": [],
"source": [
"# path to store output embeddings\n",
"output_dir = os.path.expanduser('~/Downloads/')"
"output_dir = os.path.expanduser('~/openl3_output/')"
]
},
{
Expand All @@ -221,9 +221,25 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1204: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"keep_dims is deprecated, use keepdims instead\n",
"WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1238: calling reduce_sum (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"keep_dims is deprecated, use keepdims instead\n",
"WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1255: calling reduce_prod (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"keep_dims is deprecated, use keepdims instead\n"
]
}
],
"source": [
"# compute mel256/music/6144 regression embedding\n",
"suffix=None\n",
Expand All @@ -247,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -352,9 +368,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:openl3]",
"display_name": "Python 3",
"language": "python",
"name": "conda-env-openl3-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -366,7 +382,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_get_embedding():
assert not np.any(np.isnan(emb1))

# Make sure we can load a model and pass it in
model = openl3.models.get_embedding_model("linear", "env", 6144)
model = openl3.models.load_embedding_model("linear", "env", 6144)
emb1load, ts1load = openl3.get_embedding(audio, sr,
model=model, center=True, hop_size=hop_size, verbose=1)
assert np.all(np.abs(emb1load - emb1) < tol)
Expand Down
42 changes: 21 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,59 @@
from openl3.models import get_embedding_model, get_embedding_model_path
from openl3.models import load_embedding_model, load_embedding_model_path


def test_get_embedding_model_path():
embedding_model_path = get_embedding_model_path('linear', 'music')
def test_load_embedding_model_path():
embedding_model_path = load_embedding_model_path('linear', 'music')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_linear_music.h5'

embedding_model_path = get_embedding_model_path('linear', 'env')
embedding_model_path = load_embedding_model_path('linear', 'env')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_linear_env.h5'

embedding_model_path = get_embedding_model_path('mel128', 'music')
embedding_model_path = load_embedding_model_path('mel128', 'music')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_mel128_music.h5'

embedding_model_path = get_embedding_model_path('mel128', 'env')
embedding_model_path = load_embedding_model_path('mel128', 'env')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_mel128_env.h5'

embedding_model_path = get_embedding_model_path('mel256', 'music')
embedding_model_path = load_embedding_model_path('mel256', 'music')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_mel256_music.h5'

embedding_model_path = get_embedding_model_path('mel256', 'env')
embedding_model_path = load_embedding_model_path('mel256', 'env')
assert '/'.join(embedding_model_path.split('/')[-2:]) == 'openl3/openl3_audio_mel256_env.h5'


def test_get_embedding_model():
m = get_embedding_model('linear', 'music', 6144)
def test_load_embedding_model():
m = load_embedding_model('linear', 'music', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('linear', 'music', 512)
m = load_embedding_model('linear', 'music', 512)
assert m.output_shape[1] == 512

m = get_embedding_model('linear', 'env', 6144)
m = load_embedding_model('linear', 'env', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('linear', 'env', 512)
m = load_embedding_model('linear', 'env', 512)
assert m.output_shape[1] == 512

m = get_embedding_model('mel128', 'music', 6144)
m = load_embedding_model('mel128', 'music', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('mel128', 'music', 512)
m = load_embedding_model('mel128', 'music', 512)
assert m.output_shape[1] == 512

m = get_embedding_model('mel128', 'env', 6144)
m = load_embedding_model('mel128', 'env', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('mel128', 'env', 512)
m = load_embedding_model('mel128', 'env', 512)
assert m.output_shape[1] == 512

m = get_embedding_model('mel256', 'music', 6144)
m = load_embedding_model('mel256', 'music', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('mel256', 'music', 512)
m = load_embedding_model('mel256', 'music', 512)
assert m.output_shape[1] == 512

m = get_embedding_model('mel256', 'env', 6144)
m = load_embedding_model('mel256', 'env', 6144)
assert m.output_shape[1] == 6144

m = get_embedding_model('mel256', 'env', 512)
m = load_embedding_model('mel256', 'env', 512)
assert m.output_shape[1] == 512

0 comments on commit 5829321

Please sign in to comment.