Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track "active" WhatsApp users and implement blocking when reaching the limit #323

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ type Config struct {
Avatar string `yaml:"avatar"`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Add a notice to the admin room

} `yaml:"bot"`

Limits struct {
MaxPuppetLimit uint `yaml:"max_puppet_limit"`
MinPuppetActiveDays uint `yaml:"min_puppet_activity_days"`
PuppetInactivityDays uint `yaml:"puppet_inactivity_days"`
BlockOnLimitReached bool `yaml:"block_on_limit_reached"`
} `yaml:"limits"`

ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
} `yaml:"appservice"`
Expand Down Expand Up @@ -93,6 +100,10 @@ func (config *Config) CanDoublePuppet(userID id.UserID) bool {
func (config *Config) setDefaults() {
config.AppService.Database.MaxOpenConns = 20
config.AppService.Database.MaxIdleConns = 2
config.AppService.Limits.MaxPuppetLimit = 0
config.AppService.Limits.MinPuppetActiveDays = 0
config.AppService.Limits.PuppetInactivityDays = 30
config.AppService.Limits.BlockOnLimitReached = false
config.WhatsApp.OSName = "Mautrix-WhatsApp bridge"
config.WhatsApp.BrowserName = "mx-wa"
config.Bridge.setDefaults()
Expand Down
49 changes: 37 additions & 12 deletions database/puppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (pq *PuppetQuery) New() *Puppet {
}

func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet")
rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts, first_activity_ts, last_activity_ts FROM puppet")
if err != nil || rows == nil {
return nil
}
Expand All @@ -54,23 +54,23 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
}

func (pq *PuppetQuery) Get(jid whatsapp.JID) *Puppet {
row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE jid=$1", jid)
row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts, first_activity_ts, last_activity_ts FROM puppet WHERE jid=$1", jid)
if row == nil {
return nil
}
return pq.New().Scan(row)
}

func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid)
row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts, first_activity_ts, last_activity_ts FROM puppet WHERE custom_mxid=$1", mxid)
if row == nil {
return nil
}
return pq.New().Scan(row)
}

func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''")
rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts, first_activity_ts, last_activity_ts FROM puppet WHERE custom_mxid<>''")
if err != nil || rows == nil {
return nil
}
Expand All @@ -91,18 +91,20 @@ type Puppet struct {
Displayname string
NameQuality int8

CustomMXID id.UserID
AccessToken string
NextBatch string
EnablePresence bool
EnableReceipts bool
CustomMXID id.UserID
AccessToken string
NextBatch string
EnablePresence bool
EnableReceipts bool
FirstActivityTs int64
LastActivityTs int64
}

