Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Request struct {
// User Configuration
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)

// Quota Handler
QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool)

Log logging.LeveledLogger
Realm string
ChannelBindTimeout time.Duration
Expand Down
7 changes: 7 additions & 0 deletions internal/server/turn.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint:
// server is free to define this allocation quota any way it wishes,
// but SHOULD define it based on the username used to authenticate
// the request, and not on the client's transport address.
if req.QuotaHandler != nil && !req.QuotaHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) {
quotaReachedMsg := buildMsg(stunMsg.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse),
&stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached})

return buildAndSend(req.Conn, req.SrcAddr, quotaReachedMsg...)
}

// 8. Also at any point, the server MAY choose to reject the request
// with a 300 (Try Alternate) error if it wishes to redirect the
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
type Server struct {
log logging.LeveledLogger
authHandler AuthHandler
quotaHandler QuotaHandler
realm string
channelBindTimeout time.Duration
nonceHash *server.NonceHash
Expand Down Expand Up @@ -59,6 +60,7 @@ func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop
server := &Server{
log: loggerFactory.NewLogger("turn"),
authHandler: config.AuthHandler,
quotaHandler: config.QuotaHandler,
realm: config.Realm,
channelBindTimeout: config.ChannelBindTimeout,
packetConnConfigs: config.PacketConnConfigs,
Expand Down Expand Up @@ -231,6 +233,7 @@ func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Man
Buff: buf[:n],
Log: s.log,
AuthHandler: s.authHandler,
QuotaHandler: s.quotaHandler,
Realm: s.realm,
AllocationManager: allocationManager,
ChannelBindTimeout: s.channelBindTimeout,
Expand Down
9 changes: 9 additions & 0 deletions server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ func GenerateAuthKey(username, realm, password string) []byte {
// allocation's lifecycle.
type EventHandler = allocation.EventHandler

// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is
// exceeded. If the callback returns true the allocation request is accepted, otherwise it is
// rejected and a 486 (Allocation Quota Reached) error is returned to the user.
type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool)

// ServerConfig configures the Pion TURN Server.
type ServerConfig struct {
// PacketConnConfigs and ListenerConfigs are a list of all the turn listeners
Expand All @@ -130,6 +135,10 @@ type ServerConfig struct {
// allowing users to customize Pion TURN with custom behavior
AuthHandler AuthHandler

// QuotaHandler is a callback used to reject new allocations when a
// per-user quota is exceeded.
QuotaHandler QuotaHandler

// EventHandlers is a set of callbacks for tracking allocation lifecycle.
EventHandler EventHandler

Expand Down
52 changes: 52 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,58 @@ func TestSTUNOnly(t *testing.T) {
assert.NoError(t, conn.Close())
}

func TestQuotaReached(t *testing.T) {
serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478")
assert.NoError(t, err)

serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String())
assert.NoError(t, err)

defer serverConn.Close() //nolint:errcheck

credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")}
server, err := NewServer(ServerConfig{
AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) {
if pw, ok := credMap[username]; ok {
return pw, true
}
return nil, false //nolint:nlreturn
},
QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false },
Realm: "pion.ly",
PacketConnConfigs: []PacketConnConfig{{
PacketConn: serverConn,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
}},
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)

defer server.Close() //nolint:errcheck

conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: conn,
STUNServerAddr: "127.0.0.1:3478",
TURNServerAddr: "127.0.0.1:3478",
Username: "user",
Password: "pass",
Realm: "pion.ly",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())
defer client.Close()

_, err = client.Allocate()
assert.Equal(t, err.Error(), "Allocate error response (error 486: )")
}

func RunBenchmarkServer(b *testing.B, clientNum int) { //nolint:cyclop
b.Helper()

Expand Down
Loading