diff --git a/custompuppet.go b/custompuppet.go index ebb4061..bf5619a 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -72,6 +72,8 @@ func (user *User) CustomIntent() *appservice.IntentAPI { } func (user *User) tryAutomaticDoublePuppeting() { + user.autoDoublePuppetLock.Lock() + defer user.autoDoublePuppetLock.Unlock() if !user.bridge.Config.CanAutoDoublePuppet(user.MXID) || user.DoublePuppetIntent != nil { return } diff --git a/database/message.go b/database/message.go index 842a72b..f9b3cfa 100644 --- a/database/message.go +++ b/database/message.go @@ -75,15 +75,15 @@ const ( ) func (mq *MessageQuery) GetBySlackID(ctx context.Context, key PortalKey, slackID string) ([]*Message, error) { - return mq.QueryMany(ctx, getMessageBySlackIDQuery, key, slackID) + return mq.QueryMany(ctx, getMessageBySlackIDQuery, key.TeamID, key.ChannelID, slackID) } func (mq *MessageQuery) GetFirstPartBySlackID(ctx context.Context, key PortalKey, slackID string) (*Message, error) { - return mq.QueryOne(ctx, getFirstMessagePartBySlackIDQuery, key, slackID) + return mq.QueryOne(ctx, getFirstMessagePartBySlackIDQuery, key.TeamID, key.ChannelID, slackID) } func (mq *MessageQuery) GetLastPartBySlackID(ctx context.Context, key PortalKey, slackID string) (*Message, error) { - return mq.QueryOne(ctx, getLastMessagePartBySlackIDQuery, key, slackID) + return mq.QueryOne(ctx, getLastMessagePartBySlackIDQuery, key.TeamID, key.ChannelID, slackID) } func (mq *MessageQuery) GetByMXID(ctx context.Context, eventID id.EventID) (*Message, error) { @@ -91,19 +91,19 @@ func (mq *MessageQuery) GetByMXID(ctx context.Context, eventID id.EventID) (*Mes } func (mq *MessageQuery) GetFirstInChannel(ctx context.Context, key PortalKey) (*Message, error) { - return mq.QueryOne(ctx, getFirstMessageInChannelQuery, key) + return mq.QueryOne(ctx, getFirstMessageInChannelQuery, key.TeamID, key.ChannelID) } func (mq *MessageQuery) GetLastInChannel(ctx context.Context, key PortalKey) (*Message, error) { - return mq.QueryOne(ctx, getLastMessageInChannelQuery, key) + return mq.QueryOne(ctx, getLastMessageInChannelQuery, key.TeamID, key.ChannelID) } func (mq *MessageQuery) GetFirstInThread(ctx context.Context, key PortalKey, threadID string) (*Message, error) { - return mq.QueryOne(ctx, getFirstMessageInThreadQuery, key, threadID) + return mq.QueryOne(ctx, getFirstMessageInThreadQuery, key.TeamID, key.ChannelID, threadID) } func (mq *MessageQuery) GetLastInThread(ctx context.Context, key PortalKey, threadID string) (*Message, error) { - return mq.QueryOne(ctx, getLastMessageInThreadQuery, key, threadID) + return mq.QueryOne(ctx, getLastMessageInThreadQuery, key.TeamID, key.ChannelID, threadID) } type PartType string @@ -196,7 +196,7 @@ func (m *Message) SlackURLPath() string { } func (m *Message) sqlVariables() []any { - return []any{m.TeamID, m.ChannelID, m.MessageID, m.Part, dbutil.StrPtr(m.ThreadID), m.AuthorID, m.MXID} + return []any{m.TeamID, m.ChannelID, m.MessageID, m.Part, m.ThreadID, m.AuthorID, m.MXID} } func (m *Message) Insert(ctx context.Context) error { diff --git a/database/puppet.go b/database/puppet.go index 81b2f25..182d454 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -86,7 +86,7 @@ func (p *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { var avatarURL sql.NullString err := row.Scan( &p.TeamID, &p.UserID, - &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.IsBot, + &p.Name, &p.Avatar, &avatarURL, &p.IsBot, &p.NameSet, &p.AvatarSet, &p.ContactInfoSet, ) if err != nil { diff --git a/database/teaminfo.go b/database/teaminfo.go index 6ee5f2e..098fefb 100644 --- a/database/teaminfo.go +++ b/database/teaminfo.go @@ -73,7 +73,7 @@ type TeamPortal struct { func (tp *TeamPortal) Scan(row dbutil.Scannable) (*TeamPortal, error) { var mxid, avatarMXC sql.NullString - err := row.Scan(&tp.ID, &mxid, &tp.Domain, &tp.URL, &tp.NameSet, &tp.AvatarSet, &avatarMXC, &tp.NameSet, &tp.AvatarSet) + err := row.Scan(&tp.ID, &mxid, &tp.Domain, &tp.URL, &tp.Name, &tp.Avatar, &avatarMXC, &tp.NameSet, &tp.AvatarSet) if err != nil { return nil, err } diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 70d7ec7..f60e695 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -18,7 +18,7 @@ CREATE TABLE team_portal ( CREATE TABLE portal ( team_id TEXT NOT NULL, channel_id TEXT NOT NULL, - receiver TEXT NOT NULL, + receiver TEXT NOT NULL, -- TODO add receiver to primary key mxid TEXT, type INT NOT NULL DEFAULT 0, @@ -37,7 +37,7 @@ CREATE TABLE portal ( first_slack_id TEXT, - PRIMARY KEY (team_id, channel_id, receiver), + PRIMARY KEY (team_id, channel_id), CONSTRAINT portal_mxid_unique UNIQUE (mxid), CONSTRAINT portal_team_fkey FOREIGN KEY (team_id) REFERENCES team_portal (id) ON DELETE CASCADE ON UPDATE CASCADE diff --git a/database/userteam.go b/database/userteam.go index 038b1f3..2d986c7 100644 --- a/database/userteam.go +++ b/database/userteam.go @@ -70,7 +70,7 @@ const ( token=$5, cookie_token=$6, in_space=$7 - WHERE team_id=$1 AND user_id=$2 + WHERE team_id=$1 AND user_id=$2 AND user_mxid=$3 ` deleteUserTeamQuery = ` DELETE FROM user_team WHERE team_id=$1 AND user_id=$2 diff --git a/portal.go b/portal.go index db335f9..f08fa81 100644 --- a/portal.go +++ b/portal.go @@ -122,6 +122,10 @@ func (br *SlackBridge) loadPortal(ctx context.Context, dbPortal *database.Portal if key == nil { return nil } + // Get team beforehand to ensure it exists in the database + if br.GetTeamByID(key.TeamID) == nil { + br.ZLog.Warn().Str("team_id", key.TeamID).Msg("Failed to get team by ID before inserting portal") + } dbPortal = br.DB.Portal.New() dbPortal.PortalKey = *key @@ -147,6 +151,7 @@ func (br *SlackBridge) newPortal(dbPortal *database.Portal) *Portal { bridge: br, matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), + slackMessages: make(chan portalSlackMessage, br.Config.Bridge.PortalMessageBuffer), Team: br.GetTeamByID(dbPortal.TeamID), } @@ -396,6 +401,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserTeam, ch portal.bridge.portalsLock.Unlock() portal.updateLogger() portal.zlog.Info().Msg("Matrix room created") + portal.Team.AddPortalToSpace(ctx, portal) err = portal.Update(ctx) if err != nil { @@ -1275,6 +1281,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, source *UserTeam, meta *sl changed = portal.UpdateName(ctx, meta) || changed changed = portal.UpdateTopic(ctx, meta) || changed + changed = portal.Team.AddPortalToSpace(ctx, portal) || changed if changed { portal.UpdateBridgeInfo(ctx) diff --git a/teamportal.go b/teamportal.go index 94d1907..41754fb 100644 --- a/teamportal.go +++ b/teamportal.go @@ -368,7 +368,7 @@ func (team *Team) RemoveMXID(ctx context.Context) { // } -func (team *Team) AddPortal(ctx context.Context, portal *Portal) bool { +func (team *Team) AddPortalToSpace(ctx context.Context, portal *Portal) bool { if len(team.MXID) == 0 { team.log.Error().Msg("Tried to add portal to team that has no matrix ID") if portal.InSpace { diff --git a/user.go b/user.go index 8db96a8..261113e 100644 --- a/user.go +++ b/user.go @@ -47,10 +47,11 @@ type User struct { teams map[string]*UserTeam - spaceCreateLock sync.Mutex - PermissionLevel bridgeconfig.PermissionLevel - DoublePuppetIntent *appservice.IntentAPI - CommandState *commands.CommandState + spaceCreateLock sync.Mutex + autoDoublePuppetLock sync.Mutex + PermissionLevel bridgeconfig.PermissionLevel + DoublePuppetIntent *appservice.IntentAPI + CommandState *commands.CommandState } func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel { @@ -402,7 +403,7 @@ func (user *User) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { if err != nil { user.zlog.Err(err).Msg("Failed to save user after creating space room") } - user.ensureInvited(ctx, nil, user.SpaceRoom, false) + user.ensureInvited(ctx, user.bridge.Bot, user.SpaceRoom, false) return user.SpaceRoom, nil } diff --git a/userteam.go b/userteam.go index 44481af..200aaab 100644 --- a/userteam.go +++ b/userteam.go @@ -65,10 +65,18 @@ func (ut *UserTeam) GetRemoteName() string { } func (br *SlackBridge) loadUserTeam(ctx context.Context, dbUserTeam *database.UserTeam, key *database.UserTeamMXIDKey) *UserTeam { + var team *Team + var user *User if dbUserTeam == nil { if key == nil { return nil } + // Get team and user beforehand to ensure they exist in the database + team = br.unlockedGetTeamByID(key.TeamID, false) + if team == nil { + br.ZLog.Warn().Str("team_id", key.TeamID).Msg("Failed to get team by ID before inserting user team") + } + user = br.unlockedGetUserByMXID(key.UserMXID, false) dbUserTeam = br.DB.UserTeam.New() dbUserTeam.UserTeamMXIDKey = *key err := dbUserTeam.Insert(ctx) @@ -90,8 +98,12 @@ func (br *SlackBridge) loadUserTeam(ctx context.Context, dbUserTeam *database.Us } br.userTeamsByID[userTeam.UserTeamKey] = userTeam - userTeam.Team = br.unlockedGetTeamByID(dbUserTeam.TeamID, true) - userTeam.User = br.unlockedGetUserByMXID(dbUserTeam.UserMXID, true) + if team == nil || user == nil { + team = br.unlockedGetTeamByID(dbUserTeam.TeamID, false) + user = br.unlockedGetUserByMXID(dbUserTeam.UserMXID, false) + } + userTeam.Team = team + userTeam.User = user existingUT, alreadyExists = userTeam.User.teams[userTeam.TeamID] if alreadyExists { @@ -176,11 +188,12 @@ type slackgoZerolog struct { } func (l slackgoZerolog) Output(i int, s string) error { - l.Debug().Msg(s) + l.Debug().Msg(strings.TrimSpace(s)) return nil } func (ut *UserTeam) Connect() { + ut.User.tryAutomaticDoublePuppeting() evt := ut.Log.Trace() hasTraceLog := evt.Enabled() evt.Discard() @@ -234,7 +247,7 @@ func (ut *UserTeam) Sync(ctx context.Context, meta *slack.TeamInfo) { } } ut.AddToSpace(ctx) - ut.User.ensureInvited(ctx, nil, ut.Team.MXID, false) + ut.User.ensureInvited(ctx, ut.bridge.Bot, ut.Team.MXID, false) ut.syncPortals(ctx) ut.SyncEmojis(ctx) }