Skip to content

Commit

Permalink
fixed network search i waves, looks like it works really nicely
Browse files Browse the repository at this point in the history
  • Loading branch information
lutteropp committed Dec 9, 2020
1 parent 763a2cc commit fa50575
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
39 changes: 30 additions & 9 deletions src/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "optimization/BranchLengthOptimization.hpp"
#include "optimization/TopologyOptimization.hpp"
#include "DebugPrintFunctions.hpp"
#include "utils.hpp"

namespace netrax {

Expand Down Expand Up @@ -217,7 +218,11 @@ void NetraxInstance::updateReticulationProbs(AnnotatedNetwork &ann_network) {
netrax::optimize_reticulations(ann_network, 100);
double new_score = scoreNetwork(ann_network);
std::cout << "BIC score after updating reticulation probs: " << new_score << "\n";
assert(new_score <= old_score);
if (definitelyGreaterThan(new_score, old_score)) {
std::cout << "old score: " << old_score << "\n";
std::cout << "new score: " << new_score << "\n";
assert(new_score <= old_score);
}
}

/**
Expand All @@ -237,7 +242,11 @@ void NetraxInstance::optimizeModel(AnnotatedNetwork &ann_network) {
ann_network.raxml_treeinfo->optimize_model(ann_network.options.lh_epsilon);
double new_score = scoreNetwork(ann_network);
std::cout << "BIC score after model optimization: " << new_score << "\n";
assert(new_score <= old_score);
if (definitelyGreaterThan(new_score, old_score)) {
std::cout << "old score: " << old_score << "\n";
std::cout << "new score: " << new_score << "\n";
assert(new_score <= old_score);
}
}

/**
Expand All @@ -252,7 +261,11 @@ void NetraxInstance::optimizeBranches(AnnotatedNetwork &ann_network) {
assert(netrax::computeLoglikelihood(ann_network, 1, 1) == netrax::computeLoglikelihood(ann_network, 0, 1));
double new_score = scoreNetwork(ann_network);
std::cout << "BIC score after branch length optimization: " << new_score << "\n";
assert(new_score <= old_score);
if (definitelyGreaterThan(new_score, old_score)) {
std::cout << "old score: " << old_score << "\n";
std::cout << "new score: " << new_score << "\n";
assert(new_score <= old_score);
}
}

/**
Expand All @@ -266,7 +279,11 @@ void NetraxInstance::optimizeTopology(AnnotatedNetwork &ann_network, const std::
greedyHillClimbingTopology(ann_network, types);
double new_score = scoreNetwork(ann_network);
std::cout << "BIC after topology optimization: " << new_score << "\n";
assert(new_score <= old_score);
if (definitelyGreaterThan(new_score, old_score)) {
std::cout << "old score: " << old_score << "\n";
std::cout << "new score: " << new_score << "\n";
assert(new_score <= old_score);
}
}

/**
Expand All @@ -280,11 +297,12 @@ void NetraxInstance::optimizeTopology(AnnotatedNetwork &ann_network, MoveType& t
greedyHillClimbingTopology(ann_network, type);
double new_score = scoreNetwork(ann_network);
//std::cout << "BIC after topology optimization: " << new_score << "\n";
if (new_score > old_score) {

if (definitelyGreaterThan(new_score, old_score)) {
std::cout << "old score: " << old_score << "\n";
std::cout << "new score: " << new_score << "\n";
assert(new_score <= old_score);
}
assert(new_score <= old_score);
}

double NetraxInstance::optimizeEverythingRun(AnnotatedNetwork & ann_network, std::vector<MoveType>& typesBySpeed, const std::chrono::high_resolution_clock::time_point& start_time) {
Expand Down Expand Up @@ -359,14 +377,17 @@ void NetraxInstance::optimizeEverything(AnnotatedNetwork &ann_network) {
updateReticulationProbs(ann_network);
optimizeModel(ann_network);
double new_score = scoreNetwork(ann_network);
std::cout << "Initial optimized network loglikelihood: " << computeLoglikelihood(ann_network) << "\n";
std::cout << "Initial optimized network BIC score: " << new_score << "\n";
std::cout << "Initial optimized " << ann_network.network.num_reticulations() << "-reticulation network loglikelihood: " << computeLoglikelihood(ann_network) << "\n";
std::cout << "Initial optimized " << ann_network.network.num_reticulations() << "-reticulation network BIC score: " << new_score << "\n";
//assert(new_score <= initial_score);

//double_check_likelihood(ann_network);

optimizeEverythingRun(ann_network, typesBySpeed, start_time);

std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network loglikelihood: " << computeLoglikelihood(ann_network) << "\n";
std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network BIC score: " << new_score << "\n";

std::cout << "Statistics on which moves were taken:\n";
for (const auto& entry : ann_network.stats.moves_taken) {
std::cout << toString(entry.first) << ": " << entry.second << "\n";
Expand Down Expand Up @@ -402,7 +423,7 @@ void NetraxInstance::optimizeEverythingInWaves(AnnotatedNetwork &ann_network) {
std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network loglikelihood: " << computeLoglikelihood(ann_network) << "\n";
std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network BIC score: " << new_score << "\n";

if (new_score >= best_score) { // score did not get worse
if (new_score <= best_score) { // score did not get worse
best_score = new_score;
if (ann_network.network.num_reticulations() < ann_network.options.max_reticulations) {
seen_improvement = true;
Expand Down
57 changes: 55 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,59 @@ int parseOptions(int argc, char **argv, netrax::NetraxOptions *options) {
return 0;
}

void run_single_start_waves(NetraxOptions& netraxOptions, std::mt19937& rng) {
std::vector<MoveType> typesBySpeed = {MoveType::RNNIMove, MoveType::RSPR1Move, MoveType::TailMove, MoveType::HeadMove};
auto start_time = std::chrono::high_resolution_clock::now();
double best_score = std::numeric_limits<double>::infinity();
int best_num_reticulations = 0;
netrax::AnnotatedNetwork ann_network = NetraxInstance::build_annotated_network(netraxOptions);
NetraxInstance::init_annotated_network(ann_network, rng);

std::cout << "Initial network is:\n" << toExtendedNewick(ann_network) << "\n\n";
std::string best_network = toExtendedNewick(ann_network);

bool seen_improvement = true;
while (seen_improvement) {
seen_improvement = false;

double new_score = NetraxInstance::optimizeEverythingRun(ann_network, typesBySpeed, start_time);
std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network loglikelihood: " << NetraxInstance::computeLoglikelihood(ann_network) << "\n";
std::cout << "Best optimized " << ann_network.network.num_reticulations() << "-reticulation network BIC score: " << new_score << "\n";

if (new_score < best_score) {
best_score = new_score;
best_num_reticulations = ann_network.network.num_reticulations();
std::cout << "IMPROVED BEST SCORE FOUND SO FAR: " << best_score << "\n\n";
NetraxInstance::writeNetwork(ann_network, netraxOptions.output_file);
best_network = toExtendedNewick(ann_network);
std::cout << "Better network written to " << netraxOptions.output_file << "\n";
} else {
std::cout << "REMAINED BEST SCORE FOUND SO FAR: " << best_score << "\n";
}

if (new_score <= best_score) { // score did not get worse
if (ann_network.network.num_reticulations() < ann_network.options.max_reticulations) {
seen_improvement = true;
NetraxInstance::add_extra_reticulations(ann_network, ann_network.network.num_reticulations() + 1);
NetraxInstance::optimizeBranches(ann_network);
NetraxInstance::optimizeModel(ann_network);
NetraxInstance::updateReticulationProbs(ann_network);
new_score = NetraxInstance::scoreNetwork(ann_network);
std::cout << "Initial optimized " << ann_network.network.num_reticulations() << "-reticulation network loglikelihood: " << NetraxInstance::computeLoglikelihood(ann_network) << "\n";
std::cout << "Initial optimized " << ann_network.network.num_reticulations() << "-reticulation network BIC score: " << new_score << "\n";
}
}
}

std::cout << "The inferred network has " << best_num_reticulations << " reticulations and this BIC score: " << best_score << "\n\n";
std::cout << "Best found network is:\n" << best_network << "\n\n";

std::cout << "Statistics on which moves were taken:\n";
for (const auto& entry : ann_network.stats.moves_taken) {
std::cout << toString(entry.first) << ": " << entry.second << "\n";
}
}

void run_single_start(NetraxOptions& netraxOptions, std::mt19937& rng) {
double best_score = std::numeric_limits<double>::infinity();

Expand All @@ -60,7 +113,7 @@ void run_single_start(NetraxOptions& netraxOptions, std::mt19937& rng) {
std::cout << "Initial network is:\n" << toExtendedNewick(ann_network) << "\n\n";
std::string best_network = toExtendedNewick(ann_network);

NetraxInstance::optimizeEverythingInWaves(ann_network);
NetraxInstance::optimizeEverything(ann_network);
double final_bic = NetraxInstance::scoreNetwork(ann_network);
std::cout << "The inferred network has " << ann_network.network.num_reticulations() << " reticulations and this BIC score: " << final_bic << "\n\n";
if (final_bic < best_score) {
Expand Down Expand Up @@ -271,7 +324,7 @@ int main(int argc, char **argv) {
}

if (!netraxOptions.start_network_file.empty()) {
run_single_start(netraxOptions, rng);
run_single_start_waves(netraxOptions, rng);
} else {
run_random(netraxOptions, rng);
}
Expand Down

0 comments on commit fa50575

Please sign in to comment.