Skip to content

Commit

Permalink
channels: handle messages Matrix <-> TG
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
  • Loading branch information
sumnerevans committed Jul 15, 2024
1 parent 5483561 commit 8a8ca90
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 78 deletions.
83 changes: 74 additions & 9 deletions pkg/connector/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"go.mau.fi/mautrix-telegram/pkg/connector/ids"
"go.mau.fi/mautrix-telegram/pkg/connector/media"
"go.mau.fi/mautrix-telegram/pkg/connector/msgconv"
"go.mau.fi/mautrix-telegram/pkg/connector/store"
"go.mau.fi/mautrix-telegram/pkg/connector/util"
)

type TelegramClient struct {
main *TelegramConnector
ScopedStore *store.ScopedStore
telegramUserID int64
loginID networkid.UserLoginID
userID networkid.UserID
Expand Down Expand Up @@ -102,13 +104,28 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
UpdateDispatcher: tg.NewUpdateDispatcher(),
EntityHandler: client.onEntityUpdate,
}
dispatcher.OnNewMessage(client.onUpdateNewMessage)
dispatcher.OnNewChannelMessage(client.onUpdateNewChannelMessage)
dispatcher.OnNewMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewMessage) error {
return client.onUpdateNewMessage(ctx, update)
})
dispatcher.OnNewChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewChannelMessage) error {
fmt.Printf("%+v\n", update)
return client.onUpdateNewMessage(ctx, update)
})
dispatcher.OnUserName(client.onUserName)
dispatcher.OnDeleteMessages(client.onDeleteMessages)
dispatcher.OnEditMessage(client.onMessageEdit)
dispatcher.OnDeleteMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteMessages) error {
return client.onDeleteMessages(ctx, update)
})
dispatcher.OnDeleteChannelMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteChannelMessages) error {
return client.onDeleteMessages(ctx, update)
})
dispatcher.OnEditMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditMessage) error {
return client.onMessageEdit(ctx, update)
})
dispatcher.OnEditChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditChannelMessage) error {
return client.onMessageEdit(ctx, update)
})

store := tc.Store.GetScopedStore(telegramUserID)
client.ScopedStore = tc.Store.GetScopedStore(telegramUserID)

updatesManager := updates.New(updates.Config{
OnChannelTooLong: func(channelID int64) {
Expand All @@ -117,12 +134,12 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
},
Handler: dispatcher,
Logger: zaplog.Named("gaps"),
Storage: store,
AccessHasher: store,
Storage: client.ScopedStore,
AccessHasher: client.ScopedStore,
})

client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{
SessionStorage: store,
SessionStorage: client.ScopedStore,
Logger: zaplog,
UpdateHandler: updatesManager,
})
Expand Down Expand Up @@ -182,7 +199,7 @@ func (t *TelegramClient) Disconnect() {
}

func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) {
fmt.Printf("%+v\n", portal)
fmt.Printf("get chat info %+v\n", portal)
peerType, id, err := ids.ParsePortalID(portal.ID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -251,6 +268,54 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
}
}

for _, user := range fullChat.Users {
memberList.Members = append(memberList.Members, bridgev2.ChatMember{
EventSender: bridgev2.EventSender{
IsFromMe: user.GetID() == t.telegramUserID,
SenderLogin: ids.MakeUserLoginID(user.GetID()),
Sender: ids.MakeUserID(user.GetID()),
},
})
}
case ids.PeerTypeChannel:
accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
} else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", id)
}
fullChat, err := t.client.API().ChannelsGetFullChannel(ctx, &tg.InputChannel{ChannelID: id, AccessHash: accessHash})
if err != nil {
return nil, err
}
for _, c := range fullChat.Chats {
if c.GetID() == id {
switch chat := c.(type) {
case *tg.Chat:
name = chat.Title
case *tg.Channel:
name = chat.Title
}
break
}
}

chatFull, ok := fullChat.FullChat.(*tg.ChatFull)
if !ok {
return nil, fmt.Errorf("full chat is not %T", chatFull)
}

if photo, ok := chatFull.GetChatPhoto(); ok {
avatar = &bridgev2.Avatar{
ID: ids.MakeAvatarID(photo.GetID()),
Get: func(ctx context.Context) (data []byte, err error) {
data, _, err = media.NewTransferer(t.client.API()).WithPhoto(photo).Download(ctx)
return
},
}
}

memberList.IsFull = false
for _, user := range fullChat.Users {
memberList.Members = append(memberList.Members, bridgev2.ChatMember{
EventSender: bridgev2.EventSender{
Expand Down
24 changes: 16 additions & 8 deletions pkg/connector/directdownload.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,26 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
&tg.InputMessageID{ID: int(info.MessageID)},
})
case ids.PeerTypeChannel:
// TODO test this
messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
Channel: &tg.InputChannel{ChannelID: info.ChatID},
ID: []tg.InputMessageClass{
&tg.InputMessageID{ID: int(info.MessageID)},
},
})
var accessHash int64
var found bool
accessHash, found, err = client.ScopedStore.GetChannelAccessHash(ctx, client.telegramUserID, info.ChatID)
if err != nil {
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
} else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID)
} else {
messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash},
ID: []tg.InputMessageClass{
&tg.InputMessageID{ID: int(info.MessageID)},
},
})
}
default:
return nil, fmt.Errorf("unknown peer type %s", info.PeerType)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get messages for %+v: %w", info, err)
}

