From 5a68173fe39345b8473e04bfa67cae5a13f6ca7f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 15 Mar 2023 16:18:51 +0200 Subject: [PATCH] Add support for unix sockets in appservice module Closes #116 Co-authored-by: Boris Rybalkin --- appservice/appservice.go | 93 ++++++++++++++++++++++++++++------- appservice/http.go | 38 +++++++++++--- bridge/bridge.go | 9 ++-- bridge/bridgeconfig/config.go | 2 +- bridge/crypto.go | 14 ++---- client.go | 2 +- url.go | 2 +- 7 files changed, 116 insertions(+), 44 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 17a08d88..099e4b27 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -7,10 +7,14 @@ package appservice import ( + "context" "fmt" + "net" "net/http" "net/http/cookiejar" + "net/url" "os" + "strings" "sync" "syscall" "time" @@ -20,6 +24,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" + "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -100,7 +105,7 @@ type StateStore interface { // It also serves as the appservice instance struct. type AppService struct { HomeserverDomain string - HomeserverURL string + hsURLForClient *url.URL Host HostConfig Registration *Registration @@ -178,6 +183,14 @@ func (hc *HostConfig) Address() string { return fmt.Sprintf("%s:%d", hc.Hostname, hc.Port) } +func (hc *HostConfig) IsUnixSocket() bool { + return strings.HasPrefix(hc.Hostname, "/") +} + +func (hc *HostConfig) IsConfigured() bool { + return hc.IsUnixSocket() || hc.Port != 0 +} + // Save saves this config into a file at the given path. func (as *AppService) Save(path string) error { data, err := yaml.Marshal(as) @@ -249,29 +262,73 @@ func (as *AppService) BotIntent() *IntentAPI { return as.botIntent } +func (as *AppService) SetHomeserverURL(homeserverURL string) error { + parsedURL, err := url.Parse(homeserverURL) + if err != nil { + return err + } + + as.hsURLForClient = parsedURL + if as.hsURLForClient.Scheme == "unix" { + as.hsURLForClient.Scheme = "http" + as.hsURLForClient.Host = "unix" + as.hsURLForClient.Path = "" + } else if as.hsURLForClient.Scheme == "" { + as.hsURLForClient.Scheme = "https" + } + as.hsURLForClient.RawPath = parsedURL.EscapedPath() + + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} + if parsedURL.Scheme == "unix" { + as.HTTPClient.Transport = &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", parsedURL.Path) + }, + } + } + return nil +} + +func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { + client := &mautrix.Client{ + HomeserverURL: as.hsURLForClient, + UserID: userID, + SetAppServiceUserID: true, + AccessToken: as.Registration.AppToken, + UserAgent: as.UserAgent, + StateStore: as.StateStore, + Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), + Client: as.HTTPClient, + DefaultHTTPRetries: as.DefaultHTTPRetries, + } + client.Logger = maulogadapt.ZeroAsMau(&client.Log) + return client +} + +func (as *AppService) NewExternalMautrixClient(userID id.UserID, token string, homeserverURL string) (*mautrix.Client, error) { + client := as.NewMautrixClient(userID) + client.AccessToken = token + if homeserverURL != "" { + client.Client = &http.Client{Timeout: 180 * time.Second} + var err error + client.HomeserverURL, err = mautrix.ParseAndNormalizeBaseURL(homeserverURL) + if err != nil { + return nil, err + } + } + return client, nil +} + func (as *AppService) makeClient(userID id.UserID) *mautrix.Client { as.clientsLock.Lock() defer as.clientsLock.Unlock() client, ok := as.clients[userID] - if ok { - return client - } - - client, err := mautrix.NewClient(as.HomeserverURL, userID, as.Registration.AppToken) - if err != nil { - as.Log.Error().Err(err).Msg("Failed to create mautrix client instance") - return nil + if !ok { + client = as.NewMautrixClient(userID) + as.clients[userID] = client } - client.UserAgent = as.UserAgent - client.Syncer = nil - client.Store = nil - client.StateStore = as.StateStore - client.SetAppServiceUserID = true - client.Log = as.Log.With().Str("as_user_id", client.UserID.String()).Logger() - client.Client = as.HTTPClient - client.DefaultHTTPRetries = as.DefaultHTTPRetries - as.clients[userID] = client return client } diff --git a/appservice/http.go b/appservice/http.go index 2a35dd99..06ac7788 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -11,8 +11,10 @@ import ( "encoding/json" "errors" "io" + "net" "net/http" "strings" + "syscall" "time" "github.com/gorilla/mux" @@ -25,17 +27,15 @@ import ( // Start starts the HTTP server that listens for calls from the Matrix homeserver. func (as *AppService) Start() { - var err error as.server = &http.Server{ - Addr: as.Host.Address(), Handler: as.Router, } - if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { - as.Log.Info().Str("address", as.Host.Address()).Msg("Starting HTTP listener") - err = as.server.ListenAndServe() + var err error + if as.Host.IsUnixSocket() { + err = as.listenUnix() } else { - as.Log.Info().Str("address", as.Host.Address()).Msg("Starting HTTP listener with TLS") - err = as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) + as.server.Addr = as.Host.Address() + err = as.listenTCP() } if err != nil && !errors.Is(err, http.ErrServerClosed) { as.Log.Error().Err(err).Msg("Error in HTTP listener") @@ -44,6 +44,30 @@ func (as *AppService) Start() { } } +func (as *AppService) listenUnix() error { + socket := as.Host.Hostname + _ = syscall.Unlink(socket) + defer func() { + _ = syscall.Unlink(socket) + }() + listener, err := net.Listen("unix", socket) + if err != nil { + return err + } + as.Log.Info().Str("socket", socket).Msg("Starting unix socket HTTP listener") + return as.server.Serve(listener) +} + +func (as *AppService) listenTCP() error { + if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") + return as.server.ListenAndServe() + } else { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS") + return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) + } +} + func (as *AppService) Stop() { if as.server == nil { return diff --git a/bridge/bridge.go b/bridge/bridge.go index f4ddbd63..c327c933 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -301,7 +301,7 @@ func (br *Bridge) ensureConnection() { os.Exit(17) } - if br.SpecVersions.UnstableFeatures["fi.mau.msc2659"] && br.AS.Host.Port != 0 { + if br.SpecVersions.UnstableFeatures["fi.mau.msc2659"] && br.AS.Host.IsConfigured() { txnID := br.Bot.TxnID() resp, err := br.Bot.AppservicePing(br.Config.AppService.ID, txnID) if err != nil { @@ -505,7 +505,7 @@ func (br *Bridge) init() { br.Crypto = NewCryptoHelper(br) - hsURL := br.AS.HomeserverURL + hsURL := br.Config.Homeserver.Address if br.Config.Homeserver.PublicAddress != "" { hsURL = br.Config.Homeserver.PublicAddress } @@ -546,13 +546,12 @@ func (br *Bridge) start() { br.LogDBUpgradeErrorAndExit("matrix_state", err) } - if br.AS.Host.Port != 0 { + if br.AS.Host.IsConfigured() { br.ZLog.Debug().Msg("Starting application service HTTP server") go br.AS.Start() } else { - br.ZLog.Debug().Msg("Appservice port not configured, not starting HTTP server") + br.ZLog.Debug().Msg("Appservice config doesn't have port nor unix socket path, not starting HTTP server") } - br.ZLog.Debug().Msg("Checking connection to homeserver") br.ensureConnection() go br.fetchMediaConfig() diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index 3b0a6ea8..ed0eb0d2 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -102,7 +102,7 @@ func (config *BaseConfig) GenerateRegistration() *appservice.Registration { func (config *BaseConfig) MakeAppService() *appservice.AppService { as := appservice.Create() as.HomeserverDomain = config.Homeserver.Domain - as.HomeserverURL = config.Homeserver.Address + _ = as.SetHomeserverURL(config.Homeserver.Address) as.Host.Hostname = config.AppService.Hostname as.Host.Port = config.AppService.Port as.DefaultHTTPRetries = 4 diff --git a/bridge/crypto.go b/bridge/crypto.go index e0d249d9..7dc42c1b 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -136,23 +136,15 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") } - client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "") - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to initialize client: %w", err) - } - client.StateStore = helper.bridge.AS.StateStore - client.Log = helper.log.With().Str("as_user_id", helper.bridge.AS.BotMXID().String()).Logger() - client.Client = helper.bridge.AS.HTTPClient - client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries + // Create a new client instance with the default AS settings (including as_token), + // the Login call will then override the access token in the client. + client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) flows, err := client.GetLoginFlows() if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } - // We set the API token to the AS token here to authenticate the appservice login - // It'll get overridden after the login - client.AccessToken = helper.bridge.AS.Registration.AppToken resp, err := client.Login(&mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ diff --git a/client.go b/client.go index f6e8b2d1..2923eaea 100644 --- a/client.go +++ b/client.go @@ -2001,7 +2001,7 @@ func (cli *Client) TxnID() string { // NewClient creates a new Matrix Client ready for syncing func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Client, error) { - hsURL, err := parseAndNormalizeBaseURL(homeserverURL) + hsURL, err := ParseAndNormalizeBaseURL(homeserverURL) if err != nil { return nil, err } diff --git a/url.go b/url.go index 0011d4a4..5ea03f1d 100644 --- a/url.go +++ b/url.go @@ -13,7 +13,7 @@ import ( "strings" ) -func parseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) { +func ParseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) { hsURL, err := url.Parse(homeserverURL) if err != nil { return nil, err