Skip to content
This repository has been archived by the owner on Nov 19, 2022. It is now read-only.

Commit

Permalink
Add optional bandwidth quota limit
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelmota committed Aug 20, 2019
1 parent b7631df commit b248243
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 42 deletions.
28 changes: 22 additions & 6 deletions cmd/streamhut/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/streamhut/streamhut/pkg/db/sqlite3db"
"github.com/streamhut/streamhut/pkg/httpserver"
"github.com/streamhut/streamhut/pkg/tcpserver"
"github.com/streamhut/streamhut/pkg/util"
"github.com/streamhut/streamhut/pkg/wsserver"
)

Expand All @@ -26,7 +27,8 @@ var ErrDBTypeUnsupported = errors.New("Database type is unsupported")
// ErrChannelRequired ...
var ErrChannelRequired = errors.New("Channel is required")

var yellow = color.New(color.FgYellow).SprintFunc()
var yellow = color.New(color.FgYellow)
var yellowSprintf = color.New(color.FgYellow).SprintFunc()
var green = color.New(color.FgGreen)

func main() {
Expand Down Expand Up @@ -55,6 +57,8 @@ For more info, visit: https://github.com/streamhut/streamhut`,
var shareBaseURL string
var webTarURL string
var webDir string
var humanBandwidthQuotaLimit string
var humanBandwidthQuotaDuration string

serverCmd := &cobra.Command{
Use: "server",
Expand Down Expand Up @@ -83,11 +87,16 @@ For more info, visit: https://github.com/streamhut/streamhut`,
shareBaseURL = fmt.Sprintf("http://127.0.0.1:%d/", httpPort)
}

bandwidthQuotaLimit := util.StorageSizeToUint64(humanBandwidthQuotaLimit)
bandwidthQuotaDuration := util.DurationStringToType(humanBandwidthQuotaDuration)

tcpServer := tcpserver.NewServer(&tcpserver.Config{
WS: ws,
Port: tcpPort,
DB: db,
ShareBaseURL: shareBaseURL,
WS: ws,
Port: tcpPort,
DB: db,
ShareBaseURL: shareBaseURL,
BandwidthQuotaLimit: bandwidthQuotaLimit,
BandwidthQuotaDuration: bandwidthQuotaDuration,
})

go func() {
Expand All @@ -111,6 +120,11 @@ For more info, visit: https://github.com/streamhut/streamhut`,
green.Printf("HTTP/WebSocket port: %d\n", server.Port())
green.Printf("TCP port: %d\n", tcpServer.Port())

if tcpServer.BandwidthQuotaEnabled() {
yellow.Printf("Bandwidth quota limit: %s\n", tcpServer.BandwidthQuotaLimit().String())
yellow.Printf("Bandwidth quota duration: %s\n", tcpServer.BandwidthQuotaDuration().String())
}

return server.Start()
},
}
Expand All @@ -122,6 +136,8 @@ For more info, visit: https://github.com/streamhut/streamhut`,
serverCmd.Flags().StringVarP(&shareBaseURL, "share-base-url", "", os.Getenv("HOST_URL"), "Share base URL. Example: \"https://stream.ht/\"")
serverCmd.Flags().StringVarP(&webTarURL, "web-tar-url", "", httpserver.DefaultWebTarURL, "Web app tarball url to download")
serverCmd.Flags().StringVarP(&webDir, "web-dir", "", httpserver.DefaultWebDir, "Web app directory")
serverCmd.Flags().StringVarP(&humanBandwidthQuotaLimit, "bandwidth-quota-limit", "", os.Getenv("BANDWIDTH_QUOTA_LIMIT"), "bandwidth quota limit (eg. 100kb, 1mb, 1gb, etc)")
serverCmd.Flags().StringVarP(&humanBandwidthQuotaDuration, "bandwidth-quota-duration", "", os.Getenv("BANDWIDTH_QUOTA_DURATION"), "bandwidth quota duration (eg. 45s, 10m, 1h, 1d, 1w, etc)")

