Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add last seen field to client (fixes #400) #582

Merged
merged 2 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/application_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation()
Description: "mydesc",
Image: "asd",
Internal: true,
LastUsed: nil,
}
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true, "defaultPriority":0}`)
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true, "defaultPriority":0, "lastUsed":null}`)
}

func (s *ApplicationSuite) Test_CreateApplication_expectBadRequestOnEmptyName() {
Expand Down
2 changes: 1 addition & 1 deletion api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *ClientSuite) AfterTest(suiteName, testName string) {

func (s *ClientSuite) Test_ensureClientHasCorrectJsonRepresentation() {
actual := &model.Client{ID: 1, UserID: 2, Token: "Casdasfgeeg", Name: "myclient"}
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Casdasfgeeg","name":"myclient"}`)
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Casdasfgeeg","name":"myclient","lastUsed":null}`)
}

func (s *ClientSuite) Test_CreateClient_mapAllParameters() {
Expand Down
25 changes: 25 additions & 0 deletions api/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string
}
}

// CollectConnectedClientTokens returns all tokens of the connected clients.
func (a *API) CollectConnectedClientTokens() []string {
a.lock.RLock()
defer a.lock.RUnlock()
var clients []string
for _, cs := range a.clients {
for _, c := range cs {
clients = append(clients, c.token)
}
}
return uniq(clients)
}

// NotifyDeletedUser closes existing connections for the given user.
func (a *API) NotifyDeletedUser(userID uint) error {
a.lock.Lock()
Expand Down Expand Up @@ -155,6 +168,18 @@ func (a *API) Close() {
}
}

func uniq[T comparable](s []T) []T {
m := make(map[T]struct{})
for _, v := range s {
m[v] = struct{}{}
}
var r []T
for k := range m {
r = append(r, k)
}
return r
}

func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
origin := r.Header.Get("origin")
if origin == "" {
Expand Down
92 changes: 75 additions & 17 deletions api/stream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -56,8 +57,8 @@ func TestWriteMessageFails(t *testing.T) {
wsURL := wsURL(server.URL)
user := testClient(t, wsURL)

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

clients := clients(api, 1)
assert.NotEmpty(t, clients)

Expand Down Expand Up @@ -86,13 +87,13 @@ func TestWritePingFails(t *testing.T) {
user := testClient(t, wsURL)
defer user.conn.Close()

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

clients := clients(api, 1)

assert.NotEmpty(t, clients)

time.Sleep(api.pingPeriod) // waiting for ping
time.Sleep(api.pingPeriod + (50 * time.Millisecond)) // waiting for ping

api.Notify(1, &model.MessageExternal{Message: "HI"})
user.expectNoMessage()
Expand Down Expand Up @@ -147,8 +148,8 @@ func TestCloseClientOnNotReading(t *testing.T) {
assert.Nil(t, err)
defer ws.Close()

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

assert.NotEmpty(t, clients(api, 1))

time.Sleep(api.pingPeriod + api.pongTimeout)
Expand All @@ -167,8 +168,9 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {

user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)

waitForConnectedClients(api, 1)

api.Notify(1, &model.MessageExternal{Message: "msg"})
user.expectMessage(&model.MessageExternal{Message: "msg"})
}
Expand All @@ -184,8 +186,9 @@ func TestDeleteClientShouldCloseConnection(t *testing.T) {

user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)

waitForConnectedClients(api, 1)

api.Notify(1, &model.MessageExternal{Message: "msg"})
user.expectMessage(&model.MessageExternal{Message: "msg"})

Expand Down Expand Up @@ -230,8 +233,7 @@ func TestDeleteMultipleClients(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

api.Notify(1, &model.MessageExternal{ID: 4, Message: "there"})
expectMessage(&model.MessageExternal{ID: 4, Message: "there"}, userOne...)
Expand Down Expand Up @@ -294,8 +296,7 @@ func TestDeleteUser(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

api.Notify(1, &model.MessageExternal{ID: 4, Message: "there"})
expectMessage(&model.MessageExternal{ID: 4, Message: "there"}, userOne...)
Expand All @@ -322,6 +323,43 @@ func TestDeleteUser(t *testing.T) {
api.Close()
}

func TestCollectConnectedClientTokens(t *testing.T) {
mode.Set(mode.TestDev)

defer leaktest.Check(t)()
userIDs := []uint{1, 1, 1, 2, 2}
tokens := []string{"1-1", "1-2", "1-2", "2-1", "2-2"}
i := 0
server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i])
i++
})
defer server.Close()

wsURL := wsURL(server.URL)
userOneConnOne := testClient(t, wsURL)
defer userOneConnOne.conn.Close()
userOneConnTwo := testClient(t, wsURL)
defer userOneConnTwo.conn.Close()
userOneConnThree := testClient(t, wsURL)
defer userOneConnThree.conn.Close()
waitForConnectedClients(api, 3)

ret := api.CollectConnectedClientTokens()
sort.Strings(ret)
assert.Equal(t, []string{"1-1", "1-2"}, ret)

userTwoConnOne := testClient(t, wsURL)
defer userTwoConnOne.conn.Close()
userTwoConnTwo := testClient(t, wsURL)
defer userTwoConnTwo.conn.Close()
waitForConnectedClients(api, 5)

ret = api.CollectConnectedClientTokens()
sort.Strings(ret)
assert.Equal(t, []string{"1-1", "1-2", "2-1", "2-2"}, ret)
}

func TestMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)

