Skip to content

Commit

Permalink
Fix WireGuard connection cleanup if start fails
Browse files Browse the repository at this point in the history
Sometimes connection Start could finish setup tunnel but if
handshake wait timeout occurs it will not cleanup. Now Start always calls Stop
on error which does needed cleanup logic.
  • Loading branch information
anjmao committed Jan 31, 2020
1 parent 158f11d commit 26ba14c
Show file tree
Hide file tree
Showing 11 changed files with 531 additions and 129 deletions.
16 changes: 15 additions & 1 deletion cmd/di_desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package cmd

import (
"fmt"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/mysteriumnetwork/node/communication"
Expand All @@ -44,6 +45,8 @@ import (
openvpn_service "github.com/mysteriumnetwork/node/services/openvpn/service"
"github.com/mysteriumnetwork/node/services/wireguard"
wireguard_connection "github.com/mysteriumnetwork/node/services/wireguard/connection"
"github.com/mysteriumnetwork/node/services/wireguard/endpoint"
"github.com/mysteriumnetwork/node/services/wireguard/resources"
wireguard_service "github.com/mysteriumnetwork/node/services/wireguard/service"
"github.com/mysteriumnetwork/node/session"
"github.com/mysteriumnetwork/node/session/connectivity"
Expand Down Expand Up @@ -290,8 +293,19 @@ func (di *Dependencies) registerConnections(nodeOptions node.Options) {

func (di *Dependencies) registerWireguardConnection(nodeOptions node.Options) {
wireguard.Bootstrap()
dnsManager := wireguard_connection.NewDNSManager()
handshakeWaiter := wireguard_connection.NewHandshakeWaiter()
endpointFactory := func() (wireguard.ConnectionEndpoint, error) {
resourceAllocator := resources.NewAllocator(nil, wireguard_service.DefaultOptions.Subnet)
return endpoint.NewConnectionEndpoint(nil, resourceAllocator, 0)
}
connFactory := func() (connection.Connection, error) {
return wireguard_connection.NewConnection(nodeOptions.Directories.Config, di.IPResolver, di.NATPinger)
opts := wireguard_connection.Options{
DNSConfigDir: nodeOptions.Directories.Config,
StatsUpdateInterval: 1 * time.Second,
HandshakeTimeout: 1 * time.Minute,
}
return wireguard_connection.NewConnection(opts, di.IPResolver, di.NATPinger, endpointFactory, dnsManager, handshakeWaiter)
}
di.ConnectionRegistry.Register(wireguard.ServiceType, connFactory)
}
Expand Down
3 changes: 3 additions & 0 deletions e2e/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ func consumerConnectFlow(t *testing.T, tequilapi *tequilapi_client.Client, consu
assert.Equal(t, connectionStatus.SessionID, se.SessionID)
assert.Equal(t, "New", se.Status)

// Keep active connection for some time to check for statistics change.
time.Sleep(5 * time.Second)

err = tequilapi.ConnectionDestroy()
assert.NoError(t, err)

Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,6 @@ gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bl
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
Expand Down
9 changes: 8 additions & 1 deletion mobile/mysterium/entrypoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"github.com/mysteriumnetwork/node/identity/registry"
wireguard_connection "github.com/mysteriumnetwork/node/services/wireguard/connection"
"github.com/mysteriumnetwork/node/session/pingpong"
pc "github.com/mysteriumnetwork/payments/crypto"
"github.com/pkg/errors"
Expand Down Expand Up @@ -517,10 +518,16 @@ func (mb *MobileNode) OverrideOpenvpnConnection(tunnelSetup Openvpn3TunnelSetup)
func (mb *MobileNode) OverrideWireguardConnection(wgTunnelSetup WireguardTunnelSetup) {
wireguard.Bootstrap()
factory := func() (connection.Connection, error) {
opts := wireGuardOptions{
statsUpdateInterval: 1 * time.Second,
handshakeTimeout: 1 * time.Minute,
}
return NewWireGuardConnection(
wgTunnelSetup,
opts,
newWireguardDevice(wgTunnelSetup),
mb.ipResolver,
mb.natPinger,
wireguard_connection.NewHandshakeWaiter(),
)
}
mb.connectionRegistry.Register(wireguard.ServiceType, factory)
Expand Down
175 changes: 106 additions & 69 deletions mobile/mysterium/wireguard_connection_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import (
"bufio"
"encoding/json"
"strings"
"sync"
"time"

"github.com/mysteriumnetwork/node/consumer"
"github.com/mysteriumnetwork/node/core/connection"
"github.com/mysteriumnetwork/node/core/ip"
"github.com/mysteriumnetwork/node/nat/traversal"
"github.com/mysteriumnetwork/node/services/wireguard"
wireguard_connection "github.com/mysteriumnetwork/node/services/wireguard/connection"
"github.com/mysteriumnetwork/node/services/wireguard/key"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
Expand All @@ -53,8 +55,13 @@ type WireguardTunnelSetup interface {
SetSessionName(session string)
}

type wireGuardOptions struct {
statsUpdateInterval time.Duration
handshakeTimeout time.Duration
}

// NewWireGuardConnection creates a new wireguard connection
func NewWireGuardConnection(tunnelSetup WireguardTunnelSetup, ipResolver ip.Resolver, natPinger natPinger) (connection.Connection, error) {
func NewWireGuardConnection(opts wireGuardOptions, device wireguardDevice, ipResolver ip.Resolver, natPinger natPinger, handshakeWaiter wireguard_connection.HandshakeWaiter) (connection.Connection, error) {
privateKey, err := key.GeneratePrivateKey()
if err != nil {
return nil, err
Expand All @@ -66,24 +73,28 @@ func NewWireGuardConnection(tunnelSetup WireguardTunnelSetup, ipResolver ip.Reso
pingerStop: make(chan struct{}),
stateCh: make(chan connection.State, 100),
statisticsCh: make(chan consumer.SessionStatistics, 100),
opts: opts,
device: device,
privateKey: privateKey,
tunnelSetup: tunnelSetup,
ipResolver: ipResolver,
natPinger: natPinger,
handshakeWaiter: handshakeWaiter,
}, nil
}

type wireguardConnection struct {
closeOnce sync.Once
done chan struct{}
pingerStop chan struct{}
statsCheckerStop chan struct{}
stateCh chan connection.State
statisticsCh chan consumer.SessionStatistics
opts wireGuardOptions
privateKey string
tunnelSetup WireguardTunnelSetup
device *device.Device
device wireguardDevice
ipResolver ip.Resolver
natPinger natPinger
handshakeWaiter wireguard_connection.HandshakeWaiter
}

func (c *wireguardConnection) State() <-chan connection.State {
Expand All @@ -94,16 +105,21 @@ func (c *wireguardConnection) Statistics() <-chan consumer.SessionStatistics {
return c.statisticsCh
}

// TODO:(anjmao): Rewrite error handling and cleanup. Currently cleanup assumed to work only if
// int is done correctly but if it fails in any other step user will see broken state.
// See https://github.com/mysteriumnetwork/node/issues/1499.
func (c *wireguardConnection) Start(options connection.ConnectOptions) (err error) {
var config wireguard.ServiceConfig
err = json.Unmarshal(options.SessionConfig, &config)
if err != nil {
return errors.Wrap(err, "could not parse wireguard session config")
}

c.stateCh <- connection.Connecting

defer func() {
if err != nil {
c.Stop()
}
}()

if config.LocalPort > 0 {
err = c.natPinger.PingProvider(
config.Provider.Endpoint.IP.String(),
Expand All @@ -117,42 +133,16 @@ func (c *wireguardConnection) Start(options connection.ConnectOptions) (err erro
}
}

log.Debug().Msg("Creating tunnel device")
tunDevice, err := newTunnDevice(c.tunnelSetup, config)
if err != nil {
return errors.Wrap(err, "could not create tunnel device")
}

devApi := device.NewDevice(tunDevice, device.NewLogger(device.LogLevelDebug, "[userspace-wg]"))
defer func() {
if err != nil && devApi != nil {
devApi.Close()
}
}()

err = setupWireguardDevice(devApi, c.privateKey, config)
if err != nil {
return errors.Wrap(err, "could not setup device")
}
devApi.Up()
socket, err := peekLookAtSocketFd4(devApi)
if err != nil {
return errors.Wrap(err, "could not get socket")
}
err = c.tunnelSetup.Protect(socket)
if err != nil {
return errors.Wrap(err, "could not protect socket")
if err := c.device.Start(c.privateKey, config); err != nil {
return errors.Wrap(err, "could not start device")
}

c.device = devApi
c.stateCh <- connection.Connecting

go c.updateStatsPeriodically(time.Second)

if err := wireguard.WaitHandshake(c.getDeviceStats, c.done); err != nil {
if err := c.handshakeWaiter.Wait(c.device.Stats, c.opts.handshakeTimeout, c.done); err != nil {
return errors.Wrap(err, "failed to handshake")
}

go c.updateStatsPeriodically(c.opts.statsUpdateInterval)

log.Debug().Msg("Connected successfully")
c.stateCh <- connection.Connected
return nil
Expand All @@ -164,19 +154,16 @@ func (c *wireguardConnection) Wait() error {
}

func (c *wireguardConnection) Stop() {
c.stateCh <- connection.Disconnecting
c.updateStatistics()
if c.device != nil {
c.device.Close()
c.device.Wait()
}
c.stateCh <- connection.NotConnected

close(c.done)
close(c.statsCheckerStop)
close(c.pingerStop)
close(c.stateCh)
close(c.statisticsCh)
c.closeOnce.Do(func() {
c.stateCh <- connection.Disconnecting
c.device.Stop()
c.stateCh <- connection.NotConnected

close(c.done)
close(c.statsCheckerStop)
close(c.pingerStop)
close(c.stateCh)
})
}

func (c *wireguardConnection) GetConfig() (connection.ConsumerConfig, error) {
Expand Down Expand Up @@ -208,8 +195,20 @@ func (c *wireguardConnection) isNoopPinger() bool {
return ok
}

func (c *wireguardConnection) updateStatistics() {
stats, err := c.getDeviceStats()
func (c *wireguardConnection) updateStatsPeriodically(duration time.Duration) {
for {
select {
case <-time.After(duration):
c.sendStats()
case <-c.statsCheckerStop:
close(c.statisticsCh)
return
}
}
}

func (c *wireguardConnection) sendStats() {
stats, err := c.device.Stats()
if err != nil {
log.Error().Err(err).Msg("Error updating statistics")
return
Expand All @@ -221,31 +220,69 @@ func (c *wireguardConnection) updateStatistics() {
}
}

func (c *wireguardConnection) getDeviceStats() (*wireguard.Stats, error) {
deviceState, err := wireguard.ParseUserspaceDevice(c.device.IpcGetOperation)
type wireguardDevice interface {
Start(privateKey string, config wireguard.ServiceConfig) error
Stop()
Stats() (*wireguard.Stats, error)
}

func newWireguardDevice(tunnelSetup WireguardTunnelSetup) wireguardDevice {
return &wireguardDeviceImpl{tunnelSetup: tunnelSetup}
}

type wireguardDeviceImpl struct {
tunnelSetup WireguardTunnelSetup

device *device.Device
}

func (w *wireguardDeviceImpl) Start(privateKey string, config wireguard.ServiceConfig) error {
log.Debug().Msg("Creating tunnel device")
tunDevice, err := w.newTunnDevice(w.tunnelSetup, config)
if err != nil {
return nil, errors.Wrap(err, "could not parse userspace wg device state")
return errors.Wrap(err, "could not create tunnel device")
}
stats, err := wireguard.ParseDevicePeerStats(deviceState)

w.device = device.NewDevice(tunDevice, device.NewLogger(device.LogLevelDebug, "[userspace-wg]"))

err = w.applyConfig(w.device, privateKey, config)
if err != nil {
return nil, errors.Wrap(err, "could not get userspace wg peer stats")
return errors.Wrap(err, "could not setup device configuration")
}
return stats, nil
w.device.Up()
socket, err := peekLookAtSocketFd4(w.device)
if err != nil {
return errors.Wrap(err, "could not get socket")
}
err = w.tunnelSetup.Protect(socket)
if err != nil {
return errors.Wrap(err, "could not protect socket")
}
return nil
}

func (c *wireguardConnection) updateStatsPeriodically(duration time.Duration) {
for {
select {
case <-time.After(duration):
c.updateStatistics()
func (w *wireguardDeviceImpl) Stop() {
if w.device != nil {
w.device.Close()
}
}

case <-c.statsCheckerStop:
return
}
func (w *wireguardDeviceImpl) Stats() (*wireguard.Stats, error) {
if w.device == nil {
return nil, errors.New("device is not started")
}
deviceState, err := wireguard.ParseUserspaceDevice(w.device.IpcGetOperation)
if err != nil {
return nil, errors.Wrap(err, "could not parse userspace wg device state")
}
stats, err := wireguard.ParseDevicePeerStats(deviceState)
if err != nil {
return nil, errors.Wrap(err, "could not get userspace wg peer stats")
}
return stats, nil
}

func setupWireguardDevice(devApi *device.Device, privateKey string, config wireguard.ServiceConfig) error {
func (w *wireguardDeviceImpl) applyConfig(devApi *device.Device, privateKey string, config wireguard.ServiceConfig) error {
deviceConfig := wireguard.DeviceConfig{
PrivateKey: privateKey,
ListenPort: config.LocalPort,
Expand All @@ -268,7 +305,7 @@ func setupWireguardDevice(devApi *device.Device, privateKey string, config wireg
return nil
}

func newTunnDevice(wgTunnSetup WireguardTunnelSetup, config wireguard.ServiceConfig) (tun.Device, error) {
func (w *wireguardDeviceImpl) newTunnDevice(wgTunnSetup WireguardTunnelSetup, config wireguard.ServiceConfig) (tun.Device, error) {
consumerIP := config.Consumer.IPAddress
prefixLen, _ := consumerIP.Mask.Size()
wgTunnSetup.NewTunnel()
Expand Down

0 comments on commit 26ba14c

Please sign in to comment.