/
packet_server.go
148 lines (125 loc) · 3.54 KB
/
packet_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package server
import (
"crypto/sha256"
"fmt"
"github.com/bokysan/socketace/v2/internal/streams"
"github.com/bokysan/socketace/v2/internal/util/addr"
"github.com/bokysan/socketace/v2/internal/util/cert"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/xtaci/kcp-go/v5"
"golang.org/x/crypto/pbkdf2"
"net"
"strings"
)
type ListenerFromPacketConn func(block kcp.BlockCrypt, conn net.PacketConn) (net.Listener, error)
type PacketServer struct {
cert.ServerConfig
Address addr.ProtoAddress `json:"address"`
Channels []string `json:"channels"`
PacketConnection net.PacketConn
upstreams Channels
listener net.Listener
done bool
}
func NewPacketServer() *PacketServer {
return &PacketServer{}
}
func (st *PacketServer) String() string {
return fmt.Sprintf("%s", st.Address.String())
}
func DefaultListenerFromPacketConn(block kcp.BlockCrypt, conn net.PacketConn) (net.Listener, error) {
return kcp.ServeConn(block, 10, 3, conn)
}
func (st *PacketServer) Startup(channels Channels) error {
return st.StartupPacket(channels, DefaultListenerFromPacketConn)
}
func (st *PacketServer) StartupPacket(channels Channels, createListenerFunc ListenerFromPacketConn) error {
if upstreams, err := channels.Filter(st.Channels); err != nil {
return errors.WithStack(err)
} else {
st.upstreams = upstreams
}
var a = st.Address
var secure bool
var block kcp.BlockCrypt
var pass []byte
var salt []byte
if st.Address.User != nil {
if p, set := st.Address.User.Password(); set && p != "" {
secure = true
pass = []byte(p)
// Not the best way to calculate salt but still better than nothing
h := sha256.New()
h.Write(pass)
salt = h.Sum(nil)
}
}
st.Address.User = nil
if st.PacketConnection == nil {
n, err := a.Addr()
if err != nil {
return errors.WithStack(err)
}
if conn, err := net.ListenPacket(n.Network(), n.String()); err != nil {
return errors.WithStack(err)
} else {
st.PacketConnection = conn
}
}
if secure {
log.Infof("Starting AES-encrypted packet server at %s", st.String())
key := pbkdf2.Key(pass, salt, 1024, 64, sha256.New)
if b, err := kcp.NewAESBlockCrypt(key); err != nil {
return errors.WithStack(err)
} else {
block = b
}
} else {
log.Infof("Starting plain packet server at %s", st.String())
}
listener, err := createListenerFunc(block, st.PacketConnection)
if err != nil {
return errors.WithStack(err)
} else {
st.listener = listener
}
go func() {
st.acceptConnection()
}()
return nil
}
func (st *PacketServer) acceptConnection() {
for !st.done {
conn, err := st.listener.Accept()
if conn != nil {
conn = streams.NewNamedConnection(conn, "packet")
log.Debugf("New connection detected: %+v", conn)
}
if st.done {
if conn != nil && err == nil {
streams.TryClose(conn)
}
break
}
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
log.WithError(err).Errorf("Error accepting the connection: %v", err)
if conn != nil {
streams.TryClose(conn)
}
}
continue
}
// Even though the connection might be secured by an AES-encrypted symmetric ciper, we
// state here "secure=false" to enable the client to provide StartTLS and do a potential
// host check and/or identify itself with a client certificate
if err = AcceptConnection(conn, &st.ServerConfig, false, st.upstreams); err != nil {
log.WithError(err).Errorf("Error accepting connection: %v", err)
}
}
}
func (st *PacketServer) Shutdown() error {
st.done = true
return streams.LogClose(st.listener)
}