Skip to content

Commit

Permalink
Add Caffe2 ONNX label service
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon committed Oct 18, 2017
1 parent cb55b43 commit f411e57
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 35 deletions.
3 changes: 2 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"server.socket_host": "0.0.0.0"
},
"commands": "yes,no,up,down,left,right,on,off,stop,go",
"model_path": "model/google-speech-dataset.pt",
"model_path": "model/google-speech-dataset-full.onnx",
"backend": "caffe2",
"train_script": "utils/model.py",
"speech_dataset_path": "/tmp/speech_dataset/",
"model_options": {
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
chainmap
cherrypy>=11.0.0,<=11.99
librosa>=0.5,<=0.5.99
numpy>=1.12
Expand All @@ -8,3 +9,4 @@ PyOpenGL_accelerate
pyttsx3==2.6
requests>=2.18,<=2.99
SpeechRecognition==3.7.1

11 changes: 9 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import cherrypy
import numpy as np

from service import LabelService, TrainingService
from service import Caffe2LabelService, TorchLabelService, TrainingService
from service import stride

def json_in(f):
Expand Down Expand Up @@ -130,8 +130,15 @@ def start(config):
speech_dataset_path = make_abspath(config["speech_dataset_path"])
commands = ["__silence__", "__unknown__"]
commands.extend(config["commands"].split(","))

backend = config["backend"]
if backend.lower() == "caffe2":
lbl_service = Caffe2LabelService(model_path, commands)
elif backend.lower() == "pytorch":
lbl_service = LabelService(model_path, labels=commands, no_cuda=config["model_options"]["no_cuda"])
else:
raise ValueError("Backend {} not supported!".format(backend))

lbl_service = LabelService(model_path, labels=commands, no_cuda=config["model_options"]["no_cuda"])
train_service = TrainingService(train_script, speech_dataset_path, config["model_options"])
cherrypy.tree.mount(ListenEndpoint(lbl_service), "/listen", rest_config)
cherrypy.tree.mount(DataEndpoint(train_service), "/data", rest_config)
Expand Down
74 changes: 52 additions & 22 deletions service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,23 @@

import librosa
import numpy as np
import torch
import torch.nn.functional as F

from utils.manage_audio import AudioSnippet
import utils.model as model
try:
import torch
import torch.nn.functional as F
except ImportError:
pass
try:
import onnx
import onnx_caffe2.backend
except ImportError:
pass

from utils.manage_audio import AudioSnippet, preprocess_audio

def _softmax(x):
return np.exp(x) / np.sum(np.exp(x))

class LabelService(object):
def __init__(self, model_filename, no_cuda=False, labels=["_silence_", "_unknown_", "command", "random"]):
self.labels = labels
self.model_filename = model_filename
self.no_cuda = no_cuda
self.filters = librosa.filters.dct(40, 40)
self.reload()

def reload(self):
config = model.find_config(model.ConfigType.CNN_TRAD_POOL2)
config["n_labels"] = len(self.labels)
self.model = model.SpeechModel(config)
if not self.no_cuda:
self.model.cuda()
self.model.load(self.model_filename)
self.model.eval()

def evaluate(self, speech_dirs, indices=[]):
dir_labels = {}
if indices:
Expand All @@ -51,6 +45,42 @@ def evaluate(self, speech_dirs, indices=[]):
accuracy.append(int(label == dir_labels[folder]))
return sum(accuracy) / len(accuracy)

def label(self, wav_data):
raise NotImplementedError

class Caffe2LabelService(LabelService):
def __init__(self, onnx_filename, labels):
self.labels = labels
self.model_filename = onnx_filename
self.filters = librosa.filters.dct(40, 40)
self._graph = onnx.load(onnx_filename)
self._in_name = self._graph.graph.input[0].name
self.model = onnx_caffe2.backend.prepare(self._graph)

def label(self, wav_data):
wav_data = np.frombuffer(wav_data, dtype=np.int16) / 32768.
model_in = preprocess_audio(wav_data, 40, self.filters).unsqueeze(0)
model_in = model_in.astype(np.float32)
predictions = _softmax(self.model.run({self._in_name: model_in})[0])
return (self.labels[np.argmax(predictions)], max(predictions))

class TorchLabelService(LabelService):
def __init__(self, model_filename, no_cuda=False, labels=["_silence_", "_unknown_", "command", "random"]):
self.labels = labels
self.model_filename = model_filename
self.no_cuda = no_cuda
self.filters = librosa.filters.dct(40, 40)
self.reload()

def reload(self):
config = model.find_config(model.ConfigType.CNN_TRAD_POOL2)
config["n_labels"] = len(self.labels)
self.model = model.SpeechModel(config)
if not self.no_cuda:
self.model.cuda()
self.model.load(self.model_filename)
self.model.eval()

def label(self, wav_data):
"""Labels audio data as one of the specified trained labels
Expand All @@ -61,7 +91,7 @@ def label(self, wav_data):
A (most likely label, probability) tuple
"""
wav_data = np.frombuffer(wav_data, dtype=np.int16) / 32768.
model_in = model.preprocess_audio(wav_data, 40, self.filters).unsqueeze(0)
model_in = preprocess_audio(wav_data, 40, self.filters).unsqueeze(0)
model_in = torch.autograd.Variable(model_in, requires_grad=False)
if not self.no_cuda:
model_in = model_in.cuda()
Expand Down
8 changes: 8 additions & 0 deletions utils/manage_audio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import argparse
import os
import random
Expand All @@ -13,6 +14,13 @@ def set_speech_format(f):
f.setsampwidth(2)
f.setframerate(16000)

def preprocess_audio(data, n_mels, dct_filters):
data = librosa.feature.melspectrogram(data, sr=16000, n_mels=n_mels, hop_length=160, n_fft=480, fmin=20, fmax=4000)
data[data > 0] = np.log(data[data > 0])
data = [np.matmul(dct_filters, x) for x in np.split(data, data.shape[1], axis=1)]
data = np.array(data, order="F").squeeze(2).astype(np.float32)
return data

class AudioSnippet(object):
_dct_filters = librosa.filters.dct(40, 40)
def __init__(self, byte_data=b"", dtype=np.int16):
Expand Down
14 changes: 4 additions & 10 deletions utils/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import ChainMap
from enum import Enum
import hashlib
import math
import os
import random
import re

from chainmap import ChainMap
from torch.autograd import Variable
import librosa
import numpy as np
Expand All @@ -14,6 +14,8 @@
import torch.nn.functional as F
import torch.utils.data as data

from manage_audio import preprocess_audio

class SimpleCache(dict):
def __init__(self, limit):
super().__init__()
Expand Down Expand Up @@ -128,14 +130,6 @@ def forward(self, x):
x = self.dropout(x)
return self.output(x)

def preprocess_audio(data, n_mels, dct_filters):
data = librosa.feature.melspectrogram(data, sr=16000, n_mels=n_mels, hop_length=160, n_fft=480,
fmin=20, fmax=4000)
data[data > 0] = np.log(data[data > 0])
data = [np.matmul(dct_filters, x) for x in np.split(data, data.shape[1], axis=1)]
data = np.array(data, order="F").squeeze(2).astype(np.float32)
return torch.from_numpy(data) # shape: (frames, dct_coeffs)

class DatasetType(Enum):
TRAIN = 0
DEV = 1
Expand Down Expand Up @@ -222,7 +216,7 @@ def preprocess(self, example, silence=False):
if random.random() < self.noise_prob or silence:
a = random.random() * 0.1
data = np.clip(a * bg_noise + data, -1, 1)
data = preprocess_audio(data, self.n_mels, self.filters)
data = torch.from_numpy(preprocess_audio(data, self.n_mels, self.filters))
self._audio_cache[example] = data
return data

Expand Down

0 comments on commit f411e57

Please sign in to comment.