func (puppet *Puppet) Scan(row Scannable) *Puppet {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality sql.NullInt64
var quality, firstActivityTs, lastActivityTs sql.NullInt64
var enablePresence, enableReceipts sql.NullBool
err := row.Scan(&puppet.JID, &avatar, &avatarURL, &displayname, &quality, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
err := row.Scan(&puppet.JID, &avatar, &avatarURL, &displayname, &quality, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts, &firstActivityTs, &lastActivityTs)
if err != nil {
if err != sql.ErrNoRows {
puppet.log.Errorln("Database scan failed:", err)
Expand All @@ -118,6 +120,8 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
puppet.NextBatch = nextBatch.String
puppet.EnablePresence = enablePresence.Bool
puppet.EnableReceipts = enableReceipts.Bool
puppet.FirstActivityTs = firstActivityTs.Int64
puppet.LastActivityTs = lastActivityTs.Int64
return puppet
}

Expand All @@ -133,6 +137,27 @@ func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7, enable_presence=$8, enable_receipts=$9 WHERE jid=$10",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL.String(), puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts, puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
}
}

func (puppet *Puppet) UpdateActivityTs(ts uint64) {
var signedTs = int64(ts)
if puppet.LastActivityTs > signedTs {
return
}
puppet.log.Debugfln("Updating activity time for %s to %d", puppet.JID, signedTs)
puppet.LastActivityTs = signedTs
_, err := puppet.db.Exec("UPDATE puppet SET last_activity_ts=$1 WHERE jid=$2", puppet.LastActivityTs, puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to update last_activity_ts for %s: %v", puppet.JID, err)
}

if puppet.FirstActivityTs == 0 {
puppet.FirstActivityTs = signedTs
_, err = puppet.db.Exec("UPDATE puppet SET first_activity_ts=$1 WHERE jid=$2 AND first_activity_ts is NULL", puppet.FirstActivityTs, puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to update first_activity_ts %s: %v", puppet.JID, err)
}
}
}
16 changes: 16 additions & 0 deletions database/upgrades/2021-07-07-puppet-activity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package upgrades

import (
"database/sql"
)

func init() {
upgrades[21] = upgrade{"Add ", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN first_activity_ts BIGINT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN last_activity_ts BIGINT`)
return err
}}
}
2 changes: 1 addition & 1 deletion database/upgrades/upgrades.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type upgrade struct {
fn upgradeFunc
}

const NumberOfUpgrades = 21
const NumberOfUpgrades = 22

var upgrades [NumberOfUpgrades]upgrade

Expand Down
11 changes: 11 additions & 0 deletions example-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ appservice:
as_token: "This value is generated when generating the registration"
hs_token: "This value is generated when generating the registration"

# Limit usage of the bridge
limits:
# The maximum number of bridge puppets that can be "active" before the limit is reached
max_puppet_limit: 0
# The minimum amount of days a puppet must be active for before they are considered "active".
min_puppet_activity_days: 0
# The number of days after a puppets last activity where they are considered inactive again.
puppet_inactivity_days: 30
# Should the bridge block traffic when a limit has been reached
block_on_limit_reached: false

metrics:
# Whether or not to enable prometheus metrics
enabled: false
Expand Down
43 changes: 40 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ import (
"maunium.net/go/mautrix-whatsapp/database/upgrades"
)

const ONE_DAY_S = 24 * 60 * 60

var (
// These are static
Name = "mautrix-whatsapp"
URL = "https://github.com/tulir/mautrix-whatsapp"
// This is changed when making a release
Version = "0.1.7"
// This is filled by init()
WAVersion = ""
WAVersion = ""
VersionString = ""
// These are filled at build time with the -X linker flag
Tag = "unknown"
Expand Down Expand Up @@ -148,6 +150,7 @@ type Bridge struct {
Relaybot *User
Crypto Crypto
Metrics *MetricsHandler
PuppetActivity *PuppetActivity

usersByMXID map[id.UserID]*User
usersByJID map[whatsapp.JID]*User
Expand Down Expand Up @@ -182,6 +185,10 @@ func NewBridge() *Bridge {
portalsByJID: make(map[database.PortalKey]*Portal),
puppets: make(map[whatsapp.JID]*Puppet),
puppetsByCustomMXID: make(map[id.UserID]*Puppet),
PuppetActivity: &PuppetActivity{
currentUserCount: 0,
isBlocked: false,
},
}

var err error
Expand Down Expand Up @@ -222,7 +229,6 @@ func (bridge *Bridge) Init() {
}
_, _ = bridge.AS.Init()
bridge.Bot = bridge.AS.BotIntent()

bridge.Log = log.Create()
bridge.Config.Logging.Configure(bridge.Log)
log.DefaultLogger = bridge.Log.(*log.BasicLogger)
Expand Down Expand Up @@ -270,7 +276,37 @@ func (bridge *Bridge) Init() {
bridge.MatrixHandler = NewMatrixHandler(bridge)
bridge.Formatter = NewFormatter(bridge)
bridge.Crypto = NewCryptoHelper(bridge)
bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB)
bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB, bridge.PuppetActivity)
}

func (mh *Bridge) UpdateActivePuppetCount() {
mh.Log.Debugfln("Updating active puppet count")

var minActivityTime = int64(ONE_DAY_S * mh.Config.AppService.Limits.MinPuppetActiveDays)
var maxActivityTime = int64(ONE_DAY_S * mh.Config.AppService.Limits.PuppetInactivityDays)
var activePuppetCount uint
var firstActivityTs, lastActivityTs int64

rows, active_err := mh.DB.Query("SELECT first_activity_ts, last_activity_ts FROM puppet WHERE first_activity_ts is not NULL")
if active_err != nil {
mh.Log.Warnln("Failed to scan number of active puppets:", active_err)
} else {
defer rows.Close()
for rows.Next() {
rows.Scan(&firstActivityTs, &lastActivityTs)
var secondsOfActivity = lastActivityTs - firstActivityTs
var isInactive = time.Now().Unix()-lastActivityTs > maxActivityTime
if !isInactive && secondsOfActivity > minActivityTime && secondsOfActivity < maxActivityTime {
activePuppetCount++
}
}
if mh.Config.AppService.Limits.BlockOnLimitReached {
mh.PuppetActivity.isBlocked = mh.Config.AppService.Limits.MaxPuppetLimit < activePuppetCount
}
mh.Log.Debugfln("Current active puppet count is %d (max %d)", activePuppetCount, mh.Config.AppService.Limits.MaxPuppetLimit)
mh.PuppetActivity.currentUserCount = activePuppetCount
}

}

func (bridge *Bridge) Start() {
Expand Down Expand Up @@ -303,6 +339,7 @@ func (bridge *Bridge) Start() {
go bridge.Crypto.Start()
}
go bridge.StartUsers()
bridge.UpdateActivePuppetCount()
if bridge.Config.Metrics.Enabled {
go bridge.Metrics.Start()
}
Expand Down
36 changes: 27 additions & 9 deletions metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ import (
)

type MetricsHandler struct {
db *database.Database
server *http.Server
log log.Logger
db *database.Database
server *http.Server
log log.Logger
puppetActivity *PuppetActivity

running bool
ctx context.Context
Expand All @@ -50,6 +51,8 @@ type MetricsHandler struct {
countCollection prometheus.Histogram
disconnections *prometheus.CounterVec
puppetCount prometheus.Gauge
activePuppetCount prometheus.Gauge
bridgeBlocked prometheus.Gauge
userCount prometheus.Gauge
messageCount prometheus.Gauge
portalCount *prometheus.GaugeVec
Expand All @@ -67,17 +70,17 @@ type MetricsHandler struct {
bufferLength *prometheus.GaugeVec
}

func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler {
func NewMetricsHandler(address string, log log.Logger, db *database.Database, puppetActivity *PuppetActivity) *MetricsHandler {
portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "whatsapp_portals_total",
Help: "Number of portal rooms on Matrix",
}, []string{"type", "encrypted"})
return &MetricsHandler{
db: db,
server: &http.Server{Addr: address, Handler: promhttp.Handler()},
log: log,
running: false,

db: db,
server: &http.Server{Addr: address, Handler: promhttp.Handler()},
log: log,
running: false,
puppetActivity: puppetActivity,
matrixEventHandling: promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "matrix_event",
Help: "Time spent processing Matrix events",
Expand All @@ -103,6 +106,14 @@ func NewMetricsHandler(address string, log log.Logger, db *database.Database) *M
Name: "whatsapp_puppets_total",
Help: "Number of WhatsApp users bridged into Matrix",
}),
activePuppetCount: promauto.NewGauge(prometheus.GaugeOpts{
Name: "whatsapp_active_puppets_total",
Help: "Number of active WhatsApp users bridged into Matrix",
}),
bridgeBlocked: promauto.NewGauge(prometheus.GaugeOpts{
Half-Shot marked this conversation as resolved.
Show resolved Hide resolved
Name: "whatsapp_bridge_blocked",
Help: "Is the bridge currently blocking messages",
}),
userCount: promauto.NewGauge(prometheus.GaugeOpts{
Name: "whatsapp_users_total",
Help: "Number of Matrix users using the bridge",
Expand Down Expand Up @@ -238,6 +249,13 @@ func (mh *MetricsHandler) updateStats() {
mh.puppetCount.Set(float64(puppetCount))
}

mh.activePuppetCount.Set(float64(mh.puppetActivity.currentUserCount))
if mh.puppetActivity.isBlocked {
mh.bridgeBlocked.Set(1)
} else {
mh.bridgeBlocked.Set(0)
}

var userCount int
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
if err != nil {
Expand Down
16 changes: 13 additions & 3 deletions portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ func (portal *Portal) handleMessage(msg PortalMessage, isBackfill bool) {
portal.log.Warnln("handleMessage called even though portal.MXID is empty")
return
}
if portal.bridge.PuppetActivity.isBlocked {
portal.log.Warnln("Bridge is blocking messages")
return
}
var triedToHandle bool
var trackMessageCallback func()
dataType := reflect.TypeOf(msg.data)
Expand Down Expand Up @@ -1393,6 +1397,9 @@ func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessa
} else {
portal.finishHandling(source, message.Info.Source, resp.EventID)
}
sender := portal.bridge.GetPuppetByJID(message.Info.SenderJid)
sender.UpdateActivityTs(message.Info.Timestamp)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be careful not to store historical message times.

portal.bridge.UpdateActivePuppetCount()
return true
}

Expand Down Expand Up @@ -2236,9 +2243,12 @@ func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) {
}

func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
if !portal.HasRelaybot() && (
(portal.IsPrivateChat() && sender.JID != portal.Key.Receiver) ||
portal.sendMatrixConnectionError(sender, evt.ID)) {
if portal.bridge.PuppetActivity.isBlocked {
portal.log.Warnln("Bridge is blocking messages")
return
}
if !portal.HasRelaybot() && ((portal.IsPrivateChat() && sender.JID != portal.Key.Receiver) ||
portal.sendMatrixConnectionError(sender, evt.ID)) {
return
}
portal.log.Debugfln("Received event %s", evt.ID)
Expand Down
5 changes: 5 additions & 0 deletions puppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ import (
"maunium.net/go/mautrix-whatsapp/database"
)

type PuppetActivity struct {
currentUserCount uint
isBlocked bool
}

var userIDRegex *regexp.Regexp

func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (whatsapp.JID, bool) {
Expand Down