Skip to content

Commit

Permalink
keytap3 : multi-thread wasm version
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 1, 2022
1 parent c139e16 commit cac8409
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 113 deletions.
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ if (EMSCRIPTEN)
set(CMAKE_EXE_LINKER_FLAGS " \
--bind \
--use-preload-cache \
--closure 1 \
-s ASSERTIONS=1 \
-s NO_EXIT_RUNTIME=0 \
-s PTHREAD_POOL_SIZE=8 \
-s PTHREAD_POOL_SIZE=16 \
-s INITIAL_MEMORY=536870912 \
")
elseif(MINGW)
Expand Down
5 changes: 4 additions & 1 deletion audio-logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ bool AudioLogger::install(Parameters && parameters) {
return false;
}

if (SDL_Init(SDL_INIT_AUDIO) < 0) {
static bool isInitialized = false;
if (!isInitialized && SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
isInitialized = true;

int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
printf("Found %d capture devices:\n", nDevices);
Expand Down Expand Up @@ -134,6 +136,7 @@ bool AudioLogger::install(Parameters && parameters) {
}

printf("Opened capture device succesfully!\n");
printf(" DeviceId: %d\n", data.deviceIdIn);
printf(" Frequency: %d\n", obtainedSpec.freq);
printf(" Format: %d (%d bytes)\n", obtainedSpec.format, data.sampleSize_bytes);
printf(" Channels: %d\n", obtainedSpec.channels);
Expand Down
47 changes: 5 additions & 42 deletions common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <fstream>
#include <deque>
#include <algorithm>
#include <condition_variable>

#ifndef pi
#define pi 3.1415926535897932384626433832795
Expand Down Expand Up @@ -511,21 +510,7 @@ std::tuple<TValueCC, TOffset> findBestCC(
auto sum02 = std::get<1>(ret);

#ifdef __EMSCRIPTEN__
TOffset cbesto = -1;
TValueCC cbestcc = -1.0f;

for (int o = -alignWindow; o <= alignWindow; ++o) {
auto cc = calcCC(waveform0, waveform1, sum0, sum02, is00, is0 + o, is1 + o);
if (cc > cbestcc) {
cbesto = o;
cbestcc = cc;
}
}

if (cbestcc > bestcc) {
bestcc = cbestcc;
besto = cbesto;
}
int nWorkers = std::min(4, std::max(1, int(std::thread::hardware_concurrency()) - 2));
#else
int nWorkers = std::min(4u, std::thread::hardware_concurrency());
std::mutex mutex;
Expand Down Expand Up @@ -621,15 +606,12 @@ bool calculateSimilartyMap(
res.resize(nPresses);
for (auto & x : res) x.resize(nPresses);

int nFinished = 0;
#ifdef __EMSCRIPTEN__
int nWorkers = std::max(1, std::min(4, int(std::thread::hardware_concurrency()) - 4));
int nWorkers = std::min(kMaxThreads, std::max(1, int(std::thread::hardware_concurrency()) - 2));
#else
int nWorkers = std::thread::hardware_concurrency();
#endif

std::mutex mutex;
std::condition_variable cv;
std::vector<std::thread> workers(nWorkers);
for (int iw = 0; iw < (int) workers.size(); ++iw) {
auto & worker = workers[iw];
Expand Down Expand Up @@ -668,18 +650,10 @@ bool calculateSimilartyMap(
}
avgcc /= (nPresses - 1);
}

{
std::lock_guard<std::mutex> lock(mutex);
++nFinished;
cv.notify_one();
}
}, iw);
worker.detach();
}

std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [&]() { return nFinished == nWorkers; });
for (auto & worker : workers) worker.join();

return true;
}
Expand All @@ -700,15 +674,12 @@ bool calculateSimilartyMap(
res.resize(nPresses);
for (auto & x : res) x.resize(nPresses);

int nFinished = 0;
#ifdef __EMSCRIPTEN__
int nWorkers = std::max(1, std::min(4, int(std::thread::hardware_concurrency()) - 4));
int nWorkers = std::min(kMaxThreads, std::max(1, int(std::thread::hardware_concurrency()) - 2));
#else
int nWorkers = std::thread::hardware_concurrency();
#endif

std::mutex mutex;
std::condition_variable cv;
std::vector<std::thread> workers(nWorkers);
for (int iw = 0; iw < (int) workers.size(); ++iw) {
auto & worker = workers[iw];
Expand Down Expand Up @@ -747,18 +718,10 @@ bool calculateSimilartyMap(
}
avgcc /= (nPresses - 1);
}

{
std::lock_guard<std::mutex> lock(mutex);
++nFinished;
cv.notify_one();
}
}, iw);
worker.detach();
}

