Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 65 additions & 128 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"

Expand All @@ -25,6 +26,34 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}

// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}

// If it's not the specific malformed case, return it as is.
return endpoint
}


func main() {
// Check if we're running as a Windows service
if isWindowsService() {
Expand Down Expand Up @@ -498,29 +527,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
})

olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)

jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}

if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}

// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})

// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
})

// Register handlers for different message types
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)

Expand Down Expand Up @@ -558,98 +564,55 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
}

tdev, err = func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)

// if on macOS, call findUnusedUTUN to get a new utun device
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, mtuInt)
}

if tunFdStr == "" {
return tun.CreateTUN(interfaceName, mtuInt)
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtuInt)
}

return createTUNFromFD(tunFdStr, mtuInt)
return tun.CreateTUN(interfaceName, mtuInt)
}()

if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}

realInterfaceName, err2 := tdev.Name()
if err2 == nil {
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}

// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" {
return uapiOpen(interfaceName)
}

// use supplied fd

fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil { return nil, err }
return os.NewFile(uintptr(fd), ""), nil
}

return os.NewFile(uintptr(fd), ""), nil
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}

dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))

errs := make(chan error)
if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return }

dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))

uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) }

go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
errs <- err
return
}
if err != nil { return }
go dev.IpcHandle(conn)
}
}()

logger.Info("UAPI listener started")

// Bring up the device
err = dev.Up()
if err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}

// configure the interface
err = ConfigureInterface(realInterfaceName, wgData)
if err != nil {
logger.Error("Failed to configure interface: %v", err)
}

// Set tunnel IP in HTTP server
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) }
if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) }
if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) }

peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
Expand Down Expand Up @@ -680,28 +643,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
doHolepunch,
)

// loop over the sites and call ConfigurePeer for each one
for _, site := range wgData.Sites {
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
err = ConfigurePeer(dev, site, privateKey, endpoint)
if err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}

err = addRouteForServerIP(site.ServerIP, interfaceName)
if err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)

// Add routes for remote subnets
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return }
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return }
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }

logger.Info("Configured peer %s", site.PublicKey)
}
Expand Down Expand Up @@ -748,12 +701,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
break
}
}

if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
// Send error response if needed
return
}

// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)

if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return }

// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
Expand All @@ -771,12 +723,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {

// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
// If this is part of a WgData structure, update it
for i, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break }
}
} else {
logger.Error("WireGuard device not initialized")
Expand Down Expand Up @@ -811,23 +759,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {

// Add the peer to WireGuard
if dev != nil {
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}

// Add route for the new peer
err = addRouteForServerIP(siteConfig.ServerIP, interfaceName)
if err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)

// Add routes for remote subnets
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return }
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return }
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }

// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
Expand Down