Skip to content

Commit

Permalink
Implement test for build_route_from_hops.
Browse files Browse the repository at this point in the history
  • Loading branch information
tnull committed May 25, 2022
1 parent 75ca50f commit f6607a2
Showing 1 changed file with 75 additions and 13 deletions.
88 changes: 75 additions & 13 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ impl Readable for Route {

/// Parameters needed to find a [`Route`].
///
/// Passed to [`find_route`] and also provided in [`Event::PaymentPathFailed`] for retrying a failed
/// payment path.
/// Passed to [`find_route`] and [`build_route_from_hops`], but also provided in
/// [`Event::PaymentPathFailed`] for retrying a failed payment path.
///
/// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -676,16 +676,11 @@ pub fn find_route<L: Deref, S: Score>(
) -> Result<Route, LightningError>
where L::Target: Logger {
let network_graph = network.read_only();
match get_route(
our_node_pubkey, &route_params.payment_params, &network_graph, first_hops, route_params.final_value_msat,
route_params.final_cltv_expiry_delta, logger, scorer, random_seed_bytes
) {
Ok(mut route) => {
add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
Ok(route)
},
Err(err) => Err(err),
}
let mut route = get_route(our_node_pubkey, &route_params.payment_params, &network_graph, first_hops,
route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, scorer,
random_seed_bytes)?;
add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
Ok(route)
}

pub(crate) fn get_route<L: Deref, S: Score>(
Expand Down Expand Up @@ -1785,10 +1780,57 @@ fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
}
}

/// Build a route from us (payer) with the given hops ending at the target node (payee).
///
/// Re-uses logic from `find_route`, so the restrictions described there also apply here.
pub fn build_route_from_hops<L: Deref>(
our_node_pubkey: &PublicKey, hops: &[PublicKey], route_params: &RouteParameters, network: &NetworkGraph,
logger: L, random_seed_bytes: &[u8; 32]) -> Result<Route, LightningError>
where L::Target: Logger {
let network_graph = network.read_only();
let mut route = build_route_from_hops_internal(
our_node_pubkey, hops, &route_params.payment_params, &network_graph,
route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, random_seed_bytes)?;
add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
Ok(route)
}

fn build_route_from_hops_internal<L: Deref>(
our_node_pubkey: &PublicKey, hops: &[PublicKey], payment_params: &PaymentParameters,
network_graph: &ReadOnlyNetworkGraph, final_value_msat: u64, final_cltv_expiry_delta: u32,
logger: L, random_seed_bytes: &[u8; 32]) -> Result<Route, LightningError> where L::Target: Logger {

struct HopScorer<'a> { hops: &'a [PublicKey] }

impl Score for HopScorer<'_> {
fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
_usage: ChannelUsage) -> u64
{
for i in 0..self.hops.len()-1 {
let cur_id = NodeId::from_pubkey(&self.hops[i]);
let next_id = NodeId::from_pubkey(&self.hops[i+1]);
if cur_id == *source && next_id == *target {
return 0;
}
}
u64::max_value()
}

fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}

fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
}

let scorer = HopScorer { hops: &[&[*our_node_pubkey],hops].concat() };

get_route(our_node_pubkey, payment_params, network_graph, None, final_value_msat,
final_cltv_expiry_delta, logger, &scorer, random_seed_bytes)
}

#[cfg(test)]
mod tests {
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
use routing::router::{get_route, add_random_cltv_offset, default_node_features,
use routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
use routing::scoring::{ChannelUsage, Score};
Expand Down Expand Up @@ -5486,6 +5528,26 @@ mod tests {
assert!(path_plausibility.iter().all(|x| *x));
}

#[test]
fn builds_correct_path_from_hops() {
let (secp_ctx, network, _, _, logger) = build_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
let network_graph = network.read_only();

let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
let random_seed_bytes = keys_manager.get_secure_random_bytes();

let payment_params = PaymentParameters::from_node_id(nodes[3]);
let hops = [nodes[1], nodes[2], nodes[4], nodes[3]];
let route = build_route_from_hops_internal(&our_id, &hops, &payment_params,
&network_graph, 100, 0, Arc::clone(&logger), &random_seed_bytes).unwrap();
let route_hop_pubkeys = route.paths[0].iter().map(|hop| hop.pubkey).collect::<Vec<_>>();
assert_eq!(hops.len(), route.paths[0].len());
for (idx, hop_pubkey) in hops.iter().enumerate() {
assert!(*hop_pubkey == route_hop_pubkeys[idx]);
}
}

#[cfg(not(feature = "no-std"))]
pub(super) fn random_init_seed() -> u64 {
// Because the default HashMap in std pulls OS randomness, we can use it as a (bad) RNG.
Expand Down

0 comments on commit f6607a2

Please sign in to comment.