Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

admin: Add notify endpoint (POST). #555

Merged
merged 2 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions server/admin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import (
)

const (
pongStr = "pong"
maxUInt16 = int(^uint16(0))
pongStr = "pong"
maxUInt16 = int(^uint16(0))
defaultTimeout = time.Hour * 72
)

// writeJSON marshals the provided interface and writes the bytes to the
Expand Down Expand Up @@ -255,30 +256,62 @@ func (s *Server) apiUnban(w http.ResponseWriter, r *http.Request) {
writeJSON(w, res)
}

// apiNotifyAll is the handler for the '/notifyall' API request.
func (s *Server) apiNotifyAll(w http.ResponseWriter, r *http.Request) {
func toNote(r *http.Request) (*msgjson.Message, int, error) {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
http.Error(w, fmt.Sprintf("unable to read request body: %v", err), http.StatusInternalServerError)
return
return nil, http.StatusInternalServerError, fmt.Errorf("unable to read request body: %v", err)
}
if len(body) == 0 {
http.Error(w, "no message to broadcast", http.StatusBadRequest)
return
return nil, http.StatusBadRequest, errors.New("no message to broadcast")
}
// Remove trailing newline if present. A newline is added by the curl
// command when sending from file.
if body[len(body)-1] == '\n' {
body = body[:len(body)-1]
}
if len(body) > maxUInt16 {
http.Error(w, fmt.Sprintf("cannot send messages larger than %d bytes", maxUInt16), http.StatusBadRequest)
return
return nil, http.StatusBadRequest, fmt.Errorf("cannot send messages larger than %d bytes", maxUInt16)
}
msg, err := msgjson.NewNotification(msgjson.NotifyRoute, string(body))
if err != nil {
http.Error(w, fmt.Sprintf("unable to create notification: %v", err), http.StatusInternalServerError)
return nil, http.StatusInternalServerError, fmt.Errorf("unable to create notification: %v", err)
}
return msg, 0, nil
}

