Skip to content

Commit

Permalink
feat: add AppCode into canceled token table
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmx233 committed Apr 6, 2024
1 parent 3f207b9 commit a9a13fe
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 61 deletions.
2 changes: 1 addition & 1 deletion internal/api/controllers/app/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func DestroyRefreshToken(c *gin.Context) {
return
}

err = redis.CancelToken(context.Background(), claims.ID, claims.ExpiresAt.Time)
err = redis.CancelToken(context.Background(), claims.ID, claims.AppCode, claims.ExpiresAt.Time)
if err != nil {
callback.Error(c, callback.ErrUnexpected, err)
return
Expand Down
2 changes: 1 addition & 1 deletion internal/api/controllers/user/dev/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func ModifyApp(c *gin.Context) {
ids := make([]uint, len(records))
for i, record := range records {
ids[i] = record.ID
err = redis.CancelToken(ctx, uint64(record.ID), time.Unix(int64(record.ValidBefore), 0))
err = redis.CancelToken(ctx, uint64(record.ID), record.AppCode, time.Unix(int64(record.ValidBefore), 0))
if err != nil {
if errors.Is(err, redis.Nil) {
continue
Expand Down
21 changes: 12 additions & 9 deletions internal/api/controllers/user/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
"github.com/ncuhome/GeniusAuthoritarian/internal/db/redis"
"github.com/ncuhome/GeniusAuthoritarian/internal/service"
"github.com/ncuhome/GeniusAuthoritarian/internal/tools"
"gorm.io/gorm"
"time"
)

func Logout(c *gin.Context) {
userClaims := tools.GetUserInfo(c)
loginID := tools.GetUserInfo(c).ID

loginRecordSrv, err := service.LoginRecord.Begin()
if err != nil {
Expand All @@ -21,13 +23,13 @@ func Logout(c *gin.Context) {
}
defer loginRecordSrv.Rollback()

err = loginRecordSrv.SetDestroyed(uint(userClaims.ID))
err = loginRecordSrv.SetDestroyed(uint(loginID))
if err != nil {
callback.Error(c, callback.ErrDBOperation, err)
return
}

err = redis.CancelToken(context.Background(), userClaims.ID, userClaims.ExpiresAt.Time)
err = redis.NewRecordedToken().NewStorePoint(loginID).Destroy(context.Background())
if err != nil {
callback.Error(c, callback.ErrUnexpected, err)
return
Expand Down Expand Up @@ -57,14 +59,15 @@ func LogoutDevice(c *gin.Context) {
}
defer loginRecordSrv.Rollback()

userClaims := tools.GetUserInfo(c)
exist, err := loginRecordSrv.OnlineRecordExist(userClaims.UID, f.ID, daoUtil.LockForUpdate)
uid := tools.GetUserInfo(c).UID
loginRecord, err := loginRecordSrv.TakeOnlineRecord(uid, f.ID, daoUtil.LockForUpdate)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
callback.Error(c, callback.ErrTargetDeviceOffline)
return
}
callback.Error(c, callback.ErrDBOperation, err)
return
} else if !exist {
callback.Error(c, callback.ErrTargetDeviceOffline)
return
}

err = loginRecordSrv.SetDestroyed(f.ID)
Expand All @@ -73,7 +76,7 @@ func LogoutDevice(c *gin.Context) {
return
}

err = redis.CancelToken(context.Background(), userClaims.ID, userClaims.ExpiresAt.Time)
err = redis.CancelToken(context.Background(), uint64(loginRecord.ID), loginRecord.AppCode, time.Unix(int64(loginRecord.ValidBefore), 0))
if err != nil {
if errors.Is(err, redis.Nil) {
callback.Error(c, callback.ErrTargetDeviceOffline)
Expand Down
28 changes: 17 additions & 11 deletions internal/db/dao/LoginRecord.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (a *LoginRecord) sqlLoginValid(tx *gorm.DB) *gorm.DB {
return tx.Where("login_records.valid_before>? AND NOT login_records.destroyed=1", time.Now().Unix())
}

func (a *LoginRecord) sqlGetForCancel(tx *gorm.DB) *gorm.DB {
tx = tx.Model(&LoginRecord{})
tx = a.sqlJoinApps(tx)
tx = a.sqlLoginValid(tx)
tx = tx.Select("login_records.id", "login_records.valid_before", "apps.app_code")
return tx
}

func (a *LoginRecord) Insert(tx *gorm.DB) error {
return tx.Create(a).Error
}
Expand All @@ -62,10 +70,9 @@ func (a *LoginRecord) GetByUID(tx *gorm.DB, limit int) ([]dto.LoginRecord, error
return t, tx.Find(&t).Error
}

func (a *LoginRecord) GetByAID(tx *gorm.DB) ([]dto.LoginRecordForCancel, error) {
func (a *LoginRecord) GetForCancelByAID(tx *gorm.DB) ([]dto.LoginRecordForCancel, error) {
var t []dto.LoginRecordForCancel
tx = tx.Model(a)
tx = a.sqlLoginValid(tx)
tx = a.sqlGetForCancel(tx)
return t, tx.Where(a, "aid").Find(&t).Error
}

Expand All @@ -76,14 +83,6 @@ func (a *LoginRecord) GetValidForUser(tx *gorm.DB) ([]dto.LoginRecordOnline, err
return t, tx.Find(&t).Error
}

func (a *LoginRecord) ValidExist(tx *gorm.DB) (bool, error) {
var t bool
tx = tx.Model(&LoginRecord{}).Select("1")
tx = a.sqlLoginValid(tx)
tx = tx.Where(a, "id", "uid").Limit(1)
return t, tx.Find(&t).Error
}

func (a *LoginRecord) GetLastMonth(tx *gorm.DB) ([]LoginRecord, error) {
var t []LoginRecord
return t, tx.Model(a).Where("created_at<=?", 604800).Order("id DESC").Find(&t).Error
Expand All @@ -110,3 +109,10 @@ func (a *LoginRecord) GetAdminViews(tx *gorm.DB, startAt int64) ([]dto.LoginReco
var t = make([]dto.LoginRecordDataView, 0)
return t, tx.Model(a).Where("created_at>=?", startAt).Order("id DESC").Find(&t).Error
}

func (a *LoginRecord) TakeValidForCancel(tx *gorm.DB) (*dto.LoginRecordForCancel, error) {
var t dto.LoginRecordForCancel
tx = a.sqlGetForCancel(tx)
tx = tx.Where(a, "id", "uid")
return &t, tx.Take(&t).Error
}
1 change: 1 addition & 0 deletions internal/db/dto/LoginRecord.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ type AdminLoginDataView struct {

type LoginRecordForCancel struct {
ID uint `json:"id"`
AppCode string `json:"appCode"`
ValidBefore uint64 `json:"validBefore"`
}
86 changes: 51 additions & 35 deletions internal/db/redis/RefreshToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@ package redis

import (
"context"
"encoding/json"
"github.com/Mmx233/tool"
"github.com/go-redis/redis/v8"
"github.com/ncuhome/GeniusAuthoritarian/pkg/tokenStore"
log "github.com/sirupsen/logrus"
"go/types"
"strconv"
"time"
"unsafe"
)

func CancelToken(ctx context.Context, id uint64, validBefore time.Time) error {
err := NewCanceledToken().Add(ctx, CanceledToken{
ID: id,
ValidBefore: validBefore,
})
func CancelToken(ctx context.Context, id uint64, appCode string, validBefore time.Time) error {
canceledToken := CanceledToken{
ID: id,
CanceledTokenPayload: CanceledTokenPayload{
AppCode: appCode,
ValidBefore: validBefore,
},
}
err := NewCanceledToken().Add(ctx, canceledToken)
if err != nil {
return err
}
if err = NewCanceledTokenChannel().Publish(ctx, id); err != nil {
if err = NewCanceledTokenChannel().Publish(ctx, canceledToken); err != nil {
return err
}
return NewRecordedToken().NewStorePoint(id).Destroy(ctx)
Expand All @@ -39,8 +45,12 @@ type CanceledTokenChannel struct {
key string
}

func (a CanceledTokenChannel) Publish(ctx context.Context, id uint64) error {
return Client.Publish(ctx, a.key, strconv.FormatUint(id, 10)).Err()
func (a CanceledTokenChannel) Publish(ctx context.Context, token CanceledToken) error {
data, err := json.Marshal(token)
if err != nil {
return err
}
return Client.Publish(ctx, a.key, data).Err()
}

func (a CanceledTokenChannel) Subscribe(ctx context.Context) *redis.PubSub {
Expand All @@ -53,76 +63,82 @@ func NewCanceledToken() CanceledTokenTable {
}
}

type CanceledTokenPayload struct {
AppCode string `json:"appCode"`
ValidBefore time.Time `json:"validBefore"`
}

type CanceledToken struct {
ID uint64
ValidBefore time.Time
ID uint64 `json:"id"`
CanceledTokenPayload
}

func (a CanceledToken) Key() string {
return strconv.FormatUint(a.ID, 10)
}

func (a CanceledToken) Value() string {
return a.ValidBefore.Format(time.RFC3339)
}

type CanceledTokenTable struct {
key string
}

func (a CanceledTokenTable) Add(ctx context.Context, id ...CanceledToken) error {
fields := make([]interface{}, len(id)*2)
for i, v := range id {
func (a CanceledTokenTable) Add(ctx context.Context, tokens ...CanceledToken) error {
fields := make([]interface{}, len(tokens)*2)
for i, v := range tokens {
fields[i*2] = v.Key()
fields[i*2+1] = v.Value()
data, err := json.Marshal(v.CanceledTokenPayload)
if err != nil {
return err
}
fields[i*2+1] = data
}
return Client.HSet(ctx, a.key, fields...).Err()
}

func (a CanceledTokenTable) Get(ctx context.Context) ([]uint64, error) {
func (a CanceledTokenTable) Get(ctx context.Context) ([]CanceledToken, error) {
result, err := Client.HGetAll(ctx, a.key).Result()
if err != nil {
return nil, err
}
ids := make([]uint64, len(result))
canceledTokens := make([]CanceledToken, len(result))
left, right := 0, len(result)-1
for k, v := range result {
id, err := strconv.ParseUint(k, 10, 64)
var canceledToken CanceledToken
var err error
canceledToken.ID, err = strconv.ParseUint(k, 10, 64)
if err != nil {
log.Errorln("parse id failed", err)
continue
}
validBefore, err := time.Parse(time.RFC3339, v)
if err != nil {
log.Errorln("parse time failed", err)
if err = json.Unmarshal(unsafe.Slice(unsafe.StringData(v), len(v)), &canceledToken.CanceledTokenPayload); err != nil {
log.Errorln("parse canceled token failed", err)
continue
}
if validBefore.After(time.Now()) {
ids[left] = id
if canceledToken.ValidBefore.After(time.Now()) {
canceledTokens[left] = canceledToken
left++
} else {
ids[right] = id
canceledTokens[right] = canceledToken
right--
}
}
if left != len(result)-1 {
go a.clean(ids[left+1:]...)
go a.clean(canceledTokens[left+1:]...)
}
return ids[:left], nil
return canceledTokens[:left], nil
}

func (a CanceledTokenTable) clean(id ...uint64) {
func (a CanceledTokenTable) clean(tokens ...CanceledToken) {
defer tool.Recover()
err := a.remove(context.Background(), id...)
err := a.remove(context.Background(), tokens...)
if err != nil {
log.Errorln("clean canceled token failed", err)
}
}

func (a CanceledTokenTable) remove(ctx context.Context, id ...uint64) error {
keyGroup := make([]string, len(id))
for i, v := range id {
keyGroup[i] = strconv.FormatUint(v, 10)
func (a CanceledTokenTable) remove(ctx context.Context, tokens ...CanceledToken) error {
keyGroup := make([]string, len(tokens))
for i, v := range tokens {
keyGroup[i] = v.Key()
}
return Client.HDel(ctx, a.key, keyGroup...).Err()

Expand Down
2 changes: 1 addition & 1 deletion internal/rpc/refreshToken/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (s *Server) DestroyRefreshToken(ctx context.Context, req *refreshTokenProto
return nil, status.Error(codes.Internal, "database error")
}

err = redis.CancelToken(ctx, claims.ID, claims.ExpiresAt.Time)
err = redis.CancelToken(ctx, claims.ID, claims.AppCode, claims.ExpiresAt.Time)
if err != nil {
return nil, status.Error(codes.Internal, "destroy token failed")
}
Expand Down
6 changes: 3 additions & 3 deletions internal/service/LoginRecord.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (a LoginRecordSrv) SetDestroyedByIDS(ids []uint) error {
}

func (a LoginRecordSrv) GetValidForApp(aid uint, opt ...daoUtil.ServiceOpt) ([]dto.LoginRecordForCancel, error) {
return (&dao.LoginRecord{AID: &aid}).GetByAID(daoUtil.TxOpts(a.DB, opt...))
return (&dao.LoginRecord{AID: &aid}).GetForCancelByAID(daoUtil.TxOpts(a.DB, opt...))
}

func (a LoginRecordSrv) UserHistory(uid uint, limit int) ([]dto.LoginRecord, error) {
Expand Down Expand Up @@ -90,11 +90,11 @@ func (a LoginRecordSrv) UserOnline(uid uint, currentLoginID uint) ([]dto.LoginRe
return validRecords[0:pointer], nil
}

func (a LoginRecordSrv) OnlineRecordExist(uid, id uint, opts ...daoUtil.ServiceOpt) (bool, error) {
func (a LoginRecordSrv) TakeOnlineRecord(uid, id uint, opts ...daoUtil.ServiceOpt) (*dto.LoginRecordForCancel, error) {
return (&dao.LoginRecord{
ID: id,
UID: uid,
}).ValidExist(daoUtil.TxOpts(a.DB, opts...))
}).TakeValidForCancel(daoUtil.TxOpts(a.DB, opts...))
}

func (a LoginRecordSrv) GetViewIDs(aid, startAt uint) ([]uint, error) {
Expand Down

0 comments on commit a9a13fe

Please sign in to comment.