Skip to content

Commit

Permalink
Add source and destination nodes to routing::Score
Browse files Browse the repository at this point in the history
Expand routing::Score::channel_penalty_msat to include the source and
destination node ids of the channel. This allows scores to avoid certain
nodes altogether if desired.
  • Loading branch information
jkczyz committed Oct 18, 2021
1 parent 57f0822 commit d8f0266
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 6 deletions.
6 changes: 5 additions & 1 deletion lightning/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ pub mod network_graph;
pub mod router;
pub mod scorer;

use routing::network_graph::NodeId;

/// An interface used to score payment channels for path finding.
///
/// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
pub trait Score {
/// Returns the fee in msats willing to be paid to avoid routing through the given channel.
fn channel_penalty_msat(&self, short_channel_id: u64) -> u64;
fn channel_penalty_msat(
&self, short_channel_id: u64, source: &NodeId, destination: &NodeId
) -> u64;
}
73 changes: 69 additions & 4 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ where L::Target: Logger {
}

let path_penalty_msat = $next_hops_path_penalty_msat
.checked_add(scorer.channel_penalty_msat($chan_id.clone()))
.checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
.unwrap_or_else(|| u64::max_value());
let new_graph_node = RouteGraphNode {
node_id: $src_node_id,
Expand Down Expand Up @@ -973,15 +973,17 @@ where L::Target: Logger {
_ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value())
}) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity

let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id))
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
.unwrap_or_else(|| u64::max_value());

// We assume that the recipient only included route hints for routes which had
// sufficient value to route `final_value_msat`. Note that in the case of "0-value"
// invoices where the invoice does not specify value this may not be the case, but
// better to include the hints than not.
if !add_entry!(hop.short_channel_id, NodeId::from_pubkey(&hop.src_node_id), NodeId::from_pubkey(&prev_hop_id), directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
if !add_entry!(hop.short_channel_id, src_node_id, dest_node_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
// If this hop was not used then there is no use checking the preceding hops
// in the RouteHint. We can break by just searching for a direct channel between
// last checked hop and first_hop_targets
Expand Down Expand Up @@ -1322,7 +1324,8 @@ where L::Target: Logger {

#[cfg(test)]
mod tests {
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
use routing;
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
use routing::scorer::Scorer;
use chain::transaction::OutPoint;
Expand Down Expand Up @@ -4377,6 +4380,68 @@ mod tests {
assert_eq!(path, vec![2, 4, 7, 10]);
}

struct BadChannelScorer {
short_channel_id: u64,
}

impl routing::Score for BadChannelScorer {
fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _destination: &NodeId) -> u64 {
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
}
}

struct BadNodeScorer {
node_id: NodeId,
}

impl routing::Score for BadNodeScorer {
fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, destination: &NodeId) -> u64 {
if *destination == self.node_id { u64::max_value() } else { 0 }
}
}

#[test]
fn avoids_routing_through_bad_channels_and_nodes() {
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);

// A path to nodes[6] exists when no penalties are applied to any channel.
let scorer = Scorer::new(0);
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.get_total_fees(), 100);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 6, 11, 8]);

// A different path to nodes[6] exists if channel 6 cannot be routed over.
let scorer = BadChannelScorer { short_channel_id: 6 };
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.get_total_fees(), 300);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 7, 10]);

// A path to nodes[6] does not exist if nodes[2] cannot be routed through.
let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) };
match get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
) {
Err(LightningError { err, .. } ) => {
assert_eq!(err, "Failed to find a path to the given destination");
},
Ok(_) => panic!("Expected error"),
}
}

#[test]
fn total_fees_single_path() {
let route = Route {
Expand Down
8 changes: 7 additions & 1 deletion lightning/src/routing/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

use routing;

use routing::network_graph::NodeId;

/// [`routing::Score`] implementation that provides reasonable default behavior.
///
/// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
Expand Down Expand Up @@ -71,5 +73,9 @@ impl Default for Scorer {
}

impl routing::Score for Scorer {
fn channel_penalty_msat(&self, _short_channel_id: u64) -> u64 { self.base_penalty_msat }
fn channel_penalty_msat(
&self, _short_channel_id: u64, _source: &NodeId, _destination: &NodeId
) -> u64 {
self.base_penalty_msat
}
}

0 comments on commit d8f0266

Please sign in to comment.