diff --git a/api4/channel.go b/api4/channel.go index fd33eb8828715..e5544d025549d 100644 --- a/api4/channel.go +++ b/api4/channel.go @@ -18,6 +18,7 @@ func InitChannel() { BaseRoutes.Channels.Handle("", ApiSessionRequired(createChannel)).Methods("POST") BaseRoutes.Channels.Handle("/direct", ApiSessionRequired(createDirectChannel)).Methods("POST") BaseRoutes.Channels.Handle("/members/{user_id:[A-Za-z0-9]+}/view", ApiSessionRequired(viewChannel)).Methods("POST") + BaseRoutes.Channels.Handle("/ids", ApiSessionRequired(getChannelsByIds)).Methods("POST") BaseRoutes.Team.Handle("/channels", ApiSessionRequired(getPublicChannelsForTeam)).Methods("GET") @@ -486,6 +487,28 @@ func viewChannel(c *Context, w http.ResponseWriter, r *http.Request) { ReturnStatusOK(w) } +func getChannelsByIds(c *Context, w http.ResponseWriter, r *http.Request) { + channelIds := model.ArrayFromJson(r.Body) + if len(channelIds) == 0 { + c.SetInvalidParam("channel_ids") + return + } + + for _, cid := range channelIds { + if len(cid) != 26 { + c.SetInvalidParam("channel_id") + return + } + } + + if channels, err := app.GetChannelsByIds(channelIds, c.Session.UserId); err != nil { + c.Err = err + return + } else { + w.Write([]byte(channels.ToJson())) + } +} + func updateChannelMemberRoles(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireChannelId().RequireUserId() if c.Err != nil { diff --git a/api4/channel_test.go b/api4/channel_test.go index ef0d35e4b9c17..1e20419c67491 100644 --- a/api4/channel_test.go +++ b/api4/channel_test.go @@ -6,6 +6,7 @@ package api4 import ( "fmt" "net/http" + "sort" "strconv" "testing" @@ -1022,6 +1023,57 @@ func TestViewChannel(t *testing.T) { CheckNoError(t, resp) } +func TestGetChannelsByIds(t *testing.T) { + th := Setup().InitBasic() + defer TearDown() + Client := th.Client + input := []string{th.BasicChannel.Id} + output := []string{th.BasicChannel.DisplayName} + + channels, resp := Client.GetChannelsByIds(input) + CheckNoError(t, resp) + + if len(*channels) != 1 { + t.Fatal("should return 1 channel") + } + + if (*channels)[0].DisplayName != output[0] { + t.Fatal("missing channel") + } + + input = append(input, GenerateTestId()) + input = append(input, th.BasicChannel2.Id) + output = append(output, th.BasicChannel2.DisplayName) + sort.Strings(output) + + channels, resp = Client.GetChannelsByIds(input) + CheckNoError(t, resp) + + if len(*channels) != 2 { + t.Fatal("should return 2 channels") + } + + for i, c := range *channels { + if c.DisplayName != output[i] { + t.Fatal("missing channel") + } + } + + _, resp = Client.GetChannelsByIds([]string{}) + CheckBadRequestStatus(t, resp) + + _, resp = Client.GetChannelsByIds([]string{"junk"}) + CheckBadRequestStatus(t, resp) + + _, resp = Client.GetChannelsByIds([]string{GenerateTestId()}) + CheckNotFoundStatus(t, resp) + + Client.Logout() + + _, resp = Client.GetChannelsByIds(input) + CheckUnauthorizedStatus(t, resp) +} + func TestGetChannelUnread(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() defer TearDown() diff --git a/app/channel.go b/app/channel.go index d66624f2c4267..8edefc3a9d9d1 100644 --- a/app/channel.go +++ b/app/channel.go @@ -671,6 +671,14 @@ func GetChannelsUserNotIn(teamId string, userId string, offset int, limit int) ( } } +func GetChannelsByIds(channelIds []string, userId string) (*model.ChannelList, *model.AppError) { + if result := <-Srv.Store.Channel().GetChannelsByIds(channelIds, userId); result.Err != nil { + return nil, result.Err + } else { + return result.Data.(*model.ChannelList), nil + } +} + func GetPublicChannelsForTeam(teamId string, offset int, limit int) (*model.ChannelList, *model.AppError) { if result := <-Srv.Store.Channel().GetPublicChannelsForTeam(teamId, offset, limit); result.Err != nil { return nil, result.Err diff --git a/i18n/en.json b/i18n/en.json index d16a288da636e..a48cc2cb0f46e 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -4707,6 +4707,14 @@ "id": "store.sql_channel.get_channels.not_found.app_error", "translation": "No channels were found" }, + { + "id": "store.sql_channel.get_channels_by_ids.get.app_error", + "translation": "We couldn't get the channels" + }, + { + "id": "store.sql_channel.get_channels_by_ids.not_found.app_error", + "translation": "No channel found" + }, { "id": "store.sql_channel.get_deleted_by_name.existing.app_error", "translation": "We couldn't find the existing deleted channel" diff --git a/model/client4.go b/model/client4.go index 3aef5019c0f90..633ef2dc0314e 100644 --- a/model/client4.go +++ b/model/client4.go @@ -1010,6 +1010,16 @@ func (c *Client4) ViewChannel(userId string, view *ChannelView) (bool, *Response } } +// GetChannelsByIds gets a list of channels that the user is member of +func (c *Client4) GetChannelsByIds(channelIds []string) (*ChannelList, *Response) { + if r, err := c.DoApiPost(c.GetChannelsRoute()+"/ids", ArrayToJson(channelIds)); err != nil { + return nil, &Response{StatusCode: r.StatusCode, Error: err} + } else { + defer closeBody(r) + return ChannelListFromJson(r.Body), BuildResponse(r) + } +} + // GetChannelUnread will return a ChannelUnread object that contains the number of // unread messages and mentions for a user. func (c *Client4) GetChannelUnread(channelId, userId string) (*ChannelUnread, *Response) { diff --git a/store/sql_channel_store.go b/store/sql_channel_store.go index d72722f7cce34..45fedf7498901 100644 --- a/store/sql_channel_store.go +++ b/store/sql_channel_store.go @@ -541,6 +541,45 @@ func (s SqlChannelStore) GetChannels(teamId string, userId string) StoreChannel return storeChannel } +func (s SqlChannelStore) GetChannelsByIds(channelIds []string, userId string) StoreChannel { + storeChannel := make(StoreChannel, 1) + + go func() { + result := StoreResult{} + + props := make(map[string]interface{}) + props["userId"] = userId + + idQuery := "" + + for index, channelId := range channelIds { + if len(idQuery) > 0 { + idQuery += ", " + } + + props["channelId"+strconv.Itoa(index)] = channelId + idQuery += ":channelId" + strconv.Itoa(index) + } + + data := &model.ChannelList{} + _, err := s.GetReplica().Select(data, "SELECT Channels.* FROM Channels, ChannelMembers WHERE Id = ChannelId AND UserId = :userId AND DeleteAt = 0 AND ChannelId IN ("+idQuery+") ORDER BY DisplayName", props) + + if err != nil { + result.Err = model.NewLocAppError("SqlChannelStore.GetChannelsByIds", "store.sql_channel.get_channels_by_ids.get.app_error", nil, err.Error()) + } + + if len(*data) == 0 { + result.Err = model.NewAppError("SqlChannelStore.GetChannelsByIds", "store.sql_channel.get_channels_by_ids.not_found.app_error", nil, "", http.StatusNotFound) + } + + result.Data = data + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + func (s SqlChannelStore) GetMoreChannels(teamId string, userId string, offset int, limit int) StoreChannel { storeChannel := make(StoreChannel, 1) diff --git a/store/sql_channel_store_test.go b/store/sql_channel_store_test.go index f347fa438d6c7..1a1b58131a333 100644 --- a/store/sql_channel_store_test.go +++ b/store/sql_channel_store_test.go @@ -732,6 +732,110 @@ func TestChannelStoreGetChannels(t *testing.T) { store.Channel().InvalidateAllChannelMembersForUser(m1.UserId) } +func TestChannelStoreGetChannelsByIds(t *testing.T) { + Setup() + + co1 := model.Channel{} + co1.TeamId = model.NewId() + co1.DisplayName = "Channel1" + co1.Name = "a" + model.NewId() + "b" + co1.Type = model.CHANNEL_OPEN + Must(store.Channel().Save(&co1)) + + co2 := model.Channel{} + co2.TeamId = model.NewId() + co2.DisplayName = "Channel2" + co2.Name = "a" + model.NewId() + "b" + co2.Type = model.CHANNEL_OPEN + Must(store.Channel().Save(&co2)) + + cp3 := model.Channel{} + cp3.TeamId = model.NewId() + cp3.DisplayName = "Channel3" + cp3.Name = "a" + model.NewId() + "b" + cp3.Type = model.CHANNEL_PRIVATE + Must(store.Channel().Save(&cp3)) + + cm1 := model.ChannelMember{} + cm1.ChannelId = co1.Id + cm1.UserId = model.NewId() + cm1.NotifyProps = model.GetDefaultChannelNotifyProps() + Must(store.Channel().SaveMember(&cm1)) + + cm2 := model.ChannelMember{} + cm2.ChannelId = co1.Id + cm2.UserId = model.NewId() + cm2.NotifyProps = model.GetDefaultChannelNotifyProps() + Must(store.Channel().SaveMember(&cm2)) + + cm2.ChannelId = co2.Id + Must(store.Channel().SaveMember(&cm2)) + + cm2.ChannelId = cp3.Id + Must(store.Channel().SaveMember(&cm2)) + + cids := []string{co1.Id} + cresult := <-store.Channel().GetChannelsByIds(cids, cm1.UserId) + list := cresult.Data.(*model.ChannelList) + + if len(*list) != 1 { + t.Fatal("should return 1 channel") + } + + if (*list)[0].Id != co1.Id { + t.Fatal("missing channel") + } + + cids = append(cids, co2.Id) + cresult = <-store.Channel().GetChannelsByIds(cids, cm1.UserId) + list = cresult.Data.(*model.ChannelList) + + if len(*list) != 1 { + t.Fatal("should return 1 channel") + } + + cresult = <-store.Channel().GetChannelsByIds(cids, cm2.UserId) + list = cresult.Data.(*model.ChannelList) + + if len(*list) != 2 { + t.Fatal("should return 2 channels") + } + + cids = append(cids, cp3.Id) + cresult = <-store.Channel().GetChannelsByIds(cids, cm2.UserId) + list = cresult.Data.(*model.ChannelList) + + if len(*list) != 3 { + t.Fatal("should return 3 channels") + } + + for i, c := range *list { + if c.Id != cids[i] { + t.Fatal("missing channel") + } + } + + cids = append(cids, model.NewId()) + cresult = <-store.Channel().GetChannelsByIds(cids, cm2.UserId) + list = cresult.Data.(*model.ChannelList) + + if len(*list) != 3 { + t.Fatal("should return 3 channels") + } + + cids = cids[:0] + cids = append(cids, model.NewId()) + cresult = <-store.Channel().GetChannelsByIds(cids, cm2.UserId) + list = cresult.Data.(*model.ChannelList) + + if len(*list) != 0 { + t.Fatal("should not return a channel") + } + + store.Channel().InvalidateAllChannelMembersForUser(cm1.UserId) + store.Channel().InvalidateAllChannelMembersForUser(cm2.UserId) +} + func TestChannelStoreGetMoreChannels(t *testing.T) { Setup() diff --git a/store/store.go b/store/store.go index 323727697b20d..2479042c22866 100644 --- a/store/store.go +++ b/store/store.go @@ -101,6 +101,7 @@ type ChannelStore interface { GetByNameIncludeDeleted(team_id string, name string, allowFromCache bool) StoreChannel GetDeletedByName(team_id string, name string) StoreChannel GetChannels(teamId string, userId string) StoreChannel + GetChannelsByIds(channelIds []string, userId string) StoreChannel GetMoreChannels(teamId string, userId string, offset int, limit int) StoreChannel GetPublicChannelsForTeam(teamId string, offset int, limit int) StoreChannel GetChannelCounts(teamId string, userId string) StoreChannel