Skip to content

Commit

Permalink
[DAP/Whisper] Add log information for whisper preprocess.
Browse files Browse the repository at this point in the history
  • Loading branch information
taiqzheng committed Jul 2, 2024
1 parent 566ce84 commit c622a8f
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 137 deletions.
72 changes: 69 additions & 3 deletions examples/BuddyWhisper/whisper-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,72 @@

#include "whisper-main.h"

// -----------------------------------------------------------------------------
// Helper Functions
// -----------------------------------------------------------------------------

/// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }

/// Print information for each iteration.
void printIterInfo(size_t iterIdx, std::string str, double time) {
std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m";
std::cout << "Token: " << str << " | "
<< "Time: " << time << "s" << std::endl;
}

/// Load parameters into data container.
void loadParameters(const std::string &paramFilePath,
MemRef<float, 1> &params) {
const auto loadStart = std::chrono::high_resolution_clock::now();
std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
if (!paramFile.is_open()) {
throw std::runtime_error("[Error] Failed to open params file!");
}
printLogLabel();
std::cout << "Loading params..." << std::endl;
printLogLabel();
std::cout << "Params file: " << std::filesystem::canonical(paramFilePath)
<< std::endl;
paramFile.read(reinterpret_cast<char *>(params.getData()),
sizeof(float) * (params.getSize()));
if (paramFile.fail()) {
throw std::runtime_error("Error occurred while reading params file!");
}
paramFile.close();
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "Params load time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}

/// Calculate audioInput from rawAudioData.
void runPreprocess(MemRef<double, 1> &rawAudioData,
MemRef<float, 3> &audioFeatures) {
// Move data into container.
intptr_t dataShape[1] = {AudioDataLength};
rawAudioData = std::move(MemRef<double, 1>(rawSpeech, dataShape));
printLogLabel();
std::cout << "Preprocessing audio..." << std::endl;
const auto loadStart = std::chrono::high_resolution_clock::now();
dap::whisperPreprocess(&rawAudioData, &audioFeatures);
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "Audio preprocess time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}

/// Find the index of the max value.
int findMaxIndex(const float *start, const float *end) {
return std::distance(start, std::max_element(start, end));
}

// -----------------------------------------------------------------------------
// Whisper Inference Main Entry
// -----------------------------------------------------------------------------
Expand All @@ -39,6 +105,7 @@ int main() {
// - Output container.
// - Parameters container.
Text<size_t, 2> outputContainer;
MemRef<double, 1> rawAudioContainer({AudioDataLength});
MemRef<float, 3> audioInput({1, 80, 3000});
MemRef<float, 3> resultContainer[2] = {
MemRef<float, 3>({1, 1500, 512}, false, 0),
Expand All @@ -50,11 +117,10 @@ int main() {
/// Fill data into containers
// - Output: register vocabulary.
// - Parameters: load parameters from the `arg0` file into the container.
// - Input: generate audioInput from rawAudioData.
// - Input: compute audioInput.
outputContainer.loadVocab(vocabDir);
loadParameters(paramsDir, paramsContainer);
rawAudioData = std::move(MemRef<double, 1>(rawSpeech, inputShape));
dap::whisperPreprocess(&rawAudioData, &audioInput);
runPreprocess(rawAudioContainer, audioInput);

/// Run Whisper Inference
// - Perform the forward function.
Expand Down
63 changes: 3 additions & 60 deletions examples/BuddyWhisper/whisper-main.h

Large diffs are not rendered by default.

80 changes: 80 additions & 0 deletions examples/DAPDialect/WhisperPreprocess.cpp

Large diffs are not rendered by default.

74 changes: 0 additions & 74 deletions examples/DAPDialect/whisperPreprocess.cpp

This file was deleted.

0 comments on commit c622a8f

Please sign in to comment.