// apiNotify is the handler for the '/account/{accountID}/notify?timeout=TIMEOUT'
// API request.
func (s *Server) apiNotify(w http.ResponseWriter, r *http.Request) {
acctIDStr := chi.URLParam(r, accountIDKey)
acctID, err := decodeAcctID(acctIDStr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
timeout := defaultTimeout
if timeoutStr := r.URL.Query().Get(timeoutToken); timeoutStr != "" {
var err error
timeout, err = time.ParseDuration(timeoutStr)
if err != nil {
http.Error(w, fmt.Sprintf("invalid timeout %q: %v", timeoutStr, err), http.StatusBadRequest)
return
}
}
msg, errCode, err := toNote(r)
if err != nil {
http.Error(w, err.Error(), errCode)
return
}
s.core.Notify(acctID, msg, timeout)
w.WriteHeader(http.StatusOK)
}

// apiNotifyAll is the handler for the '/notifyall' API request.
func (s *Server) apiNotifyAll(w http.ResponseWriter, r *http.Request) {
msg, errCode, err := toNote(r)
if err != nil {
http.Error(w, err.Error(), errCode)
return
}
s.core.NotifyAll(msg)
Expand Down
7 changes: 5 additions & 2 deletions server/admin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
accountIDKey = "account"
ruleToken = "rule"
messageToken = "message"
timeoutToken = "timeout"
)

var (
Expand All @@ -44,8 +45,9 @@ var (

// SvrCore is satisfied by server/dex.DEX.
type SvrCore interface {
Accounts() ([]*db.Account, error)
AccountInfo(account.AccountID) (*db.Account, error)
Accounts() (accts []*db.Account, err error)
AccountInfo(acctID account.AccountID) (*db.Account, error)
Notify(acctID account.AccountID, msg *msgjson.Message, timeout time.Duration)
NotifyAll(msg *msgjson.Message)
ConfigMsg() json.RawMessage
MarketRunning(mktName string) (found, running bool)
Expand Down Expand Up @@ -134,6 +136,7 @@ func NewServer(cfg *SrvConfig) (*Server, error) {
rm.Get("/", s.apiAccountInfo)
rm.Get("/ban", s.apiBan)
rm.Get("/unban", s.apiUnban)
rm.Post("/notify", s.apiNotify)
})
r.Post("/notifyall", s.apiNotifyAll)
r.Get("/markets", s.apiMarkets)
Expand Down
73 changes: 72 additions & 1 deletion server/admin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ func (c *TCore) Penalize(_ account.AccountID, _ account.Rule) error {
func (c *TCore) Unban(_ account.AccountID) error {
return c.unbanErr
}
func (c *TCore) NotifyAll(msg *msgjson.Message) {}
func (c *TCore) Notify(_ account.AccountID, _ *msgjson.Message, _ time.Duration) {}
func (c *TCore) NotifyAll(_ *msgjson.Message) {}

// genCertPair generates a key/cert pair to the paths provided.
func genCertPair(certFile, keyFile string) error {
Expand Down Expand Up @@ -1086,6 +1087,76 @@ func TestUnban(t *testing.T) {
}
}

func TestNotify(t *testing.T) {
core := new(TCore)
srv := &Server{
core: core,
}
mux := chi.NewRouter()
mux.Route("/account/{"+accountIDKey+"}/notify", func(rm chi.Router) {
rm.Post("/", srv.apiNotify)
})
acctIDStr := "0a9912205b2cbab0c25c2de30bda9074de0ae23b065489a99199bad763f102cc"
msgStr := "Hello world.\nAll your base are belong to us."
tests := []struct {
name, txt, acctID, timeout string
wantCode int
}{{
name: "ok no timeout",
acctID: acctIDStr,
txt: msgStr,
wantCode: http.StatusOK,
}, {
name: "ok with timeout",
acctID: acctIDStr,
timeout: "5h3m59s",
txt: msgStr,
wantCode: http.StatusOK,
}, {
name: "ok at max size",
acctID: acctIDStr,
txt: string(make([]byte, maxUInt16)),
wantCode: http.StatusOK,
}, {
name: "message too long",
acctID: acctIDStr,
txt: string(make([]byte, maxUInt16+1)),
wantCode: http.StatusBadRequest,
}, {
name: "bad duration",
acctID: acctIDStr,
timeout: "1d",
txt: msgStr,
wantCode: http.StatusBadRequest,
}, {
name: "account id not hex",
acctID: "nothex",
txt: msgStr,
wantCode: http.StatusBadRequest,
}, {
name: "account id wrong length",
acctID: acctIDStr[2:],
txt: msgStr,
wantCode: http.StatusBadRequest,
}, {
name: "no message",
acctID: acctIDStr,
wantCode: http.StatusBadRequest,
}}
for _, test := range tests {
w := httptest.NewRecorder()
br := bytes.NewReader([]byte(test.txt))
r, _ := http.NewRequest("POST", "https://localhost/account/"+test.acctID+"/notify?timeout="+test.timeout, br)
r.RemoteAddr = "localhost"

mux.ServeHTTP(w, r)

if w.Code != test.wantCode {
t.Fatalf("%q: apiNotify returned code %d, expected %d", test.name, w.Code, test.wantCode)
}
}
}

func TestNotifyAll(t *testing.T) {
core := new(TCore)
srv := &Server{
Expand Down
10 changes: 10 additions & 0 deletions server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,16 @@ func (auth *AuthManager) RequestWithTimeout(user account.AccountID, msg *msgjson
return auth.request(user, msg, f, expireTimeout, 0, expire)
}

// Notify sends a message to a client. The message should be a notification.
// See msgjson.NewNotification. The notification is abandoned upon timeout
// being reached.
func (auth *AuthManager) Notify(acctID account.AccountID, msg *msgjson.Message, timeout time.Duration) {
missedFn := func() {
log.Warnf("user %s missed notification: \n%v", acctID, msg)
}
auth.SendWhenConnected(acctID, msg, timeout, missedFn)
}

// Penalize signals that a user has broken a rule of community conduct, and that
// their account should be penalized.
func (auth *AuthManager) Penalize(user account.AccountID, rule account.Rule) error {
Expand Down
7 changes: 7 additions & 0 deletions server/dex/dex.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,13 @@ func (dm *DEX) Unban(aid account.AccountID) error {
return dm.authMgr.Unban(aid)
}

// Notify sends a text notification to a connected client. If not currently
// connected, sending will be attempted on connection unless the duration of
// timeout has passed.
func (dm *DEX) Notify(acctID account.AccountID, msg *msgjson.Message, timeout time.Duration) {
dm.authMgr.Notify(acctID, msg, timeout)
}

// NotifyAll sends a text notification to all connected clients.
func (dm *DEX) NotifyAll(msg *msgjson.Message) {
dm.server.Broadcast(msg)
Expand Down
2 changes: 2 additions & 0 deletions spec/admin.mediawiki
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ The server will provide an HTTP API for performing various adminstrative tasks.
|-
| /account/{accountID}/ban?rule=RULE || GET || ban an account for violating [[community.mediawiki/#Rules_of_Community_Conduct|a rule]]
|-
| /account/{accountID}/notify?timeout=TIMEOUT || POST || send a notification containing text in the request body to account. If not currently connected, the notificatin will be sent upon reconnect unless timeout duration has passed. default timeout is 72 hours. timeout should be of the form #h#m#s (i.e. "2h" or "5h30m"). Header Content-Type must be set to "text/plain"
JoeGruffins marked this conversation as resolved.
Show resolved Hide resolved
|-
| /markets || GET || display status information for all markets
|-
| /market/{marketID} || GET || display status information for a specific market
Expand Down