From 5c5068d3adc478f8c4437af4c69339f747b55826 Mon Sep 17 00:00:00 2001 From: Denis Issoupov Date: Mon, 27 Nov 2023 10:17:22 +0000 Subject: [PATCH] Ported from porto and xpki --- .VERSION | 1 + .github/dependabot.yml | 16 + .github/workflows/unittest.yml | 63 +++ .gitignore | 12 + .golangci.yml | 15 + .project/config.yml | 3 + .project/config_var.sh | 4 + .project/gomod-project.mk | 263 +++++++++++ .project/yaml.sh | 72 ++++ LICENSE | 20 + Makefile | 35 ++ configloader/README.md | 56 +++ configloader/configloader.go | 283 ++++++++++++ configloader/configloader_test.go | 213 +++++++++ configloader/expand.go | 125 ++++++ configloader/load.go | 114 +++++ configloader/load_test.go | 160 +++++++ .../testdata/override/custom_list.yaml | 9 + .../testdata/test_config-override.yaml | 8 + configloader/testdata/test_config.json | 6 + configloader/testdata/test_config.yaml | 44 ++ .../testdata/test_config.yaml.hostmap | 4 + fileutil/doc.go | 2 + fileutil/folders.go | 87 ++++ fileutil/folders_test.go | 78 ++++ fileutil/reloader/reloader.go | 130 ++++++ fileutil/reloader/reloader_test.go | 93 ++++ fileutil/resolve/resolve.go | 48 +++ fileutil/resolve/resolve_test.go | 82 ++++ fileutil/testdata/test_config.json | 6 + fileutil/testdata/test_config.yaml | 5 + flake/LICENSE | 21 + flake/README.md | 69 +++ flake/flake.go | 272 ++++++++++++ flake/flake_test.go | 302 +++++++++++++ go.mod | 32 ++ go.sum | 80 ++++ guid/guid.go | 17 + guid/guid_test.go | 15 + math/compare.go | 54 +++ math/compare_test.go | 112 +++++ netutil/freeport.go | 45 ++ netutil/freeport_test.go | 15 + netutil/localip.go | 83 ++++ netutil/localip_test.go | 53 +++ netutil/net.go | 54 +++ netutil/net_test.go | 51 +++ netutil/nodeinfo.go | 62 +++ netutil/nodeinfo_test.go | 26 ++ netutil/urls.go | 41 ++ netutil/urls_test.go | 121 ++++++ slices/slices.go | 254 +++++++++++ slices/slices_test.go | 408 ++++++++++++++++++ slices/uint64s.go | 19 + slices/uint64s_test.go | 20 + urlutil/urlutil.go | 52 +++ urlutil/urlutil_test.go | 40 ++ 57 files changed, 4375 insertions(+) create mode 100644 .VERSION create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/unittest.yml create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 .project/config.yml create mode 100755 .project/config_var.sh create mode 100644 .project/gomod-project.mk create mode 100755 .project/yaml.sh create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 configloader/README.md create mode 100644 configloader/configloader.go create mode 100644 configloader/configloader_test.go create mode 100644 configloader/expand.go create mode 100644 configloader/load.go create mode 100644 configloader/load_test.go create mode 100644 configloader/testdata/override/custom_list.yaml create mode 100644 configloader/testdata/test_config-override.yaml create mode 100644 configloader/testdata/test_config.json create mode 100644 configloader/testdata/test_config.yaml create mode 100644 configloader/testdata/test_config.yaml.hostmap create mode 100644 fileutil/doc.go create mode 100644 fileutil/folders.go create mode 100644 fileutil/folders_test.go create mode 100644 fileutil/reloader/reloader.go create mode 100644 fileutil/reloader/reloader_test.go create mode 100644 fileutil/resolve/resolve.go create mode 100644 fileutil/resolve/resolve_test.go create mode 100644 fileutil/testdata/test_config.json create mode 100644 fileutil/testdata/test_config.yaml create mode 100644 flake/LICENSE create mode 100644 flake/README.md create mode 100644 flake/flake.go create mode 100644 flake/flake_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 guid/guid.go create mode 100644 guid/guid_test.go create mode 100644 math/compare.go create mode 100644 math/compare_test.go create mode 100644 netutil/freeport.go create mode 100644 netutil/freeport_test.go create mode 100644 netutil/localip.go create mode 100644 netutil/localip_test.go create mode 100644 netutil/net.go create mode 100644 netutil/net_test.go create mode 100644 netutil/nodeinfo.go create mode 100644 netutil/nodeinfo_test.go create mode 100644 netutil/urls.go create mode 100644 netutil/urls_test.go create mode 100644 slices/slices.go create mode 100644 slices/slices_test.go create mode 100644 slices/uint64s.go create mode 100644 slices/uint64s_test.go create mode 100644 urlutil/urlutil.go create mode 100644 urlutil/urlutil_test.go diff --git a/.VERSION b/.VERSION new file mode 100644 index 0000000..aa33868 --- /dev/null +++ b/.VERSION @@ -0,0 +1 @@ +v0.1 \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..cdf3128 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" + # Enable version updates for Actions + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml new file mode 100644 index 0000000..95eb0ec --- /dev/null +++ b/.github/workflows/unittest.yml @@ -0,0 +1,63 @@ +name: Build + +on: + push: + branches: + - main + tags: + - "v*" + pull_request: + +jobs: + + UnitTest: + runs-on: ubuntu-latest + env: + ITEST_IMAGE_TAG: rc-${{ github.event.number }} + COMMIT_SHA: ${{ github.event.pull_request.head.sha }} + RUN_ID: ${{ github.run_id }} + PULL_NUMBER: ${{ github.event.pull_request.number }} + MIN_TESTCOV: 80 + + steps: + - name: Create code coverage status for the current commit + if: github.event_name == 'pull_request' + run: | + curl "https://${GIT_USER}:${GIT_TOKEN}@api.github.com/repos/${GITHUB_REPOSITORY}/statuses/${COMMIT_SHA}" -d "{\"state\": \"pending\",\"target_url\": \"https://github.com/${GITHUB_REPOSITORY}/pull/${PULL_NUMBER}/checks?check_run_id=${RUN_ID}\",\"description\": \"in progress — This check has started... \",\"context\": \"code cov\"}" + env: + GIT_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GIT_USER: ${{ github.actor }} + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version-file: go.mod + + - name: Prepare + run: make vars tools generate + + - name: UnitTest + run: make build covtest + + - name: Generate covarege Status + if: github.event_name == 'pull_request' + run: | + set -x + PROJECT_NAME=${PROJECT_NAME} + + total=`go tool cover -func=coverage.out | grep total | grep -Eo '[0-9]+\.[0-9]+'` + echo "total cov: $total" + (( $(echo "$total > ${MIN_TESTCOV}" | bc -l) )) && STATE=success || STATE=failure + curl "https://${GIT_USER}:${GIT_TOKEN}@api.github.com/repos/${GITHUB_REPOSITORY}/statuses/${COMMIT_SHA}" -d "{\"state\": \"${STATE}\",\"target_url\": \"https://github.com/${GITHUB_REPOSITORY}/pull/${PULL_NUMBER}/checks?check_run_id=${RUN_ID}\",\"description\": \"${total}%\",\"context\": \"code cov\"}" + env: + GIT_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GIT_USER: ${{ github.actor }} + + - name: coveralls + #if: github.event_name == 'pull_request' + env: + COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: make coveralls-github diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30ffb4b --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +/.coverage +coverage.out + +bin/ +.idea/ +.tmp/ +.docker/ +.DS_Store + +debug +debug.test +.vscode/launch.json \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..6eaaaa1 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,15 @@ +run: + skip-files: + - ".*_test\\.go$" + +linters: + enable: + - revive + +issues: + exclude: + - S1040 # type assertion to the same type: rawBody already has type io.ReadSeeker (gosimple) + - S1030 # should use w.String() instead of string(w.Bytes()) (gosimple) + - SA1019 # tc.gtc.OverrideServerName is deprecated: use grpc.WithAuthority instead. Will be supported throughout 1.x. (staticcheck) + - SA1019 # "io/ioutil" has been deprecated since Go 1.16: As of Go 1.16, the same functionality is now provided by package io or package os, and those implementations should be preferred in new code. See the specific function documentation for details. (staticcheck) + - SA1019 # "io/ioutil" has been deprecated since Go 1.16: As of Go 1.16, the same functionality is now provided by package io or package os, and those implementations should be preferred in new code. See the specific function documentation for details. (staticcheck) diff --git a/.project/config.yml b/.project/config.yml new file mode 100644 index 0000000..0cf4c04 --- /dev/null +++ b/.project/config.yml @@ -0,0 +1,3 @@ +project: + org: github.com/effective-security + name: x \ No newline at end of file diff --git a/.project/config_var.sh b/.project/config_var.sh new file mode 100755 index 0000000..60ed42e --- /dev/null +++ b/.project/config_var.sh @@ -0,0 +1,4 @@ +#!/bin/bash +source .project/yaml.sh +create_variables .project/config.yml +eval $(printf "echo $%s" "$1") diff --git a/.project/gomod-project.mk b/.project/gomod-project.mk new file mode 100644 index 0000000..69a92b5 --- /dev/null +++ b/.project/gomod-project.mk @@ -0,0 +1,263 @@ +# gomod-project.mk: this contains commonly used helpers for makefiles. +SHELL=/bin/bash + +# Used envaronment variables: +# +# PROJ_DIR +# project's absolute root directory +# +# PROJ_BIN +# project's bin folder +# +# ORG_NAME +# Git organization name, for example: github.com/go-phorce +# +# PROJ_NAME +# Git project name, for example: go-makefile +# +# REPO_NAME +# Git repo name consists of the org and project: github.com/go-phorce/go-makefile +# +# PROJ_GOFILES +# List of all .go files in the project, exluding vendor and tools +# +# Test flags: +# +# TEST_RACEFLAG +# Use -race when running go test +# +# TEST_GORACEOPTIONS +# Race options +# +# Functions: +# +# show_dep_updates {folder} +# Show dependencies updates in {folder} +# +# httpsclone {org} {repo} {destination_dir} +# +# go_test_cover +# +# go_test_cover_junit + + +PROJ_ROOT := $(shell pwd) + +## Project variables +ORG_NAME := $(shell .project/config_var.sh project_org) +PROJ_NAME := $(shell .project/config_var.sh project_name) +REPO_NAME := ${ORG_NAME}/${PROJ_NAME} +PROJ_PACKAGE := ${REPO_NAME} + +## Common variables +HOSTNAME := $(shell echo $$HOSTNAME) +UNAME := $(shell uname) +GITHUB_HOST := github.com +GOLANG_HOST := golang.org +# GIT_DIRTY is empty if the project is not modified, otherwise it's current host name +GIT_DIRTY := $(shell git describe --dirty --always --tags --long | grep -q -e '-dirty' && echo -$$HOSTNAME) +GIT_HASH := $(shell git rev-parse --short HEAD) +# number of commits +COMMITS_COUNT := $(shell git rev-list --count ${GIT_HASH}) +# +PROD_VERSION := $(shell cat .VERSION) +GIT_VERSION := $(shell printf %s.%d%s ${PROD_VERSION} ${COMMITS_COUNT} ${GIT_DIRTY}) +COVPATH=.coverage + +# List of all .go files in the project, excluding vendor and .tools +#GOFILES_NOVENDOR = $(shell find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./.tools/*" -not -path "./.gopath/*") + +export PROJ_DIR=$(PROJ_ROOT) +export PROJ_BIN=$(PROJ_ROOT)/bin +export GOBIN=$(PROJ_ROOT)/bin +export PATH := ${PATH}:${PROJ_BIN} + +# List of all .go files in the project, exluding vendor and tools +PROJ_GOFILES = $(shell find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./.gopath/*" -not -path "./.tools/*") + +COVERAGE_EXCLUSIONS="/rt\.go|/bindata\.go|_test\.go|_mock\.go|main\.go" + +# flags +INTEGRATION_TAG="integration" +TEST_RACEFLAG ?= +TEST_GORACEOPTIONS ?= + +# flag to enable golang race detector. Usage: make $(test_target) RACE=true. For example, make test RACE=true +RACE ?= +ifeq ($(RACE),true) + TEST_GORACEOPTIONS = "log_path=${PROJ_DIR}/${COVPATH}/race/report" + TEST_RACEFLAG = -race +endif + +## Common targets/functions for golang projects +# They assume that +# a) GOPATH has been set with an export GOPATH somewhere +# b) the Makefile variable PROJ_PACKAGE has been set to the name of the go pacakge to operate on +# + +# go_test_cover will run go test on a package tree, with code coverage turned on, it writes coverage results +# to ./${COVPATH} +# the 5 params are +# 1) the working dir to run the tests in +# 2) the flags to run the tests with +# 3) flag to enable race detector +# 4) options to race detector such as log_path for storing the results of the race detector +# 5) the name of the PROJ_DIR package to test +# 6) the list of source exclusions to apply to the generated code coverage result calculation +# +# it assumes you've built the cov-report tool into ${TOOLS_BIN} +# +define go_test_cover + echo "Testing in $(1)" + rm -rf ${COVPATH} + mkdir -p ${COVPATH}/race + exitCode=0 \ + && cd $(1) && go list $(5)/... | ( while read -r pkg; do \ + result=`GORACE=$(4) go test -p 40 $(2) $$pkg -coverpkg=$(5)/... -covermode=count $(3) \ + -coverprofile=${COVPATH}/cc_$$(echo $$pkg | tr "/" "_").out \ + 2>&1 | grep --invert-match "warning: no packages"` \ + && test_result=`echo "$$result" | tail -1` \ + && echo "$$test_result" \ + && if echo $$test_result | grep ^FAIL ; then \ + exitCode=1 && echo "Test for $$pkg failed. Result: $$result, exit code: $$exitCode" \ + ; fi \ + ; done \ + && echo "Completed with status code $$exitCode" \ + && if [ $$exitCode -ne "0" ] ; then echo "Test failed, exit code: $$exitCode" && exit $$exitCode ; fi ) + cov-report -ex $(6) -cc ${COVPATH}/combined.out ${COVPATH}/cc*.out + cp ${COVPATH}/combined.out ${PROJ_DIR}/coverage.out +endef + +# same as go_test_cover except it also generates results in the junit format +# assuming ${TOOLS_BIN} contains go-junit-report & cov-report +define go_test_cover_junit + echo "Testing in $(1)" + rm -rf ${COVPATH} + mkdir -p ${COVPATH}/race + set -o pipefail; failure=0; while read -r pkg; do \ + cd $(1) && GORACE=$(4) go test $(2) -v $$pkg -coverpkg=$(5)/... -covermode=count $(3) \ + -coverprofile=${COVPATH}/cc_$$(echo $$pkg | tr "/" "_").out \ + >> ${COVPATH}/citest_$$(echo $(5) | tr "/" "_").log \ + || failure=1; \ + done <<< "$$(cd $(1) && go list $(5)/...)" && \ + cat ${COVPATH}/citest_$$(echo $(5) | tr "/" "_").log | go-junit-report >> ${COVPATH}/citest_$$(echo $(5) | tr "/" "_").xml && \ + exit $$failure +endef + +# list the make targets +# http://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile/15058900#15058900 +no_targets__: +list: + sh -c "$(MAKE) -p no_targets__ | awk -F':' '/^[a-zA-Z0-9][^\$$#\/\\t=]*:([^=]|$$)/ {split(\$$1,A,/ /);for(i in A)print A[i]}' | grep -v '__\$$' | sort" + +# +# print environment variables +# +vars: + echo "PATH=$(PATH)" + echo "PROJ_DIR=$(PROJ_DIR)" + echo "PROJ_REPO_TARGET=$(PROJ_REPO_TARGET)" + echo "GOROOT=$(GOROOT)" + echo "GOBIN=$(GOBIN)" + echo "GOPATH=$(GOPATH)" + echo "PROJ_PACKAGE=$(PROJ_PACKAGE)" + echo "TOOLS_PATH=$(TOOLS_PATH)" + echo "GIT_VERSION=$(GIT_VERSION)" + go version + +# +# list packages +# +lspkg: + go list ./... + +# +# print out GO environment +# +env: + go env + +# +# GO test with bench +# +bench: + go test ${TEST_RACEFLAG} -bench . ${PROJ_PACKAGE}/... + +generate: + go generate ./... + +fmt: + echo "Running Fmt" + go fmt ./... + +vet: + echo "Running vet" + go vet ${BUILD_FLAGS} ${PROJ_PACKAGE}/... + +vulns: + echo "Running vulns" + govulncheck ${PROJ_PACKAGE}/... + +lint: fmt vet vulns + echo "Running lint" + golangci-lint run --timeout 10m0s + +test: + echo "Running test ${TEST_FLAGS} ${TEST_RACEFLAG}" + go test ${TEST_FLAGS} ${TEST_RACEFLAG} ${PROJ_PACKAGE}/... + +testshort: + echo "Running testshort" + go test ${BUILD_FLAGS} ${TEST_RACEFLAG} ./... --test.short + +# you can run a subset of tests with make sometests testname= +sometests: + go test ${BUILD_FLAGS} ${TEST_RACEFLAG} ./... --test.short -run $(testname) + +covtest: fmt vet + echo "Running covtest" + $(call go_test_cover,${PROJ_DIR},${BUILD_FLAGS},${TEST_RACEFLAG},${TEST_GORACEOPTIONS},.,${COVERAGE_EXCLUSIONS}) + +# Runs integration tests as well +testint: fmt vet lint + echo "Running testint" + go test ${TEST_RACEFLAG} -tags=${INTEGRATION_TAG} ${PROJ_PACKAGE}/... + +# shows the coverages results assuming they were already generated by a call to go_test_cover +coverage: + echo "Running coverage" + go tool cover -html="${COVPATH}/combined.out" + +# generates a HTML based code coverage report, and writes it to a file in the results directory +# assumes you've run go_test_cover (or go_test_cover_junit) +cicoverage: + echo "Running cicoverage" + mkdir -p ${COVPATH}/cover + go tool cover -html="${COVPATH}/combined.out" -o "${COVPATH}/cover/coverage.html" + +# as Jenkins runs citestint as well which will run all unit tests + integration tests with code coverage +# this unitest step can skip coverage reporting which speeds it up massively +citest: vet lint + echo "Running citest" + $(call go_test_cover_junit,${PROJ_DIR},${BUILD_FLAGS},${TEST_RACEFLAG},${TEST_GORACEOPTIONS},.,${COVERAGE_EXCLUSIONS}) + cov-report -fmt xml -o ${COVPATH}/coverage.xml -ex ${COVERAGE_EXCLUSIONS} -cc ${COVPATH}/combined.out ${COVPATH}/cc*.out + cov-report -fmt ds -o ${COVPATH}/summary.xml -ex ${COVERAGE_EXCLUSIONS} ${COVPATH}/cc*.out + +coveralls: + echo "Running coveralls" + goveralls -v -coverprofile=coverage.out -service=travis-ci -package ./... + +help: + echo "make vars - print make variables" + echo "make env - pring GO environment" + echo "make lspkg - list GO packeges in the current project" + echo "make generate - generate GO files" + echo "make bench - GO test with bench" + echo "make fmt - run go fmt on project files" + echo "make vet - run go vet on project files" + echo "make lint - run go lint on project files" + echo "make test - run test" + echo "make testshort - run test with -short flag" + echo "make covtest - run test with coverage report" + echo "make coverage - open coverage report" + echo "make coveralls - publish coverage to coveralls" diff --git a/.project/yaml.sh b/.project/yaml.sh new file mode 100755 index 0000000..04eac1d --- /dev/null +++ b/.project/yaml.sh @@ -0,0 +1,72 @@ +#!/bin/sh +#!/bin/bash +# +# MIT License +# +# Copyright (c) 2017 Jonathan Peres +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# https://github.com/jasperes/bash-yaml +# +function parse_yaml() { + local yaml_file=$1 + local prefix=$2 + local s + local w + local fs + + s='[[:space:]]*' + w='[a-zA-Z0-9_.-]*' + fs="$(echo @|tr @ '\034')" + + ( + sed -ne '/^--/s|--||g; s|\"|\\\"|g; s/\s*$//g;' \ + -e "/#.*[\"\']/!s| #.*||g; /^#/s|#.*||g;" \ + -e "s|^\($s\)\($w\)$s:$s\"\(.*\)\"$s\$|\1$fs\2$fs\3|p" \ + -e "s|^\($s\)\($w\)$s[:-]$s\(.*\)$s\$|\1$fs\2$fs\3|p" | + + awk -F"$fs" '{ + indent = length($1)/2; + if (length($2) == 0) { conj[indent]="+";} else {conj[indent]="";} + vname[indent] = $2; + for (i in vname) {if (i > indent) {delete vname[i]}} + if (length($3) > 0) { + vn=""; for (i=0; i 0 { + overrideCfg, _, err := f.ResolveConfigFile(f.overrideCfg) + if err != nil { + return err + } + logger.KV(xlog.TRACE, "override", overrideCfg) + ops = append(ops, yamlcfg.File(overrideCfg)) + } + + provider, err := yamlcfg.NewYAML(ops...) + if err != nil { + return errors.Wrap(err, "failed to load configuration") + } + + err = provider.Get(yamlcfg.Root).Populate(config) + if err != nil { + return errors.Wrap(err, "failed to parse configuration") + } + + return nil +} + +func (f *Factory) getVariableValues(environment string) map[string]string { + ret := map[string]string{ + "${HOSTNAME}": f.nodeInfo.HostName(), + "${NODENAME}": f.nodeInfo.NodeName(), + "${LOCALIP}": f.nodeInfo.LocalIP(), + "${USER}": f.userName(), + "${NORMALIZED_USER}": f.normalizedUserName(), + "${ENVIRONMENT}": environment, + "${ENVIRONMENT_UPPERCASE}": strings.ToUpper(environment), + } + + if len(f.envPrefix) > 0 { + for _, x := range os.Environ() { + kvp := strings.SplitN(x, "=", 2) + + env, val := kvp[0], kvp[1] + if strings.HasPrefix(env, f.envPrefix) { + formattedKey := fmt.Sprintf("${%v}", env) + if _, ok := ret[formattedKey]; !ok { + logger.KV(xlog.DEBUG, "set", formattedKey) + ret[formattedKey] = val + } + } + } + } + + return ret +} + +// ResolveConfigFile returns absolute path for the config file +func (f *Factory) ResolveConfigFile(configFile string) (absConfigFile, baseDir string, err error) { + if configFile == "" { + panic("config file not provided!") + //configFile = ConfigFileName + } + + if filepath.IsAbs(configFile) { + // for absolute, use the folder containing the config file + baseDir = filepath.Dir(configFile) + absConfigFile = configFile + return + } + + for _, absDir := range f.searchDirs { + absConfigFile, err = resolve.File(configFile, absDir) + if err == nil && absConfigFile != "" { + baseDir = absDir + logger.KV(xlog.DEBUG, "resolved", absConfigFile) + return + } + } + + err = errors.Errorf("file %q not found in [%s]", configFile, strings.Join(f.searchDirs, ",")) + return +} + +func (f *Factory) userName() string { + if f.user == nil { + userName := userName() + f.user = &userName + } + return *f.user +} + +func (f *Factory) normalizedUserName() string { + username := f.userName() + return strings.Replace(username, ".", "", -1) +} + +func userName() string { + u, err := user.Current() + if err != nil { + logger.Panicf("unable to determine current user: %v", err) + } + return u.Username +} diff --git a/configloader/configloader_test.go b/configloader/configloader_test.go new file mode 100644 index 0000000..f7fe0c8 --- /dev/null +++ b/configloader/configloader_test.go @@ -0,0 +1,213 @@ +package configloader + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestNewFactory(t *testing.T) { + f, err := NewFactory(nil, nil, "PORTO_") + assert.NoError(t, err) + assert.NotNil(t, f) + + var c struct{} + + _, err = f.Load("notfound-config.yaml", &c) + require.Error(t, err) + assert.Equal(t, `file "notfound-config.yaml" not found in []`, err.Error()) +} + +func TestLoadYAML(t *testing.T) { + cfgFile, err := GetAbsFilename("testdata/test_config.yaml", ".") + require.NoError(t, err, "unable to determine config file") + + f, err := NewFactory(nil, []string{"testdata/override"}, "PORTO_") + require.NoError(t, err) + + c := new(configuration) + _, err = f.Load(cfgFile, c) + assert.EqualError(t, err, "environment variable not set: NODENAME") + + t.Setenv("NODENAME", "cluster1") + c = new(configuration) + _, err = f.Load(cfgFile, c) + assert.EqualError(t, err, "secret loader not provided") + assert.Equal(t, "cluster1", c.ClusterName) +} + +func TestLoadYAMLOverrideByHostname(t *testing.T) { + cfgFile, err := GetAbsFilename("testdata/test_config.yaml", ".") + require.NoError(t, err, "unable to determine config file") + + f, err := NewFactory(nil, []string{"testdata/override"}, "TEST_") + require.NoError(t, err) + + sp := &mockSecret{ + secrets: map[string]string{ + "secret1": "api-key1", + "secret2": "api-key2", + }, + } + f.WithSecretProvider(sp) + + t.Setenv("TEST_HOSTNAME", "UNIT_TEST") + + c := new(configuration) + _, err = f.Load(cfgFile, c) + assert.EqualError(t, err, "environment variable not set: NODENAME") + + c = new(configuration) + t.Setenv("NODENAME", "UNIT_TEST") + _, err = f.Load(cfgFile, c) + require.NoError(t, err, "failed to load config: %v", cfgFile) + assert.Equal(t, "UNIT_TEST", c.Environment) // lower cased + assert.Equal(t, "local", c.Region) + assert.Equal(t, "porto-pod", c.ServiceName) + assert.NotEmpty(t, c.ClusterName) + + assert.Equal(t, "api-key2", c.ClientAPIKey) + + assert.Equal(t, fmt.Sprintf("/tmp/porto-%s/logs", c.Environment), c.Logs.Directory) + assert.Equal(t, 3, c.Logs.MaxAgeDays) + assert.Equal(t, 10, c.Logs.MaxSizeMb) + + assert.Equal(t, fmt.Sprintf("/tmp/porto-%s/audit", c.Environment), c.Audit.Directory) + assert.Equal(t, 99, c.Audit.MaxAgeDays) + assert.Equal(t, 99, c.Audit.MaxSizeMb) + + assert.Equal(t, "UNIT_TEST", c.Templates["environment"]) + assert.Equal(t, "UNIT_TEST", c.Templates["ENVIRONMENT"]) + + b, err := yaml.Marshal(c) + require.NoError(t, err) + assert.NotContains(t, string(b), "${") + + for k, v := range c.Templates { + assert.NotContains(t, v, "${", "%s is not extrapolated: %s", k, v) + } + + for idx, v := range c.List { + assert.NotContains(t, v, "${", "list[%d] is not extrapolated: %s", idx, v) + } + assert.Len(t, c.List, 4) +} + +func TestLoadYAMLWithOverride(t *testing.T) { + cfgFile, err := GetAbsFilename("testdata/test_config.yaml", ".") + require.NoError(t, err, "unable to determine config file") + + f, err := NewFactory(nil, []string{"testdata/override"}, "TEST_") + require.NoError(t, err) + + os.Setenv("TEST_HOSTNAME", "UNIT_TEST") + os.Setenv("NODENAME", "UNIT_TEST") + + f.WithOverride("custom_list.yaml") + f.WithEnvironment("test2") + + sp := &mockSecret{ + secrets: map[string]string{ + "secret1": "api-key1", + "secret2": "api-key2", + }, + } + f.WithSecretProvider(sp) + + var c configuration + _, err = f.Load(cfgFile, &c) + require.NoError(t, err, "failed to load config: %v", cfgFile) + assert.Equal(t, "test2", c.Environment) + assert.Equal(t, "test-override", c.Region) + assert.Equal(t, "porto-pod", c.ServiceName) + assert.NotEmpty(t, c.ClusterName) + assert.Equal(t, "api-key2", c.ClientAPIKey) + + assert.Equal(t, fmt.Sprintf("/tmp/porto-%s/logs", c.Environment), c.Logs.Directory) + assert.Equal(t, 3, c.Logs.MaxAgeDays) + assert.Equal(t, 10, c.Logs.MaxSizeMb) + + assert.Equal(t, fmt.Sprintf("/tmp/porto-%s/audit", c.Environment), c.Audit.Directory) + assert.Equal(t, 99, c.Audit.MaxAgeDays) + assert.Equal(t, 99, c.Audit.MaxSizeMb) + + assert.Equal(t, "test2", c.Templates["environment"]) + assert.Equal(t, "TEST2", c.Templates["ENVIRONMENT"]) + + b, err := yaml.Marshal(c) + require.NoError(t, err) + assert.NotContains(t, string(b), "${") + + for k, v := range c.Templates { + assert.NotContains(t, v, "${", "%s is not extrapolated: %s", k, v) + } + + for idx, v := range c.List { + assert.NotContains(t, v, "${", "list[%d] is not extrapolated: %s", idx, v) + } + assert.Len(t, c.List, 5) +} + +// configuration contains the user configurable data for the service +type configuration struct { + + // Region specifies the Region / Datacenter where the instance is running + Region string `json:"region,omitempty" yaml:"region,omitempty"` + + // Environment specifies the environment where the instance is running: prod|stage|dev + Environment string `json:"environment,omitempty" yaml:"environment,omitempty"` + + // ServiceName specifies the service name to be used in logs, metrics, etc + ServiceName string `json:"service,omitempty" yaml:"service,omitempty"` + + // ClusterName specifies the cluster name + ClusterName string `json:"cluster,omitempty" yaml:"cluster,omitempty"` + + // ClientAPIKey specifies the API key + ClientAPIKey string `json:"client_api_key,omitempty" yaml:"client_api_key,omitempty"` + + // Audit contains configuration for the audit logger + Audit Logger `json:"audit" yaml:"audit"` + + // Logs contains configuration for the logger + Logs Logger `json:"logs" yaml:"logs"` + + Templates map[string]string `json:"templates" yaml:"templates"` + + List []string `json:"list" yaml:"list"` + + Map map[string]*Logger `json:"map_log" yaml:"map_log"` +} + +// Logger contains information about the configuration of a logger/log rotation +type Logger struct { + + // Directory contains where to store the log files; if value is empty, them stderr is used for output + Directory string `json:"directory,omitempty" yaml:"directory,omitempty"` + + // MaxAgeDays controls how old files are before deletion + MaxAgeDays int `json:"max_age_days,omitempty" yaml:"max_age_days,omitempty"` + + // MaxSizeMb contols how large a single log file can be before its rotated + MaxSizeMb int `json:"max_size_mb,omitempty" yaml:"max_size_mb,omitempty"` +} + +type mockSecret struct { + secrets map[string]string +} + +func (s *mockSecret) GetSecret(name string) (string, error) { + tokens := strings.Split(name, "/") + sec := s.secrets[tokens[0]] + if sec != "" { + return sec, nil + + } + return "", errors.Errorf("secret not found: %s", name) +} diff --git a/configloader/expand.go b/configloader/expand.go new file mode 100644 index 0000000..2b79129 --- /dev/null +++ b/configloader/expand.go @@ -0,0 +1,125 @@ +package configloader + +import ( + "os" + "reflect" + "strings" + + "github.com/effective-security/xlog" + "github.com/pkg/errors" +) + +// Expander is used to expand variables in the input object +type Expander struct { + Variables map[string]string + SecretProvider SecretProvider +} + +// ExpandAll replace variables in the input object, using default Expander. +// The input object must be a pointer to a struct. +// If secrets are used, SecretProviderInstance must be set. +// The values started with env:// , file:// or secret:// must be resolved. +// The values inside ${} will be tried to be resolved, +// if not found will be substiduted with empy values as per os.Getenv function. +func ExpandAll(obj interface{}) error { + e := Expander{SecretProvider: SecretProviderInstance} + return e.ExpandAll(obj) +} + +// ExpandAll replace variables in the input object +func (f *Expander) ExpandAll(obj interface{}) error { + return f.doSubstituteEnvVars(reflect.ValueOf(obj)) +} + +// Expand replace variables in the input string +func (f *Expander) Expand(s string) (string, error) { + // try first prefix + s, err := ResolveValueWithSecrets(s, f.SecretProvider) + if err != nil { + return s, err + } + + if strings.Contains(s, "${") { + for key, value := range f.Variables { + s = strings.Replace(s, key, value, -1) + } + } + + if strings.Contains(s, "${") { + s = os.Expand(s, func(env string) string { + if strings.HasPrefix(env, SecretSource) && f.SecretProvider != nil { + name := strings.TrimPrefix(env, SecretSource) + sec, err := f.SecretProvider.GetSecret(name) + if err != nil { + logger.KV(xlog.ERROR, "secret", name, "err", err.Error()) + } + return sec + } + + if va, ok := f.Variables[env]; ok { + return va + } + return os.Getenv(env) + }) + } + if strings.Contains(s, "${") { + return s, errors.Errorf("unable to resolve variables: %s", s) + } + return s, nil +} + +func (f *Expander) doSubstituteEnvVars(v reflect.Value) error { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if !v.IsValid() { + return nil + } + + switch v.Kind() { + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + if err := f.doSubstituteEnvVars(v.Field(i)); err != nil { + return err + } + } + case reflect.Slice: + for i := 0; i < v.Len(); i++ { + if err := f.doSubstituteEnvVars(v.Index(i)); err != nil { + return err + } + } + case reflect.String: + if v.CanSet() { + val, err := f.Expand(v.String()) + if err != nil { + return err + } + v.SetString(val) + } + case reflect.Ptr: + if err := f.doSubstituteEnvVars(v.Elem()); err != nil { + return err + } + case reflect.Map: + if v.Type().String() == "map[string]string" { + m := v.Interface().(map[string]string) + for k, v := range m { + val, err := f.Expand(v) + if err != nil { + return err + } + m[k] = val + } + } else { + iter := v.MapRange() + for iter.Next() { + if err := f.doSubstituteEnvVars(iter.Value()); err != nil { + return err + } + } + } + default: + } + return nil +} diff --git a/configloader/load.go b/configloader/load.go new file mode 100644 index 0000000..4d27e43 --- /dev/null +++ b/configloader/load.go @@ -0,0 +1,114 @@ +package configloader + +import ( + "encoding/json" + "os" + "strings" + + "github.com/pkg/errors" + "gopkg.in/yaml.v3" +) + +const ( + // FileSource specifies to load config from a file + FileSource = "file://" + // EnvSource specifies to load config from an environment variable + EnvSource = "env://" + // SecretSource specifies to load config from a secret manager + SecretSource = "secret://" +) + +// SecretProvider is an interface to provide secrets +type SecretProvider interface { + GetSecret(name string) (string, error) +} + +// SecretProviderInstance is a global instance of SecretLoader +var SecretProviderInstance SecretProvider + +// ResolveValue returns value loaded from file:// or env:// +// If val does not start with file:// or env://, then the value is returned as is +func ResolveValue(val string) (string, error) { + return ResolveValueWithSecrets(val, SecretProviderInstance) +} + +// ResolveValue returns value loaded from file:// or env:// +// If val does not start with file:// or env://, then the value is returned as is +func ResolveValueWithSecrets(val string, loader SecretProvider) (string, error) { + if strings.HasPrefix(val, FileSource) { + fn := strings.TrimPrefix(val, FileSource) + f, err := os.ReadFile(fn) + if err != nil { + return val, errors.WithStack(err) + } + // file content + val = string(f) + } else if strings.HasPrefix(val, EnvSource) { + env := strings.TrimPrefix(val, EnvSource) + // ENV content + val = os.Getenv(env) + if val == "" { + return "", errors.Errorf("environment variable not set: %s", env) + } + } else if strings.HasPrefix(val, SecretSource) { + if loader == nil { + return "", errors.Errorf("secret loader not provided") + } + name := strings.TrimPrefix(val, SecretSource) + sec, err := loader.GetSecret(name) + if err != nil { + return val, errors.WithMessage(err, "unable to load secret") + } + val = sec + } + + return val, nil +} + +// UnmarshalAndExpand load JSON or YAML file to an interface and expands variables +func UnmarshalAndExpand(file string, v interface{}) error { + err := Unmarshal(file, v) + if err != nil { + return err + } + + return ExpandAll(v) +} + +// Unmarshal JSON or YAML file to an interface +func Unmarshal(file string, v interface{}) error { + b, err := os.ReadFile(file) + if err != nil { + return errors.WithMessagef(err, "unable to read file") + } + + if strings.HasSuffix(file, ".json") { + err = json.Unmarshal(b, v) + if err != nil { + return errors.WithMessagef(err, "unable parse JSON: %s", file) + } + } else { + err = yaml.Unmarshal(b, v) + if err != nil { + return errors.WithMessagef(err, "unable parse YAML: %s", file) + } + } + return nil +} + +// Marshal saves object to file +func Marshal(fn string, value interface{}) error { + var data []byte + var err error + if strings.HasSuffix(fn, ".json") { + data, err = json.MarshalIndent(value, "", " ") + } else { + data, err = yaml.Marshal(value) + } + + if err != nil { + return errors.WithMessage(err, "failed to encode") + } + + return os.WriteFile(fn, data, os.ModePerm) +} diff --git a/configloader/load_test.go b/configloader/load_test.go new file mode 100644 index 0000000..8547e12 --- /dev/null +++ b/configloader/load_test.go @@ -0,0 +1,160 @@ +package configloader_test + +import ( + "os" + "path" + "strings" + "testing" + + "github.com/effective-security/x/configloader" + "github.com/effective-security/x/guid" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_LoadConfigWithSchema_plain(t *testing.T) { + c, err := configloader.ResolveValue("test_data") + require.NoError(t, err) + assert.Equal(t, "test_data", c) +} + +func Test_LoadConfigWithSchema_file(t *testing.T) { + c, err := configloader.ResolveValue("file://./load.go") + require.NoError(t, err) + require.NotEmpty(t, c) + assert.Contains(t, c, "package configloader") +} + +func Test_SaveConfigWithSchema_file(t *testing.T) { + tmpDir := t.TempDir() + file := path.Join(tmpDir, guid.MustCreate()) + cfg := "file://" + file + + err := os.WriteFile(file, []byte("test"), os.ModePerm) + require.NoError(t, err) + + c, err := configloader.ResolveValue(cfg) + require.NoError(t, err) + assert.Equal(t, "test", c) + + t.Setenv("PORTO_TEST", "test") + c, err = configloader.ResolveValue("env://PORTO_TEST") + require.NoError(t, err) + assert.Equal(t, "test", c) +} + +func Test_ConfigWithSchema_Secret(t *testing.T) { + cfg := "secret://key1" + + configloader.SecretProviderInstance = nil + _, err := configloader.ResolveValue(cfg) + assert.EqualError(t, err, "secret loader not provided") + + configloader.SecretProviderInstance = &mockSecret{ + secrets: map[string]string{ + "key1": "value1", + }, + } + + val, err := configloader.ResolveValue(cfg) + require.NoError(t, err) + assert.Equal(t, "value1", val) + + _, err = configloader.ResolveValue("secret://key2") + assert.EqualError(t, err, "unable to load secret: secret not found: key2") +} + +type mockSecret struct { + secrets map[string]string +} + +func (s *mockSecret) GetSecret(name string) (string, error) { + tokens := strings.Split(name, "/") + sec := s.secrets[tokens[0]] + if sec != "" { + return sec, nil + + } + return "", errors.Errorf("secret not found: %s", name) +} + +type config struct { + Service string + Region string + Cluster string + Environment string +} + +func Test_Unmarshal(t *testing.T) { + tmp := t.TempDir() + + var v config + err := configloader.Unmarshal("testdata/test_config.yaml", &v) + require.NoError(t, err) + + assert.Equal(t, "porto-pod", v.Service) + assert.Equal(t, "local", v.Region) + assert.Equal(t, "env://NODENAME", v.Cluster) + assert.Equal(t, "test", v.Environment) + + fn := path.Join(tmp, "test_config.yaml") + err = configloader.Marshal(fn, &v) + require.NoError(t, err) + + var v2 config + err = configloader.Unmarshal(fn, &v2) + require.NoError(t, err) + assert.Equal(t, v, v2) + + err = configloader.Unmarshal("testdata/test_config.json", &v) + require.NoError(t, err) + + assert.Equal(t, "porto-pod", v.Service) + assert.Equal(t, "local", v.Region) + assert.Equal(t, "${NODENAME}", v.Cluster) + assert.Equal(t, "test", v.Environment) + + fn = path.Join(tmp, "test_config.json") + err = configloader.Marshal(fn, &v) + require.NoError(t, err) + encoded, err := os.ReadFile(fn) + require.NoError(t, err) + assert.Equal(t, + `{ + "Service": "porto-pod", + "Region": "local", + "Cluster": "${NODENAME}", + "Environment": "test" +}`, + string(encoded)) + + err = configloader.Unmarshal(fn, &v2) + require.NoError(t, err) + assert.Equal(t, v, v2) +} + +func Test_UnmarshalAndExpand(t *testing.T) { + configloader.SecretProviderInstance = nil + t.Setenv("NODENAME", "") + + v := new(config) + err := configloader.UnmarshalAndExpand("testdata/test_config.yaml", v) + assert.EqualError(t, err, "environment variable not set: NODENAME") + + configloader.SecretProviderInstance = &mockSecret{ + secrets: map[string]string{ + "secret1": "api-key1", + "secret2": "api-key2", + }, + } + t.Setenv("NODENAME", "cluster1") + + v = new(config) + err = configloader.UnmarshalAndExpand("testdata/test_config.yaml", v) + require.NoError(t, err) + assert.Equal(t, "porto-pod", v.Service) + assert.Equal(t, "local", v.Region) + assert.Equal(t, "cluster1", v.Cluster) + assert.Equal(t, "test", v.Environment) +} diff --git a/configloader/testdata/override/custom_list.yaml b/configloader/testdata/override/custom_list.yaml new file mode 100644 index 0000000..ea0508e --- /dev/null +++ b/configloader/testdata/override/custom_list.yaml @@ -0,0 +1,9 @@ +--- +region : test-override + +list: + - ${USER} + - ${NORMALIZED_USER} + - ${ENVIRONMENT} + - ${HOSTNAME}/${NODENAME} + - ${USER}@${HOSTNAME}:${LOCALIP} \ No newline at end of file diff --git a/configloader/testdata/test_config-override.yaml b/configloader/testdata/test_config-override.yaml new file mode 100644 index 0000000..b6ed6c2 --- /dev/null +++ b/configloader/testdata/test_config-override.yaml @@ -0,0 +1,8 @@ +--- +environment: UNIT_TEST + +audit: + max_age_days: 99 + max_size_mb: 99 + +client_api_key: secret://secret2/api-key2 \ No newline at end of file diff --git a/configloader/testdata/test_config.json b/configloader/testdata/test_config.json new file mode 100644 index 0000000..fb80c21 --- /dev/null +++ b/configloader/testdata/test_config.json @@ -0,0 +1,6 @@ +{ + "service": "porto-pod", + "region": "local", + "cluster": "${NODENAME}", + "environment": "test" +} \ No newline at end of file diff --git a/configloader/testdata/test_config.yaml b/configloader/testdata/test_config.yaml new file mode 100644 index 0000000..661bc4e --- /dev/null +++ b/configloader/testdata/test_config.yaml @@ -0,0 +1,44 @@ +--- +service: porto-pod +region : local +cluster: env://NODENAME +environment: test + +# configuration for the logger +logs: + # contains where to store the log files; if value is empty, them stderr is used for output + directory: /tmp/porto-${ENVIRONMENT}/logs + # controls how old files are before deletion / rotation + max_age_days: 3 + # contols how large a single log file can be before its rotated + max_size_mb: 10 + +# configuration for the audit logger +audit: + directory: /tmp/porto-${ENVIRONMENT}/audit + max_age_days: 14 + max_size_mb: 10 + +templates: + hostname: ${HOSTNAME} + nodename: ${NODENAME} + ip: ${LOCALIP} + user: domain/${USER} + nuser: ${NORMALIZED_USER} + environment: ${ENVIRONMENT} + ENVIRONMENT: ${ENVIRONMENT_UPPERCASE} + +list: + - ${USER} + - ${ENVIRONMENT} + - ${HOSTNAME}/${NODENAME} + - ${USER}@${HOSTNAME}:${LOCALIP} + +map_log: + logs: + directory: /tmp/${USER}/logs + audits: + directory: /tmp/${USER}/audit + + +client_api_key: secret://secret1/api-key1 diff --git a/configloader/testdata/test_config.yaml.hostmap b/configloader/testdata/test_config.yaml.hostmap new file mode 100644 index 0000000..a5aa8c9 --- /dev/null +++ b/configloader/testdata/test_config.yaml.hostmap @@ -0,0 +1,4 @@ +--- +override: + UNIT_TEST: test_config-override.yaml + ~: notfound.yaml diff --git a/fileutil/doc.go b/fileutil/doc.go new file mode 100644 index 0000000..de57821 --- /dev/null +++ b/fileutil/doc.go @@ -0,0 +1,2 @@ +// Package fileutil provides utilities for file operations +package fileutil diff --git a/fileutil/folders.go b/fileutil/folders.go new file mode 100644 index 0000000..5a0eb9e --- /dev/null +++ b/fileutil/folders.go @@ -0,0 +1,87 @@ +package fileutil + +import ( + "os" + + "github.com/pkg/errors" +) + +// FolderExists ensures that folder exists +func FolderExists(dir string) error { + if dir == "" { + return errors.Errorf("invalid parameter: dir") + } + + stat, err := os.Stat(dir) + if err != nil { + return errors.WithStack(err) + } + + if !stat.IsDir() { + return errors.Errorf("not a folder: %q", dir) + } + + return nil +} + +// FileExists ensures that file exists +func FileExists(file string) error { + if file == "" { + return errors.Errorf("invalid parameter: file") + } + + stat, err := os.Stat(file) + if err != nil { + return errors.WithStack(err) + } + + if stat.IsDir() { + return errors.Errorf("not a file: %q", file) + } + + return nil +} + +// SubfolderNames returns list of subfolders in provided folder +func SubfolderNames(folder string) ([]string, error) { + var list []string + + f, err := os.Open(folder) + if err != nil { + return nil, errors.WithStack(err) + } + dirs, err := f.ReadDir(-1) + f.Close() + if err != nil { + return nil, errors.WithStack(err) + } + for _, d := range dirs { + if d.IsDir() { + list = append(list, d.Name()) + } + } + + return list, nil +} + +// FileNames returns list of files in provided folder. +func FileNames(folder string) ([]string, error) { + var list []string + + f, err := os.Open(folder) + if err != nil { + return nil, errors.WithStack(err) + } + dirs, err := f.ReadDir(-1) + f.Close() + if err != nil { + return nil, errors.WithStack(err) + } + for _, d := range dirs { + if !d.IsDir() { + list = append(list, d.Name()) + } + } + + return list, nil +} diff --git a/fileutil/folders_test.go b/fileutil/folders_test.go new file mode 100644 index 0000000..9adae0a --- /dev/null +++ b/fileutil/folders_test.go @@ -0,0 +1,78 @@ +package fileutil_test + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "testing" + + "github.com/effective-security/x/fileutil" + "github.com/effective-security/x/guid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_FolderExists(t *testing.T) { + tmpDir := path.Join(os.TempDir(), "fileutil-test", guid.MustCreate()) + + err := os.MkdirAll(tmpDir, os.ModePerm) + require.NoError(t, err) + + defer os.RemoveAll(tmpDir) + + assert.Error(t, fileutil.FolderExists("")) + assert.NoError(t, fileutil.FolderExists(tmpDir)) + + err = fileutil.FolderExists(tmpDir + "/a") + require.Error(t, err) + assert.Equal(t, fmt.Sprintf("stat %s: no such file or directory", tmpDir+"/a"), err.Error()) + + err = fileutil.FolderExists("./folders.go") + require.Error(t, err) + assert.Equal(t, "not a folder: \"./folders.go\"", err.Error()) +} + +func Test_FileExists(t *testing.T) { + tmpDir := path.Join(os.TempDir(), "fileutil-test", guid.MustCreate()) + + err := os.MkdirAll(tmpDir, os.ModePerm) + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + file := path.Join(tmpDir, "file.txt") + err = ioutil.WriteFile(file, []byte("FileExists"), 0644) + require.NoError(t, err) + + assert.Error(t, fileutil.FileExists("")) + assert.NoError(t, fileutil.FileExists(file)) + + err = fileutil.FileExists(tmpDir) + require.Error(t, err) + assert.Equal(t, fmt.Sprintf("not a file: %q", tmpDir), err.Error()) + + err = fileutil.FileExists(tmpDir + "/a") + require.Error(t, err) + assert.Equal(t, fmt.Sprintf("stat %s: no such file or directory", tmpDir+"/a"), err.Error()) +} + +func Test_SubfolderNames(t *testing.T) { + l, err := fileutil.SubfolderNames(".") + require.NoError(t, err) + assert.Len(t, l, 3) + + _, err = fileutil.SubfolderNames("./notfound") + assert.EqualError(t, err, "open ./notfound: no such file or directory") + + _, err = fileutil.SubfolderNames("./folders.go") + assert.EqualError(t, err, "readdirent ./folders.go: not a directory") +} + +func Test_FileNames(t *testing.T) { + l, err := fileutil.FileNames("./resolve") + require.NoError(t, err) + assert.Len(t, l, 2) + + _, err = fileutil.FileNames("./notfound") + assert.EqualError(t, err, "open ./notfound: no such file or directory") +} diff --git a/fileutil/reloader/reloader.go b/fileutil/reloader/reloader.go new file mode 100644 index 0000000..b3d2345 --- /dev/null +++ b/fileutil/reloader/reloader.go @@ -0,0 +1,130 @@ +package reloader + +import ( + "os" + "sync" + "sync/atomic" + "time" + + "github.com/effective-security/xlog" + "github.com/pkg/errors" +) + +var logger = xlog.NewPackageLogger("github.com/effective-security/x/fileutil", "reloader") + +// Wrap time.Tick so we can override it in tests. +var makeTicker = func(interval time.Duration) (func(), <-chan time.Time) { + t := time.NewTicker(interval) + return t.Stop, t.C +} + +// OnChangedFunc is a called when the file has been modified +type OnChangedFunc func(filePath string, modifiedAt time.Time) + +// Reloader keeps necessary info to provide reloaded certificate +type Reloader struct { + lock sync.RWMutex + loadedAt time.Time + count uint32 + filePath string + fileModifiedAt time.Time + onChangedFunc OnChangedFunc + inProgress bool + stopChan chan<- struct{} + closed bool +} + +// NewReloader return an instance of the file re-loader +func NewReloader(filePath string, checkInterval time.Duration, onChangedFunc OnChangedFunc) (*Reloader, error) { + result := &Reloader{ + filePath: filePath, + onChangedFunc: onChangedFunc, + stopChan: make(chan struct{}), + } + + logger.KV(xlog.INFO, "status", "started", "file", filePath) + + stopChan := make(chan struct{}) + result.stopChan = stopChan + tickerStop, tickChan := makeTicker(checkInterval) + go func() { + for { + select { + case <-stopChan: + tickerStop() + logger.KV(xlog.INFO, "status", "closed", "count", result.LoadedCount(), "file", filePath) + return + case <-tickChan: + modified := false + fi, err := os.Stat(filePath) + if err == nil { + modified = fi.ModTime().After(result.fileModifiedAt) + if modified { + result.fileModifiedAt = fi.ModTime() + err := result.Reload() + if err != nil { + logger.KV(xlog.ERROR, "err", err) + } + } + } else { + logger.KV(xlog.WARNING, "reason", "stat", "file", filePath, "err", err) + } + } + } + }() + return result, nil +} + +// Reload will explicitly call the callback function +func (k *Reloader) Reload() error { + k.lock.Lock() + if k.inProgress { + k.lock.Unlock() + return nil + } + + k.inProgress = true + defer func() { + k.inProgress = false + k.lock.Unlock() + }() + + atomic.AddUint32(&k.count, 1) + k.loadedAt = time.Now().UTC() + + go k.onChangedFunc(k.filePath, k.fileModifiedAt) + + return nil +} + +// LoadedAt return the last time when the pair was loaded +func (k *Reloader) LoadedAt() time.Time { + k.lock.RLock() + defer k.lock.RUnlock() + + return k.loadedAt +} + +// LoadedCount returns the number of times the pair was loaded from disk +func (k *Reloader) LoadedCount() uint32 { + return atomic.LoadUint32(&k.count) +} + +// Close will close the reloader and release its resources +func (k *Reloader) Close() error { + if k == nil { + return nil + } + + k.lock.RLock() + defer k.lock.RUnlock() + + if k.closed { + return errors.New("already closed") + } + + k.closed = true + k.stopChan <- struct{}{} + + return nil +} diff --git a/fileutil/reloader/reloader_test.go b/fileutil/reloader/reloader_test.go new file mode 100644 index 0000000..1fb5c66 --- /dev/null +++ b/fileutil/reloader/reloader_test.go @@ -0,0 +1,93 @@ +package reloader_test + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" + + "github.com/effective-security/x/fileutil/reloader" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Reloader(t *testing.T) { + now := time.Now().UTC() + + file := filepath.Join(os.TempDir(), "test-reloaded.txt") + + callbackCount := 0 + lastModifiedAt := time.Now() + onChangedFunc := func(fn string, modifiedAt time.Time) { + assert.Equal(t, file, fn) + if callbackCount > 0 { + assert.True(t, modifiedAt.After(lastModifiedAt), fmt.Sprintf("this=%v, last=%v", modifiedAt, lastModifiedAt)) + } + lastModifiedAt = modifiedAt + callbackCount++ + } + + err := ioutil.WriteFile(file, []byte("Test_Reloader"), os.ModePerm) + require.NoError(t, err) + + k, err := reloader.NewReloader(file, 100*time.Millisecond, onChangedFunc) + require.NoError(t, err) + require.NotNil(t, k) + defer k.Close() + + _ = k.Reload() + + loadedAt := k.LoadedAt() + assert.True(t, loadedAt.After(now), "loaded time must be after test start time") + assert.Equal(t, uint32(1), k.LoadedCount()) + + err = ioutil.WriteFile(file, []byte("Test_Reloader2"), os.ModePerm) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + err = ioutil.WriteFile(file, []byte("Test_Reloader3"), os.ModePerm) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + loadedAt2 := k.LoadedAt() + count := int(k.LoadedCount()) + assert.Equal(t, callbackCount, count) + assert.True(t, count >= 2 && count <= 4, "must be loaded at start, whithin period and after, loaded: %d", k.LoadedCount()) + assert.True(t, loadedAt2.After(loadedAt), "re-loaded time must be after last loaded time") + + err = ioutil.WriteFile(file, []byte("Test_Reloader4"), os.ModePerm) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + err = ioutil.WriteFile(file, []byte("Test_Reloader5"), os.ModePerm) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + loadedAt3 := k.LoadedAt() + count = int(k.LoadedCount()) + assert.Equal(t, callbackCount, count) + assert.True(t, count >= 3 && count <= 5, "must be loaded at start, whithin period and after, loaded: %d", k.LoadedCount()) + assert.True(t, loadedAt3.After(loadedAt2), "re-loaded time must be after last loaded time") +} + +func Test_ReloaderClose(t *testing.T) { + var k *reloader.Reloader + assert.NotPanics(t, func() { + k.Close() + }) + + file := filepath.Join(os.TempDir(), "test-reloaded.txt") + + k, err := reloader.NewReloader(file, 100*time.Millisecond, func(fn string, modifiedAt time.Time) {}) + require.NoError(t, err) + require.NotNil(t, k) + + err = k.Close() + assert.NoError(t, err) + + err = k.Close() + require.Error(t, err) + assert.Equal(t, "already closed", err.Error()) +} diff --git a/fileutil/resolve/resolve.go b/fileutil/resolve/resolve.go new file mode 100644 index 0000000..db7c680 --- /dev/null +++ b/fileutil/resolve/resolve.go @@ -0,0 +1,48 @@ +package resolve + +import ( + "os" + "path/filepath" + + "github.com/pkg/errors" +) + +// Directory returns absolute dir name relative to baseDir, +// or NewNotFound error. +func Directory(dir string, baseDir string, create bool) (resolved string, err error) { + if dir == "" { + return dir, nil + } + if filepath.IsAbs(dir) { + resolved = dir + } else { + resolved = filepath.Join(baseDir, dir) + } + if _, err := os.Stat(resolved); os.IsNotExist(err) { + if create { + if err = os.MkdirAll(resolved, 0744); err != nil { + return "", errors.WithMessagef(err, "crerate dir: %q", resolved) + } + } else { + return resolved, errors.WithMessagef(err, "not found: %v", resolved) + } + } + return resolved, nil +} + +// File returns absolute file name relative to baseDir, +// or NewNotFound error. +func File(file string, baseDir string) (resolved string, err error) { + if file == "" { + return file, nil + } + if filepath.IsAbs(file) { + resolved = file + } else if baseDir != "" { + resolved = filepath.Join(baseDir, file) + } + if _, err := os.Stat(resolved); os.IsNotExist(err) { + return resolved, errors.WithMessagef(err, "not found: %v", resolved) + } + return resolved, nil +} diff --git a/fileutil/resolve/resolve_test.go b/fileutil/resolve/resolve_test.go new file mode 100644 index 0000000..b13ff0a --- /dev/null +++ b/fileutil/resolve/resolve_test.go @@ -0,0 +1,82 @@ +package resolve_test + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strings" + "testing" + + "github.com/effective-security/x/fileutil/resolve" + "github.com/effective-security/x/guid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ResolveDirectory(t *testing.T) { + tmpDir := path.Join(os.TempDir(), "resolve-test", guid.MustCreate()) + testData := []struct { + dir string + baseDir string + create bool + err string + }{ + { + dir: "a1/a2", + baseDir: tmpDir, + create: false, + err: "no such file or directory", + }, + { + dir: "a1/a2", + baseDir: tmpDir, + create: true, + err: "", + }, + { + dir: "a1/a2", + baseDir: tmpDir, + create: false, + err: "", + }, + } + + // Run test + for idx, v := range testData { + t.Run(fmt.Sprintf("[%d] %s", idx, v.dir), func(t *testing.T) { + d, err := resolve.Directory(v.dir, v.baseDir, v.create) + if v.err != "" { + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), v.err)) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, d) + assert.True(t, strings.HasSuffix(d, v.dir)) + } + }) + } +} + +func Test_File(t *testing.T) { + f, err := resolve.File("", ".") + assert.NoError(t, err) + assert.Empty(t, f) + + f = "resolve.go" + + // now f is relative to current folder + f2, err := resolve.File(f, ".") + assert.NoError(t, err) + assert.Equal(t, f, f2) + + fabs, err := filepath.Abs(f2) + require.NoError(t, err) + + f3, err := resolve.File(fabs, "/does/not/matter") + assert.NoError(t, err) + assert.Equal(t, fabs, f3) + + _, err = resolve.File(fabs+".junk", "/does/not/matter") + assert.Error(t, err) +} diff --git a/fileutil/testdata/test_config.json b/fileutil/testdata/test_config.json new file mode 100644 index 0000000..d46c0f2 --- /dev/null +++ b/fileutil/testdata/test_config.json @@ -0,0 +1,6 @@ +{ + "service": "porto-pod", + "region": "local", + "cluster": "cl1", + "environment": "test" +} \ No newline at end of file diff --git a/fileutil/testdata/test_config.yaml b/fileutil/testdata/test_config.yaml new file mode 100644 index 0000000..8ec8a11 --- /dev/null +++ b/fileutil/testdata/test_config.yaml @@ -0,0 +1,5 @@ +--- +service: porto-pod +region : local +cluster: cl1 +environment: test diff --git a/flake/LICENSE b/flake/LICENSE new file mode 100644 index 0000000..81795bf --- /dev/null +++ b/flake/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright 2015 Sony Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/flake/README.md b/flake/README.md new file mode 100644 index 0000000..7bcc24f --- /dev/null +++ b/flake/README.md @@ -0,0 +1,69 @@ +Flake +========= + +A fork from https://github.com/sony/sonyflake + +Flake is a distributed unique ID generator inspired by [Twitter's Snowflake](https://blog.twitter.com/2010/announcing-snowflake). + +Differences from the original Sonyflake: +- panic instead of returning errors, as these errors are mostly non actionable and should never occur: `NextID() uint64` +- time units are 1 msec instead of 10 +- 16 bits for a machine id, +- 6 bits for a sequence number (64 per 1 ms) +- 41 bits for time in units of 1 msec + +As a result, Flake has the following advantages and disadvantages: + +- The lifetime (69 years) is similar to that of Snowflake (69 years) +- It can work on more distributed machines (2^16) than Snowflake (2^10) +- It can generate 2^6 IDs per 1 msec at most in a single machine/thread + +Installation +------------ + +``` +go get github.com/effective-security/porto/pkg/flake +``` + +Usage +----- + +The function NewIDGenerator creates a new IDGenerator instance. + +```go +func NewIDGenerator(st Settings) IDGenerator +``` + +You can configure Flake by the struct Settings: + +```go +type Settings struct { + StartTime time.Time + MachineID func() (uint16, error) + CheckMachineID func(uint16) bool +} +``` + +- StartTime is the time since which the Flake time is defined as the elapsed time. + If StartTime is 0, the start time of the Sonyflake is set to "2021-01-01 00:00:00 +0000 UTC". + If StartTime is ahead of the current time, Flake is not created. + +- MachineID returns the unique ID of the Flake instance. + If MachineID returns an error, Flake will panic. + If MachineID is nil, default MachineID is used. + Default MachineID returns the lower 8 bits of the private IP address. + +- CheckMachineID validates the uniqueness of the machine ID. + If CheckMachineID returns false, Flake will panic. + If CheckMachineID is nil, no validation is done. + +In order to get a new unique ID, you just have to call the method NextID. + +```go +func (sf *Flake) NextID() uint64 +``` + +License +------- + +The MIT License (MIT) diff --git a/flake/flake.go b/flake/flake.go new file mode 100644 index 0000000..241eedb --- /dev/null +++ b/flake/flake.go @@ -0,0 +1,272 @@ +// Package flake implements Snowflake, a distributed unique ID generator inspired by Twitter's Snowflake. +// +// A Flake ID is composed of +// +// 39 bits for time in units of 10 msec +// 8 bits for a sequence number +// 16 bits for a machine id +package flake + +import ( + "hash/fnv" + "net" + "os" + "strings" + "sync" + "time" + + "github.com/effective-security/xlog" +) + +var logger = xlog.NewPackageLogger("github.com/effective-security/x", "flake") + +// IDGenerator defines an interface to generate unique ID accross the cluster +type IDGenerator interface { + // NextID generates a next unique ID. + NextID() uint64 +} + +// DefaultIDGenerator for the app +var DefaultIDGenerator IDGenerator + +// NowFunc returns the current time; it's overridden in tests. +var NowFunc = time.Now + +func init() { + DefaultIDGenerator = NewIDGenerator(Settings{ + StartTime: DefaultStartTime, + }) +} + +// These constants are the bit lengths of Flake ID parts. +const ( + BitLenMachineID = 16 // bit length of machine id, 2^16 + BitLenSequence = 6 // bit length of sequence number + BitLenTime = 63 - BitLenMachineID - BitLenSequence // bit length of time + MaskSequence16 = uint16(1<= current + sf.sequence = (sf.sequence + 1) & MaskSequence16 + if sf.sequence > sf.maxSequence { + sf.maxSequence = sf.sequence + } + + if sf.sequence == 0 { + sf.elapsedTime++ + overtime := sf.elapsedTime - current + sleep := sleepTime((overtime)) + //logger.Noticef("sleep_overtime=%v", sleep) + time.Sleep(sleep) + } + } + + sf.lastID = sf.toID() + return sf.lastID +} + +func toFlakeTime(t time.Time) int64 { + return t.UnixNano() / FlakeTimeUnit +} + +func fromFlakeTime(f int64) time.Time { + return time.Unix(0, f*FlakeTimeUnit).UTC() +} + +func currentElapsedTime(startTime int64) int64 { + return toFlakeTime(NowFunc()) - startTime +} + +func sleepTime(overtime int64) time.Duration { + return time.Nanosecond * + time.Duration(overtime*FlakeTimeUnit-NowFunc().UnixNano()%FlakeTimeUnit) +} + +func (sf *Flake) toID() uint64 { + if sf.elapsedTime >= 1<> 63 + time := id >> (BitLenSequence + BitLenMachineID) + sequence := (id & MaskSequence) >> BitLenMachineID + machineID := id & MaskMachineID + return map[string]uint64{ + "id": id, + "msb": msb, + "time": time, + "sequence": sequence, + "machine_id": machineID, + } +} + +// IDTime returns the timestamp of the flake ID. +func IDTime(g IDGenerator, id uint64) time.Time { + start := int64(0) + if fl, ok := g.(*Flake); ok { + start = fl.startTime + } + return fromFlakeTime(start + int64(id>>(BitLenSequence+BitLenMachineID))) +} + +// FirstID returns the first ID generated by the generator. +func FirstID(g IDGenerator) uint64 { + if fl, ok := g.(*Flake); ok { + return fl.firstID + } + return 0 +} + +// LastID returns the last ID generated by the generator. +func LastID(g IDGenerator) uint64 { + if fl, ok := g.(*Flake); ok { + return fl.lastID + } + return 0 +} diff --git a/flake/flake_test.go b/flake/flake_test.go new file mode 100644 index 0000000..d592cd6 --- /dev/null +++ b/flake/flake_test.go @@ -0,0 +1,302 @@ +package flake + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + mapset "github.com/deckarep/golang-set" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + sf *Flake + startTime int64 + machineID uint64 +) + +func init() { + var st Settings + st.StartTime = NowFunc() + + sf = NewIDGenerator(st).(*Flake) + + startTime = toFlakeTime(st.StartTime) + machineID = uint64(sf.machineID) +} + +func TestFlakeOnce(t *testing.T) { + assert.Equal(t, int64(1e6), FlakeTimeUnit) + + id0 := sf.NextID() + parts := Decompose(id0) + fTime := parts["time"] + start := fromFlakeTime(sf.startTime) + t.Logf("start: %s, parts: %+v", start.Format(time.RFC3339), parts) + + tim := fromFlakeTime(int64(fTime)) + assert.True(t, tim.Before(start)) + + sleepTime := uint64(5 * FlakeTimeUnit / int64(time.Millisecond)) + time.Sleep(time.Duration(sleepTime) * time.Duration(FlakeTimeUnit)) + + id := sf.NextID() + parts = Decompose(id) + t.Logf("parts: %+v", parts) + + actualMSB := parts["msb"] + assert.Equal(t, uint64(0), actualMSB) + + actualTime := parts["time"] + assert.LessOrEqual(t, sleepTime, actualTime+2, "unexpected time: %d", actualTime) + assert.LessOrEqual(t, actualTime, sleepTime+2, "unexpected time: %d", actualTime) + + actualSequence := parts["sequence"] + assert.Equal(t, uint64(0), actualSequence) + + actualMachineID := parts["machine_id"] + assert.Equal(t, uint64(machineID), uint64(actualMachineID)) +} + +func currentTime() int64 { + return toFlakeTime(NowFunc()) +} + +func TestFlakeFor10Sec(t *testing.T) { + var numID uint32 + var lastID uint64 + var maxSequence uint64 + + initial := currentTime() + current := initial + const maxTime = 10 * int64(time.Second) / FlakeTimeUnit + for current-initial < maxTime { + id := sf.NextID() + parts := Decompose(id) + numID++ + + require.Greater(t, id, lastID, "duplicated id") + lastID = id + + current = currentTime() + + actualMSB := parts["msb"] + require.Equal(t, uint64(0), actualMSB) + + actualTime := int64(parts["time"]) + overtime := startTime + actualTime - current + require.LessOrEqual(t, overtime, int64(2), "unexpected overtime", overtime) + + actualSequence := parts["sequence"] + if maxSequence < actualSequence { + maxSequence = actualSequence + } + + actualMachineID := parts["machine_id"] + require.Equal(t, uint64(machineID), uint64(actualMachineID)) + } + + assert.GreaterOrEqualf(t, maxSequence, uint64(1< gopkg.in/yaml.v2 v2.2.8 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b00c333 --- /dev/null +++ b/go.sum @@ -0,0 +1,80 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/deckarep/golang-set v1.8.0 h1:sk9/l/KqpunDwP7pSjUg0keiOOLEnOBHzykLrsPppp4= +github.com/deckarep/golang-set v1.8.0/go.mod h1:5nI87KwE7wgsBU1F4GKAw2Qod7p5kyS383rP6+o6qqo= +github.com/effective-security/xlog v0.6.0 h1:n1MzotZSHZ1+XMO3CQcc7xEO8y+0BMbNEHA0SsTLs/8= +github.com/effective-security/xlog v0.6.0/go.mod h1:ZDG9qha5Mt18D5DNd/8WhHXzw3f9JeOUVcXHYvWu3/U= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/oleiade/reflections v1.0.1 h1:D1XO3LVEYroYskEsoSiGItp9RUxG6jWnCVvrqH0HHQM= +github.com/oleiade/reflections v1.0.1/go.mod h1:rdFxbxq4QXVZWj0F+e9jqjDkc7dbp97vkRixKo2JR60= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/config v1.4.0 h1:upnMPpMm6WlbZtXoasNkK4f0FhxwS+W4Iqz5oNznehQ= +go.uber.org/config v1.4.0/go.mod h1:aCyrMHmUAc/s2h9sv1koP84M9ZF/4K+g2oleyESO/Ig= +go.uber.org/multierr v1.4.0 h1:f3WCSC2KzAcBXGATIxAB1E2XuCpNU255wNKZ505qi3E= +go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191104232314-dc038396d1f0/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/guid/guid.go b/guid/guid.go new file mode 100644 index 0000000..75117c4 --- /dev/null +++ b/guid/guid.go @@ -0,0 +1,17 @@ +package guid + +import ( + "crypto/rand" + "fmt" +) + +// MustCreate returns GUID +func MustCreate() string { + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + + return fmt.Sprintf("%X-%X-%X-%X-%X", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) +} diff --git a/guid/guid_test.go b/guid/guid_test.go new file mode 100644 index 0000000..3c53941 --- /dev/null +++ b/guid/guid_test.go @@ -0,0 +1,15 @@ +package guid_test + +import ( + "testing" + + "github.com/effective-security/x/guid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Guid(t *testing.T) { + g := guid.MustCreate() + require.NotEmpty(t, g) + assert.Equal(t, 36, len(g)) +} diff --git a/math/compare.go b/math/compare.go new file mode 100644 index 0000000..f599b8a --- /dev/null +++ b/math/compare.go @@ -0,0 +1,54 @@ +// Package math implements basic operations on various types +package math + +import ( + "time" +) + +// Max returns the larger of the 2 supplied int's +func Max(a, b int) int { + if a > b { + return a + } + return b +} + +// Min returns the smaller of the 2 supplied int's +func Min(a, b int) int { + if a < b { + return a + } + return b +} + +// MaxUint64 returns the larger of the 2 supplied uint64's +func MaxUint64(a, b uint64) uint64 { + if a > b { + return a + } + return b +} + +// MinUint64 returns the smaller of the 2 supplied uint64's +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +// MinDuration returns the smaller of teh 2 supplied Durations +func MinDuration(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b +} + +// MaxDuration returns the larger of the 2 supplied Durations +func MaxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/math/compare_test.go b/math/compare_test.go new file mode 100644 index 0000000..7387d8a --- /dev/null +++ b/math/compare_test.go @@ -0,0 +1,112 @@ +package math + +import ( + "math" + "testing" + "time" +) + +func TestMath_Max(t *testing.T) { + vals := [][]int{ + {0, 0, 0}, + {0, 1, 1}, + {1, 0, 1}, + {42, 0, 42}, + {999999, 999998, 999999}, + {-1, 0, 0}, + {0, -1, 0}, + {1, -1, 1}, + } + + for _, v := range vals { + r := Max(v[0], v[1]) + if r != v[2] { + t.Errorf("Max(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} + +func TestMath_Min(t *testing.T) { + vals := [][]int{ + {0, 0, 0}, + {0, 1, 0}, + {1, 0, 0}, + {42, 0, 0}, + {999999, 999998, 999998}, + {-1, 0, -1}, + {0, -1, -1}, + {1, -1, -1}, + } + + for _, v := range vals { + r := Min(v[0], v[1]) + if r != v[2] { + t.Errorf("Min(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} + +func TestMath_MaxUint64(t *testing.T) { + vals := [][]uint64{ + {0, 0, 0}, + {0, 1, 1}, + {1, 0, 1}, + {42, 0, 42}, + {999999, 999998, 999999}, + {math.MaxUint64, 999998, math.MaxUint64}, + } + + for _, v := range vals { + r := MaxUint64(v[0], v[1]) + if r != v[2] { + t.Errorf("MaxUint64(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} + +func TestMath_MinUint64(t *testing.T) { + vals := [][]uint64{ + {0, 0, 0}, + {0, 1, 0}, + {1, 0, 0}, + {42, 0, 0}, + {math.MaxUint64, 999998, 999998}, + } + + for _, v := range vals { + r := MinUint64(v[0], v[1]) + if r != v[2] { + t.Errorf("MinUint64(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} + +func TestMath_MinDuration(t *testing.T) { + vals := [][]time.Duration{ + {time.Second, time.Minute, time.Second}, + {0, 1, 0}, + {time.Minute, time.Second, time.Second}, + {time.Second, time.Second, time.Second}, + } + for _, v := range vals { + r := MinDuration(v[0], v[1]) + if r != v[2] { + t.Errorf("MinDuration(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} + +func TestMath_MaxDuration(t *testing.T) { + vals := [][]time.Duration{ + {time.Second, time.Minute, time.Minute}, + {0, 1, 1}, + {time.Minute, time.Second, time.Minute}, + {time.Second, time.Second, time.Second}, + } + for _, v := range vals { + r := MaxDuration(v[0], v[1]) + if r != v[2] { + t.Errorf("MaxDuration(%v,%v) returned %v, expecting %v", v[0], v[1], r, v[2]) + } + } +} diff --git a/netutil/freeport.go b/netutil/freeport.go new file mode 100644 index 0000000..1b1bfaf --- /dev/null +++ b/netutil/freeport.go @@ -0,0 +1,45 @@ +package netutil + +import ( + "net" + + "github.com/effective-security/xlog" + "github.com/pkg/errors" +) + +var logger = xlog.NewPackageLogger("github.com/effective-security/x", "netutil") + +// FindFreePort returns a free port found on a host +func FindFreePort(host string, maxAttempts int) (int, error) { + if host == "" { + host = "localhost" + } + if maxAttempts < 1 { + maxAttempts = 1 + } + + for i := 0; i < maxAttempts; i++ { + addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(host, "0")) + if err != nil { + logger.KV(xlog.ERROR, + "reason", "unable to resolve tcp addr", + "err", err.Error()) + continue + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + l.Close() + logger.KV(xlog.ERROR, + "reason", "unable to listen", + "addr", addr, + "err", err.Error()) + continue + } + + port := l.Addr().(*net.TCPAddr).Port + l.Close() + return port, nil + } + + return 0, errors.Errorf("no free port found") +} diff --git a/netutil/freeport_test.go b/netutil/freeport_test.go new file mode 100644 index 0000000..5695e79 --- /dev/null +++ b/netutil/freeport_test.go @@ -0,0 +1,15 @@ +package netutil_test + +import ( + "testing" + + "github.com/effective-security/x/netutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_FindFreePort(t *testing.T) { + p, err := netutil.FindFreePort("", 0) + require.NoError(t, err) + assert.NotEmpty(t, p) +} diff --git a/netutil/localip.go b/netutil/localip.go new file mode 100644 index 0000000..67ea3ff --- /dev/null +++ b/netutil/localip.go @@ -0,0 +1,83 @@ +package netutil + +import ( + "net" + "time" + + "github.com/pkg/errors" +) + +var cidrs []*net.IPNet + +func init() { + maxCidrBlocks := []string{ + "127.0.0.1/8", // localhost + "10.0.0.0/8", // 24-bit block + "172.16.0.0/12", // 20-bit block + "192.168.0.0/16", // 16-bit block + "169.254.0.0/16", // link local address + "::1/128", // localhost IPv6 + "fc00::/7", // unique local address IPv6 + "fe80::/10", // link local address IPv6 + } + + cidrs = make([]*net.IPNet, len(maxCidrBlocks)) + for i, maxCidrBlock := range maxCidrBlocks { + _, cidr, _ := net.ParseCIDR(maxCidrBlock) + cidrs[i] = cidr + } +} + +// IsPrivateAddress works by checking if the address is under private CIDR blocks. +// List of private CIDR blocks can be seen on : +// +// https://en.wikipedia.org/wiki/Private_network +// +// https://en.wikipedia.org/wiki/Link-local_address +func IsPrivateAddress(address string) (bool, error) { + ipAddress := net.ParseIP(address) + if ipAddress == nil { + return false, errors.New("address is not valid") + } + + for i := range cidrs { + if cidrs[i].Contains(ipAddress) { + return true, nil + } + } + + return false, nil +} + +// GetLocalIP returns the non loopback local IP of the host +func GetLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", errors.WithStack(err) + } + for _, address := range addrs { + // check the address type and if it is not a loopback the display it + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + } + return "", errors.New("unable to resolve local IP address") +} + +// WaitForNetwork will wait until the local IP is available or timeout ocurred +func WaitForNetwork(d time.Duration) (ipaddr string, err error) { + ipaddr, err = GetLocalIP() + if err != nil { + cutoff := time.Now().Add(d) + for cutoff.After(time.Now()) { + ipaddr, err = GetLocalIP() + if err == nil { + break + } + time.Sleep(time.Second) + } + } + return +} diff --git a/netutil/localip_test.go b/netutil/localip_test.go new file mode 100644 index 0000000..b479a78 --- /dev/null +++ b/netutil/localip_test.go @@ -0,0 +1,53 @@ +package netutil_test + +import ( + "testing" + "time" + + "github.com/effective-security/x/netutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_LocalIP(t *testing.T) { + ip, err := netutil.GetLocalIP() + require.NoError(t, err, "failed to resolve local IP address") + assert.NotEmpty(t, ip) + + _, err = netutil.IsPrivateAddress(ip) + require.NoError(t, err) +} + +func TestIsPrivateAddr(t *testing.T) { + testData := map[string]bool{ + "127.0.0.0": true, + "10.0.0.0": true, + "169.254.0.0": true, + "192.168.0.0": true, + "::1": true, + "fc00::": true, + + "172.15.0.0": false, + "172.16.0.0": true, + "172.31.0.0": true, + "172.32.0.0": false, + + "147.12.56.11": false, + } + + for addr, isLocal := range testData { + isPrivate, err := netutil.IsPrivateAddress(addr) + require.NoError(t, err) + assert.Equal(t, isLocal, isPrivate, addr) + } +} + +func Test_WaitForNetwork(t *testing.T) { + ip, err := netutil.WaitForNetwork(0) + require.NoError(t, err) + assert.NotEmpty(t, ip) + + ip, err = netutil.WaitForNetwork(time.Second) + require.NoError(t, err) + assert.NotEmpty(t, ip) +} diff --git a/netutil/net.go b/netutil/net.go new file mode 100644 index 0000000..8e4568a --- /dev/null +++ b/netutil/net.go @@ -0,0 +1,54 @@ +package netutil + +import ( + "fmt" + "net" + "os" + "syscall" +) + +// namedAddress represents a TCP Network address based on host name rather +// than IP address. it implements the net.Addr interface, and is used to +// define the node identity to raft. +type namedAddress struct { + network string + host string + port uint16 +} + +// newNamedAddress verifies that the supplied net/host name is resolvable and +// if so returns a namedAddress that represents it, otherwise the resolve +// error is returned +func newNamedAddress(network, host string, port uint16) (*namedAddress, error) { + na := &namedAddress{network: network, host: host, port: port} + _, err := na.Resolve() + if err != nil { + na = nil + } + return na, err +} + +// Network() is part of the net.Addr interface +func (a *namedAddress) Network() string { + return a.network +} + +// String() is part of the net.Addr interface +func (a *namedAddress) String() string { + return fmt.Sprintf("%v:%v", a.host, a.port) +} + +// Resolve() resolves this named address to a specific TCP Address. +func (a *namedAddress) Resolve() (*net.TCPAddr, error) { + return net.ResolveTCPAddr(a.Network(), a.String()) +} + +// IsAddrInUse checks whether the given error indicates "address in use" +func IsAddrInUse(err error) bool { + if err, ok := err.(*net.OpError); ok { + if err, ok := err.Err.(*os.SyscallError); ok { + return err.Err == syscall.EADDRINUSE + } + } + return false +} diff --git a/netutil/net_test.go b/netutil/net_test.go new file mode 100644 index 0000000..8aaaffa --- /dev/null +++ b/netutil/net_test.go @@ -0,0 +1,51 @@ +package netutil + +import ( + "net" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNamedAddress_Network(t *testing.T) { + a := namedAddress{network: "tcp4", host: "localhost", port: 8080} + assert.Equal(t, "tcp4", a.Network(), "Network() should return tcp4") +} + +func TestNamedAddress_isNetAddr(t *testing.T) { + // this will fail to compile if namedAddress doesn't implement the net.Addr interface + var _ net.Addr = &namedAddress{} +} + +func TestNamedAddress_String(t *testing.T) { + st := func(h string, p uint16, expected string) { + a := namedAddress{network: "tcp", host: h, port: p} + assert.Equal(t, expected, a.String(), "Unexpected result for String() for %+v", a) + } + st("localhost", 8080, "localhost:8080") + st("ekspand.com", 5001, "ekspand.com:5001") + st("", 5001, ":5001") +} + +func TestNamedAddress_Resolve(t *testing.T) { + a := namedAddress{network: "tcp", host: "localhost", port: 7070} + addr, err := a.Resolve() + require.NoError(t, err, "Error calling resolve %v", a) + assert.Equal(t, 7070, addr.Port, "Wrong port resolved") +} + +func TestNamedAddress_New(t *testing.T) { + a, err := newNamedAddress("tcp", "localhost", 7070) + require.NoError(t, err, "Error creating newNamedAddress") + assert.Equal(t, "tcp", a.Network(), "Unexpected network in namedAddress") + assert.Equal(t, "localhost:7070", a.String(), "Unexpected outout for String()") + + _, err = newNamedAddress("bob", "", 0) + assert.Error(t, err) +} + +func TestIsAddrInUse(t *testing.T) { + assert.False(t, IsAddrInUse(errors.Errorf("not"))) +} diff --git a/netutil/nodeinfo.go b/netutil/nodeinfo.go new file mode 100644 index 0000000..f6946e5 --- /dev/null +++ b/netutil/nodeinfo.go @@ -0,0 +1,62 @@ +package netutil + +import ( + "os" + + "github.com/pkg/errors" +) + +// NodeInfo is an interface to provide host and IP address for the node in the cluster. +type NodeInfo interface { + HostName() string + LocalIP() string + NodeName() string +} + +// GetNodeNameFn is an interface to return an application specific +// node name from the host name. +type GetNodeNameFn func(hostname string) string + +type localNodeInfo struct { + hostname string + ipAddr string + // nodename is derived from the host name, + // for example, by trimming www. prefix and .com suffix + nodename string +} + +// NewNodeInfo constructs node info using provided extractor function. +// If extractor is nil, then host name is used. +func NewNodeInfo(extractor GetNodeNameFn) (NodeInfo, error) { + var err error + localInfo := new(localNodeInfo) + if localInfo.hostname, err = os.Hostname(); err != nil { + return nil, errors.WithMessagef(err, "unable to determine hostname") + } + + if localInfo.ipAddr, err = GetLocalIP(); err != nil { + return nil, errors.WithMessagef(err, "unable to determine local IP address") + } + + localInfo.nodename = localInfo.hostname + if extractor != nil { + localInfo.nodename = extractor(localInfo.hostname) + } + + return localInfo, nil +} + +// HostName return a host name +func (l *localNodeInfo) HostName() string { + return l.hostname +} + +// LocalIP return local IP address +func (l *localNodeInfo) LocalIP() string { + return l.ipAddr +} + +// NodeName returns node name derived from hostname +func (l *localNodeInfo) NodeName() string { + return l.nodename +} diff --git a/netutil/nodeinfo_test.go b/netutil/nodeinfo_test.go new file mode 100644 index 0000000..dc5f0dc --- /dev/null +++ b/netutil/nodeinfo_test.go @@ -0,0 +1,26 @@ +package netutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NodeInfo(t *testing.T) { + n, err := NewNodeInfo(nil) + require.NoError(t, err) + require.NotNil(t, n) + assert.NotEmpty(t, n.HostName()) + assert.NotEmpty(t, n.LocalIP()) + assert.Equal(t, n.HostName(), n.NodeName()) + + n, err = NewNodeInfo(func(hostname string) string { + return "nodename" + }) + require.NoError(t, err) + require.NotNil(t, n) + assert.NotEmpty(t, n.HostName()) + assert.NotEmpty(t, n.LocalIP()) + assert.Equal(t, "nodename", n.NodeName()) +} diff --git a/netutil/urls.go b/netutil/urls.go new file mode 100644 index 0000000..2038331 --- /dev/null +++ b/netutil/urls.go @@ -0,0 +1,41 @@ +package netutil + +import ( + "net/url" + "strings" + + "github.com/pkg/errors" +) + +// ParseURLs creates a list of URLs from lists of hosts +func ParseURLs(list []string) ([]*url.URL, error) { + urls := make([]*url.URL, len(list)) + for i, host := range list { + u, err := url.Parse(strings.TrimSpace(host)) + if err != nil { + return nil, errors.WithStack(err) + } + urls[i] = u + } + + return urls, nil +} + +// ParseURLsFromString creates a list of URLs from a coma-separated lists of hosts +func ParseURLsFromString(hosts string) ([]*url.URL, error) { + hosts = strings.TrimSpace(hosts) + if len(hosts) == 0 { + return nil, nil + } + list := strings.Split(hosts, ",") + return ParseURLs(list) +} + +// JoinURLs returns coma-separated lists of URLs in string format +func JoinURLs(list []*url.URL) string { + strs := make([]string, len(list)) + for i, url := range list { + strs[i] = url.String() + } + return strings.Join(strs, ",") +} diff --git a/netutil/urls_test.go b/netutil/urls_test.go new file mode 100644 index 0000000..5d25c39 --- /dev/null +++ b/netutil/urls_test.go @@ -0,0 +1,121 @@ +package netutil + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseURL(t *testing.T) { + tcases := []struct { + in string + scheme string + host string + port string + path string + err string + }{ + {"localhost", "", "", "", "localhost", ""}, + {"localhost.com:8080", "localhost.com", "", "", "", ""}, + } + + for _, tc := range tcases { + t.Run(tc.in, func(t *testing.T) { + u, err := url.Parse(tc.in) + if tc.err == "" { + require.NoError(t, err) + assert.Equal(t, tc.scheme, u.Scheme) + assert.Equal(t, tc.host, u.Host) + assert.Equal(t, tc.port, u.Port()) + } else { + assert.Error(t, err) + } + }) + } +} + +func Test_ParseURLs(t *testing.T) { + tcases := []struct { + tname string + hosts []string + err string + }{ + {"from nil", nil, ""}, + {"from_empty", nil, ""}, + {"valid", []string{"localhost", "123.74.56.18", "ekspand.com", "ekspand.com:80"}, ""}, + {"valid with path", []string{"../dir/"}, ""}, + {"valid1 with page", []string{"foo.html"}, ""}, + {"invalid with ip", []string{"http://192.168.0.%31/"}, "error"}, + {"invalid with code", []string{"http://[fe80::%231]:8080/"}, "error"}, + } + + for _, tc := range tcases { + t.Run(tc.tname, func(t *testing.T) { + l, err := ParseURLs(tc.hosts) + if tc.err == "" { + require.NoError(t, err) + assert.Equal(t, len(tc.hosts), len(l)) + } else { + if !assert.Error(t, err) { + for _, u := range l { + t.Logf("parsed url: %s", u.String()) + } + } + } + }) + } +} + +func Test_ParseURLsFromString(t *testing.T) { + tcases := []struct { + tname string + hosts string + exp int + err string + }{ + {"from_empty", "", 0, ""}, + {"valid", "localhost,123.74.56.18,ekspand.com", 3, ""}, + {"valid with path", "../dir/,../dir2/", 2, ""}, + {"valid1 with page", "foo.html,foo.html", 2, ""}, + {"invalid with ip", "http://192.168.0.%31/", 0, "error"}, + {"invalid with code", "http://[fe80::%231]:8080/", 0, "error"}, + } + + for _, tc := range tcases { + t.Run(tc.tname, func(t *testing.T) { + l, err := ParseURLsFromString(tc.hosts) + if tc.err == "" { + require.NoError(t, err) + assert.Equal(t, tc.exp, len(l)) + } else { + if !assert.Error(t, err) { + for _, u := range l { + t.Logf("parsed url: %s", u.String()) + } + } + } + }) + } +} + +func Test_JoinURLs(t *testing.T) { + tcases := []struct { + in string + out string + }{ + {in: "localhost,123.74.56.18,http://ekspand.com", out: "localhost,123.74.56.18,http://ekspand.com"}, + {in: "https://123.74.56.18,unix://localhost:123456", out: "https://123.74.56.18,unix://localhost:123456"}, + } + + for _, tc := range tcases { + t.Run(tc.in, func(t *testing.T) { + l, err := ParseURLsFromString(tc.in) + require.NoError(t, err) + + str := JoinURLs(l) + assert.Equal(t, tc.out, str) + }) + } +} diff --git a/slices/slices.go b/slices/slices.go new file mode 100644 index 0000000..92dcf38 --- /dev/null +++ b/slices/slices.go @@ -0,0 +1,254 @@ +// Package slices provides additional slice functions on common slice types +package slices + +import ( + "strings" +) + +// ByteSlicesEqual returns true only if the contents of the 2 slices are the same +func ByteSlicesEqual(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// StringSlicesEqual returns true only if the contents of the 2 slices are the same +func StringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// ContainsString returns true if the items slice contains a value equal to item +// Note that this can end up traversing the entire slice, and so is only really +// suitable for small slices, for larger data sets, consider using a map instead. +func ContainsString(items []string, item string) bool { + for _, x := range items { + if x == item { + return true + } + } + return false +} + +// StringContainsOneOf returns true if one of items slice is a substring of specified value. +func StringContainsOneOf(item string, items []string) bool { + for _, x := range items { + if strings.Contains(item, x) { + return true + } + } + return false +} + +// StringStartsWithOneOf returns true if one of items slice is a prefix of specified value. +func StringStartsWithOneOf(value string, items []string) bool { + for _, x := range items { + if strings.HasPrefix(value, x) { + return true + } + } + return false +} + +// ContainsStringEqualFold returns true if the items slice contains a value equal to item +// ignoring case [i.e. using EqualFold] +// Note that this can end up traversing the entire slice, and so is only really +// suitable for small slices, for larger data sets, consider using a map instead. +func ContainsStringEqualFold(items []string, item string) bool { + for _, x := range items { + if strings.EqualFold(x, item) { + return true + } + } + return false +} + +// CloneStrings will return an independnt copy of the src slice, it preserves +// the distinction between a nil value and an empty slice. +func CloneStrings(src []string) []string { + if src != nil { + c := make([]string, len(src)) + copy(c, src) + return c + } + return nil +} + +// NvlString returns the first string from the supplied list that has len() > 0 +// or "" if all the strings are empty +func NvlString(items ...string) string { + for _, x := range items { + if len(x) > 0 { + return x + } + } + return "" +} + +// Prefixed returns a new slice of strings with each input item prefixed by the supplied prefix +// e.g. Prefixed("foo", []string{"bar","bob"}) would return []string{"foobar", "foobob"} +// the input slice is not modified. +func Prefixed(prefix string, items []string) []string { + return MapStringSlice(items, func(in string) string { + return prefix + in + }) +} + +// Suffixed returns a new slice of strings which each input item suffixed by the supplied suffix +// e.g. Suffixed("foo", []string{"bar","bob"}) would return []string{"barfoo", "bobfoo"} +// the input slice is not modified +func Suffixed(suffix string, items []string) []string { + return MapStringSlice(items, func(in string) string { + return in + suffix + }) +} + +// Quoted returns a new slice of strings where each input stream has been wrapped in quotes +func Quoted(items []string) []string { + return MapStringSlice(items, func(in string) string { + return `"` + in + `"` + }) +} + +// MapStringSlice returns a new slices of strings that is the result of applies mapFn +// to each string in the input slice. +func MapStringSlice(items []string, mapFn func(in string) string) []string { + res := make([]string, len(items)) + for idx, v := range items { + res[idx] = mapFn(v) + } + return res +} + +// BoolSlicesEqual returns true only if the contents of the 2 slices are the same +func BoolSlicesEqual(a, b []bool) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// StringsCoalesce returns the first non-empty string value +func StringsCoalesce(str ...string) string { + for _, s := range str { + if len(s) > 0 { + return s + } + } + return "" +} + +// StringUpto returns the beginning of the string up to `max` +func StringUpto(str string, max int) string { + if len(str) > max { + return str[:max] + } + return str +} + +// Int64SlicesEqual returns true only if the contents of the 2 slices are the same +func Int64SlicesEqual(a, b []int64) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// Uint64SlicesEqual returns true only if the contents of the 2 slices are the same +func Uint64SlicesEqual(a, b []uint64) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// Float64SlicesEqual returns true only if the contents of the 2 slices are the same +func Float64SlicesEqual(a, b []float64) bool { + if len(a) != len(b) { + return false + } + for idx, v := range a { + if v != b[idx] { + return false + } + } + return true +} + +// UniqueStrings removes duplicates from the given list +func UniqueStrings(dups []string) []string { + if len(dups) < 2 { + return dups + } + keys := make(map[string]bool) + list := []string{} + + for _, entry := range dups { + if _, value := keys[entry]; !value { + keys[entry] = true + list = append(list, entry) + } + } + return list +} + +// NumbersCoalesce returns the first value from the supplied list that is not 0, or 0 if there are no values that are not zero +func NumbersCoalesce[T ~int | ~int32 | ~uint | ~uint32 | ~int64 | ~uint64](items ...T) T { + for _, x := range items { + if x != 0 { + return x + } + } + return 0 +} + +// Measurable interface +type Measurable[T any] interface { + ~string | ~[]string | ~[]T +} + +// Coalesce returns the first non-empty value +func Coalesce[M Measurable[any]](args ...M) M { + for _, s := range args { + if len(s) > 0 { + return s + } + } + return args[0] +} + +// Select returns a if cond is true, otherwise b +func Select[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} diff --git a/slices/slices_test.go b/slices/slices_test.go new file mode 100644 index 0000000..a1ccf98 --- /dev/null +++ b/slices/slices_test.go @@ -0,0 +1,408 @@ +package slices + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSlices_NvlString(t *testing.T) { + v := func(exp string, items ...string) { + act := NvlString(items...) + if act != exp { + t.Errorf("Expecting NvlString(%v) to return %s, but got %s", items, exp, act) + } + } + v("") + v("", "") + v("", "", "") + v("a", "a") + v("a", "a", "") + v("a", "", "", "a") + v("b", "", "b", "a") +} + +func TestSlices_CloneStrings(t *testing.T) { + c := CloneStrings(nil) + if c != nil { + t.Errorf("CloneStrings with a nil src should return nil, but returned %#v", c) + } + c = CloneStrings([]string{}) + if c == nil || len(c) != 0 { + t.Errorf("CloneStrings with a non-nil, but empty slice, should return a new empty slice, but got %#v", c) + } + s := []string{"a", "b", "c"} + c = CloneStrings(s) + if !reflect.DeepEqual(s, c) { + t.Errorf("CloneString() returned different contents to source, got %+v, expecting %+v", c, s) + } + s[0] = "x" + if c[0] != "a" { + t.Errorf("CloneString() didn't return a Clone, it was mutated by mutating the source") + } +} + +func TestSlices_ContainsString(t *testing.T) { + s := []string{"a", "b", "c", "foo", "bar", "qux"} + missing := []string{"bob", "quxx"} + testSlicesContains(t, s, missing, "q", func(items interface{}, item interface{}) bool { + return ContainsString(items.([]string), item.(string)) + }) + if ContainsString(nil, "") { + t.Errorf("a nil slice shouldn't contain anything!") + } +} + +func TestSlices_ContainsStringEqualFold(t *testing.T) { + src := []string{"one", "TWO", "Three"} + tests := []string{"ONE", "One", "two", "three"} + m := []string{"", "oned", "Four"} + for _, item := range append(src, tests...) { + if !ContainsStringEqualFold(src, item) { + t.Errorf("Expecting to find %q in %v, but didn't", item, src) + } + } + for _, item := range m { + if ContainsStringEqualFold(src, item) { + t.Errorf("Not expecting to find %q in %v, but did", item, src) + } + } +} + +func TestSlices_StringContainsOneOf(t *testing.T) { + tcases := []struct { + str string + slices []string + exp bool + }{ + {"Daniel", []string{"foo", "bar"}, false}, + {"Daniel", []string{"foo", "el"}, true}, + {"Daniel", []string{"foo", "da"}, false}, + {"Daniel", []string{"foo", "Dan"}, true}, + } + for idx, tc := range tcases { + res := StringContainsOneOf(tc.str, tc.slices) + if res != tc.exp { + t.Errorf("case %d failed", idx) + } + } +} + +func TestSlices_StringStartsWithOneOf(t *testing.T) { + tcases := []struct { + str string + slices []string + exp bool + }{ + {"Daniel", []string{"foo", "bar"}, false}, + {"foo_Daniel", []string{"foo", "el"}, true}, + {"daniel", []string{"foo", "da"}, true}, + {"Daniel", []string{"foo", "Dan"}, true}, + } + for idx, tc := range tcases { + res := StringStartsWithOneOf(tc.str, tc.slices) + if res != tc.exp { + t.Errorf("case %d failed", idx) + } + } +} + +func testSlicesContains(t *testing.T, items interface{}, missing interface{}, newItem interface{}, containsFunc func(items interface{}, item interface{}) bool) { + vm := reflect.ValueOf(missing) + for i := 0; i < vm.Len(); i++ { + if containsFunc(items, vm.Index(i).Interface()) { + t.Errorf("Item %v wasn't in items slice, but contains said it was!", vm.Index(i)) + } + } + vi := reflect.ValueOf(items) + for i := 0; i < vi.Len(); i++ { + if !containsFunc(items, vi.Index(i).Interface()) { + t.Errorf("Item %v is at index %d in slice, but contains said it wasn't in the slice", vi.Index(i), i) + } + } + vi = reflect.Append(vi, reflect.ValueOf(newItem)) + if !containsFunc(vi.Interface(), newItem) { + t.Errorf("Item %v was added to slice, but contains didn't spot it", newItem) + } + if containsFunc(vi.Slice(1, vi.Len()-1).Interface(), vi.Index(0).Interface()) { + t.Errorf("Item %v wasn't in the modified slice, but contains said it was", vi.Index(0)) + } +} + +func TestSlices_ByteSlicesEqual(t *testing.T) { + bytes := []interface{}{ + []byte{}, + []byte{1}, + []byte{1, 2, 3}, + []byte{1, 2, 3, 4}, + []byte{2, 2, 3, 4}, + []byte{1, 2, 3, 5}, + } + testSlicesEquals(t, "Byte", bytes, bytes[2], []byte{1, 2, 3}, func(x, y interface{}) bool { + return ByteSlicesEqual(x.([]byte), y.([]byte)) + }) + if ByteSlicesEqual(nil, []byte{1}) || ByteSlicesEqual([]byte{1}, nil) { + t.Errorf("ByteSliceEqual for a nil slice shouldn't return true when the other slice has items in it") + } + if !ByteSlicesEqual(nil, nil) || !ByteSlicesEqual(nil, []byte{}) { + t.Errorf("ByteSlicesEquals for a nil & empty slice should return true") + } +} + +func TestSlices_StringSlicesEqual(t *testing.T) { + strings := []interface{}{ + []string{}, + []string{""}, + []string{"aa"}, + []string{"aa", "bb"}, + []string{"aa", "bb", "cc"}, + []string{"bb", "bb", "cc"}, + []string{"aa", "bb", "bb"}, + } + testSlicesEquals(t, "String", strings, []string{"aa", "bb", "cc"}, strings[4], func(x, y interface{}) bool { + return StringSlicesEqual(x.([]string), y.([]string)) + }) + if StringSlicesEqual(nil, []string{"a"}) || StringSlicesEqual([]string{"a"}, nil) { + t.Errorf("StringSlicesEqual for nil and a slice with an item in it should return false") + } + if !StringSlicesEqual(nil, nil) || !StringSlicesEqual(nil, []string{}) { + t.Errorf("StringSlicesEqual for a nil and empty slice should return true") + } +} + +func assertStringSlicesEqual(t *testing.T, preamble string, exp []string, act []string) { + if len(act) != len(exp) { + t.Errorf("%s: expected to get %d items, but got %d", preamble, len(exp), len(act)) + } else { + for i, a := range act { + if a != exp[i] { + t.Errorf("%s: at index %d expected to get %q, but got %q", preamble, i, exp[i], a) + } + } + } +} + +func TestSlices_Quoted(t *testing.T) { + c := func(in, exp []string) { + res := Quoted(in) + assertStringSlicesEqual(t, fmt.Sprintf("Quoted(%v)", in), exp, res) + } + c([]string{}, []string{}) + c([]string{"bob "}, []string{`"bob "`}) + c([]string{"b", "a", "c"}, []string{`"b"`, `"a"`, `"c"`}) +} + +func TestSlices_Prefixed(t *testing.T) { + c := func(p string, items []string, exp []string) { + act := Prefixed(p, items) + assertStringSlicesEqual(t, fmt.Sprintf("Prefixed(%v,%v)", p, items), exp, act) + } + c("bob", []string{}, []string{}) + c("bob", []string{"alice"}, []string{"bobalice"}) + c("bob", []string{"alice", "eve"}, []string{"bobalice", "bobeve"}) + c("", []string{"alice", "eve"}, []string{"alice", "eve"}) +} + +func TestSlices_Suffix(t *testing.T) { + c := func(p string, items []string, exp []string) { + act := Suffixed(p, items) + assertStringSlicesEqual(t, fmt.Sprintf("Suffixed(%v,%v)", p, items), exp, act) + } + c("bob", []string{}, []string{}) + c("bob", []string{"alice"}, []string{"alicebob"}) + c("bob", []string{"alice", "eve"}, []string{"alicebob", "evebob"}) + c("", []string{"alice", "eve"}, []string{"alice", "eve"}) +} + +func TestSlices_Int64SlicesEqual(t *testing.T) { + vals := []interface{}{ + []int64{}, + []int64{0}, + []int64{1}, + []int64{42, 43}, + []int64{42, 43, 0}, + []int64{41, 43, 0}, + []int64{42, 43, 43}, + } + testSlicesEquals(t, "Int64", vals, []int64{42, 43, 0}, vals[4], func(x, y interface{}) bool { + return Int64SlicesEqual(x.([]int64), y.([]int64)) + }) + if Int64SlicesEqual(nil, []int64{1}) || Int64SlicesEqual([]int64{1}, nil) { + t.Errorf("Int64SlicesEqual for a nil slice and a slice with items should return false") + } + if !Int64SlicesEqual(nil, nil) || !Int64SlicesEqual(nil, []int64{}) { + t.Errorf("Int64SlicesEqual for a nil slice and an empty slice should return true") + } +} + +func TestSlices_NvlInt(t *testing.T) { + c := func(exp int, items ...int) { + act := NumbersCoalesce(items...) + if act != exp { + t.Errorf("Expecting NvlInt(%v) to return %d, but got %d", items, exp, act) + } + } + c(0) + c(0, 0) + c(10, 10) + c(10, 10, 0) + c(-10, -10) + c(10, 0, 10) + c(-5, 0, -5, 10) +} + +func TestSlices_NvlInt64(t *testing.T) { + c := func(exp int64, items ...int64) { + act := NumbersCoalesce(items...) + if act != exp { + t.Errorf("Expecting NvlInt64(%v) to return %d, but got %d", items, exp, act) + } + } + c(0) + c(0, 0) + c(10, 10) + c(10, 10, 0) + c(-10, -10) + c(10, 0, 10) + c(-5, 0, -5, 10) +} + +func TestSlices_UInt64SlicesEqual(t *testing.T) { + vals := []interface{}{ + []uint64{}, + []uint64{0}, + []uint64{1}, + []uint64{42, 43}, + []uint64{42, 43, 0}, + []uint64{41, 43, 0}, + []uint64{42, 43, 43}, + } + testSlicesEquals(t, "Uint64", vals, []uint64{42, 43, 0}, vals[4], func(x, y interface{}) bool { + return Uint64SlicesEqual(x.([]uint64), y.([]uint64)) + }) + if Uint64SlicesEqual(nil, []uint64{1}) || Uint64SlicesEqual([]uint64{1}, nil) { + t.Errorf("Uint64SlicesEqual for a nil slice and a slice with items should return false") + } + if !Uint64SlicesEqual(nil, nil) || !Uint64SlicesEqual(nil, []uint64{}) { + t.Errorf("Uint64SlicesEqual for a nil slice and an empty slice should return true") + } +} + +func TestSlices_NvlUint64(t *testing.T) { + c := func(exp uint64, items ...uint64) { + act := NumbersCoalesce(items...) + if act != exp { + t.Errorf("Expecting NvlUnt64(%v) to return %d, but got %d", items, exp, act) + } + } + c(0) + c(0, 0) + c(10, 10) + c(10, 10, 0) + c(10, 0, 10) + c(5, 0, 5, 10) + c(5, 0, 5, 0) +} + +func TestSlices_BoolSlicesEqual(t *testing.T) { + bools := []interface{}{ + []bool{}, + []bool{false}, + []bool{true}, + []bool{false, false}, + []bool{false, false, true}, + []bool{true, false, true}, + []bool{false, false, false}, + } + testSlicesEquals(t, "Bool", bools, []bool{false, false, true}, bools[4], func(x, y interface{}) bool { + return BoolSlicesEqual(x.([]bool), y.([]bool)) + }) + if BoolSlicesEqual(nil, []bool{false}) || BoolSlicesEqual([]bool{false}, nil) { + t.Errorf("BoolSlicesEqual for a nil and slice with items should return false") + } + if !BoolSlicesEqual(nil, nil) || !BoolSlicesEqual(nil, []bool{}) { + t.Errorf("BoolSlicesEqual for a nil and empty slice should return true") + } +} + +func TestSlices_FloatSlicesEqual(t *testing.T) { + vals := []interface{}{ + []float64{}, + []float64{0}, + []float64{1, 2}, + []float64{3, 4, 5}, + []float64{2.0, 4, 5}, + []float64{3, 4, 4}, + } + testSlicesEquals(t, "Float64", vals, []float64{2.0, 4, 5}, vals[4], func(x, y interface{}) bool { + return Float64SlicesEqual(x.([]float64), y.([]float64)) + }) + if Float64SlicesEqual(nil, []float64{0}) || Float64SlicesEqual([]float64{0}, nil) { + t.Errorf("Float64SlicesEqual for a nil and slice with items should return false") + } + if !Float64SlicesEqual(nil, nil) || !Float64SlicesEqual(nil, []float64{}) { + t.Errorf("Float64SlicesEqual for a nil and empty slice should return true") + } +} + +func testSlicesEquals(t *testing.T, funcName string, vals []interface{}, goodVal1 interface{}, goodVal2 interface{}, equalsFunc func(x, y interface{}) bool) { + for i, x := range vals { + for j, y := range vals { + r := equalsFunc(x, y) + if (i == j) && !r { + t.Errorf("%vSlicesEqual for the same slice shouldn't return false! (%v,%v)", funcName, x, y) + } else if (i != j) && r { + t.Errorf("%vSlicesEqual for different slices should return false! (%v,%v)", funcName, x, y) + } + } + } + if !equalsFunc(goodVal1, goodVal2) { + t.Errorf("Different slices with the same contents should return true for %vSlicesEqual (%v,%v)", funcName, goodVal1, goodVal2) + } +} + +func TestSlices_StringsCoalesce(t *testing.T) { + assert.Equal(t, "", StringsCoalesce()) + assert.Equal(t, "1", StringsCoalesce("1", "2", "3")) + assert.Equal(t, "2", StringsCoalesce("", "2", "3")) + assert.Equal(t, "3", StringsCoalesce("", "", "3")) +} + +func TestSlices_Coalesce(t *testing.T) { + assert.Equal(t, "1", Coalesce("1", "2", "3")) + assert.Equal(t, "2", Coalesce("", "2", "3")) + assert.Equal(t, "3", Coalesce("", "", "3")) + + assert.Equal(t, []string{"1"}, Coalesce([]string{"1"}, []string{"2", "3"})) + assert.Equal(t, []string{""}, Coalesce([]string{""}, []string{"2", "3"})) + assert.Equal(t, []string{"2", "3"}, Coalesce([]string{}, []string{"2", "3"})) + var empty []string + assert.Equal(t, []string{"3"}, Coalesce(empty, empty, []string{"3"})) +} + +func TestSlices_StringUpto(t *testing.T) { + assert.Equal(t, "", StringUpto("", 0)) + assert.Equal(t, "", StringUpto("", 2)) + assert.Equal(t, "", StringUpto("11", 0)) + assert.Equal(t, "1", StringUpto("11", 1)) + assert.Equal(t, "11", StringUpto("11", 2)) + assert.Equal(t, "11", StringUpto("11", 3)) +} + +func Test_removeDuplicates(t *testing.T) { + dups := []string{"12", "45", "45", "78", "12", "porto"} + noDups := UniqueStrings(dups) + assert.Equal(t, len(noDups), 4) +} + +func TestNvlNumber(t *testing.T) { + assert.Equal(t, 1, NumbersCoalesce(0, 1)) + assert.Equal(t, uint64(1), NumbersCoalesce(0, uint64(1))) +} + +func TestChoise(t *testing.T) { + assert.Equal(t, 1, Select(false, 0, 1)) + assert.Equal(t, uint64(0), Select(true, 0, uint64(1))) +} diff --git a/slices/uint64s.go b/slices/uint64s.go new file mode 100644 index 0000000..95cb571 --- /dev/null +++ b/slices/uint64s.go @@ -0,0 +1,19 @@ +package slices + +// Uint64s is a slice of uint64, that knows how to be sorted, using sort.Sort +type Uint64s []uint64 + +// Len returns the length of the slice, as required by sort.Interface +func (a Uint64s) Len() int { + return len(a) +} + +// Less returns the true if the value at index i is smaller than the value at index j, as required by sort.Interface +func (a Uint64s) Less(i, j int) bool { + return a[i] < a[j] +} + +// Swap swaps the values at the indicated indexes, as required by sort.Interface +func (a Uint64s) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} diff --git a/slices/uint64s_test.go b/slices/uint64s_test.go new file mode 100644 index 0000000..59d7ec7 --- /dev/null +++ b/slices/uint64s_test.go @@ -0,0 +1,20 @@ +package slices + +import ( + "reflect" + "sort" + "testing" +) + +func TestUint64s_Sort(t *testing.T) { + c := func(src, exp Uint64s) { + sort.Sort(src) + if !reflect.DeepEqual(exp, src) { + t.Errorf("Expecting sorted to be %v, but was %v", exp, src) + } + } + c(Uint64s{5, 15, 6, 22, 1, 1}, Uint64s{1, 1, 5, 6, 15, 22}) + c(Uint64s{5}, Uint64s{5}) + c(Uint64s{}, Uint64s{}) + c(Uint64s{5, 5, 1, 1}, Uint64s{1, 1, 5, 5}) +} diff --git a/urlutil/urlutil.go b/urlutil/urlutil.go new file mode 100644 index 0000000..4ff2ed4 --- /dev/null +++ b/urlutil/urlutil.go @@ -0,0 +1,52 @@ +package urlutil + +import ( + "net/http" + "net/url" +) + +// XForwardedProtoHeader contains the protocol +const XForwardedProtoHeader = "X-Forwarded-Proto" + +// GetQueryString returns Query parameter +func GetQueryString(u *url.URL, name string) string { + vals, ok := u.Query()[name] + if !ok || len(vals) == 0 { + return "" + } + return vals[0] +} + +// GetValue returns a Query parameter +func GetValue(vals url.Values, name string) string { + v, ok := vals[name] + if !ok || len(v) == 0 { + return "" + } + return v[0] +} + +// GetPublicEndpointURL returns complete server URL for given relative end-point +func GetPublicEndpointURL(r *http.Request, relativeEndpoint string) *url.URL { + proto := r.URL.Scheme + + // Allow upstream proxies to specify the forwarded protocol. Allow this value + // to override our own guess. + if specifiedProto := r.Header.Get(XForwardedProtoHeader); specifiedProto != "" { + proto = specifiedProto + } + + host := r.URL.Host + if host == "" { + host = r.Host + } + if proto == "" { + proto = "https" + } + + return &url.URL{ + Scheme: proto, + Host: host, + Path: relativeEndpoint, + } +} diff --git a/urlutil/urlutil_test.go b/urlutil/urlutil_test.go new file mode 100644 index 0000000..749eebd --- /dev/null +++ b/urlutil/urlutil_test.go @@ -0,0 +1,40 @@ +package urlutil_test + +import ( + "net/http" + "net/url" + "testing" + + "github.com/effective-security/x/urlutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetQueryString(t *testing.T) { + u, err := url.Parse("http://localhost?q=test") + require.NoError(t, err) + + assert.Equal(t, "test", urlutil.GetQueryString(u, "q")) + assert.Equal(t, "", urlutil.GetQueryString(u, "p")) + + vals := u.Query() + assert.Equal(t, "test", urlutil.GetValue(vals, "q")) + assert.Equal(t, "", urlutil.GetValue(vals, "p")) +} + +func TestGetPublicServerURL(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "/v1/status", nil) + require.NoError(t, err) + + u := urlutil.GetPublicEndpointURL(r, "/v1").String() + assert.Equal(t, "https:///v1", u) + + r.URL.Scheme = "https" + r.Host = "martini.com:8443" + u = urlutil.GetPublicEndpointURL(r, "/v1").String() + assert.Equal(t, "https://martini.com:8443/v1", u) + + r.Header.Set(urlutil.XForwardedProtoHeader, "http") + u = urlutil.GetPublicEndpointURL(r, "/v1").String() + assert.Equal(t, "http://martini.com:8443/v1", u) +}