Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: introduce tun #54

Merged
merged 9 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/tun/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Package tun is the public interface for the minivpn application. It exposes a tun device interface
// where the user of the application can write to and read from.
package tun
129 changes: 129 additions & 0 deletions internal/tun/setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package tun

import (
"github.com/ooni/minivpn/internal/controlchannel"
"github.com/ooni/minivpn/internal/datachannel"
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/networkio"
"github.com/ooni/minivpn/internal/packetmuxer"
"github.com/ooni/minivpn/internal/reliabletransport"
"github.com/ooni/minivpn/internal/runtimex"
"github.com/ooni/minivpn/internal/session"
"github.com/ooni/minivpn/internal/tlssession"
"github.com/ooni/minivpn/internal/workers"
)

// connectChannel connects an existing channel (a "signal" in Qt terminology)
// to a nil pointer to channel (a "slot" in Qt terminology).
func connectChannel[T any](signal chan T, slot **chan T) {
runtimex.Assert(signal != nil, "signal is nil")
runtimex.Assert(slot == nil || *slot == nil, "slot or *slot aren't nil")
*slot = &signal
}

// startWorkers starts all the workers. See the [ARCHITECTURE]
// file for more information about the workers.
//
// [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md
func startWorkers(logger model.Logger, sessionManager *session.Manager,
tunDevice *TUN, conn networkio.FramingConn, options *model.Options) *workers.Manager {
// create a workers manager
workersManager := workers.NewManager()

// create the networkio service.
nio := &networkio.Service{
MuxerToNetwork: make(chan []byte, 1<<5),
NetworkToMuxer: nil, // ok
}

// create the packetmuxer service.
muxer := &packetmuxer.Service{
MuxerToReliable: nil, // ok
MuxerToData: nil, // ok
NotifyTLS: nil,
HardReset: make(chan any, 1),
DataOrControlToMuxer: make(chan *model.Packet),
MuxerToNetwork: nil, // ok
NetworkToMuxer: make(chan []byte),
}

// connect networkio and packetmuxer
connectChannel(nio.MuxerToNetwork, &muxer.MuxerToNetwork)
connectChannel(muxer.NetworkToMuxer, &nio.NetworkToMuxer)

// create the datachannel service.
datach := &datachannel.Service{
MuxerToData: make(chan *model.Packet),
DataOrControlToMuxer: nil, // ok
KeyReady: make(chan *session.DataChannelKey, 1),
TUNToData: tunDevice.tunDown,
DataToTUN: tunDevice.tunUp,
}

// connect the packetmuxer and the datachannel
connectChannel(datach.MuxerToData, &muxer.MuxerToData)
connectChannel(muxer.DataOrControlToMuxer, &datach.DataOrControlToMuxer)

// create the reliabletransport service.
rel := &reliabletransport.Service{
DataOrControlToMuxer: nil, // ok
ControlToReliable: make(chan *model.Packet),
MuxerToReliable: make(chan *model.Packet),
ReliableToControl: nil, // ok
}

// connect reliable service and packetmuxer.
connectChannel(rel.MuxerToReliable, &muxer.MuxerToReliable)
connectChannel(muxer.DataOrControlToMuxer, &rel.DataOrControlToMuxer)

// create the controlchannel service.
ctrl := &controlchannel.Service{
NotifyTLS: nil, // ok
ControlToReliable: nil, // ok
ReliableToControl: make(chan *model.Packet),
TLSRecordToControl: make(chan []byte),
TLSRecordFromControl: nil, // ok
}

// connect the reliable service and the controlchannel service
connectChannel(rel.ControlToReliable, &ctrl.ControlToReliable)
connectChannel(ctrl.ReliableToControl, &rel.ReliableToControl)

// create the tlssession service
tlsx := &tlssession.Service{
NotifyTLS: make(chan *model.Notification, 1),
KeyUp: nil,
TLSRecordUp: make(chan []byte),
TLSRecordDown: nil,
}

// connect the tlsstate service and the controlchannel service
connectChannel(tlsx.NotifyTLS, &ctrl.NotifyTLS)
connectChannel(tlsx.TLSRecordUp, &ctrl.TLSRecordFromControl)
connectChannel(ctrl.TLSRecordToControl, &tlsx.TLSRecordDown)

// connect tlsstate service and the datachannel service
connectChannel(datach.KeyReady, &tlsx.KeyUp)

// connect the muxer and the tlsstate service
connectChannel(tlsx.NotifyTLS, &muxer.NotifyTLS)

logger.Debugf("%T: %+v", nio, nio)
logger.Debugf("%T: %+v", muxer, muxer)
logger.Debugf("%T: %+v", rel, rel)
logger.Debugf("%T: %+v", ctrl, ctrl)
logger.Debugf("%T: %+v", tlsx, tlsx)

// start all the workers
nio.StartWorkers(logger, workersManager, conn)
muxer.StartWorkers(logger, workersManager, sessionManager)
rel.StartWorkers(logger, workersManager, sessionManager)
ctrl.StartWorkers(logger, workersManager, sessionManager)
datach.StartWorkers(logger, workersManager, sessionManager, options)
tlsx.StartWorkers(logger, workersManager, sessionManager, options)

// tell the packetmuxer that it should handshake ASAP
muxer.HardReset <- true

return workersManager
}
208 changes: 208 additions & 0 deletions internal/tun/tun.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package tun

