Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swift API for keyword spotting. #1027

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/scripts/test-swift.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ echo "pwd: $PWD"
cd swift-api-examples
ls -lh

./run-keyword-spotting-from-file.sh
rm ./keyword-spotting-from-file
rm -rf sherpa-onnx-kws-*

./run-streaming-hlg-decode-file.sh
rm ./streaming-hlg-decode-file
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
Expand Down
1 change: 1 addition & 0 deletions swift-api-examples/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sherpa-onnx-paraformer-zh-2023-09-14
!*.sh
*.bak
streaming-hlg-decode-file
keyword-spotting-from-file
108 changes: 108 additions & 0 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,111 @@ class SherpaOnnxSpokenLanguageIdentificationWrapper {
return SherpaOnnxSpokenLanguageIdentificationResultWrapper(result: result)
}
}

// keyword spotting

class SherpaOnnxKeywordResultWrapper {
/// A pointer to the underlying counterpart in C
let result: UnsafePointer<SherpaOnnxKeywordResult>!

var keyword: String {
return String(cString: result.pointee.keyword)
}

var count: Int32 {
return result.pointee.count
}

var tokens: [String] {
if let tokensPointer = result.pointee.tokens_arr {
var tokens: [String] = []
for index in 0..<count {
if let tokenPointer = tokensPointer[Int(index)] {
let token = String(cString: tokenPointer)
tokens.append(token)
}
}
return tokens
} else {
let tokens: [String] = []
return tokens
}
}

init(result: UnsafePointer<SherpaOnnxKeywordResult>!) {
self.result = result
}

deinit {
if let result {
DestroyKeywordResult(result)
}
}
}

func sherpaOnnxKeywordSpotterConfig(
featConfig: SherpaOnnxFeatureConfig,
modelConfig: SherpaOnnxOnlineModelConfig,
keywordsFile: String,
maxActivePaths: Int = 4,
numTrailingBlanks: Int = 1,
keywordsScore: Float = 1.0,
keywordsThreshold: Float = 0.25
) -> SherpaOnnxKeywordSpotterConfig {
return SherpaOnnxKeywordSpotterConfig(
feat_config: featConfig,
model_config: modelConfig,
max_active_paths: Int32(maxActivePaths),
num_trailing_blanks: Int32(numTrailingBlanks),
keywords_score: keywordsScore,
keywords_threshold: keywordsThreshold,
keywords_file: toCPointer(keywordsFile)
)
}

class SherpaOnnxKeywordSpotterWrapper {
/// A pointer to the underlying counterpart in C
let spotter: OpaquePointer!
var stream: OpaquePointer!

init(
config: UnsafePointer<SherpaOnnxKeywordSpotterConfig>!
) {
spotter = CreateKeywordSpotter(config)
stream = CreateKeywordStream(spotter)
}

deinit {
if let stream {
DestroyOnlineStream(stream)
}

if let spotter {
DestroyKeywordSpotter(spotter)
}
}

func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
}

func isReady() -> Bool {
return IsKeywordStreamReady(spotter, stream) == 1 ? true : false
}

func decode() {
DecodeKeywordStream(spotter, stream)
}

func getResult() -> SherpaOnnxKeywordResultWrapper {
let result: UnsafePointer<SherpaOnnxKeywordResult>? = GetKeywordResult(
spotter, stream)
return SherpaOnnxKeywordResultWrapper(result: result)
}

/// Signal that no more audio samples would be available.
/// After this call, you cannot call acceptWaveform() any more.
func inputFinished() {
InputFinished(stream)
}
}
83 changes: 83 additions & 0 deletions swift-api-examples/keyword-spotting-from-file.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import AVFoundation

extension AudioBuffer {
func array() -> [Float] {
return Array(UnsafeBufferPointer(self))
}
}

extension AVAudioPCMBuffer {
func array() -> [Float] {
return self.audioBufferList.pointee.mBuffers.array()
}
}

func run() {
let filePath = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
let encoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx"
let decoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx"
let joiner =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx"
let tokens =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
let keywordsFile =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
let transducerConfig = sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner
)

let modelConfig = sherpaOnnxOnlineModelConfig(
tokens: tokens,
transducer: transducerConfig
)

let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
featureDim: 80
)
var config = sherpaOnnxKeywordSpotterConfig(
featConfig: featConfig,
modelConfig: modelConfig,
keywordsFile: keywordsFile
)

let spotter = SherpaOnnxKeywordSpotterWrapper(config: &config)

let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
let audioFile = try! AVAudioFile(forReading: fileURL as URL)

let audioFormat = audioFile.processingFormat
assert(audioFormat.sampleRate == 16000)
assert(audioFormat.channelCount == 1)
assert(audioFormat.commonFormat == AVAudioCommonFormat.pcmFormatFloat32)

let audioFrameCount = UInt32(audioFile.length)
let audioFileBuffer = AVAudioPCMBuffer(pcmFormat: audioFormat, frameCapacity: audioFrameCount)

try! audioFile.read(into: audioFileBuffer!)
let array: [Float]! = audioFileBuffer?.array()
spotter.acceptWaveform(samples: array)

let tailPadding = [Float](repeating: 0.0, count: 3200)
spotter.acceptWaveform(samples: tailPadding)

spotter.inputFinished()
while spotter.isReady() {
spotter.decode()
let keyword = spotter.getResult().keyword
if keyword != "" {
print("Detected: \(keyword)")
}
}
}

@main
struct App {
static func main() {
run()
}
}
34 changes: 34 additions & 0 deletions swift-api-examples/run-keyword-spotting-from-file.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env bash

set -ex

if [ ! -d ../build-swift-macos ]; then
echo "Please run ../build-swift-macos.sh first!"
exit 1
fi

if [ ! -d ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
fi

if [ ! -e ./keyword-spotting-from-file ]; then
# Note: We use -lc++ to link against libc++ instead of libstdc++
swiftc \
-lc++ \
-I ../build-swift-macos/install/include \
-import-objc-header ./SherpaOnnx-Bridging-Header.h \
./keyword-spotting-from-file.swift ./SherpaOnnx.swift \
-L ../build-swift-macos/install/lib/ \
-l sherpa-onnx \
-l onnxruntime \
-o keyword-spotting-from-file

strip keyword-spotting-from-file
else
echo "./keyword-spotting-from-file exists - skip building"
fi

export DYLD_LIBRARY_PATH=$PWD/../build-swift-macos/install/lib:$DYLD_LIBRARY_PATH
./keyword-spotting-from-file