From 91057af79ab004592ef05832ad1f553e454a3e6b Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 17:35:04 +0100 Subject: [PATCH 01/78] implementation of reliable transport --- internal/packetmuxer/service.go | 26 +- internal/reliabletransport/constants.go | 21 ++ internal/reliabletransport/interfaces.go | 51 +++ internal/reliabletransport/packets.go | 158 ++++++++ internal/reliabletransport/receiver.go | 99 +++++ .../reliabletransport/reliabletransport.go | 107 +++++- internal/reliabletransport/sender.go | 106 ++++++ internal/reliabletransport/sender_test.go | 346 ++++++++++++++++++ 8 files changed, 893 insertions(+), 21 deletions(-) create mode 100644 internal/reliabletransport/constants.go create mode 100644 internal/reliabletransport/interfaces.go create mode 100644 internal/reliabletransport/packets.go create mode 100644 internal/reliabletransport/receiver.go create mode 100644 internal/reliabletransport/sender.go create mode 100644 internal/reliabletransport/sender_test.go diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 88f3bfd9..9e126cfb 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -3,6 +3,7 @@ package packetmuxer import ( "fmt" + "time" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/session" @@ -48,12 +49,14 @@ func (s *Service) StartWorkers( sessionManager *session.Manager, ) { ws := &workersState{ - logger: logger, - hardReset: s.HardReset, + logger: logger, + hardReset: s.HardReset, + // initialize to a sufficiently long time from now + hardResetTicker: time.NewTicker(time.Hour * 24 * 30), notifyTLS: *s.NotifyTLS, + dataOrControlToMuxer: s.DataOrControlToMuxer, muxerToReliable: *s.MuxerToReliable, muxerToData: *s.MuxerToData, - dataOrControlToMuxer: s.DataOrControlToMuxer, muxerToNetwork: *s.MuxerToNetwork, networkToMuxer: s.NetworkToMuxer, sessionManager: sessionManager, @@ -71,6 +74,9 @@ type workersState struct { // hardReset is the channel posted to force a hard reset. hardReset <-chan any + // hardResetTicker is a channel to retry the initial send of hard reset packet. + hardResetTicker *time.Ticker + // notifyTLS is used to send notifications to the TLS service. notifyTLS chan<- *model.Notification @@ -116,6 +122,13 @@ func (ws *workersState) moveUpWorker() { return } + case <-ws.hardResetTicker.C: + // retry the hard reset, it probably was lost + if err := ws.startHardReset(); err != nil { + // error already logged + return + } + case <-ws.hardReset: if err := ws.startHardReset(); err != nil { // error already logged @@ -169,6 +182,7 @@ func (ws *workersState) moveDownWorker() { // startHardReset is invoked when we need to perform a HARD RESET. func (ws *workersState) startHardReset() error { // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt + // TODO(ainghazal): we need to retry this hard reset if not ACKd in a reasonable time. packet, err := ws.sessionManager.NewPacket(model.P_CONTROL_HARD_RESET_CLIENT_V2, nil) if err != nil { ws.logger.Warnf("packetmuxer: NewPacket: %s", err.Error()) @@ -178,7 +192,10 @@ func (ws *workersState) startHardReset() error { return err } - // reset the state to become initial again + // resend if not received the server's reply in 2 seconds. + ws.hardResetTicker.Reset(time.Second * 2) + + // reset the state to become initial again. ws.sessionManager.SetNegotiationState(session.S_PRE_START) // TODO: any other change to apply in this case? @@ -198,6 +215,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { // handle the case where we're performing a HARD_RESET if ws.sessionManager.NegotiationState() == session.S_PRE_START && packet.Opcode == model.P_CONTROL_HARD_RESET_SERVER_V2 { + ws.hardResetTicker.Stop() return ws.finishThreeWayHandshake(packet) } diff --git a/internal/reliabletransport/constants.go b/internal/reliabletransport/constants.go new file mode 100644 index 00000000..c9a110bc --- /dev/null +++ b/internal/reliabletransport/constants.go @@ -0,0 +1,21 @@ +package reliabletransport + +const ( + // Capacity for the array of packets that we're tracking at any given moment (outgoing). + RELIABLE_SEND_BUFFER_SIZE = 12 + + // Capacity for the array of packets that we're tracking at any given moment (incoming). + RELIABLE_RECV_BUFFER_SIZE = RELIABLE_SEND_BUFFER_SIZE + + // The maximum numbers of ACKs that we put in an array for an outgoing packet. + MAX_ACKS_PER_OUTGOING_PACKET = 4 + + // Initial timeout for TLS retransmission, in seconds. + INITIAL_TLS_TIMEOUT_SECONDS = 2 + + // Maximum backoff interval, in seconds. + MAX_BACKOFF_SECONDS = 60 + + // Default sender ticker period, in milliseconds. + SENDER_TICKER_MS = 1000 * 60 +) diff --git a/internal/reliabletransport/interfaces.go b/internal/reliabletransport/interfaces.go new file mode 100644 index 00000000..688c476d --- /dev/null +++ b/internal/reliabletransport/interfaces.go @@ -0,0 +1,51 @@ +package reliabletransport + +import ( + "github.com/ooni/minivpn/internal/model" +) + +// sequentialPacket is a packet that can return a [model.PacketID]. +type sequentialPacket interface { + ID() model.PacketID + ExtractACKs() []model.PacketID + Packet() *model.Packet +} + +// inFlightPacket is a packet that, additionally, can keep track of how many acks for a packet with a higher PID have been received. +type inFlighter interface { + sequentialPacket + ScheduleForRetransmission() +} + +// outgoingPacketHandler has methods to deal with the outgoing packets (going down). +type outgoingPacketHandler interface { + // TryInsertOutgoingPacket attempts to insert a packet into the queue. If return value is + // false, insertion was not successful. + TryInsertOutgoingPacket(*model.Packet) bool + + // MaybeEvictOrBumpPacketAfterACK removes a packet (that we received an ack for) from the in-flight packet queue. + MaybeEvictOrBumpPacketAfterACK(id model.PacketID) bool + + // NextPacketIDsToACK returns an array of pending IDs to ACK to + // our remote. The lenght of this array SHOULD not be larger than CONTROL_SEND_ACK_MAX. + // This is used to append it to the ACK array of the outgoing packet. + NextPacketIDsToACK() []model.PacketID + + // OnIncomingPacketSeen processes a notification received in the shared channel for incoming packets. + OnIncomingPacketSeen(incomingPacketSeen) +} + +// incomingPacketHandler knows how to deal with incoming packets (going up). +type incomingPacketHandler interface { + // TODO: the interface needs to add ACKs() + // NotifySeen sends notifications about an incoming packet. + NotifySeen(*model.Packet) bool + + // MaybeInsertIncoming will insert a given packet in the reliable + // incoming queue if it passes a series of sanity checks. + MaybeInsertIncoming(*model.Packet) bool + + // NextIncomingSequence gets the largest sequence of packets ready to be passed along + // to the control channel above us. + NextIncomingSequence() incomingSequence +} diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go new file mode 100644 index 00000000..f77c9782 --- /dev/null +++ b/internal/reliabletransport/packets.go @@ -0,0 +1,158 @@ +package reliabletransport + +import ( + "fmt" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +// +// A note about terminology: in the following, **receiver** is the moveUpWorker in the [reliabletransport.Service] (since it receives incoming packets), and **sender** is the moveDownWorker in the same service. The following data structures lack mutexes because they are intended to be confined to a single goroutine (one for each worker), and they only communicate via message passing. +// + +type inFlightPacket struct { + // deadline is a moment in time when is this packet scheduled for the next retransmission. + deadline time.Time + + // how many acks we've received for packets with higher PID. + higherACKs int + + // packet is the underlying packet being sent. + packet *model.Packet + + // retries is a monotonically increasing counter for retransmission. + retries uint8 +} + +func newInFlightPacket(p *model.Packet) *inFlightPacket { + return &inFlightPacket{ + deadline: time.Time{}, + higherACKs: 0, + packet: p, + retries: 0, + } +} + +func (p *inFlightPacket) ExtractACKs() []model.PacketID { + return p.packet.ACKs +} + +// TODO leaving Fast retransmission out for now. +// ACKForHigherPacket increments the number of acks received for a higher pid than this packet. This will influence the fast rexmit selection algorithm. +func (p *inFlightPacket) ACKForHigherPacket() { + p.higherACKs += 1 +} + +func (p *inFlightPacket) ScheduleForRetransmission(t time.Time) { + p.retries += 1 + p.deadline = t.Add(p.backoff()) +} + +// backoff will calculate the next retransmission interval. +func (p *inFlightPacket) backoff() time.Duration { + backoff := time.Duration(1< maxBackoff { + backoff = maxBackoff + } + return backoff +} + +// assert that inFlightWrappedPacket implements inFlightPacket and sequentialPacket +// var _ inFlightPacket = &inFlightWrappedPacket{} +// var _ sequentialPacket = &inFlightWrappedPacket{} + +// inflightSequence is a sequence of inFlightPackets. +// A inflightSequence can be sorted. +type inflightSequence []*inFlightPacket + +// nearestDeadlineTo returns the lower deadline to a passed reference time for all the packets in the inFlight queue. Used to re-arm the Ticker. We need to be careful and not pass a +func (seq inflightSequence) nearestDeadlineTo(t time.Time) time.Time { + // we default to a long wakeup + timeout := t.Add(time.Duration(SENDER_TICKER_MS) * time.Millisecond) + + for _, p := range seq { + if p.deadline.Before(timeout) { + timeout = p.deadline + } + } + + // what's past is past and we need to move on. + if timeout.Before(t) { + timeout = t.Add(time.Nanosecond) + } + return timeout +} + +// readyToSend eturns the subset of this sequence that has a expired deadline or +// is suitable for fast retransmission. +func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { + expired := make([]*inFlightPacket, 0) + for _, p := range seq { + if p.higherACKs >= 3 { + fmt.Println("DEBUG: fast retransmit for", p.packet.ID) + expired = append(expired, p) + continue + } else if p.deadline.Before(t) { + expired = append(expired, p) + } + } + return expired +} + +// implement sort.Interface +func (seq inflightSequence) Len() int { + return len(seq) +} + +// implement sort.Interface +func (seq inflightSequence) Swap(i, j int) { + seq[i], seq[j] = seq[j], seq[i] +} + +// implement sort.Interface +func (seq inflightSequence) Less(i, j int) bool { + return seq[i].packet.ID < seq[j].packet.ID +} + +// A incomingSequence is an array of sequentialPackets. It's used to store both incoming and outgoing packet queues. +// a incomingSequence can be sorted. +type incomingSequence []sequentialPacket + +// implement sort.Interface +func (ps incomingSequence) Len() int { + return len(ps) +} + +// implement sort.Interface +func (ps incomingSequence) Swap(i, j int) { + ps[i], ps[j] = ps[j], ps[i] +} + +// implement sort.Interface +func (ps incomingSequence) Less(i, j int) bool { + return ps[i].ID() < ps[j].ID() +} + +type incomingPacket struct { + packet *model.Packet +} + +func (ip *incomingPacket) ID() model.PacketID { + return ip.packet.ID +} + +func (ip *incomingPacket) ExtractACKs() []model.PacketID { + return ip.packet.ACKs +} + +func (ip *incomingPacket) Packet() *model.Packet { + return ip.packet +} + +// incomingPacketSeen is a struct that the receiver sends us when a new packet is seen. +type incomingPacketSeen struct { + id model.PacketID + acks []model.PacketID +} diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go new file mode 100644 index 00000000..ba1b27a4 --- /dev/null +++ b/internal/reliabletransport/receiver.go @@ -0,0 +1,99 @@ +package reliabletransport + +import ( + "sort" + + "github.com/ooni/minivpn/internal/model" +) + +// +// incomingPacketHandler implementation. +// + +// TODO rename to receiver +// reliableIncoming is the receiver part that sees incoming packets moving up the stack. +type reliableIncoming struct { + // logger is the logger to use + logger model.Logger + + // incomingPackets are packets to process (reorder) before they are passed to TLS layer. + incomingPackets incomingSequence + + // incomingSeen is a channel where we send notifications for incoming packets seen by us. + incomingSeen chan<- incomingPacketSeen + + // lastConsumed is the last [model.PacketID] that we have passed to the control layer above us. + lastConsumed model.PacketID +} + +func newReliableIncoming(logger model.Logger, i chan incomingPacketSeen) *reliableIncoming { + return &reliableIncoming{ + logger: logger, + incomingPackets: []sequentialPacket{}, + incomingSeen: i, + lastConsumed: 0, + } +} + +// NotifySeen sends a incomingPacketSeen object to the shared channel where the sender will read it. +func (r *reliableIncoming) NotifySeen(p *model.Packet) bool { + incoming := incomingPacketSeen{ + id: p.ID, + acks: p.ACKs, + } + if p.ID > 0 && p.ID <= r.lastConsumed { + r.logger.Warnf("got packet id %v, but last consumed is %v\n", p.ID, r.lastConsumed) + } + r.incomingSeen <- incoming + return true + +} + +func (r *reliableIncoming) MaybeInsertIncoming(p *model.Packet) bool { + // we drop if at capacity, by default double the size of the outgoing buffer + if len(r.incomingPackets) >= RELIABLE_RECV_BUFFER_SIZE { + r.logger.Warnf("dropping packet, buffer full with len %v", len(r.incomingPackets)) + return false + } + + inc := &incomingPacket{p} + // insert this one in the queue to pass to TLS. + r.incomingPackets = append(r.incomingPackets, inc) + return true +} + +func (r *reliableIncoming) NextIncomingSequence() incomingSequence { + last := r.lastConsumed + ready := make([]sequentialPacket, 0, RELIABLE_RECV_BUFFER_SIZE) + + // sort them so that we begin with lower model.PacketID + sort.Sort(r.incomingPackets) + keep := r.incomingPackets[:0] + + for i, p := range r.incomingPackets { + if p.ID()-last == 1 { + ready = append(ready, p) + last += 1 + } else if p.ID() > last { + // here we broke sequentiality, but we want + // to drop anything that is below lastConsumed + keep = append(keep, r.incomingPackets[i:]...) + break + } + } + r.lastConsumed = last + r.incomingPackets = keep + //if len(ready) != 0 { + //r.logger.Debugf(">> BUMP LAST CONSUMED TO %v", last) + //r.logger.Debugf(">> incoming now: %v", keep) + //} + return ready +} + +// assert that reliableIncoming implements incomingPacketHandler +var _ incomingPacketHandler = &reliableIncoming{ + logger: nil, + incomingPackets: []sequentialPacket{}, + incomingSeen: make(chan<- incomingPacketSeen), + lastConsumed: 0, +} diff --git a/internal/reliabletransport/reliabletransport.go b/internal/reliabletransport/reliabletransport.go index faf4c9e4..e588156c 100644 --- a/internal/reliabletransport/reliabletransport.go +++ b/internal/reliabletransport/reliabletransport.go @@ -4,6 +4,7 @@ package reliabletransport import ( "bytes" "fmt" + "time" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/session" @@ -39,8 +40,10 @@ func (s *Service) StartWorkers( workersManager *workers.Manager, sessionManager *session.Manager, ) { + ws := &workersState{ logger: logger, + incomingSeen: make(chan incomingPacketSeen, 20), dataOrControlToMuxer: *s.DataOrControlToMuxer, controlToReliable: s.ControlToReliable, muxerToReliable: s.MuxerToReliable, @@ -57,6 +60,9 @@ type workersState struct { // logger is the logger to use logger model.Logger + // incomingSeen ins the shared channel to connect sender and receiver goroutines. + incomingSeen chan incomingPacketSeen + // dataOrControlToMuxer is the channel where we write packets going down the stack. dataOrControlToMuxer chan<- *model.Packet @@ -76,7 +82,8 @@ type workersState struct { workersManager *workers.Manager } -// moveUpWorker moves packets up the stack +// moveUpWorker moves packets up the stack (receiver) +// TODO: move worker to receiver.go func (ws *workersState) moveUpWorker() { workerName := fmt.Sprintf("%s: moveUpWorker", serviceName) @@ -87,6 +94,8 @@ func (ws *workersState) moveUpWorker() { ws.logger.Debugf("%s: started", workerName) + receiver := newReliableIncoming(ws.logger, ws.incomingSeen) + // TODO: do we need to have notifications from the control channel // to reset state or can we do this implicitly? @@ -114,19 +123,39 @@ func (ws *workersState) moveUpWorker() { continue } + // TODO: drop a packet too far away (we can use lastConsumed) + // possibly ACK the incoming packet + // TODO: move this responsibility to the sender. if err := ws.maybeACK(packet); err != nil { ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) continue } - // TODO: here we should track ACKs + ws.logger.Debugf( + "notify: ", + packet.ID, + packet.ACKs, + ) + + // TODO: possibly refactor so that the writing to the channel happens here + // the fact this channel write is hidden makes following this harder + // TODO: notify before dropping? + receiver.NotifySeen(packet) - // POSSIBLY BLOCK delivering to the upper layer - select { - case ws.reliableToControl <- packet: - case <-ws.workersManager.ShouldShutdown(): - return + if inserted := receiver.MaybeInsertIncoming(packet); !inserted { + continue + } + + // TODO drop first ------------------------------------------------ + ready := receiver.NextIncomingSequence() + for _, nextPacket := range ready { + // POSSIBLY BLOCK delivering to the upper layer + select { + case ws.reliableToControl <- nextPacket.Packet(): + case <-ws.workersManager.ShouldShutdown(): + return + } } case <-ws.workersManager.ShouldShutdown(): @@ -135,7 +164,8 @@ func (ws *workersState) moveUpWorker() { } } -// moveDownWorker moves packets down the stack +// moveDownWorker moves packets down the stack (sender) +// TODO move the worker to sender.go func (ws *workersState) moveDownWorker() { workerName := fmt.Sprintf("%s: moveDownWorker", serviceName) @@ -146,13 +176,13 @@ func (ws *workersState) moveDownWorker() { ws.logger.Debugf("%s: started", workerName) - // TODO: we should have timer for retransmission + sender := newReliableSender(ws.logger, ws.incomingSeen) + ticker := time.NewTicker(time.Duration(SENDER_TICKER_MS) * time.Millisecond) + for { // POSSIBLY BLOCK reading the next packet we should move down the stack select { case packet := <-ws.controlToReliable: - // TODO: here we should treat control packets specially - ws.logger.Infof( "> %s localID=%x remoteID=%x [%d bytes]", packet.Opcode, @@ -161,11 +191,51 @@ func (ws *workersState) moveDownWorker() { len(packet.Payload), ) - // POSSIBLY BLOCK delivering this packet to the lower layer - select { - case ws.dataOrControlToMuxer <- packet: - case <-ws.workersManager.ShouldShutdown(): - return + sender.TryInsertOutgoingPacket(packet) + // schedule for inmediate wakeup + // so that the ticker will wakeup and see if there's anything pending to be sent. + ticker.Reset(time.Nanosecond) + + case incomingSeen := <-sender.incomingSeen: + // possibly evict any acked packet + sender.OnIncomingPacketSeen(incomingSeen) + + // schedule for inmediate wakeup, because we probably need to update ACKs + ticker.Reset(time.Nanosecond) + + // TODO need to ACK here if no packets pending. + // I think we can just call withExpiredDeadline and ACK if len(expired) is 0 + + case <-ticker.C: + // First of all, we reset the ticker to the next timeout. + // By default, that's going to return one minute if there are no packets + // in the in-flight queue. + + // nearestDeadlineTo(now) ensures that we do not receive a time before now, and + // that increments the passed moment by an epsilon if all deadlines are expired, + // so it should be safe to reset the ticker with that timeout. + now := time.Now() + timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) + + ws.logger.Debug("") + ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) + + ticker.Reset(timeout.Sub(now)) + + // we flush everything that is ready to be sent. + scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) + ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) + + for _, p := range scheduledNow { + p.ScheduleForRetransmission(now) + // TODO ------------------------------------------- + // ideally, we want to append any pending ACKs here + select { + case ws.dataOrControlToMuxer <- p.packet: + ws.logger.Debugf("==> sent packet with ID: %v", p.packet.ID) + case <-ws.workersManager.ShouldShutdown(): + return + } } case <-ws.workersManager.ShouldShutdown(): @@ -177,7 +247,9 @@ func (ws *workersState) moveDownWorker() { // maybeACK sends an ACK when needed. func (ws *workersState) maybeACK(packet *model.Packet) error { // currently we are ACKing every packet - // TODO: implement better ACKing strategy + // TODO: implement better ACKing strategy - this is basically moving the responsibility + // to the sender, and then either appending up to 4 ACKs to the ACK array of an outgoing + // packet, or sending a single ACK (if there's nothing pending to be sent). // this function will fail if we don't know the remote session ID ACK, err := ws.sessionManager.NewACKForPacket(packet) @@ -188,6 +260,7 @@ func (ws *workersState) maybeACK(packet *model.Packet) error { // move the packet down. CAN BLOCK writing to the shared channel to muxer. select { case ws.dataOrControlToMuxer <- ACK: + ws.logger.Debugf("ack for remote packet id: %d", packet.ID) return nil case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go new file mode 100644 index 00000000..a68ecd3f --- /dev/null +++ b/internal/reliabletransport/sender.go @@ -0,0 +1,106 @@ +package reliabletransport + +import ( + "sort" + + "github.com/ooni/minivpn/internal/model" +) + +// reliableSender keeps state about the outgoing packet queue, and implements outgoingPacketHandler. +// Please use the constructor `newReliableSender()` +type reliableSender struct { + // logger is the logger to use + logger model.Logger + + // incomingSeen is a channel where we receive notifications for incoming packets seen by the receiver. + incomingSeen <-chan incomingPacketSeen + + // inFlight is the array of in-flight packets. + inFlight []*inFlightPacket + + // pendingACKsToSend is the array of packets that we still need to ACK. + pendingACKsToSend []model.PacketID +} + +// newReliableSender returns a new instance of reliableOutgoing. +func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliableSender { + return &reliableSender{ + logger: logger, + incomingSeen: i, + inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), + pendingACKsToSend: []model.PacketID{}, + } +} + +// +// outgoingPacketHandler implementation. +// + +func (r *reliableSender) TryInsertOutgoingPacket(p *model.Packet) bool { + if len(r.inFlight) >= RELIABLE_SEND_BUFFER_SIZE { + r.logger.Warn("outgoing array full, dropping packet") + return false + } + new := newInFlightPacket(p) + r.inFlight = append(r.inFlight, new) + return true +} + +// MaybeEvictOrBumpPacketAfterACK iterates over all the in-flight packets. For each one, +// and either evicts it (if the PacketID matches), or bumps the internal withHigherACK count in the +// packet (if the PacketID from the ACK is higher than the packet in the queue). +func (r *reliableSender) MaybeEvictOrBumpPacketAfterACK(acked model.PacketID) bool { + // TODO: it *should* be sorted, can it be not sorted? + sort.Sort(inflightSequence(r.inFlight)) + + packets := r.inFlight + for i, p := range packets { + if acked > p.packet.ID { + // we have received an ACK for a packet with a higher pid, so let's bump it + p.ACKForHigherPacket() + + } else if acked == p.packet.ID { + + // we have a match for the ack we just received: eviction it is! + r.logger.Debugf("evicting packet %v", p.packet.ID) + + // first we swap this element with the last one: + packets[i], packets[len(packets)-1] = packets[len(packets)-1], packets[i] + + // and now exclude the last element: + r.inFlight = packets[:len(packets)-1] + + // since we had sorted the in-flight array, we're done here. + return true + } + } + return false +} + +// this should return at most MAX_ACKS_PER_OUTGOING_PACKET packet IDs. +func (r *reliableSender) NextPacketIDsToACK() []model.PacketID { + var next []model.PacketID + if len(r.pendingACKsToSend) <= MAX_ACKS_PER_OUTGOING_PACKET { + next = r.pendingACKsToSend[:len(r.pendingACKsToSend)] + r.pendingACKsToSend = r.pendingACKsToSend[:0] + return next + } + + next = r.pendingACKsToSend[:MAX_ACKS_PER_OUTGOING_PACKET] + r.pendingACKsToSend = r.pendingACKsToSend[MAX_ACKS_PER_OUTGOING_PACKET : len(r.pendingACKsToSend)-1] + return next +} + +func (r *reliableSender) OnIncomingPacketSeen(ips incomingPacketSeen) { + // we have received an incomingPacketSeen on the shared channel, we need to do two things: + + // 1. add the ID to the queue of packets to be acknowledged. + r.pendingACKsToSend = append(r.pendingACKsToSend, ips.id) + + // 2. for every ACK received, see if we need to evict or bump the in-flight packet. + for _, packetID := range ips.acks { + r.MaybeEvictOrBumpPacketAfterACK(packetID) + } +} + +var _ outgoingPacketHandler = &reliableSender{} diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go new file mode 100644 index 00000000..5ff72170 --- /dev/null +++ b/internal/reliabletransport/sender_test.go @@ -0,0 +1,346 @@ +package reliabletransport + +import ( + "reflect" + "testing" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" +) + +// +// tests for reliableOutgoing +// + +func Test_reliableOutgoing_TryInsertOutgoingPacket(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + inFlight inflightSequence + } + type args struct { + p *model.Packet + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "insert on empty array", + fields: fields{ + inFlight: inflightSequence([]*inFlightPacket{}), + }, + args: args{ + p: &model.Packet{ID: 1}, + }, + want: true, + }, + { + name: "insert on full array", + fields: fields{ + inFlight: inflightSequence([]*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}, + {packet: &model.Packet{ID: 4}}, + {packet: &model.Packet{ID: 5}}, + {packet: &model.Packet{ID: 6}}, + {packet: &model.Packet{ID: 7}}, + {packet: &model.Packet{ID: 8}}, + {packet: &model.Packet{ID: 9}}, + {packet: &model.Packet{ID: 10}}, + {packet: &model.Packet{ID: 11}}, + {packet: &model.Packet{ID: 12}}, + }), + }, + args: args{ + p: &model.Packet{ID: 13}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + inFlight: tt.fields.inFlight, + } + if got := r.TryInsertOutgoingPacket(tt.args.p); got != tt.want { + t.Errorf("reliableOutgoing.TryInsertOutgoingPacket() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_reliableOutgoing_NextPacketIDsToACK(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + pendingACKsToSend []model.PacketID + } + tests := []struct { + name string + fields fields + want []model.PacketID + }{ + { + name: "empty array", + fields: fields{ + pendingACKsToSend: []model.PacketID{}, + }, + want: []model.PacketID{}, + }, + { + name: "single element", + fields: fields{ + pendingACKsToSend: []model.PacketID{11}, + }, + want: []model.PacketID{11}, + }, + { + name: "tree elements", + fields: fields{ + pendingACKsToSend: []model.PacketID{11, 12, 13}, + }, + want: []model.PacketID{11, 12, 13}, + }, + { + name: "five elements", + fields: fields{ + pendingACKsToSend: []model.PacketID{11, 12, 13, 14, 15}, + }, + want: []model.PacketID{11, 12, 13, 14}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + pendingACKsToSend: tt.fields.pendingACKsToSend, + } + if got := r.NextPacketIDsToACK(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("reliableOutgoing.NextPacketIDsToACK() = %v, want %v", got, tt.want) + } + }) + } +} + +// +// tests for reliableIncoming +// + +// testIncomingPacket is a sequentialPacket for testing incomingPackets +type testIncomingPacket struct { + id model.PacketID + acks []model.PacketID +} + +func (ip *testIncomingPacket) ID() model.PacketID { + return ip.id +} + +func (ip *testIncomingPacket) ExtractACKs() []model.PacketID { + return ip.acks +} + +func (ip *testIncomingPacket) Packet() *model.Packet { + return &model.Packet{ID: ip.id} +} + +var _ sequentialPacket = &testIncomingPacket{} + +func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + incomingPackets incomingSequence + } + type args struct { + p *testIncomingPacket + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "empty incoming, insert one", + fields: fields{ + incomingPackets: []sequentialPacket{}, + }, + args: args{ + &testIncomingPacket{id: 1}, + }, + want: true, + }, + { + name: "almost full incoming, insert one", + fields: fields{ + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 4}, + &testIncomingPacket{id: 5}, + &testIncomingPacket{id: 6}, + &testIncomingPacket{id: 7}, + &testIncomingPacket{id: 8}, + &testIncomingPacket{id: 9}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + }, + }, + args: args{ + &testIncomingPacket{id: 12}, + }, + want: true, + }, + { + name: "full incoming, cannot insert", + fields: fields{ + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 4}, + &testIncomingPacket{id: 5}, + &testIncomingPacket{id: 6}, + &testIncomingPacket{id: 7}, + &testIncomingPacket{id: 8}, + &testIncomingPacket{id: 9}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + }, + }, + args: args{ + &testIncomingPacket{id: 13}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableIncoming{ + logger: log.Log, + incomingPackets: tt.fields.incomingPackets, + } + if got := r.MaybeInsertIncoming(tt.args.p.Packet()); got != tt.want { + t.Errorf("reliableQueue.MaybeInsertIncoming() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_reliableQueue_NextIncomingSequence(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + lastConsumed model.PacketID + incomingPackets incomingSequence + } + tests := []struct { + name string + fields fields + want incomingSequence + }{ + { + name: "empty sequence", + fields: fields{ + incomingPackets: []sequentialPacket{}, + lastConsumed: model.PacketID(0), + }, + want: []sequentialPacket{}, + }, + { + name: "single packet", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + }, + }, + { + name: "series of sequential packets", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + { + name: "series of sequential packets with hole", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 5}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + { + name: "series of sequential packets with hole, lastConsumed higher", + fields: fields{ + lastConsumed: model.PacketID(10), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 5}, + }, + }, + want: []sequentialPacket{}, + }, + { + name: "series of sequential packets with hole, lastConsumed higher, some above", + fields: fields{ + lastConsumed: model.PacketID(10), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + &testIncomingPacket{id: 20}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableIncoming{ + lastConsumed: tt.fields.lastConsumed, + incomingPackets: tt.fields.incomingPackets, + } + if got := r.NextIncomingSequence(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("reliableQueue.NextIncomingSequence() = %v, want %v", got, tt.want) + } + }) + } +} From aebf87ec93bc423b128221e40028a3882cfa2bf4 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 17:46:21 +0100 Subject: [PATCH 02/78] rename & reordering --- internal/reliabletransport/packets.go | 1 - internal/reliabletransport/receiver.go | 101 ++++++- .../reliabletransport/reliabletransport.go | 268 ------------------ internal/reliabletransport/sender.go | 110 +++++++ internal/reliabletransport/sender_test.go | 4 +- internal/reliabletransport/service.go | 79 ++++++ 6 files changed, 283 insertions(+), 280 deletions(-) delete mode 100644 internal/reliabletransport/reliabletransport.go create mode 100644 internal/reliabletransport/service.go diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index f77c9782..af2824a3 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -38,7 +38,6 @@ func (p *inFlightPacket) ExtractACKs() []model.PacketID { return p.packet.ACKs } -// TODO leaving Fast retransmission out for now. // ACKForHigherPacket increments the number of acks received for a higher pid than this packet. This will influence the fast rexmit selection algorithm. func (p *inFlightPacket) ACKForHigherPacket() { p.higherACKs += 1 diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index ba1b27a4..69fa842b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -1,18 +1,101 @@ package reliabletransport import ( + "bytes" + "fmt" "sort" "github.com/ooni/minivpn/internal/model" ) +// moveUpWorker moves packets up the stack (receiver) +func (ws *workersState) moveUpWorker() { + workerName := fmt.Sprintf("%s: moveUpWorker", serviceName) + + defer func() { + ws.workersManager.OnWorkerDone(workerName) + ws.workersManager.StartShutdown() + }() + + ws.logger.Debugf("%s: started", workerName) + + receiver := newReliableReceiver(ws.logger, ws.incomingSeen) + + // TODO: do we need to have notifications from the control channel + // to reset state or can we do this implicitly? + + for { + // POSSIBLY BLOCK reading a packet to move up the stack + // or POSSIBLY BLOCK waiting for notifications + select { + case packet := <-ws.muxerToReliable: + ws.logger.Infof( + "< %s localID=%x remoteID=%x [%d bytes]", + packet.Opcode, + packet.LocalSessionID, + packet.RemoteSessionID, + len(packet.Payload), + ) + + // drop a packet that is not for our session + if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { + ws.logger.Warnf( + "%s: packet with invalid RemoteSessionID: expected %x; got %x", + workerName, + ws.sessionManager.LocalSessionID(), + packet.RemoteSessionID, + ) + continue + } + + // TODO: drop a packet too far away (we can use lastConsumed) + + // possibly ACK the incoming packet + // TODO: move this responsibility to the sender. + if err := ws.maybeACK(packet); err != nil { + ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) + continue + } + + ws.logger.Debugf( + "notify: ", + packet.ID, + packet.ACKs, + ) + + // TODO: possibly refactor so that the writing to the channel happens here + // the fact this channel write is hidden makes following this harder + // TODO: notify before dropping? + receiver.NotifySeen(packet) + + if inserted := receiver.MaybeInsertIncoming(packet); !inserted { + continue + } + + // TODO drop first ------------------------------------------------ + ready := receiver.NextIncomingSequence() + for _, nextPacket := range ready { + // POSSIBLY BLOCK delivering to the upper layer + select { + case ws.reliableToControl <- nextPacket.Packet(): + case <-ws.workersManager.ShouldShutdown(): + return + } + } + + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} + // // incomingPacketHandler implementation. // -// TODO rename to receiver -// reliableIncoming is the receiver part that sees incoming packets moving up the stack. -type reliableIncoming struct { +// reliableReceiver is the receiver part that sees incoming packets moving up the stack. +// Please use the constructor `newReliableReceiver()` +type reliableReceiver struct { // logger is the logger to use logger model.Logger @@ -26,8 +109,8 @@ type reliableIncoming struct { lastConsumed model.PacketID } -func newReliableIncoming(logger model.Logger, i chan incomingPacketSeen) *reliableIncoming { - return &reliableIncoming{ +func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliableReceiver { + return &reliableReceiver{ logger: logger, incomingPackets: []sequentialPacket{}, incomingSeen: i, @@ -36,7 +119,7 @@ func newReliableIncoming(logger model.Logger, i chan incomingPacketSeen) *reliab } // NotifySeen sends a incomingPacketSeen object to the shared channel where the sender will read it. -func (r *reliableIncoming) NotifySeen(p *model.Packet) bool { +func (r *reliableReceiver) NotifySeen(p *model.Packet) bool { incoming := incomingPacketSeen{ id: p.ID, acks: p.ACKs, @@ -49,7 +132,7 @@ func (r *reliableIncoming) NotifySeen(p *model.Packet) bool { } -func (r *reliableIncoming) MaybeInsertIncoming(p *model.Packet) bool { +func (r *reliableReceiver) MaybeInsertIncoming(p *model.Packet) bool { // we drop if at capacity, by default double the size of the outgoing buffer if len(r.incomingPackets) >= RELIABLE_RECV_BUFFER_SIZE { r.logger.Warnf("dropping packet, buffer full with len %v", len(r.incomingPackets)) @@ -62,7 +145,7 @@ func (r *reliableIncoming) MaybeInsertIncoming(p *model.Packet) bool { return true } -func (r *reliableIncoming) NextIncomingSequence() incomingSequence { +func (r *reliableReceiver) NextIncomingSequence() incomingSequence { last := r.lastConsumed ready := make([]sequentialPacket, 0, RELIABLE_RECV_BUFFER_SIZE) @@ -91,7 +174,7 @@ func (r *reliableIncoming) NextIncomingSequence() incomingSequence { } // assert that reliableIncoming implements incomingPacketHandler -var _ incomingPacketHandler = &reliableIncoming{ +var _ incomingPacketHandler = &reliableReceiver{ logger: nil, incomingPackets: []sequentialPacket{}, incomingSeen: make(chan<- incomingPacketSeen), diff --git a/internal/reliabletransport/reliabletransport.go b/internal/reliabletransport/reliabletransport.go deleted file mode 100644 index e588156c..00000000 --- a/internal/reliabletransport/reliabletransport.go +++ /dev/null @@ -1,268 +0,0 @@ -// Package reliabletransport implements the reliable transport. -package reliabletransport - -import ( - "bytes" - "fmt" - "time" - - "github.com/ooni/minivpn/internal/model" - "github.com/ooni/minivpn/internal/session" - "github.com/ooni/minivpn/internal/workers" -) - -var ( - serviceName = "reliabletransport" -) - -// Service is the reliable service. Make sure you initialize -// the channels before invoking [Service.StartWorkers]. -type Service struct { - // DataOrControlToMuxer is a shared channel that moves packets down to the muxer - DataOrControlToMuxer *chan *model.Packet - - // ControlToReliable moves packets down to us - ControlToReliable chan *model.Packet - - // MuxerToReliable moves packets up to us - MuxerToReliable chan *model.Packet - - // ReliableToControl moves packets up from us to the control layer above - ReliableToControl *chan *model.Packet -} - -// StartWorkers starts the reliable-transport workers. See the [ARCHITECTURE] -// file for more information about the reliable-transport workers. -// -// [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md -func (s *Service) StartWorkers( - logger model.Logger, - workersManager *workers.Manager, - sessionManager *session.Manager, -) { - - ws := &workersState{ - logger: logger, - incomingSeen: make(chan incomingPacketSeen, 20), - dataOrControlToMuxer: *s.DataOrControlToMuxer, - controlToReliable: s.ControlToReliable, - muxerToReliable: s.MuxerToReliable, - reliableToControl: *s.ReliableToControl, - sessionManager: sessionManager, - workersManager: workersManager, - } - workersManager.StartWorker(ws.moveUpWorker) - workersManager.StartWorker(ws.moveDownWorker) -} - -// workersState contains the reliable workers state -type workersState struct { - // logger is the logger to use - logger model.Logger - - // incomingSeen ins the shared channel to connect sender and receiver goroutines. - incomingSeen chan incomingPacketSeen - - // dataOrControlToMuxer is the channel where we write packets going down the stack. - dataOrControlToMuxer chan<- *model.Packet - - // controlToReliable is the channel from which we read packets going down the stack. - controlToReliable <-chan *model.Packet - - // muxerToReliable is the channel from which we read packets going up the stack. - muxerToReliable <-chan *model.Packet - - // reliableToControl is the channel where we write packets going up the stack. - reliableToControl chan<- *model.Packet - - // sessionManager manages the OpenVPN session. - sessionManager *session.Manager - - // workersManager controls the workers lifecycle. - workersManager *workers.Manager -} - -// moveUpWorker moves packets up the stack (receiver) -// TODO: move worker to receiver.go -func (ws *workersState) moveUpWorker() { - workerName := fmt.Sprintf("%s: moveUpWorker", serviceName) - - defer func() { - ws.workersManager.OnWorkerDone(workerName) - ws.workersManager.StartShutdown() - }() - - ws.logger.Debugf("%s: started", workerName) - - receiver := newReliableIncoming(ws.logger, ws.incomingSeen) - - // TODO: do we need to have notifications from the control channel - // to reset state or can we do this implicitly? - - for { - // POSSIBLY BLOCK reading a packet to move up the stack - // or POSSIBLY BLOCK waiting for notifications - select { - case packet := <-ws.muxerToReliable: - ws.logger.Infof( - "< %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) - - // drop a packet that is not for our session - if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { - ws.logger.Warnf( - "%s: packet with invalid RemoteSessionID: expected %x; got %x", - workerName, - ws.sessionManager.LocalSessionID(), - packet.RemoteSessionID, - ) - continue - } - - // TODO: drop a packet too far away (we can use lastConsumed) - - // possibly ACK the incoming packet - // TODO: move this responsibility to the sender. - if err := ws.maybeACK(packet); err != nil { - ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) - continue - } - - ws.logger.Debugf( - "notify: ", - packet.ID, - packet.ACKs, - ) - - // TODO: possibly refactor so that the writing to the channel happens here - // the fact this channel write is hidden makes following this harder - // TODO: notify before dropping? - receiver.NotifySeen(packet) - - if inserted := receiver.MaybeInsertIncoming(packet); !inserted { - continue - } - - // TODO drop first ------------------------------------------------ - ready := receiver.NextIncomingSequence() - for _, nextPacket := range ready { - // POSSIBLY BLOCK delivering to the upper layer - select { - case ws.reliableToControl <- nextPacket.Packet(): - case <-ws.workersManager.ShouldShutdown(): - return - } - } - - case <-ws.workersManager.ShouldShutdown(): - return - } - } -} - -// moveDownWorker moves packets down the stack (sender) -// TODO move the worker to sender.go -func (ws *workersState) moveDownWorker() { - workerName := fmt.Sprintf("%s: moveDownWorker", serviceName) - - defer func() { - ws.workersManager.OnWorkerDone(workerName) - ws.workersManager.StartShutdown() - }() - - ws.logger.Debugf("%s: started", workerName) - - sender := newReliableSender(ws.logger, ws.incomingSeen) - ticker := time.NewTicker(time.Duration(SENDER_TICKER_MS) * time.Millisecond) - - for { - // POSSIBLY BLOCK reading the next packet we should move down the stack - select { - case packet := <-ws.controlToReliable: - ws.logger.Infof( - "> %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) - - sender.TryInsertOutgoingPacket(packet) - // schedule for inmediate wakeup - // so that the ticker will wakeup and see if there's anything pending to be sent. - ticker.Reset(time.Nanosecond) - - case incomingSeen := <-sender.incomingSeen: - // possibly evict any acked packet - sender.OnIncomingPacketSeen(incomingSeen) - - // schedule for inmediate wakeup, because we probably need to update ACKs - ticker.Reset(time.Nanosecond) - - // TODO need to ACK here if no packets pending. - // I think we can just call withExpiredDeadline and ACK if len(expired) is 0 - - case <-ticker.C: - // First of all, we reset the ticker to the next timeout. - // By default, that's going to return one minute if there are no packets - // in the in-flight queue. - - // nearestDeadlineTo(now) ensures that we do not receive a time before now, and - // that increments the passed moment by an epsilon if all deadlines are expired, - // so it should be safe to reset the ticker with that timeout. - now := time.Now() - timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) - - ws.logger.Debug("") - ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) - - ticker.Reset(timeout.Sub(now)) - - // we flush everything that is ready to be sent. - scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) - ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) - - for _, p := range scheduledNow { - p.ScheduleForRetransmission(now) - // TODO ------------------------------------------- - // ideally, we want to append any pending ACKs here - select { - case ws.dataOrControlToMuxer <- p.packet: - ws.logger.Debugf("==> sent packet with ID: %v", p.packet.ID) - case <-ws.workersManager.ShouldShutdown(): - return - } - } - - case <-ws.workersManager.ShouldShutdown(): - return - } - } -} - -// maybeACK sends an ACK when needed. -func (ws *workersState) maybeACK(packet *model.Packet) error { - // currently we are ACKing every packet - // TODO: implement better ACKing strategy - this is basically moving the responsibility - // to the sender, and then either appending up to 4 ACKs to the ACK array of an outgoing - // packet, or sending a single ACK (if there's nothing pending to be sent). - - // this function will fail if we don't know the remote session ID - ACK, err := ws.sessionManager.NewACKForPacket(packet) - if err != nil { - return err - } - - // move the packet down. CAN BLOCK writing to the shared channel to muxer. - select { - case ws.dataOrControlToMuxer <- ACK: - ws.logger.Debugf("ack for remote packet id: %d", packet.ID) - return nil - case <-ws.workersManager.ShouldShutdown(): - return workers.ErrShutdown - } -} diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index a68ecd3f..9632f611 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -1,11 +1,98 @@ package reliabletransport import ( + "fmt" "sort" + "time" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/workers" ) +// moveDownWorker moves packets down the stack (sender) +// TODO move the worker to sender.go +func (ws *workersState) moveDownWorker() { + workerName := fmt.Sprintf("%s: moveDownWorker", serviceName) + + defer func() { + ws.workersManager.OnWorkerDone(workerName) + ws.workersManager.StartShutdown() + }() + + ws.logger.Debugf("%s: started", workerName) + + sender := newReliableSender(ws.logger, ws.incomingSeen) + ticker := time.NewTicker(time.Duration(SENDER_TICKER_MS) * time.Millisecond) + + for { + // POSSIBLY BLOCK reading the next packet we should move down the stack + select { + case packet := <-ws.controlToReliable: + ws.logger.Infof( + "> %s localID=%x remoteID=%x [%d bytes]", + packet.Opcode, + packet.LocalSessionID, + packet.RemoteSessionID, + len(packet.Payload), + ) + + sender.TryInsertOutgoingPacket(packet) + // schedule for inmediate wakeup + // so that the ticker will wakeup and see if there's anything pending to be sent. + ticker.Reset(time.Nanosecond) + + case incomingSeen := <-sender.incomingSeen: + // possibly evict any acked packet + sender.OnIncomingPacketSeen(incomingSeen) + + // schedule for inmediate wakeup, because we probably need to update ACKs + ticker.Reset(time.Nanosecond) + + // TODO need to ACK here if no packets pending. + // I think we can just call withExpiredDeadline and ACK if len(expired) is 0 + + case <-ticker.C: + // First of all, we reset the ticker to the next timeout. + // By default, that's going to return one minute if there are no packets + // in the in-flight queue. + + // nearestDeadlineTo(now) ensures that we do not receive a time before now, and + // that increments the passed moment by an epsilon if all deadlines are expired, + // so it should be safe to reset the ticker with that timeout. + now := time.Now() + timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) + + ws.logger.Debug("") + ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) + + ticker.Reset(timeout.Sub(now)) + + // we flush everything that is ready to be sent. + scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) + ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) + + for _, p := range scheduledNow { + p.ScheduleForRetransmission(now) + // TODO ------------------------------------------- + // ideally, we want to append any pending ACKs here + select { + case ws.dataOrControlToMuxer <- p.packet: + ws.logger.Debugf("==> sent packet with ID: %v", p.packet.ID) + case <-ws.workersManager.ShouldShutdown(): + return + } + } + + case <-ws.workersManager.ShouldShutdown(): + return + } + } +} + +// +// outgoingPacketHandler implementation. +// + // reliableSender keeps state about the outgoing packet queue, and implements outgoingPacketHandler. // Please use the constructor `newReliableSender()` type reliableSender struct { @@ -104,3 +191,26 @@ func (r *reliableSender) OnIncomingPacketSeen(ips incomingPacketSeen) { } var _ outgoingPacketHandler = &reliableSender{} + +// maybeACK sends an ACK when needed. +func (ws *workersState) maybeACK(packet *model.Packet) error { + // currently we are ACKing every packet + // TODO: implement better ACKing strategy - this is basically moving the responsibility + // to the sender, and then either appending up to 4 ACKs to the ACK array of an outgoing + // packet, or sending a single ACK (if there's nothing pending to be sent). + + // this function will fail if we don't know the remote session ID + ACK, err := ws.sessionManager.NewACKForPacket(packet) + if err != nil { + return err + } + + // move the packet down. CAN BLOCK writing to the shared channel to muxer. + select { + case ws.dataOrControlToMuxer <- ACK: + ws.logger.Debugf("ack for remote packet id: %d", packet.ID) + return nil + case <-ws.workersManager.ShouldShutdown(): + return workers.ErrShutdown + } +} diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index 5ff72170..20b94e18 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -224,7 +224,7 @@ func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &reliableIncoming{ + r := &reliableReceiver{ logger: log.Log, incomingPackets: tt.fields.incomingPackets, } @@ -334,7 +334,7 @@ func Test_reliableQueue_NextIncomingSequence(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &reliableIncoming{ + r := &reliableReceiver{ lastConsumed: tt.fields.lastConsumed, incomingPackets: tt.fields.incomingPackets, } diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go new file mode 100644 index 00000000..e67c1f5c --- /dev/null +++ b/internal/reliabletransport/service.go @@ -0,0 +1,79 @@ +// Package reliabletransport implements the reliable transport. +package reliabletransport + +import ( + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +var ( + serviceName = "reliabletransport" +) + +// Service is the reliable service. Make sure you initialize +// the channels before invoking [Service.StartWorkers]. +type Service struct { + // DataOrControlToMuxer is a shared channel that moves packets down to the muxer + DataOrControlToMuxer *chan *model.Packet + + // ControlToReliable moves packets down to us + ControlToReliable chan *model.Packet + + // MuxerToReliable moves packets up to us + MuxerToReliable chan *model.Packet + + // ReliableToControl moves packets up from us to the control layer above + ReliableToControl *chan *model.Packet +} + +// StartWorkers starts the reliable-transport workers. See the [ARCHITECTURE] +// file for more information about the reliable-transport workers. +// +// [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md +func (s *Service) StartWorkers( + logger model.Logger, + workersManager *workers.Manager, + sessionManager *session.Manager, +) { + + ws := &workersState{ + logger: logger, + incomingSeen: make(chan incomingPacketSeen, 20), + dataOrControlToMuxer: *s.DataOrControlToMuxer, + controlToReliable: s.ControlToReliable, + muxerToReliable: s.MuxerToReliable, + reliableToControl: *s.ReliableToControl, + sessionManager: sessionManager, + workersManager: workersManager, + } + workersManager.StartWorker(ws.moveUpWorker) + workersManager.StartWorker(ws.moveDownWorker) +} + +// workersState contains the reliable workers state +type workersState struct { + // logger is the logger to use + logger model.Logger + + // incomingSeen ins the shared channel to connect sender and receiver goroutines. + incomingSeen chan incomingPacketSeen + + // dataOrControlToMuxer is the channel where we write packets going down the stack. + dataOrControlToMuxer chan<- *model.Packet + + // controlToReliable is the channel from which we read packets going down the stack. + controlToReliable <-chan *model.Packet + + // muxerToReliable is the channel from which we read packets going up the stack. + muxerToReliable <-chan *model.Packet + + // reliableToControl is the channel where we write packets going up the stack. + reliableToControl chan<- *model.Packet + + // sessionManager manages the OpenVPN session. + sessionManager *session.Manager + + // workersManager controls the workers lifecycle. + workersManager *workers.Manager +} From b420eb178c65618bc12f9de96e898a65d0eeb43b Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 18:57:11 +0100 Subject: [PATCH 03/78] wip --- cmd/minivpn2/iface.go | 30 ++++++ cmd/minivpn2/main.go | 144 +++++++++++++++++++++++++ internal/reliabletransport/receiver.go | 60 +++++++---- internal/reliabletransport/sender.go | 58 +++++++--- 4 files changed, 255 insertions(+), 37 deletions(-) create mode 100644 cmd/minivpn2/iface.go create mode 100644 cmd/minivpn2/main.go diff --git a/cmd/minivpn2/iface.go b/cmd/minivpn2/iface.go new file mode 100644 index 00000000..4102d938 --- /dev/null +++ b/cmd/minivpn2/iface.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "net" +) + +func getInterfaceByIP(ipAddr string) (*net.Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.String() == ipAddr { + return &iface, nil + } + } + } + } + + return nil, fmt.Errorf("interface with IP %s not found", ipAddr) +} diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go new file mode 100644 index 00000000..a6004a2f --- /dev/null +++ b/cmd/minivpn2/main.go @@ -0,0 +1,144 @@ +package main + +import ( + "context" + "fmt" + "net" + "os" + "os/exec" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/networkio" + "github.com/ooni/minivpn/internal/tun" + + "github.com/Doridian/water" + "github.com/jackpal/gateway" +) + +func runCmd(binaryPath string, args ...string) { + cmd := exec.Command(binaryPath, args...) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + err := cmd.Run() + if nil != err { + log.WithError(err).Warn("error running /sbin/ip") + } +} + +func runIP(args ...string) { + runCmd("/sbin/ip", args...) +} + +func runRoute(args ...string) { + runCmd("/sbin/route", args...) +} + +func main() { + log.SetLevel(log.DebugLevel) + + // parse the configuration file + options, err := model.ReadConfigFile(os.Args[1]) + if err != nil { + log.WithError(err).Fatal("NewOptionsFromFilePath") + } + log.Infof("parsed options: %s", options.ServerOptionsString()) + + // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer + if !options.HasAuthInfo() { + log.Fatal("options are missing auth info") + } + // connect to the server + dialer := networkio.NewDialer(log.Log, &net.Dialer{}) + ctx := context.Background() + endpoint := net.JoinHostPort(options.Remote, options.Port) + conn, err := dialer.DialContext(ctx, options.Proto.String(), endpoint) + if err != nil { + log.WithError(err).Fatal("dialer.DialContext") + } + + // The TLS will expire in 60 seconds by default, but we can pass + // a shorter timeout. + //ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + //defer cancel() + + // create a vpn tun Device + tunnel, err := tun.StartTUN(ctx, conn, options) + if err != nil { + log.WithError(err).Fatal("init error") + return + } + fmt.Printf("Local IP: %s\n", tunnel.LocalAddr()) + fmt.Printf("Gateway: %s\n", tunnel.RemoteAddr()) + + // create a tun interface on the OS + iface, err := water.New(water.Config{ + DeviceType: water.TUN, + }) + if err != nil { + log.WithError(err).Fatal("Unable to allocate TUN interface:") + } + + // TODO: investigate what's the maximum working MTU, additionally get it from flag. + MTU := 1420 + iface.SetMTU(MTU) + + localAddr := tunnel.LocalAddr().String() + remoteAddr := tunnel.RemoteAddr().String() + netMask := tunnel.NetMask() + + // discover local gateway IP, to + defaultGatewayIP, err := gateway.DiscoverGateway() + if err != nil { + log.Warn("could not discover default gateway IP, routes might be broken") + } + defaultInterfaceIP, err := gateway.DiscoverInterface() + if err != nil { + log.Warn("could not discover default route interface IP, routes might be broken") + } + defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) + if err != nil { + log.Warn("could not get default route interface, routes might be broken") + } + + if defaultGatewayIP != nil && defaultInterface != nil { + log.Infof("route add %s gw %v dev %s", options.Remote, defaultGatewayIP, defaultInterface.Name) + runRoute("add", options.Remote, "gw", defaultGatewayIP.String(), defaultInterface.Name) + } + + // we want the network CIDR for setting up the routes + network := &net.IPNet{ + IP: net.ParseIP(localAddr).Mask(netMask), + Mask: netMask, + } + + // configure the interface and bring it up + runIP("addr", "add", localAddr, "dev", iface.Name()) + runIP("link", "set", "dev", iface.Name(), "up") + runRoute("add", remoteAddr, "gw", localAddr) + runRoute("add", "-net", network.String(), "dev", iface.Name()) + runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) + + go func() { + for { + packet := make([]byte, 2000) + n, err := iface.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + tunnel.Write(packet[:n]) + } + }() + go func() { + for { + packet := make([]byte, 2000) + n, err := tunnel.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + iface.Write(packet[:n]) + } + }() + select {} +} diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 69fa842b..8bffb67f 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -48,31 +48,31 @@ func (ws *workersState) moveUpWorker() { continue } - // TODO: drop a packet too far away (we can use lastConsumed) - // possibly ACK the incoming packet // TODO: move this responsibility to the sender. - if err := ws.maybeACK(packet); err != nil { - ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) + /* + if err := ws.maybeACK(packet); err != nil { + ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) + continue + } + */ + + if inserted := receiver.MaybeInsertIncoming(packet); !inserted { + // this packet was not inserted in the queue: we drop it continue } - ws.logger.Debugf( - "notify: ", - packet.ID, - packet.ACKs, - ) - // TODO: possibly refactor so that the writing to the channel happens here // the fact this channel write is hidden makes following this harder - // TODO: notify before dropping? - receiver.NotifySeen(packet) - - if inserted := receiver.MaybeInsertIncoming(packet); !inserted { - continue + // receiver.NotifySeen(packet) + seenPacket, shouldDrop := receiver.newIncomingPacketSeen(packet) + switch shouldDrop { + case true: + receiver.logger.Warnf("got packet id %v, but last consumed is %v (dropping)\n", packet.ID, receiver.lastConsumed) + case false: + ws.incomingSeen <- seenPacket } - // TODO drop first ------------------------------------------------ ready := receiver.NextIncomingSequence() for _, nextPacket := range ready { // POSSIBLY BLOCK delivering to the upper layer @@ -109,6 +109,11 @@ type reliableReceiver struct { lastConsumed model.PacketID } +// NotifySeen implements incomingPacketHandler. +func (*reliableReceiver) NotifySeen(*model.Packet) bool { + panic("unimplemented") +} + func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliableReceiver { return &reliableReceiver{ logger: logger, @@ -119,6 +124,7 @@ func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliab } // NotifySeen sends a incomingPacketSeen object to the shared channel where the sender will read it. +/* func (r *reliableReceiver) NotifySeen(p *model.Packet) bool { incoming := incomingPacketSeen{ id: p.ID, @@ -131,6 +137,7 @@ func (r *reliableReceiver) NotifySeen(p *model.Packet) bool { return true } +*/ func (r *reliableReceiver) MaybeInsertIncoming(p *model.Packet) bool { // we drop if at capacity, by default double the size of the outgoing buffer @@ -166,13 +173,26 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { } r.lastConsumed = last r.incomingPackets = keep - //if len(ready) != 0 { - //r.logger.Debugf(">> BUMP LAST CONSUMED TO %v", last) - //r.logger.Debugf(">> incoming now: %v", keep) - //} return ready } +func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) (incomingPacketSeen, bool) { + shouldDrop := false + incomingPacket := incomingPacketSeen{ + id: p.ID, + acks: p.ACKs, + } + r.logger.Debugf( + "notify: ", + p.ID, + p.ACKs, + ) + if p.ID > 0 && p.ID <= r.lastConsumed { + shouldDrop = true + } + return incomingPacket, shouldDrop +} + // assert that reliableIncoming implements incomingPacketHandler var _ incomingPacketHandler = &reliableReceiver{ logger: nil, diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 9632f611..71afcf83 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -41,15 +41,39 @@ func (ws *workersState) moveDownWorker() { // so that the ticker will wakeup and see if there's anything pending to be sent. ticker.Reset(time.Nanosecond) - case incomingSeen := <-sender.incomingSeen: + case seenPacket := <-sender.incomingSeen: // possibly evict any acked packet - sender.OnIncomingPacketSeen(incomingSeen) + sender.OnIncomingPacketSeen(seenPacket) - // schedule for inmediate wakeup, because we probably need to update ACKs - ticker.Reset(time.Nanosecond) + if seenPacket.id < sender.lastACKed { + continue + } + + now := time.Now() + + // this is quite arbitrary + tooLate := now.Add(1000 * time.Millisecond) - // TODO need to ACK here if no packets pending. - // I think we can just call withExpiredDeadline and ACK if len(expired) is 0 + nextTimeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) + + if nextTimeout.After(tooLate) { + // we don't want to wait so much, so we do send the ACK immediately. + if err := ws.doSendACK(&model.Packet{ID: seenPacket.id}); err != nil { + sender.lastACKed += 1 + } + + // TODO: ------------------------------------------------------------ + // discuss: how can we gauge the sending queue? should we peek what's + // if len(ws.controlToReliable) != 0 { + } else { + // we'll be fine by having these ACKs hitching a ride on the next outgoing packet + // that is scheduled to go soon anyways + fmt.Println(">>> SHOULD SEND SOON ENOUGH, APPEND ACK!--------------") + sender.pendingACKsToSend = append(sender.pendingACKsToSend, seenPacket.acks...) + // TODO: not needed anymore. + // and now we schedule for inmediate wakeup, because we probably need to update ACKs + // ticker.Reset(time.Nanosecond) + } case <-ticker.C: // First of all, we reset the ticker to the next timeout. @@ -96,8 +120,6 @@ func (ws *workersState) moveDownWorker() { // reliableSender keeps state about the outgoing packet queue, and implements outgoingPacketHandler. // Please use the constructor `newReliableSender()` type reliableSender struct { - // logger is the logger to use - logger model.Logger // incomingSeen is a channel where we receive notifications for incoming packets seen by the receiver. incomingSeen <-chan incomingPacketSeen @@ -105,6 +127,12 @@ type reliableSender struct { // inFlight is the array of in-flight packets. inFlight []*inFlightPacket + // lastACKed is the last packet ID from the remote that we have acked + lastACKed model.PacketID + + // logger is the logger to use + logger model.Logger + // pendingACKsToSend is the array of packets that we still need to ACK. pendingACKsToSend []model.PacketID } @@ -112,9 +140,10 @@ type reliableSender struct { // newReliableSender returns a new instance of reliableOutgoing. func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliableSender { return &reliableSender{ - logger: logger, incomingSeen: i, inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), + lastACKed: model.PacketID(0), + logger: logger, pendingACKsToSend: []model.PacketID{}, } } @@ -192,13 +221,8 @@ func (r *reliableSender) OnIncomingPacketSeen(ips incomingPacketSeen) { var _ outgoingPacketHandler = &reliableSender{} -// maybeACK sends an ACK when needed. -func (ws *workersState) maybeACK(packet *model.Packet) error { - // currently we are ACKing every packet - // TODO: implement better ACKing strategy - this is basically moving the responsibility - // to the sender, and then either appending up to 4 ACKs to the ACK array of an outgoing - // packet, or sending a single ACK (if there's nothing pending to be sent). - +// doSendACK sends an ACK when needed. +func (ws *workersState) doSendACK(packet *model.Packet) error { // this function will fail if we don't know the remote session ID ACK, err := ws.sessionManager.NewACKForPacket(packet) if err != nil { @@ -208,7 +232,7 @@ func (ws *workersState) maybeACK(packet *model.Packet) error { // move the packet down. CAN BLOCK writing to the shared channel to muxer. select { case ws.dataOrControlToMuxer <- ACK: - ws.logger.Debugf("ack for remote packet id: %d", packet.ID) + ws.logger.Debugf("====> ack for remote packet id: %d", packet.ID) return nil case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown From add18bdb43414ffd019e38a67f0ece4aa7caabf1 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 18:58:46 +0100 Subject: [PATCH 04/78] remove command, separate pr --- cmd/minivpn2/iface.go | 30 --------- cmd/minivpn2/main.go | 144 ------------------------------------------ 2 files changed, 174 deletions(-) delete mode 100644 cmd/minivpn2/iface.go delete mode 100644 cmd/minivpn2/main.go diff --git a/cmd/minivpn2/iface.go b/cmd/minivpn2/iface.go deleted file mode 100644 index 4102d938..00000000 --- a/cmd/minivpn2/iface.go +++ /dev/null @@ -1,30 +0,0 @@ -package main - -import ( - "fmt" - "net" -) - -func getInterfaceByIP(ipAddr string) (*net.Interface, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - for _, iface := range interfaces { - addrs, err := iface.Addrs() - if err != nil { - return nil, err - } - - for _, addr := range addrs { - if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.String() == ipAddr { - return &iface, nil - } - } - } - } - - return nil, fmt.Errorf("interface with IP %s not found", ipAddr) -} diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go deleted file mode 100644 index a6004a2f..00000000 --- a/cmd/minivpn2/main.go +++ /dev/null @@ -1,144 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net" - "os" - "os/exec" - - "github.com/apex/log" - "github.com/ooni/minivpn/internal/model" - "github.com/ooni/minivpn/internal/networkio" - "github.com/ooni/minivpn/internal/tun" - - "github.com/Doridian/water" - "github.com/jackpal/gateway" -) - -func runCmd(binaryPath string, args ...string) { - cmd := exec.Command(binaryPath, args...) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout - cmd.Stdin = os.Stdin - err := cmd.Run() - if nil != err { - log.WithError(err).Warn("error running /sbin/ip") - } -} - -func runIP(args ...string) { - runCmd("/sbin/ip", args...) -} - -func runRoute(args ...string) { - runCmd("/sbin/route", args...) -} - -func main() { - log.SetLevel(log.DebugLevel) - - // parse the configuration file - options, err := model.ReadConfigFile(os.Args[1]) - if err != nil { - log.WithError(err).Fatal("NewOptionsFromFilePath") - } - log.Infof("parsed options: %s", options.ServerOptionsString()) - - // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer - if !options.HasAuthInfo() { - log.Fatal("options are missing auth info") - } - // connect to the server - dialer := networkio.NewDialer(log.Log, &net.Dialer{}) - ctx := context.Background() - endpoint := net.JoinHostPort(options.Remote, options.Port) - conn, err := dialer.DialContext(ctx, options.Proto.String(), endpoint) - if err != nil { - log.WithError(err).Fatal("dialer.DialContext") - } - - // The TLS will expire in 60 seconds by default, but we can pass - // a shorter timeout. - //ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - //defer cancel() - - // create a vpn tun Device - tunnel, err := tun.StartTUN(ctx, conn, options) - if err != nil { - log.WithError(err).Fatal("init error") - return - } - fmt.Printf("Local IP: %s\n", tunnel.LocalAddr()) - fmt.Printf("Gateway: %s\n", tunnel.RemoteAddr()) - - // create a tun interface on the OS - iface, err := water.New(water.Config{ - DeviceType: water.TUN, - }) - if err != nil { - log.WithError(err).Fatal("Unable to allocate TUN interface:") - } - - // TODO: investigate what's the maximum working MTU, additionally get it from flag. - MTU := 1420 - iface.SetMTU(MTU) - - localAddr := tunnel.LocalAddr().String() - remoteAddr := tunnel.RemoteAddr().String() - netMask := tunnel.NetMask() - - // discover local gateway IP, to - defaultGatewayIP, err := gateway.DiscoverGateway() - if err != nil { - log.Warn("could not discover default gateway IP, routes might be broken") - } - defaultInterfaceIP, err := gateway.DiscoverInterface() - if err != nil { - log.Warn("could not discover default route interface IP, routes might be broken") - } - defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) - if err != nil { - log.Warn("could not get default route interface, routes might be broken") - } - - if defaultGatewayIP != nil && defaultInterface != nil { - log.Infof("route add %s gw %v dev %s", options.Remote, defaultGatewayIP, defaultInterface.Name) - runRoute("add", options.Remote, "gw", defaultGatewayIP.String(), defaultInterface.Name) - } - - // we want the network CIDR for setting up the routes - network := &net.IPNet{ - IP: net.ParseIP(localAddr).Mask(netMask), - Mask: netMask, - } - - // configure the interface and bring it up - runIP("addr", "add", localAddr, "dev", iface.Name()) - runIP("link", "set", "dev", iface.Name(), "up") - runRoute("add", remoteAddr, "gw", localAddr) - runRoute("add", "-net", network.String(), "dev", iface.Name()) - runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) - - go func() { - for { - packet := make([]byte, 2000) - n, err := iface.Read(packet) - if err != nil { - log.WithError(err).Fatal("error reading from tun") - } - tunnel.Write(packet[:n]) - } - }() - go func() { - for { - packet := make([]byte, 2000) - n, err := tunnel.Read(packet) - if err != nil { - log.WithError(err).Fatal("error reading from tun") - } - iface.Write(packet[:n]) - } - }() - select {} -} From 2c9761fb44600e82bbc96befe0273a2c7af833a5 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 20:58:54 +0100 Subject: [PATCH 05/78] add command to get handshake logs --- Makefile | 3 + cmd/minivpn2/iface.go | 30 +++++++++ cmd/minivpn2/main.go | 144 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+) create mode 100644 cmd/minivpn2/iface.go create mode 100644 cmd/minivpn2/main.go diff --git a/Makefile b/Makefile index eaad788c..99557266 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,9 @@ build-ndt7: bootstrap: @./scripts/bootstrap-provider ${PROVIDER} +handshake_log: + @sudo ./minivpn2 data/${PROVIDER}/config 2>&1 | grep --text -E 'info [>|<] P_*|info \[@\]' + test: GOFLAGS='-count=1' go test -v ./... diff --git a/cmd/minivpn2/iface.go b/cmd/minivpn2/iface.go new file mode 100644 index 00000000..4102d938 --- /dev/null +++ b/cmd/minivpn2/iface.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "net" +) + +func getInterfaceByIP(ipAddr string) (*net.Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.String() == ipAddr { + return &iface, nil + } + } + } + } + + return nil, fmt.Errorf("interface with IP %s not found", ipAddr) +} diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go new file mode 100644 index 00000000..a6004a2f --- /dev/null +++ b/cmd/minivpn2/main.go @@ -0,0 +1,144 @@ +package main + +import ( + "context" + "fmt" + "net" + "os" + "os/exec" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/networkio" + "github.com/ooni/minivpn/internal/tun" + + "github.com/Doridian/water" + "github.com/jackpal/gateway" +) + +func runCmd(binaryPath string, args ...string) { + cmd := exec.Command(binaryPath, args...) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + err := cmd.Run() + if nil != err { + log.WithError(err).Warn("error running /sbin/ip") + } +} + +func runIP(args ...string) { + runCmd("/sbin/ip", args...) +} + +func runRoute(args ...string) { + runCmd("/sbin/route", args...) +} + +func main() { + log.SetLevel(log.DebugLevel) + + // parse the configuration file + options, err := model.ReadConfigFile(os.Args[1]) + if err != nil { + log.WithError(err).Fatal("NewOptionsFromFilePath") + } + log.Infof("parsed options: %s", options.ServerOptionsString()) + + // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer + if !options.HasAuthInfo() { + log.Fatal("options are missing auth info") + } + // connect to the server + dialer := networkio.NewDialer(log.Log, &net.Dialer{}) + ctx := context.Background() + endpoint := net.JoinHostPort(options.Remote, options.Port) + conn, err := dialer.DialContext(ctx, options.Proto.String(), endpoint) + if err != nil { + log.WithError(err).Fatal("dialer.DialContext") + } + + // The TLS will expire in 60 seconds by default, but we can pass + // a shorter timeout. + //ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + //defer cancel() + + // create a vpn tun Device + tunnel, err := tun.StartTUN(ctx, conn, options) + if err != nil { + log.WithError(err).Fatal("init error") + return + } + fmt.Printf("Local IP: %s\n", tunnel.LocalAddr()) + fmt.Printf("Gateway: %s\n", tunnel.RemoteAddr()) + + // create a tun interface on the OS + iface, err := water.New(water.Config{ + DeviceType: water.TUN, + }) + if err != nil { + log.WithError(err).Fatal("Unable to allocate TUN interface:") + } + + // TODO: investigate what's the maximum working MTU, additionally get it from flag. + MTU := 1420 + iface.SetMTU(MTU) + + localAddr := tunnel.LocalAddr().String() + remoteAddr := tunnel.RemoteAddr().String() + netMask := tunnel.NetMask() + + // discover local gateway IP, to + defaultGatewayIP, err := gateway.DiscoverGateway() + if err != nil { + log.Warn("could not discover default gateway IP, routes might be broken") + } + defaultInterfaceIP, err := gateway.DiscoverInterface() + if err != nil { + log.Warn("could not discover default route interface IP, routes might be broken") + } + defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) + if err != nil { + log.Warn("could not get default route interface, routes might be broken") + } + + if defaultGatewayIP != nil && defaultInterface != nil { + log.Infof("route add %s gw %v dev %s", options.Remote, defaultGatewayIP, defaultInterface.Name) + runRoute("add", options.Remote, "gw", defaultGatewayIP.String(), defaultInterface.Name) + } + + // we want the network CIDR for setting up the routes + network := &net.IPNet{ + IP: net.ParseIP(localAddr).Mask(netMask), + Mask: netMask, + } + + // configure the interface and bring it up + runIP("addr", "add", localAddr, "dev", iface.Name()) + runIP("link", "set", "dev", iface.Name(), "up") + runRoute("add", remoteAddr, "gw", localAddr) + runRoute("add", "-net", network.String(), "dev", iface.Name()) + runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) + + go func() { + for { + packet := make([]byte, 2000) + n, err := iface.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + tunnel.Write(packet[:n]) + } + }() + go func() { + for { + packet := make([]byte, 2000) + n, err := tunnel.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + iface.Write(packet[:n]) + } + }() + select {} +} From 51e2d344603acac46dc28311a52d9a62ba5cb0f0 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 21:01:38 +0100 Subject: [PATCH 06/78] remove unused code --- internal/reliabletransport/interfaces.go | 4 ---- internal/reliabletransport/receiver.go | 21 --------------------- 2 files changed, 25 deletions(-) diff --git a/internal/reliabletransport/interfaces.go b/internal/reliabletransport/interfaces.go index 688c476d..85f1fb3a 100644 --- a/internal/reliabletransport/interfaces.go +++ b/internal/reliabletransport/interfaces.go @@ -37,10 +37,6 @@ type outgoingPacketHandler interface { // incomingPacketHandler knows how to deal with incoming packets (going up). type incomingPacketHandler interface { - // TODO: the interface needs to add ACKs() - // NotifySeen sends notifications about an incoming packet. - NotifySeen(*model.Packet) bool - // MaybeInsertIncoming will insert a given packet in the reliable // incoming queue if it passes a series of sanity checks. MaybeInsertIncoming(*model.Packet) bool diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 8bffb67f..e3f7302d 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -109,11 +109,6 @@ type reliableReceiver struct { lastConsumed model.PacketID } -// NotifySeen implements incomingPacketHandler. -func (*reliableReceiver) NotifySeen(*model.Packet) bool { - panic("unimplemented") -} - func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliableReceiver { return &reliableReceiver{ logger: logger, @@ -123,22 +118,6 @@ func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliab } } -// NotifySeen sends a incomingPacketSeen object to the shared channel where the sender will read it. -/* -func (r *reliableReceiver) NotifySeen(p *model.Packet) bool { - incoming := incomingPacketSeen{ - id: p.ID, - acks: p.ACKs, - } - if p.ID > 0 && p.ID <= r.lastConsumed { - r.logger.Warnf("got packet id %v, but last consumed is %v\n", p.ID, r.lastConsumed) - } - r.incomingSeen <- incoming - return true - -} -*/ - func (r *reliableReceiver) MaybeInsertIncoming(p *model.Packet) bool { // we drop if at capacity, by default double the size of the outgoing buffer if len(r.incomingPackets) >= RELIABLE_RECV_BUFFER_SIZE { From c3cb64d5d38192446ded575f78f7301e3a1541f8 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 21:02:27 +0100 Subject: [PATCH 07/78] add sender logging --- internal/reliabletransport/sender.go | 40 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 71afcf83..638fe891 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,13 +28,7 @@ func (ws *workersState) moveDownWorker() { // POSSIBLY BLOCK reading the next packet we should move down the stack select { case packet := <-ws.controlToReliable: - ws.logger.Infof( - "> %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) + logPacket(ws.logger, packet) sender.TryInsertOutgoingPacket(packet) // schedule for inmediate wakeup @@ -52,12 +46,12 @@ func (ws *workersState) moveDownWorker() { now := time.Now() // this is quite arbitrary - tooLate := now.Add(1000 * time.Millisecond) + tooLate := now.Add(50 * time.Millisecond) nextTimeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) if nextTimeout.After(tooLate) { - // we don't want to wait so much, so we do send the ACK immediately. + // we don't want to wait so much, so we do not wait for the ticker to wake up if err := ws.doSendACK(&model.Packet{ID: seenPacket.id}); err != nil { sender.lastACKed += 1 } @@ -68,10 +62,9 @@ func (ws *workersState) moveDownWorker() { } else { // we'll be fine by having these ACKs hitching a ride on the next outgoing packet // that is scheduled to go soon anyways - fmt.Println(">>> SHOULD SEND SOON ENOUGH, APPEND ACK!--------------") + fmt.Println("===> SHOULD SEND SOON ENOUGH, APPEND ACK!-----------------") sender.pendingACKsToSend = append(sender.pendingACKsToSend, seenPacket.acks...) - // TODO: not needed anymore. - // and now we schedule for inmediate wakeup, because we probably need to update ACKs + // TODO: not needed anymore right? // ticker.Reset(time.Nanosecond) } @@ -86,14 +79,15 @@ func (ws *workersState) moveDownWorker() { now := time.Now() timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) - ws.logger.Debug("") - ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) + // ws.logger.Debug("") + // ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) ticker.Reset(timeout.Sub(now)) // we flush everything that is ready to be sent. scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) - ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) + + // ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) for _, p := range scheduledNow { p.ScheduleForRetransmission(now) @@ -101,7 +95,6 @@ func (ws *workersState) moveDownWorker() { // ideally, we want to append any pending ACKs here select { case ws.dataOrControlToMuxer <- p.packet: - ws.logger.Debugf("==> sent packet with ID: %v", p.packet.ID) case <-ws.workersManager.ShouldShutdown(): return } @@ -232,9 +225,22 @@ func (ws *workersState) doSendACK(packet *model.Packet) error { // move the packet down. CAN BLOCK writing to the shared channel to muxer. select { case ws.dataOrControlToMuxer <- ACK: - ws.logger.Debugf("====> ack for remote packet id: %d", packet.ID) + logPacket(ws.logger, ACK) return nil case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown } } + +func logPacket(logger model.Logger, packet *model.Packet) { + logger.Infof( + "> %s (id=%d) [acks=%v] localID=%x remoteID=%x [%d bytes] %v", + packet.Opcode, + packet.ID, + packet.ACKs, + packet.LocalSessionID, + packet.RemoteSessionID, + len(packet.Payload), + time.Now(), + ) +} From c85b9f3b5db6cb912df2df493c72ed4c4d8da169 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 21:37:43 +0100 Subject: [PATCH 08/78] add precision time logging in the text handler --- cmd/minivpn2/log.go | 79 ++++++++++++++++++++++++++++++++++++++++++++ cmd/minivpn2/main.go | 23 ++++++++++--- internal/tun/tun.go | 11 ++++-- 3 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 cmd/minivpn2/log.go diff --git a/cmd/minivpn2/log.go b/cmd/minivpn2/log.go new file mode 100644 index 00000000..184924e7 --- /dev/null +++ b/cmd/minivpn2/log.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "io" + "os" + "sync" + "time" + + "github.com/apex/log" +) + +// Default handler outputting to stderr. +var Default = NewHandler(os.Stderr) + +// start time. +var start = time.Now() + +// colors. +const ( + none = 0 + red = 31 + green = 32 + yellow = 33 + blue = 34 + gray = 37 +) + +// Colors mapping. +var Colors = [...]int{ + log.DebugLevel: gray, + log.InfoLevel: blue, + log.WarnLevel: yellow, + log.ErrorLevel: red, + log.FatalLevel: red, +} + +// Strings mapping. +var Strings = [...]string{ + log.DebugLevel: "DEBUG", + log.InfoLevel: "INFO", + log.WarnLevel: "WARN", + log.ErrorLevel: "ERROR", + log.FatalLevel: "FATAL", +} + +// Handler implementation. +type Handler struct { + mu sync.Mutex + Writer io.Writer +} + +// New handler. +func NewHandler(w io.Writer) *Handler { + return &Handler{ + Writer: w, + } +} + +// HandleLog implements log.Handler. +func (h *Handler) HandleLog(e *log.Entry) error { + color := Colors[e.Level] + level := Strings[e.Level] + names := e.Fields.Names() + + h.mu.Lock() + defer h.mu.Unlock() + + ts := time.Since(start) // time.Microsecond + fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%10v] %-25s", color, level, ts, e.Message) + + for _, name := range names { + fmt.Fprintf(h.Writer, " \033[%dm%s\033[0m=%v", color, name, e.Fields.Get(name)) + } + + fmt.Fprintln(h.Writer) + + return nil +} diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go index a6004a2f..7ca5804f 100644 --- a/cmd/minivpn2/main.go +++ b/cmd/minivpn2/main.go @@ -7,13 +7,13 @@ import ( "os" "os/exec" + "github.com/Doridian/water" "github.com/apex/log" + "github.com/jackpal/gateway" + "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/networkio" "github.com/ooni/minivpn/internal/tun" - - "github.com/Doridian/water" - "github.com/jackpal/gateway" ) func runCmd(binaryPath string, args ...string) { @@ -35,6 +35,13 @@ func runRoute(args ...string) { runCmd("/sbin/route", args...) } +/* +func logWithElapsedTime(logger log.Interface, message string, start time.Time) { + elapsedTime := time.Since(startTime).Round(time.Millisecond) + logger.WithField("elapsed_time", elapsedTime).Info(message) +} +*/ + func main() { log.SetLevel(log.DebugLevel) @@ -49,10 +56,16 @@ func main() { if !options.HasAuthInfo() { log.Fatal("options are missing auth info") } + + log.SetHandler(NewHandler(os.Stderr)) + log.SetLevel(log.DebugLevel) + // connect to the server dialer := networkio.NewDialer(log.Log, &net.Dialer{}) ctx := context.Background() + endpoint := net.JoinHostPort(options.Remote, options.Port) + conn, err := dialer.DialContext(ctx, options.Proto.String(), endpoint) if err != nil { log.WithError(err).Fatal("dialer.DialContext") @@ -64,7 +77,7 @@ func main() { //defer cancel() // create a vpn tun Device - tunnel, err := tun.StartTUN(ctx, conn, options) + tunnel, err := tun.StartTUN(ctx, conn, options, log.Log) if err != nil { log.WithError(err).Fatal("init error") return @@ -88,7 +101,7 @@ func main() { remoteAddr := tunnel.RemoteAddr().String() netMask := tunnel.NetMask() - // discover local gateway IP, to + // discover local gateway IP, we need it to add a route to our remote via our network gw defaultGatewayIP, err := gateway.DiscoverGateway() if err != nil { log.Warn("could not discover default gateway IP, routes might be broken") diff --git a/internal/tun/tun.go b/internal/tun/tun.go index 51448c1e..5e0c2bd4 100644 --- a/internal/tun/tun.go +++ b/internal/tun/tun.go @@ -22,15 +22,20 @@ var ( // StartTUN initializes and starts the TUN device over the vpn. // If the passed context expires before the TUN device is ready, -func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options) (*TUN, error) { +func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options, logger model.Logger) (*TUN, error) { + // be useful if passing an empty logger + if logger == nil { + logger = log.Log + } + // create a session - sessionManager, err := session.NewManager(log.Log) + sessionManager, err := session.NewManager(logger) if err != nil { return nil, err } // create the TUN that will OWN the connection - tunnel := newTUN(log.Log, conn, sessionManager) + tunnel := newTUN(logger, conn, sessionManager) // start all the workers workers := startWorkers(log.Log, sessionManager, tunnel, conn, options) From ba650b8c88720566a10c820e76ef535d9905481d Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 24 Jan 2024 22:17:50 +0100 Subject: [PATCH 09/78] clearer logging --- Makefile | 2 +- cmd/minivpn2/log.go | 2 +- internal/tun/setup.go | 6 ------ 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 99557266..74094e95 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ bootstrap: @./scripts/bootstrap-provider ${PROVIDER} handshake_log: - @sudo ./minivpn2 data/${PROVIDER}/config 2>&1 | grep --text -E 'info [>|<] P_*|info \[@\]' + @sudo ./minivpn2 data/${PROVIDER}/config 2>&1 | grep --text --color=auto -E "@|P_\w+" test: GOFLAGS='-count=1' go test -v ./... diff --git a/cmd/minivpn2/log.go b/cmd/minivpn2/log.go index 184924e7..bf4a5c22 100644 --- a/cmd/minivpn2/log.go +++ b/cmd/minivpn2/log.go @@ -66,7 +66,7 @@ func (h *Handler) HandleLog(e *log.Entry) error { h.mu.Lock() defer h.mu.Unlock() - ts := time.Since(start) // time.Microsecond + ts := time.Since(start) fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%10v] %-25s", color, level, ts, e.Message) for _, name := range names { diff --git a/internal/tun/setup.go b/internal/tun/setup.go index 9d79499a..11e9d0bd 100644 --- a/internal/tun/setup.go +++ b/internal/tun/setup.go @@ -108,12 +108,6 @@ func startWorkers(logger model.Logger, sessionManager *session.Manager, // connect the muxer and the tlsstate service connectChannel(tlsx.NotifyTLS, &muxer.NotifyTLS) - logger.Debugf("%T: %+v", nio, nio) - logger.Debugf("%T: %+v", muxer, muxer) - logger.Debugf("%T: %+v", rel, rel) - logger.Debugf("%T: %+v", ctrl, ctrl) - logger.Debugf("%T: %+v", tlsx, tlsx) - // start all the workers nio.StartWorkers(logger, workersManager, conn) muxer.StartWorkers(logger, workersManager, sessionManager) From 0525989d3c3fc4fff11b3eec28d152dabacc651e Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 00:32:11 +0100 Subject: [PATCH 10/78] unify packet logging --- internal/model/packet.go | 29 ++++++++++++++++++++++++++ internal/packetmuxer/service.go | 18 +++------------- internal/reliabletransport/receiver.go | 23 ++++---------------- internal/reliabletransport/sender.go | 17 ++------------- 4 files changed, 38 insertions(+), 49 deletions(-) diff --git a/internal/model/packet.go b/internal/model/packet.go index 828faded..fd59b63c 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -306,3 +306,32 @@ func (p *Packet) IsControl() bool { func (p *Packet) IsData() bool { return p.Opcode.IsData() } + +const ( + DirectionIncoming = iota + DirectionOutgoing +) + +func (p *Packet) Log(logger Logger, direction int) { + var dir string + switch direction { + case DirectionIncoming: + dir = "<" + case DirectionOutgoing: + dir = ">" + default: + logger.Warnf("wrong direction: %d", direction) + return + } + + logger.Infof( + "%s %s {id=%d, acks=%v} localID=%x remoteID=%x [%d bytes]", + dir, + p.Opcode, + p.ID, + p.ACKs, + p.LocalSessionID, + p.RemoteSessionID, + len(p.Payload), + ) +} diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 9e126cfb..a17a0035 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -215,6 +215,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { // handle the case where we're performing a HARD_RESET if ws.sessionManager.NegotiationState() == session.S_PRE_START && packet.Opcode == model.P_CONTROL_HARD_RESET_SERVER_V2 { + packet.Log(ws.logger, model.DirectionIncoming) ws.hardResetTicker.Stop() return ws.finishThreeWayHandshake(packet) } @@ -223,6 +224,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { // multiplex the incoming packet POSSIBLY BLOCKING on delivering it if packet.IsControl() || packet.Opcode == model.P_ACK_V1 { + packet.Log(ws.logger, model.DirectionIncoming) select { case ws.muxerToReliable <- packet: case <-ws.workersManager.ShouldShutdown(): @@ -245,13 +247,6 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error { ws.sessionManager.SetRemoteSessionID(packet.LocalSessionID) // we need to manually ACK because the reliable layer is above us - ws.logger.Debugf( - "< %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) // create the ACK packet ACK, err := ws.sessionManager.NewACKForPacket(packet) @@ -298,13 +293,6 @@ func (ws *workersState) serializeAndEmit(packet *model.Packet) error { return workers.ErrShutdown } - ws.logger.Debugf( - "> %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) - + packet.Log(ws.logger, model.DirectionOutgoing) return nil } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index e3f7302d..4200dac3 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -2,6 +2,7 @@ package reliabletransport import ( "bytes" + "encoding/hex" "fmt" "sort" @@ -29,13 +30,7 @@ func (ws *workersState) moveUpWorker() { // or POSSIBLY BLOCK waiting for notifications select { case packet := <-ws.muxerToReliable: - ws.logger.Infof( - "< %s localID=%x remoteID=%x [%d bytes]", - packet.Opcode, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - ) + packet.Log(ws.logger, model.DirectionIncoming) // drop a packet that is not for our session if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { @@ -48,27 +43,17 @@ func (ws *workersState) moveUpWorker() { continue } - // possibly ACK the incoming packet - // TODO: move this responsibility to the sender. - /* - if err := ws.maybeACK(packet); err != nil { - ws.logger.Warnf("%s: cannot ACK packet: %s", workerName, err.Error()) - continue - } - */ - if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it continue } - // TODO: possibly refactor so that the writing to the channel happens here - // the fact this channel write is hidden makes following this harder - // receiver.NotifySeen(packet) seenPacket, shouldDrop := receiver.newIncomingPacketSeen(packet) switch shouldDrop { case true: receiver.logger.Warnf("got packet id %v, but last consumed is %v (dropping)\n", packet.ID, receiver.lastConsumed) + b, _ := packet.Bytes() + fmt.Println(hex.Dump(b)) case false: ws.incomingSeen <- seenPacket } diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 638fe891..4cc04e6e 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,7 +28,7 @@ func (ws *workersState) moveDownWorker() { // POSSIBLY BLOCK reading the next packet we should move down the stack select { case packet := <-ws.controlToReliable: - logPacket(ws.logger, packet) + packet.Log(ws.logger, model.DirectionOutgoing) sender.TryInsertOutgoingPacket(packet) // schedule for inmediate wakeup @@ -225,22 +225,9 @@ func (ws *workersState) doSendACK(packet *model.Packet) error { // move the packet down. CAN BLOCK writing to the shared channel to muxer. select { case ws.dataOrControlToMuxer <- ACK: - logPacket(ws.logger, ACK) + ACK.Log(ws.logger, model.DirectionOutgoing) return nil case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown } } - -func logPacket(logger model.Logger, packet *model.Packet) { - logger.Infof( - "> %s (id=%d) [acks=%v] localID=%x remoteID=%x [%d bytes] %v", - packet.Opcode, - packet.ID, - packet.ACKs, - packet.LocalSessionID, - packet.RemoteSessionID, - len(packet.Payload), - time.Now(), - ) -} From 9e5b7449f17431dd4b1e8b51293da98ca3d2ba5d Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 01:43:23 +0100 Subject: [PATCH 11/78] checkpoint: improve logging, hack client hello ack to get moving --- internal/packetmuxer/service.go | 14 ----- internal/reliabletransport/packets.go | 5 +- internal/reliabletransport/receiver.go | 26 ++++++--- internal/reliabletransport/sender.go | 73 ++++++++++---------------- internal/tlssession/tlsbio.go | 43 ++++++++------- internal/tlssession/tlssession.go | 2 +- 6 files changed, 72 insertions(+), 91 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index a17a0035..82db1c14 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -224,7 +224,6 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { // multiplex the incoming packet POSSIBLY BLOCKING on delivering it if packet.IsControl() || packet.Opcode == model.P_ACK_V1 { - packet.Log(ws.logger, model.DirectionIncoming) select { case ws.muxerToReliable <- packet: case <-ws.workersManager.ShouldShutdown(): @@ -246,19 +245,6 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error { // register the server's session (note: the PoV is the server's one) ws.sessionManager.SetRemoteSessionID(packet.LocalSessionID) - // we need to manually ACK because the reliable layer is above us - - // create the ACK packet - ACK, err := ws.sessionManager.NewACKForPacket(packet) - if err != nil { - return err - } - - // emit the packet - if err := ws.serializeAndEmit(ACK); err != nil { - return err - } - // advance the state ws.sessionManager.SetNegotiationState(session.S_START) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index af2824a3..b4d392ad 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -5,6 +5,7 @@ import ( "time" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/optional" ) // @@ -152,6 +153,6 @@ func (ip *incomingPacket) Packet() *model.Packet { // incomingPacketSeen is a struct that the receiver sends us when a new packet is seen. type incomingPacketSeen struct { - id model.PacketID - acks []model.PacketID + id optional.Value[model.PacketID] + acks optional.Value[[]model.PacketID] } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 4200dac3..a62db48c 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -7,6 +7,7 @@ import ( "sort" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/optional" ) // moveUpWorker moves packets up the stack (receiver) @@ -142,15 +143,24 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) (incomingPacketSeen, bool) { shouldDrop := false - incomingPacket := incomingPacketSeen{ - id: p.ID, - acks: p.ACKs, + incomingPacket := incomingPacketSeen{} + if p.Opcode == model.P_ACK_V1 { + incomingPacket.acks = optional.Some(p.ACKs) + } else { + incomingPacket.id = optional.Some(p.ID) + if len(p.ACKs) != 0 { + incomingPacket.acks = optional.Some(p.ACKs) + } } - r.logger.Debugf( - "notify: ", - p.ID, - p.ACKs, - ) + + /* + r.logger.Debugf( + "notify: ", + p.ID, + p.ACKs, + ) + */ + if p.ID > 0 && p.ID <= r.lastConsumed { shouldDrop = true } diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 4cc04e6e..f87d9b10 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,7 +28,6 @@ func (ws *workersState) moveDownWorker() { // POSSIBLY BLOCK reading the next packet we should move down the stack select { case packet := <-ws.controlToReliable: - packet.Log(ws.logger, model.DirectionOutgoing) sender.TryInsertOutgoingPacket(packet) // schedule for inmediate wakeup @@ -36,37 +35,10 @@ func (ws *workersState) moveDownWorker() { ticker.Reset(time.Nanosecond) case seenPacket := <-sender.incomingSeen: - // possibly evict any acked packet + // possibly evict any acked packet (in the ack array) + // and add any id to the queue of packets to ack sender.OnIncomingPacketSeen(seenPacket) - - if seenPacket.id < sender.lastACKed { - continue - } - - now := time.Now() - - // this is quite arbitrary - tooLate := now.Add(50 * time.Millisecond) - - nextTimeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) - - if nextTimeout.After(tooLate) { - // we don't want to wait so much, so we do not wait for the ticker to wake up - if err := ws.doSendACK(&model.Packet{ID: seenPacket.id}); err != nil { - sender.lastACKed += 1 - } - - // TODO: ------------------------------------------------------------ - // discuss: how can we gauge the sending queue? should we peek what's - // if len(ws.controlToReliable) != 0 { - } else { - // we'll be fine by having these ACKs hitching a ride on the next outgoing packet - // that is scheduled to go soon anyways - fmt.Println("===> SHOULD SEND SOON ENOUGH, APPEND ACK!-----------------") - sender.pendingACKsToSend = append(sender.pendingACKsToSend, seenPacket.acks...) - // TODO: not needed anymore right? - // ticker.Reset(time.Nanosecond) - } + ticker.Reset(time.Nanosecond) case <-ticker.C: // First of all, we reset the ticker to the next timeout. @@ -79,20 +51,25 @@ func (ws *workersState) moveDownWorker() { now := time.Now() timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) - // ws.logger.Debug("") - // ws.logger.Debugf("next wakeup: %v", timeout.Sub(now)) - ticker.Reset(timeout.Sub(now)) // we flush everything that is ready to be sent. scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) - // ws.logger.Debugf(":: GOT %d packets to send\n", len(scheduledNow)) - for _, p := range scheduledNow { p.ScheduleForRetransmission(now) - // TODO ------------------------------------------- - // ideally, we want to append any pending ACKs here + // append any pending ACKs + nextACKs := sender.NextPacketIDsToACK() + + // HACK: we need to account for packet IDs received below (hard reset) + // (special case) + if p.packet.ID == 1 && len(nextACKs) == 0 { + p.packet.ACKs = []model.PacketID{0} + } else { + p.packet.ACKs = nextACKs + } + + p.packet.Log(ws.logger, model.DirectionOutgoing) select { case ws.dataOrControlToMuxer <- p.packet: case <-ws.workersManager.ShouldShutdown(): @@ -121,7 +98,7 @@ type reliableSender struct { inFlight []*inFlightPacket // lastACKed is the last packet ID from the remote that we have acked - lastACKed model.PacketID + // lastACKed model.PacketID // logger is the logger to use logger model.Logger @@ -133,9 +110,9 @@ type reliableSender struct { // newReliableSender returns a new instance of reliableOutgoing. func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliableSender { return &reliableSender{ - incomingSeen: i, - inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), - lastACKed: model.PacketID(0), + incomingSeen: i, + inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), + //lastACKed: model.PacketID(0), logger: logger, pendingACKsToSend: []model.PacketID{}, } @@ -200,15 +177,19 @@ func (r *reliableSender) NextPacketIDsToACK() []model.PacketID { return next } -func (r *reliableSender) OnIncomingPacketSeen(ips incomingPacketSeen) { +func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { // we have received an incomingPacketSeen on the shared channel, we need to do two things: // 1. add the ID to the queue of packets to be acknowledged. - r.pendingACKsToSend = append(r.pendingACKsToSend, ips.id) + if !seen.id.IsNone() { + r.pendingACKsToSend = append(r.pendingACKsToSend, seen.id.Unwrap()) + } // 2. for every ACK received, see if we need to evict or bump the in-flight packet. - for _, packetID := range ips.acks { - r.MaybeEvictOrBumpPacketAfterACK(packetID) + if !seen.acks.IsNone() { + for _, packetID := range seen.acks.Unwrap() { + r.MaybeEvictOrBumpPacketAfterACK(packetID) + } } } diff --git a/internal/tlssession/tlsbio.go b/internal/tlssession/tlsbio.go index 6fec09be..00898d2b 100644 --- a/internal/tlssession/tlsbio.go +++ b/internal/tlssession/tlsbio.go @@ -2,10 +2,11 @@ package tlssession import ( "bytes" - "log" "net" "sync" "time" + + "github.com/ooni/minivpn/internal/model" ) // tlsBio allows to use channels to read and write @@ -14,70 +15,72 @@ type tlsBio struct { directionDown chan<- []byte directionUp <-chan []byte hangup chan any + logger model.Logger readBuffer *bytes.Buffer } // newTLSBio creates a new tlsBio -func newTLSBio(directionUp <-chan []byte, directionDown chan<- []byte) *tlsBio { +func newTLSBio(logger model.Logger, directionUp <-chan []byte, directionDown chan<- []byte) *tlsBio { return &tlsBio{ closeOnce: sync.Once{}, directionDown: directionDown, directionUp: directionUp, hangup: make(chan any), + logger: logger, readBuffer: &bytes.Buffer{}, } } -func (c *tlsBio) Close() error { - c.closeOnce.Do(func() { - close(c.hangup) +func (t *tlsBio) Close() error { + t.closeOnce.Do(func() { + close(t.hangup) }) return nil } -func (c *tlsBio) Read(data []byte) (int, error) { +func (t *tlsBio) Read(data []byte) (int, error) { for { - count, _ := c.readBuffer.Read(data) + count, _ := t.readBuffer.Read(data) if count > 0 { - log.Printf("[tlsbio] received %d bytes", len(data)) + t.logger.Debugf("[tlsbio] received %d bytes", len(data)) return count, nil } select { - case extra := <-c.directionUp: - c.readBuffer.Write(extra) - case <-c.hangup: + case extra := <-t.directionUp: + t.readBuffer.Write(extra) + case <-t.hangup: return 0, net.ErrClosed } } } -func (c *tlsBio) Write(data []byte) (int, error) { - log.Printf("[tlsbio] requested to write %d bytes", len(data)) +func (t *tlsBio) Write(data []byte) (int, error) { + t.logger.Debugf("[tlsbio] requested to write %d bytes", len(data)) select { - case c.directionDown <- data: + case t.directionDown <- data: return len(data), nil - case <-c.hangup: + case <-t.hangup: return 0, net.ErrClosed } } -func (c *tlsBio) LocalAddr() net.Addr { +func (t *tlsBio) LocalAddr() net.Addr { return &tlsBioAddr{} } -func (c *tlsBio) RemoteAddr() net.Addr { +func (t *tlsBio) RemoteAddr() net.Addr { return &tlsBioAddr{} } -func (c *tlsBio) SetDeadline(t time.Time) error { +func (t *tlsBio) SetDeadline(tt time.Time) error { return nil } -func (c *tlsBio) SetReadDeadline(t time.Time) error { +func (t *tlsBio) SetReadDeadline(tt time.Time) error { return nil } -func (c *tlsBio) SetWriteDeadline(t time.Time) error { +func (t *tlsBio) SetWriteDeadline(tt time.Time) error { return nil } diff --git a/internal/tlssession/tlssession.go b/internal/tlssession/tlssession.go index a227ecd8..95016bfa 100644 --- a/internal/tlssession/tlssession.go +++ b/internal/tlssession/tlssession.go @@ -101,7 +101,7 @@ func (ws *workersState) worker() { // tlsAuth runs the TLS auth algorithm func (ws *workersState) tlsAuth() error { // create the BIO to use channels as a socket - conn := newTLSBio(ws.tlsRecordUp, ws.tlsRecordDown) + conn := newTLSBio(ws.logger, ws.tlsRecordUp, ws.tlsRecordDown) defer conn.Close() // we construct the certCfg from options, that has access to the certificate material From f34af5cf7a885a91fe6aed16aa15a047453a394d Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 02:24:01 +0100 Subject: [PATCH 12/78] checkpoint --- internal/packetmuxer/service.go | 8 ++++- internal/reliabletransport/receiver.go | 50 +++++++++++--------------- internal/reliabletransport/sender.go | 23 ++++++------ 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 82db1c14..5ecb53fa 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -182,7 +182,6 @@ func (ws *workersState) moveDownWorker() { // startHardReset is invoked when we need to perform a HARD RESET. func (ws *workersState) startHardReset() error { // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt - // TODO(ainghazal): we need to retry this hard reset if not ACKd in a reasonable time. packet, err := ws.sessionManager.NewPacket(model.P_CONTROL_HARD_RESET_CLIENT_V2, nil) if err != nil { ws.logger.Warnf("packetmuxer: NewPacket: %s", err.Error()) @@ -248,6 +247,13 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error { // advance the state ws.sessionManager.SetNegotiationState(session.S_START) + // pass the packet up so that we can ack it properly + select { + case ws.muxerToReliable <- packet: + case <-ws.workersManager.ShouldShutdown(): + return workers.ErrShutdown + } + // attempt to tell TLS we want to handshake. // This WILL BLOCK if the notifyTLS channel // is Full, but we make sure we control that we don't pass spurious soft-reset packets while we're diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index a62db48c..ad9cdf0b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -2,7 +2,6 @@ package reliabletransport import ( "bytes" - "encoding/hex" "fmt" "sort" @@ -31,7 +30,10 @@ func (ws *workersState) moveUpWorker() { // or POSSIBLY BLOCK waiting for notifications select { case packet := <-ws.muxerToReliable: - packet.Log(ws.logger, model.DirectionIncoming) + if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 { + // the hard reset we logged in below + packet.Log(ws.logger, model.DirectionIncoming) + } // drop a packet that is not for our session if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { @@ -44,19 +46,23 @@ func (ws *workersState) moveUpWorker() { continue } - if inserted := receiver.MaybeInsertIncoming(packet); !inserted { - // this packet was not inserted in the queue: we drop it + seen := receiver.newIncomingPacketSeen(packet) + ws.incomingSeen <- seen + + // we only want to insert control packets going to the tls layer + + if packet.Opcode != model.P_CONTROL_V1 { continue } - seenPacket, shouldDrop := receiver.newIncomingPacketSeen(packet) - switch shouldDrop { - case true: - receiver.logger.Warnf("got packet id %v, but last consumed is %v (dropping)\n", packet.ID, receiver.lastConsumed) - b, _ := packet.Bytes() - fmt.Println(hex.Dump(b)) - case false: - ws.incomingSeen <- seenPacket + if packet.ID < receiver.lastConsumed { + ws.logger.Warnf("%s: received %d but last consumed was %d", workerName, packet.ID, receiver.lastConsumed) + continue + } + + if inserted := receiver.MaybeInsertIncoming(packet); !inserted { + // this packet was not inserted in the queue: we drop it + continue } ready := receiver.NextIncomingSequence() @@ -141,30 +147,16 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { return ready } -func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) (incomingPacketSeen, bool) { - shouldDrop := false +func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) incomingPacketSeen { incomingPacket := incomingPacketSeen{} if p.Opcode == model.P_ACK_V1 { incomingPacket.acks = optional.Some(p.ACKs) } else { incomingPacket.id = optional.Some(p.ID) - if len(p.ACKs) != 0 { - incomingPacket.acks = optional.Some(p.ACKs) - } + incomingPacket.acks = optional.Some(p.ACKs) } - /* - r.logger.Debugf( - "notify: ", - p.ID, - p.ACKs, - ) - */ - - if p.ID > 0 && p.ID <= r.lastConsumed { - shouldDrop = true - } - return incomingPacket, shouldDrop + return incomingPacket } // assert that reliableIncoming implements incomingPacketHandler diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index f87d9b10..e6c6703e 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -58,16 +58,19 @@ func (ws *workersState) moveDownWorker() { for _, p := range scheduledNow { p.ScheduleForRetransmission(now) + // append any pending ACKs - nextACKs := sender.NextPacketIDsToACK() + p.packet.ACKs = sender.NextPacketIDsToACK() // HACK: we need to account for packet IDs received below (hard reset) // (special case) - if p.packet.ID == 1 && len(nextACKs) == 0 { - p.packet.ACKs = []model.PacketID{0} - } else { - p.packet.ACKs = nextACKs - } + /* + if p.packet.ID == 1 && len(nextACKs) == 0 { + p.packet.ACKs = []model.PacketID{0} + } else { + p.packet.ACKs = nextACKs + } + */ p.packet.Log(ws.logger, model.DirectionOutgoing) select { @@ -97,9 +100,6 @@ type reliableSender struct { // inFlight is the array of in-flight packets. inFlight []*inFlightPacket - // lastACKed is the last packet ID from the remote that we have acked - // lastACKed model.PacketID - // logger is the logger to use logger model.Logger @@ -110,9 +110,8 @@ type reliableSender struct { // newReliableSender returns a new instance of reliableOutgoing. func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliableSender { return &reliableSender{ - incomingSeen: i, - inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), - //lastACKed: model.PacketID(0), + incomingSeen: i, + inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), logger: logger, pendingACKsToSend: []model.PacketID{}, } From ff1ec511d286a4fdb679de05343a4b8afb6390b8 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 02:38:28 +0100 Subject: [PATCH 13/78] checkpoint --- internal/reliabletransport/receiver.go | 14 ++++++++++---- internal/reliabletransport/sender.go | 3 ++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index ad9cdf0b..5962aef0 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -55,10 +55,13 @@ func (ws *workersState) moveUpWorker() { continue } - if packet.ID < receiver.lastConsumed { - ws.logger.Warnf("%s: received %d but last consumed was %d", workerName, packet.ID, receiver.lastConsumed) - continue - } + // I think this check is not helping -- ain + /* + if packet.ID < receiver.lastConsumed { + ws.logger.Warnf("%s: received %d but last consumed was %d", workerName, packet.ID, receiver.lastConsumed) + continue + } + */ if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it @@ -148,6 +151,9 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { } func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) incomingPacketSeen { + if len(p.ACKs) != 0 { + fmt.Println(":: seen", p.ACKs) + } incomingPacket := incomingPacketSeen{} if p.Opcode == model.P_ACK_V1 { incomingPacket.acks = optional.Some(p.ACKs) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index e6c6703e..b0604d4a 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -6,7 +6,6 @@ import ( "time" "github.com/ooni/minivpn/internal/model" - "github.com/ooni/minivpn/internal/workers" ) // moveDownWorker moves packets down the stack (sender) @@ -195,6 +194,7 @@ func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { var _ outgoingPacketHandler = &reliableSender{} // doSendACK sends an ACK when needed. +/* func (ws *workersState) doSendACK(packet *model.Packet) error { // this function will fail if we don't know the remote session ID ACK, err := ws.sessionManager.NewACKForPacket(packet) @@ -211,3 +211,4 @@ func (ws *workersState) doSendACK(packet *model.Packet) error { return workers.ErrShutdown } } +*/ From 06d6c374f6f1b5d698f3ecc1f61e4fd40fe77c0e Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 03:42:48 +0100 Subject: [PATCH 14/78] x --- internal/reliabletransport/receiver.go | 8 +++++ internal/reliabletransport/sender.go | 41 +++++--------------------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 5962aef0..4db4605c 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -35,8 +35,16 @@ func (ws *workersState) moveUpWorker() { packet.Log(ws.logger, model.DirectionIncoming) } + /* + fmt.Println(">> packet session:", packet.LocalSessionID) + fmt.Println(">> our session:", ws.sessionManager.RemoteSessionID()) + */ + + fmt.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) + // drop a packet that is not for our session if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { + fmt.Println(">> not our session!!!") ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index b0604d4a..20083fb6 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,10 +28,11 @@ func (ws *workersState) moveDownWorker() { select { case packet := <-ws.controlToReliable: - sender.TryInsertOutgoingPacket(packet) - // schedule for inmediate wakeup - // so that the ticker will wakeup and see if there's anything pending to be sent. - ticker.Reset(time.Nanosecond) + // try to insert, and if done schedule for inmediate wakeup + // so that the scheduler will wakeup + if inserted := sender.TryInsertOutgoingPacket(packet); inserted { + ticker.Reset(time.Nanosecond) + } case seenPacket := <-sender.incomingSeen: // possibly evict any acked packet (in the ack array) @@ -61,16 +62,6 @@ func (ws *workersState) moveDownWorker() { // append any pending ACKs p.packet.ACKs = sender.NextPacketIDsToACK() - // HACK: we need to account for packet IDs received below (hard reset) - // (special case) - /* - if p.packet.ID == 1 && len(nextACKs) == 0 { - p.packet.ACKs = []model.PacketID{0} - } else { - p.packet.ACKs = nextACKs - } - */ - p.packet.Log(ws.logger, model.DirectionOutgoing) select { case ws.dataOrControlToMuxer <- p.packet: @@ -180,6 +171,8 @@ func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { // 1. add the ID to the queue of packets to be acknowledged. if !seen.id.IsNone() { + // TODO: do it only if not already in the array + // FIXME -------------------------------------- r.pendingACKsToSend = append(r.pendingACKsToSend, seen.id.Unwrap()) } @@ -192,23 +185,3 @@ func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { } var _ outgoingPacketHandler = &reliableSender{} - -// doSendACK sends an ACK when needed. -/* -func (ws *workersState) doSendACK(packet *model.Packet) error { - // this function will fail if we don't know the remote session ID - ACK, err := ws.sessionManager.NewACKForPacket(packet) - if err != nil { - return err - } - - // move the packet down. CAN BLOCK writing to the shared channel to muxer. - select { - case ws.dataOrControlToMuxer <- ACK: - ACK.Log(ws.logger, model.DirectionOutgoing) - return nil - case <-ws.workersManager.ShouldShutdown(): - return workers.ErrShutdown - } -} -*/ From bfd4380434f76798ce466cacb57255c7b62cbd0a Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 17:06:18 +0100 Subject: [PATCH 15/78] testing --- internal/packetmuxer/service.go | 7 +++ internal/reliabletransport/receiver.go | 3 ++ internal/reliabletransport/sender.go | 64 ++++++++++++++++++++++---- internal/session/manager.go | 6 +-- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 5ecb53fa..768d4d5f 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -220,6 +220,9 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { } // TODO: introduce other sanity checks here + // TODO *** + // TODO: make sure we're not blocking on delivering data packets up (from old sessions) + // TODO *** // multiplex the incoming packet POSSIBLY BLOCKING on delivering it if packet.IsControl() || packet.Opcode == model.P_ACK_V1 { @@ -233,6 +236,10 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { case ws.muxerToData <- packet: case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown + // TODO ----------------- temporary: do we get spurious data packets during hadnshake from previous sessions ------------------ + default: + ws.logger.Warnf("%s: moveUpWorker.handleRawPacket: dropped data packet", serviceName) + // TODO --------------------------------------------------------------------- } } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 4db4605c..42b2de26 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -40,6 +40,9 @@ func (ws *workersState) moveUpWorker() { fmt.Println(">> our session:", ws.sessionManager.RemoteSessionID()) */ + // ------ + // FIXME: do we need to act upon a HARD_RESET_V2 while we're doing a handshake? + // ------ fmt.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) // drop a packet that is not for our session diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 20083fb6..439137a6 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,8 +28,7 @@ func (ws *workersState) moveDownWorker() { select { case packet := <-ws.controlToReliable: - // try to insert, and if done schedule for inmediate wakeup - // so that the scheduler will wakeup + // try to insert and schedule for inmediate wakeup if inserted := sender.TryInsertOutgoingPacket(packet); inserted { ticker.Reset(time.Nanosecond) } @@ -38,7 +37,27 @@ func (ws *workersState) moveDownWorker() { // possibly evict any acked packet (in the ack array) // and add any id to the queue of packets to ack sender.OnIncomingPacketSeen(seenPacket) - ticker.Reset(time.Nanosecond) + + if len(sender.pendingACKsToSend) == 0 { + continue + } + + // reschedule the ticker + if len(sender.pendingACKsToSend) >= 2 { + ticker.Reset(time.Nanosecond) + continue + } + + // if there's no event soon, give some time for other acks to arrive + // TODO: review if we need this optimization. + // TODO: maybe only during TLS handshake?? + now := time.Now() + timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) + gracePeriod := time.Millisecond * 20 + if timeout.Sub(now) > gracePeriod { + fmt.Println(">> next wakeup too late, schedule in", gracePeriod) + ticker.Reset(gracePeriod) + } case <-ticker.C: // First of all, we reset the ticker to the next timeout. @@ -53,18 +72,43 @@ func (ws *workersState) moveDownWorker() { ticker.Reset(timeout.Sub(now)) - // we flush everything that is ready to be sent. scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) - for _, p := range scheduledNow { - p.ScheduleForRetransmission(now) + if len(scheduledNow) > 0 { + // we flush everything that is ready to be sent. + for _, p := range scheduledNow { + p.ScheduleForRetransmission(now) - // append any pending ACKs - p.packet.ACKs = sender.NextPacketIDsToACK() + // append any pending ACKs + p.packet.ACKs = sender.NextPacketIDsToACK() - p.packet.Log(ws.logger, model.DirectionOutgoing) + p.packet.Log(ws.logger, model.DirectionOutgoing) + select { + case ws.dataOrControlToMuxer <- p.packet: + case <-ws.workersManager.ShouldShutdown(): + return + } + } + } else { + // there's nothing ready to be sent, so we see if we've got pending ACKs + if len(sender.pendingACKsToSend) == 0 { + continue + } + // special case, we want to send the clientHello as soon as possible + // (TODO: coordinate this with hardReset) + if len(sender.pendingACKsToSend) == 1 && sender.pendingACKsToSend[0] == model.PacketID(0) { + continue + } + + fmt.Println(":: CREATING ACK", len(sender.pendingACKsToSend), "pending to ack") + + ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) + if err != nil { + ws.logger.Warnf("%s: cannot create ack: %v", workerName, err.Error()) + } + ACK.Log(ws.logger, model.DirectionOutgoing) select { - case ws.dataOrControlToMuxer <- p.packet: + case ws.dataOrControlToMuxer <- ACK: case <-ws.workersManager.ShouldShutdown(): return } diff --git a/internal/session/manager.go b/internal/session/manager.go index 26339dff..2b586205 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -160,8 +160,8 @@ func (m *Manager) IsRemoteSessionIDSet() bool { // ErrNoRemoteSessionID indicates we are missing the remote session ID. var ErrNoRemoteSessionID = errors.New("missing remote session ID") -// NewACKForPacket creates a new ACK for the given packet. -func (m *Manager) NewACKForPacket(packet *model.Packet) (*model.Packet, error) { +// NewACKForPacket creates a new ACK for the given packet IDs. +func (m *Manager) NewACKForPacketIDs(ids []model.PacketID) (*model.Packet, error) { defer m.mu.Unlock() m.mu.Lock() if m.remoteSessionID.IsNone() { @@ -172,7 +172,7 @@ func (m *Manager) NewACKForPacket(packet *model.Packet) (*model.Packet, error) { KeyID: m.keyID, PeerID: [3]byte{}, LocalSessionID: m.localSessionID, - ACKs: []model.PacketID{packet.ID}, + ACKs: ids, RemoteSessionID: m.remoteSessionID.Unwrap(), ID: 0, Payload: []byte{}, From 44c6f9b705b5fe12c8cdace3f8808b838a73f2c4 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 19:49:26 +0100 Subject: [PATCH 16/78] pass option to do just the handshake and no routes --- cmd/minivpn2/main.go | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go index 7ca5804f..9b624ca8 100644 --- a/cmd/minivpn2/main.go +++ b/cmd/minivpn2/main.go @@ -2,10 +2,12 @@ package main import ( "context" + "flag" "fmt" "net" "os" "os/exec" + "time" "github.com/Doridian/water" "github.com/apex/log" @@ -35,18 +37,28 @@ func runRoute(args ...string) { runCmd("/sbin/route", args...) } -/* -func logWithElapsedTime(logger log.Interface, message string, start time.Time) { - elapsedTime := time.Since(startTime).Round(time.Millisecond) - logger.WithField("elapsed_time", elapsedTime).Info(message) +type config struct { + skipRoute bool + configPath string + timeout int } -*/ func main() { log.SetLevel(log.DebugLevel) + cfg := &config{} + flag.BoolVar(&cfg.skipRoute, "skip-route", false, "if true, exists without setting routes (for testing)") + flag.StringVar(&cfg.configPath, "config", "", "config file to load") + flag.IntVar(&cfg.timeout, "timeout", 60, "timeout in seconds (default=60)") + flag.Parse() + + if cfg.configPath == "" { + fmt.Println("[error] need config path") + os.Exit(1) + } + // parse the configuration file - options, err := model.ReadConfigFile(os.Args[1]) + options, err := model.ReadConfigFile(cfg.configPath) if err != nil { log.WithError(err).Fatal("NewOptionsFromFilePath") } @@ -73,8 +85,8 @@ func main() { // The TLS will expire in 60 seconds by default, but we can pass // a shorter timeout. - //ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - //defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.timeout)*time.Second) + defer cancel() // create a vpn tun Device tunnel, err := tun.StartTUN(ctx, conn, options, log.Log) @@ -82,8 +94,14 @@ func main() { log.WithError(err).Fatal("init error") return } - fmt.Printf("Local IP: %s\n", tunnel.LocalAddr()) - fmt.Printf("Gateway: %s\n", tunnel.RemoteAddr()) + log.Infof("Local IP: %s\n", tunnel.LocalAddr()) + log.Infof("Gateway: %s\n", tunnel.RemoteAddr()) + + fmt.Println("initialization-sequence-completed") + + if cfg.skipRoute { + os.Exit(0) + } // create a tun interface on the OS iface, err := water.New(water.Config{ From 0e074d1dabb8e59982787ed94f63a595d92d61de Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 25 Jan 2024 19:49:44 +0100 Subject: [PATCH 17/78] defend if data before keys --- internal/packetmuxer/service.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 768d4d5f..13b014bf 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -2,6 +2,7 @@ package packetmuxer import ( + "errors" "fmt" "time" @@ -232,13 +233,16 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { return workers.ErrShutdown } } else { + if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + return errors.New("not ready to handle data") + } select { case ws.muxerToData <- packet: case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown - // TODO ----------------- temporary: do we get spurious data packets during hadnshake from previous sessions ------------------ - default: - ws.logger.Warnf("%s: moveUpWorker.handleRawPacket: dropped data packet", serviceName) + // TODO ----------------- temporary: do we get spurious data packets during hadnshake from previous sessions ------------------ + //default: + //ws.logger.Warnf("%s: moveUpWorker.handleRawPacket: dropped data packet", serviceName) // TODO --------------------------------------------------------------------- } } From 2c6c941f80eb953571ed6f915cb884d130ce7704 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:17:11 +0100 Subject: [PATCH 18/78] improve comment --- internal/packetmuxer/service.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 13b014bf..ba32c901 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -75,6 +75,9 @@ type workersState struct { // hardReset is the channel posted to force a hard reset. hardReset <-chan any + // how many times have we sent the initial hardReset packet + hardResetCount int + // hardResetTicker is a channel to retry the initial send of hard reset packet. hardResetTicker *time.Ticker @@ -182,8 +185,16 @@ func (ws *workersState) moveDownWorker() { // startHardReset is invoked when we need to perform a HARD RESET. func (ws *workersState) startHardReset() error { + ws.hardResetCount += 1 + // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt - packet, err := ws.sessionManager.NewPacket(model.P_CONTROL_HARD_RESET_CLIENT_V2, nil) + // packet, err := ws.sessionManager.NewPacket(model.P_CONTROL_HARD_RESET_CLIENT_V2, nil) + first := false + if ws.hardResetCount == 1 { + first = true + } + + packet, err := ws.sessionManager.NewHardResetPacket(first) if err != nil { ws.logger.Warnf("packetmuxer: NewPacket: %s", err.Error()) return err @@ -220,11 +231,6 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { return ws.finishThreeWayHandshake(packet) } - // TODO: introduce other sanity checks here - // TODO *** - // TODO: make sure we're not blocking on delivering data packets up (from old sessions) - // TODO *** - // multiplex the incoming packet POSSIBLY BLOCKING on delivering it if packet.IsControl() || packet.Opcode == model.P_ACK_V1 { select { @@ -240,10 +246,12 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { case ws.muxerToData <- packet: case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown - // TODO ----------------- temporary: do we get spurious data packets during hadnshake from previous sessions ------------------ - //default: - //ws.logger.Warnf("%s: moveUpWorker.handleRawPacket: dropped data packet", serviceName) - // TODO --------------------------------------------------------------------- + // TODO: make sure we're not blocking on delivering data packets + // TODO(ainghazal): afaik, a well-behaved server will not send us data packets + // before we have a working session. Under normal operations, the + // UDP connection in the client side should pick a different port, + // so that data sent from previous sessions will not be delivered. + // However, it might not harm to be defensive here. } } From b512ad932f63854e3496d3cb427882422f2b8774 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:19:21 +0100 Subject: [PATCH 19/78] comment on newHardResetPacket --- internal/session/manager.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/internal/session/manager.go b/internal/session/manager.go index 2b586205..d86b64b7 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -208,6 +208,26 @@ func (m *Manager) NewPacket(opcode model.Opcode, payload []byte) (*model.Packet, return packet, nil } +// NewHardResetPacket creates a new hard reset packet for this session. +// This packet is a special case because, if we resend, we must not bump +// its packet ID. Normally retransmission is handled at the reliabletransport layer, +// but we send hard resets at the muxer. +func (m *Manager) NewHardResetPacket(first bool) (*model.Packet, error) { + packet := model.NewPacket( + model.P_CONTROL_HARD_RESET_CLIENT_V2, + m.keyID, + []byte{}, + ) + if first { + pid, _ := m.localControlPacketIDLocked() + packet.ID = pid + } else { + packet.ID = 0 + } + copy(packet.LocalSessionID[:], m.localSessionID[:]) + return packet, nil +} + var ErrExpiredKey = errors.New("expired key") // LocalDataPacketID returns an unique Packet ID for the Data Channel. It From 690762ed4ce424c468d7eadb4f217ea97eac6e7a Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:22:36 +0100 Subject: [PATCH 20/78] comments --- internal/reliabletransport/receiver.go | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 42b2de26..800e7bf7 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -3,6 +3,7 @@ package reliabletransport import ( "bytes" "fmt" + "log" "sort" "github.com/ooni/minivpn/internal/model" @@ -35,19 +36,13 @@ func (ws *workersState) moveUpWorker() { packet.Log(ws.logger, model.DirectionIncoming) } - /* - fmt.Println(">> packet session:", packet.LocalSessionID) - fmt.Println(">> our session:", ws.sessionManager.RemoteSessionID()) - */ - - // ------ - // FIXME: do we need to act upon a HARD_RESET_V2 while we're doing a handshake? - // ------ - fmt.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) + // TODO: are we handling a HARD_RESET_V2 while we're doing a handshake? + // I'm not sure that's a valid behavior for a server. + // We should be able to deterministically test how this affects the state machine. + log.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) // drop a packet that is not for our session if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { - fmt.Println(">> not our session!!!") ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, From 675fd8f54c876dae6843726c3b937adeea66eeaf Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:23:05 +0100 Subject: [PATCH 21/78] add elapsed time for benchmarking --- cmd/minivpn2/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go index 9b624ca8..751ce740 100644 --- a/cmd/minivpn2/main.go +++ b/cmd/minivpn2/main.go @@ -72,6 +72,8 @@ func main() { log.SetHandler(NewHandler(os.Stderr)) log.SetLevel(log.DebugLevel) + start := time.Now() + // connect to the server dialer := networkio.NewDialer(log.Log, &net.Dialer{}) ctx := context.Background() @@ -98,6 +100,7 @@ func main() { log.Infof("Gateway: %s\n", tunnel.RemoteAddr()) fmt.Println("initialization-sequence-completed") + fmt.Printf("elapsed: %v\n", time.Since(start)) if cfg.skipRoute { os.Exit(0) From 5b4d2eb83d745e451bd4251e9b8458e7c024bf13 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:31:11 +0100 Subject: [PATCH 22/78] log --- internal/model/packet.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/model/packet.go b/internal/model/packet.go index fd59b63c..210e0f60 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -312,6 +312,7 @@ const ( DirectionOutgoing ) +// Log writes an entry in the passed logger with a representation of this packet. func (p *Packet) Log(logger Logger, direction int) { var dir string switch direction { @@ -324,7 +325,7 @@ func (p *Packet) Log(logger Logger, direction int) { return } - logger.Infof( + logger.Debugf( "%s %s {id=%d, acks=%v} localID=%x remoteID=%x [%d bytes]", dir, p.Opcode, From 8c5816cf0d8266567a33375ae800e0a06b534579 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:37:52 +0100 Subject: [PATCH 23/78] add doc.go --- internal/reliabletransport/doc.go | 5 +++++ internal/reliabletransport/packets.go | 7 +++---- 2 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 internal/reliabletransport/doc.go diff --git a/internal/reliabletransport/doc.go b/internal/reliabletransport/doc.go new file mode 100644 index 00000000..2ba1edf9 --- /dev/null +++ b/internal/reliabletransport/doc.go @@ -0,0 +1,5 @@ +// Package reliabletransport implements the reliable transport. +// A note about terminology: in this package, "receiver" is the moveUpWorker in the [reliabletransport.Service] (since it receives incoming packets), and +// "sender" is the moveDownWorker in the same service. The corresponding data structures lack mutexes because they are intended to be confined to a single +// goroutine (one for each worker), and they SHOULD ONLY communicate via message passing. +package reliabletransport diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index b4d392ad..3b3341bc 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -8,10 +8,8 @@ import ( "github.com/ooni/minivpn/internal/optional" ) -// -// A note about terminology: in the following, **receiver** is the moveUpWorker in the [reliabletransport.Service] (since it receives incoming packets), and **sender** is the moveDownWorker in the same service. The following data structures lack mutexes because they are intended to be confined to a single goroutine (one for each worker), and they only communicate via message passing. -// - +// inFlighPacket is an implementation of inFlighter. It is a sequential packet +// that can be scheduled for retransmission. type inFlightPacket struct { // deadline is a moment in time when is this packet scheduled for the next retransmission. deadline time.Time @@ -59,6 +57,7 @@ func (p *inFlightPacket) backoff() time.Duration { return backoff } +// TODO: revisit interfaces while writing tests. // assert that inFlightWrappedPacket implements inFlightPacket and sequentialPacket // var _ inFlightPacket = &inFlightWrappedPacket{} // var _ sequentialPacket = &inFlightWrappedPacket{} From 0c37284b65a7e9ba9ab5ef8eeb37cac9b34b6691 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:38:53 +0100 Subject: [PATCH 24/78] x --- internal/reliabletransport/packets.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 3b3341bc..194b3e08 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -115,8 +115,8 @@ func (seq inflightSequence) Less(i, j int) bool { return seq[i].packet.ID < seq[j].packet.ID } -// A incomingSequence is an array of sequentialPackets. It's used to store both incoming and outgoing packet queues. -// a incomingSequence can be sorted. +// An incomingSequence is an array of sequentialPackets. It's used to store both incoming and outgoing packet queues. +// An incomingSequence can be sorted. type incomingSequence []sequentialPacket // implement sort.Interface From 2fb6a3c4aa23f8ddc8267c579e9af3f2cc8d45aa Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:42:10 +0100 Subject: [PATCH 25/78] remove unused code --- internal/packetmuxer/service.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index ba32c901..3037832b 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -188,7 +188,6 @@ func (ws *workersState) startHardReset() error { ws.hardResetCount += 1 // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt - // packet, err := ws.sessionManager.NewPacket(model.P_CONTROL_HARD_RESET_CLIENT_V2, nil) first := false if ws.hardResetCount == 1 { first = true From 80adb9565f309521cb740e201731b7717bc91137 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:44:36 +0100 Subject: [PATCH 26/78] move comment location --- internal/packetmuxer/service.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 3037832b..33052160 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -239,18 +239,17 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { } } else { if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + // A well-behaved server should not send us data packets + // before we have a working session. Under normal operations, the + // connection in the client side should pick a different port, + // so that data sent from previous sessions will not be delivered. + // However, it does not harm to be defensive here. return errors.New("not ready to handle data") } select { case ws.muxerToData <- packet: case <-ws.workersManager.ShouldShutdown(): return workers.ErrShutdown - // TODO: make sure we're not blocking on delivering data packets - // TODO(ainghazal): afaik, a well-behaved server will not send us data packets - // before we have a working session. Under normal operations, the - // UDP connection in the client side should pick a different port, - // so that data sent from previous sessions will not be delivered. - // However, it might not harm to be defensive here. } } From ba94049feae301260bd2ce60191febe97655d043 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:46:31 +0100 Subject: [PATCH 27/78] x --- internal/reliabletransport/packets.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 194b3e08..0c6b5486 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -84,7 +84,7 @@ func (seq inflightSequence) nearestDeadlineTo(t time.Time) time.Time { return timeout } -// readyToSend eturns the subset of this sequence that has a expired deadline or +// readyToSend returns the subset of this sequence that has a expired deadline or // is suitable for fast retransmission. func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { expired := make([]*inFlightPacket, 0) From c941f02dc549cab0820a704e6d221ffdd0d68541 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:48:54 +0100 Subject: [PATCH 28/78] remove commented code --- internal/reliabletransport/receiver.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 800e7bf7..84053d8c 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -61,14 +61,6 @@ func (ws *workersState) moveUpWorker() { continue } - // I think this check is not helping -- ain - /* - if packet.ID < receiver.lastConsumed { - ws.logger.Warnf("%s: received %d but last consumed was %d", workerName, packet.ID, receiver.lastConsumed) - continue - } - */ - if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it continue From 1bc23611ae6c88b6ae5f02ac60eaca11c0f4e0e0 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:54:20 +0100 Subject: [PATCH 29/78] add link in docs --- internal/reliabletransport/doc.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/reliabletransport/doc.go b/internal/reliabletransport/doc.go index 2ba1edf9..f5e3e966 100644 --- a/internal/reliabletransport/doc.go +++ b/internal/reliabletransport/doc.go @@ -1,4 +1,9 @@ -// Package reliabletransport implements the reliable transport. +// Package reliabletransport implements the reliable transport module for OpenVPN. +// See [the official documentation](https://community.openvpn.net/openvpn/wiki/SecurityOverview) for a detail explanation +// of why this is needed, and how it relates to the requirements of the control channel. +// It is worth to mention that, even though the original need is to have a reliable control channel +// on top of UDP, this is also used when tunneling over TCP. +// // A note about terminology: in this package, "receiver" is the moveUpWorker in the [reliabletransport.Service] (since it receives incoming packets), and // "sender" is the moveDownWorker in the same service. The corresponding data structures lack mutexes because they are intended to be confined to a single // goroutine (one for each worker), and they SHOULD ONLY communicate via message passing. From 20f89e2a43d90468c5f950c65e4d421748a7c913 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 17:57:54 +0100 Subject: [PATCH 30/78] rename in test --- internal/reliabletransport/sender_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index 20b94e18..e13ed9ed 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -9,10 +9,10 @@ import ( ) // -// tests for reliableOutgoing +// tests for reliableSender // -func Test_reliableOutgoing_TryInsertOutgoingPacket(t *testing.T) { +func Test_reliableSender_TryInsertOutgoingPacket(t *testing.T) { log.SetLevel(log.DebugLevel) type fields struct { @@ -68,13 +68,13 @@ func Test_reliableOutgoing_TryInsertOutgoingPacket(t *testing.T) { inFlight: tt.fields.inFlight, } if got := r.TryInsertOutgoingPacket(tt.args.p); got != tt.want { - t.Errorf("reliableOutgoing.TryInsertOutgoingPacket() = %v, want %v", got, tt.want) + t.Errorf("reliableSender.TryInsertOutgoingPacket() = %v, want %v", got, tt.want) } }) } } -func Test_reliableOutgoing_NextPacketIDsToACK(t *testing.T) { +func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { log.SetLevel(log.DebugLevel) type fields struct { @@ -121,14 +121,14 @@ func Test_reliableOutgoing_NextPacketIDsToACK(t *testing.T) { pendingACKsToSend: tt.fields.pendingACKsToSend, } if got := r.NextPacketIDsToACK(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("reliableOutgoing.NextPacketIDsToACK() = %v, want %v", got, tt.want) + t.Errorf("reliableSender.NextPacketIDsToACK() = %v, want %v", got, tt.want) } }) } } // -// tests for reliableIncoming +// tests for reliableReceiver // // testIncomingPacket is a sequentialPacket for testing incomingPackets From 9f751fb48833336682af5805af30e1b02beb7eea Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 18:08:30 +0100 Subject: [PATCH 31/78] remove binary --- cmd/minivpn2/iface.go | 30 ------- cmd/minivpn2/log.go | 79 ------------------- cmd/minivpn2/main.go | 178 ------------------------------------------ 3 files changed, 287 deletions(-) delete mode 100644 cmd/minivpn2/iface.go delete mode 100644 cmd/minivpn2/log.go delete mode 100644 cmd/minivpn2/main.go diff --git a/cmd/minivpn2/iface.go b/cmd/minivpn2/iface.go deleted file mode 100644 index 4102d938..00000000 --- a/cmd/minivpn2/iface.go +++ /dev/null @@ -1,30 +0,0 @@ -package main - -import ( - "fmt" - "net" -) - -func getInterfaceByIP(ipAddr string) (*net.Interface, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - for _, iface := range interfaces { - addrs, err := iface.Addrs() - if err != nil { - return nil, err - } - - for _, addr := range addrs { - if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.String() == ipAddr { - return &iface, nil - } - } - } - } - - return nil, fmt.Errorf("interface with IP %s not found", ipAddr) -} diff --git a/cmd/minivpn2/log.go b/cmd/minivpn2/log.go deleted file mode 100644 index bf4a5c22..00000000 --- a/cmd/minivpn2/log.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - "sync" - "time" - - "github.com/apex/log" -) - -// Default handler outputting to stderr. -var Default = NewHandler(os.Stderr) - -// start time. -var start = time.Now() - -// colors. -const ( - none = 0 - red = 31 - green = 32 - yellow = 33 - blue = 34 - gray = 37 -) - -// Colors mapping. -var Colors = [...]int{ - log.DebugLevel: gray, - log.InfoLevel: blue, - log.WarnLevel: yellow, - log.ErrorLevel: red, - log.FatalLevel: red, -} - -// Strings mapping. -var Strings = [...]string{ - log.DebugLevel: "DEBUG", - log.InfoLevel: "INFO", - log.WarnLevel: "WARN", - log.ErrorLevel: "ERROR", - log.FatalLevel: "FATAL", -} - -// Handler implementation. -type Handler struct { - mu sync.Mutex - Writer io.Writer -} - -// New handler. -func NewHandler(w io.Writer) *Handler { - return &Handler{ - Writer: w, - } -} - -// HandleLog implements log.Handler. -func (h *Handler) HandleLog(e *log.Entry) error { - color := Colors[e.Level] - level := Strings[e.Level] - names := e.Fields.Names() - - h.mu.Lock() - defer h.mu.Unlock() - - ts := time.Since(start) - fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%10v] %-25s", color, level, ts, e.Message) - - for _, name := range names { - fmt.Fprintf(h.Writer, " \033[%dm%s\033[0m=%v", color, name, e.Fields.Get(name)) - } - - fmt.Fprintln(h.Writer) - - return nil -} diff --git a/cmd/minivpn2/main.go b/cmd/minivpn2/main.go deleted file mode 100644 index 751ce740..00000000 --- a/cmd/minivpn2/main.go +++ /dev/null @@ -1,178 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "net" - "os" - "os/exec" - "time" - - "github.com/Doridian/water" - "github.com/apex/log" - "github.com/jackpal/gateway" - - "github.com/ooni/minivpn/internal/model" - "github.com/ooni/minivpn/internal/networkio" - "github.com/ooni/minivpn/internal/tun" -) - -func runCmd(binaryPath string, args ...string) { - cmd := exec.Command(binaryPath, args...) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout - cmd.Stdin = os.Stdin - err := cmd.Run() - if nil != err { - log.WithError(err).Warn("error running /sbin/ip") - } -} - -func runIP(args ...string) { - runCmd("/sbin/ip", args...) -} - -func runRoute(args ...string) { - runCmd("/sbin/route", args...) -} - -type config struct { - skipRoute bool - configPath string - timeout int -} - -func main() { - log.SetLevel(log.DebugLevel) - - cfg := &config{} - flag.BoolVar(&cfg.skipRoute, "skip-route", false, "if true, exists without setting routes (for testing)") - flag.StringVar(&cfg.configPath, "config", "", "config file to load") - flag.IntVar(&cfg.timeout, "timeout", 60, "timeout in seconds (default=60)") - flag.Parse() - - if cfg.configPath == "" { - fmt.Println("[error] need config path") - os.Exit(1) - } - - // parse the configuration file - options, err := model.ReadConfigFile(cfg.configPath) - if err != nil { - log.WithError(err).Fatal("NewOptionsFromFilePath") - } - log.Infof("parsed options: %s", options.ServerOptionsString()) - - // TODO(ainghazal): move the initialization step to an early phase and keep a ref in the muxer - if !options.HasAuthInfo() { - log.Fatal("options are missing auth info") - } - - log.SetHandler(NewHandler(os.Stderr)) - log.SetLevel(log.DebugLevel) - - start := time.Now() - - // connect to the server - dialer := networkio.NewDialer(log.Log, &net.Dialer{}) - ctx := context.Background() - - endpoint := net.JoinHostPort(options.Remote, options.Port) - - conn, err := dialer.DialContext(ctx, options.Proto.String(), endpoint) - if err != nil { - log.WithError(err).Fatal("dialer.DialContext") - } - - // The TLS will expire in 60 seconds by default, but we can pass - // a shorter timeout. - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.timeout)*time.Second) - defer cancel() - - // create a vpn tun Device - tunnel, err := tun.StartTUN(ctx, conn, options, log.Log) - if err != nil { - log.WithError(err).Fatal("init error") - return - } - log.Infof("Local IP: %s\n", tunnel.LocalAddr()) - log.Infof("Gateway: %s\n", tunnel.RemoteAddr()) - - fmt.Println("initialization-sequence-completed") - fmt.Printf("elapsed: %v\n", time.Since(start)) - - if cfg.skipRoute { - os.Exit(0) - } - - // create a tun interface on the OS - iface, err := water.New(water.Config{ - DeviceType: water.TUN, - }) - if err != nil { - log.WithError(err).Fatal("Unable to allocate TUN interface:") - } - - // TODO: investigate what's the maximum working MTU, additionally get it from flag. - MTU := 1420 - iface.SetMTU(MTU) - - localAddr := tunnel.LocalAddr().String() - remoteAddr := tunnel.RemoteAddr().String() - netMask := tunnel.NetMask() - - // discover local gateway IP, we need it to add a route to our remote via our network gw - defaultGatewayIP, err := gateway.DiscoverGateway() - if err != nil { - log.Warn("could not discover default gateway IP, routes might be broken") - } - defaultInterfaceIP, err := gateway.DiscoverInterface() - if err != nil { - log.Warn("could not discover default route interface IP, routes might be broken") - } - defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) - if err != nil { - log.Warn("could not get default route interface, routes might be broken") - } - - if defaultGatewayIP != nil && defaultInterface != nil { - log.Infof("route add %s gw %v dev %s", options.Remote, defaultGatewayIP, defaultInterface.Name) - runRoute("add", options.Remote, "gw", defaultGatewayIP.String(), defaultInterface.Name) - } - - // we want the network CIDR for setting up the routes - network := &net.IPNet{ - IP: net.ParseIP(localAddr).Mask(netMask), - Mask: netMask, - } - - // configure the interface and bring it up - runIP("addr", "add", localAddr, "dev", iface.Name()) - runIP("link", "set", "dev", iface.Name(), "up") - runRoute("add", remoteAddr, "gw", localAddr) - runRoute("add", "-net", network.String(), "dev", iface.Name()) - runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) - - go func() { - for { - packet := make([]byte, 2000) - n, err := iface.Read(packet) - if err != nil { - log.WithError(err).Fatal("error reading from tun") - } - tunnel.Write(packet[:n]) - } - }() - go func() { - for { - packet := make([]byte, 2000) - n, err := tunnel.Read(packet) - if err != nil { - log.WithError(err).Fatal("error reading from tun") - } - iface.Write(packet[:n]) - } - }() - select {} -} From f30db744b890b925fa4cbd2af837d3aee7cd7fdc Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Mon, 29 Jan 2024 18:10:46 +0100 Subject: [PATCH 32/78] revert Makefile --- Makefile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Makefile b/Makefile index 74094e95..eaad788c 100644 --- a/Makefile +++ b/Makefile @@ -27,9 +27,6 @@ build-ndt7: bootstrap: @./scripts/bootstrap-provider ${PROVIDER} -handshake_log: - @sudo ./minivpn2 data/${PROVIDER}/config 2>&1 | grep --text --color=auto -E "@|P_\w+" - test: GOFLAGS='-count=1' go test -v ./... From 582ef2004a1fa4c2c76c2078702b972db5b48df8 Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:25:41 +0100 Subject: [PATCH 33/78] Update internal/packetmuxer/service.go Co-authored-by: Simone Basso --- internal/packetmuxer/service.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 33052160..ae429efe 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -188,10 +188,7 @@ func (ws *workersState) startHardReset() error { ws.hardResetCount += 1 // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt - first := false - if ws.hardResetCount == 1 { - first = true - } + first := ws.hardResetCount == 1 packet, err := ws.sessionManager.NewHardResetPacket(first) if err != nil { From dcdffec53c7ea52f6a0c8e8937c6e03c8294650b Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:26:05 +0100 Subject: [PATCH 34/78] Update internal/reliabletransport/doc.go Co-authored-by: Simone Basso --- internal/reliabletransport/doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/reliabletransport/doc.go b/internal/reliabletransport/doc.go index f5e3e966..3857d19f 100644 --- a/internal/reliabletransport/doc.go +++ b/internal/reliabletransport/doc.go @@ -1,5 +1,5 @@ // Package reliabletransport implements the reliable transport module for OpenVPN. -// See [the official documentation](https://community.openvpn.net/openvpn/wiki/SecurityOverview) for a detail explanation +// See [the official documentation](https://community.openvpn.net/openvpn/wiki/SecurityOverview) for a detailed explanation // of why this is needed, and how it relates to the requirements of the control channel. // It is worth to mention that, even though the original need is to have a reliable control channel // on top of UDP, this is also used when tunneling over TCP. From 286b1366b38b0b799b079bf3baea7fe6f0c847b5 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 16:27:33 +0100 Subject: [PATCH 35/78] rename --- internal/reliabletransport/{interfaces.go => model.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename internal/reliabletransport/{interfaces.go => model.go} (100%) diff --git a/internal/reliabletransport/interfaces.go b/internal/reliabletransport/model.go similarity index 100% rename from internal/reliabletransport/interfaces.go rename to internal/reliabletransport/model.go From 1fbbee545c10faf177dd9b7017f945c12eb788ff Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 16:28:58 +0100 Subject: [PATCH 36/78] apply suggestion --- internal/reliabletransport/model.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/reliabletransport/model.go b/internal/reliabletransport/model.go index 85f1fb3a..a4cd8aa1 100644 --- a/internal/reliabletransport/model.go +++ b/internal/reliabletransport/model.go @@ -19,8 +19,8 @@ type inFlighter interface { // outgoingPacketHandler has methods to deal with the outgoing packets (going down). type outgoingPacketHandler interface { - // TryInsertOutgoingPacket attempts to insert a packet into the queue. If return value is - // false, insertion was not successful. + // TryInsertOutgoingPacket attempts to insert a packet into the + // inflight queue. If return value is false, insertion was not successful. TryInsertOutgoingPacket(*model.Packet) bool // MaybeEvictOrBumpPacketAfterACK removes a packet (that we received an ack for) from the in-flight packet queue. From 2a15f2def28c8f92723f240b76338a0ce0f05f60 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 16:35:08 +0100 Subject: [PATCH 37/78] apply suggestions --- internal/reliabletransport/model.go | 5 +++-- internal/reliabletransport/packets.go | 3 ++- internal/reliabletransport/receiver.go | 4 ++-- internal/reliabletransport/sender.go | 2 +- internal/reliabletransport/service.go | 1 - internal/tun/tun.go | 3 +-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/reliabletransport/model.go b/internal/reliabletransport/model.go index a4cd8aa1..659abba4 100644 --- a/internal/reliabletransport/model.go +++ b/internal/reliabletransport/model.go @@ -20,14 +20,15 @@ type inFlighter interface { // outgoingPacketHandler has methods to deal with the outgoing packets (going down). type outgoingPacketHandler interface { // TryInsertOutgoingPacket attempts to insert a packet into the - // inflight queue. If return value is false, insertion was not successful. + // inflight queue. If return value is false, insertion was not successful (e.g., too many + // packets in flight). TryInsertOutgoingPacket(*model.Packet) bool // MaybeEvictOrBumpPacketAfterACK removes a packet (that we received an ack for) from the in-flight packet queue. MaybeEvictOrBumpPacketAfterACK(id model.PacketID) bool // NextPacketIDsToACK returns an array of pending IDs to ACK to - // our remote. The lenght of this array SHOULD not be larger than CONTROL_SEND_ACK_MAX. + // our remote. The length of this array SHOULD not be larger than CONTROL_SEND_ACK_MAX. // This is used to append it to the ACK array of the outgoing packet. NextPacketIDsToACK() []model.PacketID diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 0c6b5486..692a5ebc 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -93,7 +93,8 @@ func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { fmt.Println("DEBUG: fast retransmit for", p.packet.ID) expired = append(expired, p) continue - } else if p.deadline.Before(t) { + } + if p.deadline.Before(t) { expired = append(expired, p) } } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 84053d8c..67cafde3 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -32,7 +32,7 @@ func (ws *workersState) moveUpWorker() { select { case packet := <-ws.muxerToReliable: if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 { - // the hard reset we logged in below + // the hard reset has already been logged by the layer below packet.Log(ws.logger, model.DirectionIncoming) } @@ -130,7 +130,7 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { // sort them so that we begin with lower model.PacketID sort.Sort(r.incomingPackets) - keep := r.incomingPackets[:0] + var keep incomingSequence for i, p := range r.incomingPackets { if p.ID()-last == 1 { diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 439137a6..2b6fc7d1 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -28,7 +28,7 @@ func (ws *workersState) moveDownWorker() { select { case packet := <-ws.controlToReliable: - // try to insert and schedule for inmediate wakeup + // try to insert and schedule for immediate wakeup if inserted := sender.TryInsertOutgoingPacket(packet); inserted { ticker.Reset(time.Nanosecond) } diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go index e67c1f5c..14a5ab0b 100644 --- a/internal/reliabletransport/service.go +++ b/internal/reliabletransport/service.go @@ -36,7 +36,6 @@ func (s *Service) StartWorkers( workersManager *workers.Manager, sessionManager *session.Manager, ) { - ws := &workersState{ logger: logger, incomingSeen: make(chan incomingPacketSeen, 20), diff --git a/internal/tun/tun.go b/internal/tun/tun.go index 5e0c2bd4..85d8cbab 100644 --- a/internal/tun/tun.go +++ b/internal/tun/tun.go @@ -23,9 +23,8 @@ var ( // StartTUN initializes and starts the TUN device over the vpn. // If the passed context expires before the TUN device is ready, func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options, logger model.Logger) (*TUN, error) { - // be useful if passing an empty logger if logger == nil { - logger = log.Log + return nil, errors.New("logger cannot be nil") } // create a session From 09fcc51b75c447e3067f60bfbc01c32e4d7c7087 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 16:37:57 +0100 Subject: [PATCH 38/78] terminology --- internal/reliabletransport/model.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/reliabletransport/model.go b/internal/reliabletransport/model.go index 659abba4..1591b60e 100644 --- a/internal/reliabletransport/model.go +++ b/internal/reliabletransport/model.go @@ -28,7 +28,7 @@ type outgoingPacketHandler interface { MaybeEvictOrBumpPacketAfterACK(id model.PacketID) bool // NextPacketIDsToACK returns an array of pending IDs to ACK to - // our remote. The length of this array SHOULD not be larger than CONTROL_SEND_ACK_MAX. + // our remote. The length of this array MUST NOT be larger than CONTROL_SEND_ACK_MAX. // This is used to append it to the ACK array of the outgoing packet. NextPacketIDsToACK() []model.PacketID From e50e4105bd3fce901cb4012854b6a47bea3b97d1 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 17:16:22 +0100 Subject: [PATCH 39/78] fix docs --- internal/reliabletransport/doc.go | 4 ---- internal/reliabletransport/receiver.go | 5 ++++- internal/reliabletransport/sender.go | 4 +++- internal/reliabletransport/service.go | 5 +++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/internal/reliabletransport/doc.go b/internal/reliabletransport/doc.go index 3857d19f..a01414f2 100644 --- a/internal/reliabletransport/doc.go +++ b/internal/reliabletransport/doc.go @@ -3,8 +3,4 @@ // of why this is needed, and how it relates to the requirements of the control channel. // It is worth to mention that, even though the original need is to have a reliable control channel // on top of UDP, this is also used when tunneling over TCP. -// -// A note about terminology: in this package, "receiver" is the moveUpWorker in the [reliabletransport.Service] (since it receives incoming packets), and -// "sender" is the moveDownWorker in the same service. The corresponding data structures lack mutexes because they are intended to be confined to a single -// goroutine (one for each worker), and they SHOULD ONLY communicate via message passing. package reliabletransport diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 67cafde3..5081882a 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -10,7 +10,10 @@ import ( "github.com/ooni/minivpn/internal/optional" ) -// moveUpWorker moves packets up the stack (receiver) +// moveUpWorker moves packets up the stack (receiver). +// The sender and receiver data structures lack mutexs because they are +// intended to be confined to a single goroutine (one for each worker), and +// they SHOULD ONLY communicate via message passing. func (ws *workersState) moveUpWorker() { workerName := fmt.Sprintf("%s: moveUpWorker", serviceName) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 2b6fc7d1..6238630f 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -9,7 +9,9 @@ import ( ) // moveDownWorker moves packets down the stack (sender) -// TODO move the worker to sender.go +// The sender and receiver data structures lack mutexes because they are +// intended to be confined to a single goroutine (one for each worker), and +// they SHOULD ONLY communicate via message passing. func (ws *workersState) moveDownWorker() { workerName := fmt.Sprintf("%s: moveDownWorker", serviceName) diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go index 14a5ab0b..3197e686 100644 --- a/internal/reliabletransport/service.go +++ b/internal/reliabletransport/service.go @@ -1,4 +1,3 @@ -// Package reliabletransport implements the reliable transport. package reliabletransport import ( @@ -37,7 +36,9 @@ func (s *Service) StartWorkers( sessionManager *session.Manager, ) { ws := &workersState{ - logger: logger, + logger: logger, + // incomingSeen is a buffered channel to avoid losing packets if we're busy + // processing in the sender goroutine. incomingSeen: make(chan incomingPacketSeen, 20), dataOrControlToMuxer: *s.DataOrControlToMuxer, controlToReliable: s.ControlToReliable, From 2095caba078d71070e1804888ae5fe38a1e80ce4 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 17:45:31 +0100 Subject: [PATCH 40/78] inflight does not need to implement sort --- internal/reliabletransport/receiver_test.go | 227 ++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 internal/reliabletransport/receiver_test.go diff --git a/internal/reliabletransport/receiver_test.go b/internal/reliabletransport/receiver_test.go new file mode 100644 index 00000000..60e77cc8 --- /dev/null +++ b/internal/reliabletransport/receiver_test.go @@ -0,0 +1,227 @@ +package reliabletransport + +import ( + "reflect" + "testing" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" +) + +// +// tests for reliableReceiver +// + +// testIncomingPacket is a sequentialPacket for testing incomingPackets +type testIncomingPacket struct { + id model.PacketID + acks []model.PacketID +} + +func (ip *testIncomingPacket) ID() model.PacketID { + return ip.id +} + +func (ip *testIncomingPacket) ExtractACKs() []model.PacketID { + return ip.acks +} + +func (ip *testIncomingPacket) Packet() *model.Packet { + return &model.Packet{ID: ip.id} +} + +var _ sequentialPacket = &testIncomingPacket{} + +func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + incomingPackets incomingSequence + } + type args struct { + p *testIncomingPacket + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "empty incoming, insert one", + fields: fields{ + incomingPackets: []sequentialPacket{}, + }, + args: args{ + &testIncomingPacket{id: 1}, + }, + want: true, + }, + { + name: "almost full incoming, insert one", + fields: fields{ + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 4}, + &testIncomingPacket{id: 5}, + &testIncomingPacket{id: 6}, + &testIncomingPacket{id: 7}, + &testIncomingPacket{id: 8}, + &testIncomingPacket{id: 9}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + }, + }, + args: args{ + &testIncomingPacket{id: 12}, + }, + want: true, + }, + { + name: "full incoming, cannot insert", + fields: fields{ + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 4}, + &testIncomingPacket{id: 5}, + &testIncomingPacket{id: 6}, + &testIncomingPacket{id: 7}, + &testIncomingPacket{id: 8}, + &testIncomingPacket{id: 9}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + }, + }, + args: args{ + &testIncomingPacket{id: 13}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableReceiver{ + logger: log.Log, + incomingPackets: tt.fields.incomingPackets, + } + if got := r.MaybeInsertIncoming(tt.args.p.Packet()); got != tt.want { + t.Errorf("reliableQueue.MaybeInsertIncoming() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_reliableQueue_NextIncomingSequence(t *testing.T) { + log.SetLevel(log.DebugLevel) + + type fields struct { + lastConsumed model.PacketID + incomingPackets incomingSequence + } + tests := []struct { + name string + fields fields + want incomingSequence + }{ + { + name: "empty sequence", + fields: fields{ + incomingPackets: []sequentialPacket{}, + lastConsumed: model.PacketID(0), + }, + want: []sequentialPacket{}, + }, + { + name: "single packet", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + }, + }, + { + name: "series of sequential packets", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + { + name: "series of sequential packets with hole", + fields: fields{ + lastConsumed: model.PacketID(0), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 5}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + }, + }, + { + name: "series of sequential packets with hole, lastConsumed higher", + fields: fields{ + lastConsumed: model.PacketID(10), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 3}, + &testIncomingPacket{id: 5}, + }, + }, + want: []sequentialPacket{}, + }, + { + name: "series of sequential packets with hole, lastConsumed higher, some above", + fields: fields{ + lastConsumed: model.PacketID(10), + incomingPackets: []sequentialPacket{ + &testIncomingPacket{id: 1}, + &testIncomingPacket{id: 2}, + &testIncomingPacket{id: 10}, + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + &testIncomingPacket{id: 20}, + }, + }, + want: []sequentialPacket{ + &testIncomingPacket{id: 11}, + &testIncomingPacket{id: 12}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableReceiver{ + lastConsumed: tt.fields.lastConsumed, + incomingPackets: tt.fields.incomingPackets, + } + if got := r.NextIncomingSequence(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("reliableQueue.NextIncomingSequence() = %v, want %v", got, tt.want) + } + }) + } +} From 7009ffb654594ec1806f94dd7877ecbb427c15a7 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 19:57:25 +0100 Subject: [PATCH 41/78] implement set for ack queue --- internal/packetmuxer/service.go | 7 +- internal/reliabletransport/constants.go | 9 +- internal/reliabletransport/model.go | 26 ++- internal/reliabletransport/packets.go | 19 +- internal/reliabletransport/receiver.go | 3 + internal/reliabletransport/sender.go | 148 ++++++++----- internal/reliabletransport/sender_test.go | 247 ++++------------------ 7 files changed, 179 insertions(+), 280 deletions(-) diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index ae429efe..7cdeeb55 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -15,6 +15,11 @@ var ( serviceName = "packetmuxer" ) +const ( + // A sufficiently long wakup period to initialize a ticker with. + longWakeup = time.Hour * 24 * 30 +) + // Service is the packetmuxer service. Make sure you initialize // the channels before invoking [Service.StartWorkers]. type Service struct { @@ -53,7 +58,7 @@ func (s *Service) StartWorkers( logger: logger, hardReset: s.HardReset, // initialize to a sufficiently long time from now - hardResetTicker: time.NewTicker(time.Hour * 24 * 30), + hardResetTicker: time.NewTicker(longWakeup), notifyTLS: *s.NotifyTLS, dataOrControlToMuxer: s.DataOrControlToMuxer, muxerToReliable: *s.MuxerToReliable, diff --git a/internal/reliabletransport/constants.go b/internal/reliabletransport/constants.go index c9a110bc..abde64c7 100644 --- a/internal/reliabletransport/constants.go +++ b/internal/reliabletransport/constants.go @@ -2,14 +2,19 @@ package reliabletransport const ( // Capacity for the array of packets that we're tracking at any given moment (outgoing). - RELIABLE_SEND_BUFFER_SIZE = 12 + // This is defined by OpenVPN in ssl_pkt.h + RELIABLE_SEND_BUFFER_SIZE = 6 // Capacity for the array of packets that we're tracking at any given moment (incoming). - RELIABLE_RECV_BUFFER_SIZE = RELIABLE_SEND_BUFFER_SIZE + // This is defined by OpenVPN in ssl_pkt.h + RELIABLE_RECV_BUFFER_SIZE = 12 // The maximum numbers of ACKs that we put in an array for an outgoing packet. MAX_ACKS_PER_OUTGOING_PACKET = 4 + // How many IDs pending to be acked can we store. + ACK_SET_CAPACITY = 8 + // Initial timeout for TLS retransmission, in seconds. INITIAL_TLS_TIMEOUT_SECONDS = 2 diff --git a/internal/reliabletransport/model.go b/internal/reliabletransport/model.go index 1591b60e..b121eee0 100644 --- a/internal/reliabletransport/model.go +++ b/internal/reliabletransport/model.go @@ -11,29 +11,35 @@ type sequentialPacket interface { Packet() *model.Packet } -// inFlightPacket is a packet that, additionally, can keep track of how many acks for a packet with a higher PID have been received. -type inFlighter interface { - sequentialPacket +// retransmissionPacket is a packet that can be scheduled for retransmission. +type retransmissionPacket interface { ScheduleForRetransmission() } -// outgoingPacketHandler has methods to deal with the outgoing packets (going down). -type outgoingPacketHandler interface { +type outgoingPacketWriter interface { // TryInsertOutgoingPacket attempts to insert a packet into the // inflight queue. If return value is false, insertion was not successful (e.g., too many // packets in flight). TryInsertOutgoingPacket(*model.Packet) bool +} - // MaybeEvictOrBumpPacketAfterACK removes a packet (that we received an ack for) from the in-flight packet queue. - MaybeEvictOrBumpPacketAfterACK(id model.PacketID) bool +type seenPacketHandler interface { + // OnIncomingPacketSeen processes a notification received in the shared lateral channel where receiver + // notifies sender of incoming packets. There are two side-effects expected from this call: + // 1. The ID in incomingPacketSeen needs to be appended to the array of packets pending to be acked, if not already + // there. This insertion needs to be reflected by NextPacketIDsToACK() + // 2. Any ACK values in the incomingPacketSeen need to: + // a) evict the matching packet, if existing in the in flight queue, and + // b) increment the counter of acks-with-higher-pid for each packet with a lesser + // packet id (used for fast retransmission) + OnIncomingPacketSeen(incomingPacketSeen) +} +type outgoingPacketHandler interface { // NextPacketIDsToACK returns an array of pending IDs to ACK to // our remote. The length of this array MUST NOT be larger than CONTROL_SEND_ACK_MAX. // This is used to append it to the ACK array of the outgoing packet. NextPacketIDsToACK() []model.PacketID - - // OnIncomingPacketSeen processes a notification received in the shared channel for incoming packets. - OnIncomingPacketSeen(incomingPacketSeen) } // incomingPacketHandler knows how to deal with incoming packets (going up). diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 692a5ebc..c2e99ace 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -63,7 +63,8 @@ func (p *inFlightPacket) backoff() time.Duration { // var _ sequentialPacket = &inFlightWrappedPacket{} // inflightSequence is a sequence of inFlightPackets. -// A inflightSequence can be sorted. +// A inflightSequence MUST be sorted (since the controlchannel has assigned sequential packet IDs when creating the +// packet) type inflightSequence []*inFlightPacket // nearestDeadlineTo returns the lower deadline to a passed reference time for all the packets in the inFlight queue. Used to re-arm the Ticker. We need to be careful and not pass a @@ -101,21 +102,6 @@ func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { return expired } -// implement sort.Interface -func (seq inflightSequence) Len() int { - return len(seq) -} - -// implement sort.Interface -func (seq inflightSequence) Swap(i, j int) { - seq[i], seq[j] = seq[j], seq[i] -} - -// implement sort.Interface -func (seq inflightSequence) Less(i, j int) bool { - return seq[i].packet.ID < seq[j].packet.ID -} - // An incomingSequence is an array of sequentialPackets. It's used to store both incoming and outgoing packet queues. // An incomingSequence can be sorted. type incomingSequence []sequentialPacket @@ -135,6 +121,7 @@ func (ps incomingSequence) Less(i, j int) bool { return ps[i].ID() < ps[j].ID() } +// TODO: this is just a packet type incomingPacket struct { packet *model.Packet } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 5081882a..0275e59e 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -36,6 +36,7 @@ func (ws *workersState) moveUpWorker() { case packet := <-ws.muxerToReliable: if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 { // the hard reset has already been logged by the layer below + // TODO: move logging here? packet.Log(ws.logger, model.DirectionIncoming) } @@ -58,6 +59,8 @@ func (ws *workersState) moveUpWorker() { seen := receiver.newIncomingPacketSeen(packet) ws.incomingSeen <- seen + // TODO(ainghazal): drop a packet that is a replay (id <= lastConsumed, but != ACK...?) + // we only want to insert control packets going to the tls layer if packet.Opcode != model.P_CONTROL_V1 { diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 6238630f..4d73df56 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -6,6 +6,7 @@ import ( "time" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/optional" ) // moveDownWorker moves packets down the stack (sender) @@ -40,12 +41,11 @@ func (ws *workersState) moveDownWorker() { // and add any id to the queue of packets to ack sender.OnIncomingPacketSeen(seenPacket) - if len(sender.pendingACKsToSend) == 0 { + if sender.pendingACKsToSend.Len() == 0 { continue } - // reschedule the ticker - if len(sender.pendingACKsToSend) >= 2 { + if sender.pendingACKsToSend.Len() >= 2 { ticker.Reset(time.Nanosecond) continue } @@ -92,17 +92,19 @@ func (ws *workersState) moveDownWorker() { } } } else { + // TODO --- mve this to function ------------------------------------------- + // TODO: somethingToACK(state) --------------------------------------------- // there's nothing ready to be sent, so we see if we've got pending ACKs - if len(sender.pendingACKsToSend) == 0 { + if sender.pendingACKsToSend.Len() == 0 { continue } - // special case, we want to send the clientHello as soon as possible + // special case, we want to send the clientHello as soon as possible ----------------------------- // (TODO: coordinate this with hardReset) - if len(sender.pendingACKsToSend) == 1 && sender.pendingACKsToSend[0] == model.PacketID(0) { + if sender.pendingACKsToSend.Len() == 1 && *sender.pendingACKsToSend.first() == model.PacketID(0) { continue } - fmt.Println(":: CREATING ACK", len(sender.pendingACKsToSend), "pending to ack") + ws.logger.Debugf("Creating ACK: %d pending to ack", sender.pendingACKsToSend.Len()) ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) if err != nil { @@ -122,11 +124,7 @@ func (ws *workersState) moveDownWorker() { } } -// -// outgoingPacketHandler implementation. -// - -// reliableSender keeps state about the outgoing packet queue, and implements outgoingPacketHandler. +// reliableSender keeps state about the in flight packet queue, and implements outgoingPacketHandler. // Please use the constructor `newReliableSender()` type reliableSender struct { @@ -139,8 +137,8 @@ type reliableSender struct { // logger is the logger to use logger model.Logger - // pendingACKsToSend is the array of packets that we still need to ACK. - pendingACKsToSend []model.PacketID + // pendingACKsToSend is a set of packets that we still need to ACK. + pendingACKsToSend *ackSet } // newReliableSender returns a new instance of reliableOutgoing. @@ -149,14 +147,11 @@ func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliable incomingSeen: i, inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), logger: logger, - pendingACKsToSend: []model.PacketID{}, + pendingACKsToSend: newACKSet(), } } -// -// outgoingPacketHandler implementation. -// - +// implement outgoingPacketWriter func (r *reliableSender) TryInsertOutgoingPacket(p *model.Packet) bool { if len(r.inFlight) >= RELIABLE_SEND_BUFFER_SIZE { r.logger.Warn("outgoing array full, dropping packet") @@ -167,13 +162,26 @@ func (r *reliableSender) TryInsertOutgoingPacket(p *model.Packet) bool { return true } -// MaybeEvictOrBumpPacketAfterACK iterates over all the in-flight packets. For each one, -// and either evicts it (if the PacketID matches), or bumps the internal withHigherACK count in the -// packet (if the PacketID from the ACK is higher than the packet in the queue). -func (r *reliableSender) MaybeEvictOrBumpPacketAfterACK(acked model.PacketID) bool { - // TODO: it *should* be sorted, can it be not sorted? - sort.Sort(inflightSequence(r.inFlight)) +// OnIncomingPacketSeen implements seenPacketHandler +func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { + // we have received an incomingPacketSeen on the shared channel, we need to do two things: + + // 1. add the ID to the set of packets to be acknowledged. + r.pendingACKsToSend.maybeAdd(seen.id) + + // 2. for every ACK received, see if we need to evict or bump the in-flight packet. + if seen.acks.IsNone() { + return + } + for _, packetID := range seen.acks.Unwrap() { + r.maybeEvictOrMarkWithHigherACK(packetID) + } +} +// maybeEvictOrMarkWithHigherACK iterates over all the in-flight packets. For each one, +// either evicts it (if the PacketID matches), or bumps the internal withHigherACK count in the +// packet (if the PacketID from the ACK is higher than the packet in the queue). +func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) bool { packets := r.inFlight for i, p := range packets { if acked > p.packet.ID { @@ -181,7 +189,6 @@ func (r *reliableSender) MaybeEvictOrBumpPacketAfterACK(acked model.PacketID) bo p.ACKForHigherPacket() } else if acked == p.packet.ID { - // we have a match for the ack we just received: eviction it is! r.logger.Debugf("evicting packet %v", p.packet.ID) @@ -191,43 +198,86 @@ func (r *reliableSender) MaybeEvictOrBumpPacketAfterACK(acked model.PacketID) bo // and now exclude the last element: r.inFlight = packets[:len(packets)-1] - // since we had sorted the in-flight array, we're done here. + // since the in-flight array is always sorted by ascending packet-id + // (because of sequentiality assumption in the control channel), + // we're done here. return true } } return false } -// this should return at most MAX_ACKS_PER_OUTGOING_PACKET packet IDs. +// NextPacketIDsToACK implement outgoingPacketHandler func (r *reliableSender) NextPacketIDsToACK() []model.PacketID { - var next []model.PacketID - if len(r.pendingACKsToSend) <= MAX_ACKS_PER_OUTGOING_PACKET { - next = r.pendingACKsToSend[:len(r.pendingACKsToSend)] - r.pendingACKsToSend = r.pendingACKsToSend[:0] - return next + return r.pendingACKsToSend.nextToACK() +} + +var _ outgoingPacketHandler = &reliableSender{} + +// ackSet is a set of acks. The zero value struct +// is invalid, please use newACKSet. +type ackSet struct { + // m is the map we use to represent the set. + m map[model.PacketID]bool +} + +// NewACKSet creates a new empty ACK set. +func newACKSet(ids ...model.PacketID) *ackSet { + m := make(map[model.PacketID]bool) + for _, id := range ids { + m[id] = true } + return &ackSet{m} +} - next = r.pendingACKsToSend[:MAX_ACKS_PER_OUTGOING_PACKET] - r.pendingACKsToSend = r.pendingACKsToSend[MAX_ACKS_PER_OUTGOING_PACKET : len(r.pendingACKsToSend)-1] - return next +// maybeAdd unwraps the optional value, and if not empty it MUTATES the set to add a (possibly-new) +// packet ID to the set and. It returns the same set to the caller. +func (as *ackSet) maybeAdd(id optional.Value[model.PacketID]) *ackSet { + if len(as.m) >= ACK_SET_CAPACITY { + return as + } + if !id.IsNone() { + as.m[id.Unwrap()] = true + } + return as } -func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { - // we have received an incomingPacketSeen on the shared channel, we need to do two things: +// nextToACK returns up to MAX_ACKS_PER_OUTGOING_PACKET from the set, sorted by ascending packet ID. +func (as *ackSet) nextToACK() []model.PacketID { + ids := as.sorted() + var next []model.PacketID + if len(ids) <= MAX_ACKS_PER_OUTGOING_PACKET { + next = ids + } else { + next = ids[:MAX_ACKS_PER_OUTGOING_PACKET] + } + for _, i := range next { + delete(as.m, i) + } + return next +} - // 1. add the ID to the queue of packets to be acknowledged. - if !seen.id.IsNone() { - // TODO: do it only if not already in the array - // FIXME -------------------------------------- - r.pendingACKsToSend = append(r.pendingACKsToSend, seen.id.Unwrap()) +// first returns the first packetID in the set, in ascending order. +func (as *ackSet) first() *model.PacketID { + ids := as.sorted() + if len(ids) == 0 { + return nil } + return &ids[0] +} - // 2. for every ACK received, see if we need to evict or bump the in-flight packet. - if !seen.acks.IsNone() { - for _, packetID := range seen.acks.Unwrap() { - r.MaybeEvictOrBumpPacketAfterACK(packetID) - } +// sorted returns a []model.PacketID array with the stored ids, in ascending order. +func (as *ackSet) sorted() []model.PacketID { + ids := make([]model.PacketID, 0) + for id := range as.m { + ids = append(ids, id) } + sort.SliceStable(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + return ids } -var _ outgoingPacketHandler = &reliableSender{} +func (as *ackSet) Len() int { + return len(as.m) +} diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index e13ed9ed..5bde6a06 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -6,6 +6,7 @@ import ( "github.com/apex/log" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/optional" ) // @@ -47,19 +48,29 @@ func Test_reliableSender_TryInsertOutgoingPacket(t *testing.T) { {packet: &model.Packet{ID: 4}}, {packet: &model.Packet{ID: 5}}, {packet: &model.Packet{ID: 6}}, - {packet: &model.Packet{ID: 7}}, - {packet: &model.Packet{ID: 8}}, - {packet: &model.Packet{ID: 9}}, - {packet: &model.Packet{ID: 10}}, - {packet: &model.Packet{ID: 11}}, - {packet: &model.Packet{ID: 12}}, }), }, args: args{ - p: &model.Packet{ID: 13}, + p: &model.Packet{ID: 7}, }, want: false, }, + { + name: "insert on almost full array", + fields: fields{ + inFlight: inflightSequence([]*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}, + {packet: &model.Packet{ID: 4}}, + {packet: &model.Packet{ID: 5}}, + }), + }, + args: args{ + p: &model.Packet{ID: 6}, + }, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -102,14 +113,14 @@ func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { { name: "tree elements", fields: fields{ - pendingACKsToSend: []model.PacketID{11, 12, 13}, + pendingACKsToSend: []model.PacketID{12, 11, 13}, }, want: []model.PacketID{11, 12, 13}, }, { name: "five elements", fields: fields{ - pendingACKsToSend: []model.PacketID{11, 12, 13, 14, 15}, + pendingACKsToSend: []model.PacketID{15, 12, 14, 13, 11}, }, want: []model.PacketID{11, 12, 13, 14}, }, @@ -118,7 +129,7 @@ func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := &reliableSender{ logger: log.Log, - pendingACKsToSend: tt.fields.pendingACKsToSend, + pendingACKsToSend: newACKSet(tt.fields.pendingACKsToSend...), } if got := r.NextPacketIDsToACK(); !reflect.DeepEqual(got, tt.want) { t.Errorf("reliableSender.NextPacketIDsToACK() = %v, want %v", got, tt.want) @@ -127,219 +138,51 @@ func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { } } -// -// tests for reliableReceiver -// - -// testIncomingPacket is a sequentialPacket for testing incomingPackets -type testIncomingPacket struct { - id model.PacketID - acks []model.PacketID -} - -func (ip *testIncomingPacket) ID() model.PacketID { - return ip.id -} - -func (ip *testIncomingPacket) ExtractACKs() []model.PacketID { - return ip.acks -} - -func (ip *testIncomingPacket) Packet() *model.Packet { - return &model.Packet{ID: ip.id} -} - -var _ sequentialPacket = &testIncomingPacket{} - -func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { - log.SetLevel(log.DebugLevel) - +func Test_ackSet_maybeAdd(t *testing.T) { type fields struct { - incomingPackets incomingSequence + m map[model.PacketID]bool } type args struct { - p *testIncomingPacket + id optional.Value[model.PacketID] } tests := []struct { name string fields fields args args - want bool + want *ackSet }{ { - name: "empty incoming, insert one", - fields: fields{ - incomingPackets: []sequentialPacket{}, - }, - args: args{ - &testIncomingPacket{id: 1}, - }, - want: true, + name: "can add on empty set", + fields: fields{newACKSet().m}, + args: args{optional.Some(model.PacketID(1))}, + want: newACKSet(1), }, { - name: "almost full incoming, insert one", - fields: fields{ - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 4}, - &testIncomingPacket{id: 5}, - &testIncomingPacket{id: 6}, - &testIncomingPacket{id: 7}, - &testIncomingPacket{id: 8}, - &testIncomingPacket{id: 9}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, - }, - }, - args: args{ - &testIncomingPacket{id: 12}, - }, - want: true, + name: "add duplicate on empty set", + fields: fields{newACKSet(1).m}, + args: args{optional.Some(model.PacketID(1))}, + want: newACKSet(1), }, { - name: "full incoming, cannot insert", - fields: fields{ - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 4}, - &testIncomingPacket{id: 5}, - &testIncomingPacket{id: 6}, - &testIncomingPacket{id: 7}, - &testIncomingPacket{id: 8}, - &testIncomingPacket{id: 9}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, - }, - }, - args: args{ - &testIncomingPacket{id: 13}, - }, - want: false, + name: "cannot add beyond capacity", + fields: fields{newACKSet(1, 2, 3, 4, 5, 6, 7, 8).m}, + args: args{optional.Some(model.PacketID(10))}, + want: newACKSet(1, 2, 3, 4, 5, 6, 7, 8), }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &reliableReceiver{ - logger: log.Log, - incomingPackets: tt.fields.incomingPackets, - } - if got := r.MaybeInsertIncoming(tt.args.p.Packet()); got != tt.want { - t.Errorf("reliableQueue.MaybeInsertIncoming() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_reliableQueue_NextIncomingSequence(t *testing.T) { - log.SetLevel(log.DebugLevel) - - type fields struct { - lastConsumed model.PacketID - incomingPackets incomingSequence - } - tests := []struct { - name string - fields fields - want incomingSequence - }{ { - name: "empty sequence", - fields: fields{ - incomingPackets: []sequentialPacket{}, - lastConsumed: model.PacketID(0), - }, - want: []sequentialPacket{}, - }, - { - name: "single packet", - fields: fields{ - lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, - }, - }, - { - name: "series of sequential packets", - fields: fields{ - lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - }, - }, - { - name: "series of sequential packets with hole", - fields: fields{ - lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 5}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - }, - }, - { - name: "series of sequential packets with hole, lastConsumed higher", - fields: fields{ - lastConsumed: model.PacketID(10), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 5}, - }, - }, - want: []sequentialPacket{}, - }, - { - name: "series of sequential packets with hole, lastConsumed higher, some above", - fields: fields{ - lastConsumed: model.PacketID(10), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, - &testIncomingPacket{id: 20}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, - }, + name: "order does not matter", + fields: fields{newACKSet(3, 2, 1).m}, + args: args{optional.Some(model.PacketID(4))}, + want: newACKSet(1, 2, 3, 4), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &reliableReceiver{ - lastConsumed: tt.fields.lastConsumed, - incomingPackets: tt.fields.incomingPackets, + as := &ackSet{ + m: tt.fields.m, } - if got := r.NextIncomingSequence(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("reliableQueue.NextIncomingSequence() = %v, want %v", got, tt.want) + if got := as.maybeAdd(tt.args.id); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ackSet.maybeAdd() = %v, want %v", got, tt.want) } }) } From 32a83e05cbb284fe49ca985b183f01c817f49eb3 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 20:49:47 +0100 Subject: [PATCH 42/78] add tests for pending acks + evict after ack --- internal/reliabletransport/sender.go | 10 +- internal/reliabletransport/sender_test.go | 159 ++++++++++++++++++++++ 2 files changed, 166 insertions(+), 3 deletions(-) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 4d73df56..965a6b68 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -157,8 +157,7 @@ func (r *reliableSender) TryInsertOutgoingPacket(p *model.Packet) bool { r.logger.Warn("outgoing array full, dropping packet") return false } - new := newInFlightPacket(p) - r.inFlight = append(r.inFlight, new) + r.inFlight = append(r.inFlight, newInFlightPacket(p)) return true } @@ -184,10 +183,12 @@ func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) bool { packets := r.inFlight for i, p := range packets { + if p.packet == nil { + panic("malformed packet") + } if acked > p.packet.ID { // we have received an ACK for a packet with a higher pid, so let's bump it p.ACKForHigherPacket() - } else if acked == p.packet.ID { // we have a match for the ack we just received: eviction it is! r.logger.Debugf("evicting packet %v", p.packet.ID) @@ -269,6 +270,9 @@ func (as *ackSet) first() *model.PacketID { // sorted returns a []model.PacketID array with the stored ids, in ascending order. func (as *ackSet) sorted() []model.PacketID { ids := make([]model.PacketID, 0) + if len(as.m) == 0 { + return ids + } for id := range as.m { ids = append(ids, id) } diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index 5bde6a06..ac8f639f 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -2,6 +2,7 @@ package reliabletransport import ( "reflect" + "slices" "testing" "github.com/apex/log" @@ -187,3 +188,161 @@ func Test_ackSet_maybeAdd(t *testing.T) { }) } } + +// test the combined behavior of reacting to an incomin packet and checking +// what's left in the in flight queue and what's left in the queue of pending acks. +func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { + + idSequence := func(ifp []*inFlightPacket) []model.PacketID { + ids := make([]model.PacketID, 0) + for _, p := range ifp { + ids = append(ids, p.packet.ID) + } + return ids + } + + type fields struct { + pendingacks *ackSet + inflight []*inFlightPacket + } + type args struct { + seen []incomingPacketSeen + } + type want struct { + acks []model.PacketID + inflight []model.PacketID + } + tests := []struct { + name string + fields fields + args args + want want + }{ + { + name: "empty seen does not change anything", + fields: fields{ + pendingacks: newACKSet(), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}}, + }, + args: args{}, + want: want{inflight: []model.PacketID{1, 2}}, + }, + { + name: "ack for 1 evicts in-flight packet 1", + fields: fields{ + pendingacks: newACKSet(), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}}, + }, + args: args{[]incomingPacketSeen{ + { + acks: optional.Some([]model.PacketID{model.PacketID(1)}), + }, + }, + }, + want: want{inflight: []model.PacketID{2}}, + }, + { + name: "ack for 1,2 evicts in-flight packets 1,2", + fields: fields{ + pendingacks: newACKSet(), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}}, + }, + args: args{[]incomingPacketSeen{ + { + acks: optional.Some([]model.PacketID{ + model.PacketID(2), + model.PacketID(1), + }), + }, + }, + }, + want: want{inflight: []model.PacketID{}}, + }, + { + name: "ack for non-existent packet does not evict anything", + fields: fields{ + pendingacks: newACKSet(), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}}, + }, + args: args{[]incomingPacketSeen{ + { + acks: optional.Some([]model.PacketID{ + model.PacketID(10), + }), + }, + }, + }, + want: want{inflight: []model.PacketID{1, 2, 3}}, + }, + { + name: "duplicated ack can only evict once", + fields: fields{ + pendingacks: newACKSet(), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}, + {packet: &model.Packet{ID: 4}}}, + }, + args: args{[]incomingPacketSeen{ + { + acks: optional.Some([]model.PacketID{ + model.PacketID(3), + model.PacketID(3), + }), + }, + }, + }, + want: want{inflight: []model.PacketID{1, 2, 4}}, + }, + { + name: "seen id adds to pending ids to ack, plus ack evicts", + fields: fields{ + pendingacks: newACKSet(4, 6), + inflight: []*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 3}}}, + }, + args: args{[]incomingPacketSeen{ + // a packet seen with ID + acks + { + id: optional.Some(model.PacketID(2)), + acks: optional.Some([]model.PacketID{ + model.PacketID(1), + }), + }, + }, + }, + want: want{ + acks: []model.PacketID{2, 4, 6}, + inflight: []model.PacketID{3}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + inFlight: tt.fields.inflight, + pendingACKsToSend: tt.fields.pendingacks, + } + for _, seen := range tt.args.seen { + r.OnIncomingPacketSeen(seen) + } + if gotACKs := r.NextPacketIDsToACK(); !slices.Equal(gotACKs, tt.want.acks) { + t.Errorf("reliableSender.NextPacketIDsToACK() = %v, want %v", gotACKs, tt.want.acks) + } + if seq := idSequence(r.inFlight); !slices.Equal(seq, tt.want.inflight) { + t.Errorf("reliableSender.NextPacketIDsToACK() = %v, want %v", seq, tt.want.inflight) + } + }) + } +} From a814e135b81f7078fd28d6040a05da30fae830c8 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 20:59:14 +0100 Subject: [PATCH 43/78] test next packet ids to ack --- internal/reliabletransport/sender_test.go | 44 ++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index ac8f639f..a674577e 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -189,7 +189,49 @@ func Test_ackSet_maybeAdd(t *testing.T) { } } -// test the combined behavior of reacting to an incomin packet and checking +func Test_ackSet_nextToACK(t *testing.T) { + type fields struct { + m map[model.PacketID]bool + } + tests := []struct { + name string + fields fields + want []model.PacketID + }{ + { + name: "get all if you have <4", + fields: fields{newACKSet(1, 2, 3).m}, + want: []model.PacketID{1, 2, 3}, + }, + { + name: "get all if you have 4", + fields: fields{newACKSet(1, 2, 3, 4).m}, + want: []model.PacketID{1, 2, 3, 4}, + }, + { + name: "get 2 if you have 2, sorted", + fields: fields{newACKSet(4, 1).m}, + want: []model.PacketID{1, 4}, + }, + { + name: "get first 4 if you have >4, sorted", + fields: fields{newACKSet(5, 6, 8, 3, 2, 4, 1).m}, + want: []model.PacketID{1, 2, 3, 4}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + as := &ackSet{ + m: tt.fields.m, + } + if got := as.nextToACK(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, tt.want) + } + }) + } +} + +// test the combined behavior of reacting to an incoming packet and checking // what's left in the in flight queue and what's left in the queue of pending acks. func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { From c1747c5fd821bcbae23056787f9b219e6d154950 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 21:07:17 +0100 Subject: [PATCH 44/78] test ack empties --- internal/reliabletransport/sender_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index a674577e..ec184792 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -231,6 +231,24 @@ func Test_ackSet_nextToACK(t *testing.T) { } } +func Test_ackSet_nextToACK_empties_set(t *testing.T) { + acks := newACKSet(1, 2, 3, 5, 4, 6, 7, 10, 9, 8) + + want1 := []model.PacketID{1, 2, 3, 4} + want2 := []model.PacketID{5, 6, 7, 8} + want3 := []model.PacketID{9, 10} + + if got := acks.nextToACK(); !reflect.DeepEqual(got, want1) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) + } + if got := acks.nextToACK(); !reflect.DeepEqual(got, want2) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) + } + if got := acks.nextToACK(); !reflect.DeepEqual(got, want3) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want3) + } +} + // test the combined behavior of reacting to an incoming packet and checking // what's left in the in flight queue and what's left in the queue of pending acks. func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { From 9aaa7f18320827440a86a40ad54673fa9c666674 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 21:15:46 +0100 Subject: [PATCH 45/78] remove tlssession logging from this pr, separated in a different one --- internal/tlssession/tlsbio.go | 43 ++++++++++++++----------------- internal/tlssession/tlssession.go | 2 +- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/internal/tlssession/tlsbio.go b/internal/tlssession/tlsbio.go index 00898d2b..6fec09be 100644 --- a/internal/tlssession/tlsbio.go +++ b/internal/tlssession/tlsbio.go @@ -2,11 +2,10 @@ package tlssession import ( "bytes" + "log" "net" "sync" "time" - - "github.com/ooni/minivpn/internal/model" ) // tlsBio allows to use channels to read and write @@ -15,72 +14,70 @@ type tlsBio struct { directionDown chan<- []byte directionUp <-chan []byte hangup chan any - logger model.Logger readBuffer *bytes.Buffer } // newTLSBio creates a new tlsBio -func newTLSBio(logger model.Logger, directionUp <-chan []byte, directionDown chan<- []byte) *tlsBio { +func newTLSBio(directionUp <-chan []byte, directionDown chan<- []byte) *tlsBio { return &tlsBio{ closeOnce: sync.Once{}, directionDown: directionDown, directionUp: directionUp, hangup: make(chan any), - logger: logger, readBuffer: &bytes.Buffer{}, } } -func (t *tlsBio) Close() error { - t.closeOnce.Do(func() { - close(t.hangup) +func (c *tlsBio) Close() error { + c.closeOnce.Do(func() { + close(c.hangup) }) return nil } -func (t *tlsBio) Read(data []byte) (int, error) { +func (c *tlsBio) Read(data []byte) (int, error) { for { - count, _ := t.readBuffer.Read(data) + count, _ := c.readBuffer.Read(data) if count > 0 { - t.logger.Debugf("[tlsbio] received %d bytes", len(data)) + log.Printf("[tlsbio] received %d bytes", len(data)) return count, nil } select { - case extra := <-t.directionUp: - t.readBuffer.Write(extra) - case <-t.hangup: + case extra := <-c.directionUp: + c.readBuffer.Write(extra) + case <-c.hangup: return 0, net.ErrClosed } } } -func (t *tlsBio) Write(data []byte) (int, error) { - t.logger.Debugf("[tlsbio] requested to write %d bytes", len(data)) +func (c *tlsBio) Write(data []byte) (int, error) { + log.Printf("[tlsbio] requested to write %d bytes", len(data)) select { - case t.directionDown <- data: + case c.directionDown <- data: return len(data), nil - case <-t.hangup: + case <-c.hangup: return 0, net.ErrClosed } } -func (t *tlsBio) LocalAddr() net.Addr { +func (c *tlsBio) LocalAddr() net.Addr { return &tlsBioAddr{} } -func (t *tlsBio) RemoteAddr() net.Addr { +func (c *tlsBio) RemoteAddr() net.Addr { return &tlsBioAddr{} } -func (t *tlsBio) SetDeadline(tt time.Time) error { +func (c *tlsBio) SetDeadline(t time.Time) error { return nil } -func (t *tlsBio) SetReadDeadline(tt time.Time) error { +func (c *tlsBio) SetReadDeadline(t time.Time) error { return nil } -func (t *tlsBio) SetWriteDeadline(tt time.Time) error { +func (c *tlsBio) SetWriteDeadline(t time.Time) error { return nil } diff --git a/internal/tlssession/tlssession.go b/internal/tlssession/tlssession.go index 95016bfa..a227ecd8 100644 --- a/internal/tlssession/tlssession.go +++ b/internal/tlssession/tlssession.go @@ -101,7 +101,7 @@ func (ws *workersState) worker() { // tlsAuth runs the TLS auth algorithm func (ws *workersState) tlsAuth() error { // create the BIO to use channels as a socket - conn := newTLSBio(ws.logger, ws.tlsRecordUp, ws.tlsRecordDown) + conn := newTLSBio(ws.tlsRecordUp, ws.tlsRecordDown) defer conn.Close() // we construct the certCfg from options, that has access to the certificate material From cc4f35fdbef19b7f9128ee81dd9f70975475251e Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Tue, 30 Jan 2024 21:16:02 +0100 Subject: [PATCH 46/78] note --- internal/reliabletransport/sender_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index ec184792..7c0b6914 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -306,7 +306,7 @@ func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { want: want{inflight: []model.PacketID{2}}, }, { - name: "ack for 1,2 evicts in-flight packets 1,2", + name: "ack for 2,1 evicts in-flight packets 1,2", fields: fields{ pendingacks: newACKSet(), inflight: []*inFlightPacket{ @@ -406,3 +406,5 @@ func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { }) } } + +// TODO: exercise maybeEvict + withHigherACKs From bddbfe4757ff0359e024ae09df56cd6a9e6c2895 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 13:00:59 +0100 Subject: [PATCH 47/78] x --- internal/reliabletransport/packets.go | 2 ++ internal/reliabletransport/packets_test.go | 29 ++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 internal/reliabletransport/packets_test.go diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index c2e99ace..072e38a0 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -33,9 +33,11 @@ func newInFlightPacket(p *model.Packet) *inFlightPacket { } } +/* func (p *inFlightPacket) ExtractACKs() []model.PacketID { return p.packet.ACKs } +*/ // ACKForHigherPacket increments the number of acks received for a higher pid than this packet. This will influence the fast rexmit selection algorithm. func (p *inFlightPacket) ACKForHigherPacket() { diff --git a/internal/reliabletransport/packets_test.go b/internal/reliabletransport/packets_test.go new file mode 100644 index 00000000..f40bb46a --- /dev/null +++ b/internal/reliabletransport/packets_test.go @@ -0,0 +1,29 @@ +package reliabletransport + +import ( + "testing" + "time" +) + +func Test_inFlightPacket_backoff(t *testing.T) { + type fields struct { + retries uint8 + } + tests := []struct { + name string + fields fields + want time.Duration + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &inFlightPacket{ + retries: tt.fields.retries, + } + if got := p.backoff(); got != tt.want { + t.Errorf("inFlightPacket.backoff() = %v, want %v", got, tt.want) + } + }) + } +} From 037e6d12a1cacd29cc720615ac5b7b80962a93a0 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 13:35:11 +0100 Subject: [PATCH 48/78] test for inflightSequence --- internal/reliabletransport/packets.go | 8 +- internal/reliabletransport/packets_test.go | 244 ++++++++++++++++++++- 2 files changed, 244 insertions(+), 8 deletions(-) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 072e38a0..c97392a5 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -33,12 +33,6 @@ func newInFlightPacket(p *model.Packet) *inFlightPacket { } } -/* -func (p *inFlightPacket) ExtractACKs() []model.PacketID { - return p.packet.ACKs -} -*/ - // ACKForHigherPacket increments the number of acks received for a higher pid than this packet. This will influence the fast rexmit selection algorithm. func (p *inFlightPacket) ACKForHigherPacket() { p.higherACKs += 1 @@ -123,7 +117,7 @@ func (ps incomingSequence) Less(i, j int) bool { return ps[i].ID() < ps[j].ID() } -// TODO: this is just a packet +// TODO: this is just a packet ----------------------------- type incomingPacket struct { packet *model.Packet } diff --git a/internal/reliabletransport/packets_test.go b/internal/reliabletransport/packets_test.go index f40bb46a..4b071dc4 100644 --- a/internal/reliabletransport/packets_test.go +++ b/internal/reliabletransport/packets_test.go @@ -1,8 +1,11 @@ package reliabletransport import ( + "reflect" "testing" "time" + + "github.com/ooni/minivpn/internal/model" ) func Test_inFlightPacket_backoff(t *testing.T) { @@ -14,7 +17,51 @@ func Test_inFlightPacket_backoff(t *testing.T) { fields fields want time.Duration }{ - // TODO: Add test cases. + { + name: "retries=0", + fields: fields{0}, + want: time.Second, + }, + { + name: "retries=1", + fields: fields{1}, + want: time.Second * 2, + }, + { + name: "retries=2", + fields: fields{2}, + want: time.Second * 4, + }, + { + name: "retries=3", + fields: fields{3}, + want: time.Second * 8, + }, + { + name: "retries=4", + fields: fields{4}, + want: time.Second * 16, + }, + { + name: "retries=5", + fields: fields{5}, + want: time.Second * 32, + }, + { + name: "retries=6", + fields: fields{6}, + want: time.Second * 60, + }, + { + name: "retries=10", + fields: fields{10}, + want: time.Second * 60, + }, + { + name: "retries=6", + fields: fields{6}, + want: time.Second * 60, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -27,3 +74,198 @@ func Test_inFlightPacket_backoff(t *testing.T) { }) } } + +func Test_inFlightPacket_ScheduleForRetransmission(t *testing.T) { + p0 := newInFlightPacket(&model.Packet{}) + if p0.retries != 0 { + t.Errorf("inFlightPacket.retries should be 0") + } + t0 := time.Now() + p0.ScheduleForRetransmission(t0) + if p0.retries != 1 { + t.Errorf("inFlightPacket.retries should be 0") + } + if p0.deadline != t0.Add(time.Second*2) { + t.Errorf("inFlightPacket.deadline should be 2s in the future") + } + // schedule twice now + p0.ScheduleForRetransmission(t0) + p0.ScheduleForRetransmission(t0) + if p0.retries != 3 { + t.Errorf("inFlightPacket.retries should be 3") + } + if p0.deadline != t0.Add(time.Second*8) { + t.Errorf("inFlightPacket.deadline should be 8s in the future") + } + // schedule twice again + p0.ScheduleForRetransmission(t0) + p0.ScheduleForRetransmission(t0) + if p0.retries != 5 { + t.Errorf("inFlightPacket.retries should be 5") + } + if p0.deadline != t0.Add(time.Second*32) { + t.Errorf("inFlightPacket.deadline should be 32s in the future") + } +} + +func Test_inflightSequence_nearestDeadlineTo(t *testing.T) { + t0 := time.Date(1984, time.January, 1, 0, 0, 0, 0, time.UTC) + + type args struct { + t time.Time + } + tests := []struct { + name string + seq inflightSequence + args args + want time.Time + }{ + { + name: "empty case returns one minute wakeup", + seq: []*inFlightPacket{}, + args: args{t0}, + want: t0.Add(time.Minute), + }, + { + name: "single expired deadline returns ~now", + seq: []*inFlightPacket{ + {deadline: t0.Add(-1 * time.Second)}, + }, + args: args{t0}, + want: t0.Add(time.Nanosecond), + }, + { + name: "a expired deadline returns ~now", + seq: []*inFlightPacket{ + {deadline: t0.Add(-1 * time.Second)}, + {deadline: t0.Add(-2 * time.Second)}, + {deadline: t0.Add(10 * time.Millisecond)}, + {deadline: t0.Add(50 * time.Millisecond)}, + }, + args: args{t0}, + want: t0.Add(time.Nanosecond), + }, + { + name: "with several deadlines in the future, returns the lowest", + seq: []*inFlightPacket{ + {deadline: t0.Add(10 * time.Millisecond)}, + {deadline: t0.Add(20 * time.Millisecond)}, + {deadline: t0.Add(50 * time.Millisecond)}, + {deadline: t0.Add(1 * time.Second)}, + }, + args: args{t0}, + want: t0.Add(10 * time.Millisecond), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.seq.nearestDeadlineTo(tt.args.t); !reflect.DeepEqual(got, tt.want) { + t.Errorf("inflightSequence.nearestDeadlineTo() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_inflightSequence_readyToSend(t *testing.T) { + t0 := time.Date(1984, time.January, 1, 0, 0, 0, 0, time.UTC) + + type args struct { + t time.Time + } + tests := []struct { + name string + seq inflightSequence + args args + want inflightSequence + }{ + { + name: "empty queue returns empty slice", + seq: []*inFlightPacket{}, + args: args{t0}, + want: []*inFlightPacket{}, + }, + { + name: "not expired packet returns empty slice", + seq: []*inFlightPacket{ + {deadline: t0.Add(10 * time.Millisecond)}, + }, + args: args{t0}, + want: []*inFlightPacket{}, + }, + { + name: "one expired packet among many", + seq: []*inFlightPacket{ + { + packet: &model.Packet{ID: 1}, + deadline: t0.Add(10 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 2}, + deadline: t0.Add(-1 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 3}, + deadline: t0.Add(20 * time.Millisecond), + }, + }, + args: args{t0}, + want: []*inFlightPacket{ + { + packet: &model.Packet{ID: 2}, + deadline: t0.Add(-1 * time.Millisecond), + }, + }, + }, + { + name: "one expired packet and two fast retransmit", + seq: []*inFlightPacket{ + { + packet: &model.Packet{ID: 1}, + deadline: t0.Add(10 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 2}, + deadline: t0.Add(-1 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 3}, + deadline: t0.Add(20 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 4}, + deadline: t0.Add(100 * time.Millisecond), + higherACKs: 3, + }, + { + packet: &model.Packet{ID: 5}, + deadline: t0.Add(100 * time.Millisecond), + higherACKs: 5, + }, + }, + args: args{t0}, + want: []*inFlightPacket{ + { + packet: &model.Packet{ID: 2}, + deadline: t0.Add(-1 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 4}, + deadline: t0.Add(100 * time.Millisecond), + higherACKs: 3, + }, + { + packet: &model.Packet{ID: 5}, + deadline: t0.Add(100 * time.Millisecond), + higherACKs: 5, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.seq.readyToSend(tt.args.t); !reflect.DeepEqual(got, tt.want) { + t.Errorf("inflightSequence.readyToSend() = %v, want %v", got, tt.want) + } + }) + } +} From 5fcc0c6c6a455f1c24a7aca46a4b72abcd47f55e Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 15:09:08 +0100 Subject: [PATCH 49/78] fix tests for receiver --- internal/reliabletransport/packets.go | 35 ++--- internal/reliabletransport/receiver.go | 20 ++- internal/reliabletransport/receiver_test.go | 135 +++++--------------- 3 files changed, 48 insertions(+), 142 deletions(-) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index c97392a5..89563e5e 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -1,7 +1,6 @@ package reliabletransport import ( - "fmt" "time" "github.com/ooni/minivpn/internal/model" @@ -87,7 +86,6 @@ func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { expired := make([]*inFlightPacket, 0) for _, p := range seq { if p.higherACKs >= 3 { - fmt.Println("DEBUG: fast retransmit for", p.packet.ID) expired = append(expired, p) continue } @@ -98,40 +96,23 @@ func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { return expired } -// An incomingSequence is an array of sequentialPackets. It's used to store both incoming and outgoing packet queues. +// An incomingSequence is an array of [model.Packet]. // An incomingSequence can be sorted. -type incomingSequence []sequentialPacket +type incomingSequence []*model.Packet // implement sort.Interface -func (ps incomingSequence) Len() int { - return len(ps) +func (seq incomingSequence) Len() int { + return len(seq) } // implement sort.Interface -func (ps incomingSequence) Swap(i, j int) { - ps[i], ps[j] = ps[j], ps[i] +func (seq incomingSequence) Swap(i, j int) { + seq[i], seq[j] = seq[j], seq[i] } // implement sort.Interface -func (ps incomingSequence) Less(i, j int) bool { - return ps[i].ID() < ps[j].ID() -} - -// TODO: this is just a packet ----------------------------- -type incomingPacket struct { - packet *model.Packet -} - -func (ip *incomingPacket) ID() model.PacketID { - return ip.packet.ID -} - -func (ip *incomingPacket) ExtractACKs() []model.PacketID { - return ip.packet.ACKs -} - -func (ip *incomingPacket) Packet() *model.Packet { - return ip.packet +func (seq incomingSequence) Less(i, j int) bool { + return seq[i].ID < seq[j].ID } // incomingPacketSeen is a struct that the receiver sends us when a new packet is seen. diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 0275e59e..7812ac7f 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -76,7 +76,7 @@ func (ws *workersState) moveUpWorker() { for _, nextPacket := range ready { // POSSIBLY BLOCK delivering to the upper layer select { - case ws.reliableToControl <- nextPacket.Packet(): + case ws.reliableToControl <- nextPacket: case <-ws.workersManager.ShouldShutdown(): return } @@ -111,7 +111,7 @@ type reliableReceiver struct { func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliableReceiver { return &reliableReceiver{ logger: logger, - incomingPackets: []sequentialPacket{}, + incomingPackets: make([]*model.Packet, 0), incomingSeen: i, lastConsumed: 0, } @@ -124,25 +124,24 @@ func (r *reliableReceiver) MaybeInsertIncoming(p *model.Packet) bool { return false } - inc := &incomingPacket{p} // insert this one in the queue to pass to TLS. - r.incomingPackets = append(r.incomingPackets, inc) + r.incomingPackets = append(r.incomingPackets, p) return true } func (r *reliableReceiver) NextIncomingSequence() incomingSequence { last := r.lastConsumed - ready := make([]sequentialPacket, 0, RELIABLE_RECV_BUFFER_SIZE) + ready := make([]*model.Packet, 0, RELIABLE_RECV_BUFFER_SIZE) // sort them so that we begin with lower model.PacketID sort.Sort(r.incomingPackets) var keep incomingSequence for i, p := range r.incomingPackets { - if p.ID()-last == 1 { + if p.ID-last == 1 { ready = append(ready, p) last += 1 - } else if p.ID() > last { + } else if p.ID > last { // here we broke sequentiality, but we want // to drop anything that is below lastConsumed keep = append(keep, r.incomingPackets[i:]...) @@ -155,9 +154,6 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { } func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) incomingPacketSeen { - if len(p.ACKs) != 0 { - fmt.Println(":: seen", p.ACKs) - } incomingPacket := incomingPacketSeen{} if p.Opcode == model.P_ACK_V1 { incomingPacket.acks = optional.Some(p.ACKs) @@ -169,10 +165,10 @@ func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) incomingPacket return incomingPacket } -// assert that reliableIncoming implements incomingPacketHandler +// assert that reliableReceiver implements incomingPacketHandler var _ incomingPacketHandler = &reliableReceiver{ logger: nil, - incomingPackets: []sequentialPacket{}, + incomingPackets: make([]*model.Packet, 0), incomingSeen: make(chan<- incomingPacketSeen), lastConsumed: 0, } diff --git a/internal/reliabletransport/receiver_test.go b/internal/reliabletransport/receiver_test.go index 60e77cc8..21c727ac 100644 --- a/internal/reliabletransport/receiver_test.go +++ b/internal/reliabletransport/receiver_test.go @@ -12,26 +12,6 @@ import ( // tests for reliableReceiver // -// testIncomingPacket is a sequentialPacket for testing incomingPackets -type testIncomingPacket struct { - id model.PacketID - acks []model.PacketID -} - -func (ip *testIncomingPacket) ID() model.PacketID { - return ip.id -} - -func (ip *testIncomingPacket) ExtractACKs() []model.PacketID { - return ip.acks -} - -func (ip *testIncomingPacket) Packet() *model.Packet { - return &model.Packet{ID: ip.id} -} - -var _ sequentialPacket = &testIncomingPacket{} - func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { log.SetLevel(log.DebugLevel) @@ -39,7 +19,7 @@ func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { incomingPackets incomingSequence } type args struct { - p *testIncomingPacket + p *model.Packet } tests := []struct { name string @@ -50,55 +30,36 @@ func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { { name: "empty incoming, insert one", fields: fields{ - incomingPackets: []sequentialPacket{}, + incomingPackets: make([]*model.Packet, 0), }, args: args{ - &testIncomingPacket{id: 1}, + &model.Packet{ID: 1}, }, want: true, }, { name: "almost full incoming, insert one", fields: fields{ - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 4}, - &testIncomingPacket{id: 5}, - &testIncomingPacket{id: 6}, - &testIncomingPacket{id: 7}, - &testIncomingPacket{id: 8}, - &testIncomingPacket{id: 9}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, + incomingPackets: []*model.Packet{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, + {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, + {ID: 9}, {ID: 10}, {ID: 11}, }, }, - args: args{ - &testIncomingPacket{id: 12}, - }, + args: args{&model.Packet{ID: 12}}, want: true, }, { name: "full incoming, cannot insert", fields: fields{ - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 4}, - &testIncomingPacket{id: 5}, - &testIncomingPacket{id: 6}, - &testIncomingPacket{id: 7}, - &testIncomingPacket{id: 8}, - &testIncomingPacket{id: 9}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, + incomingPackets: []*model.Packet{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, + {ID: 5}, {ID: 6}, {ID: 7}, {ID: 8}, + {ID: 9}, {ID: 10}, {ID: 11}, {ID: 12}, }, }, args: args{ - &testIncomingPacket{id: 13}, + &model.Packet{ID: 13}, }, want: false, }, @@ -109,7 +70,7 @@ func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { logger: log.Log, incomingPackets: tt.fields.incomingPackets, } - if got := r.MaybeInsertIncoming(tt.args.p.Packet()); got != tt.want { + if got := r.MaybeInsertIncoming(tt.args.p); got != tt.want { t.Errorf("reliableQueue.MaybeInsertIncoming() = %v, want %v", got, tt.want) } }) @@ -131,86 +92,54 @@ func Test_reliableQueue_NextIncomingSequence(t *testing.T) { { name: "empty sequence", fields: fields{ - incomingPackets: []sequentialPacket{}, + incomingPackets: []*model.Packet{}, lastConsumed: model.PacketID(0), }, - want: []sequentialPacket{}, + want: []*model.Packet{}, }, { name: "single packet", fields: fields{ lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, + incomingPackets: []*model.Packet{ + {ID: 1}, }, }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, + want: []*model.Packet{ + {ID: 1}, }, }, { name: "series of sequential packets", fields: fields{ - lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, + lastConsumed: model.PacketID(0), + incomingPackets: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 3}}, }, + want: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 3}}, }, { name: "series of sequential packets with hole", fields: fields{ - lastConsumed: model.PacketID(0), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 5}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, + lastConsumed: model.PacketID(0), + incomingPackets: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 5}}, }, + want: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 3}}, }, { name: "series of sequential packets with hole, lastConsumed higher", fields: fields{ - lastConsumed: model.PacketID(10), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 3}, - &testIncomingPacket{id: 5}, - }, + lastConsumed: model.PacketID(10), + incomingPackets: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 5}}, }, - want: []sequentialPacket{}, + want: []*model.Packet{}, }, { name: "series of sequential packets with hole, lastConsumed higher, some above", fields: fields{ - lastConsumed: model.PacketID(10), - incomingPackets: []sequentialPacket{ - &testIncomingPacket{id: 1}, - &testIncomingPacket{id: 2}, - &testIncomingPacket{id: 10}, - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, - &testIncomingPacket{id: 20}, - }, - }, - want: []sequentialPacket{ - &testIncomingPacket{id: 11}, - &testIncomingPacket{id: 12}, + lastConsumed: model.PacketID(10), + incomingPackets: []*model.Packet{{ID: 1}, {ID: 2}, {ID: 10}, {ID: 11}, {ID: 12}, {ID: 20}}, }, + want: []*model.Packet{{ID: 11}, {ID: 12}}, }, } for _, tt := range tests { From 60cf8ac5a540ead6aabd0fbe0dc38fbc646ef3c3 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 16:03:31 +0100 Subject: [PATCH 50/78] more coverage --- internal/reliabletransport/packets.go | 20 +- internal/reliabletransport/receiver.go | 7 +- internal/reliabletransport/sender.go | 29 +- internal/reliabletransport/sender_test.go | 339 +++++++++++++++------- 4 files changed, 258 insertions(+), 137 deletions(-) diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 89563e5e..62f8e53f 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -52,11 +52,6 @@ func (p *inFlightPacket) backoff() time.Duration { return backoff } -// TODO: revisit interfaces while writing tests. -// assert that inFlightWrappedPacket implements inFlightPacket and sequentialPacket -// var _ inFlightPacket = &inFlightWrappedPacket{} -// var _ sequentialPacket = &inFlightWrappedPacket{} - // inflightSequence is a sequence of inFlightPackets. // A inflightSequence MUST be sorted (since the controlchannel has assigned sequential packet IDs when creating the // packet) @@ -96,6 +91,21 @@ func (seq inflightSequence) readyToSend(t time.Time) inflightSequence { return expired } +// implement sort.Interface +func (seq inflightSequence) Len() int { + return len(seq) +} + +// implement sort.Interface +func (seq inflightSequence) Swap(i, j int) { + seq[i], seq[j] = seq[j], seq[i] +} + +// implement sort.Interface +func (seq inflightSequence) Less(i, j int) bool { + return seq[i].packet.ID < seq[j].packet.ID +} + // An incomingSequence is an array of [model.Packet]. // An incomingSequence can be sorted. type incomingSequence []*model.Packet diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 7812ac7f..de3f8e4b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -166,9 +166,4 @@ func (r *reliableReceiver) newIncomingPacketSeen(p *model.Packet) incomingPacket } // assert that reliableReceiver implements incomingPacketHandler -var _ incomingPacketHandler = &reliableReceiver{ - logger: nil, - incomingPackets: make([]*model.Packet, 0), - incomingSeen: make(chan<- incomingPacketSeen), - lastConsumed: 0, -} +var _ incomingPacketHandler = &reliableReceiver{} diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 965a6b68..4f7de589 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -92,7 +92,8 @@ func (ws *workersState) moveDownWorker() { } } } else { - // TODO --- mve this to function ------------------------------------------- + // TODO --- move this to function ------------------------------------------- + // TODO: somethingToACK(state) --------------------------------------------- // there's nothing ready to be sent, so we see if we've got pending ACKs if sender.pendingACKsToSend.Len() == 0 { @@ -127,7 +128,6 @@ func (ws *workersState) moveDownWorker() { // reliableSender keeps state about the in flight packet queue, and implements outgoingPacketHandler. // Please use the constructor `newReliableSender()` type reliableSender struct { - // incomingSeen is a channel where we receive notifications for incoming packets seen by the receiver. incomingSeen <-chan incomingPacketSeen @@ -142,9 +142,9 @@ type reliableSender struct { } // newReliableSender returns a new instance of reliableOutgoing. -func newReliableSender(logger model.Logger, i chan incomingPacketSeen) *reliableSender { +func newReliableSender(logger model.Logger, ch chan incomingPacketSeen) *reliableSender { return &reliableSender{ - incomingSeen: i, + incomingSeen: ch, inFlight: make([]*inFlightPacket, 0, RELIABLE_SEND_BUFFER_SIZE), logger: logger, pendingACKsToSend: newACKSet(), @@ -180,9 +180,9 @@ func (r *reliableSender) OnIncomingPacketSeen(seen incomingPacketSeen) { // maybeEvictOrMarkWithHigherACK iterates over all the in-flight packets. For each one, // either evicts it (if the PacketID matches), or bumps the internal withHigherACK count in the // packet (if the PacketID from the ACK is higher than the packet in the queue). -func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) bool { - packets := r.inFlight - for i, p := range packets { +func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) { + pkts := r.inFlight + for i, p := range pkts { if p.packet == nil { panic("malformed packet") } @@ -194,18 +194,13 @@ func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) boo r.logger.Debugf("evicting packet %v", p.packet.ID) // first we swap this element with the last one: - packets[i], packets[len(packets)-1] = packets[len(packets)-1], packets[i] + pkts[i], pkts[len(pkts)-1] = pkts[len(pkts)-1], pkts[i] // and now exclude the last element: - r.inFlight = packets[:len(packets)-1] - - // since the in-flight array is always sorted by ascending packet-id - // (because of sequentiality assumption in the control channel), - // we're done here. - return true + r.inFlight = pkts[:len(pkts)-1] } } - return false + sort.Sort(inflightSequence(r.inFlight)) } // NextPacketIDsToACK implement outgoingPacketHandler @@ -213,7 +208,11 @@ func (r *reliableSender) NextPacketIDsToACK() []model.PacketID { return r.pendingACKsToSend.nextToACK() } +// assert reliableSender implements the needed interfaces + var _ outgoingPacketHandler = &reliableSender{} +var _ seenPacketHandler = &reliableSender{} +var _ outgoingPacketWriter = &reliableSender{} // ackSet is a set of acks. The zero value struct // is invalid, please use newACKSet. diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index 7c0b6914..f8fcf81d 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -4,16 +4,35 @@ import ( "reflect" "slices" "testing" + "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/optional" ) +func idSequence(s inflightSequence) []model.PacketID { + ids := make([]model.PacketID, 0) + for _, p := range s { + ids = append(ids, p.packet.ID) + } + return ids +} + // // tests for reliableSender // +func Test_newReliableSender(t *testing.T) { + s := newReliableSender(log.Log, make(chan incomingPacketSeen)) + if s.logger == nil { + t.Errorf("newReliableSender(): expected non nil logger") + } + if s.incomingSeen == nil { + t.Errorf("newReliableSender(): expected non nil incomingSeen") + } +} + func Test_reliableSender_TryInsertOutgoingPacket(t *testing.T) { log.SetLevel(log.DebugLevel) @@ -139,116 +158,6 @@ func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { } } -func Test_ackSet_maybeAdd(t *testing.T) { - type fields struct { - m map[model.PacketID]bool - } - type args struct { - id optional.Value[model.PacketID] - } - tests := []struct { - name string - fields fields - args args - want *ackSet - }{ - { - name: "can add on empty set", - fields: fields{newACKSet().m}, - args: args{optional.Some(model.PacketID(1))}, - want: newACKSet(1), - }, - { - name: "add duplicate on empty set", - fields: fields{newACKSet(1).m}, - args: args{optional.Some(model.PacketID(1))}, - want: newACKSet(1), - }, - { - name: "cannot add beyond capacity", - fields: fields{newACKSet(1, 2, 3, 4, 5, 6, 7, 8).m}, - args: args{optional.Some(model.PacketID(10))}, - want: newACKSet(1, 2, 3, 4, 5, 6, 7, 8), - }, - { - name: "order does not matter", - fields: fields{newACKSet(3, 2, 1).m}, - args: args{optional.Some(model.PacketID(4))}, - want: newACKSet(1, 2, 3, 4), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - as := &ackSet{ - m: tt.fields.m, - } - if got := as.maybeAdd(tt.args.id); !reflect.DeepEqual(got, tt.want) { - t.Errorf("ackSet.maybeAdd() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_ackSet_nextToACK(t *testing.T) { - type fields struct { - m map[model.PacketID]bool - } - tests := []struct { - name string - fields fields - want []model.PacketID - }{ - { - name: "get all if you have <4", - fields: fields{newACKSet(1, 2, 3).m}, - want: []model.PacketID{1, 2, 3}, - }, - { - name: "get all if you have 4", - fields: fields{newACKSet(1, 2, 3, 4).m}, - want: []model.PacketID{1, 2, 3, 4}, - }, - { - name: "get 2 if you have 2, sorted", - fields: fields{newACKSet(4, 1).m}, - want: []model.PacketID{1, 4}, - }, - { - name: "get first 4 if you have >4, sorted", - fields: fields{newACKSet(5, 6, 8, 3, 2, 4, 1).m}, - want: []model.PacketID{1, 2, 3, 4}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - as := &ackSet{ - m: tt.fields.m, - } - if got := as.nextToACK(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("ackSet.nextToACK() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_ackSet_nextToACK_empties_set(t *testing.T) { - acks := newACKSet(1, 2, 3, 5, 4, 6, 7, 10, 9, 8) - - want1 := []model.PacketID{1, 2, 3, 4} - want2 := []model.PacketID{5, 6, 7, 8} - want3 := []model.PacketID{9, 10} - - if got := acks.nextToACK(); !reflect.DeepEqual(got, want1) { - t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) - } - if got := acks.nextToACK(); !reflect.DeepEqual(got, want2) { - t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) - } - if got := acks.nextToACK(); !reflect.DeepEqual(got, want3) { - t.Errorf("ackSet.nextToACK() = %v, want %v", got, want3) - } -} - // test the combined behavior of reacting to an incoming packet and checking // what's left in the in flight queue and what's left in the queue of pending acks. func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { @@ -407,4 +316,212 @@ func Test_reliableSender_OnIncomingPacketSeen(t *testing.T) { } } -// TODO: exercise maybeEvict + withHigherACKs +// Here we test injecting different ACKs for a given in flight queue (with expired deadlines or not), +// and we check what do we get ready to send. +func Test_reliableSender_maybeEvictOrMarkWithHigherACK(t *testing.T) { + t0 := time.Date(1984, time.January, 1, 0, 0, 0, 0, time.UTC) + + type fields struct { + inFlight []*inFlightPacket + } + type args struct { + acked model.PacketID + } + tests := []struct { + name string + fields fields + args args + wantSequence []model.PacketID + }{ + { + name: "empty ack does not evict anything", + fields: fields{[]*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + }}, + args: args{}, + wantSequence: []model.PacketID{1}, + }, + { + name: "one ack evicts the matching inflight packet", + fields: fields{[]*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}, + {packet: &model.Packet{ID: 4}}, + }}, + args: args{model.PacketID(1)}, + wantSequence: []model.PacketID{2, 3, 4}, + }, + { + name: "high ack evicts only that packet", + fields: fields{[]*inFlightPacket{ + {packet: &model.Packet{ID: 1}}, + {packet: &model.Packet{ID: 2}}, + {packet: &model.Packet{ID: 3}}, + {packet: &model.Packet{ID: 4}}, + }}, + args: args{ + model.PacketID(4), + }, + wantSequence: []model.PacketID{1, 2, 3}, + }, + { + name: "high ack evicts that packet, and gets a fast rxmit if >=3", + fields: fields{[]*inFlightPacket{ + { + // expired, should be returned + packet: &model.Packet{ID: 1}, + deadline: t0.Add(-1 * time.Millisecond), + }, + { + // this one should get returned too, will get the ack counter == 3 + packet: &model.Packet{ID: 2}, + deadline: t0.Add(20 * time.Millisecond), + higherACKs: 2, + }, + { + // this one has counter to zero and not expired, should not be returned + packet: &model.Packet{ID: 3}, + deadline: t0.Add(20 * time.Millisecond), + higherACKs: 0, + }, + { + // this one is the one we're evicting so who cares + packet: &model.Packet{ID: 4}, + }, + }}, + args: args{ + // let's evict this poor packet! + model.PacketID(4), + }, + wantSequence: []model.PacketID{1, 2}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + inFlight: tt.fields.inFlight, + } + r.maybeEvictOrMarkWithHigherACK(tt.args.acked) + gotToSend := idSequence(inflightSequence(r.inFlight).readyToSend(t0)) + if !slices.Equal(gotToSend, tt.wantSequence) { + t.Errorf("reliableSender.maybeEvictOrMarkWithHigherACK() = %v, want %v", gotToSend, tt.wantSequence) + } + }) + } +} + +// +// tests for ackSet +// + +func Test_ackSet_maybeAdd(t *testing.T) { + type fields struct { + m map[model.PacketID]bool + } + type args struct { + id optional.Value[model.PacketID] + } + tests := []struct { + name string + fields fields + args args + want *ackSet + }{ + { + name: "can add on empty set", + fields: fields{newACKSet().m}, + args: args{optional.Some(model.PacketID(1))}, + want: newACKSet(1), + }, + { + name: "add duplicate on empty set", + fields: fields{newACKSet(1).m}, + args: args{optional.Some(model.PacketID(1))}, + want: newACKSet(1), + }, + { + name: "cannot add beyond capacity", + fields: fields{newACKSet(1, 2, 3, 4, 5, 6, 7, 8).m}, + args: args{optional.Some(model.PacketID(10))}, + want: newACKSet(1, 2, 3, 4, 5, 6, 7, 8), + }, + { + name: "order does not matter", + fields: fields{newACKSet(3, 2, 1).m}, + args: args{optional.Some(model.PacketID(4))}, + want: newACKSet(1, 2, 3, 4), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + as := &ackSet{ + m: tt.fields.m, + } + if got := as.maybeAdd(tt.args.id); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ackSet.maybeAdd() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ackSet_nextToACK(t *testing.T) { + type fields struct { + m map[model.PacketID]bool + } + tests := []struct { + name string + fields fields + want []model.PacketID + }{ + { + name: "get all if you have <4", + fields: fields{newACKSet(1, 2, 3).m}, + want: []model.PacketID{1, 2, 3}, + }, + { + name: "get all if you have 4", + fields: fields{newACKSet(1, 2, 3, 4).m}, + want: []model.PacketID{1, 2, 3, 4}, + }, + { + name: "get 2 if you have 2, sorted", + fields: fields{newACKSet(4, 1).m}, + want: []model.PacketID{1, 4}, + }, + { + name: "get first 4 if you have >4, sorted", + fields: fields{newACKSet(5, 6, 8, 3, 2, 4, 1).m}, + want: []model.PacketID{1, 2, 3, 4}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + as := &ackSet{ + m: tt.fields.m, + } + if got := as.nextToACK(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ackSet_nextToACK_empties_set(t *testing.T) { + acks := newACKSet(1, 2, 3, 5, 4, 6, 7, 10, 9, 8) + + want1 := []model.PacketID{1, 2, 3, 4} + want2 := []model.PacketID{5, 6, 7, 8} + want3 := []model.PacketID{9, 10} + + if got := acks.nextToACK(); !reflect.DeepEqual(got, want1) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) + } + if got := acks.nextToACK(); !reflect.DeepEqual(got, want2) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want1) + } + if got := acks.nextToACK(); !reflect.DeepEqual(got, want3) { + t.Errorf("ackSet.nextToACK() = %v, want %v", got, want3) + } +} From ae27c8f3d3c1fad94caf0c793183a94349663e81 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 16:12:30 +0100 Subject: [PATCH 51/78] x --- internal/reliabletransport/sender.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 4f7de589..226e76e1 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -214,8 +214,7 @@ var _ outgoingPacketHandler = &reliableSender{} var _ seenPacketHandler = &reliableSender{} var _ outgoingPacketWriter = &reliableSender{} -// ackSet is a set of acks. The zero value struct -// is invalid, please use newACKSet. +// ackSet is a set of acks. The zero value struct is invalid, please use newACKSet. type ackSet struct { // m is the map we use to represent the set. m map[model.PacketID]bool From b1b0904f4c0e45fe3e80b451632362696ce0a430 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 18:18:11 +0100 Subject: [PATCH 52/78] tests for next wakeup --- internal/reliabletransport/sender.go | 75 +++++++++----- internal/reliabletransport/sender_test.go | 118 ++++++++++++++++++++++ 2 files changed, 165 insertions(+), 28 deletions(-) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 226e76e1..538b41d9 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -9,6 +9,11 @@ import ( "github.com/ooni/minivpn/internal/optional" ) +var ( + // how long to wait for possible outgoing packets before sending a pending ACK as its own packet. + gracePeriodForOutgoingACKs = time.Millisecond * 20 +) + // moveDownWorker moves packets down the stack (sender) // The sender and receiver data structures lack mutexes because they are // intended to be confined to a single goroutine (one for each worker), and @@ -40,25 +45,8 @@ func (ws *workersState) moveDownWorker() { // possibly evict any acked packet (in the ack array) // and add any id to the queue of packets to ack sender.OnIncomingPacketSeen(seenPacket) - - if sender.pendingACKsToSend.Len() == 0 { - continue - } - - if sender.pendingACKsToSend.Len() >= 2 { - ticker.Reset(time.Nanosecond) - continue - } - - // if there's no event soon, give some time for other acks to arrive - // TODO: review if we need this optimization. - // TODO: maybe only during TLS handshake?? - now := time.Now() - timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) - gracePeriod := time.Millisecond * 20 - if timeout.Sub(now) > gracePeriod { - fmt.Println(">> next wakeup too late, schedule in", gracePeriod) - ticker.Reset(gracePeriod) + if shouldWakeup, when := sender.shouldWakeupAfterACK(time.Now()); shouldWakeup { + ticker.Reset(when) } case <-ticker.C: @@ -73,7 +61,6 @@ func (ws *workersState) moveDownWorker() { timeout := inflightSequence(sender.inFlight).nearestDeadlineTo(now) ticker.Reset(timeout.Sub(now)) - scheduledNow := inflightSequence(sender.inFlight).readyToSend(now) if len(scheduledNow) > 0 { @@ -84,7 +71,9 @@ func (ws *workersState) moveDownWorker() { // append any pending ACKs p.packet.ACKs = sender.NextPacketIDsToACK() + // log the packet p.packet.Log(ws.logger, model.DirectionOutgoing) + select { case ws.dataOrControlToMuxer <- p.packet: case <-ws.workersManager.ShouldShutdown(): @@ -92,18 +81,19 @@ func (ws *workersState) moveDownWorker() { } } } else { - // TODO --- move this to function ------------------------------------------- - - // TODO: somethingToACK(state) --------------------------------------------- - // there's nothing ready to be sent, so we see if we've got pending ACKs - if sender.pendingACKsToSend.Len() == 0 { + if !sender.hasPendingACKs() { continue } + // special case, we want to send the clientHello as soon as possible ----------------------------- // (TODO: coordinate this with hardReset) - if sender.pendingACKsToSend.Len() == 1 && *sender.pendingACKsToSend.first() == model.PacketID(0) { - continue - } + + /* + // TODO is this doing the right thing? + if sender.pendingACKsToSend.Len() == 1 && *sender.pendingACKsToSend.first() == model.PacketID(0) { + continue + } + */ ws.logger.Debugf("Creating ACK: %d pending to ack", sender.pendingACKsToSend.Len()) @@ -203,6 +193,32 @@ func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) { sort.Sort(inflightSequence(r.inFlight)) } +// shouldRescheduleAfterACK checks whether we need to wakeup after receiving an ACK. +// TODO: change this depending on the handshake state -------------------------- +func (r *reliableSender) shouldWakeupAfterACK(t time.Time) (bool, time.Duration) { + if r.pendingACKsToSend.Len() == 0 { + return false, time.Minute + } + // for two or more ACKs pending, we want to send right now. + if r.pendingACKsToSend.Len() >= 2 { + return true, time.Nanosecond + } + // if we've got a single ACK to send, we give it a grace period in case no other packets are + // scheduled to go out in this time. + timeout := inflightSequence(r.inFlight).nearestDeadlineTo(t) + + if timeout.Sub(t) > gracePeriodForOutgoingACKs { + r.logger.Debugf("next wakeup too late, schedule in %v", gracePeriodForOutgoingACKs) + return true, gracePeriodForOutgoingACKs + } + return true, timeout.Sub(t) +} + +// hasPendingACKs return true if there's any ack in the pending queue +func (r *reliableSender) hasPendingACKs() bool { + return r.pendingACKsToSend.Len() != 0 +} + // NextPacketIDsToACK implement outgoingPacketHandler func (r *reliableSender) NextPacketIDsToACK() []model.PacketID { return r.pendingACKsToSend.nextToACK() @@ -257,6 +273,8 @@ func (as *ackSet) nextToACK() []model.PacketID { } // first returns the first packetID in the set, in ascending order. +// TODO -- unused, possibly delete ---- was for a special case --- +/* func (as *ackSet) first() *model.PacketID { ids := as.sorted() if len(ids) == 0 { @@ -264,6 +282,7 @@ func (as *ackSet) first() *model.PacketID { } return &ids[0] } +*/ // sorted returns a []model.PacketID array with the stored ids, in ascending order. func (as *ackSet) sorted() []model.PacketID { diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index f8fcf81d..488703ff 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -412,6 +412,124 @@ func Test_reliableSender_maybeEvictOrMarkWithHigherACK(t *testing.T) { } } +func Test_reliableSender_hasPendingACKs(t *testing.T) { + type fields struct { + pendingACKsToSend *ackSet + } + tests := []struct { + name string + fields fields + want bool + }{ + { + name: "empty acksset returns false", + fields: fields{ + newACKSet(), + }, + want: false, + }, + { + name: "not empty ackset returns true", + fields: fields{ + newACKSet(1), + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + pendingACKsToSend: tt.fields.pendingACKsToSend, + } + if got := r.hasPendingACKs(); got != tt.want { + t.Errorf("reliableSender.hasPendingACKs() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_reliableSender_shouldWakeupAfterACK(t *testing.T) { + t0 := time.Date(1984, time.January, 1, 0, 0, 0, 0, time.UTC) + + type fields struct { + inflight []*inFlightPacket + pendingACKsToSend *ackSet + } + type args struct { + t time.Time + } + tests := []struct { + name string + fields fields + args args + want bool + wantDuration time.Duration + }{ + { + name: "empty ackset returns false", + fields: fields{ + pendingACKsToSend: newACKSet(), + }, + args: args{t0}, + want: false, + wantDuration: time.Minute, + }, + { + name: "len(ackset)=2 returns true", + fields: fields{ + pendingACKsToSend: newACKSet(1, 2), + }, + args: args{t0}, + want: true, + wantDuration: time.Nanosecond, + }, + { + name: "len(ackset)=1 returns grace period", + fields: fields{ + pendingACKsToSend: newACKSet(1), + }, + args: args{t0}, + want: true, + wantDuration: gracePeriodForOutgoingACKs, + }, + { + name: "len(ackset)=1 returns lower deadline if below grace period", + fields: fields{ + inflight: []*inFlightPacket{ + { + packet: &model.Packet{ID: 1}, + deadline: t0.Add(5 * time.Millisecond), + }, + { + packet: &model.Packet{ID: 2}, + deadline: t0.Add(10 * time.Millisecond), + }}, + pendingACKsToSend: newACKSet(1), + }, + args: args{t0}, + want: true, + wantDuration: time.Millisecond * 5, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &reliableSender{ + logger: log.Log, + inFlight: tt.fields.inflight, + pendingACKsToSend: tt.fields.pendingACKsToSend, + } + got, gotDuration := r.shouldWakeupAfterACK(tt.args.t) + if got != tt.want { + t.Errorf("reliableSender.shouldWakeupAfterACK() got = %v, want %v", got, tt.want) + } + if gotDuration != tt.wantDuration { + t.Errorf("reliableSender.shouldWakeupAfterACK() gotDuration = %v, want %v", gotDuration, tt.wantDuration) + } + }) + } +} + // // tests for ackSet // From a701b8379f4b24bddcd17a1fd41df7614a644bd3 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 18:45:58 +0100 Subject: [PATCH 53/78] tests for service initialization --- internal/reliabletransport/service_test.go | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 internal/reliabletransport/service_test.go diff --git a/internal/reliabletransport/service_test.go b/internal/reliabletransport/service_test.go new file mode 100644 index 00000000..76b4de59 --- /dev/null +++ b/internal/reliabletransport/service_test.go @@ -0,0 +1,65 @@ +package reliabletransport + +import ( + "testing" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// test that we can start the workers +func TestService_StartWorkers(t *testing.T) { + type fields struct { + DataOrControlToMuxer *chan *model.Packet + ControlToReliable chan *model.Packet + MuxerToReliable chan *model.Packet + ReliableToControl *chan *model.Packet + } + type args struct { + logger model.Logger + workersManager *workers.Manager + sessionManager *session.Manager + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "call startworkers with properly initialized channels", + fields: fields{ + DataOrControlToMuxer: func() *chan *model.Packet { + ch := make(chan *model.Packet) + return &ch + }(), + ControlToReliable: make(chan *model.Packet), + MuxerToReliable: make(chan *model.Packet), + ReliableToControl: func() *chan *model.Packet { + ch := make(chan *model.Packet) + return &ch + }(), + }, + args: args{ + logger: log.Log, + workersManager: workers.NewManager(log.Log), + sessionManager: func() *session.Manager { + m, _ := session.NewManager(log.Log) + return m + }(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{ + DataOrControlToMuxer: tt.fields.DataOrControlToMuxer, + ControlToReliable: tt.fields.ControlToReliable, + MuxerToReliable: tt.fields.MuxerToReliable, + ReliableToControl: tt.fields.ReliableToControl, + } + s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) + }) + } +} From 8071942c8e210df879f9cecc635800bf60881c7a Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 31 Jan 2024 19:06:42 +0100 Subject: [PATCH 54/78] add more unit tests --- internal/reliabletransport/receiver.go | 4 ++-- internal/reliabletransport/receiver_test.go | 13 +++++++++++++ internal/session/manager.go | 1 - 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index de3f8e4b..2f81581b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -108,11 +108,11 @@ type reliableReceiver struct { lastConsumed model.PacketID } -func newReliableReceiver(logger model.Logger, i chan incomingPacketSeen) *reliableReceiver { +func newReliableReceiver(logger model.Logger, ch chan incomingPacketSeen) *reliableReceiver { return &reliableReceiver{ logger: logger, incomingPackets: make([]*model.Packet, 0), - incomingSeen: i, + incomingSeen: ch, lastConsumed: 0, } } diff --git a/internal/reliabletransport/receiver_test.go b/internal/reliabletransport/receiver_test.go index 21c727ac..b5083201 100644 --- a/internal/reliabletransport/receiver_test.go +++ b/internal/reliabletransport/receiver_test.go @@ -12,6 +12,19 @@ import ( // tests for reliableReceiver // +func Test_newReliableReceiver(t *testing.T) { + rr := newReliableReceiver(log.Log, make(chan incomingPacketSeen)) + if rr.logger == nil { + t.Errorf("newReliableReceiver() should not have nil logger") + } + if rr.incomingPackets == nil { + t.Errorf("newReliableReceiver() should not have nil incomingPackets ch") + } + if rr.lastConsumed != 0 { + t.Errorf("newReliableReceiver() should have lastConsumed == 0") + } +} + func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { log.SetLevel(log.DebugLevel) diff --git a/internal/session/manager.go b/internal/session/manager.go index d86b64b7..7b54c4ab 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -87,7 +87,6 @@ type Manager struct { // Ready is a channel where we signal that we can start accepting data, because we've // successfully generated key material for the data channel. - // TODO(ainghazal): find a better way? Ready chan any } From 8547dafab0d0953aa87604672ab22e69804b5852 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 15:35:11 +0100 Subject: [PATCH 55/78] first naive reordering test --- internal/model/packet.go | 7 +- internal/reliabletransport/receiver.go | 10 +- internal/reliabletransport/reliable_test.go | 128 ++++++++++++++++++++ internal/reliabletransport/sender.go | 3 +- 4 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 internal/reliabletransport/reliable_test.go diff --git a/internal/model/packet.go b/internal/model/packet.go index 210e0f60..3939425d 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -325,6 +325,11 @@ func (p *Packet) Log(logger Logger, direction int) { return } + payloadLen := 0 + if p.Payload != nil { + payloadLen = len(p.Payload) + } + logger.Debugf( "%s %s {id=%d, acks=%v} localID=%x remoteID=%x [%d bytes]", dir, @@ -333,6 +338,6 @@ func (p *Packet) Log(logger Logger, direction int) { p.ACKs, p.LocalSessionID, p.RemoteSessionID, - len(p.Payload), + payloadLen, ) } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 2f81581b..11bcc61e 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -2,8 +2,8 @@ package reliabletransport import ( "bytes" + "encoding/hex" "fmt" - "log" "sort" "github.com/ooni/minivpn/internal/model" @@ -43,10 +43,14 @@ func (ws *workersState) moveUpWorker() { // TODO: are we handling a HARD_RESET_V2 while we're doing a handshake? // I'm not sure that's a valid behavior for a server. // We should be able to deterministically test how this affects the state machine. - log.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) + // log.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) // drop a packet that is not for our session - if !bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID()) { + fmt.Println(hex.Dump((packet.RemoteSessionID[:]))) + fmt.Println(hex.Dump((ws.sessionManager.LocalSessionID()))) + + if !bytes.Equal([]byte(packet.RemoteSessionID[:]), []byte(ws.sessionManager.LocalSessionID())) { + ws.logger.Debugf("%T: %T", packet.RemoteSessionID[:], ws.sessionManager.LocalSessionID()) ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go new file mode 100644 index 00000000..61cb309e --- /dev/null +++ b/internal/reliabletransport/reliable_test.go @@ -0,0 +1,128 @@ +package reliabletransport + +import ( + "slices" + "testing" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// test that we're able to reorder whatever is received. +func TestReliable_Reordering_withWorkers(t *testing.T) { + type fields struct { + DataOrControlToMuxer *chan *model.Packet + ControlToReliable chan *model.Packet + MuxerToReliable chan *model.Packet + ReliableToControl *chan *model.Packet + } + type args struct { + logger model.Logger + workersManager *workers.Manager + sessionManager *session.Manager + inputSequence []int + outputSequence []int + } + getFields := func() fields { + f := fields{ + DataOrControlToMuxer: func() *chan *model.Packet { + ch := make(chan *model.Packet) + return &ch + }(), + ControlToReliable: make(chan *model.Packet), + MuxerToReliable: make(chan *model.Packet), + ReliableToControl: func() *chan *model.Packet { + ch := make(chan *model.Packet) + return &ch + }(), + } + return f + } + + getArgs := func() args { + a := args{ + logger: log.Log, + workersManager: workers.NewManager(log.Log), + sessionManager: func() *session.Manager { + m, _ := session.NewManager(log.Log) + return m + }(), + inputSequence: []int{}, + outputSequence: []int{}, + } + return a + } + + tests := []struct { + name string + fields fields + args args + }{ + { + name: "test reordering for input sequence", + fields: getFields(), + args: func() args { + args := getArgs() + args.inputSequence = []int{3, 1, 2, 4} + args.outputSequence = []int{1, 2, 3, 4} + return args + }(), + }, + { + name: "test duplicates and reordering for input sequence", + fields: getFields(), + args: func() args { + args := getArgs() + args.inputSequence = []int{3, 3, 1, 1, 2, 4} + args.outputSequence = []int{1, 2, 3, 4} + return args + }(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{ + DataOrControlToMuxer: tt.fields.DataOrControlToMuxer, + ControlToReliable: tt.fields.ControlToReliable, + MuxerToReliable: tt.fields.MuxerToReliable, + ReliableToControl: tt.fields.ReliableToControl, + } + s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) + + sessionID := tt.args.sessionManager.LocalSessionID() + dataIn := tt.fields.MuxerToReliable + dataOut := tt.fields.ReliableToControl + + // create a buffered channel with "enough" capacity + collectOut := make(chan *model.Packet, 1024) + + go func(chan *model.Packet) { + for { + pkt := <-*dataOut + collectOut <- pkt + } + }(collectOut) + + for _, idx := range tt.args.inputSequence { + dataIn <- &model.Packet{ + Opcode: model.P_CONTROL_V1, + RemoteSessionID: model.SessionID(sessionID), + ID: model.PacketID(idx), + } + } + + got := make([]int, 0) + + for i := 0; i < len(tt.args.outputSequence); i++ { + pkt := <-collectOut + got = append(got, int(pkt.ID)) + } + + if !slices.Equal(got, tt.args.outputSequence) { + t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) + } + }) + } +} diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 538b41d9..1fbfd209 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -100,6 +100,7 @@ func (ws *workersState) moveDownWorker() { ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) if err != nil { ws.logger.Warnf("%s: cannot create ack: %v", workerName, err.Error()) + return } ACK.Log(ws.logger, model.DirectionOutgoing) select { @@ -196,7 +197,7 @@ func (r *reliableSender) maybeEvictOrMarkWithHigherACK(acked model.PacketID) { // shouldRescheduleAfterACK checks whether we need to wakeup after receiving an ACK. // TODO: change this depending on the handshake state -------------------------- func (r *reliableSender) shouldWakeupAfterACK(t time.Time) (bool, time.Duration) { - if r.pendingACKsToSend.Len() == 0 { + if r.pendingACKsToSend.Len() <= 0 { return false, time.Minute } // for two or more ACKs pending, we want to send right now. From d07895cd9474570cfcc325f8f9218b60078ca7c2 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 17:02:07 +0100 Subject: [PATCH 56/78] going for a walk outside --- internal/model/packet.go | 27 ++++++ internal/reliabletransport/receiver.go | 5 +- internal/reliabletransport/reliable_test.go | 100 +++++++++++++------- 3 files changed, 95 insertions(+), 37 deletions(-) diff --git a/internal/model/packet.go b/internal/model/packet.go index 3939425d..65f07afe 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -32,6 +32,33 @@ const ( P_DATA_V2 // 9 ) +// NewOpcodeFromString returns an opcode from a string representation, and an error if it cannot parse the opcode +// representation. The zero return value is invalid. +func NewOpcodeFromString(s string) (Opcode, error) { + switch s { + case "CONTROL_HARD_RESET_CLIENT_V1": + return P_CONTROL_HARD_RESET_CLIENT_V1, nil + case "CONTROL_HARD_RESET_SERVER_V1": + return P_CONTROL_HARD_RESET_SERVER_V1, nil + case "CONTROL_SOFT_RESET_V1": + return P_CONTROL_SOFT_RESET_V1, nil + case "CONTROL_V1": + return P_CONTROL_V1, nil + case "ACK_V1": + return P_ACK_V1, nil + case "DATA_V1": + return P_DATA_V1, nil + case "CONTROL_HARD_RESET_CLIENT_V2": + return P_CONTROL_HARD_RESET_CLIENT_V2, nil + case "CONTROL_HARD_RESET_SERVER_V2": + return P_CONTROL_HARD_RESET_SERVER_V2, nil + case "DATA_V2": + return P_DATA_V2, nil + default: + return 0, errors.New("unknown opcode") + } +} + // String returns the opcode string representation func (op Opcode) String() string { switch op { diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 11bcc61e..e4a4f0ad 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -2,7 +2,6 @@ package reliabletransport import ( "bytes" - "encoding/hex" "fmt" "sort" @@ -46,11 +45,8 @@ func (ws *workersState) moveUpWorker() { // log.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) // drop a packet that is not for our session - fmt.Println(hex.Dump((packet.RemoteSessionID[:]))) - fmt.Println(hex.Dump((ws.sessionManager.LocalSessionID()))) if !bytes.Equal([]byte(packet.RemoteSessionID[:]), []byte(ws.sessionManager.LocalSessionID())) { - ws.logger.Debugf("%T: %T", packet.RemoteSessionID[:], ws.sessionManager.LocalSessionID()) ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, @@ -73,6 +69,7 @@ func (ws *workersState) moveUpWorker() { if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it + ws.logger.Debugf("Dropping packet: %v", packet.ID) continue } diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index 61cb309e..c0f346cb 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -1,17 +1,24 @@ package reliabletransport import ( + "fmt" "slices" + "sync" "testing" + "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/vpntest" "github.com/ooni/minivpn/internal/workers" ) -// test that we're able to reorder whatever is received. +// test that we're able to reorder (towards TLS) whatever is received (from the muxer). func TestReliable_Reordering_withWorkers(t *testing.T) { + + log.SetLevel(log.DebugLevel) + type fields struct { DataOrControlToMuxer *chan *model.Packet ControlToReliable chan *model.Packet @@ -22,7 +29,7 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { logger model.Logger workersManager *workers.Manager sessionManager *session.Manager - inputSequence []int + inputSequence []string outputSequence []int } getFields := func() fields { @@ -32,9 +39,9 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { return &ch }(), ControlToReliable: make(chan *model.Packet), - MuxerToReliable: make(chan *model.Packet), + MuxerToReliable: make(chan *model.Packet, 1024), ReliableToControl: func() *chan *model.Packet { - ch := make(chan *model.Packet) + ch := make(chan *model.Packet, 1024) return &ch }(), } @@ -49,7 +56,7 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { m, _ := session.NewManager(log.Log) return m }(), - inputSequence: []int{}, + inputSequence: []string{}, outputSequence: []int{}, } return a @@ -61,25 +68,42 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { args args }{ { - name: "test reordering for input sequence", - fields: getFields(), - args: func() args { - args := getArgs() - args.inputSequence = []int{3, 1, 2, 4} - args.outputSequence = []int{1, 2, 3, 4} - return args - }(), - }, - { - name: "test duplicates and reordering for input sequence", + name: "test proper ordering for input sequence", fields: getFields(), args: func() args { args := getArgs() - args.inputSequence = []int{3, 3, 1, 1, 2, 4} + args.inputSequence = []string{ + "[1] CONTROL_V1 +5ms", + "[2] CONTROL_V1 +5ms", + "[3] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + } args.outputSequence = []int{1, 2, 3, 4} return args }(), }, + + // not yet! :) + + /* + { + name: "test reordering for input sequence", + fields: getFields(), + args: func() args { + args := getArgs() + args.inputSequence = []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + "[3] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +5ms", + } + args.outputSequence = []int{1, 2, 3, 4} + return args + }(), + }, + */ + + // TODO test duplicates } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,33 +115,43 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { } s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) - sessionID := tt.args.sessionManager.LocalSessionID() + // the only two channels we're going to be testing on this test dataIn := tt.fields.MuxerToReliable dataOut := tt.fields.ReliableToControl + sessionID := tt.args.sessionManager.LocalSessionID() - // create a buffered channel with "enough" capacity - collectOut := make(chan *model.Packet, 1024) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for _, testStr := range tt.args.inputSequence { + testPkt, err := vpntest.NewTestPacketFromString(testStr) + if err != nil { + t.Errorf("Reordering: error reading test sequence: %v", err.Error()) + } - go func(chan *model.Packet) { - for { - pkt := <-*dataOut - collectOut <- pkt - } - }(collectOut) + fmt.Printf("test packet: %v\n", testPkt) - for _, idx := range tt.args.inputSequence { - dataIn <- &model.Packet{ - Opcode: model.P_CONTROL_V1, - RemoteSessionID: model.SessionID(sessionID), - ID: model.PacketID(idx), + dataIn <- &model.Packet{ + Opcode: testPkt.Opcode, + RemoteSessionID: model.SessionID(sessionID), + ID: model.PacketID(testPkt.ID), + } + log.Debugf("sleeping for %T(%v)", testPkt.IAT, testPkt.IAT) + time.Sleep(testPkt.IAT) } - } + log.Info("test: done writing") + }() + + wg.Wait() + log.Debug("start collecting packets") got := make([]int, 0) for i := 0; i < len(tt.args.outputSequence); i++ { - pkt := <-collectOut + pkt := <-*dataOut got = append(got, int(pkt.ID)) + log.Debugf("got packet: %v", pkt.ID) } if !slices.Equal(got, tt.args.outputSequence) { From c5718ad72196712ec47bb386a93db1063455a0e9 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 18:39:03 +0100 Subject: [PATCH 57/78] debug: wtf is going on --- internal/reliabletransport/receiver.go | 14 +++ internal/reliabletransport/reliable_test.go | 94 +++++++++++++-------- internal/reliabletransport/service.go | 2 +- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index e4a4f0ad..02bf2402 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -39,6 +39,8 @@ func (ws *workersState) moveUpWorker() { packet.Log(ws.logger, model.DirectionIncoming) } + fmt.Println("< from muxer", packet.ID) + // TODO: are we handling a HARD_RESET_V2 while we're doing a handshake? // I'm not sure that's a valid behavior for a server. // We should be able to deterministically test how this affects the state machine. @@ -56,33 +58,45 @@ func (ws *workersState) moveUpWorker() { continue } + fmt.Println("< create seen") seen := receiver.newIncomingPacketSeen(packet) ws.incomingSeen <- seen + fmt.Println("< wrote to seen ch") // TODO(ainghazal): drop a packet that is a replay (id <= lastConsumed, but != ACK...?) // we only want to insert control packets going to the tls layer if packet.Opcode != model.P_CONTROL_V1 { + fmt.Println("< not a control!!") continue } if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it + fmt.Println("< droppin!!") ws.logger.Debugf("Dropping packet: %v", packet.ID) continue } ready := receiver.NextIncomingSequence() + fmt.Println("< got next", ready) + for _, nextPacket := range ready { // POSSIBLY BLOCK delivering to the upper layer select { case ws.reliableToControl <- nextPacket: + fmt.Println("< wrote to control") case <-ws.workersManager.ShouldShutdown(): return } } + fmt.Println("< DONE, END LOOP") + fmt.Println("< incomingSeen:", len(ws.incomingSeen)) + fmt.Println("< muxerToReliable:", len(ws.muxerToReliable)) + fmt.Println("") + case <-ws.workersManager.ShouldShutdown(): return } diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index c0f346cb..9d9be9a4 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -5,7 +5,6 @@ import ( "slices" "sync" "testing" - "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" @@ -45,6 +44,8 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { return &ch }(), } + fmt.Println(":: muxer to reliable", len(f.MuxerToReliable)) + fmt.Println(":: reliable to control", len(f.MuxerToReliable)) return f } @@ -73,10 +74,10 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { args: func() args { args := getArgs() args.inputSequence = []string{ - "[1] CONTROL_V1 +5ms", - "[2] CONTROL_V1 +5ms", - "[3] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", } args.outputSequence = []int{1, 2, 3, 4} return args @@ -113,50 +114,73 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { MuxerToReliable: tt.fields.MuxerToReliable, ReliableToControl: tt.fields.ReliableToControl, } - s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) - // the only two channels we're going to be testing on this test dataIn := tt.fields.MuxerToReliable dataOut := tt.fields.ReliableToControl sessionID := tt.args.sessionManager.LocalSessionID() + fmt.Println("") + fmt.Println(">> initial len DATAIN ", len(dataIn)) + + // let the workers pump up the jam! + s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) + + for _, testStr := range tt.args.inputSequence { + testPkt, err := vpntest.NewTestPacketFromString(testStr) + if err != nil { + t.Errorf("Reordering: error reading test sequence: %v", err.Error()) + } + + fmt.Printf("::: test packet: %v\n", testPkt) + + p := &model.Packet{ + Opcode: testPkt.Opcode, + RemoteSessionID: model.SessionID(sessionID), + ID: model.PacketID(testPkt.ID), + } + dataIn <- p + log.Infof("test: len write ch: %v", len(dataIn)) + // log.Debugf("sleeping for %T(%v)", testPkt.IAT, testPkt.IAT) + // time.Sleep(testPkt.IAT) + // time.Sleep(time.Millisecond) + } + log.Info("test: done writing") + log.Infof("test: len write ch: %v", len(dataIn)) + + fmt.Println("data out", len(*dataOut)) + + fmt.Println("s", s) + + // start the result collector in a different goroutine var wg sync.WaitGroup wg.Add(1) - go func() { + go func(ch <-chan *model.Packet) { defer wg.Done() - for _, testStr := range tt.args.inputSequence { - testPkt, err := vpntest.NewTestPacketFromString(testStr) - if err != nil { - t.Errorf("Reordering: error reading test sequence: %v", err.Error()) - } + log.Debug("start collecting packets") - fmt.Printf("test packet: %v\n", testPkt) + got := make([]int, 0) - dataIn <- &model.Packet{ - Opcode: testPkt.Opcode, - RemoteSessionID: model.SessionID(sessionID), - ID: model.PacketID(testPkt.ID), + for { + // have we read enough packets to call it a day? + if len(got) >= len(tt.args.outputSequence) { + fmt.Println("we got enough packets!", got) + break } - log.Debugf("sleeping for %T(%v)", testPkt.IAT, testPkt.IAT) - time.Sleep(testPkt.IAT) + // no, so let's keep reading until the test runner kills us + pkt := <-ch + got = append(got, int(pkt.ID)) + log.Debugf("got packet: %v", pkt.ID) } - log.Info("test: done writing") - }() - - wg.Wait() - log.Debug("start collecting packets") - got := make([]int, 0) - - for i := 0; i < len(tt.args.outputSequence); i++ { - pkt := <-*dataOut - got = append(got, int(pkt.ID)) - log.Debugf("got packet: %v", pkt.ID) - } + // let's check if what we got is correct + if !slices.Equal(got, tt.args.outputSequence) { + t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) + } + }(*dataOut) - if !slices.Equal(got, tt.args.outputSequence) { - t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) - } + wg.Wait() + tt.args.workersManager.StartShutdown() + tt.args.workersManager.WaitWorkersShutdown() }) } } diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go index 3197e686..b2a1bd75 100644 --- a/internal/reliabletransport/service.go +++ b/internal/reliabletransport/service.go @@ -39,7 +39,7 @@ func (s *Service) StartWorkers( logger: logger, // incomingSeen is a buffered channel to avoid losing packets if we're busy // processing in the sender goroutine. - incomingSeen: make(chan incomingPacketSeen, 20), + incomingSeen: make(chan incomingPacketSeen, 100), dataOrControlToMuxer: *s.DataOrControlToMuxer, controlToReliable: s.ControlToReliable, muxerToReliable: s.MuxerToReliable, From b18c3ee5d382d8af046ef2c3caf98d7fe16396a6 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 19:53:06 +0100 Subject: [PATCH 58/78] fix bug in sender that breaks loop --- internal/reliabletransport/receiver.go | 4 +- internal/reliabletransport/reliable_test.go | 180 ++++++++------------ internal/reliabletransport/sender.go | 2 +- 3 files changed, 78 insertions(+), 108 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 02bf2402..6e0be275 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -80,14 +80,16 @@ func (ws *workersState) moveUpWorker() { } ready := receiver.NextIncomingSequence() - fmt.Println("< got next", ready) + fmt.Println("< got next packets", len(ready)) for _, nextPacket := range ready { + fmt.Println(">> WRITE UP", nextPacket.ID) // POSSIBLY BLOCK delivering to the upper layer select { case ws.reliableToControl <- nextPacket: fmt.Println("< wrote to control") case <-ws.workersManager.ShouldShutdown(): + fmt.Println(">> GOT SHUTDOWN SIGNAL") return } } diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index 9d9be9a4..108aa1d0 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -5,6 +5,7 @@ import ( "slices" "sync" "testing" + "time" "github.com/apex/log" "github.com/ooni/minivpn/internal/model" @@ -18,112 +19,94 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { log.SetLevel(log.DebugLevel) - type fields struct { - DataOrControlToMuxer *chan *model.Packet - ControlToReliable chan *model.Packet - MuxerToReliable chan *model.Packet - ReliableToControl *chan *model.Packet - } type args struct { - logger model.Logger - workersManager *workers.Manager - sessionManager *session.Manager inputSequence []string outputSequence []int } - getFields := func() fields { - f := fields{ - DataOrControlToMuxer: func() *chan *model.Packet { - ch := make(chan *model.Packet) - return &ch - }(), - ControlToReliable: make(chan *model.Packet), - MuxerToReliable: make(chan *model.Packet, 1024), - ReliableToControl: func() *chan *model.Packet { - ch := make(chan *model.Packet, 1024) - return &ch - }(), - } - fmt.Println(":: muxer to reliable", len(f.MuxerToReliable)) - fmt.Println(":: reliable to control", len(f.MuxerToReliable)) - return f - } - - getArgs := func() args { - a := args{ - logger: log.Log, - workersManager: workers.NewManager(log.Log), - sessionManager: func() *session.Manager { - m, _ := session.NewManager(log.Log) - return m - }(), - inputSequence: []string{}, - outputSequence: []int{}, - } - return a - } tests := []struct { - name string - fields fields - args args + name string + args args }{ { - name: "test proper ordering for input sequence", - fields: getFields(), - args: func() args { - args := getArgs() - args.inputSequence = []string{ - "[1] CONTROL_V1 +1ms", - "[2] CONTROL_V1 +1ms", - "[3] CONTROL_V1 +1ms", - "[4] CONTROL_V1 +1ms", - } - args.outputSequence = []int{1, 2, 3, 4} - return args - }(), + name: "test proper ordering for input sequence", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +5ms", + "[2] CONTROL_V1 +5ms", + "[3] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, }, - - // not yet! :) - - /* - { - name: "test reordering for input sequence", - fields: getFields(), - args: func() args { - args := getArgs() - args.inputSequence = []string{ - "[2] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", - "[3] CONTROL_V1 +5ms", - "[1] CONTROL_V1 +5ms", - } - args.outputSequence = []int{1, 2, 3, 4} - return args - }(), + { + name: "test reordering for input sequence", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + "[3] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +5ms", + }, + outputSequence: []int{1, 2, 3, 4}, }, - */ - - // TODO test duplicates + }, + { + name: "test reordering for input sequence, longer waits", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +50ms", + "[3] CONTROL_V1 +100ms", + "[1] CONTROL_V1 +100ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, + { + name: "test reordering for input sequence, with duplicates", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +5ms", + "[3] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +5ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &Service{ - DataOrControlToMuxer: tt.fields.DataOrControlToMuxer, - ControlToReliable: tt.fields.ControlToReliable, - MuxerToReliable: tt.fields.MuxerToReliable, - ReliableToControl: tt.fields.ReliableToControl, - } + dataToMuxer := make(chan *model.Packet) + // the only two channels we're going to be testing on this test - dataIn := tt.fields.MuxerToReliable - dataOut := tt.fields.ReliableToControl - sessionID := tt.args.sessionManager.LocalSessionID() + dataIn := make(chan *model.Packet, 1024) + dataOut := make(chan *model.Packet, 1024) + + workersManager := workers.NewManager(log.Log) + sessionManager, err := session.NewManager(log.Log) + if err != nil { + t.Errorf("Reordering: cannot create session.Manager: %v", err.Error()) + } - fmt.Println("") - fmt.Println(">> initial len DATAIN ", len(dataIn)) + s := &Service{ + DataOrControlToMuxer: nil, + ControlToReliable: make(chan *model.Packet), + MuxerToReliable: dataIn, + ReliableToControl: nil, + } + s.DataOrControlToMuxer = &dataToMuxer + s.ReliableToControl = &dataOut + sessionID := sessionManager.LocalSessionID() // let the workers pump up the jam! - s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) + s.StartWorkers(log.Log, workersManager, sessionManager) for _, testStr := range tt.args.inputSequence { testPkt, err := vpntest.NewTestPacketFromString(testStr) @@ -131,35 +114,22 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { t.Errorf("Reordering: error reading test sequence: %v", err.Error()) } - fmt.Printf("::: test packet: %v\n", testPkt) - p := &model.Packet{ Opcode: testPkt.Opcode, RemoteSessionID: model.SessionID(sessionID), ID: model.PacketID(testPkt.ID), } dataIn <- p - log.Infof("test: len write ch: %v", len(dataIn)) - // log.Debugf("sleeping for %T(%v)", testPkt.IAT, testPkt.IAT) - // time.Sleep(testPkt.IAT) - // time.Sleep(time.Millisecond) + time.Sleep(testPkt.IAT) } - log.Info("test: done writing") - log.Infof("test: len write ch: %v", len(dataIn)) - - fmt.Println("data out", len(*dataOut)) - - fmt.Println("s", s) // start the result collector in a different goroutine var wg sync.WaitGroup wg.Add(1) go func(ch <-chan *model.Packet) { defer wg.Done() - log.Debug("start collecting packets") got := make([]int, 0) - for { // have we read enough packets to call it a day? if len(got) >= len(tt.args.outputSequence) { @@ -176,11 +146,9 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { if !slices.Equal(got, tt.args.outputSequence) { t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) } - }(*dataOut) + }(dataOut) wg.Wait() - tt.args.workersManager.StartShutdown() - tt.args.workersManager.WaitWorkersShutdown() }) } } diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 1fbfd209..53578c43 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -100,7 +100,7 @@ func (ws *workersState) moveDownWorker() { ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) if err != nil { ws.logger.Warnf("%s: cannot create ack: %v", workerName, err.Error()) - return + continue } ACK.Log(ws.logger, model.DirectionOutgoing) select { From a6cece1ef27308d9008c844f4ec72ac7d340dd97 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 19:56:39 +0100 Subject: [PATCH 59/78] remove debug lines --- internal/reliabletransport/receiver.go | 15 --------------- internal/reliabletransport/reliable_test.go | 2 -- 2 files changed, 17 deletions(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 6e0be275..f980632f 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -39,8 +39,6 @@ func (ws *workersState) moveUpWorker() { packet.Log(ws.logger, model.DirectionIncoming) } - fmt.Println("< from muxer", packet.ID) - // TODO: are we handling a HARD_RESET_V2 while we're doing a handshake? // I'm not sure that's a valid behavior for a server. // We should be able to deterministically test how this affects the state machine. @@ -58,47 +56,34 @@ func (ws *workersState) moveUpWorker() { continue } - fmt.Println("< create seen") seen := receiver.newIncomingPacketSeen(packet) ws.incomingSeen <- seen - fmt.Println("< wrote to seen ch") // TODO(ainghazal): drop a packet that is a replay (id <= lastConsumed, but != ACK...?) // we only want to insert control packets going to the tls layer if packet.Opcode != model.P_CONTROL_V1 { - fmt.Println("< not a control!!") continue } if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it - fmt.Println("< droppin!!") ws.logger.Debugf("Dropping packet: %v", packet.ID) continue } ready := receiver.NextIncomingSequence() - fmt.Println("< got next packets", len(ready)) for _, nextPacket := range ready { - fmt.Println(">> WRITE UP", nextPacket.ID) // POSSIBLY BLOCK delivering to the upper layer select { case ws.reliableToControl <- nextPacket: - fmt.Println("< wrote to control") case <-ws.workersManager.ShouldShutdown(): - fmt.Println(">> GOT SHUTDOWN SIGNAL") return } } - fmt.Println("< DONE, END LOOP") - fmt.Println("< incomingSeen:", len(ws.incomingSeen)) - fmt.Println("< muxerToReliable:", len(ws.muxerToReliable)) - fmt.Println("") - case <-ws.workersManager.ShouldShutdown(): return } diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index 108aa1d0..39106bd1 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -1,7 +1,6 @@ package reliabletransport import ( - "fmt" "slices" "sync" "testing" @@ -133,7 +132,6 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { for { // have we read enough packets to call it a day? if len(got) >= len(tt.args.outputSequence) { - fmt.Println("we got enough packets!", got) break } // no, so let's keep reading until the test runner kills us From 5a961eb3a65c23a1244632a8fdc54c1dc316e440 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Thu, 1 Feb 2024 20:11:16 +0100 Subject: [PATCH 60/78] cleanup test a bit --- internal/reliabletransport/reliable_test.go | 75 +++++++++++---------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index 39106bd1..8cb31932 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -13,6 +13,15 @@ import ( "github.com/ooni/minivpn/internal/workers" ) +func initManagers() (*workers.Manager, *session.Manager) { + w := workers.NewManager(log.Log) + s, err := session.NewManager(log.Log) + if err != nil { + panic(err) + } + return w, s +} + // test that we're able to reorder (towards TLS) whatever is received (from the muxer). func TestReliable_Reordering_withWorkers(t *testing.T) { @@ -31,10 +40,10 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { name: "test proper ordering for input sequence", args: args{ inputSequence: []string{ - "[1] CONTROL_V1 +5ms", - "[2] CONTROL_V1 +5ms", - "[3] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", }, outputSequence: []int{1, 2, 3, 4}, }, @@ -43,10 +52,10 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { name: "test reordering for input sequence", args: args{ inputSequence: []string{ - "[2] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", - "[3] CONTROL_V1 +5ms", - "[1] CONTROL_V1 +5ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", }, outputSequence: []int{1, 2, 3, 4}, }, @@ -56,9 +65,9 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { args: args{ inputSequence: []string{ "[2] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +50ms", - "[3] CONTROL_V1 +100ms", - "[1] CONTROL_V1 +100ms", + "[4] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +50ms", }, outputSequence: []int{1, 2, 3, 4}, }, @@ -67,14 +76,14 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { name: "test reordering for input sequence, with duplicates", args: args{ inputSequence: []string{ - "[2] CONTROL_V1 +5ms", - "[2] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", - "[4] CONTROL_V1 +5ms", - "[1] CONTROL_V1 +5ms", - "[3] CONTROL_V1 +5ms", - "[1] CONTROL_V1 +5ms", + "[2] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", }, outputSequence: []int{1, 2, 3, 4}, }, @@ -82,30 +91,26 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + + s := &Service{} + + // just to properly initialize it, we don't care about these + s.ControlToReliable = make(chan *model.Packet) dataToMuxer := make(chan *model.Packet) + s.DataOrControlToMuxer = &dataToMuxer // the only two channels we're going to be testing on this test dataIn := make(chan *model.Packet, 1024) dataOut := make(chan *model.Packet, 1024) - workersManager := workers.NewManager(log.Log) - sessionManager, err := session.NewManager(log.Log) - if err != nil { - t.Errorf("Reordering: cannot create session.Manager: %v", err.Error()) - } - - s := &Service{ - DataOrControlToMuxer: nil, - ControlToReliable: make(chan *model.Packet), - MuxerToReliable: dataIn, - ReliableToControl: nil, - } - s.DataOrControlToMuxer = &dataToMuxer + s.MuxerToReliable = dataIn s.ReliableToControl = &dataOut - sessionID := sessionManager.LocalSessionID() + + workers, session := initManagers() + sessionID := session.LocalSessionID() // let the workers pump up the jam! - s.StartWorkers(log.Log, workersManager, sessionManager) + s.StartWorkers(log.Log, workers, session) for _, testStr := range tt.args.inputSequence { testPkt, err := vpntest.NewTestPacketFromString(testStr) @@ -127,7 +132,6 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { wg.Add(1) go func(ch <-chan *model.Packet) { defer wg.Done() - got := make([]int, 0) for { // have we read enough packets to call it a day? @@ -145,7 +149,6 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) } }(dataOut) - wg.Wait() }) } From 6abd35d1b8b48e9f94ae631f70b0834505a4e29d Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 01:18:24 +0100 Subject: [PATCH 61/78] add vpntest module --- internal/reliabletransport/reliable_test.go | 63 +++++------- internal/vpntest/packetio.go | 103 ++++++++++++++++++++ internal/vpntest/vpntest.go | 56 +++++++++++ internal/vpntest/vpntest_test.go | 45 +++++++++ 4 files changed, 228 insertions(+), 39 deletions(-) create mode 100644 internal/vpntest/packetio.go create mode 100644 internal/vpntest/vpntest.go create mode 100644 internal/vpntest/vpntest_test.go diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_test.go index 8cb31932..39ad35c3 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_test.go @@ -1,8 +1,6 @@ package reliabletransport import ( - "slices" - "sync" "testing" "time" @@ -88,10 +86,23 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { outputSequence: []int{1, 2, 3, 4}, }, }, + { + name: "reordering with acks interspersed", + args: args{ + inputSequence: []string{ + "[2] CONTROL_V1 +5ms", + "[4] CONTROL_V1 +2ms", + "[0] ACK_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[0] ACK_V1 +1ms", + "[1] CONTROL_V1 +2ms", + }, + outputSequence: []int{1, 2, 3, 4}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &Service{} // just to properly initialize it, we don't care about these @@ -100,6 +111,7 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { s.DataOrControlToMuxer = &dataToMuxer // the only two channels we're going to be testing on this test + // we want to buffer enough to be safe writing to them. dataIn := make(chan *model.Packet, 1024) dataOut := make(chan *model.Packet, 1024) @@ -109,47 +121,20 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { workers, session := initManagers() sessionID := session.LocalSessionID() + t0 := time.Now() + // let the workers pump up the jam! s.StartWorkers(log.Log, workers, session) - for _, testStr := range tt.args.inputSequence { - testPkt, err := vpntest.NewTestPacketFromString(testStr) - if err != nil { - t.Errorf("Reordering: error reading test sequence: %v", err.Error()) - } + writer := vpntest.NewPacketWriter(dataIn) + writer.LocalSessionID = model.SessionID(sessionID) + go writer.WriteSequence(tt.args.inputSequence) - p := &model.Packet{ - Opcode: testPkt.Opcode, - RemoteSessionID: model.SessionID(sessionID), - ID: model.PacketID(testPkt.ID), - } - dataIn <- p - time.Sleep(testPkt.IAT) + reader := vpntest.NewPacketReader(dataOut) + if ok := reader.WaitForSequence(tt.args.outputSequence, t0); !ok { + got := vpntest.PacketLog(reader.ReceivedSequence()).IDSequence() + t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) } - - // start the result collector in a different goroutine - var wg sync.WaitGroup - wg.Add(1) - go func(ch <-chan *model.Packet) { - defer wg.Done() - got := make([]int, 0) - for { - // have we read enough packets to call it a day? - if len(got) >= len(tt.args.outputSequence) { - break - } - // no, so let's keep reading until the test runner kills us - pkt := <-ch - got = append(got, int(pkt.ID)) - log.Debugf("got packet: %v", pkt.ID) - } - - // let's check if what we got is correct - if !slices.Equal(got, tt.args.outputSequence) { - t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) - } - }(dataOut) - wg.Wait() }) } } diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go new file mode 100644 index 00000000..c08b53fc --- /dev/null +++ b/internal/vpntest/packetio.go @@ -0,0 +1,103 @@ +package vpntest + +import ( + "slices" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" +) + +// PacketWriter is a service that writes packets into a channel. +type PacketWriter struct { + // A channel where to write packets to. + ch chan<- *model.Packet + + // LocalSessionID is needed to produce packets that pass sanity checks. + LocalSessionID model.SessionID +} + +// NewPacketWriter creates a new PacketWriter. +func NewPacketWriter(ch chan<- *model.Packet) *PacketWriter { + return &PacketWriter{ch: ch} +} + +// WriteSequence writes the passed packet sequence (in their string representation) +// to the configured channel. It will wait the specified interval between one packet and the next. +func (pw *PacketWriter) WriteSequence(seq []string) { + for _, testStr := range seq { + testPkt, err := NewTestPacketFromString(testStr) + if err != nil { + panic("PacketWriter: error reading test sequence:" + err.Error()) + } + + p := &model.Packet{ + Opcode: testPkt.Opcode, + RemoteSessionID: pw.LocalSessionID, + ID: model.PacketID(testPkt.ID), + } + pw.ch <- p + time.Sleep(testPkt.IAT) + } +} + +// LoggedPacket is a trace of a received packet. +type LoggedPacket struct { + ID int + Opcode model.Opcode + + At time.Duration +} + +// PacketLog is a sequence of LoggedPacket. +type PacketLog []*LoggedPacket + +// IDSequence returns a sequence of int from the logged packets. +func (l PacketLog) IDSequence() []int { + ids := make([]int, 0) + for _, p := range l { + ids = append(ids, int(p.ID)) + } + return ids +} + +// PacketReader reads packets from a channel. +type PacketReader struct { + ch <-chan *model.Packet + got []*LoggedPacket +} + +// NewPacketReader creates a new PacketReader. +func NewPacketReader(ch <-chan *model.Packet) *PacketReader { + return &PacketReader{ch: ch} +} + +// WaitForSequence blocks forever reading from the internal channel until the obtained +// sequence matches the len of the expected; it stores the received sequence and then returns +// true if the obtained packet ID sequence matches the expected one. +func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { + got := make([]*LoggedPacket, 0) + for { + // have we read enough packets to call it a day? + if len(got) >= len(seq) { + break + } + // no, so let's keep reading until the test runner kills us + pkt := <-pr.ch + got = append( + got, + &LoggedPacket{ + ID: int(pkt.ID), + Opcode: pkt.Opcode, + At: time.Since(start), + }) + log.Debugf("got packet: %v", pkt.ID) + } + pr.got = got + return slices.Equal(seq, PacketLog(got).IDSequence()) +} + +// ReceivedSequence returns the log of the received sequence. +func (pr *PacketReader) ReceivedSequence() []*LoggedPacket { + return pr.got +} diff --git a/internal/vpntest/vpntest.go b/internal/vpntest/vpntest.go new file mode 100644 index 00000000..a6382874 --- /dev/null +++ b/internal/vpntest/vpntest.go @@ -0,0 +1,56 @@ +// Package vpntest provides utilities for minivpn testing. +package vpntest + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +// TestPacket is used to simulate incoming packets over the network. The goal is to be able to +// have a compact representation of a sequence of packets, their type, and extra properties like +// inter-arrival time. +type TestPacket struct { + // ID is the packet sequence + ID int + + // Opcode is the OpenVPN packet opcode. + Opcode model.Opcode + + // IAT is the inter-arrival time until the next packet is received. + IAT time.Duration +} + +// the test packet string is in the form: +// "[ID] OPCODE +42ms" +func NewTestPacketFromString(s string) (*TestPacket, error) { + parts := strings.Split(s, " +") + + // Extracting id and opcode parts + idAndOpcode := strings.Split(parts[0], " ") + if len(idAndOpcode) != 2 { + return nil, fmt.Errorf("invalid format for ID and opcode: %s", parts[0]) + } + + id, err := strconv.Atoi(strings.Trim(idAndOpcode[0], "[]")) + if err != nil { + return nil, fmt.Errorf("failed to parse id: %v", err) + } + + opcode, err := model.NewOpcodeFromString(idAndOpcode[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse opcode: %v", err) + } + + // Parsing duration part + iatStr := parts[1] + iat, err := time.ParseDuration(iatStr) + if err != nil { + return nil, fmt.Errorf("failed to parse duration: %v", err) + } + + return &TestPacket{ID: id, Opcode: opcode, IAT: iat}, nil +} diff --git a/internal/vpntest/vpntest_test.go b/internal/vpntest/vpntest_test.go new file mode 100644 index 00000000..f8c04f68 --- /dev/null +++ b/internal/vpntest/vpntest_test.go @@ -0,0 +1,45 @@ +// Package vpntest provides utilities for minivpn testing. +package vpntest + +import ( + "reflect" + "testing" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +func TestNewTestPacketFromString(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want *TestPacket + wantErr bool + }{ + { + name: "parse a correct testpacket string", + args: args{"[1] CONTROL_V1 +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewTestPacketFromString(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("NewTestPacketFromString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTestPacketFromString() = %v, want %v", got, tt.want) + } + }) + } +} From bf7b83b589a2f733dd0b13ba60b222d53b7828b1 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 01:26:55 +0100 Subject: [PATCH 62/78] cosmetic changes --- ...iable_test.go => reliable_reorder_test.go} | 7 +++-- internal/vpntest/packetio.go | 29 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) rename internal/reliabletransport/{reliable_test.go => reliable_reorder_test.go} (94%) diff --git a/internal/reliabletransport/reliable_test.go b/internal/reliabletransport/reliable_reorder_test.go similarity index 94% rename from internal/reliabletransport/reliable_test.go rename to internal/reliabletransport/reliable_reorder_test.go index 39ad35c3..a9879672 100644 --- a/internal/reliabletransport/reliable_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -11,6 +11,7 @@ import ( "github.com/ooni/minivpn/internal/workers" ) +// initManagers initializes a workers manager and a session manager func initManagers() (*workers.Manager, *session.Manager) { w := workers.NewManager(log.Log) s, err := session.NewManager(log.Log) @@ -21,7 +22,7 @@ func initManagers() (*workers.Manager, *session.Manager) { } // test that we're able to reorder (towards TLS) whatever is received (from the muxer). -func TestReliable_Reordering_withWorkers(t *testing.T) { +func TestReliable_Reordering_UP(t *testing.T) { log.SetLevel(log.DebugLevel) @@ -35,7 +36,7 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { args args }{ { - name: "test proper ordering for input sequence", + name: "test wil well-ordered input sequence", args: args{ inputSequence: []string{ "[1] CONTROL_V1 +1ms", @@ -132,7 +133,7 @@ func TestReliable_Reordering_withWorkers(t *testing.T) { reader := vpntest.NewPacketReader(dataOut) if ok := reader.WaitForSequence(tt.args.outputSequence, t0); !ok { - got := vpntest.PacketLog(reader.ReceivedSequence()).IDSequence() + got := reader.Log().IDSequence() t.Errorf("Reordering: got = %v, want %v", got, tt.args.outputSequence) } }) diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index c08b53fc..28b7abd3 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -8,7 +8,7 @@ import ( "github.com/ooni/minivpn/internal/model" ) -// PacketWriter is a service that writes packets into a channel. +// PacketWriter writes packets into a channel. type PacketWriter struct { // A channel where to write packets to. ch chan<- *model.Packet @@ -45,8 +45,7 @@ func (pw *PacketWriter) WriteSequence(seq []string) { type LoggedPacket struct { ID int Opcode model.Opcode - - At time.Duration + At time.Duration } // PacketLog is a sequence of LoggedPacket. @@ -64,7 +63,7 @@ func (l PacketLog) IDSequence() []int { // PacketReader reads packets from a channel. type PacketReader struct { ch <-chan *model.Packet - got []*LoggedPacket + log []*LoggedPacket } // NewPacketReader creates a new PacketReader. @@ -72,20 +71,20 @@ func NewPacketReader(ch <-chan *model.Packet) *PacketReader { return &PacketReader{ch: ch} } -// WaitForSequence blocks forever reading from the internal channel until the obtained -// sequence matches the len of the expected; it stores the received sequence and then returns +// WaitForSequence loops reading from the internal channel until the logged +// sequence matches the len of the expected sequence; it returns // true if the obtained packet ID sequence matches the expected one. func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { - got := make([]*LoggedPacket, 0) + logged := make([]*LoggedPacket, 0) for { // have we read enough packets to call it a day? - if len(got) >= len(seq) { + if len(logged) >= len(seq) { break } // no, so let's keep reading until the test runner kills us pkt := <-pr.ch - got = append( - got, + logged = append( + logged, &LoggedPacket{ ID: int(pkt.ID), Opcode: pkt.Opcode, @@ -93,11 +92,11 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { }) log.Debugf("got packet: %v", pkt.ID) } - pr.got = got - return slices.Equal(seq, PacketLog(got).IDSequence()) + pr.log = logged + return slices.Equal(seq, PacketLog(logged).IDSequence()) } -// ReceivedSequence returns the log of the received sequence. -func (pr *PacketReader) ReceivedSequence() []*LoggedPacket { - return pr.got +// Log returns the log of the received packets. +func (pr *PacketReader) Log() PacketLog { + return PacketLog(pr.log) } From 39bc2227f01a257ef89522a86af1e1f9dd981f12 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 01:45:57 +0100 Subject: [PATCH 63/78] test packetio --- .../reliable_reorder_test.go | 2 +- internal/vpntest/packetio.go | 13 ++- internal/vpntest/packetio_test.go | 81 +++++++++++++++++++ 3 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 internal/vpntest/packetio_test.go diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index a9879672..0d94a626 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -36,7 +36,7 @@ func TestReliable_Reordering_UP(t *testing.T) { args args }{ { - name: "test wil well-ordered input sequence", + name: "test with a well-ordered input sequence", args: args{ inputSequence: []string{ "[1] CONTROL_V1 +1ms", diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 28b7abd3..5f834c69 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -68,23 +68,23 @@ type PacketReader struct { // NewPacketReader creates a new PacketReader. func NewPacketReader(ch <-chan *model.Packet) *PacketReader { - return &PacketReader{ch: ch} + logged := make([]*LoggedPacket, 0) + return &PacketReader{ch: ch, log: logged} } // WaitForSequence loops reading from the internal channel until the logged // sequence matches the len of the expected sequence; it returns // true if the obtained packet ID sequence matches the expected one. func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { - logged := make([]*LoggedPacket, 0) for { // have we read enough packets to call it a day? - if len(logged) >= len(seq) { + if len(pr.log) >= len(seq) { break } // no, so let's keep reading until the test runner kills us pkt := <-pr.ch - logged = append( - logged, + pr.log = append( + pr.log, &LoggedPacket{ ID: int(pkt.ID), Opcode: pkt.Opcode, @@ -92,8 +92,7 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { }) log.Debugf("got packet: %v", pkt.ID) } - pr.log = logged - return slices.Equal(seq, PacketLog(logged).IDSequence()) + return slices.Equal(seq, PacketLog(pr.log).IDSequence()) } // Log returns the log of the received packets. diff --git a/internal/vpntest/packetio_test.go b/internal/vpntest/packetio_test.go new file mode 100644 index 00000000..3a312c7f --- /dev/null +++ b/internal/vpntest/packetio_test.go @@ -0,0 +1,81 @@ +package vpntest + +import ( + "testing" + "time" + + "github.com/ooni/minivpn/internal/model" +) + +func TestPacketReaderWriter(t *testing.T) { + type args struct { + input []string + output []int + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "simple input, simple output", + args: args{ + input: []string{ + "[1] CONTROL_V1 +0ms", + "[2] CONTROL_V1 +0ms", + "[3] CONTROL_V1 +0ms", + }, + output: []int{1, 2, 3}, + }, + want: true, + }, + { + name: "reverse in, reverse out", + args: args{ + input: []string{ + "[3] CONTROL_V1 +0ms", + "[2] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + }, + output: []int{3, 2, 1}, + }, + want: true, + }, + { + name: "holes in, holes out", + args: args{ + input: []string{ + "[0] CONTROL_V1 +0ms", + "[10] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + "[20] CONTROL_V1 +0ms", + }, + output: []int{0, 10, 1, 20}, + }, + want: true, + }, + { + name: "mismatch returns false", + args: args{ + input: []string{ + "[0] CONTROL_V1 +0ms", + "[1] CONTROL_V1 +0ms", + }, + output: []int{1, 0}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan *model.Packet) + writer := NewPacketWriter(ch) + go writer.WriteSequence(tt.args.input) + reader := NewPacketReader(ch) + if ok := reader.WaitForSequence(tt.args.output, time.Now()); ok != tt.want { + got := reader.Log().IDSequence() + t.Errorf("PacketReader.WaitForSequence() = %v, want %v", got, tt.args.output) + } + }) + } +} From 0bcb8dec003129a035c2079b48f062707e8ad205 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 12:27:34 +0100 Subject: [PATCH 64/78] parse acks --- internal/vpntest/vpntest.go | 50 ++++++++++++++++++++++++++------ internal/vpntest/vpntest_test.go | 23 +++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/internal/vpntest/vpntest.go b/internal/vpntest/vpntest.go index a6382874..bdf67fce 100644 --- a/internal/vpntest/vpntest.go +++ b/internal/vpntest/vpntest.go @@ -2,6 +2,7 @@ package vpntest import ( + "errors" "fmt" "strconv" "strings" @@ -14,11 +15,14 @@ import ( // have a compact representation of a sequence of packets, their type, and extra properties like // inter-arrival time. type TestPacket struct { + // Opcode is the OpenVPN packet opcode. + Opcode model.Opcode + // ID is the packet sequence ID int - // Opcode is the OpenVPN packet opcode. - Opcode model.Opcode + // ACKs is the ack array in this packet + ACKs []int // IAT is the inter-arrival time until the next packet is received. IAT time.Duration @@ -29,22 +33,32 @@ type TestPacket struct { func NewTestPacketFromString(s string) (*TestPacket, error) { parts := strings.Split(s, " +") - // Extracting id and opcode parts - idAndOpcode := strings.Split(parts[0], " ") - if len(idAndOpcode) != 2 { - return nil, fmt.Errorf("invalid format for ID and opcode: %s", parts[0]) + // Extracting id, opcode and ack parts + head := strings.Split(parts[0], " ") + if len(head) < 2 || len(head) > 3 { + return nil, fmt.Errorf("invalid format for ID-op-acks: %s", parts[0]) } - id, err := strconv.Atoi(strings.Trim(idAndOpcode[0], "[]")) + id, err := strconv.Atoi(strings.Trim(head[0], "[]")) if err != nil { return nil, fmt.Errorf("failed to parse id: %v", err) } - opcode, err := model.NewOpcodeFromString(idAndOpcode[1]) + opcode, err := model.NewOpcodeFromString(head[1]) if err != nil { return nil, fmt.Errorf("failed to parse opcode: %v", err) } + acks := []int{} + + if len(head) == 3 { + acks, err = parseACKs(strings.Trim(head[2], "()")) + fmt.Println("acks:", acks) + if err != nil { + return nil, fmt.Errorf("failed to parse opcode: %v", err) + } + } + // Parsing duration part iatStr := parts[1] iat, err := time.ParseDuration(iatStr) @@ -52,5 +66,23 @@ func NewTestPacketFromString(s string) (*TestPacket, error) { return nil, fmt.Errorf("failed to parse duration: %v", err) } - return &TestPacket{ID: id, Opcode: opcode, IAT: iat}, nil + return &TestPacket{ID: id, Opcode: opcode, ACKs: acks, IAT: iat}, nil +} + +var errBadACK = errors.New("wrong ack string") + +func parseACKs(s string) ([]int, error) { + acks := []int{} + h := strings.Split(s, "ack:") + if len(h) != 2 { + return acks, errBadACK + } + values := strings.Split(h[1], ",") + for _, v := range values { + n, err := strconv.Atoi(v) + if err == nil { + acks = append(acks, n) + } + } + return acks, nil } diff --git a/internal/vpntest/vpntest_test.go b/internal/vpntest/vpntest_test.go index f8c04f68..f606bd08 100644 --- a/internal/vpntest/vpntest_test.go +++ b/internal/vpntest/vpntest_test.go @@ -25,6 +25,29 @@ func TestNewTestPacketFromString(t *testing.T) { want: &TestPacket{ ID: 1, Opcode: model.P_CONTROL_V1, + ACKs: []int{}, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + { + name: "parse a testpacket with acks", + args: args{"[1] CONTROL_V1 (ack:0,1) +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + ACKs: []int{0, 1}, + IAT: time.Millisecond * 42, + }, + wantErr: false, + }, + { + name: "empty acks part", + args: args{"[1] CONTROL_V1 (ack:) +42ms"}, + want: &TestPacket{ + ID: 1, + Opcode: model.P_CONTROL_V1, + ACKs: []int{}, IAT: time.Millisecond * 42, }, wantErr: false, From d85275c589d4780938f89c1734db429e68dafd29 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 13:40:18 +0100 Subject: [PATCH 65/78] wip testing acks --- .../reliabletransport/reliable_ack_test.go | 91 +++++++++++++++++++ .../reliable_reorder_test.go | 23 ++--- internal/reliabletransport/tests.go | 21 +++++ internal/vpntest/packetio.go | 79 ++++++++++++++-- 4 files changed, 191 insertions(+), 23 deletions(-) create mode 100644 internal/reliabletransport/reliable_ack_test.go create mode 100644 internal/reliabletransport/tests.go diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go new file mode 100644 index 00000000..88664cf7 --- /dev/null +++ b/internal/reliabletransport/reliable_ack_test.go @@ -0,0 +1,91 @@ +package reliabletransport + +import ( + "testing" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/vpntest" +) + +// test that everything that is received from below is eventually ACKed to the sender. +func TestReliable_ACK(t *testing.T) { + + log.SetLevel(log.DebugLevel) + + type args struct { + inputSequence []string + wantacks int + } + + tests := []struct { + name string + args args + }{ + { + name: "ten ordered packets in", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[5] CONTROL_V1 +1ms", + "[6] CONTROL_V1 +1ms", + "[7] CONTROL_V1 +1ms", + "[8] CONTROL_V1 +1ms", + "[9] CONTROL_V1 +1ms", + "[10] CONTROL_V1 +1ms", + }, + wantacks: 10, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{} + + // just to properly initialize it, we don't care about these + s.ControlToReliable = make(chan *model.Packet) + reliableToControl := make(chan *model.Packet) + s.ReliableToControl = &reliableToControl + + // the only two channels we're going to be testing on this test + // we want to buffer enough to be safe writing to them. + dataIn := make(chan *model.Packet, 1024) + dataOut := make(chan *model.Packet, 1024) + + s.MuxerToReliable = dataIn // up + s.DataOrControlToMuxer = &dataOut // down + + workers, session := initManagers() + + // this is our session (local to us) + localSessionID := session.LocalSessionID() + remoteSessionID := session.RemoteSessionID() + + t0 := time.Now() + + // let the workers pump up the jam! + s.StartWorkers(log.Log, workers, session) + + writer := vpntest.NewPacketWriter(dataIn) + + // TODO -- need to create a session + writer.LocalSessionID = model.SessionID(remoteSessionID) + writer.RemoteSessionID = model.SessionID(localSessionID) + + go writer.WriteSequence(tt.args.inputSequence) + + reader := vpntest.NewPacketReader(dataOut) + witness := vpntest.NewWitness(reader) + + if ok := witness.VerifyACKs(tt.args.wantacks, t0); !ok { + //log.Debug(witness.Log()) + got := witness.NumberOfACKs() + t.Errorf("Reordering: got = %v, want %v", got, tt.args.wantacks) + } + }) + } +} diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index 0d94a626..c5d9b2ea 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -6,21 +6,9 @@ import ( "github.com/apex/log" "github.com/ooni/minivpn/internal/model" - "github.com/ooni/minivpn/internal/session" "github.com/ooni/minivpn/internal/vpntest" - "github.com/ooni/minivpn/internal/workers" ) -// initManagers initializes a workers manager and a session manager -func initManagers() (*workers.Manager, *session.Manager) { - w := workers.NewManager(log.Log) - s, err := session.NewManager(log.Log) - if err != nil { - panic(err) - } - return w, s -} - // test that we're able to reorder (towards TLS) whatever is received (from the muxer). func TestReliable_Reordering_UP(t *testing.T) { @@ -120,7 +108,10 @@ func TestReliable_Reordering_UP(t *testing.T) { s.ReliableToControl = &dataOut workers, session := initManagers() - sessionID := session.LocalSessionID() + + // this is our session (local to us) + localSessionID := session.LocalSessionID() + remoteSessionID := session.RemoteSessionID() t0 := time.Now() @@ -128,7 +119,11 @@ func TestReliable_Reordering_UP(t *testing.T) { s.StartWorkers(log.Log, workers, session) writer := vpntest.NewPacketWriter(dataIn) - writer.LocalSessionID = model.SessionID(sessionID) + + // TODO -- need to create a session + writer.LocalSessionID = model.SessionID(remoteSessionID) + writer.RemoteSessionID = model.SessionID(localSessionID) + go writer.WriteSequence(tt.args.inputSequence) reader := vpntest.NewPacketReader(dataOut) diff --git a/internal/reliabletransport/tests.go b/internal/reliabletransport/tests.go new file mode 100644 index 00000000..61472042 --- /dev/null +++ b/internal/reliabletransport/tests.go @@ -0,0 +1,21 @@ +package reliabletransport + +import ( + "github.com/apex/log" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// +// Common utilities for tests in this package. +// + +// initManagers initializes a workers manager and a session manager +func initManagers() (*workers.Manager, *session.Manager) { + w := workers.NewManager(log.Log) + s, err := session.NewManager(log.Log) + if err != nil { + panic(err) + } + return w, s +} diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 5f834c69..a8814416 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -13,8 +13,11 @@ type PacketWriter struct { // A channel where to write packets to. ch chan<- *model.Packet - // LocalSessionID is needed to produce packets that pass sanity checks. + // LocalSessionID is needed to produce incoming packets that pass sanity checks. LocalSessionID model.SessionID + + // RemoteSessionID is needed to produce ACKs. + RemoteSessionID model.SessionID } // NewPacketWriter creates a new PacketWriter. @@ -33,7 +36,8 @@ func (pw *PacketWriter) WriteSequence(seq []string) { p := &model.Packet{ Opcode: testPkt.Opcode, - RemoteSessionID: pw.LocalSessionID, + RemoteSessionID: pw.RemoteSessionID, + LocalSessionID: pw.LocalSessionID, ID: model.PacketID(testPkt.ID), } pw.ch <- p @@ -45,9 +49,20 @@ func (pw *PacketWriter) WriteSequence(seq []string) { type LoggedPacket struct { ID int Opcode model.Opcode + ACKs []model.PacketID At time.Duration } +// newLoggedPacket returns a pointer to LoggedPacket from a real packet and a origin of time. +func newLoggedPacket(p *model.Packet, origin time.Time) *LoggedPacket { + return &LoggedPacket{ + ID: int(p.ID), + Opcode: p.Opcode, + ACKs: p.ACKs, + At: time.Since(origin), + } +} + // PacketLog is a sequence of LoggedPacket. type PacketLog []*LoggedPacket @@ -60,6 +75,18 @@ func (l PacketLog) IDSequence() []int { return ids } +// acks filters the log and returns an array of ids that have been acked +// either as ack packets or as part of the ack array of an outgoing packet. +func (l PacketLog) acks() []int { + acks := []int{} + for _, p := range l { + for _, ack := range p.ACKs { + acks = append(acks, int(ack)) + } + } + return acks +} + // PacketReader reads packets from a channel. type PacketReader struct { ch <-chan *model.Packet @@ -83,19 +110,53 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { } // no, so let's keep reading until the test runner kills us pkt := <-pr.ch - pr.log = append( - pr.log, - &LoggedPacket{ - ID: int(pkt.ID), - Opcode: pkt.Opcode, - At: time.Since(start), - }) + pr.log = append(pr.log, newLoggedPacket(pkt, start)) log.Debugf("got packet: %v", pkt.ID) } + // TODO move the comparison to witness, leave only wait here return slices.Equal(seq, PacketLog(pr.log).IDSequence()) } +func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { + for { + // have we read enough acks to call it a day? + if len(PacketLog(pr.log).acks()) >= total { + break + } + // no, so let's keep reading until the test runner kills us + pkt := <-pr.ch + pr.log = append(pr.log, newLoggedPacket(pkt, start)) + log.Debugf("got packet: %v", pkt.ID) + } +} + // Log returns the log of the received packets. func (pr *PacketReader) Log() PacketLog { return PacketLog(pr.log) } + +// A Witness checks for different conditions over a reader +type Witness struct { + reader *PacketReader +} + +func NewWitness(r *PacketReader) *Witness { + return &Witness{r} +} + +func (w *Witness) Log() PacketLog { + return w.reader.Log() +} + +// VerifyACKs tells the underlying reader to wait for a given number of acks, +// and then checks that we have an ack sequence without holes. +func (w *Witness) VerifyACKs(total int, t time.Time) bool { + w.reader.WaitForNumberOfACKs(total, t) + // TODO: compare the range here, no holes + // TODO: probl. need start idx + return true +} + +func (w *Witness) NumberOfACKs() int { + return len(w.reader.Log().acks()) +} From 31abc06af2df7ec5ff7c8a96bac9acb7254277bc Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 16:24:33 +0100 Subject: [PATCH 66/78] ack testing utils --- internal/reliabletransport/common_test.go | 48 ++++++++++++++ .../reliabletransport/reliable_ack_test.go | 62 +++++++++++++++---- .../reliable_reorder_test.go | 9 +-- internal/reliabletransport/sender.go | 2 +- internal/reliabletransport/tests.go | 21 ------- internal/vpntest/packetio.go | 20 +++--- 6 files changed, 112 insertions(+), 50 deletions(-) create mode 100644 internal/reliabletransport/common_test.go delete mode 100644 internal/reliabletransport/tests.go diff --git a/internal/reliabletransport/common_test.go b/internal/reliabletransport/common_test.go new file mode 100644 index 00000000..ca6492e1 --- /dev/null +++ b/internal/reliabletransport/common_test.go @@ -0,0 +1,48 @@ +package reliabletransport + +import ( + "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/workers" +) + +// +// Common utilities for tests in this package. +// + +// initManagers initializes a workers manager and a session manager. +func initManagers() (*workers.Manager, *session.Manager) { + w := workers.NewManager(log.Log) + s, err := session.NewManager(log.Log) + if err != nil { + panic(err) + } + return w, s +} + +// newRandomSessionID returns a random session ID to initialize mock sessions. +func newRandomSessionID() model.SessionID { + b, err := bytesx.GenRandomBytes(8) + if err != nil { + panic(err) + } + return model.SessionID(b) +} + +func ackSetFromInts(s []int) *ackSet { + acks := make([]model.PacketID, 0) + for _, i := range s { + acks = append(acks, model.PacketID(i)) + } + return newACKSet(acks...) +} + +func ackSetFromRange(start, total int) *ackSet { + acks := make([]model.PacketID, 0) + for i := 0; i < total; i++ { + acks = append(acks, model.PacketID(start+i)) + } + return newACKSet(acks...) +} diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index 88664cf7..8fc9f5ed 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -1,6 +1,7 @@ package reliabletransport import ( + "slices" "testing" "time" @@ -16,6 +17,7 @@ func TestReliable_ACK(t *testing.T) { type args struct { inputSequence []string + start int wantacks int } @@ -38,9 +40,38 @@ func TestReliable_ACK(t *testing.T) { "[9] CONTROL_V1 +1ms", "[10] CONTROL_V1 +1ms", }, + start: 1, wantacks: 10, }, }, + { + name: "five ordered packets with offset", + args: args{ + inputSequence: []string{ + "[100] CONTROL_V1 +1ms", + "[101] CONTROL_V1 +1ms", + "[102] CONTROL_V1 +1ms", + "[103] CONTROL_V1 +1ms", + "[104] CONTROL_V1 +1ms", + }, + start: 100, + wantacks: 5, + }, + }, + { + name: "five reversed packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -48,7 +79,9 @@ func TestReliable_ACK(t *testing.T) { // just to properly initialize it, we don't care about these s.ControlToReliable = make(chan *model.Packet) - reliableToControl := make(chan *model.Packet) + // this one up to control/tls also needs to be buffered because otherwise + // we'll block on the receiver when delivering up. + reliableToControl := make(chan *model.Packet, 1024) s.ReliableToControl = &reliableToControl // the only two channels we're going to be testing on this test @@ -61,10 +94,6 @@ func TestReliable_ACK(t *testing.T) { workers, session := initManagers() - // this is our session (local to us) - localSessionID := session.LocalSessionID() - remoteSessionID := session.RemoteSessionID() - t0 := time.Now() // let the workers pump up the jam! @@ -72,19 +101,28 @@ func TestReliable_ACK(t *testing.T) { writer := vpntest.NewPacketWriter(dataIn) - // TODO -- need to create a session - writer.LocalSessionID = model.SessionID(remoteSessionID) - writer.RemoteSessionID = model.SessionID(localSessionID) + // initialize a mock session ID for our peer + peerSessionID := newRandomSessionID() + + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = peerSessionID + session.SetRemoteSessionID(peerSessionID) go writer.WriteSequence(tt.args.inputSequence) reader := vpntest.NewPacketReader(dataOut) witness := vpntest.NewWitness(reader) - if ok := witness.VerifyACKs(tt.args.wantacks, t0); !ok { - //log.Debug(witness.Log()) - got := witness.NumberOfACKs() - t.Errorf("Reordering: got = %v, want %v", got, tt.args.wantacks) + if ok := witness.VerifyNumberOfACKs(tt.args.start, tt.args.wantacks, t0); !ok { + got := len(witness.Log().ACKs()) + t.Errorf("TestACK: got = %v, want %v", got, tt.args.wantacks) + } + gotAckSet := ackSetFromInts(witness.Log().ACKs()).sorted() + wantAckSet := ackSetFromRange(tt.args.start, tt.args.wantacks).sorted() + + if !slices.Equal(gotAckSet, wantAckSet) { + t.Errorf("TestACK: got = %v, want %v", gotAckSet, wantAckSet) + } }) } diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index c5d9b2ea..b4e94823 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -109,10 +109,6 @@ func TestReliable_Reordering_UP(t *testing.T) { workers, session := initManagers() - // this is our session (local to us) - localSessionID := session.LocalSessionID() - remoteSessionID := session.RemoteSessionID() - t0 := time.Now() // let the workers pump up the jam! @@ -120,9 +116,8 @@ func TestReliable_Reordering_UP(t *testing.T) { writer := vpntest.NewPacketWriter(dataIn) - // TODO -- need to create a session - writer.LocalSessionID = model.SessionID(remoteSessionID) - writer.RemoteSessionID = model.SessionID(localSessionID) + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = newRandomSessionID() go writer.WriteSequence(tt.args.inputSequence) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 53578c43..c9013e65 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -237,7 +237,7 @@ type ackSet struct { m map[model.PacketID]bool } -// NewACKSet creates a new empty ACK set. +// newACKSet creates a new empty ACK set. func newACKSet(ids ...model.PacketID) *ackSet { m := make(map[model.PacketID]bool) for _, id := range ids { diff --git a/internal/reliabletransport/tests.go b/internal/reliabletransport/tests.go deleted file mode 100644 index 61472042..00000000 --- a/internal/reliabletransport/tests.go +++ /dev/null @@ -1,21 +0,0 @@ -package reliabletransport - -import ( - "github.com/apex/log" - "github.com/ooni/minivpn/internal/session" - "github.com/ooni/minivpn/internal/workers" -) - -// -// Common utilities for tests in this package. -// - -// initManagers initializes a workers manager and a session manager -func initManagers() (*workers.Manager, *session.Manager) { - w := workers.NewManager(log.Log) - s, err := session.NewManager(log.Log) - if err != nil { - panic(err) - } - return w, s -} diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index a8814416..499009d8 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -1,6 +1,7 @@ package vpntest import ( + "fmt" "slices" "time" @@ -41,6 +42,7 @@ func (pw *PacketWriter) WriteSequence(seq []string) { ID: model.PacketID(testPkt.ID), } pw.ch <- p + fmt.Println("<< wrote", p.ID) time.Sleep(testPkt.IAT) } } @@ -75,9 +77,9 @@ func (l PacketLog) IDSequence() []int { return ids } -// acks filters the log and returns an array of ids that have been acked +// ACKs filters the log and returns an array of ids that have been acked // either as ack packets or as part of the ack array of an outgoing packet. -func (l PacketLog) acks() []int { +func (l PacketLog) ACKs() []int { acks := []int{} for _, p := range l { for _, ack := range p.ACKs { @@ -120,7 +122,7 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { for { // have we read enough acks to call it a day? - if len(PacketLog(pr.log).acks()) >= total { + if len(PacketLog(pr.log).ACKs()) >= total { break } // no, so let's keep reading until the test runner kills us @@ -149,14 +151,14 @@ func (w *Witness) Log() PacketLog { } // VerifyACKs tells the underlying reader to wait for a given number of acks, -// and then checks that we have an ack sequence without holes. -func (w *Witness) VerifyACKs(total int, t time.Time) bool { +// returns true if we have the same number of acks. +func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { w.reader.WaitForNumberOfACKs(total, t) - // TODO: compare the range here, no holes - // TODO: probl. need start idx - return true + return len(w.Log().ACKs()) == total } +/* func (w *Witness) NumberOfACKs() int { - return len(w.reader.Log().acks()) + return len(w.Log().ACKs()) } +*/ From 7b92e6ce423105852b9aedc378de5652b8080f9a Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 16:37:50 +0100 Subject: [PATCH 67/78] wip: ack duplicates compare set --- .../reliabletransport/reliable_ack_test.go | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index 8fc9f5ed..e43b0337 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -60,6 +60,39 @@ func TestReliable_ACK(t *testing.T) { }, { name: "five reversed packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, + { + name: "ten unordered packets with duplicates", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[5] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, + }, + }, + { + name: "ten packets", args: args{ inputSequence: []string{ "[5] CONTROL_V1 +1ms", From 6bdc2d556cb605020f8367106b56a85bb32465a3 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 17:16:10 +0100 Subject: [PATCH 68/78] run the new tests in internal too --- Makefile | 4 ++-- internal/vpntest/packetio.go | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index eaad788c..620f0bbb 100644 --- a/Makefile +++ b/Makefile @@ -31,10 +31,10 @@ test: GOFLAGS='-count=1' go test -v ./... test-coverage: - go test -coverprofile=coverage.out ./vpn + go test -coverprofile=coverage.out ./vpn ./internal/... test-coverage-threshold: - go test --short -coverprofile=cov-threshold.out ./vpn + go test --short -coverprofile=cov-threshold.out ./vpn ./internal/... ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} test-short: diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 499009d8..973032f4 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -1,7 +1,6 @@ package vpntest import ( - "fmt" "slices" "time" @@ -42,7 +41,6 @@ func (pw *PacketWriter) WriteSequence(seq []string) { ID: model.PacketID(testPkt.ID), } pw.ch <- p - fmt.Println("<< wrote", p.ID) time.Sleep(testPkt.IAT) } } @@ -77,13 +75,16 @@ func (l PacketLog) IDSequence() []int { return ids } -// ACKs filters the log and returns an array of ids that have been acked +// ACKs filters the log and returns an array of unique ids that have been acked // either as ack packets or as part of the ack array of an outgoing packet. func (l PacketLog) ACKs() []int { acks := []int{} for _, p := range l { for _, ack := range p.ACKs { - acks = append(acks, int(ack)) + a := int(ack) + if !contains(acks, a) { + acks = append(acks, a) + } } } return acks @@ -157,8 +158,13 @@ func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { return len(w.Log().ACKs()) == total } -/* -func (w *Witness) NumberOfACKs() int { - return len(w.Log().ACKs()) +// contains check if the element is in the slice. this is expensive, but it's only +// for tests and the alternative is to make ackSet public. +func contains(slice []int, target int) bool { + for _, item := range slice { + if item == target { + return true + } + } + return false } -*/ From 80e8e5053056a4781240440b03255661c7683c12 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 17:46:08 +0100 Subject: [PATCH 69/78] x --- .../reliabletransport/reliable_ack_test.go | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index e43b0337..fe48048b 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -91,20 +91,22 @@ func TestReliable_ACK(t *testing.T) { wantacks: 5, }, }, - { - name: "ten packets", - args: args{ - inputSequence: []string{ - "[5] CONTROL_V1 +1ms", - "[1] CONTROL_V1 +1ms", - "[3] CONTROL_V1 +1ms", - "[2] CONTROL_V1 +1ms", - "[4] CONTROL_V1 +1ms", + /* + { + name: "a burst of packets", + args: args{ + inputSequence: []string{ + "[5] CONTROL_V1 +1ms", + "[1] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + }, + start: 1, + wantacks: 5, }, - start: 1, - wantacks: 5, }, - }, + */ } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From c84ad6515064d6fb77f729cd17cbfa8cfaeddeac Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:33:39 +0100 Subject: [PATCH 70/78] return if ack error --- internal/reliabletransport/sender.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 2abbd9e9..989e81f9 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -105,6 +105,7 @@ func (ws *workersState) blockOnTryingToSend(sender *reliableSender, ticker *time ACK, err := ws.sessionManager.NewACKForPacketIDs(sender.NextPacketIDsToACK()) if err != nil { ws.logger.Warnf("moveDownWorker: tryToSend: cannot create ack: %v", err.Error()) + return } ACK.Log(ws.logger, model.DirectionOutgoing) select { From cf17b413b40cbdff98f4e1cb254dc34359edfe80 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:36:38 +0100 Subject: [PATCH 71/78] add targets for testing internal path --- .github/workflows/build-refactor.yml | 57 ++++++++++++++++++++++++++++ Makefile | 11 +++++- 2 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/build-refactor.yml diff --git a/.github/workflows/build-refactor.yml b/.github/workflows/build-refactor.yml new file mode 100644 index 00000000..a822bf03 --- /dev/null +++ b/.github/workflows/build-refactor.yml @@ -0,0 +1,57 @@ +name: build-refactor +# this action is covering internal/ tree with go1.21 + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + short-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Run short tests + run: go test --short -cover ./internal/... + + gosec: + runs-on: ubuntu-latest + env: + GO111MODULE: on + steps: + - name: Checkout Source + uses: actions/checkout@v2 + - name: Run Gosec security scanner + uses: securego/gosec@master + with: + args: '-no-fail ./...' + + coverage-threshold: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Ensure coverage threshold + run: make test-coverage-threshold-refactor + + integration: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: setup go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: run integration tests + run: go test -v ./tests/integration + diff --git a/Makefile b/Makefile index 620f0bbb..dec4c815 100644 --- a/Makefile +++ b/Makefile @@ -31,10 +31,17 @@ test: GOFLAGS='-count=1' go test -v ./... test-coverage: - go test -coverprofile=coverage.out ./vpn ./internal/... + go test -coverprofile=coverage.out ./vpn + +test-coverage-refactor: + go test -coverprofile=coverage.out ./internal/... test-coverage-threshold: - go test --short -coverprofile=cov-threshold.out ./vpn ./internal/... + go test --short -coverprofile=cov-threshold.out ./vpn + ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} + +test-coverage-threshold-refactor: + go test --short -coverprofile=cov-threshold.out ./internal/... ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} test-short: From cd00bda457d3f0500f8307f5bc0e1e10536ee04a Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:40:40 +0100 Subject: [PATCH 72/78] relax coverage threshold for now --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index dec4c815..0b330149 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ TARGET ?= "1.1.1.1" COUNT ?= 5 TIMEOUT ?= 10 LOCAL_TARGET := $(shell ip -4 addr show docker0 | grep 'inet ' | awk '{print $$2}' | cut -f 1 -d /) -COVERAGE_THRESHOLD := 88 +COVERAGE_THRESHOLD := 80 FLAGS=-ldflags="-w -s -buildid=none -linkmode=external" -buildmode=pie -buildvcs=false build: From 893f4e07503b72256ebe3f9965230c17ab9f7a4e Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:42:47 +0100 Subject: [PATCH 73/78] coverage for refactor --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 0b330149..881d2423 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,8 @@ test-coverage-threshold: ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} test-coverage-threshold-refactor: - go test --short -coverprofile=cov-threshold.out ./internal/... - ./scripts/go-coverage-check.sh cov-threshold.out ${COVERAGE_THRESHOLD} + go test --short -coverprofile=cov-threshold-refactor.out ./internal/... + ./scripts/go-coverage-check.sh cov-threshold-refactor.out ${COVERAGE_THRESHOLD} test-short: go test -race -short -v ./... From 7dba3b7d4e07977c41f4947d7d25fc0a1fd46bcb Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:43:36 +0100 Subject: [PATCH 74/78] igore coverage output --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 6b0e6c86..0d839243 100644 --- a/.gitignore +++ b/.gitignore @@ -8,8 +8,6 @@ *.swo *.pem *.ovpn +/*.out data/* measurements/* -coverage.out -coverage-ping.out -cov-threshold.out From b902a49bf5dd3dd4d1fd6a518de6cffddba3b229 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Fri, 2 Feb 2024 19:49:55 +0100 Subject: [PATCH 75/78] remove extra comment --- internal/reliabletransport/receiver.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index fc6aa61c..180a0be8 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -39,7 +39,6 @@ func (ws *workersState) moveUpWorker() { // TODO: are we handling a HARD_RESET_V2 while we're doing a handshake? // I'm not sure that's a valid behavior for a server. // We should be able to deterministically test how this affects the state machine. - // log.Printf("%s session check: %v\n", packet.Opcode, bytes.Equal(packet.LocalSessionID[:], ws.sessionManager.RemoteSessionID())) // drop a packet that is not for our session if !bytes.Equal([]byte(packet.RemoteSessionID[:]), []byte(ws.sessionManager.LocalSessionID())) { From 56ab1339ac316f7f9731605732f1f00f7ae27328 Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:45:08 +0100 Subject: [PATCH 76/78] Update internal/model/packet.go Co-authored-by: Simone Basso --- internal/model/packet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/model/packet.go b/internal/model/packet.go index 65f07afe..6c5538d2 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -33,7 +33,7 @@ const ( ) // NewOpcodeFromString returns an opcode from a string representation, and an error if it cannot parse the opcode -// representation. The zero return value is invalid. +// representation. The zero return value is invalid and always coupled with a non-nil error. func NewOpcodeFromString(s string) (Opcode, error) { switch s { case "CONTROL_HARD_RESET_CLIENT_V1": From 20deba97a7b67319a9cc626c8cc44720fa31c183 Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:46:05 +0100 Subject: [PATCH 77/78] Update internal/vpntest/packetio.go Co-authored-by: Simone Basso --- internal/vpntest/packetio.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 973032f4..b5cfaebe 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -116,7 +116,7 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { pr.log = append(pr.log, newLoggedPacket(pkt, start)) log.Debugf("got packet: %v", pkt.ID) } - // TODO move the comparison to witness, leave only wait here + // TODO(ainghazal): move the comparison to witness, leave only wait here return slices.Equal(seq, PacketLog(pr.log).IDSequence()) } From 8a042bafdf702b8ce3925fbc477d1001466e9855 Mon Sep 17 00:00:00 2001 From: ain ghazal Date: Wed, 7 Feb 2024 17:59:05 +0100 Subject: [PATCH 78/78] address comments from code review --- internal/model/packet.go | 7 +------ internal/reliabletransport/common_test.go | 5 ++--- internal/reliabletransport/receiver.go | 2 +- internal/runtimex/runtimex.go | 10 ++++++++++ internal/vpntest/packetio.go | 4 +++- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/internal/model/packet.go b/internal/model/packet.go index 6c5538d2..f1256a4b 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -352,11 +352,6 @@ func (p *Packet) Log(logger Logger, direction int) { return } - payloadLen := 0 - if p.Payload != nil { - payloadLen = len(p.Payload) - } - logger.Debugf( "%s %s {id=%d, acks=%v} localID=%x remoteID=%x [%d bytes]", dir, @@ -365,6 +360,6 @@ func (p *Packet) Log(logger Logger, direction int) { p.ACKs, p.LocalSessionID, p.RemoteSessionID, - payloadLen, + len(p.Payload), ) } diff --git a/internal/reliabletransport/common_test.go b/internal/reliabletransport/common_test.go index ca6492e1..63581f48 100644 --- a/internal/reliabletransport/common_test.go +++ b/internal/reliabletransport/common_test.go @@ -4,6 +4,7 @@ import ( "github.com/apex/log" "github.com/ooni/minivpn/internal/bytesx" "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/runtimex" "github.com/ooni/minivpn/internal/session" "github.com/ooni/minivpn/internal/workers" ) @@ -16,9 +17,7 @@ import ( func initManagers() (*workers.Manager, *session.Manager) { w := workers.NewManager(log.Log) s, err := session.NewManager(log.Log) - if err != nil { - panic(err) - } + runtimex.PanicOnError(err, "cannot create session manager") return w, s } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 180a0be8..4065f54b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -41,7 +41,7 @@ func (ws *workersState) moveUpWorker() { // We should be able to deterministically test how this affects the state machine. // drop a packet that is not for our session - if !bytes.Equal([]byte(packet.RemoteSessionID[:]), []byte(ws.sessionManager.LocalSessionID())) { + if !bytes.Equal(packet.RemoteSessionID[:], ws.sessionManager.LocalSessionID()) { ws.logger.Warnf( "%s: packet with invalid RemoteSessionID: expected %x; got %x", workerName, diff --git a/internal/runtimex/runtimex.go b/internal/runtimex/runtimex.go index 5e135484..2403c825 100644 --- a/internal/runtimex/runtimex.go +++ b/internal/runtimex/runtimex.go @@ -1,6 +1,8 @@ // Package runtimex contains [runtime] extensions. package runtimex +import "fmt" + // PanicIfFalse calls panic with the given message if the given statement is false. func PanicIfFalse(stmt bool, message interface{}) { if !stmt { @@ -17,3 +19,11 @@ func PanicIfTrue(stmt bool, message interface{}) { // Assert calls panic with the given message if the given statement is false. var Assert = PanicIfFalse + +// PanicOnError calls panic() if err is not nil. The type passed to panic +// is an error type wrapping the original error. +func PanicOnError(err error, message string) { + if err != nil { + panic(fmt.Errorf("%s: %w", message, err)) + } +} diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index b5cfaebe..923bfc84 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -143,15 +143,17 @@ type Witness struct { reader *PacketReader } +// NewWitness constructs a Witness from a [PacketReader]. func NewWitness(r *PacketReader) *Witness { return &Witness{r} } +// Log returns the packet log from the internal reader this witness uses. func (w *Witness) Log() PacketLog { return w.reader.Log() } -// VerifyACKs tells the underlying reader to wait for a given number of acks, +// VerifyNumberOfACKs tells the underlying reader to wait for a given number of acks, // returns true if we have the same number of acks. func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { w.reader.WaitForNumberOfACKs(total, t)