Skip to content

Commit

Permalink
all_road_paths (#5)
Browse files Browse the repository at this point in the history
* init compile_commands.json config

* okay

* fix

* fix test

---------

Co-authored-by: TANG ZHIXIONG <zhixiong.tang@momenta.ai>
  • Loading branch information
district10 and zhixiong-tang committed May 23, 2024
1 parent dcc1903 commit 225ae16
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 2 deletions.
14 changes: 14 additions & 0 deletions .vscode/c_cpp_properties.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"configurations": [
{
"name": "Linux",
"defines": [],
"compilerPath": "/usr/bin/clang",
"cStandard": "c11",
"cppStandard": "c++17",
"intelliSenseMode": "clang-x64",
"compileCommands": "${workspaceFolder}/build/compile_commands.json"
}
],
"version": 4
}
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{"name":"Python: Current File","type":"python","request":"launch","program":"${file}","console":"integratedTerminal"},
]
}
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"files.associations": {
"optional": "cpp"
}
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "fast_viterbi"
version = "0.1.0"
version = "0.1.1"
description="a viterbi algo collection"
readme = "README.md"
authors = [
Expand Down
47 changes: 47 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ struct Seq {
}
};

using Roads = std::unordered_set<std::vector<int64_t>, hash_vector<std::vector<int64_t>>>;

namespace std {
template <>
struct hash<Seq> {
Expand Down Expand Up @@ -209,6 +211,50 @@ struct FastViterbi {
return true;
}

std::vector<std::vector<int64_t>> all_road_paths() const {
if (roads_.empty() || sp_paths_.empty()) {
return {};
}
std::vector<Roads> prev_paths(K_);
for (auto &pair : heads_) {
auto cidx = pair.first;
auto nidx = roads_[0][cidx];
prev_paths[cidx].insert({nidx});
}
for (int n = 0; n < N_ - 1; ++n) {
std::vector<Roads> curr_paths(K_);
auto &paths = sp_paths_.at(n);
auto &layer = links_[n];
for (int i = 0; i < K_; ++i) {
const auto &heads = prev_paths[i];
if (heads.empty() || layer[i].empty()) {
continue;
}
auto &p = paths.at(i);
for (auto &pair : layer[i]) {
int j = pair.first;
const auto &sig = p.at(j);
if (sig.size() == 1) {
curr_paths[j].insert(heads.begin(), heads.end());
continue;
}
for (auto copy : heads) {
copy.insert(copy.end(), sig.begin() + 1, sig.end());
curr_paths[j].insert(std::move(copy));
}
}
}
prev_paths = std::move(curr_paths);
}
Roads ret;
for (auto seqs : prev_paths) {
for (auto &seq : seqs) {
ret.insert(seq);
}
}
return {ret.begin(), ret.end()};
}

std::tuple<double, std::vector<int>, std::vector<int64_t>> inference(const std::vector<int64_t> &road_path) const {
if (roads_.empty() || sp_paths_.empty()) {
return std::make_tuple(pos_inf, std::vector<int>{}, std::vector<int64_t>{});
Expand Down Expand Up @@ -361,6 +407,7 @@ PYBIND11_MODULE(_core, m) {
.def("setup_roads", &FastViterbi::setup_roads, "roads"_a)
.def("setup_shortest_road_paths", &FastViterbi::setup_shortest_road_paths, "sp_paths"_a)
//
.def("all_road_paths", &FastViterbi::all_road_paths)
.def("inference", py::overload_cast<const std::vector<int64_t> &>(&FastViterbi::inference, py::const_),
"road_path"_a, py::call_guard<py::gil_scoped_release>())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_version():
assert m.__version__ == "0.1.0"
assert m.__version__ == "0.1.1"


def test_add():
Expand Down

0 comments on commit 225ae16

Please sign in to comment.