/
proto.go
93 lines (76 loc) · 1.87 KB
/
proto.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package exchange
import (
"context"
"time"
"golang.org/x/xerrors"
"github.com/gotd/td/bin"
"github.com/gotd/td/clock"
"github.com/gotd/td/internal/proto"
"github.com/gotd/td/internal/proto/codec"
"github.com/gotd/td/transport"
)
type unencryptedWriter struct {
clock clock.Clock
conn transport.Conn
timeout time.Duration
input proto.MessageType
output proto.MessageType
}
func (w unencryptedWriter) writeUnencrypted(ctx context.Context, b *bin.Buffer, data bin.Encoder) error {
b.Reset()
if err := data.Encode(b); err != nil {
return err
}
msg := proto.UnencryptedMessage{
MessageID: int64(proto.NewMessageID(w.clock.Now(), w.output)),
MessageData: b.Copy(),
}
b.Reset()
if err := msg.Encode(b); err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, w.timeout)
defer cancel()
return w.conn.Send(ctx, b)
}
func (w unencryptedWriter) tryRead(ctx context.Context, b *bin.Buffer) error {
ctx, cancel := context.WithTimeout(ctx, w.timeout)
defer cancel()
if err := w.conn.Recv(ctx, b); err != nil {
return err
}
return nil
}
func (w unencryptedWriter) isClient() bool {
return w.output == proto.MessageFromClient
}
func (w unencryptedWriter) readUnencrypted(ctx context.Context, b *bin.Buffer, data bin.Decoder) error {
b.Reset()
for {
if err := w.tryRead(ctx, b); err != nil {
var protocolErr *codec.ProtocolErr
if w.isClient() &&
xerrors.As(err, &protocolErr) &&
protocolErr.Code == codec.CodeAuthKeyNotFound {
continue
}
return err
}
break
}
var msg proto.UnencryptedMessage
if err := msg.Decode(b); err != nil {
return err
}
if err := w.checkMsgID(msg.MessageID); err != nil {
return err
}
b.ResetTo(msg.MessageData)
return data.Decode(b)
}
func (w unencryptedWriter) checkMsgID(id int64) error {
if proto.MessageID(id).Type() != w.input {
return xerrors.New("bad msg type")
}
return nil
}