import (
"bytes"
"context"
"errors"
"net"
"os"
"sync"
"time"

"github.com/apex/log"
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/networkio"
"github.com/ooni/minivpn/internal/session"
)

var (
ErrInitializationTimeout = errors.New("timeout while waiting for TUN to start")
)

// StartTUN initializes and starts the TUN device over the vpn.
// If the passed context expires before the TUN device is ready,
func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options) (*TUN, error) {
// create a session
sessionManager, err := session.NewManager(log.Log)
if err != nil {
return nil, err
}

// create the TUN that will OWN the connection
tunnel := newTUN(log.Log, conn, sessionManager)

// start all the workers
workers := startWorkers(log.Log, sessionManager, tunnel, conn, options)
tunnel.whenDone(func() {
workers.StartShutdown()
workers.WaitWorkersShutdown()
})

// Await for the signal from the session manager to tell us we're ready to start accepting data.
// In practice, this means that we already have a valid TunnelInfo at this point
// (i.e., three way handshake has completed, and we have valid keys).

select {
case <-ctx.Done():
return nil, ErrInitializationTimeout
case <-sessionManager.Ready:
return tunnel, nil
}
}

// TUN allows to use channels to read and write. It also OWNS the underlying connection.
// TUN implements net.Conn
type TUN struct {
// ensure idempotency.
closeOnce sync.Once

// conn is the underlying connection.
conn networkio.FramingConn

// hangup is used to let methods know the connection is closed.
hangup chan any

// logger implements model.Logger
logger model.Logger

// network is the underlying network for the passed [networkio.FramingConn].
network string

// used to buffer reads from above.
readBuffer *bytes.Buffer

// readDeadline is used to set the read deadline.
readDeadline tunDeadline

// session is the session manager
session *session.Manager

// tunDown moves bytes down to the data channel.
tunDown chan []byte

// tunUp moves bytes up from the data channel.
tunUp chan []byte

// callback to be executed on shutdown.
whenDoneFn func()

// writeDeadline is used to set the write deadline.
writeDeadline tunDeadline
}

// newTUN creates a new TUN.
// This function TAKES OWNERSHIP of the conn.
func newTUN(logger model.Logger, conn networkio.FramingConn, session *session.Manager) *TUN {
return &TUN{
closeOnce: sync.Once{},
conn: conn,
hangup: make(chan any),
logger: logger,
network: conn.LocalAddr().Network(),
readBuffer: &bytes.Buffer{},
readDeadline: makeTUNDeadline(),
session: session,
tunDown: make(chan []byte),
tunUp: make(chan []byte, 10),
// this function is explicitely set empty so that we can safely use a callback even if not set.
whenDoneFn: func() {},
writeDeadline: makeTUNDeadline(),
}
}

// whenDone registers a callback to be called on shutdown.
// This is useful to propagate shutdown to workers.
func (t *TUN) whenDone(fn func()) {
t.whenDoneFn = fn
}

func (t *TUN) Close() error {
t.closeOnce.Do(func() {
close(t.hangup)
// We OWN the connection
t.conn.Close()
// execute any shutdown callback
t.whenDoneFn()
})
return nil
}

func (t *TUN) Read(data []byte) (int, error) {
for {
count, _ := t.readBuffer.Read(data)
if count > 0 {
// log.Printf("[tunbio] received %d bytes", len(data))
return count, nil
}
if isClosedChan(t.readDeadline.wait()) {
return 0, os.ErrDeadlineExceeded
}
select {
case extra := <-t.tunUp:
t.readBuffer.Write(extra)
case <-t.hangup:
return 0, net.ErrClosed
case <-t.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
}

func (t *TUN) Write(data []byte) (int, error) {
if isClosedChan(t.writeDeadline.wait()) {
return 0, os.ErrDeadlineExceeded
}
select {
case t.tunDown <- data:
return len(data), nil
case <-t.hangup:
return 0, net.ErrClosed
case <-t.writeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}

func (t *TUN) LocalAddr() net.Addr {
ip := t.session.TunnelInfo().IP
return &tunBioAddr{ip, t.network}
}

func (t *TUN) RemoteAddr() net.Addr {
gw := t.session.TunnelInfo().GW
return &tunBioAddr{gw, t.network}
}

func (t *TUN) SetDeadline(tm time.Time) error {
t.readDeadline.set(tm)
t.writeDeadline.set(tm)
return nil
}

func (t *TUN) SetReadDeadline(tm time.Time) error {
t.readDeadline.set(tm)
return nil
}

func (t *TUN) SetWriteDeadline(tm time.Time) error {
t.writeDeadline.set(tm)
return nil
}

// tunBioAddr is the type of address returned by [*TUN]
type tunBioAddr struct {
addr string
net string
}

var _ net.Addr = &tunBioAddr{}

// Network implements net.Addr. It returns the network
// for the underlying connection.
func (t *tunBioAddr) Network() string {
return t.net
}

// String implements net.Addr
func (t *tunBioAddr) String() string {
return t.addr
}
Loading