var msgMedia tg.MessageMediaClass
Expand Down
21 changes: 0 additions & 21 deletions pkg/connector/ids/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,6 @@ func ParsePortalID(portalID networkid.PortalID) (pt PeerType, id int64, err erro
return
}

func InputPeerForPortalID(portalID networkid.PortalID) (tg.InputPeerClass, error) {
peerType, id, err := ParsePortalID(portalID)
if err != nil {
return nil, err
}
switch peerType {
case PeerTypeUser:
return &tg.InputPeerUser{UserID: id}, nil
case PeerTypeChat:
return &tg.InputPeerChat{ChatID: id}, nil
case PeerTypeChannel:
return &tg.InputPeerChannel{ChannelID: id}, nil
default:
panic("invalid peer type")
}
}

func InputPeerForPortalKey(portalKey networkid.PortalKey) (tg.InputPeerClass, error) {
return InputPeerForPortalID(portalKey.ID)
}

func MakeAvatarID(photoID int64) networkid.AvatarID {
return networkid.AvatarID(strconv.FormatInt(photoID, 10))
}
Expand Down
16 changes: 10 additions & 6 deletions pkg/connector/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ func getMediaFilenameAndCaption(content *event.MessageEventContent) (filename, c
}

func (t *TelegramClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (resp *bridgev2.MatrixMessageResponse, err error) {
sender := message.NewSender(t.client.API())
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
if err != nil {
return nil, err
}
builder := sender.To(peer)
builder := message.NewSender(t.client.API()).To(peer)

// TODO handle sticker

Expand Down Expand Up @@ -173,8 +172,13 @@ func (t *TelegramClient) HandleMatrixMessageRemove(ctx context.Context, msg *bri
return err
} else if messageID, err := ids.ParseMessageID(dbMsg.ID); err != nil {
return err
} else if peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID); err != nil {
return err
} else {
_, err = message.NewSender(t.client.API()).Self().Revoke().Messages(ctx, messageID)
_, err := message.NewSender(t.client.API()).
To(peer).
Revoke().
Messages(ctx, messageID)
return err
}
}
Expand Down Expand Up @@ -224,7 +228,7 @@ func (t *TelegramClient) appendEmojiID(reactionList []tg.ReactionClass, emojiID
}

func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (reaction *database.Reaction, err error) {
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -255,7 +259,7 @@ func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2
}

func (t *TelegramClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error {
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/connector/store/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ func (c *Container) Upgrade(ctx context.Context) error {
return c.Database.Upgrade(ctx)
}

func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
return &scopedStore{c.Database, telegramUserID}
func (c *Container) GetScopedStore(telegramUserID int64) *ScopedStore {
return &ScopedStore{c.Database, telegramUserID}
}
40 changes: 20 additions & 20 deletions pkg/connector/store/scoped_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"go.mau.fi/util/dbutil"
)

// scopedStore is a wrapper around a database that implements
// ScopedStore is a wrapper around a database that implements
// [session.Storage] scoped to a specific Telegram user ID.
type scopedStore struct {
type ScopedStore struct {
db *dbutil.Database
telegramUserID int64
}
Expand Down Expand Up @@ -60,22 +60,22 @@ const (
`
)

var _ session.Storage = (*scopedStore)(nil)
var _ session.Storage = (*ScopedStore)(nil)

func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
func (s *ScopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
err = row.Scan(&sessionData)
return
}

func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
func (s *ScopedStore) StoreSession(ctx context.Context, data []byte) error {
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
return err
}

var _ updates.StateStorage = (*scopedStore)(nil)
var _ updates.StateStorage = (*ScopedStore)(nil)

func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
s.assertUserIDMatches(userID)
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
if err != nil {
Expand All @@ -93,7 +93,7 @@ func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(
return nil
}

func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
func (s *ScopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -102,13 +102,13 @@ func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID
return pts, err == nil, err
}

func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
func (s *ScopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
return
}

func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
func (s *ScopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -117,45 +117,45 @@ func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates
return state, err == nil, err
}

func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
func (s *ScopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
return
}

func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
func (s *ScopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
return
}

func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
func (s *ScopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
return
}

func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
func (s *ScopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
return
}

func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
func (s *ScopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
return
}

func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
func (s *ScopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
return
}

var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
var _ updates.ChannelAccessHasher = (*ScopedStore)(nil)

func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -164,15 +164,15 @@ func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, ch
return accessHash, err == nil, err
}

func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
return
}

// Helper Functions

func (s *scopedStore) assertUserIDMatches(userID int64) {
func (s *ScopedStore) assertUserIDMatches(userID int64) {
if s.telegramUserID != userID {
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
}
Expand Down
Loading

0 comments on commit 8a8ca90

Please sign in to comment.