Skip to content

Commit

Permalink
Add socket permissions config option
Browse files Browse the repository at this point in the history
  • Loading branch information
grongor committed Jul 15, 2020
1 parent f073d50 commit 4968257
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
3 changes: 2 additions & 1 deletion cmd/snmp-proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (

type Configuration struct {
Api struct {
Listen string
Listen string
SocketPermissions os.FileMode // only ever used when Listen is a Unix socket
}
Common struct {
Debug bool
Expand Down
8 changes: 7 additions & 1 deletion cmd/snmp-proxy/snmp-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ func main() {
mibDataProvider := mib.NewDataProvider(displayHints)
requester := snmpproxy.NewGosnmpRequester(mibDataProvider)

apiListener := snmpproxy.NewApiListener(validator, requester, config.Logger, config.Api.Listen)
apiListener := snmpproxy.NewApiListener(
validator,
requester,
config.Logger,
config.Api.Listen,
config.Api.SocketPermissions,
)

err = apiListener.Start()
if err != nil {
Expand Down
26 changes: 18 additions & 8 deletions snmpproxy/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"time"

Expand All @@ -17,10 +18,11 @@ import (
)

type ApiListener struct {
validator *RequestValidator
requester Requester
logger *zap.SugaredLogger
server *http.Server
validator *RequestValidator
requester Requester
logger *zap.SugaredLogger
server *http.Server
socketPermissions os.FileMode
}

func (l *ApiListener) Start() error {
Expand All @@ -31,6 +33,12 @@ func (l *ApiListener) Start() error {

if strings.HasSuffix(l.server.Addr, ".sock") {
ln, err = net.Listen("unix", l.server.Addr)
if err == nil && l.socketPermissions != 0 {
err = os.Chmod(l.server.Addr, l.socketPermissions)
if err != nil {
ln.Close()
}
}
} else {
ln, err = net.Listen("tcp", l.server.Addr)
}
Expand Down Expand Up @@ -126,14 +134,16 @@ func NewApiListener(
requester Requester,
logger *zap.SugaredLogger,
listen string,
socketPermissions os.FileMode,
) *ApiListener {
mux := http.NewServeMux()

listener := &ApiListener{
validator: validator,
requester: requester,
logger: logger,
server: &http.Server{Addr: listen, Handler: mux},
validator: validator,
requester: requester,
logger: logger,
server: &http.Server{Addr: listen, Handler: mux},
socketPermissions: socketPermissions,
}

metricOpts := prometheus.HistogramOpts{
Expand Down
41 changes: 32 additions & 9 deletions snmpproxy/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestListenerErrorNotPost(t *testing.T) {

requester := &mockRequester{}

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

request := httptest.NewRequest("GET", "/snmp-proxy", errReader{})

Expand All @@ -79,7 +79,7 @@ func TestListenerErrorReadingRequest(t *testing.T) {

requester := &mockRequester{}

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

request := httptest.NewRequest("POST", "/snmp-proxy", errReader{})

Expand Down Expand Up @@ -117,7 +117,7 @@ func TestListenerErrorUnmarshalingRequest(t *testing.T) {
requester := &mockRequester{}
defer requester.AssertExpectations(t)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

request := httptest.NewRequest("POST", "/snmp-proxy", strings.NewReader(test.requestBody))

Expand All @@ -138,7 +138,7 @@ func TestListenerErrorRequestValidatorError(t *testing.T) {

requester := &mockRequester{}

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

const requestBody = `
{
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestListenerErrorRequesterError(t *testing.T) {

requester.On("ExecuteRequest", mock.Anything).Once().Return(nil, errors.New("some error"))

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

request := httptest.NewRequest("POST", "/snmp-proxy", strings.NewReader(getRequestBody))

Expand All @@ -195,7 +195,7 @@ func TestListenerNoError(t *testing.T) {

requester.On("ExecuteRequest", mock.Anything).Once().Return([][]interface{}{{".1.2.3", 123}}, nil)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "", 0)

request := httptest.NewRequest("POST", "/snmp-proxy", strings.NewReader(getRequestBody))

Expand All @@ -218,7 +218,7 @@ func TestStartAndClose(t *testing.T) {

requester.On("ExecuteRequest", mock.Anything).Once().Return([][]interface{}{{".1.2.3", 123}}, nil)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "localhost:15721")
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "localhost:15721", 0)
listener.Start()

time.Sleep(time.Millisecond * 10)
Expand Down Expand Up @@ -250,7 +250,7 @@ func TestStartAndCloseOnSocket(t *testing.T) {
require.NoError(f.Close())
require.NoError(os.Remove(f.Name()))

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), f.Name())
listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), f.Name(), 0)
err = listener.Start()
require.NoError(err)

Expand All @@ -276,12 +276,35 @@ func TestStartAndCloseOnSocket(t *testing.T) {
require.Error(err)
}

func TestStartSocketWithCorrectPermissions(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()

requester := &mockRequester{}

f, err := ioutil.TempFile("", "snmp-proxy-test-*.sock")
require.NoError(err)
require.NoError(f.Close())
require.NoError(os.Remove(f.Name()))

expectedMode := os.FileMode(0124)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), f.Name(), expectedMode)
err = listener.Start()
require.NoError(err)

stat, err := os.Stat(f.Name())
require.NoError(err)
require.Equal(expectedMode, stat.Mode().Perm())
}

func TestStartError(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()

listener := snmpproxy.NewApiListener(newValidator(), &mockRequester{}, zap.NewNop().Sugar(), "localhost:80")
listener := snmpproxy.NewApiListener(newValidator(), &mockRequester{}, zap.NewNop().Sugar(), "localhost:80", 0)
err := listener.Start()
require.EqualError(err, "listen tcp 127.0.0.1:80: bind: permission denied")
}
Expand Down

0 comments on commit 4968257

Please sign in to comment.