-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
emst_main.cpp
131 lines (108 loc) · 3.89 KB
/
emst_main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/**
* @file emst_main.cpp
* @author Bill March (march@gatech.edu)
*
* Calls the DualTreeBoruvka algorithm from dtb.hpp.
* Can optionally call naive Boruvka's method.
*
* For algorithm details, see:
*
* @code
* @inproceedings{
* author = {March, W.B., Ram, P., and Gray, A.G.},
* title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
* Applications.}},
* booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
* on Knowledge Discovery and Data Mining}
* series = {KDD 2010},
* year = {2010}
* }
* @endcode
*/
#include "dtb.hpp"
#include <mlpack/core.hpp>
PROGRAM_INFO("Fast Euclidean Minimum Spanning Tree", "This program can compute "
"the Euclidean minimum spanning tree of a set of input points using the "
"dual-tree Boruvka algorithm."
"\n\n"
"The output is saved in a three-column matrix, where each row indicates an "
"edge. The first column corresponds to the lesser index of the edge; the "
"second column corresponds to the greater index of the edge; and the third "
"column corresponds to the distance between the two points.");
PARAM_STRING_REQ("input_file", "Data input file.", "i");
PARAM_STRING("output_file", "Data output file. Stored as an edge list.",
"o", "");
PARAM_FLAG("naive", "Compute the MST using O(n^2) naive algorithm.", "n");
PARAM_INT("leaf_size", "Leaf size in the kd-tree. One-element leaves give the "
"empirically best performance, but at the cost of greater memory "
"requirements.", "l", 1);
using namespace mlpack;
using namespace mlpack::emst;
using namespace mlpack::tree;
using namespace mlpack::metric;
using namespace std;
int main(int argc, char* argv[])
{
CLI::ParseCommandLine(argc, argv);
const string inputFile = CLI::GetParam<string>("input_file");
const string outputFile= CLI::GetParam<string>("output_file");
if (CLI::HasParam("output_file"))
Log::Warn << "--output_file (-o) is not specified;"
<< "no results will be saved!" << endl;
arma::mat dataPoints;
data::Load(inputFile, dataPoints, true);
// Do naive computation if necessary.
if (CLI::GetParam<bool>("naive"))
{
Log::Info << "Running naive algorithm." << endl;
DualTreeBoruvka<> naive(dataPoints, true);
arma::mat naiveResults;
naive.ComputeMST(naiveResults);
if (CLI::HasParam("output_file"))
data::Save(outputFile, naiveResults, true);
}
else
{
Log::Info << "Building tree.\n";
// Check that the leaf size is reasonable.
if (CLI::GetParam<int>("leaf_size") <= 0)
{
Log::Fatal << "Invalid leaf size (" << CLI::GetParam<int>("leaf_size")
<< ")! Must be greater than or equal to 1." << std::endl;
}
// Initialize the tree and get ready to compute the MST. Compute the tree
// by hand.
const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
Timer::Start("tree_building");
std::vector<size_t> oldFromNew;
KDTree<EuclideanDistance, DTBStat, arma::mat> tree(dataPoints, oldFromNew,
leafSize);
metric::LMetric<2, true> metric;
Timer::Stop("tree_building");
DualTreeBoruvka<> dtb(&tree, metric);
// Run the DTB algorithm.
Log::Info << "Calculating minimum spanning tree." << endl;
arma::mat results;
dtb.ComputeMST(results);
// Unmap the results.
arma::mat unmappedResults(results.n_rows, results.n_cols);
for (size_t i = 0; i < results.n_cols; ++i)
{
const size_t indexA = oldFromNew[size_t(results(0, i))];
const size_t indexB = oldFromNew[size_t(results(1, i))];
if (indexA < indexB)
{
unmappedResults(0, i) = indexA;
unmappedResults(1, i) = indexB;
}
else
{
unmappedResults(0, i) = indexB;
unmappedResults(1, i) = indexA;
}
unmappedResults(2, i) = results(2, i);
}
if (CLI::HasParam("output_file"))
data::Save(outputFile, unmappedResults, true);
}
}