var host string
var port uint
Expand Down Expand Up @@ -170,7 +186,7 @@ func handleExit(cb func()) {
signal.Notify(gracefulStop, syscall.SIGINT)
go func() {
sig := <-gracefulStop
fmt.Printf("Caught signal: %+v\n%s", sig, yellow("Shutting down..."))
fmt.Printf("Caught signal: %+v\n%s", sig, yellowSprintf("Shutting down..."))
cb()
os.Exit(0)
}()
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pelletier/go-toml v1.4.0 // indirect
github.com/pkg/errors v0.8.1 // indirect
github.com/prometheus/client_golang v1.1.0 // indirect
Expand All @@ -45,7 +46,7 @@ require (
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 // indirect
golang.org/x/mobile v0.0.0-20190806162312-597adff16ade // indirect
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 // indirect
golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa // indirect
golang.org/x/sys v0.0.0-20190812073006-9eafafc0a87e // indirect
golang.org/x/tools v0.0.0-20190809145639-6d4652c779c4 // indirect
google.golang.org/grpc v1.22.1 // indirect
honnef.co/go/tools v0.0.1-2019.2.2 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down Expand Up @@ -259,6 +261,8 @@ golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa h1:KIDDMLT1O0Nr7TSxp8xM5tJcdn8tgyAONntO829og1M=
golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190812073006-9eafafc0a87e h1:TsjK5I7fXk8f2FQrgu6NS7i5Qih3knl2FL1htyguLRE=
golang.org/x/sys v0.0.0-20190812073006-9eafafc0a87e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
Expand Down
13 changes: 8 additions & 5 deletions pkg/db/sqlite3db/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ func NewDB(config *Config) *DB {
log.Fatal(err)
}

svc := &DB{db}
svc := &DB{
db: db,
}

for _, line := range svc.schema() {
if len(line) < 3 {
continue
Expand Down Expand Up @@ -153,8 +156,8 @@ var insertMu sync.Mutex

// InsertStreamLog ...
func (d *DB) InsertStreamLog(vLog *types.StreamLog) {
//insertMu.Lock()
//defer insertMu.Unlock()
insertMu.Lock()
defer insertMu.Unlock()
tx, err := d.db.Begin()
if err != nil {
log.Fatal(3, err)
Expand All @@ -181,8 +184,8 @@ func (d *DB) InsertStreamLog(vLog *types.StreamLog) {

// InsertStreamMessage ...
func (d *DB) InsertStreamMessage(msg *types.StreamMessage) {
//insertMu.Lock()
//defer insertMu.Unlock()
insertMu.Lock()
defer insertMu.Unlock()
tx, err := d.db.Begin()
if err != nil {
log.Fatal(3, err)
Expand Down
1 change: 0 additions & 1 deletion pkg/httpserver/httpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func (s *Server) channelHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id := vars["channelId"]
_ = id
fmt.Println("W")
if strings.Contains(r.URL.String(), ".websocket") {
w.Write([]byte(""))
return
Expand Down
146 changes: 122 additions & 24 deletions pkg/tcpserver/tcpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,55 @@ import (
"strings"
"time"

"github.com/patrickmn/go-cache"
gocache "github.com/patrickmn/go-cache"
uuid "github.com/satori/go.uuid"
"github.com/streamhut/streamhut/common/byteutil"
"github.com/streamhut/streamhut/common/stringutil"
"github.com/streamhut/streamhut/common/util"
common "github.com/streamhut/streamhut/common/util"
"github.com/streamhut/streamhut/pkg/db"
"github.com/streamhut/streamhut/pkg/util"
"github.com/streamhut/streamhut/pkg/wsserver"
"github.com/streamhut/streamhut/types"
)

// ErrInvalidQuotaLimit ...
var ErrInvalidQuotaLimit = fmt.Sprintf("Invalid quota size")

// ErrInvalidQuotaDuration ...
var ErrInvalidQuotaDuration = fmt.Sprintf("Invalid quota duration")

// DefaultBandwidthQuotaLimit ...
var DefaultBandwidthQuotaLimit = 1000 * 1000 * 10 // 10mb

// DefaultBandwidthQuotaDuration ...
var DefaultBandwidthQuotaDuration = 1 * time.Minute

// BandwidthQuotaLimit type
type BandwidthQuotaLimit uint64

// Server ...
type Server struct {
host string
listener net.Listener
port uint
ws *wsserver.WS
db db.DB
shareBaseURL string
host string
listener net.Listener
port uint
ws *wsserver.WS
db db.DB
shareBaseURL string
cache *gocache.Cache
bandwidthQuotaLimit BandwidthQuotaLimit
bandwidthQuotaDuration time.Duration
}

// Config ...
type Config struct {
Host string
Port uint
WS *wsserver.WS
DB db.DB
ShareBaseURL string
Host string
Port uint
WS *wsserver.WS
DB db.DB
ShareBaseURL string
BandwidthQuotaLimit uint64
BandwidthQuotaDuration time.Duration
}

// NewServer ...
Expand All @@ -49,12 +72,23 @@ func NewServer(config *Config) *Server {
}
}

if config.BandwidthQuotaLimit > 0 && config.BandwidthQuotaDuration == 0 {
log.Fatal(ErrInvalidQuotaDuration)
}

if config.BandwidthQuotaDuration > 0 && config.BandwidthQuotaLimit == 0 {
log.Fatal(ErrInvalidQuotaLimit)
}

return &Server{
host: config.Host,
port: config.Port,
ws: config.WS,
db: config.DB,
shareBaseURL: shareBaseURL,
host: config.Host,
port: config.Port,
ws: config.WS,
db: config.DB,
shareBaseURL: shareBaseURL,
cache: cache.New(gocache.DefaultExpiration, gocache.DefaultExpiration),
bandwidthQuotaLimit: BandwidthQuotaLimit(config.BandwidthQuotaLimit),
bandwidthQuotaDuration: config.BandwidthQuotaDuration,
}
}

Expand Down Expand Up @@ -89,7 +123,7 @@ func (s *Server) randChannel() string {
for {
channel := stringutil.RandStringRunes(6)
_, ok := s.ws.Socks[channel]
if !ok && util.ValidChannelName(channel) {
if !ok && common.ValidChannelName(channel) {
return channel
}
}
Expand All @@ -103,10 +137,16 @@ func (s *Server) channelTaken(channel string) bool {
func (s *Server) handleRequest(client *wsserver.Conn) {
reader := bufio.NewReader(client.Netconn)
index := 0
expired := false
channelReadExpired := false

var ip string
if addr, ok := client.Netconn.RemoteAddr().(*net.TCPAddr); ok {
ip = addr.IP.String()
}

// NOTE: a timeout to allow reading of channel first
time.AfterFunc(5*time.Millisecond, func() {
expired = true
channelReadExpired = true
if client.Channel == "" {
client.Channel = s.randChannel()
}
Expand All @@ -117,9 +157,42 @@ func (s *Server) handleRequest(client *wsserver.Conn) {

for {
line := make([]byte, 1024)
_, err := reader.Read(line)
n, err := reader.Read(line)
switch err {
case nil:
if s.bandwidthQuotaLimit > 0 {
_, exp, found := s.cache.GetWithExpiration(ip)
if found {
expiresIn := time.Duration(exp.Unix()-time.Now().Unix()) * time.Second

if expiresIn.Seconds() == 0 {
client.ResetBandwidthQuota()
s.cache.Delete(ip)
} else {
msg := fmt.Sprintf("streamhut: bandwidth quota reached. Try again in %vs", expiresIn.Seconds())
log.Printf("quota reached for ip %v; can retry in %vs\n", ip, expiresIn.Seconds())
client.Netconn.Write([]byte(msg))

// NOTE: timeout must be less than channelReadExpired
time.Sleep(4 * time.Millisecond)
client.Netconn.Close()
return
}
}

client.TollBandwidth(uint64(n))
if client.BandwidthQuotaUsed() > s.bandwidthQuotaLimit.Uint64() {
expiresIn := time.Duration(int(s.bandwidthQuotaDuration.Seconds())-time.Now().Second()) * time.Second
s.cache.Set(ip, time.Now(), expiresIn)
msg := fmt.Sprintf("\nstreamhut: bandwidth quota reached. Try again in %vs", expiresIn.Seconds())
log.Printf("quota reached for ip %v; can retry in %vs\n", ip, expiresIn.Seconds())
client.Netconn.Write([]byte(msg))
time.Sleep(1 * time.Second)
client.Netconn.Close()
return
}
}

// echo back to client
client.Netconn.Write(line)
case io.EOF:
Expand All @@ -128,13 +201,13 @@ func (s *Server) handleRequest(client *wsserver.Conn) {
os.Exit(0)
}

if index == 0 && !expired {
if index == 0 && !channelReadExpired {
if len(line) > 0 && line[0] == '#' {
re := regexp.MustCompile(`#([a-zA-Z0-9]+)\n?\r?`)
matches := re.FindAllStringSubmatch(string(line), -1)
if len(matches) > 0 && len(matches[0]) > 1 {
client.Channel = util.NormalizeChannelName(matches[0][1])
if !util.ValidChannelName(client.Channel) {
client.Channel = common.NormalizeChannelName(matches[0][1])
if !common.ValidChannelName(client.Channel) {
msg := fmt.Sprintf("streamhut: channel name %q is not available", client.Channel)
client.Netconn.Write([]byte(msg))
if err := client.Netconn.Close(); err != nil {
Expand Down Expand Up @@ -213,3 +286,28 @@ func (s *Server) shareURL(client *wsserver.Conn) string {
func (s *Server) Port() uint {
return s.port
}

// BandwidthQuotaLimit ...
func (s *Server) BandwidthQuotaLimit() BandwidthQuotaLimit {
return s.bandwidthQuotaLimit
}

// BandwidthQuotaDuration ...
func (s *Server) BandwidthQuotaDuration() time.Duration {
return s.bandwidthQuotaDuration
}

// BandwidthQuotaEnabled ...
func (s *Server) BandwidthQuotaEnabled() bool {
return s.bandwidthQuotaLimit > 0
}

// Uint64 ...
func (b BandwidthQuotaLimit) Uint64() uint64 {
return uint64(b)
}

// String ...
func (b BandwidthQuotaLimit) String() string {
return util.Uint64ToStorageSizeString(uint64(b))
}

0 comments on commit b248243

Please sign in to comment.