std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [&]() { return nFinished == nWorkers; });
for (auto & worker : workers) worker.join();

return true;
}
Expand Down
3 changes: 2 additions & 1 deletion constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
#include <cstdint>

#ifdef __EMSCRIPTEN__
static constexpr int64_t kSamplesPerFrame = 2048;
static constexpr int32_t kMaxThreads = 8;
static constexpr int64_t kSamplesPerFrame = 1024;
#else
static constexpr int64_t kSamplesPerFrame = 512;
#endif
Expand Down
155 changes: 90 additions & 65 deletions keytap3-app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct StateRecording {
std::atomic_bool interrupt = false;
std::atomic_bool doRecord = false;
std::atomic_bool doneRecording = false;
std::atomic_bool doneCutoffSearch = false;
size_t totalSize_bytes = 0;

std::string pathOutput = "record.kbd";
Expand All @@ -71,7 +72,7 @@ struct StateRecording {
AudioLogger::Callback cbAudio;

bool init() {
if (isStarted && !doneRecording) {
if (isStarted) {
return false;
}

Expand All @@ -80,6 +81,7 @@ struct StateRecording {
interrupt = false;
doRecord = true;
doneRecording = false;
doneCutoffSearch = false;
totalSize_bytes = 0;

waveformF.clear();
Expand All @@ -97,7 +99,7 @@ struct StateRecording {

if (interrupt) {
doneRecording = true;
printf("\n[!] Recoding interrupted\n");
printf("\n[!] Recording interrupted\n");
return;
}
if (nKeysToCapture <= nKeysHave && !doneRecording) {
Expand All @@ -110,6 +112,7 @@ struct StateRecording {

const auto tEnd = std::chrono::high_resolution_clock::now();
const auto tDiff = std::chrono::duration_cast<std::chrono::milliseconds>(tEnd - tStart).count();

if (tDiff > 3) {
printf("[!] Audio callback took %ld ms\n", (long) tDiff);
}
Expand Down Expand Up @@ -299,7 +302,6 @@ bool AppInterface::init(State & state) {

printf("[+] Starting recording. nKeysToCapture: %d\n", state.recording.nKeysToCapture);
state.state = State::Recording;
state.recording.init();
}
} else if (cmd == "stop") {
if (state.state == State::Recording) {
Expand Down Expand Up @@ -354,6 +356,7 @@ bool AppInterface::init(State & state) {
} break;
case State::Recording:
{
state.recording.init();
state.recording.update();

if (state.worker.joinable() == false) {
Expand All @@ -373,15 +376,19 @@ bool AppInterface::init(State & state) {
const auto tEnd = std::chrono::high_resolution_clock::now();

printf("[+] Recording took %4.3f seconds\n", toSeconds(tStart, tEnd));

printf("[+] Finding best cutoff frequency ..\n");
state.recording.updateWorker(state.dataOutput, -1.0f);

state.recording.doneCutoffSearch = true;
});
}

if (state.recording.doneRecording) {
if (state.recording.doneCutoffSearch) {
state.recording.audioLogger.terminate();
state.recording.isStarted = false;
state.worker.join();

state.recording.updateWorker(state.dataOutput, -1.0f);

// write record.kbd
{
std::ofstream fout(state.recording.pathOutput, std::ios::binary);
Expand Down Expand Up @@ -488,86 +495,104 @@ bool AppInterface::init(State & state) {

n = keyPresses.size();

const int ncc = std::min(32, n);
for (int j = 0; j < ncc; ++j) {
printf("%2d: ", j);
for (int i = 0; i < ncc; ++i) {
printf("%6.3f ", similarityMap[j][i].cc);
if (n > 0) {
const int ncc = std::min(32, n);
for (int j = 0; j < ncc; ++j) {
printf("%2d: ", j);
for (int i = 0; i < ncc; ++i) {
printf("%6.3f ", similarityMap[j][i].cc);
}
printf("\n");
}
printf("\n");
}
printf("\n");

auto minCC = similarityMap[0][1].cc;
auto maxCC = similarityMap[0][1].cc;
for (int j = 0; j < n - 1; ++j) {
for (int i = j + 1; i < n; ++i) {
minCC = std::min(minCC, similarityMap[j][i].cc);
maxCC = std::max(maxCC, similarityMap[j][i].cc);

auto minCC = similarityMap[0][1].cc;
auto maxCC = similarityMap[0][1].cc;
for (int j = 0; j < n - 1; ++j) {
for (int i = j + 1; i < n; ++i) {
minCC = std::min(minCC, similarityMap[j][i].cc);
maxCC = std::max(maxCC, similarityMap[j][i].cc);
}
}
}

printf("[+] Similarity map: min = %g, max = %g\n", minCC, maxCC);
printf("[+] Similarity map: min = %g, max = %g\n", minCC, maxCC);
}
}
{
Cipher::Processor processor;

Cipher::TParameters params;
params.maxClusters = 29;
params.wEnglishFreq = 20.0;
params.nHypothesesToKeep = std::max(100, 2100 - 10*std::min(200, std::max(0, ((int) keyPresses.size() - 100))));
processor.init(params, state.decoding.freqMap6, similarityMap);
if (n > 0) {
const int nThread = std::min(8, std::max(1, int(std::thread::hardware_concurrency()) - 2));

printf("[+] Attempting to recover the text from the recording. nHypothesesToKeep = %d\n", params.nHypothesesToKeep);
printf("[+] Attempting to recover the text from the recording, nThreads = %d\n", nThread);

std::vector<Cipher::TResult> clusterings;
for (int iMain = 0; iMain < 16; ++iMain) {
Cipher::Processor processor;

Cipher::TParameters params;
params.maxClusters = 30;
params.wEnglishFreq = 30.0;
params.fSpread = 0.5 + 0.1*iMain;
params.nHypothesesToKeep = std::max(100, 500 - 2*std::min(200, std::max(0, ((int) keyPresses.size() - 100))));
processor.init(params, state.decoding.freqMap6, similarityMap);

// clustering
{
const auto tStart = std::chrono::high_resolution_clock::now();

for (int nIter = 0; nIter < 16; ++nIter) {
auto clusteringsCur = processor.getClusterings(32);
std::vector<Cipher::TResult> clusterings;

for (int i = 0; i < (int) clusteringsCur.size(); ++i) {
printf("[+] Clustering %d: pClusters = %g\n", i, clusteringsCur[i].pClusters);
clusterings.push_back(std::move(clusteringsCur[i]));
// clustering
{
const auto tStart = std::chrono::high_resolution_clock::now();

for (int nIter = 0; nIter < 8; ++nIter) {
auto clusteringsCur = processor.getClusterings(2);

for (int i = 0; i < (int) clusteringsCur.size(); ++i) {
clusterings.push_back(std::move(clusteringsCur[i]));
}

params.maxClusters = 30 + 8*(nIter + 1);
processor.init(params, state.decoding.freqMap6, similarityMap);
}

params.maxClusters = 29 + 4*(nIter + 1);
processor.init(params, state.decoding.freqMap6, similarityMap);
const auto tEnd = std::chrono::high_resolution_clock::now();
printf("[+] Clustering took %4.3f seconds, fSpread = %g\n", toSeconds(tStart, tEnd), params.fSpread);
}

const auto tEnd = std::chrono::high_resolution_clock::now();
printf("[+] Clustering took %4.3f seconds\n", toSeconds(tStart, tEnd));
}

params.hint.clear();
params.hint.resize(n, -1);
printf("\n[+] Recovering the unknown text:\n\n");
params.hint.clear();
params.hint.resize(n, -1);

[[maybe_unused]] int nConverged = 0;
while (true) {
// beam search
{
for (int i = 0; i < (int) clusterings.size(); ++i) {
Cipher::beamSearch(params, state.decoding.freqMap6, clusterings[i]);
printf(" ");
Cipher::printDecoded(clusterings[i].clusters, clusterings[i].clMap, params.hint);
printf(" [%8.3f %8.3f]\n", clusterings[i].p, clusterings[i].pClusters);
printf(" ");
Cipher::refineNearby(params, state.decoding.freqMap6, clusterings[i]);
printf("\n");

if (state.decoding.interrupt) {
printf("\n[!] Analysis interrupted\n");
break;
}
std::vector<std::thread> workers(std::min(nThread, (int) clusterings.size()));

std::mutex mutexPrint;
for (int i = 0; i < nThread; ++i) {
workers[i] = std::thread([&, i]() {
for (int j = i; j < (int) clusterings.size(); j += nThread) {
Cipher::beamSearch(params, state.decoding.freqMap6, clusterings[j]);
mutexPrint.lock();
printf(" ");
Cipher::printDecoded(clusterings[j].clusters, clusterings[j].clMap, params.hint);
printf(" [%8.3f %8.3f]\n", clusterings[j].p, clusterings[j].pClusters);
mutexPrint.unlock();

if (state.decoding.interrupt) {
break;
}
}
});
}
}

break;
for (auto& worker : workers) {
worker.join();
}

if (state.decoding.interrupt) {
printf("\n[!] Analysis interrupted\n");
break;
}
}
}
} else {
printf("[!] No keys found\n");
}

printf("[+] Done\n");
Expand Down
Loading

0 comments on commit cac8409

Please sign in to comment.