/
handler.go
93 lines (86 loc) · 2.21 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.
package blocklist
import (
"encoding/xml"
"mellium.im/xmlstream"
"mellium.im/xmpp/jid"
"mellium.im/xmpp/mux"
"mellium.im/xmpp/stanza"
)
// Handle returns an option that registers the given handler on the mux for the
// various blocking command payloads.
func Handle(h Handler) mux.Option {
return func(m *mux.ServeMux) {
mux.IQ(stanza.GetIQ, xml.Name{Space: NS, Local: "blocklist"}, h)(m)
mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "block"}, h)(m)
mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "unblock"}, h)(m)
}
}
// Handler can be used to respond to incoming blocking command requests.
type Handler struct {
Block func(jid.JID)
Unblock func(jid.JID)
UnblockAll func()
List func(chan<- jid.JID)
}
// HandleIQ implements mux.IQHandler.
func (h Handler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, start *xml.StartElement) error {
if start.Name.Local == "blocklist" {
res := iq.Result(xmlstream.Wrap(nil, *start))
// Copy the start IQ and start payload first.
_, err := xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
if err != nil {
return err
}
if h.List != nil {
c := make(chan jid.JID)
go func() {
h.List(c)
close(c)
}()
for j := range c {
_, err = xmlstream.Copy(r, xmlstream.Wrap(nil, xml.StartElement{
Name: xml.Name{Space: NS, Local: "item"},
Attr: []xml.Attr{{
Name: xml.Name{Local: "jid"},
Value: j.String(),
}},
}))
if err != nil {
return err
}
}
}
// Copy the end payload and end IQ.
_, err = xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
return err
}
iter := xmlstream.NewIter(r)
var found bool
for iter.Next() {
found = true
itemStart, _ := iter.Current()
jstr := itemStart.Attr[0].Value
j := jid.MustParse(jstr)
switch start.Name.Local {
case "block":
if h.Block != nil {
h.Block(j)
}
case "unblock":
if h.Unblock != nil {
h.Unblock(j)
}
}
}
err := iter.Err()
if err != nil {
return err
}
if !found && start.Name.Local == "unblock" && h.UnblockAll != nil {
h.UnblockAll()
}
return nil
}