Expand Down Expand Up @@ -354,8 +392,7 @@ func TestMultipleClients(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

// there should not be messages at the beginning
expectNoMessage(userOne...)
Expand Down Expand Up @@ -474,6 +511,17 @@ func clients(api *API, user uint) []*client {
return api.clients[user]
}

func countClients(a *API) int {
a.lock.RLock()
defer a.lock.RUnlock()

var i int
for _, clients := range a.clients {
i += len(clients)
}
return i
}

func testClient(t *testing.T, url string) *testingClient {
client := createClient(t, url)
startReading(client)
Expand Down Expand Up @@ -560,3 +608,13 @@ func staticUserID() gin.HandlerFunc {
auth.RegisterAuthentication(context, nil, 1, "customtoken")
}
}

func waitForConnectedClients(api *API, count int) {
for i := 0; i < 10; i++ {
if countClients(api) == count {
// ok
return
}
time.Sleep(10 * time.Millisecond)
}
}
27 changes: 21 additions & 6 deletions auth/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"errors"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/auth/password"
Expand All @@ -20,6 +21,8 @@
GetPluginConfByToken(token string) (*model.PluginConf, error)
GetUserByName(name string) (*model.User, error)
GetUserByID(id uint) (*model.User, error)
UpdateClientTokensLastUsed(tokens []string, t *time.Time) error
UpdateApplicationTokenLastUsed(token string, t *time.Time) error
}

// Auth is the provider for authentication middleware.
Expand Down Expand Up @@ -56,10 +59,16 @@
if user != nil {
return true, true, user.ID, nil
}
if token, err := a.DB.GetClientByToken(tokenID); err != nil {
if client, err := a.DB.GetClientByToken(tokenID); err != nil {
return false, false, 0, err
} else if token != nil {
return true, true, token.UserID, nil
} else if client != nil {
now := time.Now()
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateClientTokensLastUsed([]string{tokenID}, &now); err != nil {
return false, false, 0, err

Check warning on line 68 in auth/authentication.go

View check run for this annotation

Codecov / codecov/patch

auth/authentication.go#L68

Added line #L68 was not covered by tests
}
}
return true, true, client.UserID, nil
}
return false, false, 0, nil
})
Expand All @@ -71,10 +80,16 @@
if user != nil {
return true, false, 0, nil
}
if token, err := a.DB.GetApplicationByToken(tokenID); err != nil {
if app, err := a.DB.GetApplicationByToken(tokenID); err != nil {
return false, false, 0, err
} else if token != nil {
return true, true, token.UserID, nil
} else if app != nil {
now := time.Now()
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateApplicationTokenLastUsed(tokenID, &now); err != nil {
return false, false, 0, err

Check warning on line 89 in auth/authentication.go

View check run for this annotation

Codecov / codecov/patch

auth/authentication.go#L89

Added line #L89 was not covered by tests
}
}
return true, true, app.UserID, nil
}
return false, false, 0, nil
})
Expand Down
7 changes: 7 additions & 0 deletions database/application.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
)
Expand Down Expand Up @@ -56,3 +58,8 @@ func (d *GormDatabase) GetApplicationsByUser(userID uint) ([]*model.Application,
func (d *GormDatabase) UpdateApplication(app *model.Application) error {
return d.DB.Save(app).Error
}

// UpdateApplicationTokenLastUsed updates the last used time of the application token.
func (d *GormDatabase) UpdateApplicationTokenLastUsed(token string, t *time.Time) error {
return d.DB.Model(&model.Application{}).Where("token = ?", token).Update("last_used", t).Error
}
10 changes: 10 additions & 0 deletions database/application_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -40,6 +42,14 @@ func (s *DatabaseSuite) TestApplication() {
assert.Equal(s.T(), app, newApp)
}

lastUsed := time.Now().Add(-time.Hour)
s.db.UpdateApplicationTokenLastUsed(app.Token, &lastUsed)
newApp, err = s.db.GetApplicationByID(app.ID)
if assert.NoError(s.T(), err) {
assert.Equal(s.T(), lastUsed.Unix(), newApp.LastUsed.Unix())
}
app.LastUsed = &lastUsed

newApp.Image = "asdasd"
assert.NoError(s.T(), s.db.UpdateApplication(newApp))

Expand Down
7 changes: 7 additions & 0 deletions database/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
)
Expand Down Expand Up @@ -55,3 +57,8 @@ func (d *GormDatabase) DeleteClientByID(id uint) error {
func (d *GormDatabase) UpdateClient(client *model.Client) error {
return d.DB.Save(client).Error
}

// UpdateClientTokensLastUsed updates the last used timestamp of clients.
func (d *GormDatabase) UpdateClientTokensLastUsed(tokens []string, t *time.Time) error {
return d.DB.Model(&model.Client{}).Where("token IN (?)", tokens).Update("last_used", t).Error
}
9 changes: 9 additions & 0 deletions database/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -44,6 +46,13 @@ func (s *DatabaseSuite) TestClient() {
assert.Equal(s.T(), updateClient, updatedClient)
}

lastUsed := time.Now().Add(-time.Hour)
s.db.UpdateClientTokensLastUsed([]string{client.Token}, &lastUsed)
newClient, err = s.db.GetClientByID(client.ID)
if assert.NoError(s.T(), err) {
assert.Equal(s.T(), lastUsed.Unix(), newClient.LastUsed.Unix())
}

s.db.DeleteClientByID(client.ID)

if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) {
Expand Down
Loading
Loading