Skip to content

Commit

Permalink
check if fee has been collected
Browse files Browse the repository at this point in the history
  • Loading branch information
johncantrell97 committed May 17, 2024
1 parent 0513e0c commit 5983ec0
Showing 1 changed file with 95 additions and 23 deletions.
118 changes: 95 additions & 23 deletions src/lsps2/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ enum HTLCInterceptedAction {
ForwardPayment(ForwardPaymentAction),
}

/// Possible actions that need to be taken when a payment is forwarded.
#[derive(Debug, PartialEq)]
enum PaymentForwardedAction {
ForwardPayment(ForwardPaymentAction),
ForwardHTLCs(ForwardHTLCsAction),
}

/// The forwarding of a payment while skimming the JIT channel opening fee.
#[derive(Debug, PartialEq)]
struct ForwardPaymentAction(ChannelId, FeePayment);
Expand Down Expand Up @@ -318,23 +325,42 @@ impl OutboundJITChannelState {
}

fn payment_forwarded(
&mut self,
) -> Result<(Self, Option<ForwardHTLCsAction>), ChannelStateError> {
&mut self, skimmed_fee_msat: Option<u64>,
) -> Result<(Self, Option<PaymentForwardedAction>), ChannelStateError> {
match self {
OutboundJITChannelState::PendingPaymentForward {
payment_queue, channel_id, ..
payment_queue,
channel_id,
opening_fee_msat,
} => {
let mut payment_queue_lock = payment_queue.lock().unwrap();
let payment_forwarded =
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
let htlcs = payment_queue_lock
.clear()
.into_iter()
.map(|(_, htlcs)| htlcs)
.flatten()
.collect();
let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs);
Ok((payment_forwarded, Some(forward_htlcs)))

let skimmed_fee_msat = skimmed_fee_msat.unwrap_or(0);
let remaining_fee = opening_fee_msat.saturating_sub(skimmed_fee_msat);

if remaining_fee > 0 {
let (state, payment_action) = try_get_payment(
Arc::clone(payment_queue),
payment_queue_lock,
*channel_id,
remaining_fee,
);
Ok((state, payment_action.map(|pa| PaymentForwardedAction::ForwardPayment(pa))))
} else {
let payment_forwarded =
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
let htlcs = payment_queue_lock
.clear()
.into_iter()
.map(|(_, htlcs)| htlcs)
.flatten()
.collect();
let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs);
Ok((
payment_forwarded,
Some(PaymentForwardedAction::ForwardHTLCs(forward_htlcs)),
))
}
},
OutboundJITChannelState::PaymentForwarded { channel_id } => {
let payment_forwarded =
Expand Down Expand Up @@ -368,6 +394,10 @@ impl OutboundJITChannel {
}
}

pub fn has_paid_fee(&self) -> bool {
matches!(self.state, OutboundJITChannelState::PaymentForwarded { .. })
}

fn htlc_intercepted(
&mut self, htlc: InterceptedHTLC,
) -> Result<Option<HTLCInterceptedAction>, LightningError> {
Expand All @@ -391,8 +421,10 @@ impl OutboundJITChannel {
Ok(action)
}

fn payment_forwarded(&mut self) -> Result<Option<ForwardHTLCsAction>, LightningError> {
let (new_state, action) = self.state.payment_forwarded()?;
fn payment_forwarded(
&mut self, skimmed_fee_msat: Option<u64>,
) -> Result<Option<PaymentForwardedAction>, LightningError> {
let (new_state, action) = self.state.payment_forwarded(skimmed_fee_msat)?;
self.state = new_state;
Ok(action)
}
Expand Down Expand Up @@ -818,7 +850,9 @@ where
/// greater or equal to 0.0.107.
///
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> {
pub fn payment_forwarded(
&self, next_channel_id: ChannelId, skimmed_fee_msat: Option<u64>,
) -> Result<bool, APIError> {
if let Some(counterparty_node_id) =
self.peer_by_channel_id.read().unwrap().get(&next_channel_id)
{
Expand All @@ -832,8 +866,10 @@ where
if let Some(jit_channel) =
peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
{
match jit_channel.payment_forwarded() {
Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => {
match jit_channel.payment_forwarded(skimmed_fee_msat) {
Ok(Some(PaymentForwardedAction::ForwardHTLCs(
ForwardHTLCsAction(channel_id, htlcs),
))) => {
for htlc in htlcs {
self.channel_manager.get_cm().forward_intercepted_htlc(
htlc.intercept_id,
Expand All @@ -843,6 +879,29 @@ where
)?;
}
},
Ok(Some(PaymentForwardedAction::ForwardPayment(
ForwardPaymentAction(
channel_id,
FeePayment { htlcs, opening_fee_msat },
),
))) => {
let amounts_to_forward_msat =
calculate_amount_to_forward_per_htlc(
&htlcs,
opening_fee_msat,
);

for (intercept_id, amount_to_forward_msat) in
amounts_to_forward_msat
{
self.channel_manager.get_cm().forward_intercepted_htlc(
intercept_id,
&channel_id,
*counterparty_node_id,
amount_to_forward_msat,
)?;
}
},
Ok(None) => {},
Err(e) => {
return Err(APIError::APIMisuseError {
Expand All @@ -853,6 +912,7 @@ where
})
},
}
return Ok(jit_channel.has_paid_fee());
}
} else {
return Err(APIError::APIMisuseError {
Expand All @@ -868,7 +928,7 @@ where
}
}

Ok(())
Ok(false)
}

/// Used by LSP to fail intercepted htlcs backwards when the channel open fails for any reason.
Expand Down Expand Up @@ -1476,12 +1536,18 @@ mod tests {
}
state = new_state;
}

// TODO: how do I get the expected skimmed amount here

// Payment completes, queued payments get forwarded.
{
let (new_state, action) = state.payment_forwarded().unwrap();
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
match action {
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
channel_id,
htlcs,
))) => {
assert_eq!(channel_id, ChannelId([200; 32]));
assert_eq!(
htlcs,
Expand Down Expand Up @@ -1617,12 +1683,18 @@ mod tests {
}
state = new_state;
}

// TODO: how do I grab the expected skimmed fee amount here.

// Payment completes, queued payments get forwarded.
{
let (new_state, action) = state.payment_forwarded().unwrap();
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
match action {
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
channel_id,
htlcs,
))) => {
assert_eq!(channel_id, ChannelId([200; 32]));
assert_eq!(
htlcs,
Expand Down

0 comments on commit 5983ec0

Please sign in to comment.