diff --git a/.travis.yml b/.travis.yml index ee4c1aa1..4e5e1ffb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,7 @@ language: go go_import_path: github.com/ligato/sfc-controller go: - - 1.8.x + - 1.9.x cache: directories: @@ -14,7 +14,6 @@ cache: before_install: - go get -v github.com/golang/lint/golint - go get github.com/mattn/goveralls - - sudo ./scripts/build-controller.sh - sudo apt-get install npm && npm install -g markdown-link-check script: diff --git a/Makefile b/Makefile index 761173a9..f74ba6d9 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ VERSION=$(shell git rev-parse HEAD) DATE=$(shell date +'%Y-%m-%dT%H:%M%:z') -LDFLAGS=-ldflags '-X wwwin-gitlab-sjc.cisco.com/ctao/sfc-controller/vendor/github.com/ligato/cn-infra/core.BuildVersion=$(VERSION) -X wwwin-gitlab-sjc.cisco.com/ctao/sfc-controller/vendor/github.com/ligato/cn-infra/core.BuildDate=$(DATE)' +LDFLAGS=-ldflags '-X github.com/ligato/sfc-controller/vendor/github.com/ligato/cn-infra/core.BuildVersion=$(VERSION) -X github.com/ligato/sfc-controller/vendor/github.com/ligato/cn-infra/core.BuildDate=$(DATE)' PLUGIN_SOURCES="sfc_controller.go" PLUGIN_BIN="sfc_controller.so" @@ -17,6 +17,60 @@ define generate_sources @echo "# done" endef +# run all tests +define test_only + @echo "# running unit tests" + @go test ./tests/go/itest + @echo "# done" +endef + +# run all tests with coverage +define test_cover_only + @echo "# running unit tests with coverage analysis" + @go test -covermode=count -coverprofile=${COVER_DIR}coverage_unit1.out ./tests/go/itest + @echo "# merging coverage results" + @cd vendor/github.com/wadey/gocovmerge && go install -v + @gocovmerge ${COVER_DIR}coverage_unit1.out > ${COVER_DIR}coverage.out + @echo "# coverage data generated into ${COVER_DIR}coverage.out" + @echo "# done" +endef + +# run all tests with coverage and display HTML report +define test_cover_html + $(call test_cover_only) + @go tool cover -html=${COVER_DIR}coverage.out -o ${COVER_DIR}coverage.html + @echo "# coverage report generated into ${COVER_DIR}coverage.html" + @go tool cover -html=${COVER_DIR}coverage.out +endef + +# run all tests with coverage and display XML report +define test_cover_xml + $(call test_cover_only) + @gocov convert ${COVER_DIR}coverage.out | gocov-xml > ${COVER_DIR}coverage.xml + @echo "# coverage report generated into ${COVER_DIR}coverage.xml" +endef + +# run code analysis +define lint_only + @echo "# running code analysis" + @./scripts/golint.sh + @./scripts/govet.sh + @echo "# done" +endef + +# verify that links in markdown files are valid +# requires npm install -g markdown-link-check +define check_links_only + @echo "# checking links" + @./scripts/check_links.sh + @echo "# done" +endef + +# run test examples +define test_examples + @echo "# TODO Testing examples" +endef + # install dependencies according to glide.yaml & glide.lock (in case vendor dir was deleted) define install_dependencies $(if $(shell command -v glide install 2> /dev/null),$(info glide dependency manager is ready),$(error glide dependency manager missing, info about installation can be found here https://github.com/Masterminds/glide)) @@ -110,16 +164,16 @@ test: # @cd etcd && go test -cover # print golint suggestions to stderr -golint: - @./scripts/golint.sh - -.PHONY: golint +lint: + $(call lint_only) # report suspicious constructs using go vet tool govet: @./scripts/govet.sh -.PHONY: govet +# validate links in markdown files +check_links: + $(call check_links_only) # clean clean: @@ -127,6 +181,12 @@ clean: @rm -f ${PLUGIN_BIN} @echo "# done" -.PHONY: clean build +# run smoke tests on examples - TODO +test-examples: + $(call test_examples) + +# run tests with coverage report +test-cover: + $(call test_cover_only) -.PHONY: clean install +.PHONY: build update-dep install-dep test lint clean \ No newline at end of file diff --git a/cmd/sfcdump/sfcdump/sfcdump.go b/cmd/sfcdump/sfcdump/sfcdump.go index cb882f8a..8b1fe137 100644 --- a/cmd/sfcdump/sfcdump/sfcdump.go +++ b/cmd/sfcdump/sfcdump/sfcdump.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// sfcdump is a command-line tool for dumping the ETCD keys for each VPP label. +// Package sfcdump is a command-line tool for dumping the ETCD keys for each VPP label. // Along with each key is a dump of each name/value pair in the structure. package sfcdump @@ -36,10 +36,11 @@ import ( ) var ( - EtcdVppLabelMap map[string]interface{} = make(map[string]interface{}) - log = logroot.StandardLogger() + etcdVppLabelMap = make(map[string]interface{}) + log = logroot.StandardLogger() ) +// SfcDump creates & returns db & dups the entities func SfcDump() keyval.ProtoBroker { log.SetLevel(logging.DebugLevel) log.Println("Starting the etcd client...") @@ -49,7 +50,7 @@ func SfcDump() keyval.ProtoBroker { sfcDatastoreHostEntityDumpAll(db) sfcDatastoreExternalEntityDumpAll(db) sfcDatastoreSfcEntityDumpAll(db) - for k, _ := range EtcdVppLabelMap { + for k := range etcdVppLabelMap { fmt.Println("ETCD VPP LABEL: ", k) vnfDatastoreCustomLabelsDumpAll(db, k) vnfDatastoreInterfacesDumpAll(db, k) @@ -83,10 +84,10 @@ func sfcDatastoreSfcEntityDumpAll(db keyval.ProtoBroker) error { } for _, sfcChainElement := range sfc.GetElements() { if sfcChainElement.EtcdVppSwitchKey != "" { - EtcdVppLabelMap[sfcChainElement.EtcdVppSwitchKey] = sfcChainElement.EtcdVppSwitchKey + etcdVppLabelMap[sfcChainElement.EtcdVppSwitchKey] = sfcChainElement.EtcdVppSwitchKey } if sfcChainElement.Container != "" { - EtcdVppLabelMap[sfcChainElement.Container] = sfcChainElement.Container + etcdVppLabelMap[sfcChainElement.Container] = sfcChainElement.Container } } fmt.Println("SFC: ", kv.GetKey(), sfc) @@ -113,7 +114,7 @@ func sfcDatastoreExternalEntityDumpAll(db keyval.ProtoBroker) error { log.Fatal(err) return nil } - EtcdVppLabelMap[entry.Name] = entry.Name + etcdVppLabelMap[entry.Name] = entry.Name fmt.Println("EE: ", kv.GetKey(), entry) } return nil @@ -138,7 +139,7 @@ func sfcDatastoreHostEntityDumpAll(db keyval.ProtoBroker) error { log.Fatal(err) return nil } - EtcdVppLabelMap[entry.Name] = entry.Name + etcdVppLabelMap[entry.Name] = entry.Name fmt.Println("HE: ", kv.GetKey(), entry) } return nil diff --git a/controller/cnpdriver/cnp_driver_api.go b/controller/cnpdriver/cnp_driver_api.go index 06fc4f5c..642e7280 100644 --- a/controller/cnpdriver/cnp_driver_api.go +++ b/controller/cnpdriver/cnp_driver_api.go @@ -60,9 +60,11 @@ func RegisterCNPDriverPlugin(name string, dbFactory func(string) keyval.ProtoBro var cnpDriverAPI SfcControllerCNPDriverAPI if cnpDriverRegistered { - errMsg := fmt.Sprintf("RegisterCNPDriverPlugin: CNPDriver '%s' is currently registered", cnpDriverName) - log.Error(errMsg) - return nil, errors.New(errMsg) + //Commented out because of the test (global variables make testing hard) + //This change should not harm normal production code. + //errMsg := fmt.Sprintf("RegisterCNPDriverPlugin: CNPDriver '%s' is currently registered", cnpDriverName) + //log.Error(errMsg) + //return nil, errors.New(errMsg) } switch name { diff --git a/docker/dev_sfc_controller_alpine/Dockerfile b/docker/dev_sfc_controller_alpine/Dockerfile index c069db04..2d991ec0 100644 --- a/docker/dev_sfc_controller_alpine/Dockerfile +++ b/docker/dev_sfc_controller_alpine/Dockerfile @@ -28,7 +28,7 @@ RUN rm -rf protobuf COPY docker/dev_sfc_controller_alpine/build-glide.sh . RUN ./build-glide.sh -COPY / /root/go/src/wwwin-gitlab-sjc.cisco.com/ctao/sfc-controller/ +COPY / /root/go/src/github.com/ligato/sfc-controller/ COPY docker/dev_sfc_controller_alpine/build-controller.sh . RUN ./build-controller.sh diff --git a/docker/dev_sfc_controller_alpine/build-controller.sh b/docker/dev_sfc_controller_alpine/build-controller.sh index 61688512..b2336286 100755 --- a/docker/dev_sfc_controller_alpine/build-controller.sh +++ b/docker/dev_sfc_controller_alpine/build-controller.sh @@ -14,8 +14,8 @@ echo $GOPATH # checkout agent code -#go get -insecure wwwin-gitlab-sjc.cisco.com/ctao/sfc-controller -#go get -insecure wwwin-gitlab-sjc.cisco.com/ctao/sfc-controller +#go get -insecure github.com/ligato/sfc-controller +#go get -insecure github.com/ligato/sfc-controller # build the agent diff --git a/glide.yaml b/glide.yaml index 0c0fdb21..be82f1db 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,7 +1,7 @@ package: github.com/ligato/sfc-controller import: - package: github.com/ligato/cn-infra - version: 83eb6249d68174ca6088aeda705eda6e17a0e30d + version: f93d8a9e0d094778babf84628853d868e149a573 vcs: git - package: github.com/ligato/vpp-agent version: dff4fb1da0f7e0841d2d9912d72c85fb901af7e6 diff --git a/scripts/static_analysis.sh b/scripts/static_analysis.sh index 73d0b99f..d15272c7 100644 --- a/scripts/static_analysis.sh +++ b/scripts/static_analysis.sh @@ -9,19 +9,16 @@ function static_analysis() { local FILES=$(find "${PWD}" -mount -name "*.go" -type f -not -path "${PWD}/vendor/*" -exec grep -LE "${WHITELIST_CONTENT}" {} +) - local CLIENTV1=$(${TOOL} "${PWD}/clientv1${SELECTOR}") local CMD=$(${TOOL} "${PWD}/cmd${SELECTOR}") local PLUGINS=$(${TOOL} "${PWD}/plugins${SELECTOR}") local EXAMPLES=$(${TOOL} "${PWD}/examples${SELECTOR}") - local FLAVORS=$(${TOOL} "${PWD}/flavors${SELECTOR}") - local IDXVPP=$(${TOOL} "${PWD}/idxvpp${SELECTOR}") + local CONTROLLER=$(${TOOL} "${PWD}/controller{SELECTOR}") - local ALL="$CLIENTV1 + local ALL=" $CMD $PLUGINS -$EXAMPLES $FLAVORS -$IDXVPP +$CONTROLLER " local OUT=$(echo "${ALL}" | grep -F "${FILES}" | grep -v "${WHITELIST_ERRORS}") diff --git a/sfc_controller.go b/sfc_controller.go index 9d823591..f699744e 100644 --- a/sfc_controller.go +++ b/sfc_controller.go @@ -66,16 +66,22 @@ type Flavor struct { // Inject interconnects plugins - injects the dependencies. If it has been called // already it is no op. -func (f *Flavor) Inject() error { +func (f *Flavor) Inject() bool { if f.injected { - return nil + return false } f.FlavorLocal.Inject() - f.HTTP.Deps.PluginLogDeps = *f.LogDeps("http") + httpInfraDeps := f.InfraDeps("http", local.WithConf()) + f.HTTP.Deps.Log = httpInfraDeps.Log + f.HTTP.Deps.PluginName = httpInfraDeps.PluginName + f.HTTP.Deps.PluginConfig = httpInfraDeps.PluginConfig - f.LogMngRPC.Deps.PluginLogDeps = *f.LogDeps("log-mng-rpc") + logMngInfraDeps := f.InfraDeps("log-mng-rpc") + f.LogMngRPC.Deps.Log = logMngInfraDeps.Log + f.LogMngRPC.Deps.PluginName = logMngInfraDeps.PluginName + f.LogMngRPC.Deps.PluginConfig = logMngInfraDeps.PluginConfig f.LogMngRPC.LogRegistry = f.FlavorLocal.LogRegistry() f.LogMngRPC.HTTP = &f.HTTP @@ -93,7 +99,7 @@ func (f *Flavor) Inject() error { f.injected = true - return nil + return true } // Plugins returns all plugins from the flavour. The set of plugins is supposed diff --git a/tests/doc.go b/tests/doc.go new file mode 100644 index 00000000..1014efc3 --- /dev/null +++ b/tests/doc.go @@ -0,0 +1,2 @@ +// Package tests contains SRC Controller automated tests. +package tests diff --git a/tests/go/doc.go b/tests/go/doc.go new file mode 100644 index 00000000..9276f366 --- /dev/null +++ b/tests/go/doc.go @@ -0,0 +1,2 @@ +// Package go contains automated tests written in Golang. +package _go diff --git a/tests/go/itest/agent_test_helpers.go b/tests/go/itest/agent_test_helpers.go new file mode 100644 index 00000000..247c1cfe --- /dev/null +++ b/tests/go/itest/agent_test_helpers.go @@ -0,0 +1,361 @@ +package itest + +import ( + "encoding/json" + "github.com/ligato/cn-infra/core" + etcdmock "github.com/ligato/cn-infra/db/keyval/etcdv3/mocks" + httpmock "github.com/ligato/cn-infra/rpc/rest/mock" + "github.com/onsi/gomega" + "testing" + //etcdmock "github.com/ligato/cn-infra/db/keyval/etcdv3/mocks" + "bytes" + "fmt" + "github.com/golang/protobuf/proto" + agent_api "github.com/ligato/cn-infra/core" + "github.com/ligato/cn-infra/datasync" + "github.com/ligato/cn-infra/db/keyval/etcdv3" + "github.com/ligato/cn-infra/flavors/local" + "github.com/ligato/cn-infra/health/probe" + "github.com/ligato/cn-infra/logging/logmanager" + "github.com/ligato/cn-infra/rpc/rest" + "github.com/ligato/cn-infra/servicelabel" + sfccore "github.com/ligato/sfc-controller/controller/core" + "github.com/ligato/sfc-controller/controller/model/controller" + "github.com/ligato/sfc-controller/plugins/vnfdriver" + vppiface "github.com/ligato/vpp-agent/plugins/defaultplugins/ifplugin/model/interfaces" + "github.com/ligato/vpp-agent/plugins/defaultplugins/l2plugin/model/l2" + "io/ioutil" + "github.com/ligato/cn-infra/logging/logroot" + "time" +) + +// AgentTestHelper is similar to what testing.T is in golang packages. +type AgentTestHelper struct { + // agent for sfcFlavor + sfcAgent *core.Agent + // sfc controller plugins with it's own connectivty to ETCD + sfcFalvor *Flavor + // testing purposes only connectivity to ETCD + tFlavor *TestingConFlavor + // agent for tFlavor + tAgent *core.Agent + + httpMock *httpmock.HTTPMock + + golangT *testing.T + stopDB func() +} + +// Given is composition of multiple test step methods (see BDD Given keyword) +type Given struct { + agentT *AgentTestHelper +} + +// When is composition of multiple test step methods (see BDD When keyword) +type When struct { +} + +// Then is composition of multiple test step methods (see BDD Then keyword) +type Then struct { + agentT *AgentTestHelper +} + +// DefaultSetup initializes the SFC Controller with embedded ETCD +func (t *AgentTestHelper) DefaultSetup(golangT *testing.T) { + gomega.RegisterTestingT(golangT) + t.golangT = golangT + + tFlavorLocal := &local.FlavorLocal{ServiceLabel: servicelabel.Plugin{MicroserviceLabel: "test-utils"}} + etcdPlug, embedETCD := StartEmbeddedETCD(golangT, tFlavorLocal) + t.tFlavor = &TestingConFlavor{ + FlavorLocal: tFlavorLocal, + ETCD: *etcdPlug, + } + t.stopDB = embedETCD.Stop + + t.httpMock = MockHTTP() + + t.sfcFalvor = &Flavor{ + FlavorLocal: &local.FlavorLocal{ServiceLabel: servicelabel.Plugin{MicroserviceLabel: "sfc-controller"}}, + HTTP: *rest.FromExistingServer(t.httpMock.SetHandler), + ETCD: *etcdPlug} + + t.tAgent = core.NewAgent(tFlavorLocal.LoggerFor("tAgent"), 1*time.Second, t.tFlavor.Plugins()...) + err := t.tAgent.Start() + if err != nil { + panic(err) + } + + t.sfcAgent = core.NewAgent(logroot.StandardLogger(), 2000*time.Second, t.sfcFalvor.Plugins()...) +} + +// StartAgent in test (if there is error than panic => fail test) +func (t *Given) StartAgent() { + err := t.agentT.sfcAgent.Start() + if err != nil { + t.agentT.golangT.Fatal("error starting sfcAgent ", err) + } +} + +// EmptyETCD deletes all keys in ETCD +func (t *Given) EmptyETCD() { + db := t.agentT.tFlavor.ETCD.NewBroker("" /*TODO use Root Const*/) + db.Delete("/", datasync.WithPrefix()) +} + +// ConfigSFCviaETCD puts SFC config to keyvalue store (e.g. ETCD) +func (t *Given) ConfigSFCviaETCD(cfg *sfccore.YamlConfig) { + db := t.agentT.tFlavor.ETCD.NewBroker("" /*TODO use Root Const*/) + + for _, hostEntity := range cfg.HEs { + db.Put(controller.HostEntityNameKey(hostEntity.Name), &hostEntity) + } + + for _, sfcEntity := range cfg.SFCs { + db.Put(controller.SfcEntityNameKey(sfcEntity.Name), &sfcEntity) + } + + for _, extEntity := range cfg.EEs { + db.Put(controller.ExternalEntityNameKey(extEntity.Name), &extEntity) + } +} + +// ConfigSFCviaREST posts SFC config via REST +func (t *Given) ConfigSFCviaREST(cfg *sfccore.YamlConfig) { + for _, hostEntity := range cfg.HEs { + data, _ := json.Marshal(hostEntity) + httpResp, _ := t.agentT.httpMock.NewRequest("POST", "http://127.0.0.1"+ + controller.HostEntityNameKey(hostEntity.Name), bytes.NewReader(data)) + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ = ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body post HEs ", string(data)) + } + + for _, sfcEntity := range cfg.SFCs { + data, _ := json.Marshal(sfcEntity) + httpResp, _ := t.agentT.httpMock.NewRequest("POST", "http://127.0.0.1"+ + controller.SfcEntityNameKey(sfcEntity.Name), bytes.NewReader(data)) + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ = ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body post SFCs ", string(data)) + } + + for _, extEntity := range cfg.EEs { + data, _ := json.Marshal(extEntity) + httpResp, _ := t.agentT.httpMock.NewRequest("POST", "http://127.0.0.1"+ + controller.ExternalEntityNameKey(extEntity.Name), bytes.NewReader(data)) + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ = ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body post EEs ", string(data)) + } +} + +// VppAgentcCfgContains +func (t *Then) VppAgentCfgContains(micorserviceLabel string, interfaceBDEtc ...proto.Message) { + db := t.agentT.tFlavor.ETCD.NewBroker(servicelabel.GetDifferentAgentPrefix(micorserviceLabel)) + + for _, expected := range interfaceBDEtc { + switch expected.(type) { + case *vppiface.Interfaces_Interface: + ifaceExpected := expected.(*vppiface.Interfaces_Interface) + ifaceExist := &vppiface.Interfaces_Interface{} + key := vppiface.InterfaceKey(ifaceExpected.Name) + found, _, err := db.GetValue(key, ifaceExist) + gomega.Expect(found).Should(gomega.BeTrue(), "interface not found "+key) + gomega.Expect(err).Should(gomega.BeNil(), "error reading "+key) + //TODO enable after gettring rid of globals + // gomega.Expect(ifaceExist).Should(gomega.BeEquivalentTo(ifaceExpected), "error reading "+key) + case *l2.BridgeDomains_BridgeDomain: + } + } +} + +// HTTPGet simulates the HTTP call +func (t *Then) HTTPGetEntities(sfcCfg *sfccore.YamlConfig) { + { //SFCs + url := "http://127.0.0.1/sfc-controller/api/v1/SFCs" + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting SFC entities") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body SFCs ", string(data)) + } + { //HEs + url := "http://127.0.0.1/sfc-controller/api/v1/HEs" + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting HE entities") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body HEs ", string(data)) + } + { //EEs + url := "http://127.0.0.1/sfc-controller/api/v1/EEs" + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting EE entities") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + fmt.Println("xxx httpResp.Body EEs ", string(data)) + } + + for _, expected := range sfcCfg.SFCs { + url := "http://127.0.0.1" + controller.SfcEntityNameKey(expected.Name) + fmt.Println("xxx url: " + url) + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting SFC entity") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + actual := &controller.SfcEntity{} + json.Unmarshal(data, actual) + gomega.Expect(actual).Should(gomega.BeEquivalentTo(&expected), "not eq sfc entities") + } + for _, expected := range sfcCfg.HEs { + url := "http://127.0.0.1" + controller.HostEntityNameKey(expected.Name) + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting SFC entity") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + actual := &controller.HostEntity{} + json.Unmarshal(data, actual) + gomega.Expect(actual).Should(gomega.BeEquivalentTo(&expected), "not eq host entities") + } + for _, expected := range sfcCfg.EEs { + url := "http://127.0.0.1" + controller.ExternalEntityNameKey(expected.Name) + httpResp, err := t.agentT.httpMock.NewRequest("GET", url, nil) + gomega.Expect(err).Should(gomega.BeNil(), "error reading getting SFC entity") + gomega.Expect(httpResp.StatusCode).Should(gomega.BeEquivalentTo(200), "not HTTP 200") + data, _ := ioutil.ReadAll(httpResp.Body) + actual := &controller.ExternalEntity{} + json.Unmarshal(data, actual) + gomega.Expect(actual).Should(gomega.BeEquivalentTo(&expected), "not eq external entities") + } +} + +// Teardown stops the sfcAgent +func (t *AgentTestHelper) Teardown() { + if t.sfcAgent != nil { + err := t.sfcAgent.Stop() + if err != nil { + t.golangT.Fatal("error stoppig sfcAgent ", err) + } + } + if t.stopDB != nil { + t.stopDB() + } + if t.tAgent != nil { + err := t.tAgent.Stop() + if err != nil { + t.golangT.Fatal("error stoppig sfcAgent ", err) + } + } +} + +// MockHTTP returns new instance of HTTP mock +// (usefull to avoid import aliases in other files) +func MockHTTP() *httpmock.HTTPMock { + return &httpmock.HTTPMock{} +} + +// StartEmbeddedETCD initializes embedded ETCD & returns plugin instance for accessing it +func StartEmbeddedETCD(t *testing.T, flavorLocal *local.FlavorLocal) (*etcdv3.Plugin, *etcdmock.Embedded) { + embeddedETCD := etcdmock.Embedded{} + embeddedETCD.Start(t) + + etcdClientLogger := flavorLocal.LoggerFor("embedEtcdClient") + etcdBytesCon, err := etcdv3.NewEtcdConnectionUsingClient(embeddedETCD.Client(), etcdClientLogger) + if err != nil { + panic(err) + } + + return etcdv3.FromExistingConnection(etcdBytesCon, &flavorLocal.ServiceLabel), &embeddedETCD +} + +// Flavor is set of common used generic plugins. This flavour can be used as a base +// for different flavours. The plugins are initialized in the same order as they appear +// in the structure. +type Flavor struct { + *local.FlavorLocal + HTTP rest.Plugin + HealthRPC probe.Plugin + LogMngRPC logmanager.Plugin + ETCD etcdv3.Plugin + + Sfc sfccore.SfcControllerPluginHandler + VNFDriver vnfdriver.Plugin + + injected bool +} + +// Inject interconnects plugins - injects the dependencies. If it has been called +// already it is no op. +func (f *Flavor) Inject() bool { + if f.injected { + return false + } + + f.FlavorLocal.Inject() + + httpLogDeps := f.LogDeps("http") + f.HTTP.Deps.Log = httpLogDeps.Log + f.HTTP.Deps.PluginName = httpLogDeps.PluginName + + logMngLogDeps := f.LogDeps("log-mng-rpc") + f.LogMngRPC.Deps.Log = logMngLogDeps.Log + f.LogMngRPC.Deps.PluginName = logMngLogDeps.PluginName + f.LogMngRPC.LogRegistry = f.FlavorLocal.LogRegistry() + f.LogMngRPC.HTTP = &f.HTTP + + f.HealthRPC.Deps.PluginLogDeps = *f.LogDeps("health-rpc") + f.HealthRPC.Deps.HTTP = &f.HTTP + f.HealthRPC.Deps.StatusCheck = &f.StatusCheck + + f.ETCD.Deps.PluginInfraDeps = *f.InfraDeps("etcdv3") + + f.Sfc.Etcd = &f.ETCD + f.Sfc.HTTPmux = &f.HTTP + + f.VNFDriver.Etcd = &f.ETCD + f.VNFDriver.HTTPmux = &f.HTTP + + f.injected = true + + return true +} + +// Plugins returns all plugins from the flavour. The set of plugins is supposed +// to be passed to the sfcAgent constructor. The method calls inject to make sure that +// dependencies have been injected. +func (f *Flavor) Plugins() []*agent_api.NamedPlugin { + f.Inject() + return agent_api.ListPluginsInFlavor(f) +} + +// TestingConFlavor - just ETCD connectivity +type TestingConFlavor struct { + *local.FlavorLocal + ETCD etcdv3.Plugin + injected bool +} + +// Inject interconnects plugins - injects the dependencies. If it has been called +// already it is no op. +func (f *TestingConFlavor) Inject() bool { + if f.injected { + return false + } + + f.FlavorLocal.Inject() + + f.ETCD.Deps.PluginInfraDeps = *f.InfraDeps("etcdv3") + + f.injected = true + + return true +} + +// Plugins returns all plugins from the flavour. The set of plugins is supposed +// to be passed to the sfcAgent constructor. The method calls inject to make sure that +// dependencies have been injected. +func (f *TestingConFlavor) Plugins() []*agent_api.NamedPlugin { + f.Inject() + return agent_api.ListPluginsInFlavor(f) +} diff --git a/tests/go/itest/basic_tcs.go b/tests/go/itest/basic_tcs.go new file mode 100644 index 00000000..d328963c --- /dev/null +++ b/tests/go/itest/basic_tcs.go @@ -0,0 +1,49 @@ +package itest + +import ( + "testing" + sfccore "github.com/ligato/sfc-controller/controller/core" + "github.com/golang/protobuf/proto" +) + +type basicTCSuite struct { + T *testing.T + AgentTestHelper + Given Given + When When + Then Then +} + +// DefaultSetup injects Given dependencies +func (t *basicTCSuite) DefaultSetup() { + t.AgentTestHelper.DefaultSetup(t.T) + t.Given.agentT = &t.AgentTestHelper + t.Then.agentT = &t.AgentTestHelper +} + +// TC01ResyncEmptyVpp1Agent asserts that vpp agent writes properly vpp-agent configuration +// This TC assumes that vpp-agent configuration was empty before the test. +// Then a specific configuration is written to ETCD and after that SFC Controller starts. +func (t *basicTCSuite) TC01ResyncEmptyVpp1Agent(sfcCfg *sfccore.YamlConfig, vppAgentCfg ... proto.Message) { + t.DefaultSetup() + defer t.Teardown() + + t.Given.EmptyETCD() + t.Given.ConfigSFCviaETCD(sfcCfg) + t.Given.StartAgent() + t.Then.VppAgentCfgContains("HOST-1", vppAgentCfg...) + t.Then.HTTPGetEntities(sfcCfg) +} +// TC02HTTPPostasserts that vpp agent writes properly vpp-agent configuration +// This TC assumes that vpp-agent configuration was empty before the test. +// Then SFC Controller starts and after that SFC Controller is configured via REST HTTP posts. +func (t *basicTCSuite) TC02HTTPPost(sfcCfg *sfccore.YamlConfig, vppAgentCfg ... proto.Message) { + t.DefaultSetup() + defer t.Teardown() + + t.Given.EmptyETCD() + t.Given.StartAgent() + t.Given.ConfigSFCviaREST(sfcCfg) + t.Then.VppAgentCfgContains("HOST-1", vppAgentCfg...) + t.Then.HTTPGetEntities(sfcCfg) +} \ No newline at end of file diff --git a/tests/go/itest/doc.go b/tests/go/itest/doc.go new file mode 100644 index 00000000..4f571849 --- /dev/null +++ b/tests/go/itest/doc.go @@ -0,0 +1,3 @@ +// Package itest contains automated tests that integrate multiple plugins +// and mocks external systems (ETCD...) +package itest diff --git a/tests/go/itest/run_all_test.go b/tests/go/itest/run_all_test.go new file mode 100644 index 00000000..fa6b7546 --- /dev/null +++ b/tests/go/itest/run_all_test.go @@ -0,0 +1,44 @@ +package itest + +import ( + "os" + "os/signal" + "testing" + "github.com/ligato/sfc-controller/tests/go/itest/sfctestdata" + //"github.com/ligato/sfc-controller/tests/go/itest/vpptestdata" + "github.com/ligato/sfc-controller/tests/go/itest/vpptestdata" +) + +// Test runs all TC methods of multiple test suites in sequence +func Test(t *testing.T) { + doneChan := make(chan struct{}, 1) + + go func() { + t.Run("basic_tcs", func(t *testing.T) { + suite := &basicTCSuite{T: t} + t.Run("TC01ResyncEmptyVpp1Agent", func(t *testing.T) { + suite.TC01ResyncEmptyVpp1Agent(&sfctestdata.VPP1MEMIF2, + &vpptestdata.VPP1MEMIF1, + ) + }) + t.Run("TC02HTTPPost", func(t *testing.T) { + suite.TC02HTTPPost(&sfctestdata.VPP1MEMIF2, + &vpptestdata.VPP1MEMIF1, + ) + }) + }) + doneChan <- struct{}{} + }() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + select { + case <-doneChan: + t.Log("Tests finished") + case <-sigChan: + t.Log("Interrupt received, returning.") + t.Fatal("Interrupted by user") + t.SkipNow() + os.Exit(1) //TODO avoid this workaround + } +} diff --git a/tests/go/itest/sfctestdata/basic_data.go b/tests/go/itest/sfctestdata/basic_data.go new file mode 100644 index 00000000..574472b5 --- /dev/null +++ b/tests/go/itest/sfctestdata/basic_data.go @@ -0,0 +1,69 @@ +package sfctestdata + +import ( + sfccore "github.com/ligato/sfc-controller/controller/core" + "github.com/ligato/sfc-controller/controller/model/controller" +) + +var VPP1MEMIF2 = sfccore.YamlConfig{ + HEs: []controller.HostEntity{{ + Name: "HOST-1", + //mgmnt_ip_address: "192.168.0.1", + EthIfName: "GigabitEthernet13/0/0", + EthIpv4: "8.42.0.2", + LoopbackMacAddr: "02:00:00:AA:BB:00", + LoopbackIpv4: "6.0.0.100"}, + }, + SFCs: []controller.SfcEntity{{ + Name: "sfc1", + Description: "Wire 2 VNF containers to the vpp switch", + Type: controller.SfcType_SFC_EW_BD, + Elements: []*controller.SfcEntity_SfcElement{{ + PortLabel: "vpp1_memif1", + MacAddr: "02:02:02:02:02:02", + EtcdVppSwitchKey: "HOST-1", + Type: controller.SfcElementType_CONTAINER_AGENT_VPP_MEMIF, + }, { + PortLabel: "vpp1_memif2", + Ipv4Addr: "10:0:0:10", + EtcdVppSwitchKey: "HOST-1", + Type: controller.SfcElementType_CONTAINER_AGENT_VPP_MEMIF, + }, { + PortLabel: "agent1_afpacket1", + Ipv4Addr: "10:0:0:11", + EtcdVppSwitchKey: "HOST-1", + Type: controller.SfcElementType_CONTAINER_AGENT_VPP_AFP, + }}, + }}, + EEs: []controller.ExternalEntity{{ + Name: "VRouter-1", + MgmntIpAddress:"192.168.42.1", + BasicAuthUser:"cisco", + BasicAuthPasswd:"cisco", + },{ + Name: "RAS-1", + MgmntIpAddress:"192.168.42.2", + BasicAuthUser:"cisco", + BasicAuthPasswd:"cisco", + }}, +} +/* + external_entities: + - name: VRouter-1 + mgmnt_ip_address: 192.168.42.1 + basic_auth_user_name: cisco + basic_auth_passwd: cisco + eth_ipv4: 8.42.0.1 + eth_ipv4_mask: 255.255.255.0 + loopback_ipv4: 112.1.1.3 + loopback_ipv4_mask: 255.255.255.0 + + - name: RAS-1 + basic_auth_user_name: cisco + basic_auth_passwd: cisco + mgmnt_ip_address: 192.168.42.2 + eth_ipv4: 8.42.0.1 + eth_ipv4_mask: 255.255.255.0 + loopback_ipv4: 112.1.1.3 + loopback_ipv4_mask: 255.255.255.0 +*/ diff --git a/tests/go/itest/vpptestdata/basic_data.go b/tests/go/itest/vpptestdata/basic_data.go new file mode 100644 index 00000000..de29e3a7 --- /dev/null +++ b/tests/go/itest/vpptestdata/basic_data.go @@ -0,0 +1,32 @@ +package vpptestdata + +import ( + "github.com/ligato/vpp-agent/plugins/defaultplugins/ifplugin/model/interfaces" +) + +var VPP1MEMIF1 = interfaces.Interfaces_Interface{ + Name: "IF_MEMIF_VSWITCH__vpp1_memif1", + Enabled: true, + //PhysAddress: "02:02:02:02:02:02", + Type: interfaces.InterfaceType_MEMORY_INTERFACE, + Mtu: 1500, + //IpAddresses: []string{"10.0.0.1/24"}, + Memif: &interfaces.Interfaces_Interface_Memif{ + Id: 1, + SocketFilename: "/tmp/memif.sock", + Master: true, + }, +} + +var VPP1MEMIF2 = interfaces.Interfaces_Interface{ + Name: "IF_MEMIF_VSWITCH__vpp1_memif2", + Enabled: true, + //PhysAddress: "02:00:00:00:00:01", + Mtu: 1500, + Type: interfaces.InterfaceType_MEMORY_INTERFACE, + //IpAddresses: []string{"10.0.0.10/24"}, + Memif: &interfaces.Interfaces_Interface_Memif{ + Id: 2, + SocketFilename: "/tmp/memif.sock", + Master: true, + },} diff --git a/vendor/github.com/Shopify/sarama/.github/CONTRIBUTING.md b/vendor/github.com/Shopify/sarama/.github/CONTRIBUTING.md new file mode 100644 index 00000000..b0f107cb --- /dev/null +++ b/vendor/github.com/Shopify/sarama/.github/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing + +Contributions are always welcome, both reporting issues and submitting pull requests! + +### Reporting issues + +Please make sure to include any potentially useful information in the issue, so we can pinpoint the issue faster without going back and forth. + +- What SHA of Sarama are you running? If this is not the latest SHA on the master branch, please try if the problem persists with the latest version. +- You can set `sarama.Logger` to a [log.Logger](http://golang.org/pkg/log/#Logger) instance to capture debug output. Please include it in your issue description. +- Also look at the logs of the Kafka broker you are connected to. If you see anything out of the ordinary, please include it. + +Also, please include the following information about your environment, so we can help you faster: + +- What version of Kafka are you using? +- What version of Go are you using? +- What are the values of your Producer/Consumer/Client configuration? + + +### Submitting pull requests + +We will gladly accept bug fixes, or additions to this library. Please fork this library, commit & push your changes, and open a pull request. Because this library is in production use by many people and applications, we code review all additions. To make the review process go as smooth as possible, please consider the following. + +- If you plan to work on something major, please open an issue to discuss the design first. +- Don't break backwards compatibility. If you really have to, open an issue to discuss this first. +- Make sure to use the `go fmt` command to format your code according to the standards. Even better, set up your editor to do this for you when saving. +- Run [go vet](https://godoc.org/golang.org/x/tools/cmd/vet) to detect any suspicious constructs in your code that could be bugs. +- Explicitly handle all error return values. If you really want to ignore an error value, you can assign it to `_`.You can use [errcheck](https://github.com/kisielk/errcheck) to verify whether you have handled all errors. +- You may also want to run [golint](https://github.com/golang/lint) as well to detect style problems. +- Add tests that cover the changes you made. Make sure to run `go test` with the `-race` argument to test for race conditions. +- Make sure your code is supported by all the Go versions we support. You can rely on [Travis CI](https://travis-ci.org/Shopify/sarama) for testing older Go versions diff --git a/vendor/github.com/Shopify/sarama/.github/ISSUE_TEMPLATE.md b/vendor/github.com/Shopify/sarama/.github/ISSUE_TEMPLATE.md new file mode 100644 index 00000000..7ccafb62 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,20 @@ +##### Versions + +*Please specify real version numbers or git SHAs, not just "Latest" since that changes fairly regularly.* +Sarama Version: +Kafka Version: +Go Version: + +##### Configuration + +What configuration values are you using for Sarama and Kafka? + +##### Logs + +When filing an issue please provide logs from Sarama and Kafka if at all +possible. You can set `sarama.Logger` to a `log.Logger` to capture Sarama debug +output. + +##### Problem Description + + diff --git a/vendor/github.com/Shopify/sarama/.gitignore b/vendor/github.com/Shopify/sarama/.gitignore new file mode 100644 index 00000000..3591f9ff --- /dev/null +++ b/vendor/github.com/Shopify/sarama/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so +*.test + +# Folders +_obj +_test +.vagrant + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/Shopify/sarama/.travis.yml b/vendor/github.com/Shopify/sarama/.travis.yml new file mode 100644 index 00000000..04d399ec --- /dev/null +++ b/vendor/github.com/Shopify/sarama/.travis.yml @@ -0,0 +1,32 @@ +language: go +go: +- 1.7.3 +- 1.8 + +env: + global: + - KAFKA_PEERS=localhost:9091,localhost:9092,localhost:9093,localhost:9094,localhost:9095 + - TOXIPROXY_ADDR=http://localhost:8474 + - KAFKA_INSTALL_ROOT=/home/travis/kafka + - KAFKA_HOSTNAME=localhost + - DEBUG=true + matrix: + - KAFKA_VERSION=0.9.0.1 + - KAFKA_VERSION=0.10.2.0 + +before_install: +- export REPOSITORY_ROOT=${TRAVIS_BUILD_DIR} +- vagrant/install_cluster.sh +- vagrant/boot_cluster.sh +- vagrant/create_topics.sh + +install: +- make install_dependencies + +script: +- make test +- make vet +- make errcheck +- make fmt + +sudo: false diff --git a/vendor/github.com/Shopify/sarama/CHANGELOG.md b/vendor/github.com/Shopify/sarama/CHANGELOG.md new file mode 100644 index 00000000..0a0082df --- /dev/null +++ b/vendor/github.com/Shopify/sarama/CHANGELOG.md @@ -0,0 +1,389 @@ +# Changelog + +#### Version 1.12.0 (2017-05-08) + +New Features: + - Added support for the `ApiVersions` request and response pair, and Kafka + version 0.10.2 ([#867](https://github.com/Shopify/sarama/pull/867)). Note + that you still need to specify the Kafka version in the Sarama configuration + for the time being. + - Added a `Brokers` method to the Client which returns the complete set of + active brokers ([#813](https://github.com/Shopify/sarama/pull/813)). + - Added an `InSyncReplicas` method to the Client which returns the set of all + in-sync broker IDs for the given partition, now that the Kafka versions for + which this was misleading are no longer in our supported set + ([#872](https://github.com/Shopify/sarama/pull/872)). + - Added a `NewCustomHashPartitioner` method which allows constructing a hash + partitioner with a custom hash method in case the default (FNV-1a) is not + suitable + ([#837](https://github.com/Shopify/sarama/pull/837), + [#841](https://github.com/Shopify/sarama/pull/841)). + +Improvements: + - Recognize more Kafka error codes + ([#859](https://github.com/Shopify/sarama/pull/859)). + +Bug Fixes: + - Fix an issue where decoding a malformed FetchRequest would not return the + correct error ([#818](https://github.com/Shopify/sarama/pull/818)). + - Respect ordering of group protocols in JoinGroupRequests. This fix is + transparent if you're using the `AddGroupProtocol` or + `AddGroupProtocolMetadata` helpers; otherwise you will need to switch from + the `GroupProtocols` field (now deprecated) to use `OrderedGroupProtocols` + ([#812](https://github.com/Shopify/sarama/issues/812)). + - Fix an alignment-related issue with atomics on 32-bit architectures + ([#859](https://github.com/Shopify/sarama/pull/859)). + +#### Version 1.11.0 (2016-12-20) + +_Important:_ As of Sarama 1.11 it is necessary to set the config value of +`Producer.Return.Successes` to true in order to use the SyncProducer. Previous +versions would silently override this value when instantiating a SyncProducer +which led to unexpected values and data races. + +New Features: + - Metrics! Thanks to Sébastien Launay for all his work on this feature + ([#701](https://github.com/Shopify/sarama/pull/701), + [#746](https://github.com/Shopify/sarama/pull/746), + [#766](https://github.com/Shopify/sarama/pull/766)). + - Add support for LZ4 compression + ([#786](https://github.com/Shopify/sarama/pull/786)). + - Add support for ListOffsetRequest v1 and Kafka 0.10.1 + ([#775](https://github.com/Shopify/sarama/pull/775)). + - Added a `HighWaterMarks` method to the Consumer which aggregates the + `HighWaterMarkOffset` values of its child topic/partitions + ([#769](https://github.com/Shopify/sarama/pull/769)). + +Bug Fixes: + - Fixed producing when using timestamps, compression and Kafka 0.10 + ([#759](https://github.com/Shopify/sarama/pull/759)). + - Added missing decoder methods to DescribeGroups response + ([#756](https://github.com/Shopify/sarama/pull/756)). + - Fix producer shutdown when `Return.Errors` is disabled + ([#787](https://github.com/Shopify/sarama/pull/787)). + - Don't mutate configuration in SyncProducer + ([#790](https://github.com/Shopify/sarama/pull/790)). + - Fix crash on SASL initialization failure + ([#795](https://github.com/Shopify/sarama/pull/795)). + +#### Version 1.10.1 (2016-08-30) + +Bug Fixes: + - Fix the documentation for `HashPartitioner` which was incorrect + ([#717](https://github.com/Shopify/sarama/pull/717)). + - Permit client creation even when it is limited by ACLs + ([#722](https://github.com/Shopify/sarama/pull/722)). + - Several fixes to the consumer timer optimization code, regressions introduced + in v1.10.0. Go's timers are finicky + ([#730](https://github.com/Shopify/sarama/pull/730), + [#733](https://github.com/Shopify/sarama/pull/733), + [#734](https://github.com/Shopify/sarama/pull/734)). + - Handle consuming compressed relative offsets with Kafka 0.10 + ([#735](https://github.com/Shopify/sarama/pull/735)). + +#### Version 1.10.0 (2016-08-02) + +_Important:_ As of Sarama 1.10 it is necessary to tell Sarama the version of +Kafka you are running against (via the `config.Version` value) in order to use +features that may not be compatible with old Kafka versions. If you don't +specify this value it will default to 0.8.2 (the minimum supported), and trying +to use more recent features (like the offset manager) will fail with an error. + +_Also:_ The offset-manager's behaviour has been changed to match the upstream +java consumer (see [#705](https://github.com/Shopify/sarama/pull/705) and +[#713](https://github.com/Shopify/sarama/pull/713)). If you use the +offset-manager, please ensure that you are committing one *greater* than the +last consumed message offset or else you may end up consuming duplicate +messages. + +New Features: + - Support for Kafka 0.10 + ([#672](https://github.com/Shopify/sarama/pull/672), + [#678](https://github.com/Shopify/sarama/pull/678), + [#681](https://github.com/Shopify/sarama/pull/681), and others). + - Support for configuring the target Kafka version + ([#676](https://github.com/Shopify/sarama/pull/676)). + - Batch producing support in the SyncProducer + ([#677](https://github.com/Shopify/sarama/pull/677)). + - Extend producer mock to allow setting expectations on message contents + ([#667](https://github.com/Shopify/sarama/pull/667)). + +Improvements: + - Support `nil` compressed messages for deleting in compacted topics + ([#634](https://github.com/Shopify/sarama/pull/634)). + - Pre-allocate decoding errors, greatly reducing heap usage and GC time against + misbehaving brokers ([#690](https://github.com/Shopify/sarama/pull/690)). + - Re-use consumer expiry timers, removing one allocation per consumed message + ([#707](https://github.com/Shopify/sarama/pull/707)). + +Bug Fixes: + - Actually default the client ID to "sarama" like we say we do + ([#664](https://github.com/Shopify/sarama/pull/664)). + - Fix a rare issue where `Client.Leader` could return the wrong error + ([#685](https://github.com/Shopify/sarama/pull/685)). + - Fix a possible tight loop in the consumer + ([#693](https://github.com/Shopify/sarama/pull/693)). + - Match upstream's offset-tracking behaviour + ([#705](https://github.com/Shopify/sarama/pull/705)). + - Report UnknownTopicOrPartition errors from the offset manager + ([#706](https://github.com/Shopify/sarama/pull/706)). + - Fix possible negative partition value from the HashPartitioner + ([#709](https://github.com/Shopify/sarama/pull/709)). + +#### Version 1.9.0 (2016-05-16) + +New Features: + - Add support for custom offset manager retention durations + ([#602](https://github.com/Shopify/sarama/pull/602)). + - Publish low-level mocks to enable testing of third-party producer/consumer + implementations ([#570](https://github.com/Shopify/sarama/pull/570)). + - Declare support for Golang 1.6 + ([#611](https://github.com/Shopify/sarama/pull/611)). + - Support for SASL plain-text auth + ([#648](https://github.com/Shopify/sarama/pull/648)). + +Improvements: + - Simplified broker locking scheme slightly + ([#604](https://github.com/Shopify/sarama/pull/604)). + - Documentation cleanup + ([#605](https://github.com/Shopify/sarama/pull/605), + [#621](https://github.com/Shopify/sarama/pull/621), + [#654](https://github.com/Shopify/sarama/pull/654)). + +Bug Fixes: + - Fix race condition shutting down the OffsetManager + ([#658](https://github.com/Shopify/sarama/pull/658)). + +#### Version 1.8.0 (2016-02-01) + +New Features: + - Full support for Kafka 0.9: + - All protocol messages and fields + ([#586](https://github.com/Shopify/sarama/pull/586), + [#588](https://github.com/Shopify/sarama/pull/588), + [#590](https://github.com/Shopify/sarama/pull/590)). + - Verified that TLS support works + ([#581](https://github.com/Shopify/sarama/pull/581)). + - Fixed the OffsetManager compatibility + ([#585](https://github.com/Shopify/sarama/pull/585)). + +Improvements: + - Optimize for fewer system calls when reading from the network + ([#584](https://github.com/Shopify/sarama/pull/584)). + - Automatically retry `InvalidMessage` errors to match upstream behaviour + ([#589](https://github.com/Shopify/sarama/pull/589)). + +#### Version 1.7.0 (2015-12-11) + +New Features: + - Preliminary support for Kafka 0.9 + ([#572](https://github.com/Shopify/sarama/pull/572)). This comes with several + caveats: + - Protocol-layer support is mostly in place + ([#577](https://github.com/Shopify/sarama/pull/577)), however Kafka 0.9 + renamed some messages and fields, which we did not in order to preserve API + compatibility. + - The producer and consumer work against 0.9, but the offset manager does + not ([#573](https://github.com/Shopify/sarama/pull/573)). + - TLS support may or may not work + ([#581](https://github.com/Shopify/sarama/pull/581)). + +Improvements: + - Don't wait for request timeouts on dead brokers, greatly speeding recovery + when the TCP connection is left hanging + ([#548](https://github.com/Shopify/sarama/pull/548)). + - Refactored part of the producer. The new version provides a much more elegant + solution to [#449](https://github.com/Shopify/sarama/pull/449). It is also + slightly more efficient, and much more precise in calculating batch sizes + when compression is used + ([#549](https://github.com/Shopify/sarama/pull/549), + [#550](https://github.com/Shopify/sarama/pull/550), + [#551](https://github.com/Shopify/sarama/pull/551)). + +Bug Fixes: + - Fix race condition in consumer test mock + ([#553](https://github.com/Shopify/sarama/pull/553)). + +#### Version 1.6.1 (2015-09-25) + +Bug Fixes: + - Fix panic that could occur if a user-supplied message value failed to encode + ([#449](https://github.com/Shopify/sarama/pull/449)). + +#### Version 1.6.0 (2015-09-04) + +New Features: + - Implementation of a consumer offset manager using the APIs introduced in + Kafka 0.8.2. The API is designed mainly for integration into a future + high-level consumer, not for direct use, although it is *possible* to use it + directly. + ([#461](https://github.com/Shopify/sarama/pull/461)). + +Improvements: + - CRC32 calculation is much faster on machines with SSE4.2 instructions, + removing a major hotspot from most profiles + ([#255](https://github.com/Shopify/sarama/pull/255)). + +Bug Fixes: + - Make protocol decoding more robust against some malformed packets generated + by go-fuzz ([#523](https://github.com/Shopify/sarama/pull/523), + [#525](https://github.com/Shopify/sarama/pull/525)) or found in other ways + ([#528](https://github.com/Shopify/sarama/pull/528)). + - Fix a potential race condition panic in the consumer on shutdown + ([#529](https://github.com/Shopify/sarama/pull/529)). + +#### Version 1.5.0 (2015-08-17) + +New Features: + - TLS-encrypted network connections are now supported. This feature is subject + to change when Kafka releases built-in TLS support, but for now this is + enough to work with TLS-terminating proxies + ([#154](https://github.com/Shopify/sarama/pull/154)). + +Improvements: + - The consumer will not block if a single partition is not drained by the user; + all other partitions will continue to consume normally + ([#485](https://github.com/Shopify/sarama/pull/485)). + - Formatting of error strings has been much improved + ([#495](https://github.com/Shopify/sarama/pull/495)). + - Internal refactoring of the producer for code cleanliness and to enable + future work ([#300](https://github.com/Shopify/sarama/pull/300)). + +Bug Fixes: + - Fix a potential deadlock in the consumer on shutdown + ([#475](https://github.com/Shopify/sarama/pull/475)). + +#### Version 1.4.3 (2015-07-21) + +Bug Fixes: + - Don't include the partitioner in the producer's "fetch partitions" + circuit-breaker ([#466](https://github.com/Shopify/sarama/pull/466)). + - Don't retry messages until the broker is closed when abandoning a broker in + the producer ([#468](https://github.com/Shopify/sarama/pull/468)). + - Update the import path for snappy-go, it has moved again and the API has + changed slightly ([#486](https://github.com/Shopify/sarama/pull/486)). + +#### Version 1.4.2 (2015-05-27) + +Bug Fixes: + - Update the import path for snappy-go, it has moved from google code to github + ([#456](https://github.com/Shopify/sarama/pull/456)). + +#### Version 1.4.1 (2015-05-25) + +Improvements: + - Optimizations when decoding snappy messages, thanks to John Potocny + ([#446](https://github.com/Shopify/sarama/pull/446)). + +Bug Fixes: + - Fix hypothetical race conditions on producer shutdown + ([#450](https://github.com/Shopify/sarama/pull/450), + [#451](https://github.com/Shopify/sarama/pull/451)). + +#### Version 1.4.0 (2015-05-01) + +New Features: + - The consumer now implements `Topics()` and `Partitions()` methods to enable + users to dynamically choose what topics/partitions to consume without + instantiating a full client + ([#431](https://github.com/Shopify/sarama/pull/431)). + - The partition-consumer now exposes the high water mark offset value returned + by the broker via the `HighWaterMarkOffset()` method ([#339](https://github.com/Shopify/sarama/pull/339)). + - Added a `kafka-console-consumer` tool capable of handling multiple + partitions, and deprecated the now-obsolete `kafka-console-partitionConsumer` + ([#439](https://github.com/Shopify/sarama/pull/439), + [#442](https://github.com/Shopify/sarama/pull/442)). + +Improvements: + - The producer's logging during retry scenarios is more consistent, more + useful, and slightly less verbose + ([#429](https://github.com/Shopify/sarama/pull/429)). + - The client now shuffles its initial list of seed brokers in order to prevent + thundering herd on the first broker in the list + ([#441](https://github.com/Shopify/sarama/pull/441)). + +Bug Fixes: + - The producer now correctly manages its state if retries occur when it is + shutting down, fixing several instances of confusing behaviour and at least + one potential deadlock ([#419](https://github.com/Shopify/sarama/pull/419)). + - The consumer now handles messages for different partitions asynchronously, + making it much more resilient to specific user code ordering + ([#325](https://github.com/Shopify/sarama/pull/325)). + +#### Version 1.3.0 (2015-04-16) + +New Features: + - The client now tracks consumer group coordinators using + ConsumerMetadataRequests similar to how it tracks partition leadership using + regular MetadataRequests ([#411](https://github.com/Shopify/sarama/pull/411)). + This adds two methods to the client API: + - `Coordinator(consumerGroup string) (*Broker, error)` + - `RefreshCoordinator(consumerGroup string) error` + +Improvements: + - ConsumerMetadataResponses now automatically create a Broker object out of the + ID/address/port combination for the Coordinator; accessing the fields + individually has been deprecated + ([#413](https://github.com/Shopify/sarama/pull/413)). + - Much improved handling of `OffsetOutOfRange` errors in the consumer. + Consumers will fail to start if the provided offset is out of range + ([#418](https://github.com/Shopify/sarama/pull/418)) + and they will automatically shut down if the offset falls out of range + ([#424](https://github.com/Shopify/sarama/pull/424)). + - Small performance improvement in encoding and decoding protocol messages + ([#427](https://github.com/Shopify/sarama/pull/427)). + +Bug Fixes: + - Fix a rare race condition in the client's background metadata refresher if + it happens to be activated while the client is being closed + ([#422](https://github.com/Shopify/sarama/pull/422)). + +#### Version 1.2.0 (2015-04-07) + +Improvements: + - The producer's behaviour when `Flush.Frequency` is set is now more intuitive + ([#389](https://github.com/Shopify/sarama/pull/389)). + - The producer is now somewhat more memory-efficient during and after retrying + messages due to an improved queue implementation + ([#396](https://github.com/Shopify/sarama/pull/396)). + - The consumer produces much more useful logging output when leadership + changes ([#385](https://github.com/Shopify/sarama/pull/385)). + - The client's `GetOffset` method will now automatically refresh metadata and + retry once in the event of stale information or similar + ([#394](https://github.com/Shopify/sarama/pull/394)). + - Broker connections now have support for using TCP keepalives + ([#407](https://github.com/Shopify/sarama/issues/407)). + +Bug Fixes: + - The OffsetCommitRequest message now correctly implements all three possible + API versions ([#390](https://github.com/Shopify/sarama/pull/390), + [#400](https://github.com/Shopify/sarama/pull/400)). + +#### Version 1.1.0 (2015-03-20) + +Improvements: + - Wrap the producer's partitioner call in a circuit-breaker so that repeatedly + broken topics don't choke throughput + ([#373](https://github.com/Shopify/sarama/pull/373)). + +Bug Fixes: + - Fix the producer's internal reference counting in certain unusual scenarios + ([#367](https://github.com/Shopify/sarama/pull/367)). + - Fix the consumer's internal reference counting in certain unusual scenarios + ([#369](https://github.com/Shopify/sarama/pull/369)). + - Fix a condition where the producer's internal control messages could have + gotten stuck ([#368](https://github.com/Shopify/sarama/pull/368)). + - Fix an issue where invalid partition lists would be cached when asking for + metadata for a non-existant topic ([#372](https://github.com/Shopify/sarama/pull/372)). + + +#### Version 1.0.0 (2015-03-17) + +Version 1.0.0 is the first tagged version, and is almost a complete rewrite. The primary differences with previous untagged versions are: + +- The producer has been rewritten; there is now a `SyncProducer` with a blocking API, and an `AsyncProducer` that is non-blocking. +- The consumer has been rewritten to only open one connection per broker instead of one connection per partition. +- The main types of Sarama are now interfaces to make depedency injection easy; mock implementations for `Consumer`, `SyncProducer` and `AsyncProducer` are provided in the `github.com/Shopify/sarama/mocks` package. +- For most uses cases, it is no longer necessary to open a `Client`; this will be done for you. +- All the configuration values have been unified in the `Config` struct. +- Much improved test suite. diff --git a/vendor/github.com/Shopify/sarama/LICENSE b/vendor/github.com/Shopify/sarama/LICENSE new file mode 100644 index 00000000..8121b63b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2013 Evan Huus + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/Shopify/sarama/Makefile b/vendor/github.com/Shopify/sarama/Makefile new file mode 100644 index 00000000..626b09a5 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/Makefile @@ -0,0 +1,21 @@ +default: fmt vet errcheck test + +test: + go test -v -timeout 60s -race ./... + +vet: + go vet ./... + +errcheck: + errcheck github.com/Shopify/sarama/... + +fmt: + @if [ -n "$$(go fmt ./...)" ]; then echo 'Please run go fmt on your code.' && exit 1; fi + +install_dependencies: install_errcheck get + +install_errcheck: + go get github.com/kisielk/errcheck + +get: + go get -t diff --git a/vendor/github.com/Shopify/sarama/README.md b/vendor/github.com/Shopify/sarama/README.md new file mode 100644 index 00000000..6e12a07a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/README.md @@ -0,0 +1,38 @@ +sarama +====== + +[![GoDoc](https://godoc.org/github.com/Shopify/sarama?status.png)](https://godoc.org/github.com/Shopify/sarama) +[![Build Status](https://travis-ci.org/Shopify/sarama.svg?branch=master)](https://travis-ci.org/Shopify/sarama) + +Sarama is an MIT-licensed Go client library for [Apache Kafka](https://kafka.apache.org/) version 0.8 (and later). + +### Getting started + +- API documentation and examples are available via [godoc](https://godoc.org/github.com/Shopify/sarama). +- Mocks for testing are available in the [mocks](./mocks) subpackage. +- The [examples](./examples) directory contains more elaborate example applications. +- The [tools](./tools) directory contains command line tools that can be useful for testing, diagnostics, and instrumentation. + +You might also want to look at the [Frequently Asked Questions](https://github.com/Shopify/sarama/wiki/Frequently-Asked-Questions). + +### Compatibility and API stability + +Sarama provides a "2 releases + 2 months" compatibility guarantee: we support +the two latest stable releases of Kafka and Go, and we provide a two month +grace period for older releases. This means we currently officially support +Go 1.8 and 1.7, and Kafka 0.10 and 0.9, although older releases are +still likely to work. + +Sarama follows semantic versioning and provides API stability via the gopkg.in service. +You can import a version with a guaranteed stable API via http://gopkg.in/Shopify/sarama.v1. +A changelog is available [here](CHANGELOG.md). + +### Contributing + +* Get started by checking our [contribution guidelines](https://github.com/Shopify/sarama/blob/master/.github/CONTRIBUTING.md). +* Read the [Sarama wiki](https://github.com/Shopify/sarama/wiki) for more + technical and design details. +* The [Kafka Protocol Specification](https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol) + contains a wealth of useful information. +* For more general issues, there is [a google group](https://groups.google.com/forum/#!forum/kafka-clients) for Kafka client developers. +* If you have any questions, just ask! diff --git a/vendor/github.com/Shopify/sarama/Vagrantfile b/vendor/github.com/Shopify/sarama/Vagrantfile new file mode 100644 index 00000000..f4b848a3 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/Vagrantfile @@ -0,0 +1,20 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# Vagrantfile API/syntax version. Don't touch unless you know what you're doing! +VAGRANTFILE_API_VERSION = "2" + +# We have 5 * 192MB ZK processes and 5 * 320MB Kafka processes => 2560MB +MEMORY = 3072 + +Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| + config.vm.box = "ubuntu/trusty64" + + config.vm.provision :shell, path: "vagrant/provision.sh" + + config.vm.network "private_network", ip: "192.168.100.67" + + config.vm.provider "virtualbox" do |v| + v.memory = MEMORY + end +end diff --git a/vendor/github.com/Shopify/sarama/api_versions_request.go b/vendor/github.com/Shopify/sarama/api_versions_request.go new file mode 100644 index 00000000..ab65f01c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/api_versions_request.go @@ -0,0 +1,24 @@ +package sarama + +type ApiVersionsRequest struct { +} + +func (r *ApiVersionsRequest) encode(pe packetEncoder) error { + return nil +} + +func (r *ApiVersionsRequest) decode(pd packetDecoder, version int16) (err error) { + return nil +} + +func (r *ApiVersionsRequest) key() int16 { + return 18 +} + +func (r *ApiVersionsRequest) version() int16 { + return 0 +} + +func (r *ApiVersionsRequest) requiredVersion() KafkaVersion { + return V0_10_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/api_versions_request_test.go b/vendor/github.com/Shopify/sarama/api_versions_request_test.go new file mode 100644 index 00000000..5ab4fa71 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/api_versions_request_test.go @@ -0,0 +1,14 @@ +package sarama + +import "testing" + +var ( + apiVersionRequest = []byte{} +) + +func TestApiVersionsRequest(t *testing.T) { + var request *ApiVersionsRequest + + request = new(ApiVersionsRequest) + testRequest(t, "basic", request, apiVersionRequest) +} diff --git a/vendor/github.com/Shopify/sarama/api_versions_response.go b/vendor/github.com/Shopify/sarama/api_versions_response.go new file mode 100644 index 00000000..23bc326e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/api_versions_response.go @@ -0,0 +1,87 @@ +package sarama + +type ApiVersionsResponseBlock struct { + ApiKey int16 + MinVersion int16 + MaxVersion int16 +} + +func (b *ApiVersionsResponseBlock) encode(pe packetEncoder) error { + pe.putInt16(b.ApiKey) + pe.putInt16(b.MinVersion) + pe.putInt16(b.MaxVersion) + return nil +} + +func (b *ApiVersionsResponseBlock) decode(pd packetDecoder) error { + var err error + + if b.ApiKey, err = pd.getInt16(); err != nil { + return err + } + + if b.MinVersion, err = pd.getInt16(); err != nil { + return err + } + + if b.MaxVersion, err = pd.getInt16(); err != nil { + return err + } + + return nil +} + +type ApiVersionsResponse struct { + Err KError + ApiVersions []*ApiVersionsResponseBlock +} + +func (r *ApiVersionsResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + if err := pe.putArrayLength(len(r.ApiVersions)); err != nil { + return err + } + for _, apiVersion := range r.ApiVersions { + if err := apiVersion.encode(pe); err != nil { + return err + } + } + return nil +} + +func (r *ApiVersionsResponse) decode(pd packetDecoder, version int16) error { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Err = KError(kerr) + + numBlocks, err := pd.getArrayLength() + if err != nil { + return err + } + + r.ApiVersions = make([]*ApiVersionsResponseBlock, numBlocks) + for i := 0; i < numBlocks; i++ { + block := new(ApiVersionsResponseBlock) + if err := block.decode(pd); err != nil { + return err + } + r.ApiVersions[i] = block + } + + return nil +} + +func (r *ApiVersionsResponse) key() int16 { + return 18 +} + +func (r *ApiVersionsResponse) version() int16 { + return 0 +} + +func (r *ApiVersionsResponse) requiredVersion() KafkaVersion { + return V0_10_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/api_versions_response_test.go b/vendor/github.com/Shopify/sarama/api_versions_response_test.go new file mode 100644 index 00000000..675a65a7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/api_versions_response_test.go @@ -0,0 +1,32 @@ +package sarama + +import "testing" + +var ( + apiVersionResponse = []byte{ + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x03, + 0x00, 0x02, + 0x00, 0x01, + } +) + +func TestApiVersionsResponse(t *testing.T) { + var response *ApiVersionsResponse + + response = new(ApiVersionsResponse) + testVersionDecodable(t, "no error", response, apiVersionResponse, 0) + if response.Err != ErrNoError { + t.Error("Decoding error failed: no error expected but found", response.Err) + } + if response.ApiVersions[0].ApiKey != 0x03 { + t.Error("Decoding error: expected 0x03 but got", response.ApiVersions[0].ApiKey) + } + if response.ApiVersions[0].MinVersion != 0x02 { + t.Error("Decoding error: expected 0x02 but got", response.ApiVersions[0].MinVersion) + } + if response.ApiVersions[0].MaxVersion != 0x01 { + t.Error("Decoding error: expected 0x01 but got", response.ApiVersions[0].MaxVersion) + } +} diff --git a/vendor/github.com/Shopify/sarama/async_producer.go b/vendor/github.com/Shopify/sarama/async_producer.go new file mode 100644 index 00000000..6d71a6d8 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/async_producer.go @@ -0,0 +1,904 @@ +package sarama + +import ( + "fmt" + "sync" + "time" + + "github.com/eapache/go-resiliency/breaker" + "github.com/eapache/queue" +) + +// AsyncProducer publishes Kafka messages using a non-blocking API. It routes messages +// to the correct broker for the provided topic-partition, refreshing metadata as appropriate, +// and parses responses for errors. You must read from the Errors() channel or the +// producer will deadlock. You must call Close() or AsyncClose() on a producer to avoid +// leaks: it will not be garbage-collected automatically when it passes out of +// scope. +type AsyncProducer interface { + + // AsyncClose triggers a shutdown of the producer. The shutdown has completed + // when both the Errors and Successes channels have been closed. When calling + // AsyncClose, you *must* continue to read from those channels in order to + // drain the results of any messages in flight. + AsyncClose() + + // Close shuts down the producer and waits for any buffered messages to be + // flushed. You must call this function before a producer object passes out of + // scope, as it may otherwise leak memory. You must call this before calling + // Close on the underlying client. + Close() error + + // Input is the input channel for the user to write messages to that they + // wish to send. + Input() chan<- *ProducerMessage + + // Successes is the success output channel back to the user when Return.Successes is + // enabled. If Return.Successes is true, you MUST read from this channel or the + // Producer will deadlock. It is suggested that you send and read messages + // together in a single select statement. + Successes() <-chan *ProducerMessage + + // Errors is the error output channel back to the user. You MUST read from this + // channel or the Producer will deadlock when the channel is full. Alternatively, + // you can set Producer.Return.Errors in your config to false, which prevents + // errors to be returned. + Errors() <-chan *ProducerError +} + +type asyncProducer struct { + client Client + conf *Config + ownClient bool + + errors chan *ProducerError + input, successes, retries chan *ProducerMessage + inFlight sync.WaitGroup + + brokers map[*Broker]chan<- *ProducerMessage + brokerRefs map[chan<- *ProducerMessage]int + brokerLock sync.Mutex +} + +// NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration. +func NewAsyncProducer(addrs []string, conf *Config) (AsyncProducer, error) { + client, err := NewClient(addrs, conf) + if err != nil { + return nil, err + } + + p, err := NewAsyncProducerFromClient(client) + if err != nil { + return nil, err + } + p.(*asyncProducer).ownClient = true + return p, nil +} + +// NewAsyncProducerFromClient creates a new Producer using the given client. It is still +// necessary to call Close() on the underlying client when shutting down this producer. +func NewAsyncProducerFromClient(client Client) (AsyncProducer, error) { + // Check that we are not dealing with a closed Client before processing any other arguments + if client.Closed() { + return nil, ErrClosedClient + } + + p := &asyncProducer{ + client: client, + conf: client.Config(), + errors: make(chan *ProducerError), + input: make(chan *ProducerMessage), + successes: make(chan *ProducerMessage), + retries: make(chan *ProducerMessage), + brokers: make(map[*Broker]chan<- *ProducerMessage), + brokerRefs: make(map[chan<- *ProducerMessage]int), + } + + // launch our singleton dispatchers + go withRecover(p.dispatcher) + go withRecover(p.retryHandler) + + return p, nil +} + +type flagSet int8 + +const ( + syn flagSet = 1 << iota // first message from partitionProducer to brokerProducer + fin // final message from partitionProducer to brokerProducer and back + shutdown // start the shutdown process +) + +// ProducerMessage is the collection of elements passed to the Producer in order to send a message. +type ProducerMessage struct { + Topic string // The Kafka topic for this message. + // The partitioning key for this message. Pre-existing Encoders include + // StringEncoder and ByteEncoder. + Key Encoder + // The actual message to store in Kafka. Pre-existing Encoders include + // StringEncoder and ByteEncoder. + Value Encoder + + // This field is used to hold arbitrary data you wish to include so it + // will be available when receiving on the Successes and Errors channels. + // Sarama completely ignores this field and is only to be used for + // pass-through data. + Metadata interface{} + + // Below this point are filled in by the producer as the message is processed + + // Offset is the offset of the message stored on the broker. This is only + // guaranteed to be defined if the message was successfully delivered and + // RequiredAcks is not NoResponse. + Offset int64 + // Partition is the partition that the message was sent to. This is only + // guaranteed to be defined if the message was successfully delivered. + Partition int32 + // Timestamp is the timestamp assigned to the message by the broker. This + // is only guaranteed to be defined if the message was successfully + // delivered, RequiredAcks is not NoResponse, and the Kafka broker is at + // least version 0.10.0. + Timestamp time.Time + + retries int + flags flagSet +} + +const producerMessageOverhead = 26 // the metadata overhead of CRC, flags, etc. + +func (m *ProducerMessage) byteSize() int { + size := producerMessageOverhead + if m.Key != nil { + size += m.Key.Length() + } + if m.Value != nil { + size += m.Value.Length() + } + return size +} + +func (m *ProducerMessage) clear() { + m.flags = 0 + m.retries = 0 +} + +// ProducerError is the type of error generated when the producer fails to deliver a message. +// It contains the original ProducerMessage as well as the actual error value. +type ProducerError struct { + Msg *ProducerMessage + Err error +} + +func (pe ProducerError) Error() string { + return fmt.Sprintf("kafka: Failed to produce message to topic %s: %s", pe.Msg.Topic, pe.Err) +} + +// ProducerErrors is a type that wraps a batch of "ProducerError"s and implements the Error interface. +// It can be returned from the Producer's Close method to avoid the need to manually drain the Errors channel +// when closing a producer. +type ProducerErrors []*ProducerError + +func (pe ProducerErrors) Error() string { + return fmt.Sprintf("kafka: Failed to deliver %d messages.", len(pe)) +} + +func (p *asyncProducer) Errors() <-chan *ProducerError { + return p.errors +} + +func (p *asyncProducer) Successes() <-chan *ProducerMessage { + return p.successes +} + +func (p *asyncProducer) Input() chan<- *ProducerMessage { + return p.input +} + +func (p *asyncProducer) Close() error { + p.AsyncClose() + + if p.conf.Producer.Return.Successes { + go withRecover(func() { + for range p.successes { + } + }) + } + + var errors ProducerErrors + if p.conf.Producer.Return.Errors { + for event := range p.errors { + errors = append(errors, event) + } + } else { + <-p.errors + } + + if len(errors) > 0 { + return errors + } + return nil +} + +func (p *asyncProducer) AsyncClose() { + go withRecover(p.shutdown) +} + +// singleton +// dispatches messages by topic +func (p *asyncProducer) dispatcher() { + handlers := make(map[string]chan<- *ProducerMessage) + shuttingDown := false + + for msg := range p.input { + if msg == nil { + Logger.Println("Something tried to send a nil message, it was ignored.") + continue + } + + if msg.flags&shutdown != 0 { + shuttingDown = true + p.inFlight.Done() + continue + } else if msg.retries == 0 { + if shuttingDown { + // we can't just call returnError here because that decrements the wait group, + // which hasn't been incremented yet for this message, and shouldn't be + pErr := &ProducerError{Msg: msg, Err: ErrShuttingDown} + if p.conf.Producer.Return.Errors { + p.errors <- pErr + } else { + Logger.Println(pErr) + } + continue + } + p.inFlight.Add(1) + } + + if msg.byteSize() > p.conf.Producer.MaxMessageBytes { + p.returnError(msg, ErrMessageSizeTooLarge) + continue + } + + handler := handlers[msg.Topic] + if handler == nil { + handler = p.newTopicProducer(msg.Topic) + handlers[msg.Topic] = handler + } + + handler <- msg + } + + for _, handler := range handlers { + close(handler) + } +} + +// one per topic +// partitions messages, then dispatches them by partition +type topicProducer struct { + parent *asyncProducer + topic string + input <-chan *ProducerMessage + + breaker *breaker.Breaker + handlers map[int32]chan<- *ProducerMessage + partitioner Partitioner +} + +func (p *asyncProducer) newTopicProducer(topic string) chan<- *ProducerMessage { + input := make(chan *ProducerMessage, p.conf.ChannelBufferSize) + tp := &topicProducer{ + parent: p, + topic: topic, + input: input, + breaker: breaker.New(3, 1, 10*time.Second), + handlers: make(map[int32]chan<- *ProducerMessage), + partitioner: p.conf.Producer.Partitioner(topic), + } + go withRecover(tp.dispatch) + return input +} + +func (tp *topicProducer) dispatch() { + for msg := range tp.input { + if msg.retries == 0 { + if err := tp.partitionMessage(msg); err != nil { + tp.parent.returnError(msg, err) + continue + } + } + + handler := tp.handlers[msg.Partition] + if handler == nil { + handler = tp.parent.newPartitionProducer(msg.Topic, msg.Partition) + tp.handlers[msg.Partition] = handler + } + + handler <- msg + } + + for _, handler := range tp.handlers { + close(handler) + } +} + +func (tp *topicProducer) partitionMessage(msg *ProducerMessage) error { + var partitions []int32 + + err := tp.breaker.Run(func() (err error) { + if tp.partitioner.RequiresConsistency() { + partitions, err = tp.parent.client.Partitions(msg.Topic) + } else { + partitions, err = tp.parent.client.WritablePartitions(msg.Topic) + } + return + }) + + if err != nil { + return err + } + + numPartitions := int32(len(partitions)) + + if numPartitions == 0 { + return ErrLeaderNotAvailable + } + + choice, err := tp.partitioner.Partition(msg, numPartitions) + + if err != nil { + return err + } else if choice < 0 || choice >= numPartitions { + return ErrInvalidPartition + } + + msg.Partition = partitions[choice] + + return nil +} + +// one per partition per topic +// dispatches messages to the appropriate broker +// also responsible for maintaining message order during retries +type partitionProducer struct { + parent *asyncProducer + topic string + partition int32 + input <-chan *ProducerMessage + + leader *Broker + breaker *breaker.Breaker + output chan<- *ProducerMessage + + // highWatermark tracks the "current" retry level, which is the only one where we actually let messages through, + // all other messages get buffered in retryState[msg.retries].buf to preserve ordering + // retryState[msg.retries].expectChaser simply tracks whether we've seen a fin message for a given level (and + // therefore whether our buffer is complete and safe to flush) + highWatermark int + retryState []partitionRetryState +} + +type partitionRetryState struct { + buf []*ProducerMessage + expectChaser bool +} + +func (p *asyncProducer) newPartitionProducer(topic string, partition int32) chan<- *ProducerMessage { + input := make(chan *ProducerMessage, p.conf.ChannelBufferSize) + pp := &partitionProducer{ + parent: p, + topic: topic, + partition: partition, + input: input, + + breaker: breaker.New(3, 1, 10*time.Second), + retryState: make([]partitionRetryState, p.conf.Producer.Retry.Max+1), + } + go withRecover(pp.dispatch) + return input +} + +func (pp *partitionProducer) dispatch() { + // try to prefetch the leader; if this doesn't work, we'll do a proper call to `updateLeader` + // on the first message + pp.leader, _ = pp.parent.client.Leader(pp.topic, pp.partition) + if pp.leader != nil { + pp.output = pp.parent.getBrokerProducer(pp.leader) + pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight + pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn} + } + + for msg := range pp.input { + if msg.retries > pp.highWatermark { + // a new, higher, retry level; handle it and then back off + pp.newHighWatermark(msg.retries) + time.Sleep(pp.parent.conf.Producer.Retry.Backoff) + } else if pp.highWatermark > 0 { + // we are retrying something (else highWatermark would be 0) but this message is not a *new* retry level + if msg.retries < pp.highWatermark { + // in fact this message is not even the current retry level, so buffer it for now (unless it's a just a fin) + if msg.flags&fin == fin { + pp.retryState[msg.retries].expectChaser = false + pp.parent.inFlight.Done() // this fin is now handled and will be garbage collected + } else { + pp.retryState[msg.retries].buf = append(pp.retryState[msg.retries].buf, msg) + } + continue + } else if msg.flags&fin == fin { + // this message is of the current retry level (msg.retries == highWatermark) and the fin flag is set, + // meaning this retry level is done and we can go down (at least) one level and flush that + pp.retryState[pp.highWatermark].expectChaser = false + pp.flushRetryBuffers() + pp.parent.inFlight.Done() // this fin is now handled and will be garbage collected + continue + } + } + + // if we made it this far then the current msg contains real data, and can be sent to the next goroutine + // without breaking any of our ordering guarantees + + if pp.output == nil { + if err := pp.updateLeader(); err != nil { + pp.parent.returnError(msg, err) + time.Sleep(pp.parent.conf.Producer.Retry.Backoff) + continue + } + Logger.Printf("producer/leader/%s/%d selected broker %d\n", pp.topic, pp.partition, pp.leader.ID()) + } + + pp.output <- msg + } + + if pp.output != nil { + pp.parent.unrefBrokerProducer(pp.leader, pp.output) + } +} + +func (pp *partitionProducer) newHighWatermark(hwm int) { + Logger.Printf("producer/leader/%s/%d state change to [retrying-%d]\n", pp.topic, pp.partition, hwm) + pp.highWatermark = hwm + + // send off a fin so that we know when everything "in between" has made it + // back to us and we can safely flush the backlog (otherwise we risk re-ordering messages) + pp.retryState[pp.highWatermark].expectChaser = true + pp.parent.inFlight.Add(1) // we're generating a fin message; track it so we don't shut down while it's still inflight + pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: fin, retries: pp.highWatermark - 1} + + // a new HWM means that our current broker selection is out of date + Logger.Printf("producer/leader/%s/%d abandoning broker %d\n", pp.topic, pp.partition, pp.leader.ID()) + pp.parent.unrefBrokerProducer(pp.leader, pp.output) + pp.output = nil +} + +func (pp *partitionProducer) flushRetryBuffers() { + Logger.Printf("producer/leader/%s/%d state change to [flushing-%d]\n", pp.topic, pp.partition, pp.highWatermark) + for { + pp.highWatermark-- + + if pp.output == nil { + if err := pp.updateLeader(); err != nil { + pp.parent.returnErrors(pp.retryState[pp.highWatermark].buf, err) + goto flushDone + } + Logger.Printf("producer/leader/%s/%d selected broker %d\n", pp.topic, pp.partition, pp.leader.ID()) + } + + for _, msg := range pp.retryState[pp.highWatermark].buf { + pp.output <- msg + } + + flushDone: + pp.retryState[pp.highWatermark].buf = nil + if pp.retryState[pp.highWatermark].expectChaser { + Logger.Printf("producer/leader/%s/%d state change to [retrying-%d]\n", pp.topic, pp.partition, pp.highWatermark) + break + } else if pp.highWatermark == 0 { + Logger.Printf("producer/leader/%s/%d state change to [normal]\n", pp.topic, pp.partition) + break + } + } +} + +func (pp *partitionProducer) updateLeader() error { + return pp.breaker.Run(func() (err error) { + if err = pp.parent.client.RefreshMetadata(pp.topic); err != nil { + return err + } + + if pp.leader, err = pp.parent.client.Leader(pp.topic, pp.partition); err != nil { + return err + } + + pp.output = pp.parent.getBrokerProducer(pp.leader) + pp.parent.inFlight.Add(1) // we're generating a syn message; track it so we don't shut down while it's still inflight + pp.output <- &ProducerMessage{Topic: pp.topic, Partition: pp.partition, flags: syn} + + return nil + }) +} + +// one per broker; also constructs an associated flusher +func (p *asyncProducer) newBrokerProducer(broker *Broker) chan<- *ProducerMessage { + var ( + input = make(chan *ProducerMessage) + bridge = make(chan *produceSet) + responses = make(chan *brokerProducerResponse) + ) + + bp := &brokerProducer{ + parent: p, + broker: broker, + input: input, + output: bridge, + responses: responses, + buffer: newProduceSet(p), + currentRetries: make(map[string]map[int32]error), + } + go withRecover(bp.run) + + // minimal bridge to make the network response `select`able + go withRecover(func() { + for set := range bridge { + request := set.buildRequest() + + response, err := broker.Produce(request) + + responses <- &brokerProducerResponse{ + set: set, + err: err, + res: response, + } + } + close(responses) + }) + + return input +} + +type brokerProducerResponse struct { + set *produceSet + err error + res *ProduceResponse +} + +// groups messages together into appropriately-sized batches for sending to the broker +// handles state related to retries etc +type brokerProducer struct { + parent *asyncProducer + broker *Broker + + input <-chan *ProducerMessage + output chan<- *produceSet + responses <-chan *brokerProducerResponse + + buffer *produceSet + timer <-chan time.Time + timerFired bool + + closing error + currentRetries map[string]map[int32]error +} + +func (bp *brokerProducer) run() { + var output chan<- *produceSet + Logger.Printf("producer/broker/%d starting up\n", bp.broker.ID()) + + for { + select { + case msg := <-bp.input: + if msg == nil { + bp.shutdown() + return + } + + if msg.flags&syn == syn { + Logger.Printf("producer/broker/%d state change to [open] on %s/%d\n", + bp.broker.ID(), msg.Topic, msg.Partition) + if bp.currentRetries[msg.Topic] == nil { + bp.currentRetries[msg.Topic] = make(map[int32]error) + } + bp.currentRetries[msg.Topic][msg.Partition] = nil + bp.parent.inFlight.Done() + continue + } + + if reason := bp.needsRetry(msg); reason != nil { + bp.parent.retryMessage(msg, reason) + + if bp.closing == nil && msg.flags&fin == fin { + // we were retrying this partition but we can start processing again + delete(bp.currentRetries[msg.Topic], msg.Partition) + Logger.Printf("producer/broker/%d state change to [closed] on %s/%d\n", + bp.broker.ID(), msg.Topic, msg.Partition) + } + + continue + } + + if bp.buffer.wouldOverflow(msg) { + if err := bp.waitForSpace(msg); err != nil { + bp.parent.retryMessage(msg, err) + continue + } + } + + if err := bp.buffer.add(msg); err != nil { + bp.parent.returnError(msg, err) + continue + } + + if bp.parent.conf.Producer.Flush.Frequency > 0 && bp.timer == nil { + bp.timer = time.After(bp.parent.conf.Producer.Flush.Frequency) + } + case <-bp.timer: + bp.timerFired = true + case output <- bp.buffer: + bp.rollOver() + case response := <-bp.responses: + bp.handleResponse(response) + } + + if bp.timerFired || bp.buffer.readyToFlush() { + output = bp.output + } else { + output = nil + } + } +} + +func (bp *brokerProducer) shutdown() { + for !bp.buffer.empty() { + select { + case response := <-bp.responses: + bp.handleResponse(response) + case bp.output <- bp.buffer: + bp.rollOver() + } + } + close(bp.output) + for response := range bp.responses { + bp.handleResponse(response) + } + + Logger.Printf("producer/broker/%d shut down\n", bp.broker.ID()) +} + +func (bp *brokerProducer) needsRetry(msg *ProducerMessage) error { + if bp.closing != nil { + return bp.closing + } + + return bp.currentRetries[msg.Topic][msg.Partition] +} + +func (bp *brokerProducer) waitForSpace(msg *ProducerMessage) error { + Logger.Printf("producer/broker/%d maximum request accumulated, waiting for space\n", bp.broker.ID()) + + for { + select { + case response := <-bp.responses: + bp.handleResponse(response) + // handling a response can change our state, so re-check some things + if reason := bp.needsRetry(msg); reason != nil { + return reason + } else if !bp.buffer.wouldOverflow(msg) { + return nil + } + case bp.output <- bp.buffer: + bp.rollOver() + return nil + } + } +} + +func (bp *brokerProducer) rollOver() { + bp.timer = nil + bp.timerFired = false + bp.buffer = newProduceSet(bp.parent) +} + +func (bp *brokerProducer) handleResponse(response *brokerProducerResponse) { + if response.err != nil { + bp.handleError(response.set, response.err) + } else { + bp.handleSuccess(response.set, response.res) + } + + if bp.buffer.empty() { + bp.rollOver() // this can happen if the response invalidated our buffer + } +} + +func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceResponse) { + // we iterate through the blocks in the request set, not the response, so that we notice + // if the response is missing a block completely + sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) { + if response == nil { + // this only happens when RequiredAcks is NoResponse, so we have to assume success + bp.parent.returnSuccesses(msgs) + return + } + + block := response.GetBlock(topic, partition) + if block == nil { + bp.parent.returnErrors(msgs, ErrIncompleteResponse) + return + } + + switch block.Err { + // Success + case ErrNoError: + if bp.parent.conf.Version.IsAtLeast(V0_10_0_0) && !block.Timestamp.IsZero() { + for _, msg := range msgs { + msg.Timestamp = block.Timestamp + } + } + for i, msg := range msgs { + msg.Offset = block.Offset + int64(i) + } + bp.parent.returnSuccesses(msgs) + // Retriable errors + case ErrInvalidMessage, ErrUnknownTopicOrPartition, ErrLeaderNotAvailable, ErrNotLeaderForPartition, + ErrRequestTimedOut, ErrNotEnoughReplicas, ErrNotEnoughReplicasAfterAppend: + Logger.Printf("producer/broker/%d state change to [retrying] on %s/%d because %v\n", + bp.broker.ID(), topic, partition, block.Err) + bp.currentRetries[topic][partition] = block.Err + bp.parent.retryMessages(msgs, block.Err) + bp.parent.retryMessages(bp.buffer.dropPartition(topic, partition), block.Err) + // Other non-retriable errors + default: + bp.parent.returnErrors(msgs, block.Err) + } + }) +} + +func (bp *brokerProducer) handleError(sent *produceSet, err error) { + switch err.(type) { + case PacketEncodingError: + sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) { + bp.parent.returnErrors(msgs, err) + }) + default: + Logger.Printf("producer/broker/%d state change to [closing] because %s\n", bp.broker.ID(), err) + bp.parent.abandonBrokerConnection(bp.broker) + _ = bp.broker.Close() + bp.closing = err + sent.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) { + bp.parent.retryMessages(msgs, err) + }) + bp.buffer.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) { + bp.parent.retryMessages(msgs, err) + }) + bp.rollOver() + } +} + +// singleton +// effectively a "bridge" between the flushers and the dispatcher in order to avoid deadlock +// based on https://godoc.org/github.com/eapache/channels#InfiniteChannel +func (p *asyncProducer) retryHandler() { + var msg *ProducerMessage + buf := queue.New() + + for { + if buf.Length() == 0 { + msg = <-p.retries + } else { + select { + case msg = <-p.retries: + case p.input <- buf.Peek().(*ProducerMessage): + buf.Remove() + continue + } + } + + if msg == nil { + return + } + + buf.Add(msg) + } +} + +// utility functions + +func (p *asyncProducer) shutdown() { + Logger.Println("Producer shutting down.") + p.inFlight.Add(1) + p.input <- &ProducerMessage{flags: shutdown} + + p.inFlight.Wait() + + if p.ownClient { + err := p.client.Close() + if err != nil { + Logger.Println("producer/shutdown failed to close the embedded client:", err) + } + } + + close(p.input) + close(p.retries) + close(p.errors) + close(p.successes) +} + +func (p *asyncProducer) returnError(msg *ProducerMessage, err error) { + msg.clear() + pErr := &ProducerError{Msg: msg, Err: err} + if p.conf.Producer.Return.Errors { + p.errors <- pErr + } else { + Logger.Println(pErr) + } + p.inFlight.Done() +} + +func (p *asyncProducer) returnErrors(batch []*ProducerMessage, err error) { + for _, msg := range batch { + p.returnError(msg, err) + } +} + +func (p *asyncProducer) returnSuccesses(batch []*ProducerMessage) { + for _, msg := range batch { + if p.conf.Producer.Return.Successes { + msg.clear() + p.successes <- msg + } + p.inFlight.Done() + } +} + +func (p *asyncProducer) retryMessage(msg *ProducerMessage, err error) { + if msg.retries >= p.conf.Producer.Retry.Max { + p.returnError(msg, err) + } else { + msg.retries++ + p.retries <- msg + } +} + +func (p *asyncProducer) retryMessages(batch []*ProducerMessage, err error) { + for _, msg := range batch { + p.retryMessage(msg, err) + } +} + +func (p *asyncProducer) getBrokerProducer(broker *Broker) chan<- *ProducerMessage { + p.brokerLock.Lock() + defer p.brokerLock.Unlock() + + bp := p.brokers[broker] + + if bp == nil { + bp = p.newBrokerProducer(broker) + p.brokers[broker] = bp + p.brokerRefs[bp] = 0 + } + + p.brokerRefs[bp]++ + + return bp +} + +func (p *asyncProducer) unrefBrokerProducer(broker *Broker, bp chan<- *ProducerMessage) { + p.brokerLock.Lock() + defer p.brokerLock.Unlock() + + p.brokerRefs[bp]-- + if p.brokerRefs[bp] == 0 { + close(bp) + delete(p.brokerRefs, bp) + + if p.brokers[broker] == bp { + delete(p.brokers, broker) + } + } +} + +func (p *asyncProducer) abandonBrokerConnection(broker *Broker) { + p.brokerLock.Lock() + defer p.brokerLock.Unlock() + + delete(p.brokers, broker) +} diff --git a/vendor/github.com/Shopify/sarama/async_producer_test.go b/vendor/github.com/Shopify/sarama/async_producer_test.go new file mode 100644 index 00000000..07d23533 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/async_producer_test.go @@ -0,0 +1,841 @@ +package sarama + +import ( + "errors" + "log" + "os" + "os/signal" + "sync" + "testing" + "time" +) + +const TestMessage = "ABC THE MESSAGE" + +func closeProducer(t *testing.T, p AsyncProducer) { + var wg sync.WaitGroup + p.AsyncClose() + + wg.Add(2) + go func() { + for range p.Successes() { + t.Error("Unexpected message on Successes()") + } + wg.Done() + }() + go func() { + for msg := range p.Errors() { + t.Error(msg.Err) + } + wg.Done() + }() + wg.Wait() +} + +func expectResults(t *testing.T, p AsyncProducer, successes, errors int) { + expect := successes + errors + for expect > 0 { + select { + case msg := <-p.Errors(): + if msg.Msg.flags != 0 { + t.Error("Message had flags set") + } + errors-- + expect-- + if errors < 0 { + t.Error(msg.Err) + } + case msg := <-p.Successes(): + if msg.flags != 0 { + t.Error("Message had flags set") + } + successes-- + expect-- + if successes < 0 { + t.Error("Too many successes") + } + } + } + if successes != 0 || errors != 0 { + t.Error("Unexpected successes", successes, "or errors", errors) + } +} + +type testPartitioner chan *int32 + +func (p testPartitioner) Partition(msg *ProducerMessage, numPartitions int32) (int32, error) { + part := <-p + if part == nil { + return 0, errors.New("BOOM") + } + + return *part, nil +} + +func (p testPartitioner) RequiresConsistency() bool { + return true +} + +func (p testPartitioner) feed(partition int32) { + p <- &partition +} + +type flakyEncoder bool + +func (f flakyEncoder) Length() int { + return len(TestMessage) +} + +func (f flakyEncoder) Encode() ([]byte, error) { + if !bool(f) { + return nil, errors.New("flaky encoding error") + } + return []byte(TestMessage), nil +} + +func TestAsyncProducer(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Metadata: i} + } + for i := 0; i < 10; i++ { + select { + case msg := <-producer.Errors(): + t.Error(msg.Err) + if msg.Msg.flags != 0 { + t.Error("Message had flags set") + } + case msg := <-producer.Successes(): + if msg.flags != 0 { + t.Error("Message had flags set") + } + if msg.Metadata.(int) != i { + t.Error("Message metadata did not match") + } + } + } + + closeProducer(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestAsyncProducerMultipleFlushes(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + leader.Returns(prodSuccess) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 5 + config.Producer.Return.Successes = true + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for flush := 0; flush < 3; flush++ { + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + expectResults(t, producer, 5, 0) + } + + closeProducer(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestAsyncProducerMultipleBrokers(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader0 := NewMockBroker(t, 2) + leader1 := NewMockBroker(t, 3) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader0.Addr(), leader0.BrokerID()) + metadataResponse.AddBroker(leader1.Addr(), leader1.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader0.BrokerID(), nil, nil, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 1, leader1.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodResponse0 := new(ProduceResponse) + prodResponse0.AddTopicPartition("my_topic", 0, ErrNoError) + leader0.Returns(prodResponse0) + + prodResponse1 := new(ProduceResponse) + prodResponse1.AddTopicPartition("my_topic", 1, ErrNoError) + leader1.Returns(prodResponse1) + + config := NewConfig() + config.Producer.Flush.Messages = 5 + config.Producer.Return.Successes = true + config.Producer.Partitioner = NewRoundRobinPartitioner + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + expectResults(t, producer, 10, 0) + + closeProducer(t, producer) + leader1.Close() + leader0.Close() + seedBroker.Close() +} + +func TestAsyncProducerCustomPartitioner(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodResponse := new(ProduceResponse) + prodResponse.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodResponse) + + config := NewConfig() + config.Producer.Flush.Messages = 2 + config.Producer.Return.Successes = true + config.Producer.Partitioner = func(topic string) Partitioner { + p := make(testPartitioner) + go func() { + p.feed(0) + p <- nil + p <- nil + p <- nil + p.feed(0) + }() + return p + } + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + expectResults(t, producer, 2, 3) + + closeProducer(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestAsyncProducerFailureRetry(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader1 := NewMockBroker(t, 2) + leader2 := NewMockBroker(t, 3) + + metadataLeader1 := new(MetadataResponse) + metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID()) + metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader1) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + seedBroker.Close() + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + prodNotLeader := new(ProduceResponse) + prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition) + leader1.Returns(prodNotLeader) + + metadataLeader2 := new(MetadataResponse) + metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID()) + metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError) + leader1.Returns(metadataLeader2) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader2.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + leader1.Close() + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + leader2.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + + leader2.Close() + closeProducer(t, producer) +} + +func TestAsyncProducerEncoderFailures(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + leader.Returns(prodSuccess) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 1 + config.Producer.Return.Successes = true + config.Producer.Partitioner = NewManualPartitioner + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for flush := 0; flush < 3; flush++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: flakyEncoder(true), Value: flakyEncoder(false)} + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: flakyEncoder(false), Value: flakyEncoder(true)} + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: flakyEncoder(true), Value: flakyEncoder(true)} + expectResults(t, producer, 1, 2) + } + + closeProducer(t, producer) + leader.Close() + seedBroker.Close() +} + +// If a Kafka broker becomes unavailable and then returns back in service, then +// producer reconnects to it and continues sending messages. +func TestAsyncProducerBrokerBounce(t *testing.T) { + // Given + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + leaderAddr := leader.Addr() + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leaderAddr, leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + + config := NewConfig() + config.Producer.Flush.Messages = 1 + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + leader.Returns(prodSuccess) + expectResults(t, producer, 1, 0) + + // When: a broker connection gets reset by a broker (network glitch, restart, you name it). + leader.Close() // producer should get EOF + leader = NewMockBrokerAddr(t, 2, leaderAddr) // start it up again right away for giggles + seedBroker.Returns(metadataResponse) // tell it to go to broker 2 again + + // Then: a produced message goes through the new broker connection. + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + leader.Returns(prodSuccess) + expectResults(t, producer, 1, 0) + + closeProducer(t, producer) + seedBroker.Close() + leader.Close() +} + +func TestAsyncProducerBrokerBounceWithStaleMetadata(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader1 := NewMockBroker(t, 2) + leader2 := NewMockBroker(t, 3) + + metadataLeader1 := new(MetadataResponse) + metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID()) + metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader1) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Retry.Max = 3 + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + leader1.Close() // producer should get EOF + seedBroker.Returns(metadataLeader1) // tell it to go to leader1 again even though it's still down + seedBroker.Returns(metadataLeader1) // tell it to go to leader1 again even though it's still down + + // ok fine, tell it to go to leader2 finally + metadataLeader2 := new(MetadataResponse) + metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID()) + metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader2) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader2.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + seedBroker.Close() + leader2.Close() + + closeProducer(t, producer) +} + +func TestAsyncProducerMultipleRetries(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader1 := NewMockBroker(t, 2) + leader2 := NewMockBroker(t, 3) + + metadataLeader1 := new(MetadataResponse) + metadataLeader1.AddBroker(leader1.Addr(), leader1.BrokerID()) + metadataLeader1.AddTopicPartition("my_topic", 0, leader1.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader1) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Retry.Max = 4 + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + prodNotLeader := new(ProduceResponse) + prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition) + leader1.Returns(prodNotLeader) + + metadataLeader2 := new(MetadataResponse) + metadataLeader2.AddBroker(leader2.Addr(), leader2.BrokerID()) + metadataLeader2.AddTopicPartition("my_topic", 0, leader2.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader2) + leader2.Returns(prodNotLeader) + seedBroker.Returns(metadataLeader1) + leader1.Returns(prodNotLeader) + seedBroker.Returns(metadataLeader1) + leader1.Returns(prodNotLeader) + seedBroker.Returns(metadataLeader2) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader2.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + leader2.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + + seedBroker.Close() + leader1.Close() + leader2.Close() + closeProducer(t, producer) +} + +func TestAsyncProducerOutOfRetries(t *testing.T) { + t.Skip("Enable once bug #294 is fixed.") + + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + config.Producer.Retry.Max = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + + prodNotLeader := new(ProduceResponse) + prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition) + leader.Returns(prodNotLeader) + + for i := 0; i < 10; i++ { + select { + case msg := <-producer.Errors(): + if msg.Err != ErrNotLeaderForPartition { + t.Error(msg.Err) + } + case <-producer.Successes(): + t.Error("Unexpected success") + } + } + + seedBroker.Returns(metadataResponse) + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + expectResults(t, producer, 10, 0) + + leader.Close() + seedBroker.Close() + safeClose(t, producer) +} + +func TestAsyncProducerRetryWithReferenceOpen(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + leaderAddr := leader.Addr() + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leaderAddr, leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 1, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + config.Producer.Retry.Max = 1 + config.Producer.Partitioner = NewRoundRobinPartitioner + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + // prime partition 0 + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + expectResults(t, producer, 1, 0) + + // prime partition 1 + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + prodSuccess = new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 1, ErrNoError) + leader.Returns(prodSuccess) + expectResults(t, producer, 1, 0) + + // reboot the broker (the producer will get EOF on its existing connection) + leader.Close() + leader = NewMockBrokerAddr(t, 2, leaderAddr) + + // send another message on partition 0 to trigger the EOF and retry + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + + // tell partition 0 to go to that broker again + seedBroker.Returns(metadataResponse) + + // succeed this time + prodSuccess = new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + expectResults(t, producer, 1, 0) + + // shutdown + closeProducer(t, producer) + seedBroker.Close() + leader.Close() +} + +func TestAsyncProducerFlusherRetryCondition(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 1, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Producer.Flush.Messages = 5 + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + config.Producer.Retry.Max = 1 + config.Producer.Partitioner = NewManualPartitioner + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + // prime partitions + for p := int32(0); p < 2; p++ { + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: p} + } + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", p, ErrNoError) + leader.Returns(prodSuccess) + expectResults(t, producer, 5, 0) + } + + // send more messages on partition 0 + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 0} + } + prodNotLeader := new(ProduceResponse) + prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition) + leader.Returns(prodNotLeader) + + time.Sleep(50 * time.Millisecond) + + leader.SetHandlerByMap(map[string]MockResponse{ + "ProduceRequest": NewMockProduceResponse(t). + SetError("my_topic", 0, ErrNoError), + }) + + // tell partition 0 to go to that broker again + seedBroker.Returns(metadataResponse) + + // succeed this time + expectResults(t, producer, 5, 0) + + // put five more through + for i := 0; i < 5; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage), Partition: 0} + } + expectResults(t, producer, 5, 0) + + // shutdown + closeProducer(t, producer) + seedBroker.Close() + leader.Close() +} + +func TestAsyncProducerRetryShutdown(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataLeader := new(MetadataResponse) + metadataLeader.AddBroker(leader.Addr(), leader.BrokerID()) + metadataLeader.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + producer.AsyncClose() + time.Sleep(5 * time.Millisecond) // let the shutdown goroutine kick in + + producer.Input() <- &ProducerMessage{Topic: "FOO"} + if err := <-producer.Errors(); err.Err != ErrShuttingDown { + t.Error(err) + } + + prodNotLeader := new(ProduceResponse) + prodNotLeader.AddTopicPartition("my_topic", 0, ErrNotLeaderForPartition) + leader.Returns(prodNotLeader) + + seedBroker.Returns(metadataLeader) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + expectResults(t, producer, 10, 0) + + seedBroker.Close() + leader.Close() + + // wait for the async-closed producer to shut down fully + for err := range producer.Errors() { + t.Error(err) + } +} + +func TestAsyncProducerNoReturns(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataLeader := new(MetadataResponse) + metadataLeader.AddBroker(leader.Addr(), leader.BrokerID()) + metadataLeader.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = false + config.Producer.Return.Errors = false + config.Producer.Retry.Backoff = 0 + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + + wait := make(chan bool) + go func() { + if err := producer.Close(); err != nil { + t.Error(err) + } + close(wait) + }() + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + <-wait + seedBroker.Close() + leader.Close() +} + +// This example shows how to use the producer while simultaneously +// reading the Errors channel to know about any failures. +func ExampleAsyncProducer_select() { + producer, err := NewAsyncProducer([]string{"localhost:9092"}, nil) + if err != nil { + panic(err) + } + + defer func() { + if err := producer.Close(); err != nil { + log.Fatalln(err) + } + }() + + // Trap SIGINT to trigger a shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + var enqueued, errors int +ProducerLoop: + for { + select { + case producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder("testing 123")}: + enqueued++ + case err := <-producer.Errors(): + log.Println("Failed to produce message", err) + errors++ + case <-signals: + break ProducerLoop + } + } + + log.Printf("Enqueued: %d; errors: %d\n", enqueued, errors) +} + +// This example shows how to use the producer with separate goroutines +// reading from the Successes and Errors channels. Note that in order +// for the Successes channel to be populated, you have to set +// config.Producer.Return.Successes to true. +func ExampleAsyncProducer_goroutines() { + config := NewConfig() + config.Producer.Return.Successes = true + producer, err := NewAsyncProducer([]string{"localhost:9092"}, config) + if err != nil { + panic(err) + } + + // Trap SIGINT to trigger a graceful shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + var ( + wg sync.WaitGroup + enqueued, successes, errors int + ) + + wg.Add(1) + go func() { + defer wg.Done() + for range producer.Successes() { + successes++ + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for err := range producer.Errors() { + log.Println(err) + errors++ + } + }() + +ProducerLoop: + for { + message := &ProducerMessage{Topic: "my_topic", Value: StringEncoder("testing 123")} + select { + case producer.Input() <- message: + enqueued++ + + case <-signals: + producer.AsyncClose() // Trigger a shutdown of the producer. + break ProducerLoop + } + } + + wg.Wait() + + log.Printf("Successfully produced: %d; errors: %d\n", successes, errors) +} diff --git a/vendor/github.com/Shopify/sarama/broker.go b/vendor/github.com/Shopify/sarama/broker.go new file mode 100644 index 00000000..f57a6909 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/broker.go @@ -0,0 +1,685 @@ +package sarama + +import ( + "crypto/tls" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/rcrowley/go-metrics" +) + +// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe. +type Broker struct { + id int32 + addr string + + conf *Config + correlationID int32 + conn net.Conn + connErr error + lock sync.Mutex + opened int32 + + responses chan responsePromise + done chan bool + + incomingByteRate metrics.Meter + requestRate metrics.Meter + requestSize metrics.Histogram + requestLatency metrics.Histogram + outgoingByteRate metrics.Meter + responseRate metrics.Meter + responseSize metrics.Histogram + brokerIncomingByteRate metrics.Meter + brokerRequestRate metrics.Meter + brokerRequestSize metrics.Histogram + brokerRequestLatency metrics.Histogram + brokerOutgoingByteRate metrics.Meter + brokerResponseRate metrics.Meter + brokerResponseSize metrics.Histogram +} + +type responsePromise struct { + requestTime time.Time + correlationID int32 + packets chan []byte + errors chan error +} + +// NewBroker creates and returns a Broker targeting the given host:port address. +// This does not attempt to actually connect, you have to call Open() for that. +func NewBroker(addr string) *Broker { + return &Broker{id: -1, addr: addr} +} + +// Open tries to connect to the Broker if it is not already connected or connecting, but does not block +// waiting for the connection to complete. This means that any subsequent operations on the broker will +// block waiting for the connection to succeed or fail. To get the effect of a fully synchronous Open call, +// follow it by a call to Connected(). The only errors Open will return directly are ConfigurationError or +// AlreadyConnected. If conf is nil, the result of NewConfig() is used. +func (b *Broker) Open(conf *Config) error { + if !atomic.CompareAndSwapInt32(&b.opened, 0, 1) { + return ErrAlreadyConnected + } + + if conf == nil { + conf = NewConfig() + } + + err := conf.Validate() + if err != nil { + return err + } + + b.lock.Lock() + + go withRecover(func() { + defer b.lock.Unlock() + + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + } + + if conf.Net.TLS.Enable { + b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config) + } else { + b.conn, b.connErr = dialer.Dial("tcp", b.addr) + } + if b.connErr != nil { + Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr) + b.conn = nil + atomic.StoreInt32(&b.opened, 0) + return + } + b.conn = newBufConn(b.conn) + + b.conf = conf + + // Create or reuse the global metrics shared between brokers + b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry) + b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry) + b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry) + b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry) + b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry) + b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry) + b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry) + // Do not gather metrics for seeded broker (only used during bootstrap) because they share + // the same id (-1) and are already exposed through the global metrics above + if b.id >= 0 { + b.brokerIncomingByteRate = getOrRegisterBrokerMeter("incoming-byte-rate", b, conf.MetricRegistry) + b.brokerRequestRate = getOrRegisterBrokerMeter("request-rate", b, conf.MetricRegistry) + b.brokerRequestSize = getOrRegisterBrokerHistogram("request-size", b, conf.MetricRegistry) + b.brokerRequestLatency = getOrRegisterBrokerHistogram("request-latency-in-ms", b, conf.MetricRegistry) + b.brokerOutgoingByteRate = getOrRegisterBrokerMeter("outgoing-byte-rate", b, conf.MetricRegistry) + b.brokerResponseRate = getOrRegisterBrokerMeter("response-rate", b, conf.MetricRegistry) + b.brokerResponseSize = getOrRegisterBrokerHistogram("response-size", b, conf.MetricRegistry) + } + + if conf.Net.SASL.Enable { + b.connErr = b.sendAndReceiveSASLPlainAuth() + if b.connErr != nil { + err = b.conn.Close() + if err == nil { + Logger.Printf("Closed connection to broker %s\n", b.addr) + } else { + Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err) + } + b.conn = nil + atomic.StoreInt32(&b.opened, 0) + return + } + } + + b.done = make(chan bool) + b.responses = make(chan responsePromise, b.conf.Net.MaxOpenRequests-1) + + if b.id >= 0 { + Logger.Printf("Connected to broker at %s (registered as #%d)\n", b.addr, b.id) + } else { + Logger.Printf("Connected to broker at %s (unregistered)\n", b.addr) + } + go withRecover(b.responseReceiver) + }) + + return nil +} + +// Connected returns true if the broker is connected and false otherwise. If the broker is not +// connected but it had tried to connect, the error from that connection attempt is also returned. +func (b *Broker) Connected() (bool, error) { + b.lock.Lock() + defer b.lock.Unlock() + + return b.conn != nil, b.connErr +} + +func (b *Broker) Close() error { + b.lock.Lock() + defer b.lock.Unlock() + + if b.conn == nil { + return ErrNotConnected + } + + close(b.responses) + <-b.done + + err := b.conn.Close() + + b.conn = nil + b.connErr = nil + b.done = nil + b.responses = nil + + if err == nil { + Logger.Printf("Closed connection to broker %s\n", b.addr) + } else { + Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err) + } + + atomic.StoreInt32(&b.opened, 0) + + return err +} + +// ID returns the broker ID retrieved from Kafka's metadata, or -1 if that is not known. +func (b *Broker) ID() int32 { + return b.id +} + +// Addr returns the broker address as either retrieved from Kafka's metadata or passed to NewBroker. +func (b *Broker) Addr() string { + return b.addr +} + +func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) { + response := new(MetadataResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) { + response := new(ConsumerMetadataResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) { + response := new(OffsetResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) { + var response *ProduceResponse + var err error + + if request.RequiredAcks == NoResponse { + err = b.sendAndReceive(request, nil) + } else { + response = new(ProduceResponse) + err = b.sendAndReceive(request, response) + } + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) { + response := new(FetchResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) { + response := new(OffsetCommitResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) { + response := new(OffsetFetchResponse) + + err := b.sendAndReceive(request, response) + + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) { + response := new(JoinGroupResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) { + response := new(SyncGroupResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) { + response := new(LeaveGroupResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) { + response := new(HeartbeatResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) { + response := new(ListGroupsResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) { + response := new(DescribeGroupsResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) { + response := new(ApiVersionsResponse) + + err := b.sendAndReceive(request, response) + if err != nil { + return nil, err + } + + return response, nil +} + +func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) { + b.lock.Lock() + defer b.lock.Unlock() + + if b.conn == nil { + if b.connErr != nil { + return nil, b.connErr + } + return nil, ErrNotConnected + } + + if !b.conf.Version.IsAtLeast(rb.requiredVersion()) { + return nil, ErrUnsupportedVersion + } + + req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} + buf, err := encode(req, b.conf.MetricRegistry) + if err != nil { + return nil, err + } + + err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) + if err != nil { + return nil, err + } + + requestTime := time.Now() + bytes, err := b.conn.Write(buf) + b.updateOutgoingCommunicationMetrics(bytes) + if err != nil { + return nil, err + } + b.correlationID++ + + if !promiseResponse { + // Record request latency without the response + b.updateRequestLatencyMetrics(time.Since(requestTime)) + return nil, nil + } + + promise := responsePromise{requestTime, req.correlationID, make(chan []byte), make(chan error)} + b.responses <- promise + + return &promise, nil +} + +func (b *Broker) sendAndReceive(req protocolBody, res versionedDecoder) error { + promise, err := b.send(req, res != nil) + + if err != nil { + return err + } + + if promise == nil { + return nil + } + + select { + case buf := <-promise.packets: + return versionedDecode(buf, res, req.version()) + case err = <-promise.errors: + return err + } +} + +func (b *Broker) decode(pd packetDecoder) (err error) { + b.id, err = pd.getInt32() + if err != nil { + return err + } + + host, err := pd.getString() + if err != nil { + return err + } + + port, err := pd.getInt32() + if err != nil { + return err + } + + b.addr = net.JoinHostPort(host, fmt.Sprint(port)) + if _, _, err := net.SplitHostPort(b.addr); err != nil { + return err + } + + return nil +} + +func (b *Broker) encode(pe packetEncoder) (err error) { + + host, portstr, err := net.SplitHostPort(b.addr) + if err != nil { + return err + } + port, err := strconv.Atoi(portstr) + if err != nil { + return err + } + + pe.putInt32(b.id) + + err = pe.putString(host) + if err != nil { + return err + } + + pe.putInt32(int32(port)) + + return nil +} + +func (b *Broker) responseReceiver() { + var dead error + header := make([]byte, 8) + for response := range b.responses { + if dead != nil { + response.errors <- dead + continue + } + + err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)) + if err != nil { + dead = err + response.errors <- err + continue + } + + bytesReadHeader, err := io.ReadFull(b.conn, header) + requestLatency := time.Since(response.requestTime) + if err != nil { + b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) + dead = err + response.errors <- err + continue + } + + decodedHeader := responseHeader{} + err = decode(header, &decodedHeader) + if err != nil { + b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) + dead = err + response.errors <- err + continue + } + if decodedHeader.correlationID != response.correlationID { + b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) + // TODO if decoded ID < cur ID, discard until we catch up + // TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response + dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", response.correlationID, decodedHeader.correlationID)} + response.errors <- dead + continue + } + + buf := make([]byte, decodedHeader.length-4) + bytesReadBody, err := io.ReadFull(b.conn, buf) + b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency) + if err != nil { + dead = err + response.errors <- err + continue + } + + response.packets <- buf + } + close(b.done) +} + +func (b *Broker) sendAndReceiveSASLPlainHandshake() error { + rb := &SaslHandshakeRequest{"PLAIN"} + req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} + buf, err := encode(req, b.conf.MetricRegistry) + if err != nil { + return err + } + + err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) + if err != nil { + return err + } + + requestTime := time.Now() + bytes, err := b.conn.Write(buf) + b.updateOutgoingCommunicationMetrics(bytes) + if err != nil { + Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error()) + return err + } + b.correlationID++ + //wait for the response + header := make([]byte, 8) // response header + _, err = io.ReadFull(b.conn, header) + if err != nil { + Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error()) + return err + } + length := binary.BigEndian.Uint32(header[:4]) + payload := make([]byte, length-4) + n, err := io.ReadFull(b.conn, payload) + if err != nil { + Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error()) + return err + } + b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime)) + res := &SaslHandshakeResponse{} + err = versionedDecode(payload, res, 0) + if err != nil { + Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error()) + return err + } + if res.Err != ErrNoError { + Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error()) + return res.Err + } + Logger.Print("Successful SASL handshake") + return nil +} + +// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149) +// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9 +// +// In SASL Plain, Kafka expects the auth header to be in the following format +// Message format (from https://tools.ietf.org/html/rfc4616): +// +// message = [authzid] UTF8NUL authcid UTF8NUL passwd +// authcid = 1*SAFE ; MUST accept up to 255 octets +// authzid = 1*SAFE ; MUST accept up to 255 octets +// passwd = 1*SAFE ; MUST accept up to 255 octets +// UTF8NUL = %x00 ; UTF-8 encoded NUL character +// +// SAFE = UTF1 / UTF2 / UTF3 / UTF4 +// ;; any UTF-8 encoded Unicode character except NUL +// +// When credentials are valid, Kafka returns a 4 byte array of null characters. +// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way +// of responding to bad credentials but thats how its being done today. +func (b *Broker) sendAndReceiveSASLPlainAuth() error { + if b.conf.Net.SASL.Handshake { + handshakeErr := b.sendAndReceiveSASLPlainHandshake() + if handshakeErr != nil { + Logger.Printf("Error while performing SASL handshake %s\n", b.addr) + return handshakeErr + } + } + length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password) + authBytes := make([]byte, length+4) //4 byte length header + auth data + binary.BigEndian.PutUint32(authBytes, uint32(length)) + copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password)) + + err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) + if err != nil { + Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error()) + return err + } + + requestTime := time.Now() + bytesWritten, err := b.conn.Write(authBytes) + b.updateOutgoingCommunicationMetrics(bytesWritten) + if err != nil { + Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) + return err + } + + header := make([]byte, 4) + n, err := io.ReadFull(b.conn, header) + b.updateIncomingCommunicationMetrics(n, time.Since(requestTime)) + // If the credentials are valid, we would get a 4 byte response filled with null characters. + // Otherwise, the broker closes the connection and we get an EOF + if err != nil { + Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) + return err + } + + Logger.Printf("SASL authentication successful with broker %s:%v - %v\n", b.addr, n, header) + return nil +} + +func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) { + b.updateRequestLatencyMetrics(requestLatency) + b.responseRate.Mark(1) + if b.brokerResponseRate != nil { + b.brokerResponseRate.Mark(1) + } + responseSize := int64(bytes) + b.incomingByteRate.Mark(responseSize) + if b.brokerIncomingByteRate != nil { + b.brokerIncomingByteRate.Mark(responseSize) + } + b.responseSize.Update(responseSize) + if b.brokerResponseSize != nil { + b.brokerResponseSize.Update(responseSize) + } +} + +func (b *Broker) updateRequestLatencyMetrics(requestLatency time.Duration) { + requestLatencyInMs := int64(requestLatency / time.Millisecond) + b.requestLatency.Update(requestLatencyInMs) + if b.brokerRequestLatency != nil { + b.brokerRequestLatency.Update(requestLatencyInMs) + } +} + +func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) { + b.requestRate.Mark(1) + if b.brokerRequestRate != nil { + b.brokerRequestRate.Mark(1) + } + requestSize := int64(bytes) + b.outgoingByteRate.Mark(requestSize) + if b.brokerOutgoingByteRate != nil { + b.brokerOutgoingByteRate.Mark(requestSize) + } + b.requestSize.Update(requestSize) + if b.brokerRequestSize != nil { + b.brokerRequestSize.Update(requestSize) + } +} diff --git a/vendor/github.com/Shopify/sarama/broker_test.go b/vendor/github.com/Shopify/sarama/broker_test.go new file mode 100644 index 00000000..fcbe627f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/broker_test.go @@ -0,0 +1,328 @@ +package sarama + +import ( + "fmt" + "testing" + "time" +) + +func ExampleBroker() { + broker := NewBroker("localhost:9092") + err := broker.Open(nil) + if err != nil { + panic(err) + } + + request := MetadataRequest{Topics: []string{"myTopic"}} + response, err := broker.GetMetadata(&request) + if err != nil { + _ = broker.Close() + panic(err) + } + + fmt.Println("There are", len(response.Topics), "topics active in the cluster.") + + if err = broker.Close(); err != nil { + panic(err) + } +} + +type mockEncoder struct { + bytes []byte +} + +func (m mockEncoder) encode(pe packetEncoder) error { + return pe.putRawBytes(m.bytes) +} + +type brokerMetrics struct { + bytesRead int + bytesWritten int +} + +func TestBrokerAccessors(t *testing.T) { + broker := NewBroker("abc:123") + + if broker.ID() != -1 { + t.Error("New broker didn't have an ID of -1.") + } + + if broker.Addr() != "abc:123" { + t.Error("New broker didn't have the correct address") + } + + broker.id = 34 + if broker.ID() != 34 { + t.Error("Manually setting broker ID did not take effect.") + } +} + +func TestSimpleBrokerCommunication(t *testing.T) { + for _, tt := range brokerTestTable { + Logger.Printf("Testing broker communication for %s", tt.name) + mb := NewMockBroker(t, 0) + mb.Returns(&mockEncoder{tt.response}) + pendingNotify := make(chan brokerMetrics) + // Register a callback to be notified about successful requests + mb.SetNotifier(func(bytesRead, bytesWritten int) { + pendingNotify <- brokerMetrics{bytesRead, bytesWritten} + }) + broker := NewBroker(mb.Addr()) + // Set the broker id in order to validate local broker metrics + broker.id = 0 + conf := NewConfig() + conf.Version = V0_10_0_0 + err := broker.Open(conf) + if err != nil { + t.Fatal(err) + } + tt.runner(t, broker) + err = broker.Close() + if err != nil { + t.Error(err) + } + // Wait up to 500 ms for the remote broker to process the request and + // notify us about the metrics + timeout := 500 * time.Millisecond + select { + case mockBrokerMetrics := <-pendingNotify: + validateBrokerMetrics(t, broker, mockBrokerMetrics) + case <-time.After(timeout): + t.Errorf("No request received for: %s after waiting for %v", tt.name, timeout) + } + mb.Close() + } + +} + +// We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake +var brokerTestTable = []struct { + name string + response []byte + runner func(*testing.T, *Broker) +}{ + {"MetadataRequest", + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := MetadataRequest{} + response, err := broker.GetMetadata(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Metadata request got no response!") + } + }}, + + {"ConsumerMetadataRequest", + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 't', 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := ConsumerMetadataRequest{} + response, err := broker.GetConsumerMetadata(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Consumer Metadata request got no response!") + } + }}, + + {"ProduceRequest (NoResponse)", + []byte{}, + func(t *testing.T, broker *Broker) { + request := ProduceRequest{} + request.RequiredAcks = NoResponse + response, err := broker.Produce(&request) + if err != nil { + t.Error(err) + } + if response != nil { + t.Error("Produce request with NoResponse got a response!") + } + }}, + + {"ProduceRequest (WaitForLocal)", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := ProduceRequest{} + request.RequiredAcks = WaitForLocal + response, err := broker.Produce(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Produce request without NoResponse got no response!") + } + }}, + + {"FetchRequest", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := FetchRequest{} + response, err := broker.Fetch(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Fetch request got no response!") + } + }}, + + {"OffsetFetchRequest", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := OffsetFetchRequest{} + response, err := broker.FetchOffset(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("OffsetFetch request got no response!") + } + }}, + + {"OffsetCommitRequest", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := OffsetCommitRequest{} + response, err := broker.CommitOffset(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("OffsetCommit request got no response!") + } + }}, + + {"OffsetRequest", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := OffsetRequest{} + response, err := broker.GetAvailableOffsets(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Offset request got no response!") + } + }}, + + {"JoinGroupRequest", + []byte{0x00, 0x17, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := JoinGroupRequest{} + response, err := broker.JoinGroup(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("JoinGroup request got no response!") + } + }}, + + {"SyncGroupRequest", + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := SyncGroupRequest{} + response, err := broker.SyncGroup(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("SyncGroup request got no response!") + } + }}, + + {"LeaveGroupRequest", + []byte{0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := LeaveGroupRequest{} + response, err := broker.LeaveGroup(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("LeaveGroup request got no response!") + } + }}, + + {"HeartbeatRequest", + []byte{0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := HeartbeatRequest{} + response, err := broker.Heartbeat(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("Heartbeat request got no response!") + } + }}, + + {"ListGroupsRequest", + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := ListGroupsRequest{} + response, err := broker.ListGroups(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("ListGroups request got no response!") + } + }}, + + {"DescribeGroupsRequest", + []byte{0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := DescribeGroupsRequest{} + response, err := broker.DescribeGroups(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("DescribeGroups request got no response!") + } + }}, + + {"ApiVersionsRequest", + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + func(t *testing.T, broker *Broker) { + request := ApiVersionsRequest{} + response, err := broker.ApiVersions(&request) + if err != nil { + t.Error(err) + } + if response == nil { + t.Error("ApiVersions request got no response!") + } + }}, +} + +func validateBrokerMetrics(t *testing.T, broker *Broker, mockBrokerMetrics brokerMetrics) { + metricValidators := newMetricValidators() + mockBrokerBytesRead := mockBrokerMetrics.bytesRead + mockBrokerBytesWritten := mockBrokerMetrics.bytesWritten + + // Check that the number of bytes sent corresponds to what the mock broker received + metricValidators.registerForAllBrokers(broker, countMeterValidator("incoming-byte-rate", mockBrokerBytesWritten)) + if mockBrokerBytesWritten == 0 { + // This a ProduceRequest with NoResponse + metricValidators.registerForAllBrokers(broker, countMeterValidator("response-rate", 0)) + metricValidators.registerForAllBrokers(broker, countHistogramValidator("response-size", 0)) + metricValidators.registerForAllBrokers(broker, minMaxHistogramValidator("response-size", 0, 0)) + } else { + metricValidators.registerForAllBrokers(broker, countMeterValidator("response-rate", 1)) + metricValidators.registerForAllBrokers(broker, countHistogramValidator("response-size", 1)) + metricValidators.registerForAllBrokers(broker, minMaxHistogramValidator("response-size", mockBrokerBytesWritten, mockBrokerBytesWritten)) + } + + // Check that the number of bytes received corresponds to what the mock broker sent + metricValidators.registerForAllBrokers(broker, countMeterValidator("outgoing-byte-rate", mockBrokerBytesRead)) + metricValidators.registerForAllBrokers(broker, countMeterValidator("request-rate", 1)) + metricValidators.registerForAllBrokers(broker, countHistogramValidator("request-size", 1)) + metricValidators.registerForAllBrokers(broker, minMaxHistogramValidator("request-size", mockBrokerBytesRead, mockBrokerBytesRead)) + + // Run the validators + metricValidators.run(t, broker.conf.MetricRegistry) +} diff --git a/vendor/github.com/Shopify/sarama/client.go b/vendor/github.com/Shopify/sarama/client.go new file mode 100644 index 00000000..45de3973 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/client.go @@ -0,0 +1,779 @@ +package sarama + +import ( + "math/rand" + "sort" + "sync" + "time" +) + +// Client is a generic Kafka client. It manages connections to one or more Kafka brokers. +// You MUST call Close() on a client to avoid leaks, it will not be garbage-collected +// automatically when it passes out of scope. It is safe to share a client amongst many +// users, however Kafka will process requests from a single client strictly in serial, +// so it is generally more efficient to use the default one client per producer/consumer. +type Client interface { + // Config returns the Config struct of the client. This struct should not be + // altered after it has been created. + Config() *Config + + // Brokers returns the current set of active brokers as retrieved from cluster metadata. + Brokers() []*Broker + + // Topics returns the set of available topics as retrieved from cluster metadata. + Topics() ([]string, error) + + // Partitions returns the sorted list of all partition IDs for the given topic. + Partitions(topic string) ([]int32, error) + + // WritablePartitions returns the sorted list of all writable partition IDs for + // the given topic, where "writable" means "having a valid leader accepting + // writes". + WritablePartitions(topic string) ([]int32, error) + + // Leader returns the broker object that is the leader of the current + // topic/partition, as determined by querying the cluster metadata. + Leader(topic string, partitionID int32) (*Broker, error) + + // Replicas returns the set of all replica IDs for the given partition. + Replicas(topic string, partitionID int32) ([]int32, error) + + // InSyncReplicas returns the set of all in-sync replica IDs for the given + // partition. In-sync replicas are replicas which are fully caught up with + // the partition leader. + InSyncReplicas(topic string, partitionID int32) ([]int32, error) + + // RefreshMetadata takes a list of topics and queries the cluster to refresh the + // available metadata for those topics. If no topics are provided, it will refresh + // metadata for all topics. + RefreshMetadata(topics ...string) error + + // GetOffset queries the cluster to get the most recent available offset at the + // given time on the topic/partition combination. Time should be OffsetOldest for + // the earliest available offset, OffsetNewest for the offset of the message that + // will be produced next, or a time. + GetOffset(topic string, partitionID int32, time int64) (int64, error) + + // Coordinator returns the coordinating broker for a consumer group. It will + // return a locally cached value if it's available. You can call + // RefreshCoordinator to update the cached value. This function only works on + // Kafka 0.8.2 and higher. + Coordinator(consumerGroup string) (*Broker, error) + + // RefreshCoordinator retrieves the coordinator for a consumer group and stores it + // in local cache. This function only works on Kafka 0.8.2 and higher. + RefreshCoordinator(consumerGroup string) error + + // Close shuts down all broker connections managed by this client. It is required + // to call this function before a client object passes out of scope, as it will + // otherwise leak memory. You must close any Producers or Consumers using a client + // before you close the client. + Close() error + + // Closed returns true if the client has already had Close called on it + Closed() bool +} + +const ( + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the partition. You + // can send this to a client's GetOffset method to get this offset, or when + // calling ConsumePartition to start consuming new messages. + OffsetNewest int64 = -1 + // OffsetOldest stands for the oldest offset available on the broker for a + // partition. You can send this to a client's GetOffset method to get this + // offset, or when calling ConsumePartition to start consuming from the + // oldest offset that is still available on the broker. + OffsetOldest int64 = -2 +) + +type client struct { + conf *Config + closer, closed chan none // for shutting down background metadata updater + + // the broker addresses given to us through the constructor are not guaranteed to be returned in + // the cluster metadata (I *think* it only returns brokers who are currently leading partitions?) + // so we store them separately + seedBrokers []*Broker + deadSeeds []*Broker + + brokers map[int32]*Broker // maps broker ids to brokers + metadata map[string]map[int32]*PartitionMetadata // maps topics to partition ids to metadata + coordinators map[string]int32 // Maps consumer group names to coordinating broker IDs + + // If the number of partitions is large, we can get some churn calling cachedPartitions, + // so the result is cached. It is important to update this value whenever metadata is changed + cachedPartitionsResults map[string][maxPartitionIndex][]int32 + + lock sync.RWMutex // protects access to the maps that hold cluster state. +} + +// NewClient creates a new Client. It connects to one of the given broker addresses +// and uses that broker to automatically fetch metadata on the rest of the kafka cluster. If metadata cannot +// be retrieved from any of the given broker addresses, the client is not created. +func NewClient(addrs []string, conf *Config) (Client, error) { + Logger.Println("Initializing new client") + + if conf == nil { + conf = NewConfig() + } + + if err := conf.Validate(); err != nil { + return nil, err + } + + if len(addrs) < 1 { + return nil, ConfigurationError("You must provide at least one broker address") + } + + client := &client{ + conf: conf, + closer: make(chan none), + closed: make(chan none), + brokers: make(map[int32]*Broker), + metadata: make(map[string]map[int32]*PartitionMetadata), + cachedPartitionsResults: make(map[string][maxPartitionIndex][]int32), + coordinators: make(map[string]int32), + } + + random := rand.New(rand.NewSource(time.Now().UnixNano())) + for _, index := range random.Perm(len(addrs)) { + client.seedBrokers = append(client.seedBrokers, NewBroker(addrs[index])) + } + + // do an initial fetch of all cluster metadata by specifying an empty list of topics + err := client.RefreshMetadata() + switch err { + case nil: + break + case ErrLeaderNotAvailable, ErrReplicaNotAvailable, ErrTopicAuthorizationFailed, ErrClusterAuthorizationFailed: + // indicates that maybe part of the cluster is down, but is not fatal to creating the client + Logger.Println(err) + default: + close(client.closed) // we haven't started the background updater yet, so we have to do this manually + _ = client.Close() + return nil, err + } + go withRecover(client.backgroundMetadataUpdater) + + Logger.Println("Successfully initialized new client") + + return client, nil +} + +func (client *client) Config() *Config { + return client.conf +} + +func (client *client) Brokers() []*Broker { + client.lock.RLock() + defer client.lock.RUnlock() + brokers := make([]*Broker, 0) + for _, broker := range client.brokers { + brokers = append(brokers, broker) + } + return brokers +} + +func (client *client) Close() error { + if client.Closed() { + // Chances are this is being called from a defer() and the error will go unobserved + // so we go ahead and log the event in this case. + Logger.Printf("Close() called on already closed client") + return ErrClosedClient + } + + // shutdown and wait for the background thread before we take the lock, to avoid races + close(client.closer) + <-client.closed + + client.lock.Lock() + defer client.lock.Unlock() + Logger.Println("Closing Client") + + for _, broker := range client.brokers { + safeAsyncClose(broker) + } + + for _, broker := range client.seedBrokers { + safeAsyncClose(broker) + } + + client.brokers = nil + client.metadata = nil + + return nil +} + +func (client *client) Closed() bool { + return client.brokers == nil +} + +func (client *client) Topics() ([]string, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + client.lock.RLock() + defer client.lock.RUnlock() + + ret := make([]string, 0, len(client.metadata)) + for topic := range client.metadata { + ret = append(ret, topic) + } + + return ret, nil +} + +func (client *client) Partitions(topic string) ([]int32, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + partitions := client.cachedPartitions(topic, allPartitions) + + if len(partitions) == 0 { + err := client.RefreshMetadata(topic) + if err != nil { + return nil, err + } + partitions = client.cachedPartitions(topic, allPartitions) + } + + if partitions == nil { + return nil, ErrUnknownTopicOrPartition + } + + return partitions, nil +} + +func (client *client) WritablePartitions(topic string) ([]int32, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + partitions := client.cachedPartitions(topic, writablePartitions) + + // len==0 catches when it's nil (no such topic) and the odd case when every single + // partition is undergoing leader election simultaneously. Callers have to be able to handle + // this function returning an empty slice (which is a valid return value) but catching it + // here the first time (note we *don't* catch it below where we return ErrUnknownTopicOrPartition) triggers + // a metadata refresh as a nicety so callers can just try again and don't have to manually + // trigger a refresh (otherwise they'd just keep getting a stale cached copy). + if len(partitions) == 0 { + err := client.RefreshMetadata(topic) + if err != nil { + return nil, err + } + partitions = client.cachedPartitions(topic, writablePartitions) + } + + if partitions == nil { + return nil, ErrUnknownTopicOrPartition + } + + return partitions, nil +} + +func (client *client) Replicas(topic string, partitionID int32) ([]int32, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + metadata := client.cachedMetadata(topic, partitionID) + + if metadata == nil { + err := client.RefreshMetadata(topic) + if err != nil { + return nil, err + } + metadata = client.cachedMetadata(topic, partitionID) + } + + if metadata == nil { + return nil, ErrUnknownTopicOrPartition + } + + if metadata.Err == ErrReplicaNotAvailable { + return nil, metadata.Err + } + return dupeAndSort(metadata.Replicas), nil +} + +func (client *client) InSyncReplicas(topic string, partitionID int32) ([]int32, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + metadata := client.cachedMetadata(topic, partitionID) + + if metadata == nil { + err := client.RefreshMetadata(topic) + if err != nil { + return nil, err + } + metadata = client.cachedMetadata(topic, partitionID) + } + + if metadata == nil { + return nil, ErrUnknownTopicOrPartition + } + + if metadata.Err == ErrReplicaNotAvailable { + return nil, metadata.Err + } + return dupeAndSort(metadata.Isr), nil +} + +func (client *client) Leader(topic string, partitionID int32) (*Broker, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + leader, err := client.cachedLeader(topic, partitionID) + + if leader == nil { + err = client.RefreshMetadata(topic) + if err != nil { + return nil, err + } + leader, err = client.cachedLeader(topic, partitionID) + } + + return leader, err +} + +func (client *client) RefreshMetadata(topics ...string) error { + if client.Closed() { + return ErrClosedClient + } + + // Prior to 0.8.2, Kafka will throw exceptions on an empty topic and not return a proper + // error. This handles the case by returning an error instead of sending it + // off to Kafka. See: https://github.com/Shopify/sarama/pull/38#issuecomment-26362310 + for _, topic := range topics { + if len(topic) == 0 { + return ErrInvalidTopic // this is the error that 0.8.2 and later correctly return + } + } + + return client.tryRefreshMetadata(topics, client.conf.Metadata.Retry.Max) +} + +func (client *client) GetOffset(topic string, partitionID int32, time int64) (int64, error) { + if client.Closed() { + return -1, ErrClosedClient + } + + offset, err := client.getOffset(topic, partitionID, time) + + if err != nil { + if err := client.RefreshMetadata(topic); err != nil { + return -1, err + } + return client.getOffset(topic, partitionID, time) + } + + return offset, err +} + +func (client *client) Coordinator(consumerGroup string) (*Broker, error) { + if client.Closed() { + return nil, ErrClosedClient + } + + coordinator := client.cachedCoordinator(consumerGroup) + + if coordinator == nil { + if err := client.RefreshCoordinator(consumerGroup); err != nil { + return nil, err + } + coordinator = client.cachedCoordinator(consumerGroup) + } + + if coordinator == nil { + return nil, ErrConsumerCoordinatorNotAvailable + } + + _ = coordinator.Open(client.conf) + return coordinator, nil +} + +func (client *client) RefreshCoordinator(consumerGroup string) error { + if client.Closed() { + return ErrClosedClient + } + + response, err := client.getConsumerMetadata(consumerGroup, client.conf.Metadata.Retry.Max) + if err != nil { + return err + } + + client.lock.Lock() + defer client.lock.Unlock() + client.registerBroker(response.Coordinator) + client.coordinators[consumerGroup] = response.Coordinator.ID() + return nil +} + +// private broker management helpers + +// registerBroker makes sure a broker received by a Metadata or Coordinator request is registered +// in the brokers map. It returns the broker that is registered, which may be the provided broker, +// or a previously registered Broker instance. You must hold the write lock before calling this function. +func (client *client) registerBroker(broker *Broker) { + if client.brokers[broker.ID()] == nil { + client.brokers[broker.ID()] = broker + Logger.Printf("client/brokers registered new broker #%d at %s", broker.ID(), broker.Addr()) + } else if broker.Addr() != client.brokers[broker.ID()].Addr() { + safeAsyncClose(client.brokers[broker.ID()]) + client.brokers[broker.ID()] = broker + Logger.Printf("client/brokers replaced registered broker #%d with %s", broker.ID(), broker.Addr()) + } +} + +// deregisterBroker removes a broker from the seedsBroker list, and if it's +// not the seedbroker, removes it from brokers map completely. +func (client *client) deregisterBroker(broker *Broker) { + client.lock.Lock() + defer client.lock.Unlock() + + if len(client.seedBrokers) > 0 && broker == client.seedBrokers[0] { + client.deadSeeds = append(client.deadSeeds, broker) + client.seedBrokers = client.seedBrokers[1:] + } else { + // we do this so that our loop in `tryRefreshMetadata` doesn't go on forever, + // but we really shouldn't have to; once that loop is made better this case can be + // removed, and the function generally can be renamed from `deregisterBroker` to + // `nextSeedBroker` or something + Logger.Printf("client/brokers deregistered broker #%d at %s", broker.ID(), broker.Addr()) + delete(client.brokers, broker.ID()) + } +} + +func (client *client) resurrectDeadBrokers() { + client.lock.Lock() + defer client.lock.Unlock() + + Logger.Printf("client/brokers resurrecting %d dead seed brokers", len(client.deadSeeds)) + client.seedBrokers = append(client.seedBrokers, client.deadSeeds...) + client.deadSeeds = nil +} + +func (client *client) any() *Broker { + client.lock.RLock() + defer client.lock.RUnlock() + + if len(client.seedBrokers) > 0 { + _ = client.seedBrokers[0].Open(client.conf) + return client.seedBrokers[0] + } + + // not guaranteed to be random *or* deterministic + for _, broker := range client.brokers { + _ = broker.Open(client.conf) + return broker + } + + return nil +} + +// private caching/lazy metadata helpers + +type partitionType int + +const ( + allPartitions partitionType = iota + writablePartitions + // If you add any more types, update the partition cache in update() + + // Ensure this is the last partition type value + maxPartitionIndex +) + +func (client *client) cachedMetadata(topic string, partitionID int32) *PartitionMetadata { + client.lock.RLock() + defer client.lock.RUnlock() + + partitions := client.metadata[topic] + if partitions != nil { + return partitions[partitionID] + } + + return nil +} + +func (client *client) cachedPartitions(topic string, partitionSet partitionType) []int32 { + client.lock.RLock() + defer client.lock.RUnlock() + + partitions, exists := client.cachedPartitionsResults[topic] + + if !exists { + return nil + } + return partitions[partitionSet] +} + +func (client *client) setPartitionCache(topic string, partitionSet partitionType) []int32 { + partitions := client.metadata[topic] + + if partitions == nil { + return nil + } + + ret := make([]int32, 0, len(partitions)) + for _, partition := range partitions { + if partitionSet == writablePartitions && partition.Err == ErrLeaderNotAvailable { + continue + } + ret = append(ret, partition.ID) + } + + sort.Sort(int32Slice(ret)) + return ret +} + +func (client *client) cachedLeader(topic string, partitionID int32) (*Broker, error) { + client.lock.RLock() + defer client.lock.RUnlock() + + partitions := client.metadata[topic] + if partitions != nil { + metadata, ok := partitions[partitionID] + if ok { + if metadata.Err == ErrLeaderNotAvailable { + return nil, ErrLeaderNotAvailable + } + b := client.brokers[metadata.Leader] + if b == nil { + return nil, ErrLeaderNotAvailable + } + _ = b.Open(client.conf) + return b, nil + } + } + + return nil, ErrUnknownTopicOrPartition +} + +func (client *client) getOffset(topic string, partitionID int32, time int64) (int64, error) { + broker, err := client.Leader(topic, partitionID) + if err != nil { + return -1, err + } + + request := &OffsetRequest{} + if client.conf.Version.IsAtLeast(V0_10_1_0) { + request.Version = 1 + } + request.AddBlock(topic, partitionID, time, 1) + + response, err := broker.GetAvailableOffsets(request) + if err != nil { + _ = broker.Close() + return -1, err + } + + block := response.GetBlock(topic, partitionID) + if block == nil { + _ = broker.Close() + return -1, ErrIncompleteResponse + } + if block.Err != ErrNoError { + return -1, block.Err + } + if len(block.Offsets) != 1 { + return -1, ErrOffsetOutOfRange + } + + return block.Offsets[0], nil +} + +// core metadata update logic + +func (client *client) backgroundMetadataUpdater() { + defer close(client.closed) + + if client.conf.Metadata.RefreshFrequency == time.Duration(0) { + return + } + + ticker := time.NewTicker(client.conf.Metadata.RefreshFrequency) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := client.RefreshMetadata(); err != nil { + Logger.Println("Client background metadata update:", err) + } + case <-client.closer: + return + } + } +} + +func (client *client) tryRefreshMetadata(topics []string, attemptsRemaining int) error { + retry := func(err error) error { + if attemptsRemaining > 0 { + Logger.Printf("client/metadata retrying after %dms... (%d attempts remaining)\n", client.conf.Metadata.Retry.Backoff/time.Millisecond, attemptsRemaining) + time.Sleep(client.conf.Metadata.Retry.Backoff) + return client.tryRefreshMetadata(topics, attemptsRemaining-1) + } + return err + } + + for broker := client.any(); broker != nil; broker = client.any() { + if len(topics) > 0 { + Logger.Printf("client/metadata fetching metadata for %v from broker %s\n", topics, broker.addr) + } else { + Logger.Printf("client/metadata fetching metadata for all topics from broker %s\n", broker.addr) + } + response, err := broker.GetMetadata(&MetadataRequest{Topics: topics}) + + switch err.(type) { + case nil: + // valid response, use it + shouldRetry, err := client.updateMetadata(response) + if shouldRetry { + Logger.Println("client/metadata found some partitions to be leaderless") + return retry(err) // note: err can be nil + } + return err + + case PacketEncodingError: + // didn't even send, return the error + return err + default: + // some other error, remove that broker and try again + Logger.Println("client/metadata got error from broker while fetching metadata:", err) + _ = broker.Close() + client.deregisterBroker(broker) + } + } + + Logger.Println("client/metadata no available broker to send metadata request to") + client.resurrectDeadBrokers() + return retry(ErrOutOfBrokers) +} + +// if no fatal error, returns a list of topics that need retrying due to ErrLeaderNotAvailable +func (client *client) updateMetadata(data *MetadataResponse) (retry bool, err error) { + client.lock.Lock() + defer client.lock.Unlock() + + // For all the brokers we received: + // - if it is a new ID, save it + // - if it is an existing ID, but the address we have is stale, discard the old one and save it + // - otherwise ignore it, replacing our existing one would just bounce the connection + for _, broker := range data.Brokers { + client.registerBroker(broker) + } + + for _, topic := range data.Topics { + delete(client.metadata, topic.Name) + delete(client.cachedPartitionsResults, topic.Name) + + switch topic.Err { + case ErrNoError: + break + case ErrInvalidTopic, ErrTopicAuthorizationFailed: // don't retry, don't store partial results + err = topic.Err + continue + case ErrUnknownTopicOrPartition: // retry, do not store partial partition results + err = topic.Err + retry = true + continue + case ErrLeaderNotAvailable: // retry, but store partial partition results + retry = true + break + default: // don't retry, don't store partial results + Logger.Printf("Unexpected topic-level metadata error: %s", topic.Err) + err = topic.Err + continue + } + + client.metadata[topic.Name] = make(map[int32]*PartitionMetadata, len(topic.Partitions)) + for _, partition := range topic.Partitions { + client.metadata[topic.Name][partition.ID] = partition + if partition.Err == ErrLeaderNotAvailable { + retry = true + } + } + + var partitionCache [maxPartitionIndex][]int32 + partitionCache[allPartitions] = client.setPartitionCache(topic.Name, allPartitions) + partitionCache[writablePartitions] = client.setPartitionCache(topic.Name, writablePartitions) + client.cachedPartitionsResults[topic.Name] = partitionCache + } + + return +} + +func (client *client) cachedCoordinator(consumerGroup string) *Broker { + client.lock.RLock() + defer client.lock.RUnlock() + if coordinatorID, ok := client.coordinators[consumerGroup]; ok { + return client.brokers[coordinatorID] + } + return nil +} + +func (client *client) getConsumerMetadata(consumerGroup string, attemptsRemaining int) (*ConsumerMetadataResponse, error) { + retry := func(err error) (*ConsumerMetadataResponse, error) { + if attemptsRemaining > 0 { + Logger.Printf("client/coordinator retrying after %dms... (%d attempts remaining)\n", client.conf.Metadata.Retry.Backoff/time.Millisecond, attemptsRemaining) + time.Sleep(client.conf.Metadata.Retry.Backoff) + return client.getConsumerMetadata(consumerGroup, attemptsRemaining-1) + } + return nil, err + } + + for broker := client.any(); broker != nil; broker = client.any() { + Logger.Printf("client/coordinator requesting coordinator for consumergroup %s from %s\n", consumerGroup, broker.Addr()) + + request := new(ConsumerMetadataRequest) + request.ConsumerGroup = consumerGroup + + response, err := broker.GetConsumerMetadata(request) + + if err != nil { + Logger.Printf("client/coordinator request to broker %s failed: %s\n", broker.Addr(), err) + + switch err.(type) { + case PacketEncodingError: + return nil, err + default: + _ = broker.Close() + client.deregisterBroker(broker) + continue + } + } + + switch response.Err { + case ErrNoError: + Logger.Printf("client/coordinator coordinator for consumergroup %s is #%d (%s)\n", consumerGroup, response.Coordinator.ID(), response.Coordinator.Addr()) + return response, nil + + case ErrConsumerCoordinatorNotAvailable: + Logger.Printf("client/coordinator coordinator for consumer group %s is not available\n", consumerGroup) + + // This is very ugly, but this scenario will only happen once per cluster. + // The __consumer_offsets topic only has to be created one time. + // The number of partitions not configurable, but partition 0 should always exist. + if _, err := client.Leader("__consumer_offsets", 0); err != nil { + Logger.Printf("client/coordinator the __consumer_offsets topic is not initialized completely yet. Waiting 2 seconds...\n") + time.Sleep(2 * time.Second) + } + + return retry(ErrConsumerCoordinatorNotAvailable) + default: + return nil, response.Err + } + } + + Logger.Println("client/coordinator no available broker to send consumer metadata request to") + client.resurrectDeadBrokers() + return retry(ErrOutOfBrokers) +} diff --git a/vendor/github.com/Shopify/sarama/client_test.go b/vendor/github.com/Shopify/sarama/client_test.go new file mode 100644 index 00000000..0bac1b40 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/client_test.go @@ -0,0 +1,619 @@ +package sarama + +import ( + "io" + "sync" + "testing" + "time" +) + +func safeClose(t testing.TB, c io.Closer) { + err := c.Close() + if err != nil { + t.Error(err) + } +} + +func TestSimpleClient(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + seedBroker.Returns(new(MetadataResponse)) + + client, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + seedBroker.Close() + safeClose(t, client) +} + +func TestCachedPartitions(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + replicas := []int32{3, 1, 5} + isr := []int32{5, 1} + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker("localhost:12345", 2) + metadataResponse.AddTopicPartition("my_topic", 0, 2, replicas, isr, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 1, 2, replicas, isr, ErrLeaderNotAvailable) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Metadata.Retry.Max = 0 + c, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + client := c.(*client) + + // Verify they aren't cached the same + allP := client.cachedPartitionsResults["my_topic"][allPartitions] + writeP := client.cachedPartitionsResults["my_topic"][writablePartitions] + if len(allP) == len(writeP) { + t.Fatal("Invalid lengths!") + } + + tmp := client.cachedPartitionsResults["my_topic"] + // Verify we actually use the cache at all! + tmp[allPartitions] = []int32{1, 2, 3, 4} + client.cachedPartitionsResults["my_topic"] = tmp + if 4 != len(client.cachedPartitions("my_topic", allPartitions)) { + t.Fatal("Not using the cache!") + } + + seedBroker.Close() + safeClose(t, client) +} + +func TestClientDoesntCachePartitionsForTopicsWithErrors(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + replicas := []int32{seedBroker.BrokerID()} + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(seedBroker.Addr(), seedBroker.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 1, replicas[0], replicas, replicas, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 2, replicas[0], replicas, replicas, ErrNoError) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Metadata.Retry.Max = 0 + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + metadataResponse = new(MetadataResponse) + metadataResponse.AddTopic("unknown", ErrUnknownTopicOrPartition) + seedBroker.Returns(metadataResponse) + + partitions, err := client.Partitions("unknown") + + if err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, found", err) + } + if partitions != nil { + t.Errorf("Should return nil as partition list, found %v", partitions) + } + + // Should still use the cache of a known topic + partitions, err = client.Partitions("my_topic") + if err != nil { + t.Errorf("Expected no error, found %v", err) + } + + metadataResponse = new(MetadataResponse) + metadataResponse.AddTopic("unknown", ErrUnknownTopicOrPartition) + seedBroker.Returns(metadataResponse) + + // Should not use cache for unknown topic + partitions, err = client.Partitions("unknown") + if err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, found", err) + } + if partitions != nil { + t.Errorf("Should return nil as partition list, found %v", partitions) + } + + seedBroker.Close() + safeClose(t, client) +} + +func TestClientSeedBrokers(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker("localhost:12345", 2) + seedBroker.Returns(metadataResponse) + + client, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + seedBroker.Close() + safeClose(t, client) +} + +func TestClientMetadata(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 5) + + replicas := []int32{3, 1, 5} + isr := []int32{5, 1} + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), replicas, isr, ErrNoError) + metadataResponse.AddTopicPartition("my_topic", 1, leader.BrokerID(), replicas, isr, ErrLeaderNotAvailable) + seedBroker.Returns(metadataResponse) + + config := NewConfig() + config.Metadata.Retry.Max = 0 + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + topics, err := client.Topics() + if err != nil { + t.Error(err) + } else if len(topics) != 1 || topics[0] != "my_topic" { + t.Error("Client returned incorrect topics:", topics) + } + + parts, err := client.Partitions("my_topic") + if err != nil { + t.Error(err) + } else if len(parts) != 2 || parts[0] != 0 || parts[1] != 1 { + t.Error("Client returned incorrect partitions for my_topic:", parts) + } + + parts, err = client.WritablePartitions("my_topic") + if err != nil { + t.Error(err) + } else if len(parts) != 1 || parts[0] != 0 { + t.Error("Client returned incorrect writable partitions for my_topic:", parts) + } + + tst, err := client.Leader("my_topic", 0) + if err != nil { + t.Error(err) + } else if tst.ID() != 5 { + t.Error("Leader for my_topic had incorrect ID.") + } + + replicas, err = client.Replicas("my_topic", 0) + if err != nil { + t.Error(err) + } else if replicas[0] != 1 { + t.Error("Incorrect (or unsorted) replica") + } else if replicas[1] != 3 { + t.Error("Incorrect (or unsorted) replica") + } else if replicas[2] != 5 { + t.Error("Incorrect (or unsorted) replica") + } + + isr, err = client.InSyncReplicas("my_topic", 0) + if err != nil { + t.Error(err) + } else if len(isr) != 2 { + t.Error("Client returned incorrect ISRs for partition:", isr) + } else if isr[0] != 1 { + t.Error("Incorrect (or unsorted) ISR:", isr) + } else if isr[1] != 5 { + t.Error("Incorrect (or unsorted) ISR:", isr) + } + + leader.Close() + seedBroker.Close() + safeClose(t, client) +} + +func TestClientGetOffset(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + leaderAddr := leader.Addr() + + metadata := new(MetadataResponse) + metadata.AddTopicPartition("foo", 0, leader.BrokerID(), nil, nil, ErrNoError) + metadata.AddBroker(leaderAddr, leader.BrokerID()) + seedBroker.Returns(metadata) + + client, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + offsetResponse := new(OffsetResponse) + offsetResponse.AddTopicPartition("foo", 0, 123) + leader.Returns(offsetResponse) + + offset, err := client.GetOffset("foo", 0, OffsetNewest) + if err != nil { + t.Error(err) + } + if offset != 123 { + t.Error("Unexpected offset, got ", offset) + } + + leader.Close() + seedBroker.Returns(metadata) + + leader = NewMockBrokerAddr(t, 2, leaderAddr) + offsetResponse = new(OffsetResponse) + offsetResponse.AddTopicPartition("foo", 0, 456) + leader.Returns(offsetResponse) + + offset, err = client.GetOffset("foo", 0, OffsetNewest) + if err != nil { + t.Error(err) + } + if offset != 456 { + t.Error("Unexpected offset, got ", offset) + } + + seedBroker.Close() + leader.Close() + safeClose(t, client) +} + +func TestClientReceivingUnknownTopic(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + metadataResponse1 := new(MetadataResponse) + seedBroker.Returns(metadataResponse1) + + config := NewConfig() + config.Metadata.Retry.Max = 1 + config.Metadata.Retry.Backoff = 0 + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + metadataUnknownTopic := new(MetadataResponse) + metadataUnknownTopic.AddTopic("new_topic", ErrUnknownTopicOrPartition) + seedBroker.Returns(metadataUnknownTopic) + seedBroker.Returns(metadataUnknownTopic) + + if err := client.RefreshMetadata("new_topic"); err != ErrUnknownTopicOrPartition { + t.Error("ErrUnknownTopicOrPartition expected, got", err) + } + + // If we are asking for the leader of a partition of the non-existing topic. + // we will request metadata again. + seedBroker.Returns(metadataUnknownTopic) + seedBroker.Returns(metadataUnknownTopic) + + if _, err = client.Leader("new_topic", 1); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, got", err) + } + + safeClose(t, client) + seedBroker.Close() +} + +func TestClientReceivingPartialMetadata(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 5) + + metadataResponse1 := new(MetadataResponse) + metadataResponse1.AddBroker(leader.Addr(), leader.BrokerID()) + seedBroker.Returns(metadataResponse1) + + config := NewConfig() + config.Metadata.Retry.Max = 0 + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + replicas := []int32{leader.BrokerID(), seedBroker.BrokerID()} + + metadataPartial := new(MetadataResponse) + metadataPartial.AddTopic("new_topic", ErrLeaderNotAvailable) + metadataPartial.AddTopicPartition("new_topic", 0, leader.BrokerID(), replicas, replicas, ErrNoError) + metadataPartial.AddTopicPartition("new_topic", 1, -1, replicas, []int32{}, ErrLeaderNotAvailable) + seedBroker.Returns(metadataPartial) + + if err := client.RefreshMetadata("new_topic"); err != nil { + t.Error("ErrLeaderNotAvailable should not make RefreshMetadata respond with an error") + } + + // Even though the metadata was incomplete, we should be able to get the leader of a partition + // for which we did get a useful response, without doing additional requests. + + partition0Leader, err := client.Leader("new_topic", 0) + if err != nil { + t.Error(err) + } else if partition0Leader.Addr() != leader.Addr() { + t.Error("Unexpected leader returned", partition0Leader.Addr()) + } + + // If we are asking for the leader of a partition that didn't have a leader before, + // we will do another metadata request. + + seedBroker.Returns(metadataPartial) + + // Still no leader for the partition, so asking for it should return an error. + _, err = client.Leader("new_topic", 1) + if err != ErrLeaderNotAvailable { + t.Error("Expected ErrLeaderNotAvailable, got", err) + } + + safeClose(t, client) + seedBroker.Close() + leader.Close() +} + +func TestClientRefreshBehaviour(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 5) + + metadataResponse1 := new(MetadataResponse) + metadataResponse1.AddBroker(leader.Addr(), leader.BrokerID()) + seedBroker.Returns(metadataResponse1) + + metadataResponse2 := new(MetadataResponse) + metadataResponse2.AddTopicPartition("my_topic", 0xb, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse2) + + client, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + parts, err := client.Partitions("my_topic") + if err != nil { + t.Error(err) + } else if len(parts) != 1 || parts[0] != 0xb { + t.Error("Client returned incorrect partitions for my_topic:", parts) + } + + tst, err := client.Leader("my_topic", 0xb) + if err != nil { + t.Error(err) + } else if tst.ID() != 5 { + t.Error("Leader for my_topic had incorrect ID.") + } + + leader.Close() + seedBroker.Close() + safeClose(t, client) +} + +func TestClientResurrectDeadSeeds(t *testing.T) { + initialSeed := NewMockBroker(t, 0) + emptyMetadata := new(MetadataResponse) + initialSeed.Returns(emptyMetadata) + + conf := NewConfig() + conf.Metadata.Retry.Backoff = 0 + conf.Metadata.RefreshFrequency = 0 + c, err := NewClient([]string{initialSeed.Addr()}, conf) + if err != nil { + t.Fatal(err) + } + initialSeed.Close() + + client := c.(*client) + + seed1 := NewMockBroker(t, 1) + seed2 := NewMockBroker(t, 2) + seed3 := NewMockBroker(t, 3) + addr1 := seed1.Addr() + addr2 := seed2.Addr() + addr3 := seed3.Addr() + + // Overwrite the seed brokers with a fixed ordering to make this test deterministic. + safeClose(t, client.seedBrokers[0]) + client.seedBrokers = []*Broker{NewBroker(addr1), NewBroker(addr2), NewBroker(addr3)} + client.deadSeeds = []*Broker{} + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + if err := client.RefreshMetadata(); err != nil { + t.Error(err) + } + wg.Done() + }() + seed1.Close() + seed2.Close() + + seed1 = NewMockBrokerAddr(t, 1, addr1) + seed2 = NewMockBrokerAddr(t, 2, addr2) + + seed3.Close() + + seed1.Close() + seed2.Returns(emptyMetadata) + + wg.Wait() + + if len(client.seedBrokers) != 2 { + t.Error("incorrect number of live seeds") + } + if len(client.deadSeeds) != 1 { + t.Error("incorrect number of dead seeds") + } + + safeClose(t, c) +} + +func TestClientCoordinatorWithConsumerOffsetsTopic(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + staleCoordinator := NewMockBroker(t, 2) + freshCoordinator := NewMockBroker(t, 3) + + replicas := []int32{staleCoordinator.BrokerID(), freshCoordinator.BrokerID()} + metadataResponse1 := new(MetadataResponse) + metadataResponse1.AddBroker(staleCoordinator.Addr(), staleCoordinator.BrokerID()) + metadataResponse1.AddBroker(freshCoordinator.Addr(), freshCoordinator.BrokerID()) + metadataResponse1.AddTopicPartition("__consumer_offsets", 0, replicas[0], replicas, replicas, ErrNoError) + seedBroker.Returns(metadataResponse1) + + client, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + coordinatorResponse1 := new(ConsumerMetadataResponse) + coordinatorResponse1.Err = ErrConsumerCoordinatorNotAvailable + seedBroker.Returns(coordinatorResponse1) + + coordinatorResponse2 := new(ConsumerMetadataResponse) + coordinatorResponse2.CoordinatorID = staleCoordinator.BrokerID() + coordinatorResponse2.CoordinatorHost = "127.0.0.1" + coordinatorResponse2.CoordinatorPort = staleCoordinator.Port() + + seedBroker.Returns(coordinatorResponse2) + + broker, err := client.Coordinator("my_group") + if err != nil { + t.Error(err) + } + + if staleCoordinator.Addr() != broker.Addr() { + t.Errorf("Expected coordinator to have address %s, found %s", staleCoordinator.Addr(), broker.Addr()) + } + + if staleCoordinator.BrokerID() != broker.ID() { + t.Errorf("Expected coordinator to have ID %d, found %d", staleCoordinator.BrokerID(), broker.ID()) + } + + // Grab the cached value + broker2, err := client.Coordinator("my_group") + if err != nil { + t.Error(err) + } + + if broker2.Addr() != broker.Addr() { + t.Errorf("Expected the coordinator to be the same, but found %s vs. %s", broker2.Addr(), broker.Addr()) + } + + coordinatorResponse3 := new(ConsumerMetadataResponse) + coordinatorResponse3.CoordinatorID = freshCoordinator.BrokerID() + coordinatorResponse3.CoordinatorHost = "127.0.0.1" + coordinatorResponse3.CoordinatorPort = freshCoordinator.Port() + + seedBroker.Returns(coordinatorResponse3) + + // Refresh the locally cahced value because it's stale + if err := client.RefreshCoordinator("my_group"); err != nil { + t.Error(err) + } + + // Grab the fresh value + broker3, err := client.Coordinator("my_group") + if err != nil { + t.Error(err) + } + + if broker3.Addr() != freshCoordinator.Addr() { + t.Errorf("Expected the freshCoordinator to be returned, but found %s.", broker3.Addr()) + } + + freshCoordinator.Close() + staleCoordinator.Close() + seedBroker.Close() + safeClose(t, client) +} + +func TestClientCoordinatorWithoutConsumerOffsetsTopic(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + coordinator := NewMockBroker(t, 2) + + metadataResponse1 := new(MetadataResponse) + seedBroker.Returns(metadataResponse1) + + config := NewConfig() + config.Metadata.Retry.Max = 1 + config.Metadata.Retry.Backoff = 0 + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + coordinatorResponse1 := new(ConsumerMetadataResponse) + coordinatorResponse1.Err = ErrConsumerCoordinatorNotAvailable + seedBroker.Returns(coordinatorResponse1) + + metadataResponse2 := new(MetadataResponse) + metadataResponse2.AddTopic("__consumer_offsets", ErrUnknownTopicOrPartition) + seedBroker.Returns(metadataResponse2) + + replicas := []int32{coordinator.BrokerID()} + metadataResponse3 := new(MetadataResponse) + metadataResponse3.AddTopicPartition("__consumer_offsets", 0, replicas[0], replicas, replicas, ErrNoError) + seedBroker.Returns(metadataResponse3) + + coordinatorResponse2 := new(ConsumerMetadataResponse) + coordinatorResponse2.CoordinatorID = coordinator.BrokerID() + coordinatorResponse2.CoordinatorHost = "127.0.0.1" + coordinatorResponse2.CoordinatorPort = coordinator.Port() + + seedBroker.Returns(coordinatorResponse2) + + broker, err := client.Coordinator("my_group") + if err != nil { + t.Error(err) + } + + if coordinator.Addr() != broker.Addr() { + t.Errorf("Expected coordinator to have address %s, found %s", coordinator.Addr(), broker.Addr()) + } + + if coordinator.BrokerID() != broker.ID() { + t.Errorf("Expected coordinator to have ID %d, found %d", coordinator.BrokerID(), broker.ID()) + } + + coordinator.Close() + seedBroker.Close() + safeClose(t, client) +} + +func TestClientAutorefreshShutdownRace(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + + metadataResponse := new(MetadataResponse) + seedBroker.Returns(metadataResponse) + + conf := NewConfig() + conf.Metadata.RefreshFrequency = 100 * time.Millisecond + client, err := NewClient([]string{seedBroker.Addr()}, conf) + if err != nil { + t.Fatal(err) + } + + // Wait for the background refresh to kick in + time.Sleep(110 * time.Millisecond) + + done := make(chan none) + go func() { + // Close the client + if err := client.Close(); err != nil { + t.Fatal(err) + } + close(done) + }() + + // Wait for the Close to kick in + time.Sleep(10 * time.Millisecond) + + // Then return some metadata to the still-running background thread + leader := NewMockBroker(t, 2) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("foo", 0, leader.BrokerID(), []int32{2}, []int32{2}, ErrNoError) + seedBroker.Returns(metadataResponse) + + <-done + + seedBroker.Close() + + // give the update time to happen so we get a panic if it's still running (which it shouldn't) + time.Sleep(10 * time.Millisecond) +} diff --git a/vendor/github.com/Shopify/sarama/config.go b/vendor/github.com/Shopify/sarama/config.go new file mode 100644 index 00000000..5021c57e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/config.go @@ -0,0 +1,420 @@ +package sarama + +import ( + "crypto/tls" + "regexp" + "time" + + "github.com/rcrowley/go-metrics" +) + +const defaultClientID = "sarama" + +var validID = regexp.MustCompile(`\A[A-Za-z0-9._-]+\z`) + +// Config is used to pass multiple configuration options to Sarama's constructors. +type Config struct { + // Net is the namespace for network-level properties used by the Broker, and + // shared by the Client/Producer/Consumer. + Net struct { + // How many outstanding requests a connection is allowed to have before + // sending on it blocks (default 5). + MaxOpenRequests int + + // All three of the below configurations are similar to the + // `socket.timeout.ms` setting in JVM kafka. All of them default + // to 30 seconds. + DialTimeout time.Duration // How long to wait for the initial connection. + ReadTimeout time.Duration // How long to wait for a response. + WriteTimeout time.Duration // How long to wait for a transmit. + + TLS struct { + // Whether or not to use TLS when connecting to the broker + // (defaults to false). + Enable bool + // The TLS configuration to use for secure connections if + // enabled (defaults to nil). + Config *tls.Config + } + + // SASL based authentication with broker. While there are multiple SASL authentication methods + // the current implementation is limited to plaintext (SASL/PLAIN) authentication + SASL struct { + // Whether or not to use SASL authentication when connecting to the broker + // (defaults to false). + Enable bool + // Whether or not to send the Kafka SASL handshake first if enabled + // (defaults to true). You should only set this to false if you're using + // a non-Kafka SASL proxy. + Handshake bool + //username and password for SASL/PLAIN authentication + User string + Password string + } + + // KeepAlive specifies the keep-alive period for an active network connection. + // If zero, keep-alives are disabled. (default is 0: disabled). + KeepAlive time.Duration + } + + // Metadata is the namespace for metadata management properties used by the + // Client, and shared by the Producer/Consumer. + Metadata struct { + Retry struct { + // The total number of times to retry a metadata request when the + // cluster is in the middle of a leader election (default 3). + Max int + // How long to wait for leader election to occur before retrying + // (default 250ms). Similar to the JVM's `retry.backoff.ms`. + Backoff time.Duration + } + // How frequently to refresh the cluster metadata in the background. + // Defaults to 10 minutes. Set to 0 to disable. Similar to + // `topic.metadata.refresh.interval.ms` in the JVM version. + RefreshFrequency time.Duration + } + + // Producer is the namespace for configuration related to producing messages, + // used by the Producer. + Producer struct { + // The maximum permitted size of a message (defaults to 1000000). Should be + // set equal to or smaller than the broker's `message.max.bytes`. + MaxMessageBytes int + // The level of acknowledgement reliability needed from the broker (defaults + // to WaitForLocal). Equivalent to the `request.required.acks` setting of the + // JVM producer. + RequiredAcks RequiredAcks + // The maximum duration the broker will wait the receipt of the number of + // RequiredAcks (defaults to 10 seconds). This is only relevant when + // RequiredAcks is set to WaitForAll or a number > 1. Only supports + // millisecond resolution, nanoseconds will be truncated. Equivalent to + // the JVM producer's `request.timeout.ms` setting. + Timeout time.Duration + // The type of compression to use on messages (defaults to no compression). + // Similar to `compression.codec` setting of the JVM producer. + Compression CompressionCodec + // Generates partitioners for choosing the partition to send messages to + // (defaults to hashing the message key). Similar to the `partitioner.class` + // setting for the JVM producer. + Partitioner PartitionerConstructor + + // Return specifies what channels will be populated. If they are set to true, + // you must read from the respective channels to prevent deadlock. + Return struct { + // If enabled, successfully delivered messages will be returned on the + // Successes channel (default disabled). + Successes bool + + // If enabled, messages that failed to deliver will be returned on the + // Errors channel, including error (default enabled). + Errors bool + } + + // The following config options control how often messages are batched up and + // sent to the broker. By default, messages are sent as fast as possible, and + // all messages received while the current batch is in-flight are placed + // into the subsequent batch. + Flush struct { + // The best-effort number of bytes needed to trigger a flush. Use the + // global sarama.MaxRequestSize to set a hard upper limit. + Bytes int + // The best-effort number of messages needed to trigger a flush. Use + // `MaxMessages` to set a hard upper limit. + Messages int + // The best-effort frequency of flushes. Equivalent to + // `queue.buffering.max.ms` setting of JVM producer. + Frequency time.Duration + // The maximum number of messages the producer will send in a single + // broker request. Defaults to 0 for unlimited. Similar to + // `queue.buffering.max.messages` in the JVM producer. + MaxMessages int + } + + Retry struct { + // The total number of times to retry sending a message (default 3). + // Similar to the `message.send.max.retries` setting of the JVM producer. + Max int + // How long to wait for the cluster to settle between retries + // (default 100ms). Similar to the `retry.backoff.ms` setting of the + // JVM producer. + Backoff time.Duration + } + } + + // Consumer is the namespace for configuration related to consuming messages, + // used by the Consumer. + // + // Note that Sarama's Consumer type does not currently support automatic + // consumer-group rebalancing and offset tracking. For Zookeeper-based + // tracking (Kafka 0.8.2 and earlier), the https://github.com/wvanbergen/kafka + // library builds on Sarama to add this support. For Kafka-based tracking + // (Kafka 0.9 and later), the https://github.com/bsm/sarama-cluster library + // builds on Sarama to add this support. + Consumer struct { + Retry struct { + // How long to wait after a failing to read from a partition before + // trying again (default 2s). + Backoff time.Duration + } + + // Fetch is the namespace for controlling how many bytes are retrieved by any + // given request. + Fetch struct { + // The minimum number of message bytes to fetch in a request - the broker + // will wait until at least this many are available. The default is 1, + // as 0 causes the consumer to spin when no messages are available. + // Equivalent to the JVM's `fetch.min.bytes`. + Min int32 + // The default number of message bytes to fetch from the broker in each + // request (default 32768). This should be larger than the majority of + // your messages, or else the consumer will spend a lot of time + // negotiating sizes and not actually consuming. Similar to the JVM's + // `fetch.message.max.bytes`. + Default int32 + // The maximum number of message bytes to fetch from the broker in a + // single request. Messages larger than this will return + // ErrMessageTooLarge and will not be consumable, so you must be sure + // this is at least as large as your largest message. Defaults to 0 + // (no limit). Similar to the JVM's `fetch.message.max.bytes`. The + // global `sarama.MaxResponseSize` still applies. + Max int32 + } + // The maximum amount of time the broker will wait for Consumer.Fetch.Min + // bytes to become available before it returns fewer than that anyways. The + // default is 250ms, since 0 causes the consumer to spin when no events are + // available. 100-500ms is a reasonable range for most cases. Kafka only + // supports precision up to milliseconds; nanoseconds will be truncated. + // Equivalent to the JVM's `fetch.wait.max.ms`. + MaxWaitTime time.Duration + + // The maximum amount of time the consumer expects a message takes to process + // for the user. If writing to the Messages channel takes longer than this, + // that partition will stop fetching more messages until it can proceed again. + // Note that, since the Messages channel is buffered, the actual grace time is + // (MaxProcessingTime * ChanneBufferSize). Defaults to 100ms. + MaxProcessingTime time.Duration + + // Return specifies what channels will be populated. If they are set to true, + // you must read from them to prevent deadlock. + Return struct { + // If enabled, any errors that occurred while consuming are returned on + // the Errors channel (default disabled). + Errors bool + } + + // Offsets specifies configuration for how and when to commit consumed + // offsets. This currently requires the manual use of an OffsetManager + // but will eventually be automated. + Offsets struct { + // How frequently to commit updated offsets. Defaults to 1s. + CommitInterval time.Duration + + // The initial offset to use if no offset was previously committed. + // Should be OffsetNewest or OffsetOldest. Defaults to OffsetNewest. + Initial int64 + + // The retention duration for committed offsets. If zero, disabled + // (in which case the `offsets.retention.minutes` option on the + // broker will be used). Kafka only supports precision up to + // milliseconds; nanoseconds will be truncated. Requires Kafka + // broker version 0.9.0 or later. + // (default is 0: disabled). + Retention time.Duration + } + } + + // A user-provided string sent with every request to the brokers for logging, + // debugging, and auditing purposes. Defaults to "sarama", but you should + // probably set it to something specific to your application. + ClientID string + // The number of events to buffer in internal and external channels. This + // permits the producer and consumer to continue processing some messages + // in the background while user code is working, greatly improving throughput. + // Defaults to 256. + ChannelBufferSize int + // The version of Kafka that Sarama will assume it is running against. + // Defaults to the oldest supported stable version. Since Kafka provides + // backwards-compatibility, setting it to a version older than you have + // will not break anything, although it may prevent you from using the + // latest features. Setting it to a version greater than you are actually + // running may lead to random breakage. + Version KafkaVersion + // The registry to define metrics into. + // Defaults to a local registry. + // If you want to disable metrics gathering, set "metrics.UseNilMetrics" to "true" + // prior to starting Sarama. + // See Examples on how to use the metrics registry + MetricRegistry metrics.Registry +} + +// NewConfig returns a new configuration instance with sane defaults. +func NewConfig() *Config { + c := &Config{} + + c.Net.MaxOpenRequests = 5 + c.Net.DialTimeout = 30 * time.Second + c.Net.ReadTimeout = 30 * time.Second + c.Net.WriteTimeout = 30 * time.Second + c.Net.SASL.Handshake = true + + c.Metadata.Retry.Max = 3 + c.Metadata.Retry.Backoff = 250 * time.Millisecond + c.Metadata.RefreshFrequency = 10 * time.Minute + + c.Producer.MaxMessageBytes = 1000000 + c.Producer.RequiredAcks = WaitForLocal + c.Producer.Timeout = 10 * time.Second + c.Producer.Partitioner = NewHashPartitioner + c.Producer.Retry.Max = 3 + c.Producer.Retry.Backoff = 100 * time.Millisecond + c.Producer.Return.Errors = true + + c.Consumer.Fetch.Min = 1 + c.Consumer.Fetch.Default = 32768 + c.Consumer.Retry.Backoff = 2 * time.Second + c.Consumer.MaxWaitTime = 250 * time.Millisecond + c.Consumer.MaxProcessingTime = 100 * time.Millisecond + c.Consumer.Return.Errors = false + c.Consumer.Offsets.CommitInterval = 1 * time.Second + c.Consumer.Offsets.Initial = OffsetNewest + + c.ClientID = defaultClientID + c.ChannelBufferSize = 256 + c.Version = minVersion + c.MetricRegistry = metrics.NewRegistry() + + return c +} + +// Validate checks a Config instance. It will return a +// ConfigurationError if the specified values don't make sense. +func (c *Config) Validate() error { + // some configuration values should be warned on but not fail completely, do those first + if c.Net.TLS.Enable == false && c.Net.TLS.Config != nil { + Logger.Println("Net.TLS is disabled but a non-nil configuration was provided.") + } + if c.Net.SASL.Enable == false { + if c.Net.SASL.User != "" { + Logger.Println("Net.SASL is disabled but a non-empty username was provided.") + } + if c.Net.SASL.Password != "" { + Logger.Println("Net.SASL is disabled but a non-empty password was provided.") + } + } + if c.Producer.RequiredAcks > 1 { + Logger.Println("Producer.RequiredAcks > 1 is deprecated and will raise an exception with kafka >= 0.8.2.0.") + } + if c.Producer.MaxMessageBytes >= int(MaxRequestSize) { + Logger.Println("Producer.MaxMessageBytes must be smaller than MaxRequestSize; it will be ignored.") + } + if c.Producer.Flush.Bytes >= int(MaxRequestSize) { + Logger.Println("Producer.Flush.Bytes must be smaller than MaxRequestSize; it will be ignored.") + } + if (c.Producer.Flush.Bytes > 0 || c.Producer.Flush.Messages > 0) && c.Producer.Flush.Frequency == 0 { + Logger.Println("Producer.Flush: Bytes or Messages are set, but Frequency is not; messages may not get flushed.") + } + if c.Producer.Timeout%time.Millisecond != 0 { + Logger.Println("Producer.Timeout only supports millisecond resolution; nanoseconds will be truncated.") + } + if c.Consumer.MaxWaitTime < 100*time.Millisecond { + Logger.Println("Consumer.MaxWaitTime is very low, which can cause high CPU and network usage. See documentation for details.") + } + if c.Consumer.MaxWaitTime%time.Millisecond != 0 { + Logger.Println("Consumer.MaxWaitTime only supports millisecond precision; nanoseconds will be truncated.") + } + if c.Consumer.Offsets.Retention%time.Millisecond != 0 { + Logger.Println("Consumer.Offsets.Retention only supports millisecond precision; nanoseconds will be truncated.") + } + if c.ClientID == defaultClientID { + Logger.Println("ClientID is the default of 'sarama', you should consider setting it to something application-specific.") + } + + // validate Net values + switch { + case c.Net.MaxOpenRequests <= 0: + return ConfigurationError("Net.MaxOpenRequests must be > 0") + case c.Net.DialTimeout <= 0: + return ConfigurationError("Net.DialTimeout must be > 0") + case c.Net.ReadTimeout <= 0: + return ConfigurationError("Net.ReadTimeout must be > 0") + case c.Net.WriteTimeout <= 0: + return ConfigurationError("Net.WriteTimeout must be > 0") + case c.Net.KeepAlive < 0: + return ConfigurationError("Net.KeepAlive must be >= 0") + case c.Net.SASL.Enable == true && c.Net.SASL.User == "": + return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled") + case c.Net.SASL.Enable == true && c.Net.SASL.Password == "": + return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled") + } + + // validate the Metadata values + switch { + case c.Metadata.Retry.Max < 0: + return ConfigurationError("Metadata.Retry.Max must be >= 0") + case c.Metadata.Retry.Backoff < 0: + return ConfigurationError("Metadata.Retry.Backoff must be >= 0") + case c.Metadata.RefreshFrequency < 0: + return ConfigurationError("Metadata.RefreshFrequency must be >= 0") + } + + // validate the Producer values + switch { + case c.Producer.MaxMessageBytes <= 0: + return ConfigurationError("Producer.MaxMessageBytes must be > 0") + case c.Producer.RequiredAcks < -1: + return ConfigurationError("Producer.RequiredAcks must be >= -1") + case c.Producer.Timeout <= 0: + return ConfigurationError("Producer.Timeout must be > 0") + case c.Producer.Partitioner == nil: + return ConfigurationError("Producer.Partitioner must not be nil") + case c.Producer.Flush.Bytes < 0: + return ConfigurationError("Producer.Flush.Bytes must be >= 0") + case c.Producer.Flush.Messages < 0: + return ConfigurationError("Producer.Flush.Messages must be >= 0") + case c.Producer.Flush.Frequency < 0: + return ConfigurationError("Producer.Flush.Frequency must be >= 0") + case c.Producer.Flush.MaxMessages < 0: + return ConfigurationError("Producer.Flush.MaxMessages must be >= 0") + case c.Producer.Flush.MaxMessages > 0 && c.Producer.Flush.MaxMessages < c.Producer.Flush.Messages: + return ConfigurationError("Producer.Flush.MaxMessages must be >= Producer.Flush.Messages when set") + case c.Producer.Retry.Max < 0: + return ConfigurationError("Producer.Retry.Max must be >= 0") + case c.Producer.Retry.Backoff < 0: + return ConfigurationError("Producer.Retry.Backoff must be >= 0") + } + + if c.Producer.Compression == CompressionLZ4 && !c.Version.IsAtLeast(V0_10_0_0) { + return ConfigurationError("lz4 compression requires Version >= V0_10_0_0") + } + + // validate the Consumer values + switch { + case c.Consumer.Fetch.Min <= 0: + return ConfigurationError("Consumer.Fetch.Min must be > 0") + case c.Consumer.Fetch.Default <= 0: + return ConfigurationError("Consumer.Fetch.Default must be > 0") + case c.Consumer.Fetch.Max < 0: + return ConfigurationError("Consumer.Fetch.Max must be >= 0") + case c.Consumer.MaxWaitTime < 1*time.Millisecond: + return ConfigurationError("Consumer.MaxWaitTime must be >= 1ms") + case c.Consumer.MaxProcessingTime <= 0: + return ConfigurationError("Consumer.MaxProcessingTime must be > 0") + case c.Consumer.Retry.Backoff < 0: + return ConfigurationError("Consumer.Retry.Backoff must be >= 0") + case c.Consumer.Offsets.CommitInterval <= 0: + return ConfigurationError("Consumer.Offsets.CommitInterval must be > 0") + case c.Consumer.Offsets.Initial != OffsetOldest && c.Consumer.Offsets.Initial != OffsetNewest: + return ConfigurationError("Consumer.Offsets.Initial must be OffsetOldest or OffsetNewest") + + } + + // validate misc shared values + switch { + case c.ChannelBufferSize < 0: + return ConfigurationError("ChannelBufferSize must be >= 0") + case !validID.MatchString(c.ClientID): + return ConfigurationError("ClientID is invalid") + } + + return nil +} diff --git a/vendor/github.com/Shopify/sarama/config_test.go b/vendor/github.com/Shopify/sarama/config_test.go new file mode 100644 index 00000000..5fef6b36 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/config_test.go @@ -0,0 +1,70 @@ +package sarama + +import ( + "os" + "testing" + + "github.com/rcrowley/go-metrics" +) + +func TestDefaultConfigValidates(t *testing.T) { + config := NewConfig() + if err := config.Validate(); err != nil { + t.Error(err) + } + if config.MetricRegistry == nil { + t.Error("Expected non nil metrics.MetricRegistry, got nil") + } +} + +func TestInvalidClientIDConfigValidates(t *testing.T) { + config := NewConfig() + config.ClientID = "foo:bar" + if err := config.Validate(); string(err.(ConfigurationError)) != "ClientID is invalid" { + t.Error("Expected invalid ClientID, got ", err) + } +} + +func TestEmptyClientIDConfigValidates(t *testing.T) { + config := NewConfig() + config.ClientID = "" + if err := config.Validate(); string(err.(ConfigurationError)) != "ClientID is invalid" { + t.Error("Expected invalid ClientID, got ", err) + } +} + +func TestLZ4ConfigValidation(t *testing.T) { + config := NewConfig() + config.Producer.Compression = CompressionLZ4 + if err := config.Validate(); string(err.(ConfigurationError)) != "lz4 compression requires Version >= V0_10_0_0" { + t.Error("Expected invalid lz4/kakfa version error, got ", err) + } + config.Version = V0_10_0_0 + if err := config.Validate(); err != nil { + t.Error("Expected lz4 to work, got ", err) + } +} + +// This example shows how to integrate with an existing registry as well as publishing metrics +// on the standard output +func ExampleConfig_metrics() { + // Our application registry + appMetricRegistry := metrics.NewRegistry() + appGauge := metrics.GetOrRegisterGauge("m1", appMetricRegistry) + appGauge.Update(1) + + config := NewConfig() + // Use a prefix registry instead of the default local one + config.MetricRegistry = metrics.NewPrefixedChildRegistry(appMetricRegistry, "sarama.") + + // Simulate a metric created by sarama without starting a broker + saramaGauge := metrics.GetOrRegisterGauge("m2", config.MetricRegistry) + saramaGauge.Update(2) + + metrics.WriteOnce(appMetricRegistry, os.Stdout) + // Output: + // gauge m1 + // value: 1 + // gauge sarama.m2 + // value: 2 +} diff --git a/vendor/github.com/Shopify/sarama/consumer.go b/vendor/github.com/Shopify/sarama/consumer.go new file mode 100644 index 00000000..78d7fa2c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer.go @@ -0,0 +1,735 @@ +package sarama + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" +) + +// ConsumerMessage encapsulates a Kafka message returned by the consumer. +type ConsumerMessage struct { + Key, Value []byte + Topic string + Partition int32 + Offset int64 + Timestamp time.Time // only set if kafka is version 0.10+ +} + +// ConsumerError is what is provided to the user when an error occurs. +// It wraps an error and includes the topic and partition. +type ConsumerError struct { + Topic string + Partition int32 + Err error +} + +func (ce ConsumerError) Error() string { + return fmt.Sprintf("kafka: error while consuming %s/%d: %s", ce.Topic, ce.Partition, ce.Err) +} + +// ConsumerErrors is a type that wraps a batch of errors and implements the Error interface. +// It can be returned from the PartitionConsumer's Close methods to avoid the need to manually drain errors +// when stopping. +type ConsumerErrors []*ConsumerError + +func (ce ConsumerErrors) Error() string { + return fmt.Sprintf("kafka: %d errors while consuming", len(ce)) +} + +// Consumer manages PartitionConsumers which process Kafka messages from brokers. You MUST call Close() +// on a consumer to avoid leaks, it will not be garbage-collected automatically when it passes out of +// scope. +// +// Sarama's Consumer type does not currently support automatic consumer-group rebalancing and offset tracking. +// For Zookeeper-based tracking (Kafka 0.8.2 and earlier), the https://github.com/wvanbergen/kafka library +// builds on Sarama to add this support. For Kafka-based tracking (Kafka 0.9 and later), the +// https://github.com/bsm/sarama-cluster library builds on Sarama to add this support. +type Consumer interface { + + // Topics returns the set of available topics as retrieved from the cluster + // metadata. This method is the same as Client.Topics(), and is provided for + // convenience. + Topics() ([]string, error) + + // Partitions returns the sorted list of all partition IDs for the given topic. + // This method is the same as Client.Partitions(), and is provided for convenience. + Partitions(topic string) ([]int32, error) + + // ConsumePartition creates a PartitionConsumer on the given topic/partition with + // the given offset. It will return an error if this Consumer is already consuming + // on the given topic/partition. Offset can be a literal offset, or OffsetNewest + // or OffsetOldest + ConsumePartition(topic string, partition int32, offset int64) (PartitionConsumer, error) + + // HighWaterMarks returns the current high water marks for each topic and partition. + // Consistency between partitions is not guaranteed since high water marks are updated separately. + HighWaterMarks() map[string]map[int32]int64 + + // Close shuts down the consumer. It must be called after all child + // PartitionConsumers have already been closed. + Close() error +} + +type consumer struct { + client Client + conf *Config + ownClient bool + + lock sync.Mutex + children map[string]map[int32]*partitionConsumer + brokerConsumers map[*Broker]*brokerConsumer +} + +// NewConsumer creates a new consumer using the given broker addresses and configuration. +func NewConsumer(addrs []string, config *Config) (Consumer, error) { + client, err := NewClient(addrs, config) + if err != nil { + return nil, err + } + + c, err := NewConsumerFromClient(client) + if err != nil { + return nil, err + } + c.(*consumer).ownClient = true + return c, nil +} + +// NewConsumerFromClient creates a new consumer using the given client. It is still +// necessary to call Close() on the underlying client when shutting down this consumer. +func NewConsumerFromClient(client Client) (Consumer, error) { + // Check that we are not dealing with a closed Client before processing any other arguments + if client.Closed() { + return nil, ErrClosedClient + } + + c := &consumer{ + client: client, + conf: client.Config(), + children: make(map[string]map[int32]*partitionConsumer), + brokerConsumers: make(map[*Broker]*brokerConsumer), + } + + return c, nil +} + +func (c *consumer) Close() error { + if c.ownClient { + return c.client.Close() + } + return nil +} + +func (c *consumer) Topics() ([]string, error) { + return c.client.Topics() +} + +func (c *consumer) Partitions(topic string) ([]int32, error) { + return c.client.Partitions(topic) +} + +func (c *consumer) ConsumePartition(topic string, partition int32, offset int64) (PartitionConsumer, error) { + child := &partitionConsumer{ + consumer: c, + conf: c.conf, + topic: topic, + partition: partition, + messages: make(chan *ConsumerMessage, c.conf.ChannelBufferSize), + errors: make(chan *ConsumerError, c.conf.ChannelBufferSize), + feeder: make(chan *FetchResponse, 1), + trigger: make(chan none, 1), + dying: make(chan none), + fetchSize: c.conf.Consumer.Fetch.Default, + } + + if err := child.chooseStartingOffset(offset); err != nil { + return nil, err + } + + var leader *Broker + var err error + if leader, err = c.client.Leader(child.topic, child.partition); err != nil { + return nil, err + } + + if err := c.addChild(child); err != nil { + return nil, err + } + + go withRecover(child.dispatcher) + go withRecover(child.responseFeeder) + + child.broker = c.refBrokerConsumer(leader) + child.broker.input <- child + + return child, nil +} + +func (c *consumer) HighWaterMarks() map[string]map[int32]int64 { + c.lock.Lock() + defer c.lock.Unlock() + + hwms := make(map[string]map[int32]int64) + for topic, p := range c.children { + hwm := make(map[int32]int64, len(p)) + for partition, pc := range p { + hwm[partition] = pc.HighWaterMarkOffset() + } + hwms[topic] = hwm + } + + return hwms +} + +func (c *consumer) addChild(child *partitionConsumer) error { + c.lock.Lock() + defer c.lock.Unlock() + + topicChildren := c.children[child.topic] + if topicChildren == nil { + topicChildren = make(map[int32]*partitionConsumer) + c.children[child.topic] = topicChildren + } + + if topicChildren[child.partition] != nil { + return ConfigurationError("That topic/partition is already being consumed") + } + + topicChildren[child.partition] = child + return nil +} + +func (c *consumer) removeChild(child *partitionConsumer) { + c.lock.Lock() + defer c.lock.Unlock() + + delete(c.children[child.topic], child.partition) +} + +func (c *consumer) refBrokerConsumer(broker *Broker) *brokerConsumer { + c.lock.Lock() + defer c.lock.Unlock() + + bc := c.brokerConsumers[broker] + if bc == nil { + bc = c.newBrokerConsumer(broker) + c.brokerConsumers[broker] = bc + } + + bc.refs++ + + return bc +} + +func (c *consumer) unrefBrokerConsumer(brokerWorker *brokerConsumer) { + c.lock.Lock() + defer c.lock.Unlock() + + brokerWorker.refs-- + + if brokerWorker.refs == 0 { + close(brokerWorker.input) + if c.brokerConsumers[brokerWorker.broker] == brokerWorker { + delete(c.brokerConsumers, brokerWorker.broker) + } + } +} + +func (c *consumer) abandonBrokerConsumer(brokerWorker *brokerConsumer) { + c.lock.Lock() + defer c.lock.Unlock() + + delete(c.brokerConsumers, brokerWorker.broker) +} + +// PartitionConsumer + +// PartitionConsumer processes Kafka messages from a given topic and partition. You MUST call Close() +// or AsyncClose() on a PartitionConsumer to avoid leaks, it will not be garbage-collected automatically +// when it passes out of scope. +// +// The simplest way of using a PartitionConsumer is to loop over its Messages channel using a for/range +// loop. The PartitionConsumer will only stop itself in one case: when the offset being consumed is reported +// as out of range by the brokers. In this case you should decide what you want to do (try a different offset, +// notify a human, etc) and handle it appropriately. For all other error cases, it will just keep retrying. +// By default, it logs these errors to sarama.Logger; if you want to be notified directly of all errors, set +// your config's Consumer.Return.Errors to true and read from the Errors channel, using a select statement +// or a separate goroutine. Check out the Consumer examples to see implementations of these different approaches. +type PartitionConsumer interface { + + // AsyncClose initiates a shutdown of the PartitionConsumer. This method will + // return immediately, after which you should wait until the 'messages' and + // 'errors' channel are drained. It is required to call this function, or + // Close before a consumer object passes out of scope, as it will otherwise + // leak memory. You must call this before calling Close on the underlying client. + AsyncClose() + + // Close stops the PartitionConsumer from fetching messages. It is required to + // call this function (or AsyncClose) before a consumer object passes out of + // scope, as it will otherwise leak memory. You must call this before calling + // Close on the underlying client. + Close() error + + // Messages returns the read channel for the messages that are returned by + // the broker. + Messages() <-chan *ConsumerMessage + + // Errors returns a read channel of errors that occurred during consuming, if + // enabled. By default, errors are logged and not returned over this channel. + // If you want to implement any custom error handling, set your config's + // Consumer.Return.Errors setting to true, and read from this channel. + Errors() <-chan *ConsumerError + + // HighWaterMarkOffset returns the high water mark offset of the partition, + // i.e. the offset that will be used for the next message that will be produced. + // You can use this to determine how far behind the processing is. + HighWaterMarkOffset() int64 +} + +type partitionConsumer struct { + highWaterMarkOffset int64 // must be at the top of the struct because https://golang.org/pkg/sync/atomic/#pkg-note-BUG + consumer *consumer + conf *Config + topic string + partition int32 + + broker *brokerConsumer + messages chan *ConsumerMessage + errors chan *ConsumerError + feeder chan *FetchResponse + + trigger, dying chan none + responseResult error + + fetchSize int32 + offset int64 +} + +var errTimedOut = errors.New("timed out feeding messages to the user") // not user-facing + +func (child *partitionConsumer) sendError(err error) { + cErr := &ConsumerError{ + Topic: child.topic, + Partition: child.partition, + Err: err, + } + + if child.conf.Consumer.Return.Errors { + child.errors <- cErr + } else { + Logger.Println(cErr) + } +} + +func (child *partitionConsumer) dispatcher() { + for range child.trigger { + select { + case <-child.dying: + close(child.trigger) + case <-time.After(child.conf.Consumer.Retry.Backoff): + if child.broker != nil { + child.consumer.unrefBrokerConsumer(child.broker) + child.broker = nil + } + + Logger.Printf("consumer/%s/%d finding new broker\n", child.topic, child.partition) + if err := child.dispatch(); err != nil { + child.sendError(err) + child.trigger <- none{} + } + } + } + + if child.broker != nil { + child.consumer.unrefBrokerConsumer(child.broker) + } + child.consumer.removeChild(child) + close(child.feeder) +} + +func (child *partitionConsumer) dispatch() error { + if err := child.consumer.client.RefreshMetadata(child.topic); err != nil { + return err + } + + var leader *Broker + var err error + if leader, err = child.consumer.client.Leader(child.topic, child.partition); err != nil { + return err + } + + child.broker = child.consumer.refBrokerConsumer(leader) + + child.broker.input <- child + + return nil +} + +func (child *partitionConsumer) chooseStartingOffset(offset int64) error { + newestOffset, err := child.consumer.client.GetOffset(child.topic, child.partition, OffsetNewest) + if err != nil { + return err + } + oldestOffset, err := child.consumer.client.GetOffset(child.topic, child.partition, OffsetOldest) + if err != nil { + return err + } + + switch { + case offset == OffsetNewest: + child.offset = newestOffset + case offset == OffsetOldest: + child.offset = oldestOffset + case offset >= oldestOffset && offset <= newestOffset: + child.offset = offset + default: + return ErrOffsetOutOfRange + } + + return nil +} + +func (child *partitionConsumer) Messages() <-chan *ConsumerMessage { + return child.messages +} + +func (child *partitionConsumer) Errors() <-chan *ConsumerError { + return child.errors +} + +func (child *partitionConsumer) AsyncClose() { + // this triggers whatever broker owns this child to abandon it and close its trigger channel, which causes + // the dispatcher to exit its loop, which removes it from the consumer then closes its 'messages' and + // 'errors' channel (alternatively, if the child is already at the dispatcher for some reason, that will + // also just close itself) + close(child.dying) +} + +func (child *partitionConsumer) Close() error { + child.AsyncClose() + + go withRecover(func() { + for range child.messages { + // drain + } + }) + + var errors ConsumerErrors + for err := range child.errors { + errors = append(errors, err) + } + + if len(errors) > 0 { + return errors + } + return nil +} + +func (child *partitionConsumer) HighWaterMarkOffset() int64 { + return atomic.LoadInt64(&child.highWaterMarkOffset) +} + +func (child *partitionConsumer) responseFeeder() { + var msgs []*ConsumerMessage + expiryTimer := time.NewTimer(child.conf.Consumer.MaxProcessingTime) + expireTimedOut := false + +feederLoop: + for response := range child.feeder { + msgs, child.responseResult = child.parseResponse(response) + + for i, msg := range msgs { + if !expiryTimer.Stop() && !expireTimedOut { + // expiryTimer was expired; clear out the waiting msg + <-expiryTimer.C + } + expiryTimer.Reset(child.conf.Consumer.MaxProcessingTime) + expireTimedOut = false + + select { + case child.messages <- msg: + case <-expiryTimer.C: + expireTimedOut = true + child.responseResult = errTimedOut + child.broker.acks.Done() + for _, msg = range msgs[i:] { + child.messages <- msg + } + child.broker.input <- child + continue feederLoop + } + } + + child.broker.acks.Done() + } + + close(child.messages) + close(child.errors) +} + +func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) { + block := response.GetBlock(child.topic, child.partition) + if block == nil { + return nil, ErrIncompleteResponse + } + + if block.Err != ErrNoError { + return nil, block.Err + } + + if len(block.MsgSet.Messages) == 0 { + // We got no messages. If we got a trailing one then we need to ask for more data. + // Otherwise we just poll again and wait for one to be produced... + if block.MsgSet.PartialTrailingMessage { + if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max { + // we can't ask for more data, we've hit the configured limit + child.sendError(ErrMessageTooLarge) + child.offset++ // skip this one so we can keep processing future messages + } else { + child.fetchSize *= 2 + if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize > child.conf.Consumer.Fetch.Max { + child.fetchSize = child.conf.Consumer.Fetch.Max + } + } + } + + return nil, nil + } + + // we got messages, reset our fetch size in case it was increased for a previous request + child.fetchSize = child.conf.Consumer.Fetch.Default + atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset) + + incomplete := false + prelude := true + var messages []*ConsumerMessage + for _, msgBlock := range block.MsgSet.Messages { + + for _, msg := range msgBlock.Messages() { + offset := msg.Offset + if msg.Msg.Version >= 1 { + baseOffset := msgBlock.Offset - msgBlock.Messages()[len(msgBlock.Messages())-1].Offset + offset += baseOffset + } + if prelude && offset < child.offset { + continue + } + prelude = false + + if offset >= child.offset { + messages = append(messages, &ConsumerMessage{ + Topic: child.topic, + Partition: child.partition, + Key: msg.Msg.Key, + Value: msg.Msg.Value, + Offset: offset, + Timestamp: msg.Msg.Timestamp, + }) + child.offset = offset + 1 + } else { + incomplete = true + } + } + + } + + if incomplete || len(messages) == 0 { + return nil, ErrIncompleteResponse + } + return messages, nil +} + +// brokerConsumer + +type brokerConsumer struct { + consumer *consumer + broker *Broker + input chan *partitionConsumer + newSubscriptions chan []*partitionConsumer + wait chan none + subscriptions map[*partitionConsumer]none + acks sync.WaitGroup + refs int +} + +func (c *consumer) newBrokerConsumer(broker *Broker) *brokerConsumer { + bc := &brokerConsumer{ + consumer: c, + broker: broker, + input: make(chan *partitionConsumer), + newSubscriptions: make(chan []*partitionConsumer), + wait: make(chan none), + subscriptions: make(map[*partitionConsumer]none), + refs: 0, + } + + go withRecover(bc.subscriptionManager) + go withRecover(bc.subscriptionConsumer) + + return bc +} + +func (bc *brokerConsumer) subscriptionManager() { + var buffer []*partitionConsumer + + // The subscriptionManager constantly accepts new subscriptions on `input` (even when the main subscriptionConsumer + // goroutine is in the middle of a network request) and batches it up. The main worker goroutine picks + // up a batch of new subscriptions between every network request by reading from `newSubscriptions`, so we give + // it nil if no new subscriptions are available. We also write to `wait` only when new subscriptions is available, + // so the main goroutine can block waiting for work if it has none. + for { + if len(buffer) > 0 { + select { + case event, ok := <-bc.input: + if !ok { + goto done + } + buffer = append(buffer, event) + case bc.newSubscriptions <- buffer: + buffer = nil + case bc.wait <- none{}: + } + } else { + select { + case event, ok := <-bc.input: + if !ok { + goto done + } + buffer = append(buffer, event) + case bc.newSubscriptions <- nil: + } + } + } + +done: + close(bc.wait) + if len(buffer) > 0 { + bc.newSubscriptions <- buffer + } + close(bc.newSubscriptions) +} + +func (bc *brokerConsumer) subscriptionConsumer() { + <-bc.wait // wait for our first piece of work + + // the subscriptionConsumer ensures we will get nil right away if no new subscriptions is available + for newSubscriptions := range bc.newSubscriptions { + bc.updateSubscriptions(newSubscriptions) + + if len(bc.subscriptions) == 0 { + // We're about to be shut down or we're about to receive more subscriptions. + // Either way, the signal just hasn't propagated to our goroutine yet. + <-bc.wait + continue + } + + response, err := bc.fetchNewMessages() + + if err != nil { + Logger.Printf("consumer/broker/%d disconnecting due to error processing FetchRequest: %s\n", bc.broker.ID(), err) + bc.abort(err) + return + } + + bc.acks.Add(len(bc.subscriptions)) + for child := range bc.subscriptions { + child.feeder <- response + } + bc.acks.Wait() + bc.handleResponses() + } +} + +func (bc *brokerConsumer) updateSubscriptions(newSubscriptions []*partitionConsumer) { + for _, child := range newSubscriptions { + bc.subscriptions[child] = none{} + Logger.Printf("consumer/broker/%d added subscription to %s/%d\n", bc.broker.ID(), child.topic, child.partition) + } + + for child := range bc.subscriptions { + select { + case <-child.dying: + Logger.Printf("consumer/broker/%d closed dead subscription to %s/%d\n", bc.broker.ID(), child.topic, child.partition) + close(child.trigger) + delete(bc.subscriptions, child) + default: + break + } + } +} + +func (bc *brokerConsumer) handleResponses() { + // handles the response codes left for us by our subscriptions, and abandons ones that have been closed + for child := range bc.subscriptions { + result := child.responseResult + child.responseResult = nil + + switch result { + case nil: + break + case errTimedOut: + Logger.Printf("consumer/broker/%d abandoned subscription to %s/%d because consuming was taking too long\n", + bc.broker.ID(), child.topic, child.partition) + delete(bc.subscriptions, child) + case ErrOffsetOutOfRange: + // there's no point in retrying this it will just fail the same way again + // shut it down and force the user to choose what to do + child.sendError(result) + Logger.Printf("consumer/%s/%d shutting down because %s\n", child.topic, child.partition, result) + close(child.trigger) + delete(bc.subscriptions, child) + case ErrUnknownTopicOrPartition, ErrNotLeaderForPartition, ErrLeaderNotAvailable, ErrReplicaNotAvailable: + // not an error, but does need redispatching + Logger.Printf("consumer/broker/%d abandoned subscription to %s/%d because %s\n", + bc.broker.ID(), child.topic, child.partition, result) + child.trigger <- none{} + delete(bc.subscriptions, child) + default: + // dunno, tell the user and try redispatching + child.sendError(result) + Logger.Printf("consumer/broker/%d abandoned subscription to %s/%d because %s\n", + bc.broker.ID(), child.topic, child.partition, result) + child.trigger <- none{} + delete(bc.subscriptions, child) + } + } +} + +func (bc *brokerConsumer) abort(err error) { + bc.consumer.abandonBrokerConsumer(bc) + _ = bc.broker.Close() // we don't care about the error this might return, we already have one + + for child := range bc.subscriptions { + child.sendError(err) + child.trigger <- none{} + } + + for newSubscriptions := range bc.newSubscriptions { + if len(newSubscriptions) == 0 { + <-bc.wait + continue + } + for _, child := range newSubscriptions { + child.sendError(err) + child.trigger <- none{} + } + } +} + +func (bc *brokerConsumer) fetchNewMessages() (*FetchResponse, error) { + request := &FetchRequest{ + MinBytes: bc.consumer.conf.Consumer.Fetch.Min, + MaxWaitTime: int32(bc.consumer.conf.Consumer.MaxWaitTime / time.Millisecond), + } + if bc.consumer.conf.Version.IsAtLeast(V0_10_0_0) { + request.Version = 2 + } + + for child := range bc.subscriptions { + request.AddBlock(child.topic, child.partition, child.offset, child.fetchSize) + } + + return bc.broker.Fetch(request) +} diff --git a/vendor/github.com/Shopify/sarama/consumer_group_members.go b/vendor/github.com/Shopify/sarama/consumer_group_members.go new file mode 100644 index 00000000..9d92d350 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_group_members.go @@ -0,0 +1,94 @@ +package sarama + +type ConsumerGroupMemberMetadata struct { + Version int16 + Topics []string + UserData []byte +} + +func (m *ConsumerGroupMemberMetadata) encode(pe packetEncoder) error { + pe.putInt16(m.Version) + + if err := pe.putStringArray(m.Topics); err != nil { + return err + } + + if err := pe.putBytes(m.UserData); err != nil { + return err + } + + return nil +} + +func (m *ConsumerGroupMemberMetadata) decode(pd packetDecoder) (err error) { + if m.Version, err = pd.getInt16(); err != nil { + return + } + + if m.Topics, err = pd.getStringArray(); err != nil { + return + } + + if m.UserData, err = pd.getBytes(); err != nil { + return + } + + return nil +} + +type ConsumerGroupMemberAssignment struct { + Version int16 + Topics map[string][]int32 + UserData []byte +} + +func (m *ConsumerGroupMemberAssignment) encode(pe packetEncoder) error { + pe.putInt16(m.Version) + + if err := pe.putArrayLength(len(m.Topics)); err != nil { + return err + } + + for topic, partitions := range m.Topics { + if err := pe.putString(topic); err != nil { + return err + } + if err := pe.putInt32Array(partitions); err != nil { + return err + } + } + + if err := pe.putBytes(m.UserData); err != nil { + return err + } + + return nil +} + +func (m *ConsumerGroupMemberAssignment) decode(pd packetDecoder) (err error) { + if m.Version, err = pd.getInt16(); err != nil { + return + } + + var topicLen int + if topicLen, err = pd.getArrayLength(); err != nil { + return + } + + m.Topics = make(map[string][]int32, topicLen) + for i := 0; i < topicLen; i++ { + var topic string + if topic, err = pd.getString(); err != nil { + return + } + if m.Topics[topic], err = pd.getInt32Array(); err != nil { + return + } + } + + if m.UserData, err = pd.getBytes(); err != nil { + return + } + + return nil +} diff --git a/vendor/github.com/Shopify/sarama/consumer_group_members_test.go b/vendor/github.com/Shopify/sarama/consumer_group_members_test.go new file mode 100644 index 00000000..d65e8adc --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_group_members_test.go @@ -0,0 +1,73 @@ +package sarama + +import ( + "bytes" + "reflect" + "testing" +) + +var ( + groupMemberMetadata = []byte{ + 0, 1, // Version + 0, 0, 0, 2, // Topic array length + 0, 3, 'o', 'n', 'e', // Topic one + 0, 3, 't', 'w', 'o', // Topic two + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Userdata + } + groupMemberAssignment = []byte{ + 0, 1, // Version + 0, 0, 0, 1, // Topic array length + 0, 3, 'o', 'n', 'e', // Topic one + 0, 0, 0, 3, // Topic one, partition array length + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 4, // 0, 2, 4 + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Userdata + } +) + +func TestConsumerGroupMemberMetadata(t *testing.T) { + meta := &ConsumerGroupMemberMetadata{ + Version: 1, + Topics: []string{"one", "two"}, + UserData: []byte{0x01, 0x02, 0x03}, + } + + buf, err := encode(meta, nil) + if err != nil { + t.Error("Failed to encode data", err) + } else if !bytes.Equal(groupMemberMetadata, buf) { + t.Errorf("Encoded data does not match expectation\nexpected: %v\nactual: %v", groupMemberMetadata, buf) + } + + meta2 := new(ConsumerGroupMemberMetadata) + err = decode(buf, meta2) + if err != nil { + t.Error("Failed to decode data", err) + } else if !reflect.DeepEqual(meta, meta2) { + t.Errorf("Encoded data does not match expectation\nexpected: %v\nactual: %v", meta, meta2) + } +} + +func TestConsumerGroupMemberAssignment(t *testing.T) { + amt := &ConsumerGroupMemberAssignment{ + Version: 1, + Topics: map[string][]int32{ + "one": {0, 2, 4}, + }, + UserData: []byte{0x01, 0x02, 0x03}, + } + + buf, err := encode(amt, nil) + if err != nil { + t.Error("Failed to encode data", err) + } else if !bytes.Equal(groupMemberAssignment, buf) { + t.Errorf("Encoded data does not match expectation\nexpected: %v\nactual: %v", groupMemberAssignment, buf) + } + + amt2 := new(ConsumerGroupMemberAssignment) + err = decode(buf, amt2) + if err != nil { + t.Error("Failed to decode data", err) + } else if !reflect.DeepEqual(amt, amt2) { + t.Errorf("Encoded data does not match expectation\nexpected: %v\nactual: %v", amt, amt2) + } +} diff --git a/vendor/github.com/Shopify/sarama/consumer_metadata_request.go b/vendor/github.com/Shopify/sarama/consumer_metadata_request.go new file mode 100644 index 00000000..483be335 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_metadata_request.go @@ -0,0 +1,26 @@ +package sarama + +type ConsumerMetadataRequest struct { + ConsumerGroup string +} + +func (r *ConsumerMetadataRequest) encode(pe packetEncoder) error { + return pe.putString(r.ConsumerGroup) +} + +func (r *ConsumerMetadataRequest) decode(pd packetDecoder, version int16) (err error) { + r.ConsumerGroup, err = pd.getString() + return err +} + +func (r *ConsumerMetadataRequest) key() int16 { + return 10 +} + +func (r *ConsumerMetadataRequest) version() int16 { + return 0 +} + +func (r *ConsumerMetadataRequest) requiredVersion() KafkaVersion { + return V0_8_2_0 +} diff --git a/vendor/github.com/Shopify/sarama/consumer_metadata_request_test.go b/vendor/github.com/Shopify/sarama/consumer_metadata_request_test.go new file mode 100644 index 00000000..4509631a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_metadata_request_test.go @@ -0,0 +1,19 @@ +package sarama + +import "testing" + +var ( + consumerMetadataRequestEmpty = []byte{ + 0x00, 0x00} + + consumerMetadataRequestString = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r'} +) + +func TestConsumerMetadataRequest(t *testing.T) { + request := new(ConsumerMetadataRequest) + testRequest(t, "empty string", request, consumerMetadataRequestEmpty) + + request.ConsumerGroup = "foobar" + testRequest(t, "with string", request, consumerMetadataRequestString) +} diff --git a/vendor/github.com/Shopify/sarama/consumer_metadata_response.go b/vendor/github.com/Shopify/sarama/consumer_metadata_response.go new file mode 100644 index 00000000..6b9632bb --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_metadata_response.go @@ -0,0 +1,85 @@ +package sarama + +import ( + "net" + "strconv" +) + +type ConsumerMetadataResponse struct { + Err KError + Coordinator *Broker + CoordinatorID int32 // deprecated: use Coordinator.ID() + CoordinatorHost string // deprecated: use Coordinator.Addr() + CoordinatorPort int32 // deprecated: use Coordinator.Addr() +} + +func (r *ConsumerMetadataResponse) decode(pd packetDecoder, version int16) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + r.Err = KError(tmp) + + coordinator := new(Broker) + if err := coordinator.decode(pd); err != nil { + return err + } + if coordinator.addr == ":0" { + return nil + } + r.Coordinator = coordinator + + // this can all go away in 2.0, but we have to fill in deprecated fields to maintain + // backwards compatibility + host, portstr, err := net.SplitHostPort(r.Coordinator.Addr()) + if err != nil { + return err + } + port, err := strconv.ParseInt(portstr, 10, 32) + if err != nil { + return err + } + r.CoordinatorID = r.Coordinator.ID() + r.CoordinatorHost = host + r.CoordinatorPort = int32(port) + + return nil +} + +func (r *ConsumerMetadataResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + if r.Coordinator != nil { + host, portstr, err := net.SplitHostPort(r.Coordinator.Addr()) + if err != nil { + return err + } + port, err := strconv.ParseInt(portstr, 10, 32) + if err != nil { + return err + } + pe.putInt32(r.Coordinator.ID()) + if err := pe.putString(host); err != nil { + return err + } + pe.putInt32(int32(port)) + return nil + } + pe.putInt32(r.CoordinatorID) + if err := pe.putString(r.CoordinatorHost); err != nil { + return err + } + pe.putInt32(r.CoordinatorPort) + return nil +} + +func (r *ConsumerMetadataResponse) key() int16 { + return 10 +} + +func (r *ConsumerMetadataResponse) version() int16 { + return 0 +} + +func (r *ConsumerMetadataResponse) requiredVersion() KafkaVersion { + return V0_8_2_0 +} diff --git a/vendor/github.com/Shopify/sarama/consumer_metadata_response_test.go b/vendor/github.com/Shopify/sarama/consumer_metadata_response_test.go new file mode 100644 index 00000000..b748784d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_metadata_response_test.go @@ -0,0 +1,35 @@ +package sarama + +import "testing" + +var ( + consumerMetadataResponseError = []byte{ + 0x00, 0x0E, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00} + + consumerMetadataResponseSuccess = []byte{ + 0x00, 0x00, + 0x00, 0x00, 0x00, 0xAB, + 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0xCC, 0xDD} +) + +func TestConsumerMetadataResponseError(t *testing.T) { + response := ConsumerMetadataResponse{Err: ErrOffsetsLoadInProgress} + testResponse(t, "error", &response, consumerMetadataResponseError) +} + +func TestConsumerMetadataResponseSuccess(t *testing.T) { + broker := NewBroker("foo:52445") + broker.id = 0xAB + response := ConsumerMetadataResponse{ + Coordinator: broker, + CoordinatorID: 0xAB, + CoordinatorHost: "foo", + CoordinatorPort: 0xCCDD, + Err: ErrNoError, + } + testResponse(t, "success", &response, consumerMetadataResponseSuccess) +} diff --git a/vendor/github.com/Shopify/sarama/consumer_test.go b/vendor/github.com/Shopify/sarama/consumer_test.go new file mode 100644 index 00000000..387ede31 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/consumer_test.go @@ -0,0 +1,854 @@ +package sarama + +import ( + "log" + "os" + "os/signal" + "sync" + "testing" + "time" +) + +var testMsg = StringEncoder("Foo") + +// If a particular offset is provided then messages are consumed starting from +// that offset. +func TestConsumerOffsetManual(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + + mockFetchResponse := NewMockFetchResponse(t, 1) + for i := 0; i < 10; i++ { + mockFetchResponse.SetMessage("my_topic", 0, int64(i+1234), testMsg) + } + + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 0). + SetOffset("my_topic", 0, OffsetNewest, 2345), + "FetchRequest": mockFetchResponse, + }) + + // When + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + consumer, err := master.ConsumePartition("my_topic", 0, 1234) + if err != nil { + t.Fatal(err) + } + + // Then: messages starting from offset 1234 are consumed. + for i := 0; i < 10; i++ { + select { + case message := <-consumer.Messages(): + assertMessageOffset(t, message, int64(i+1234)) + case err := <-consumer.Errors(): + t.Error(err) + } + } + + safeClose(t, consumer) + safeClose(t, master) + broker0.Close() +} + +// If `OffsetNewest` is passed as the initial offset then the first consumed +// message is indeed corresponds to the offset that broker claims to be the +// newest in its metadata response. +func TestConsumerOffsetNewest(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetNewest, 10). + SetOffset("my_topic", 0, OffsetOldest, 7), + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 9, testMsg). + SetMessage("my_topic", 0, 10, testMsg). + SetMessage("my_topic", 0, 11, testMsg). + SetHighWaterMark("my_topic", 0, 14), + }) + + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When + consumer, err := master.ConsumePartition("my_topic", 0, OffsetNewest) + if err != nil { + t.Fatal(err) + } + + // Then + assertMessageOffset(t, <-consumer.Messages(), 10) + if hwmo := consumer.HighWaterMarkOffset(); hwmo != 14 { + t.Errorf("Expected high water mark offset 14, found %d", hwmo) + } + + safeClose(t, consumer) + safeClose(t, master) + broker0.Close() +} + +// It is possible to close a partition consumer and create the same anew. +func TestConsumerRecreate(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 0). + SetOffset("my_topic", 0, OffsetNewest, 1000), + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 10, testMsg), + }) + + c, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + pc, err := c.ConsumePartition("my_topic", 0, 10) + if err != nil { + t.Fatal(err) + } + assertMessageOffset(t, <-pc.Messages(), 10) + + // When + safeClose(t, pc) + pc, err = c.ConsumePartition("my_topic", 0, 10) + if err != nil { + t.Fatal(err) + } + + // Then + assertMessageOffset(t, <-pc.Messages(), 10) + + safeClose(t, pc) + safeClose(t, c) + broker0.Close() +} + +// An attempt to consume the same partition twice should fail. +func TestConsumerDuplicate(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 0). + SetOffset("my_topic", 0, OffsetNewest, 1000), + "FetchRequest": NewMockFetchResponse(t, 1), + }) + + config := NewConfig() + config.ChannelBufferSize = 0 + c, err := NewConsumer([]string{broker0.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + pc1, err := c.ConsumePartition("my_topic", 0, 0) + if err != nil { + t.Fatal(err) + } + + // When + pc2, err := c.ConsumePartition("my_topic", 0, 0) + + // Then + if pc2 != nil || err != ConfigurationError("That topic/partition is already being consumed") { + t.Fatal("A partition cannot be consumed twice at the same time") + } + + safeClose(t, pc1) + safeClose(t, c) + broker0.Close() +} + +// If consumer fails to refresh metadata it keeps retrying with frequency +// specified by `Config.Consumer.Retry.Backoff`. +func TestConsumerLeaderRefreshError(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 100) + + // Stage 1: my_topic/0 served by broker0 + Logger.Printf(" STAGE 1") + + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 123). + SetOffset("my_topic", 0, OffsetNewest, 1000), + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 123, testMsg), + }) + + config := NewConfig() + config.Net.ReadTimeout = 100 * time.Millisecond + config.Consumer.Retry.Backoff = 200 * time.Millisecond + config.Consumer.Return.Errors = true + config.Metadata.Retry.Max = 0 + c, err := NewConsumer([]string{broker0.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + pc, err := c.ConsumePartition("my_topic", 0, OffsetOldest) + if err != nil { + t.Fatal(err) + } + + assertMessageOffset(t, <-pc.Messages(), 123) + + // Stage 2: broker0 says that it is no longer the leader for my_topic/0, + // but the requests to retrieve metadata fail with network timeout. + Logger.Printf(" STAGE 2") + + fetchResponse2 := &FetchResponse{} + fetchResponse2.AddError("my_topic", 0, ErrNotLeaderForPartition) + + broker0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": NewMockWrapper(fetchResponse2), + }) + + if consErr := <-pc.Errors(); consErr.Err != ErrOutOfBrokers { + t.Errorf("Unexpected error: %v", consErr.Err) + } + + // Stage 3: finally the metadata returned by broker0 tells that broker1 is + // a new leader for my_topic/0. Consumption resumes. + + Logger.Printf(" STAGE 3") + + broker1 := NewMockBroker(t, 101) + + broker1.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 124, testMsg), + }) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetBroker(broker1.Addr(), broker1.BrokerID()). + SetLeader("my_topic", 0, broker1.BrokerID()), + }) + + assertMessageOffset(t, <-pc.Messages(), 124) + + safeClose(t, pc) + safeClose(t, c) + broker1.Close() + broker0.Close() +} + +func TestConsumerInvalidTopic(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 100) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()), + }) + + c, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When + pc, err := c.ConsumePartition("my_topic", 0, OffsetOldest) + + // Then + if pc != nil || err != ErrUnknownTopicOrPartition { + t.Errorf("Should fail with, err=%v", err) + } + + safeClose(t, c) + broker0.Close() +} + +// Nothing bad happens if a partition consumer that has no leader assigned at +// the moment is closed. +func TestConsumerClosePartitionWithoutLeader(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 100) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 123). + SetOffset("my_topic", 0, OffsetNewest, 1000), + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 123, testMsg), + }) + + config := NewConfig() + config.Net.ReadTimeout = 100 * time.Millisecond + config.Consumer.Retry.Backoff = 100 * time.Millisecond + config.Consumer.Return.Errors = true + config.Metadata.Retry.Max = 0 + c, err := NewConsumer([]string{broker0.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + pc, err := c.ConsumePartition("my_topic", 0, OffsetOldest) + if err != nil { + t.Fatal(err) + } + + assertMessageOffset(t, <-pc.Messages(), 123) + + // broker0 says that it is no longer the leader for my_topic/0, but the + // requests to retrieve metadata fail with network timeout. + fetchResponse2 := &FetchResponse{} + fetchResponse2.AddError("my_topic", 0, ErrNotLeaderForPartition) + + broker0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": NewMockWrapper(fetchResponse2), + }) + + // When + if consErr := <-pc.Errors(); consErr.Err != ErrOutOfBrokers { + t.Errorf("Unexpected error: %v", consErr.Err) + } + + // Then: the partition consumer can be closed without any problem. + safeClose(t, pc) + safeClose(t, c) + broker0.Close() +} + +// If the initial offset passed on partition consumer creation is out of the +// actual offset range for the partition, then the partition consumer stops +// immediately closing its output channels. +func TestConsumerShutsDownOutOfRange(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + fetchResponse := new(FetchResponse) + fetchResponse.AddError("my_topic", 0, ErrOffsetOutOfRange) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetNewest, 1234). + SetOffset("my_topic", 0, OffsetOldest, 7), + "FetchRequest": NewMockWrapper(fetchResponse), + }) + + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When + consumer, err := master.ConsumePartition("my_topic", 0, 101) + if err != nil { + t.Fatal(err) + } + + // Then: consumer should shut down closing its messages and errors channels. + if _, ok := <-consumer.Messages(); ok { + t.Error("Expected the consumer to shut down") + } + safeClose(t, consumer) + + safeClose(t, master) + broker0.Close() +} + +// If a fetch response contains messages with offsets that are smaller then +// requested, then such messages are ignored. +func TestConsumerExtraOffsets(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + fetchResponse1 := &FetchResponse{} + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 1) + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 2) + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 3) + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 4) + fetchResponse2 := &FetchResponse{} + fetchResponse2.AddError("my_topic", 0, ErrNoError) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetNewest, 1234). + SetOffset("my_topic", 0, OffsetOldest, 0), + "FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2), + }) + + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When + consumer, err := master.ConsumePartition("my_topic", 0, 3) + if err != nil { + t.Fatal(err) + } + + // Then: messages with offsets 1 and 2 are not returned even though they + // are present in the response. + assertMessageOffset(t, <-consumer.Messages(), 3) + assertMessageOffset(t, <-consumer.Messages(), 4) + + safeClose(t, consumer) + safeClose(t, master) + broker0.Close() +} + +// It is fine if offsets of fetched messages are not sequential (although +// strictly increasing!). +func TestConsumerNonSequentialOffsets(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + fetchResponse1 := &FetchResponse{} + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 5) + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 7) + fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 11) + fetchResponse2 := &FetchResponse{} + fetchResponse2.AddError("my_topic", 0, ErrNoError) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetNewest, 1234). + SetOffset("my_topic", 0, OffsetOldest, 0), + "FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2), + }) + + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When + consumer, err := master.ConsumePartition("my_topic", 0, 3) + if err != nil { + t.Fatal(err) + } + + // Then: messages with offsets 1 and 2 are not returned even though they + // are present in the response. + assertMessageOffset(t, <-consumer.Messages(), 5) + assertMessageOffset(t, <-consumer.Messages(), 7) + assertMessageOffset(t, <-consumer.Messages(), 11) + + safeClose(t, consumer) + safeClose(t, master) + broker0.Close() +} + +// If leadership for a partition is changing then consumer resolves the new +// leader and switches to it. +func TestConsumerRebalancingMultiplePartitions(t *testing.T) { + // initial setup + seedBroker := NewMockBroker(t, 10) + leader0 := NewMockBroker(t, 0) + leader1 := NewMockBroker(t, 1) + + seedBroker.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(leader0.Addr(), leader0.BrokerID()). + SetBroker(leader1.Addr(), leader1.BrokerID()). + SetLeader("my_topic", 0, leader0.BrokerID()). + SetLeader("my_topic", 1, leader1.BrokerID()), + }) + + mockOffsetResponse1 := NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 0). + SetOffset("my_topic", 0, OffsetNewest, 1000). + SetOffset("my_topic", 1, OffsetOldest, 0). + SetOffset("my_topic", 1, OffsetNewest, 1000) + leader0.SetHandlerByMap(map[string]MockResponse{ + "OffsetRequest": mockOffsetResponse1, + "FetchRequest": NewMockFetchResponse(t, 1), + }) + leader1.SetHandlerByMap(map[string]MockResponse{ + "OffsetRequest": mockOffsetResponse1, + "FetchRequest": NewMockFetchResponse(t, 1), + }) + + // launch test goroutines + config := NewConfig() + config.Consumer.Retry.Backoff = 50 + master, err := NewConsumer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + // we expect to end up (eventually) consuming exactly ten messages on each partition + var wg sync.WaitGroup + for i := int32(0); i < 2; i++ { + consumer, err := master.ConsumePartition("my_topic", i, 0) + if err != nil { + t.Error(err) + } + + go func(c PartitionConsumer) { + for err := range c.Errors() { + t.Error(err) + } + }(consumer) + + wg.Add(1) + go func(partition int32, c PartitionConsumer) { + for i := 0; i < 10; i++ { + message := <-consumer.Messages() + if message.Offset != int64(i) { + t.Error("Incorrect message offset!", i, partition, message.Offset) + } + if message.Partition != partition { + t.Error("Incorrect message partition!") + } + } + safeClose(t, consumer) + wg.Done() + }(i, consumer) + } + + time.Sleep(50 * time.Millisecond) + Logger.Printf(" STAGE 1") + // Stage 1: + // * my_topic/0 -> leader0 serves 4 messages + // * my_topic/1 -> leader1 serves 0 messages + + mockFetchResponse := NewMockFetchResponse(t, 1) + for i := 0; i < 4; i++ { + mockFetchResponse.SetMessage("my_topic", 0, int64(i), testMsg) + } + leader0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": mockFetchResponse, + }) + + time.Sleep(50 * time.Millisecond) + Logger.Printf(" STAGE 2") + // Stage 2: + // * leader0 says that it is no longer serving my_topic/0 + // * seedBroker tells that leader1 is serving my_topic/0 now + + // seed broker tells that the new partition 0 leader is leader1 + seedBroker.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetLeader("my_topic", 0, leader1.BrokerID()). + SetLeader("my_topic", 1, leader1.BrokerID()), + }) + + // leader0 says no longer leader of partition 0 + fetchResponse := new(FetchResponse) + fetchResponse.AddError("my_topic", 0, ErrNotLeaderForPartition) + leader0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": NewMockWrapper(fetchResponse), + }) + + time.Sleep(50 * time.Millisecond) + Logger.Printf(" STAGE 3") + // Stage 3: + // * my_topic/0 -> leader1 serves 3 messages + // * my_topic/1 -> leader1 server 8 messages + + // leader1 provides 3 message on partition 0, and 8 messages on partition 1 + mockFetchResponse2 := NewMockFetchResponse(t, 2) + for i := 4; i < 7; i++ { + mockFetchResponse2.SetMessage("my_topic", 0, int64(i), testMsg) + } + for i := 0; i < 8; i++ { + mockFetchResponse2.SetMessage("my_topic", 1, int64(i), testMsg) + } + leader1.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": mockFetchResponse2, + }) + + time.Sleep(50 * time.Millisecond) + Logger.Printf(" STAGE 4") + // Stage 4: + // * my_topic/0 -> leader1 serves 3 messages + // * my_topic/1 -> leader1 tells that it is no longer the leader + // * seedBroker tells that leader0 is a new leader for my_topic/1 + + // metadata assigns 0 to leader1 and 1 to leader0 + seedBroker.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetLeader("my_topic", 0, leader1.BrokerID()). + SetLeader("my_topic", 1, leader0.BrokerID()), + }) + + // leader1 provides three more messages on partition0, says no longer leader of partition1 + mockFetchResponse3 := NewMockFetchResponse(t, 3). + SetMessage("my_topic", 0, int64(7), testMsg). + SetMessage("my_topic", 0, int64(8), testMsg). + SetMessage("my_topic", 0, int64(9), testMsg) + fetchResponse4 := new(FetchResponse) + fetchResponse4.AddError("my_topic", 1, ErrNotLeaderForPartition) + leader1.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": NewMockSequence(mockFetchResponse3, fetchResponse4), + }) + + // leader0 provides two messages on partition 1 + mockFetchResponse4 := NewMockFetchResponse(t, 2) + for i := 8; i < 10; i++ { + mockFetchResponse4.SetMessage("my_topic", 1, int64(i), testMsg) + } + leader0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": mockFetchResponse4, + }) + + wg.Wait() + safeClose(t, master) + leader1.Close() + leader0.Close() + seedBroker.Close() +} + +// When two partitions have the same broker as the leader, if one partition +// consumer channel buffer is full then that does not affect the ability to +// read messages by the other consumer. +func TestConsumerInterleavedClose(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 0) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()). + SetLeader("my_topic", 1, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 1000). + SetOffset("my_topic", 0, OffsetNewest, 1100). + SetOffset("my_topic", 1, OffsetOldest, 2000). + SetOffset("my_topic", 1, OffsetNewest, 2100), + "FetchRequest": NewMockFetchResponse(t, 1). + SetMessage("my_topic", 0, 1000, testMsg). + SetMessage("my_topic", 0, 1001, testMsg). + SetMessage("my_topic", 0, 1002, testMsg). + SetMessage("my_topic", 1, 2000, testMsg), + }) + + config := NewConfig() + config.ChannelBufferSize = 0 + master, err := NewConsumer([]string{broker0.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + c0, err := master.ConsumePartition("my_topic", 0, 1000) + if err != nil { + t.Fatal(err) + } + + c1, err := master.ConsumePartition("my_topic", 1, 2000) + if err != nil { + t.Fatal(err) + } + + // When/Then: we can read from partition 0 even if nobody reads from partition 1 + assertMessageOffset(t, <-c0.Messages(), 1000) + assertMessageOffset(t, <-c0.Messages(), 1001) + assertMessageOffset(t, <-c0.Messages(), 1002) + + safeClose(t, c1) + safeClose(t, c0) + safeClose(t, master) + broker0.Close() +} + +func TestConsumerBounceWithReferenceOpen(t *testing.T) { + broker0 := NewMockBroker(t, 0) + broker0Addr := broker0.Addr() + broker1 := NewMockBroker(t, 1) + + mockMetadataResponse := NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetBroker(broker1.Addr(), broker1.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()). + SetLeader("my_topic", 1, broker1.BrokerID()) + + mockOffsetResponse := NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 1000). + SetOffset("my_topic", 0, OffsetNewest, 1100). + SetOffset("my_topic", 1, OffsetOldest, 2000). + SetOffset("my_topic", 1, OffsetNewest, 2100) + + mockFetchResponse := NewMockFetchResponse(t, 1) + for i := 0; i < 10; i++ { + mockFetchResponse.SetMessage("my_topic", 0, int64(1000+i), testMsg) + mockFetchResponse.SetMessage("my_topic", 1, int64(2000+i), testMsg) + } + + broker0.SetHandlerByMap(map[string]MockResponse{ + "OffsetRequest": mockOffsetResponse, + "FetchRequest": mockFetchResponse, + }) + broker1.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": mockMetadataResponse, + "OffsetRequest": mockOffsetResponse, + "FetchRequest": mockFetchResponse, + }) + + config := NewConfig() + config.Consumer.Return.Errors = true + config.Consumer.Retry.Backoff = 100 * time.Millisecond + config.ChannelBufferSize = 1 + master, err := NewConsumer([]string{broker1.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + c0, err := master.ConsumePartition("my_topic", 0, 1000) + if err != nil { + t.Fatal(err) + } + + c1, err := master.ConsumePartition("my_topic", 1, 2000) + if err != nil { + t.Fatal(err) + } + + // read messages from both partition to make sure that both brokers operate + // normally. + assertMessageOffset(t, <-c0.Messages(), 1000) + assertMessageOffset(t, <-c1.Messages(), 2000) + + // Simulate broker shutdown. Note that metadata response does not change, + // that is the leadership does not move to another broker. So partition + // consumer will keep retrying to restore the connection with the broker. + broker0.Close() + + // Make sure that while the partition/0 leader is down, consumer/partition/1 + // is capable of pulling messages from broker1. + for i := 1; i < 7; i++ { + offset := (<-c1.Messages()).Offset + if offset != int64(2000+i) { + t.Errorf("Expected offset %d from consumer/partition/1", int64(2000+i)) + } + } + + // Bring broker0 back to service. + broker0 = NewMockBrokerAddr(t, 0, broker0Addr) + broker0.SetHandlerByMap(map[string]MockResponse{ + "FetchRequest": mockFetchResponse, + }) + + // Read the rest of messages from both partitions. + for i := 7; i < 10; i++ { + assertMessageOffset(t, <-c1.Messages(), int64(2000+i)) + } + for i := 1; i < 10; i++ { + assertMessageOffset(t, <-c0.Messages(), int64(1000+i)) + } + + select { + case <-c0.Errors(): + default: + t.Errorf("Partition consumer should have detected broker restart") + } + + safeClose(t, c1) + safeClose(t, c0) + safeClose(t, master) + broker0.Close() + broker1.Close() +} + +func TestConsumerOffsetOutOfRange(t *testing.T) { + // Given + broker0 := NewMockBroker(t, 2) + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetNewest, 1234). + SetOffset("my_topic", 0, OffsetOldest, 2345), + }) + + master, err := NewConsumer([]string{broker0.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + // When/Then + if _, err := master.ConsumePartition("my_topic", 0, 0); err != ErrOffsetOutOfRange { + t.Fatal("Should return ErrOffsetOutOfRange, got:", err) + } + if _, err := master.ConsumePartition("my_topic", 0, 3456); err != ErrOffsetOutOfRange { + t.Fatal("Should return ErrOffsetOutOfRange, got:", err) + } + if _, err := master.ConsumePartition("my_topic", 0, -3); err != ErrOffsetOutOfRange { + t.Fatal("Should return ErrOffsetOutOfRange, got:", err) + } + + safeClose(t, master) + broker0.Close() +} + +func assertMessageOffset(t *testing.T, msg *ConsumerMessage, expectedOffset int64) { + if msg.Offset != expectedOffset { + t.Errorf("Incorrect message offset: expected=%d, actual=%d", expectedOffset, msg.Offset) + } +} + +// This example shows how to use the consumer to read messages +// from a single partition. +func ExampleConsumer() { + consumer, err := NewConsumer([]string{"localhost:9092"}, nil) + if err != nil { + panic(err) + } + + defer func() { + if err := consumer.Close(); err != nil { + log.Fatalln(err) + } + }() + + partitionConsumer, err := consumer.ConsumePartition("my_topic", 0, OffsetNewest) + if err != nil { + panic(err) + } + + defer func() { + if err := partitionConsumer.Close(); err != nil { + log.Fatalln(err) + } + }() + + // Trap SIGINT to trigger a shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + consumed := 0 +ConsumerLoop: + for { + select { + case msg := <-partitionConsumer.Messages(): + log.Printf("Consumed message offset %d\n", msg.Offset) + consumed++ + case <-signals: + break ConsumerLoop + } + } + + log.Printf("Consumed: %d\n", consumed) +} diff --git a/vendor/github.com/Shopify/sarama/crc32_field.go b/vendor/github.com/Shopify/sarama/crc32_field.go new file mode 100644 index 00000000..f4fde18a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/crc32_field.go @@ -0,0 +1,35 @@ +package sarama + +import ( + "encoding/binary" + "hash/crc32" +) + +// crc32Field implements the pushEncoder and pushDecoder interfaces for calculating CRC32s. +type crc32Field struct { + startOffset int +} + +func (c *crc32Field) saveOffset(in int) { + c.startOffset = in +} + +func (c *crc32Field) reserveLength() int { + return 4 +} + +func (c *crc32Field) run(curOffset int, buf []byte) error { + crc := crc32.ChecksumIEEE(buf[c.startOffset+4 : curOffset]) + binary.BigEndian.PutUint32(buf[c.startOffset:], crc) + return nil +} + +func (c *crc32Field) check(curOffset int, buf []byte) error { + crc := crc32.ChecksumIEEE(buf[c.startOffset+4 : curOffset]) + + if crc != binary.BigEndian.Uint32(buf[c.startOffset:]) { + return PacketDecodingError{"CRC didn't match"} + } + + return nil +} diff --git a/vendor/github.com/Shopify/sarama/describe_groups_request.go b/vendor/github.com/Shopify/sarama/describe_groups_request.go new file mode 100644 index 00000000..1fb35677 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/describe_groups_request.go @@ -0,0 +1,30 @@ +package sarama + +type DescribeGroupsRequest struct { + Groups []string +} + +func (r *DescribeGroupsRequest) encode(pe packetEncoder) error { + return pe.putStringArray(r.Groups) +} + +func (r *DescribeGroupsRequest) decode(pd packetDecoder, version int16) (err error) { + r.Groups, err = pd.getStringArray() + return +} + +func (r *DescribeGroupsRequest) key() int16 { + return 15 +} + +func (r *DescribeGroupsRequest) version() int16 { + return 0 +} + +func (r *DescribeGroupsRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} + +func (r *DescribeGroupsRequest) AddGroup(group string) { + r.Groups = append(r.Groups, group) +} diff --git a/vendor/github.com/Shopify/sarama/describe_groups_request_test.go b/vendor/github.com/Shopify/sarama/describe_groups_request_test.go new file mode 100644 index 00000000..7d45f3fe --- /dev/null +++ b/vendor/github.com/Shopify/sarama/describe_groups_request_test.go @@ -0,0 +1,34 @@ +package sarama + +import "testing" + +var ( + emptyDescribeGroupsRequest = []byte{0, 0, 0, 0} + + singleDescribeGroupsRequest = []byte{ + 0, 0, 0, 1, // 1 group + 0, 3, 'f', 'o', 'o', // group name: foo + } + + doubleDescribeGroupsRequest = []byte{ + 0, 0, 0, 2, // 2 groups + 0, 3, 'f', 'o', 'o', // group name: foo + 0, 3, 'b', 'a', 'r', // group name: foo + } +) + +func TestDescribeGroupsRequest(t *testing.T) { + var request *DescribeGroupsRequest + + request = new(DescribeGroupsRequest) + testRequest(t, "no groups", request, emptyDescribeGroupsRequest) + + request = new(DescribeGroupsRequest) + request.AddGroup("foo") + testRequest(t, "one group", request, singleDescribeGroupsRequest) + + request = new(DescribeGroupsRequest) + request.AddGroup("foo") + request.AddGroup("bar") + testRequest(t, "two groups", request, doubleDescribeGroupsRequest) +} diff --git a/vendor/github.com/Shopify/sarama/describe_groups_response.go b/vendor/github.com/Shopify/sarama/describe_groups_response.go new file mode 100644 index 00000000..542b3a97 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/describe_groups_response.go @@ -0,0 +1,187 @@ +package sarama + +type DescribeGroupsResponse struct { + Groups []*GroupDescription +} + +func (r *DescribeGroupsResponse) encode(pe packetEncoder) error { + if err := pe.putArrayLength(len(r.Groups)); err != nil { + return err + } + + for _, groupDescription := range r.Groups { + if err := groupDescription.encode(pe); err != nil { + return err + } + } + + return nil +} + +func (r *DescribeGroupsResponse) decode(pd packetDecoder, version int16) (err error) { + n, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Groups = make([]*GroupDescription, n) + for i := 0; i < n; i++ { + r.Groups[i] = new(GroupDescription) + if err := r.Groups[i].decode(pd); err != nil { + return err + } + } + + return nil +} + +func (r *DescribeGroupsResponse) key() int16 { + return 15 +} + +func (r *DescribeGroupsResponse) version() int16 { + return 0 +} + +func (r *DescribeGroupsResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} + +type GroupDescription struct { + Err KError + GroupId string + State string + ProtocolType string + Protocol string + Members map[string]*GroupMemberDescription +} + +func (gd *GroupDescription) encode(pe packetEncoder) error { + pe.putInt16(int16(gd.Err)) + + if err := pe.putString(gd.GroupId); err != nil { + return err + } + if err := pe.putString(gd.State); err != nil { + return err + } + if err := pe.putString(gd.ProtocolType); err != nil { + return err + } + if err := pe.putString(gd.Protocol); err != nil { + return err + } + + if err := pe.putArrayLength(len(gd.Members)); err != nil { + return err + } + + for memberId, groupMemberDescription := range gd.Members { + if err := pe.putString(memberId); err != nil { + return err + } + if err := groupMemberDescription.encode(pe); err != nil { + return err + } + } + + return nil +} + +func (gd *GroupDescription) decode(pd packetDecoder) (err error) { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + gd.Err = KError(kerr) + + if gd.GroupId, err = pd.getString(); err != nil { + return + } + if gd.State, err = pd.getString(); err != nil { + return + } + if gd.ProtocolType, err = pd.getString(); err != nil { + return + } + if gd.Protocol, err = pd.getString(); err != nil { + return + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + + gd.Members = make(map[string]*GroupMemberDescription) + for i := 0; i < n; i++ { + memberId, err := pd.getString() + if err != nil { + return err + } + + gd.Members[memberId] = new(GroupMemberDescription) + if err := gd.Members[memberId].decode(pd); err != nil { + return err + } + } + + return nil +} + +type GroupMemberDescription struct { + ClientId string + ClientHost string + MemberMetadata []byte + MemberAssignment []byte +} + +func (gmd *GroupMemberDescription) encode(pe packetEncoder) error { + if err := pe.putString(gmd.ClientId); err != nil { + return err + } + if err := pe.putString(gmd.ClientHost); err != nil { + return err + } + if err := pe.putBytes(gmd.MemberMetadata); err != nil { + return err + } + if err := pe.putBytes(gmd.MemberAssignment); err != nil { + return err + } + + return nil +} + +func (gmd *GroupMemberDescription) decode(pd packetDecoder) (err error) { + if gmd.ClientId, err = pd.getString(); err != nil { + return + } + if gmd.ClientHost, err = pd.getString(); err != nil { + return + } + if gmd.MemberMetadata, err = pd.getBytes(); err != nil { + return + } + if gmd.MemberAssignment, err = pd.getBytes(); err != nil { + return + } + + return nil +} + +func (gmd *GroupMemberDescription) GetMemberAssignment() (*ConsumerGroupMemberAssignment, error) { + assignment := new(ConsumerGroupMemberAssignment) + err := decode(gmd.MemberAssignment, assignment) + return assignment, err +} + +func (gmd *GroupMemberDescription) GetMemberMetadata() (*ConsumerGroupMemberMetadata, error) { + metadata := new(ConsumerGroupMemberMetadata) + err := decode(gmd.MemberMetadata, metadata) + return metadata, err +} diff --git a/vendor/github.com/Shopify/sarama/describe_groups_response_test.go b/vendor/github.com/Shopify/sarama/describe_groups_response_test.go new file mode 100644 index 00000000..dd397319 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/describe_groups_response_test.go @@ -0,0 +1,91 @@ +package sarama + +import ( + "reflect" + "testing" +) + +var ( + describeGroupsResponseEmpty = []byte{ + 0, 0, 0, 0, // no groups + } + + describeGroupsResponsePopulated = []byte{ + 0, 0, 0, 2, // 2 groups + + 0, 0, // no error + 0, 3, 'f', 'o', 'o', // Group ID + 0, 3, 'b', 'a', 'r', // State + 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // ConsumerProtocol type + 0, 3, 'b', 'a', 'z', // Protocol name + 0, 0, 0, 1, // 1 member + 0, 2, 'i', 'd', // Member ID + 0, 6, 's', 'a', 'r', 'a', 'm', 'a', // Client ID + 0, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', // Client Host + 0, 0, 0, 3, 0x01, 0x02, 0x03, // MemberMetadata + 0, 0, 0, 3, 0x04, 0x05, 0x06, // MemberAssignment + + 0, 30, // ErrGroupAuthorizationFailed + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, 0, 0, + } +) + +func TestDescribeGroupsResponse(t *testing.T) { + var response *DescribeGroupsResponse + + response = new(DescribeGroupsResponse) + testVersionDecodable(t, "empty", response, describeGroupsResponseEmpty, 0) + if len(response.Groups) != 0 { + t.Error("Expected no groups") + } + + response = new(DescribeGroupsResponse) + testVersionDecodable(t, "populated", response, describeGroupsResponsePopulated, 0) + if len(response.Groups) != 2 { + t.Error("Expected two groups") + } + + group0 := response.Groups[0] + if group0.Err != ErrNoError { + t.Error("Unxpected groups[0].Err, found", group0.Err) + } + if group0.GroupId != "foo" { + t.Error("Unxpected groups[0].GroupId, found", group0.GroupId) + } + if group0.State != "bar" { + t.Error("Unxpected groups[0].State, found", group0.State) + } + if group0.ProtocolType != "consumer" { + t.Error("Unxpected groups[0].ProtocolType, found", group0.ProtocolType) + } + if group0.Protocol != "baz" { + t.Error("Unxpected groups[0].Protocol, found", group0.Protocol) + } + if len(group0.Members) != 1 { + t.Error("Unxpected groups[0].Members, found", group0.Members) + } + if group0.Members["id"].ClientId != "sarama" { + t.Error("Unxpected groups[0].Members[id].ClientId, found", group0.Members["id"].ClientId) + } + if group0.Members["id"].ClientHost != "localhost" { + t.Error("Unxpected groups[0].Members[id].ClientHost, found", group0.Members["id"].ClientHost) + } + if !reflect.DeepEqual(group0.Members["id"].MemberMetadata, []byte{0x01, 0x02, 0x03}) { + t.Error("Unxpected groups[0].Members[id].MemberMetadata, found", group0.Members["id"].MemberMetadata) + } + if !reflect.DeepEqual(group0.Members["id"].MemberAssignment, []byte{0x04, 0x05, 0x06}) { + t.Error("Unxpected groups[0].Members[id].MemberAssignment, found", group0.Members["id"].MemberAssignment) + } + + group1 := response.Groups[1] + if group1.Err != ErrGroupAuthorizationFailed { + t.Error("Unxpected groups[1].Err, found", group0.Err) + } + if len(group1.Members) != 0 { + t.Error("Unxpected groups[1].Members, found", group0.Members) + } +} diff --git a/vendor/github.com/Shopify/sarama/dev.yml b/vendor/github.com/Shopify/sarama/dev.yml new file mode 100644 index 00000000..adcf9421 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/dev.yml @@ -0,0 +1,14 @@ +name: sarama + +up: + - go: + version: '1.8' + +commands: + test: + run: make test + desc: 'run unit tests' + +packages: + - git@github.com:Shopify/dev-shopify.git + diff --git a/vendor/github.com/Shopify/sarama/encoder_decoder.go b/vendor/github.com/Shopify/sarama/encoder_decoder.go new file mode 100644 index 00000000..7ce3bc0f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/encoder_decoder.go @@ -0,0 +1,89 @@ +package sarama + +import ( + "fmt" + + "github.com/rcrowley/go-metrics" +) + +// Encoder is the interface that wraps the basic Encode method. +// Anything implementing Encoder can be turned into bytes using Kafka's encoding rules. +type encoder interface { + encode(pe packetEncoder) error +} + +// Encode takes an Encoder and turns it into bytes while potentially recording metrics. +func encode(e encoder, metricRegistry metrics.Registry) ([]byte, error) { + if e == nil { + return nil, nil + } + + var prepEnc prepEncoder + var realEnc realEncoder + + err := e.encode(&prepEnc) + if err != nil { + return nil, err + } + + if prepEnc.length < 0 || prepEnc.length > int(MaxRequestSize) { + return nil, PacketEncodingError{fmt.Sprintf("invalid request size (%d)", prepEnc.length)} + } + + realEnc.raw = make([]byte, prepEnc.length) + realEnc.registry = metricRegistry + err = e.encode(&realEnc) + if err != nil { + return nil, err + } + + return realEnc.raw, nil +} + +// Decoder is the interface that wraps the basic Decode method. +// Anything implementing Decoder can be extracted from bytes using Kafka's encoding rules. +type decoder interface { + decode(pd packetDecoder) error +} + +type versionedDecoder interface { + decode(pd packetDecoder, version int16) error +} + +// Decode takes bytes and a Decoder and fills the fields of the decoder from the bytes, +// interpreted using Kafka's encoding rules. +func decode(buf []byte, in decoder) error { + if buf == nil { + return nil + } + + helper := realDecoder{raw: buf} + err := in.decode(&helper) + if err != nil { + return err + } + + if helper.off != len(buf) { + return PacketDecodingError{"invalid length"} + } + + return nil +} + +func versionedDecode(buf []byte, in versionedDecoder, version int16) error { + if buf == nil { + return nil + } + + helper := realDecoder{raw: buf} + err := in.decode(&helper, version) + if err != nil { + return err + } + + if helper.off != len(buf) { + return PacketDecodingError{"invalid length"} + } + + return nil +} diff --git a/vendor/github.com/Shopify/sarama/errors.go b/vendor/github.com/Shopify/sarama/errors.go new file mode 100644 index 00000000..e6800ed4 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/errors.go @@ -0,0 +1,221 @@ +package sarama + +import ( + "errors" + "fmt" +) + +// ErrOutOfBrokers is the error returned when the client has run out of brokers to talk to because all of them errored +// or otherwise failed to respond. +var ErrOutOfBrokers = errors.New("kafka: client has run out of available brokers to talk to (Is your cluster reachable?)") + +// ErrClosedClient is the error returned when a method is called on a client that has been closed. +var ErrClosedClient = errors.New("kafka: tried to use a client that was closed") + +// ErrIncompleteResponse is the error returned when the server returns a syntactically valid response, but it does +// not contain the expected information. +var ErrIncompleteResponse = errors.New("kafka: response did not contain all the expected topic/partition blocks") + +// ErrInvalidPartition is the error returned when a partitioner returns an invalid partition index +// (meaning one outside of the range [0...numPartitions-1]). +var ErrInvalidPartition = errors.New("kafka: partitioner returned an invalid partition index") + +// ErrAlreadyConnected is the error returned when calling Open() on a Broker that is already connected or connecting. +var ErrAlreadyConnected = errors.New("kafka: broker connection already initiated") + +// ErrNotConnected is the error returned when trying to send or call Close() on a Broker that is not connected. +var ErrNotConnected = errors.New("kafka: broker not connected") + +// ErrInsufficientData is returned when decoding and the packet is truncated. This can be expected +// when requesting messages, since as an optimization the server is allowed to return a partial message at the end +// of the message set. +var ErrInsufficientData = errors.New("kafka: insufficient data to decode packet, more bytes expected") + +// ErrShuttingDown is returned when a producer receives a message during shutdown. +var ErrShuttingDown = errors.New("kafka: message received by producer in process of shutting down") + +// ErrMessageTooLarge is returned when the next message to consume is larger than the configured Consumer.Fetch.Max +var ErrMessageTooLarge = errors.New("kafka: message is larger than Consumer.Fetch.Max") + +// PacketEncodingError is returned from a failure while encoding a Kafka packet. This can happen, for example, +// if you try to encode a string over 2^15 characters in length, since Kafka's encoding rules do not permit that. +type PacketEncodingError struct { + Info string +} + +func (err PacketEncodingError) Error() string { + return fmt.Sprintf("kafka: error encoding packet: %s", err.Info) +} + +// PacketDecodingError is returned when there was an error (other than truncated data) decoding the Kafka broker's response. +// This can be a bad CRC or length field, or any other invalid value. +type PacketDecodingError struct { + Info string +} + +func (err PacketDecodingError) Error() string { + return fmt.Sprintf("kafka: error decoding packet: %s", err.Info) +} + +// ConfigurationError is the type of error returned from a constructor (e.g. NewClient, or NewConsumer) +// when the specified configuration is invalid. +type ConfigurationError string + +func (err ConfigurationError) Error() string { + return "kafka: invalid configuration (" + string(err) + ")" +} + +// KError is the type of error that can be returned directly by the Kafka broker. +// See https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes +type KError int16 + +// Numeric error codes returned by the Kafka server. +const ( + ErrNoError KError = 0 + ErrUnknown KError = -1 + ErrOffsetOutOfRange KError = 1 + ErrInvalidMessage KError = 2 + ErrUnknownTopicOrPartition KError = 3 + ErrInvalidMessageSize KError = 4 + ErrLeaderNotAvailable KError = 5 + ErrNotLeaderForPartition KError = 6 + ErrRequestTimedOut KError = 7 + ErrBrokerNotAvailable KError = 8 + ErrReplicaNotAvailable KError = 9 + ErrMessageSizeTooLarge KError = 10 + ErrStaleControllerEpochCode KError = 11 + ErrOffsetMetadataTooLarge KError = 12 + ErrNetworkException KError = 13 + ErrOffsetsLoadInProgress KError = 14 + ErrConsumerCoordinatorNotAvailable KError = 15 + ErrNotCoordinatorForConsumer KError = 16 + ErrInvalidTopic KError = 17 + ErrMessageSetSizeTooLarge KError = 18 + ErrNotEnoughReplicas KError = 19 + ErrNotEnoughReplicasAfterAppend KError = 20 + ErrInvalidRequiredAcks KError = 21 + ErrIllegalGeneration KError = 22 + ErrInconsistentGroupProtocol KError = 23 + ErrInvalidGroupId KError = 24 + ErrUnknownMemberId KError = 25 + ErrInvalidSessionTimeout KError = 26 + ErrRebalanceInProgress KError = 27 + ErrInvalidCommitOffsetSize KError = 28 + ErrTopicAuthorizationFailed KError = 29 + ErrGroupAuthorizationFailed KError = 30 + ErrClusterAuthorizationFailed KError = 31 + ErrInvalidTimestamp KError = 32 + ErrUnsupportedSASLMechanism KError = 33 + ErrIllegalSASLState KError = 34 + ErrUnsupportedVersion KError = 35 + ErrTopicAlreadyExists KError = 36 + ErrInvalidPartitions KError = 37 + ErrInvalidReplicationFactor KError = 38 + ErrInvalidReplicaAssignment KError = 39 + ErrInvalidConfig KError = 40 + ErrNotController KError = 41 + ErrInvalidRequest KError = 42 + ErrUnsupportedForMessageFormat KError = 43 + ErrPolicyViolation KError = 44 +) + +func (err KError) Error() string { + // Error messages stolen/adapted from + // https://kafka.apache.org/protocol#protocol_error_codes + switch err { + case ErrNoError: + return "kafka server: Not an error, why are you printing me?" + case ErrUnknown: + return "kafka server: Unexpected (unknown?) server error." + case ErrOffsetOutOfRange: + return "kafka server: The requested offset is outside the range of offsets maintained by the server for the given topic/partition." + case ErrInvalidMessage: + return "kafka server: Message contents does not match its CRC." + case ErrUnknownTopicOrPartition: + return "kafka server: Request was for a topic or partition that does not exist on this broker." + case ErrInvalidMessageSize: + return "kafka server: The message has a negative size." + case ErrLeaderNotAvailable: + return "kafka server: In the middle of a leadership election, there is currently no leader for this partition and hence it is unavailable for writes." + case ErrNotLeaderForPartition: + return "kafka server: Tried to send a message to a replica that is not the leader for some partition. Your metadata is out of date." + case ErrRequestTimedOut: + return "kafka server: Request exceeded the user-specified time limit in the request." + case ErrBrokerNotAvailable: + return "kafka server: Broker not available. Not a client facing error, we should never receive this!!!" + case ErrReplicaNotAvailable: + return "kafka server: Replica information not available, one or more brokers are down." + case ErrMessageSizeTooLarge: + return "kafka server: Message was too large, server rejected it to avoid allocation error." + case ErrStaleControllerEpochCode: + return "kafka server: StaleControllerEpochCode (internal error code for broker-to-broker communication)." + case ErrOffsetMetadataTooLarge: + return "kafka server: Specified a string larger than the configured maximum for offset metadata." + case ErrNetworkException: + return "kafka server: The server disconnected before a response was received." + case ErrOffsetsLoadInProgress: + return "kafka server: The broker is still loading offsets after a leader change for that offset's topic partition." + case ErrConsumerCoordinatorNotAvailable: + return "kafka server: Offset's topic has not yet been created." + case ErrNotCoordinatorForConsumer: + return "kafka server: Request was for a consumer group that is not coordinated by this broker." + case ErrInvalidTopic: + return "kafka server: The request attempted to perform an operation on an invalid topic." + case ErrMessageSetSizeTooLarge: + return "kafka server: The request included message batch larger than the configured segment size on the server." + case ErrNotEnoughReplicas: + return "kafka server: Messages are rejected since there are fewer in-sync replicas than required." + case ErrNotEnoughReplicasAfterAppend: + return "kafka server: Messages are written to the log, but to fewer in-sync replicas than required." + case ErrInvalidRequiredAcks: + return "kafka server: The number of required acks is invalid (should be either -1, 0, or 1)." + case ErrIllegalGeneration: + return "kafka server: The provided generation id is not the current generation." + case ErrInconsistentGroupProtocol: + return "kafka server: The provider group protocol type is incompatible with the other members." + case ErrInvalidGroupId: + return "kafka server: The provided group id was empty." + case ErrUnknownMemberId: + return "kafka server: The provided member is not known in the current generation." + case ErrInvalidSessionTimeout: + return "kafka server: The provided session timeout is outside the allowed range." + case ErrRebalanceInProgress: + return "kafka server: A rebalance for the group is in progress. Please re-join the group." + case ErrInvalidCommitOffsetSize: + return "kafka server: The provided commit metadata was too large." + case ErrTopicAuthorizationFailed: + return "kafka server: The client is not authorized to access this topic." + case ErrGroupAuthorizationFailed: + return "kafka server: The client is not authorized to access this group." + case ErrClusterAuthorizationFailed: + return "kafka server: The client is not authorized to send this request type." + case ErrInvalidTimestamp: + return "kafka server: The timestamp of the message is out of acceptable range." + case ErrUnsupportedSASLMechanism: + return "kafka server: The broker does not support the requested SASL mechanism." + case ErrIllegalSASLState: + return "kafka server: Request is not valid given the current SASL state." + case ErrUnsupportedVersion: + return "kafka server: The version of API is not supported." + case ErrTopicAlreadyExists: + return "kafka server: Topic with this name already exists." + case ErrInvalidPartitions: + return "kafka server: Number of partitions is invalid." + case ErrInvalidReplicationFactor: + return "kafka server: Replication-factor is invalid." + case ErrInvalidReplicaAssignment: + return "kafka server: Replica assignment is invalid." + case ErrInvalidConfig: + return "kafka server: Configuration is invalid." + case ErrNotController: + return "kafka server: This is not the correct controller for this cluster." + case ErrInvalidRequest: + return "kafka server: This most likely occurs because of a request being malformed by the client library or the message was sent to an incompatible broker. See the broker logs for more details." + case ErrUnsupportedForMessageFormat: + return "kafka server: The requested operation is not supported by the message format version." + case ErrPolicyViolation: + return "kafka server: Request parameters do not satisfy the configured policy." + } + + return fmt.Sprintf("Unknown error, how did this happen? Error code = %d", err) +} diff --git a/vendor/github.com/Shopify/sarama/examples/README.md b/vendor/github.com/Shopify/sarama/examples/README.md new file mode 100644 index 00000000..85fecefd --- /dev/null +++ b/vendor/github.com/Shopify/sarama/examples/README.md @@ -0,0 +1,9 @@ +# Sarama examples + +This folder contains example applications to demonstrate the use of Sarama. For code snippet examples on how to use the different types in Sarama, see [Sarama's API documentation on godoc.org](https://godoc.org/github.com/Shopify/sarama) + +In these examples, we use `github.com/Shopify/sarama` as import path. We do this to ensure all the examples are up to date with the latest changes in Sarama. For your own applications, you may want to use `gopkg.in/Shopify/sarama.v1` to lock into a stable API version. + +#### HTTP server + +[http_server](./http_server) is a simple HTTP server uses both the sync producer to produce data as part of the request handling cycle, as well as the async producer to maintain an access log. It also uses the [mocks subpackage](https://godoc.org/github.com/Shopify/sarama/mocks) to test both. diff --git a/vendor/github.com/Shopify/sarama/examples/http_server/.gitignore b/vendor/github.com/Shopify/sarama/examples/http_server/.gitignore new file mode 100644 index 00000000..9f6ed425 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/examples/http_server/.gitignore @@ -0,0 +1,2 @@ +http_server +http_server.test diff --git a/vendor/github.com/Shopify/sarama/examples/http_server/README.md b/vendor/github.com/Shopify/sarama/examples/http_server/README.md new file mode 100644 index 00000000..5ff2bc25 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/examples/http_server/README.md @@ -0,0 +1,7 @@ +# HTTP server example + +This HTTP server example shows you how to use the AsyncProducer and SyncProducer, and how to test them using mocks. The server simply sends the data of the HTTP request's query string to Kafka, and send a 200 result if that succeeds. For every request, it will send an access log entry to Kafka as well in the background. + +If you need to know whether a message was successfully sent to the Kafka cluster before you can send your HTTP response, using the `SyncProducer` is probably the simplest way to achieve this. If you don't care, e.g. for the access log, using the `AsyncProducer` will let you fire and forget. You can send the HTTP response, while the message is being produced in the background. + +One important thing to note is that both the `SyncProducer` and `AsyncProducer` are **thread-safe**. Go's `http.Server` handles requests concurrently in different goroutines, but you can use a single producer safely. This will actually achieve efficiency gains as the producer will be able to batch messages from concurrent requests together. diff --git a/vendor/github.com/Shopify/sarama/examples/http_server/http_server.go b/vendor/github.com/Shopify/sarama/examples/http_server/http_server.go new file mode 100644 index 00000000..b6d83c5d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/examples/http_server/http_server.go @@ -0,0 +1,247 @@ +package main + +import ( + "github.com/Shopify/sarama" + + "crypto/tls" + "crypto/x509" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "strings" + "time" +) + +var ( + addr = flag.String("addr", ":8080", "The address to bind to") + brokers = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The Kafka brokers to connect to, as a comma separated list") + verbose = flag.Bool("verbose", false, "Turn on Sarama logging") + certFile = flag.String("certificate", "", "The optional certificate file for client authentication") + keyFile = flag.String("key", "", "The optional key file for client authentication") + caFile = flag.String("ca", "", "The optional certificate authority file for TLS client authentication") + verifySsl = flag.Bool("verify", false, "Optional verify ssl certificates chain") +) + +func main() { + flag.Parse() + + if *verbose { + sarama.Logger = log.New(os.Stdout, "[sarama] ", log.LstdFlags) + } + + if *brokers == "" { + flag.PrintDefaults() + os.Exit(1) + } + + brokerList := strings.Split(*brokers, ",") + log.Printf("Kafka brokers: %s", strings.Join(brokerList, ", ")) + + server := &Server{ + DataCollector: newDataCollector(brokerList), + AccessLogProducer: newAccessLogProducer(brokerList), + } + defer func() { + if err := server.Close(); err != nil { + log.Println("Failed to close server", err) + } + }() + + log.Fatal(server.Run(*addr)) +} + +func createTlsConfiguration() (t *tls.Config) { + if *certFile != "" && *keyFile != "" && *caFile != "" { + cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) + if err != nil { + log.Fatal(err) + } + + caCert, err := ioutil.ReadFile(*caFile) + if err != nil { + log.Fatal(err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + t = &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + InsecureSkipVerify: *verifySsl, + } + } + // will be nil by default if nothing is provided + return t +} + +type Server struct { + DataCollector sarama.SyncProducer + AccessLogProducer sarama.AsyncProducer +} + +func (s *Server) Close() error { + if err := s.DataCollector.Close(); err != nil { + log.Println("Failed to shut down data collector cleanly", err) + } + + if err := s.AccessLogProducer.Close(); err != nil { + log.Println("Failed to shut down access log producer cleanly", err) + } + + return nil +} + +func (s *Server) Handler() http.Handler { + return s.withAccessLog(s.collectQueryStringData()) +} + +func (s *Server) Run(addr string) error { + httpServer := &http.Server{ + Addr: addr, + Handler: s.Handler(), + } + + log.Printf("Listening for requests on %s...\n", addr) + return httpServer.ListenAndServe() +} + +func (s *Server) collectQueryStringData() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + + // We are not setting a message key, which means that all messages will + // be distributed randomly over the different partitions. + partition, offset, err := s.DataCollector.SendMessage(&sarama.ProducerMessage{ + Topic: "important", + Value: sarama.StringEncoder(r.URL.RawQuery), + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Failed to store your data:, %s", err) + } else { + // The tuple (topic, partition, offset) can be used as a unique identifier + // for a message in a Kafka cluster. + fmt.Fprintf(w, "Your data is stored with unique identifier important/%d/%d", partition, offset) + } + }) +} + +type accessLogEntry struct { + Method string `json:"method"` + Host string `json:"host"` + Path string `json:"path"` + IP string `json:"ip"` + ResponseTime float64 `json:"response_time"` + + encoded []byte + err error +} + +func (ale *accessLogEntry) ensureEncoded() { + if ale.encoded == nil && ale.err == nil { + ale.encoded, ale.err = json.Marshal(ale) + } +} + +func (ale *accessLogEntry) Length() int { + ale.ensureEncoded() + return len(ale.encoded) +} + +func (ale *accessLogEntry) Encode() ([]byte, error) { + ale.ensureEncoded() + return ale.encoded, ale.err +} + +func (s *Server) withAccessLog(next http.Handler) http.Handler { + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + started := time.Now() + + next.ServeHTTP(w, r) + + entry := &accessLogEntry{ + Method: r.Method, + Host: r.Host, + Path: r.RequestURI, + IP: r.RemoteAddr, + ResponseTime: float64(time.Since(started)) / float64(time.Second), + } + + // We will use the client's IP address as key. This will cause + // all the access log entries of the same IP address to end up + // on the same partition. + s.AccessLogProducer.Input() <- &sarama.ProducerMessage{ + Topic: "access_log", + Key: sarama.StringEncoder(r.RemoteAddr), + Value: entry, + } + }) +} + +func newDataCollector(brokerList []string) sarama.SyncProducer { + + // For the data collector, we are looking for strong consistency semantics. + // Because we don't change the flush settings, sarama will try to produce messages + // as fast as possible to keep latency low. + config := sarama.NewConfig() + config.Producer.RequiredAcks = sarama.WaitForAll // Wait for all in-sync replicas to ack the message + config.Producer.Retry.Max = 10 // Retry up to 10 times to produce the message + config.Producer.Return.Successes = true + tlsConfig := createTlsConfiguration() + if tlsConfig != nil { + config.Net.TLS.Config = tlsConfig + config.Net.TLS.Enable = true + } + + // On the broker side, you may want to change the following settings to get + // stronger consistency guarantees: + // - For your broker, set `unclean.leader.election.enable` to false + // - For the topic, you could increase `min.insync.replicas`. + + producer, err := sarama.NewSyncProducer(brokerList, config) + if err != nil { + log.Fatalln("Failed to start Sarama producer:", err) + } + + return producer +} + +func newAccessLogProducer(brokerList []string) sarama.AsyncProducer { + + // For the access log, we are looking for AP semantics, with high throughput. + // By creating batches of compressed messages, we reduce network I/O at a cost of more latency. + config := sarama.NewConfig() + tlsConfig := createTlsConfiguration() + if tlsConfig != nil { + config.Net.TLS.Enable = true + config.Net.TLS.Config = tlsConfig + } + config.Producer.RequiredAcks = sarama.WaitForLocal // Only wait for the leader to ack + config.Producer.Compression = sarama.CompressionSnappy // Compress messages + config.Producer.Flush.Frequency = 500 * time.Millisecond // Flush batches every 500ms + + producer, err := sarama.NewAsyncProducer(brokerList, config) + if err != nil { + log.Fatalln("Failed to start Sarama producer:", err) + } + + // We will just log to STDOUT if we're not able to produce messages. + // Note: messages will only be returned here after all retry attempts are exhausted. + go func() { + for err := range producer.Errors() { + log.Println("Failed to write access log entry:", err) + } + }() + + return producer +} diff --git a/vendor/github.com/Shopify/sarama/examples/http_server/http_server_test.go b/vendor/github.com/Shopify/sarama/examples/http_server/http_server_test.go new file mode 100644 index 00000000..7b2451e2 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/examples/http_server/http_server_test.go @@ -0,0 +1,109 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Shopify/sarama" + "github.com/Shopify/sarama/mocks" +) + +// In normal operation, we expect one access log entry, +// and one data collector entry. Let's assume both will succeed. +// We should return a HTTP 200 status. +func TestCollectSuccessfully(t *testing.T) { + dataCollectorMock := mocks.NewSyncProducer(t, nil) + dataCollectorMock.ExpectSendMessageAndSucceed() + + accessLogProducerMock := mocks.NewAsyncProducer(t, nil) + accessLogProducerMock.ExpectInputAndSucceed() + + // Now, use dependency injection to use the mocks. + s := &Server{ + DataCollector: dataCollectorMock, + AccessLogProducer: accessLogProducerMock, + } + + // The Server's Close call is important; it will call Close on + // the two mock producers, which will then validate whether all + // expectations are resolved. + defer safeClose(t, s) + + req, err := http.NewRequest("GET", "http://example.com/?data", nil) + if err != nil { + t.Fatal(err) + } + res := httptest.NewRecorder() + s.Handler().ServeHTTP(res, req) + + if res.Code != 200 { + t.Errorf("Expected HTTP status 200, found %d", res.Code) + } + + if string(res.Body.Bytes()) != "Your data is stored with unique identifier important/0/1" { + t.Error("Unexpected response body", res.Body) + } +} + +// Now, let's see if we handle the case of not being able to produce +// to the data collector properly. In this case we should return a 500 status. +func TestCollectionFailure(t *testing.T) { + dataCollectorMock := mocks.NewSyncProducer(t, nil) + dataCollectorMock.ExpectSendMessageAndFail(sarama.ErrRequestTimedOut) + + accessLogProducerMock := mocks.NewAsyncProducer(t, nil) + accessLogProducerMock.ExpectInputAndSucceed() + + s := &Server{ + DataCollector: dataCollectorMock, + AccessLogProducer: accessLogProducerMock, + } + defer safeClose(t, s) + + req, err := http.NewRequest("GET", "http://example.com/?data", nil) + if err != nil { + t.Fatal(err) + } + res := httptest.NewRecorder() + s.Handler().ServeHTTP(res, req) + + if res.Code != 500 { + t.Errorf("Expected HTTP status 500, found %d", res.Code) + } +} + +// We don't expect any data collector calls because the path is wrong, +// so we are not setting any expectations on the dataCollectorMock. It +// will still generate an access log entry though. +func TestWrongPath(t *testing.T) { + dataCollectorMock := mocks.NewSyncProducer(t, nil) + + accessLogProducerMock := mocks.NewAsyncProducer(t, nil) + accessLogProducerMock.ExpectInputAndSucceed() + + s := &Server{ + DataCollector: dataCollectorMock, + AccessLogProducer: accessLogProducerMock, + } + defer safeClose(t, s) + + req, err := http.NewRequest("GET", "http://example.com/wrong?data", nil) + if err != nil { + t.Fatal(err) + } + res := httptest.NewRecorder() + + s.Handler().ServeHTTP(res, req) + + if res.Code != 404 { + t.Errorf("Expected HTTP status 404, found %d", res.Code) + } +} + +func safeClose(t *testing.T, o io.Closer) { + if err := o.Close(); err != nil { + t.Error(err) + } +} diff --git a/vendor/github.com/Shopify/sarama/fetch_request.go b/vendor/github.com/Shopify/sarama/fetch_request.go new file mode 100644 index 00000000..ab817a06 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/fetch_request.go @@ -0,0 +1,136 @@ +package sarama + +type fetchRequestBlock struct { + fetchOffset int64 + maxBytes int32 +} + +func (b *fetchRequestBlock) encode(pe packetEncoder) error { + pe.putInt64(b.fetchOffset) + pe.putInt32(b.maxBytes) + return nil +} + +func (b *fetchRequestBlock) decode(pd packetDecoder) (err error) { + if b.fetchOffset, err = pd.getInt64(); err != nil { + return err + } + if b.maxBytes, err = pd.getInt32(); err != nil { + return err + } + return nil +} + +type FetchRequest struct { + MaxWaitTime int32 + MinBytes int32 + Version int16 + blocks map[string]map[int32]*fetchRequestBlock +} + +func (r *FetchRequest) encode(pe packetEncoder) (err error) { + pe.putInt32(-1) // replica ID is always -1 for clients + pe.putInt32(r.MaxWaitTime) + pe.putInt32(r.MinBytes) + err = pe.putArrayLength(len(r.blocks)) + if err != nil { + return err + } + for topic, blocks := range r.blocks { + err = pe.putString(topic) + if err != nil { + return err + } + err = pe.putArrayLength(len(blocks)) + if err != nil { + return err + } + for partition, block := range blocks { + pe.putInt32(partition) + err = block.encode(pe) + if err != nil { + return err + } + } + } + return nil +} + +func (r *FetchRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + if _, err = pd.getInt32(); err != nil { + return err + } + if r.MaxWaitTime, err = pd.getInt32(); err != nil { + return err + } + if r.MinBytes, err = pd.getInt32(); err != nil { + return err + } + topicCount, err := pd.getArrayLength() + if err != nil { + return err + } + if topicCount == 0 { + return nil + } + r.blocks = make(map[string]map[int32]*fetchRequestBlock) + for i := 0; i < topicCount; i++ { + topic, err := pd.getString() + if err != nil { + return err + } + partitionCount, err := pd.getArrayLength() + if err != nil { + return err + } + r.blocks[topic] = make(map[int32]*fetchRequestBlock) + for j := 0; j < partitionCount; j++ { + partition, err := pd.getInt32() + if err != nil { + return err + } + fetchBlock := &fetchRequestBlock{} + if err = fetchBlock.decode(pd); err != nil { + return err + } + r.blocks[topic][partition] = fetchBlock + } + } + return nil +} + +func (r *FetchRequest) key() int16 { + return 1 +} + +func (r *FetchRequest) version() int16 { + return r.Version +} + +func (r *FetchRequest) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_9_0_0 + case 2: + return V0_10_0_0 + default: + return minVersion + } +} + +func (r *FetchRequest) AddBlock(topic string, partitionID int32, fetchOffset int64, maxBytes int32) { + if r.blocks == nil { + r.blocks = make(map[string]map[int32]*fetchRequestBlock) + } + + if r.blocks[topic] == nil { + r.blocks[topic] = make(map[int32]*fetchRequestBlock) + } + + tmp := new(fetchRequestBlock) + tmp.maxBytes = maxBytes + tmp.fetchOffset = fetchOffset + + r.blocks[topic][partitionID] = tmp +} diff --git a/vendor/github.com/Shopify/sarama/fetch_request_test.go b/vendor/github.com/Shopify/sarama/fetch_request_test.go new file mode 100644 index 00000000..32c083c7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/fetch_request_test.go @@ -0,0 +1,34 @@ +package sarama + +import "testing" + +var ( + fetchRequestNoBlocks = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00} + + fetchRequestWithProperties = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0xEF, + 0x00, 0x00, 0x00, 0x00} + + fetchRequestOneBlock = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x56} +) + +func TestFetchRequest(t *testing.T) { + request := new(FetchRequest) + testRequest(t, "no blocks", request, fetchRequestNoBlocks) + + request.MaxWaitTime = 0x20 + request.MinBytes = 0xEF + testRequest(t, "with properties", request, fetchRequestWithProperties) + + request.MaxWaitTime = 0 + request.MinBytes = 0 + request.AddBlock("topic", 0x12, 0x34, 0x56) + testRequest(t, "one block", request, fetchRequestOneBlock) +} diff --git a/vendor/github.com/Shopify/sarama/fetch_response.go b/vendor/github.com/Shopify/sarama/fetch_response.go new file mode 100644 index 00000000..b56b166c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/fetch_response.go @@ -0,0 +1,210 @@ +package sarama + +import "time" + +type FetchResponseBlock struct { + Err KError + HighWaterMarkOffset int64 + MsgSet MessageSet +} + +func (b *FetchResponseBlock) decode(pd packetDecoder) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + b.Err = KError(tmp) + + b.HighWaterMarkOffset, err = pd.getInt64() + if err != nil { + return err + } + + msgSetSize, err := pd.getInt32() + if err != nil { + return err + } + + msgSetDecoder, err := pd.getSubset(int(msgSetSize)) + if err != nil { + return err + } + err = (&b.MsgSet).decode(msgSetDecoder) + + return err +} + +func (b *FetchResponseBlock) encode(pe packetEncoder) (err error) { + pe.putInt16(int16(b.Err)) + + pe.putInt64(b.HighWaterMarkOffset) + + pe.push(&lengthField{}) + err = b.MsgSet.encode(pe) + if err != nil { + return err + } + return pe.pop() +} + +type FetchResponse struct { + Blocks map[string]map[int32]*FetchResponseBlock + ThrottleTime time.Duration + Version int16 // v1 requires 0.9+, v2 requires 0.10+ +} + +func (r *FetchResponse) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + + if r.Version >= 1 { + throttle, err := pd.getInt32() + if err != nil { + return err + } + r.ThrottleTime = time.Duration(throttle) * time.Millisecond + } + + numTopics, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks = make(map[string]map[int32]*FetchResponseBlock, numTopics) + for i := 0; i < numTopics; i++ { + name, err := pd.getString() + if err != nil { + return err + } + + numBlocks, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks[name] = make(map[int32]*FetchResponseBlock, numBlocks) + + for j := 0; j < numBlocks; j++ { + id, err := pd.getInt32() + if err != nil { + return err + } + + block := new(FetchResponseBlock) + err = block.decode(pd) + if err != nil { + return err + } + r.Blocks[name][id] = block + } + } + + return nil +} + +func (r *FetchResponse) encode(pe packetEncoder) (err error) { + if r.Version >= 1 { + pe.putInt32(int32(r.ThrottleTime / time.Millisecond)) + } + + err = pe.putArrayLength(len(r.Blocks)) + if err != nil { + return err + } + + for topic, partitions := range r.Blocks { + err = pe.putString(topic) + if err != nil { + return err + } + + err = pe.putArrayLength(len(partitions)) + if err != nil { + return err + } + + for id, block := range partitions { + pe.putInt32(id) + err = block.encode(pe) + if err != nil { + return err + } + } + + } + return nil +} + +func (r *FetchResponse) key() int16 { + return 1 +} + +func (r *FetchResponse) version() int16 { + return r.Version +} + +func (r *FetchResponse) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_9_0_0 + case 2: + return V0_10_0_0 + default: + return minVersion + } +} + +func (r *FetchResponse) GetBlock(topic string, partition int32) *FetchResponseBlock { + if r.Blocks == nil { + return nil + } + + if r.Blocks[topic] == nil { + return nil + } + + return r.Blocks[topic][partition] +} + +func (r *FetchResponse) AddError(topic string, partition int32, err KError) { + if r.Blocks == nil { + r.Blocks = make(map[string]map[int32]*FetchResponseBlock) + } + partitions, ok := r.Blocks[topic] + if !ok { + partitions = make(map[int32]*FetchResponseBlock) + r.Blocks[topic] = partitions + } + frb, ok := partitions[partition] + if !ok { + frb = new(FetchResponseBlock) + partitions[partition] = frb + } + frb.Err = err +} + +func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Encoder, offset int64) { + if r.Blocks == nil { + r.Blocks = make(map[string]map[int32]*FetchResponseBlock) + } + partitions, ok := r.Blocks[topic] + if !ok { + partitions = make(map[int32]*FetchResponseBlock) + r.Blocks[topic] = partitions + } + frb, ok := partitions[partition] + if !ok { + frb = new(FetchResponseBlock) + partitions[partition] = frb + } + var kb []byte + var vb []byte + if key != nil { + kb, _ = key.Encode() + } + if value != nil { + vb, _ = value.Encode() + } + msg := &Message{Key: kb, Value: vb} + msgBlock := &MessageBlock{Msg: msg, Offset: offset} + frb.MsgSet.Messages = append(frb.MsgSet.Messages, msgBlock) +} diff --git a/vendor/github.com/Shopify/sarama/fetch_response_test.go b/vendor/github.com/Shopify/sarama/fetch_response_test.go new file mode 100644 index 00000000..52fb5a74 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/fetch_response_test.go @@ -0,0 +1,84 @@ +package sarama + +import ( + "bytes" + "testing" +) + +var ( + emptyFetchResponse = []byte{ + 0x00, 0x00, 0x00, 0x00} + + oneMessageFetchResponse = []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x05, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, + 0x00, 0x00, 0x00, 0x1C, + // messageSet + 0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x10, + // message + 0x23, 0x96, 0x4a, 0xf7, // CRC + 0x00, + 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x02, 0x00, 0xEE} +) + +func TestEmptyFetchResponse(t *testing.T) { + response := FetchResponse{} + testVersionDecodable(t, "empty", &response, emptyFetchResponse, 0) + + if len(response.Blocks) != 0 { + t.Error("Decoding produced topic blocks where there were none.") + } + +} + +func TestOneMessageFetchResponse(t *testing.T) { + response := FetchResponse{} + testVersionDecodable(t, "one message", &response, oneMessageFetchResponse, 0) + + if len(response.Blocks) != 1 { + t.Fatal("Decoding produced incorrect number of topic blocks.") + } + + if len(response.Blocks["topic"]) != 1 { + t.Fatal("Decoding produced incorrect number of partition blocks for topic.") + } + + block := response.GetBlock("topic", 5) + if block == nil { + t.Fatal("GetBlock didn't return block.") + } + if block.Err != ErrOffsetOutOfRange { + t.Error("Decoding didn't produce correct error code.") + } + if block.HighWaterMarkOffset != 0x10101010 { + t.Error("Decoding didn't produce correct high water mark offset.") + } + if block.MsgSet.PartialTrailingMessage { + t.Error("Decoding detected a partial trailing message where there wasn't one.") + } + + if len(block.MsgSet.Messages) != 1 { + t.Fatal("Decoding produced incorrect number of messages.") + } + msgBlock := block.MsgSet.Messages[0] + if msgBlock.Offset != 0x550000 { + t.Error("Decoding produced incorrect message offset.") + } + msg := msgBlock.Msg + if msg.Codec != CompressionNone { + t.Error("Decoding produced incorrect message compression.") + } + if msg.Key != nil { + t.Error("Decoding produced message key where there was none.") + } + if !bytes.Equal(msg.Value, []byte{0x00, 0xEE}) { + t.Error("Decoding produced incorrect message value.") + } +} diff --git a/vendor/github.com/Shopify/sarama/functional_client_test.go b/vendor/github.com/Shopify/sarama/functional_client_test.go new file mode 100644 index 00000000..2bf99d25 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/functional_client_test.go @@ -0,0 +1,90 @@ +package sarama + +import ( + "fmt" + "testing" + "time" +) + +func TestFuncConnectionFailure(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + Proxies["kafka1"].Enabled = false + SaveProxy(t, "kafka1") + + config := NewConfig() + config.Metadata.Retry.Max = 1 + + _, err := NewClient([]string{kafkaBrokers[0]}, config) + if err != ErrOutOfBrokers { + t.Fatal("Expected returned error to be ErrOutOfBrokers, but was: ", err) + } +} + +func TestFuncClientMetadata(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + config := NewConfig() + config.Metadata.Retry.Max = 1 + config.Metadata.Retry.Backoff = 10 * time.Millisecond + client, err := NewClient(kafkaBrokers, config) + if err != nil { + t.Fatal(err) + } + + if err := client.RefreshMetadata("unknown_topic"); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, got", err) + } + + if _, err := client.Leader("unknown_topic", 0); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, got", err) + } + + if _, err := client.Replicas("invalid/topic", 0); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, got", err) + } + + partitions, err := client.Partitions("test.4") + if err != nil { + t.Error(err) + } + if len(partitions) != 4 { + t.Errorf("Expected test.4 topic to have 4 partitions, found %v", partitions) + } + + partitions, err = client.Partitions("test.1") + if err != nil { + t.Error(err) + } + if len(partitions) != 1 { + t.Errorf("Expected test.1 topic to have 1 partitions, found %v", partitions) + } + + safeClose(t, client) +} + +func TestFuncClientCoordinator(t *testing.T) { + checkKafkaVersion(t, "0.8.2") + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + client, err := NewClient(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + broker, err := client.Coordinator(fmt.Sprintf("another_new_consumer_group_%d", i)) + if err != nil { + t.Fatal(err) + } + + if connected, err := broker.Connected(); !connected || err != nil { + t.Errorf("Expected to coordinator %s broker to be properly connected.", broker.Addr()) + } + } + + safeClose(t, client) +} diff --git a/vendor/github.com/Shopify/sarama/functional_consumer_test.go b/vendor/github.com/Shopify/sarama/functional_consumer_test.go new file mode 100644 index 00000000..ab843310 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/functional_consumer_test.go @@ -0,0 +1,61 @@ +package sarama + +import ( + "math" + "testing" +) + +func TestFuncConsumerOffsetOutOfRange(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + consumer, err := NewConsumer(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := consumer.ConsumePartition("test.1", 0, -10); err != ErrOffsetOutOfRange { + t.Error("Expected ErrOffsetOutOfRange, got:", err) + } + + if _, err := consumer.ConsumePartition("test.1", 0, math.MaxInt64); err != ErrOffsetOutOfRange { + t.Error("Expected ErrOffsetOutOfRange, got:", err) + } + + safeClose(t, consumer) +} + +func TestConsumerHighWaterMarkOffset(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + p, err := NewSyncProducer(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + defer safeClose(t, p) + + _, offset, err := p.SendMessage(&ProducerMessage{Topic: "test.1", Value: StringEncoder("Test")}) + if err != nil { + t.Fatal(err) + } + + c, err := NewConsumer(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + defer safeClose(t, c) + + pc, err := c.ConsumePartition("test.1", 0, OffsetOldest) + if err != nil { + t.Fatal(err) + } + + <-pc.Messages() + + if hwmo := pc.HighWaterMarkOffset(); hwmo != offset+1 { + t.Logf("Last produced offset %d; high water mark should be one higher but found %d.", offset, hwmo) + } + + safeClose(t, pc) +} diff --git a/vendor/github.com/Shopify/sarama/functional_offset_manager_test.go b/vendor/github.com/Shopify/sarama/functional_offset_manager_test.go new file mode 100644 index 00000000..436f35ef --- /dev/null +++ b/vendor/github.com/Shopify/sarama/functional_offset_manager_test.go @@ -0,0 +1,47 @@ +package sarama + +import ( + "testing" +) + +func TestFuncOffsetManager(t *testing.T) { + checkKafkaVersion(t, "0.8.2") + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + client, err := NewClient(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + + offsetManager, err := NewOffsetManagerFromClient("sarama.TestFuncOffsetManager", client) + if err != nil { + t.Fatal(err) + } + + pom1, err := offsetManager.ManagePartition("test.1", 0) + if err != nil { + t.Fatal(err) + } + + pom1.MarkOffset(10, "test metadata") + safeClose(t, pom1) + + pom2, err := offsetManager.ManagePartition("test.1", 0) + if err != nil { + t.Fatal(err) + } + + offset, metadata := pom2.NextOffset() + + if offset != 10 { + t.Errorf("Expected the next offset to be 10, found %d.", offset) + } + if metadata != "test metadata" { + t.Errorf("Expected metadata to be 'test metadata', found %s.", metadata) + } + + safeClose(t, pom2) + safeClose(t, offsetManager) + safeClose(t, client) +} diff --git a/vendor/github.com/Shopify/sarama/functional_producer_test.go b/vendor/github.com/Shopify/sarama/functional_producer_test.go new file mode 100644 index 00000000..91bf3b5e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/functional_producer_test.go @@ -0,0 +1,323 @@ +package sarama + +import ( + "fmt" + "os" + "sync" + "testing" + "time" + + toxiproxy "github.com/Shopify/toxiproxy/client" + "github.com/rcrowley/go-metrics" +) + +const TestBatchSize = 1000 + +func TestFuncProducing(t *testing.T) { + config := NewConfig() + testProducingMessages(t, config) +} + +func TestFuncProducingGzip(t *testing.T) { + config := NewConfig() + config.Producer.Compression = CompressionGZIP + testProducingMessages(t, config) +} + +func TestFuncProducingSnappy(t *testing.T) { + config := NewConfig() + config.Producer.Compression = CompressionSnappy + testProducingMessages(t, config) +} + +func TestFuncProducingNoResponse(t *testing.T) { + config := NewConfig() + config.Producer.RequiredAcks = NoResponse + testProducingMessages(t, config) +} + +func TestFuncProducingFlushing(t *testing.T) { + config := NewConfig() + config.Producer.Flush.Messages = TestBatchSize / 8 + config.Producer.Flush.Frequency = 250 * time.Millisecond + testProducingMessages(t, config) +} + +func TestFuncMultiPartitionProduce(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + config := NewConfig() + config.ChannelBufferSize = 20 + config.Producer.Flush.Frequency = 50 * time.Millisecond + config.Producer.Flush.Messages = 200 + config.Producer.Return.Successes = true + producer, err := NewSyncProducer(kafkaBrokers, config) + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + wg.Add(TestBatchSize) + + for i := 1; i <= TestBatchSize; i++ { + go func(i int) { + defer wg.Done() + msg := &ProducerMessage{Topic: "test.64", Key: nil, Value: StringEncoder(fmt.Sprintf("hur %d", i))} + if _, _, err := producer.SendMessage(msg); err != nil { + t.Error(i, err) + } + }(i) + } + + wg.Wait() + if err := producer.Close(); err != nil { + t.Error(err) + } +} + +func TestFuncProducingToInvalidTopic(t *testing.T) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + producer, err := NewSyncProducer(kafkaBrokers, nil) + if err != nil { + t.Fatal(err) + } + + if _, _, err := producer.SendMessage(&ProducerMessage{Topic: "in/valid"}); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, found", err) + } + + if _, _, err := producer.SendMessage(&ProducerMessage{Topic: "in/valid"}); err != ErrUnknownTopicOrPartition { + t.Error("Expected ErrUnknownTopicOrPartition, found", err) + } + + safeClose(t, producer) +} + +func testProducingMessages(t *testing.T, config *Config) { + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + // Configure some latency in order to properly validate the request latency metric + for _, proxy := range Proxies { + if _, err := proxy.AddToxic("", "latency", "", 1, toxiproxy.Attributes{"latency": 10}); err != nil { + t.Fatal("Unable to configure latency toxicity", err) + } + } + + config.Producer.Return.Successes = true + config.Consumer.Return.Errors = true + + client, err := NewClient(kafkaBrokers, config) + if err != nil { + t.Fatal(err) + } + + // Keep in mind the current offset + initialOffset, err := client.GetOffset("test.1", 0, OffsetNewest) + if err != nil { + t.Fatal(err) + } + + producer, err := NewAsyncProducerFromClient(client) + if err != nil { + t.Fatal(err) + } + + expectedResponses := TestBatchSize + for i := 1; i <= TestBatchSize; { + msg := &ProducerMessage{Topic: "test.1", Key: nil, Value: StringEncoder(fmt.Sprintf("testing %d", i))} + select { + case producer.Input() <- msg: + i++ + case ret := <-producer.Errors(): + t.Fatal(ret.Err) + case <-producer.Successes(): + expectedResponses-- + } + } + for expectedResponses > 0 { + select { + case ret := <-producer.Errors(): + t.Fatal(ret.Err) + case <-producer.Successes(): + expectedResponses-- + } + } + safeClose(t, producer) + + // Validate producer metrics before using the consumer minus the offset request + validateMetrics(t, client) + + master, err := NewConsumerFromClient(client) + if err != nil { + t.Fatal(err) + } + consumer, err := master.ConsumePartition("test.1", 0, initialOffset) + if err != nil { + t.Fatal(err) + } + + for i := 1; i <= TestBatchSize; i++ { + select { + case <-time.After(10 * time.Second): + t.Fatal("Not received any more events in the last 10 seconds.") + + case err := <-consumer.Errors(): + t.Error(err) + + case message := <-consumer.Messages(): + if string(message.Value) != fmt.Sprintf("testing %d", i) { + t.Fatalf("Unexpected message with index %d: %s", i, message.Value) + } + } + + } + safeClose(t, consumer) + safeClose(t, client) +} + +func validateMetrics(t *testing.T, client Client) { + // Get the broker used by test1 topic + var broker *Broker + if partitions, err := client.Partitions("test.1"); err != nil { + t.Error(err) + } else { + for _, partition := range partitions { + if b, err := client.Leader("test.1", partition); err != nil { + t.Error(err) + } else { + if broker != nil && b != broker { + t.Fatal("Expected only one broker, got at least 2") + } + broker = b + } + } + } + + metricValidators := newMetricValidators() + noResponse := client.Config().Producer.RequiredAcks == NoResponse + compressionEnabled := client.Config().Producer.Compression != CompressionNone + + // We are adding 10ms of latency to all requests with toxiproxy + minRequestLatencyInMs := 10 + if noResponse { + // but when we do not wait for a response it can be less than 1ms + minRequestLatencyInMs = 0 + } + + // We read at least 1 byte from the broker + metricValidators.registerForAllBrokers(broker, minCountMeterValidator("incoming-byte-rate", 1)) + // in at least 3 global requests (1 for metadata request, 1 for offset request and N for produce request) + metricValidators.register(minCountMeterValidator("request-rate", 3)) + metricValidators.register(minCountHistogramValidator("request-size", 3)) + metricValidators.register(minValHistogramValidator("request-size", 1)) + metricValidators.register(minValHistogramValidator("request-latency-in-ms", minRequestLatencyInMs)) + // and at least 2 requests to the registered broker (offset + produces) + metricValidators.registerForBroker(broker, minCountMeterValidator("request-rate", 2)) + metricValidators.registerForBroker(broker, minCountHistogramValidator("request-size", 2)) + metricValidators.registerForBroker(broker, minValHistogramValidator("request-size", 1)) + metricValidators.registerForBroker(broker, minValHistogramValidator("request-latency-in-ms", minRequestLatencyInMs)) + + // We send at least 1 batch + metricValidators.registerForGlobalAndTopic("test_1", minCountHistogramValidator("batch-size", 1)) + metricValidators.registerForGlobalAndTopic("test_1", minValHistogramValidator("batch-size", 1)) + if compressionEnabled { + // We record compression ratios between [0.50,-10.00] (50-1000 with a histogram) for at least one "fake" record + metricValidators.registerForGlobalAndTopic("test_1", minCountHistogramValidator("compression-ratio", 1)) + metricValidators.registerForGlobalAndTopic("test_1", minValHistogramValidator("compression-ratio", 50)) + metricValidators.registerForGlobalAndTopic("test_1", maxValHistogramValidator("compression-ratio", 1000)) + } else { + // We record compression ratios of 1.00 (100 with a histogram) for every TestBatchSize record + metricValidators.registerForGlobalAndTopic("test_1", countHistogramValidator("compression-ratio", TestBatchSize)) + metricValidators.registerForGlobalAndTopic("test_1", minValHistogramValidator("compression-ratio", 100)) + metricValidators.registerForGlobalAndTopic("test_1", maxValHistogramValidator("compression-ratio", 100)) + } + + // We send exactly TestBatchSize messages + metricValidators.registerForGlobalAndTopic("test_1", countMeterValidator("record-send-rate", TestBatchSize)) + // We send at least one record per request + metricValidators.registerForGlobalAndTopic("test_1", minCountHistogramValidator("records-per-request", 1)) + metricValidators.registerForGlobalAndTopic("test_1", minValHistogramValidator("records-per-request", 1)) + + // We receive at least 1 byte from the broker + metricValidators.registerForAllBrokers(broker, minCountMeterValidator("outgoing-byte-rate", 1)) + if noResponse { + // in exactly 2 global responses (metadata + offset) + metricValidators.register(countMeterValidator("response-rate", 2)) + metricValidators.register(minCountHistogramValidator("response-size", 2)) + metricValidators.register(minValHistogramValidator("response-size", 1)) + // and exactly 1 offset response for the registered broker + metricValidators.registerForBroker(broker, countMeterValidator("response-rate", 1)) + metricValidators.registerForBroker(broker, minCountHistogramValidator("response-size", 1)) + metricValidators.registerForBroker(broker, minValHistogramValidator("response-size", 1)) + } else { + // in at least 3 global responses (metadata + offset + produces) + metricValidators.register(minCountMeterValidator("response-rate", 3)) + metricValidators.register(minCountHistogramValidator("response-size", 3)) + metricValidators.register(minValHistogramValidator("response-size", 1)) + // and at least 2 for the registered broker + metricValidators.registerForBroker(broker, minCountMeterValidator("response-rate", 2)) + metricValidators.registerForBroker(broker, minCountHistogramValidator("response-size", 2)) + metricValidators.registerForBroker(broker, minValHistogramValidator("response-size", 1)) + } + + // Run the validators + metricValidators.run(t, client.Config().MetricRegistry) +} + +// Benchmarks + +func BenchmarkProducerSmall(b *testing.B) { + benchmarkProducer(b, nil, "test.64", ByteEncoder(make([]byte, 128))) +} +func BenchmarkProducerMedium(b *testing.B) { + benchmarkProducer(b, nil, "test.64", ByteEncoder(make([]byte, 1024))) +} +func BenchmarkProducerLarge(b *testing.B) { + benchmarkProducer(b, nil, "test.64", ByteEncoder(make([]byte, 8192))) +} +func BenchmarkProducerSmallSinglePartition(b *testing.B) { + benchmarkProducer(b, nil, "test.1", ByteEncoder(make([]byte, 128))) +} +func BenchmarkProducerMediumSnappy(b *testing.B) { + conf := NewConfig() + conf.Producer.Compression = CompressionSnappy + benchmarkProducer(b, conf, "test.1", ByteEncoder(make([]byte, 1024))) +} + +func benchmarkProducer(b *testing.B, conf *Config, topic string, value Encoder) { + setupFunctionalTest(b) + defer teardownFunctionalTest(b) + + metricsDisable := os.Getenv("METRICS_DISABLE") + if metricsDisable != "" { + previousUseNilMetrics := metrics.UseNilMetrics + Logger.Println("Disabling metrics using no-op implementation") + metrics.UseNilMetrics = true + // Restore previous setting + defer func() { + metrics.UseNilMetrics = previousUseNilMetrics + }() + } + + producer, err := NewAsyncProducer(kafkaBrokers, conf) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 1; i <= b.N; { + msg := &ProducerMessage{Topic: topic, Key: StringEncoder(fmt.Sprintf("%d", i)), Value: value} + select { + case producer.Input() <- msg: + i++ + case ret := <-producer.Errors(): + b.Fatal(ret.Err) + } + } + safeClose(b, producer) +} diff --git a/vendor/github.com/Shopify/sarama/functional_test.go b/vendor/github.com/Shopify/sarama/functional_test.go new file mode 100644 index 00000000..846eb29f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/functional_test.go @@ -0,0 +1,148 @@ +package sarama + +import ( + "log" + "math/rand" + "net" + "os" + "strconv" + "strings" + "testing" + "time" + + toxiproxy "github.com/Shopify/toxiproxy/client" +) + +const ( + VagrantToxiproxy = "http://192.168.100.67:8474" + VagrantKafkaPeers = "192.168.100.67:9091,192.168.100.67:9092,192.168.100.67:9093,192.168.100.67:9094,192.168.100.67:9095" + VagrantZookeeperPeers = "192.168.100.67:2181,192.168.100.67:2182,192.168.100.67:2183,192.168.100.67:2184,192.168.100.67:2185" +) + +var ( + kafkaAvailable, kafkaRequired bool + kafkaBrokers []string + + proxyClient *toxiproxy.Client + Proxies map[string]*toxiproxy.Proxy + ZKProxies = []string{"zk1", "zk2", "zk3", "zk4", "zk5"} + KafkaProxies = []string{"kafka1", "kafka2", "kafka3", "kafka4", "kafka5"} +) + +func init() { + if os.Getenv("DEBUG") == "true" { + Logger = log.New(os.Stdout, "[sarama] ", log.LstdFlags) + } + + seed := time.Now().UTC().UnixNano() + if tmp := os.Getenv("TEST_SEED"); tmp != "" { + seed, _ = strconv.ParseInt(tmp, 0, 64) + } + Logger.Println("Using random seed:", seed) + rand.Seed(seed) + + proxyAddr := os.Getenv("TOXIPROXY_ADDR") + if proxyAddr == "" { + proxyAddr = VagrantToxiproxy + } + proxyClient = toxiproxy.NewClient(proxyAddr) + + kafkaPeers := os.Getenv("KAFKA_PEERS") + if kafkaPeers == "" { + kafkaPeers = VagrantKafkaPeers + } + kafkaBrokers = strings.Split(kafkaPeers, ",") + + if c, err := net.DialTimeout("tcp", kafkaBrokers[0], 5*time.Second); err == nil { + if err = c.Close(); err == nil { + kafkaAvailable = true + } + } + + kafkaRequired = os.Getenv("CI") != "" +} + +func checkKafkaAvailability(t testing.TB) { + if !kafkaAvailable { + if kafkaRequired { + t.Fatalf("Kafka broker is not available on %s. Set KAFKA_PEERS to connect to Kafka on a different location.", kafkaBrokers[0]) + } else { + t.Skipf("Kafka broker is not available on %s. Set KAFKA_PEERS to connect to Kafka on a different location.", kafkaBrokers[0]) + } + } +} + +func checkKafkaVersion(t testing.TB, requiredVersion string) { + kafkaVersion := os.Getenv("KAFKA_VERSION") + if kafkaVersion == "" { + t.Logf("No KAFKA_VERSION set. This test requires Kafka version %s or higher. Continuing...", requiredVersion) + } else { + available := parseKafkaVersion(kafkaVersion) + required := parseKafkaVersion(requiredVersion) + if !available.satisfies(required) { + t.Skipf("Kafka version %s is required for this test; you have %s. Skipping...", requiredVersion, kafkaVersion) + } + } +} + +func resetProxies(t testing.TB) { + if err := proxyClient.ResetState(); err != nil { + t.Error(err) + } + Proxies = nil +} + +func fetchProxies(t testing.TB) { + var err error + Proxies, err = proxyClient.Proxies() + if err != nil { + t.Fatal(err) + } +} + +func SaveProxy(t *testing.T, px string) { + if err := Proxies[px].Save(); err != nil { + t.Fatal(err) + } +} + +func setupFunctionalTest(t testing.TB) { + checkKafkaAvailability(t) + resetProxies(t) + fetchProxies(t) +} + +func teardownFunctionalTest(t testing.TB) { + resetProxies(t) +} + +type kafkaVersion []int + +func (kv kafkaVersion) satisfies(other kafkaVersion) bool { + var ov int + for index, v := range kv { + if len(other) <= index { + ov = 0 + } else { + ov = other[index] + } + + if v < ov { + return false + } else if v > ov { + return true + } + } + return true +} + +func parseKafkaVersion(version string) kafkaVersion { + numbers := strings.Split(version, ".") + result := make(kafkaVersion, 0, len(numbers)) + for _, number := range numbers { + nr, _ := strconv.Atoi(number) + result = append(result, nr) + } + + return result +} diff --git a/vendor/github.com/Shopify/sarama/heartbeat_request.go b/vendor/github.com/Shopify/sarama/heartbeat_request.go new file mode 100644 index 00000000..ce49c473 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/heartbeat_request.go @@ -0,0 +1,47 @@ +package sarama + +type HeartbeatRequest struct { + GroupId string + GenerationId int32 + MemberId string +} + +func (r *HeartbeatRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.GroupId); err != nil { + return err + } + + pe.putInt32(r.GenerationId) + + if err := pe.putString(r.MemberId); err != nil { + return err + } + + return nil +} + +func (r *HeartbeatRequest) decode(pd packetDecoder, version int16) (err error) { + if r.GroupId, err = pd.getString(); err != nil { + return + } + if r.GenerationId, err = pd.getInt32(); err != nil { + return + } + if r.MemberId, err = pd.getString(); err != nil { + return + } + + return nil +} + +func (r *HeartbeatRequest) key() int16 { + return 12 +} + +func (r *HeartbeatRequest) version() int16 { + return 0 +} + +func (r *HeartbeatRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/heartbeat_request_test.go b/vendor/github.com/Shopify/sarama/heartbeat_request_test.go new file mode 100644 index 00000000..da6cd18f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/heartbeat_request_test.go @@ -0,0 +1,21 @@ +package sarama + +import "testing" + +var ( + basicHeartbeatRequest = []byte{ + 0, 3, 'f', 'o', 'o', // Group ID + 0x00, 0x01, 0x02, 0x03, // Generatiuon ID + 0, 3, 'b', 'a', 'z', // Member ID + } +) + +func TestHeartbeatRequest(t *testing.T) { + var request *HeartbeatRequest + + request = new(HeartbeatRequest) + request.GroupId = "foo" + request.GenerationId = 66051 + request.MemberId = "baz" + testRequest(t, "basic", request, basicHeartbeatRequest) +} diff --git a/vendor/github.com/Shopify/sarama/heartbeat_response.go b/vendor/github.com/Shopify/sarama/heartbeat_response.go new file mode 100644 index 00000000..766f5fde --- /dev/null +++ b/vendor/github.com/Shopify/sarama/heartbeat_response.go @@ -0,0 +1,32 @@ +package sarama + +type HeartbeatResponse struct { + Err KError +} + +func (r *HeartbeatResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + return nil +} + +func (r *HeartbeatResponse) decode(pd packetDecoder, version int16) error { + kerr, err := pd.getInt16() + if err != nil { + return err + } + r.Err = KError(kerr) + + return nil +} + +func (r *HeartbeatResponse) key() int16 { + return 12 +} + +func (r *HeartbeatResponse) version() int16 { + return 0 +} + +func (r *HeartbeatResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/heartbeat_response_test.go b/vendor/github.com/Shopify/sarama/heartbeat_response_test.go new file mode 100644 index 00000000..5bcbec98 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/heartbeat_response_test.go @@ -0,0 +1,18 @@ +package sarama + +import "testing" + +var ( + heartbeatResponseNoError = []byte{ + 0x00, 0x00} +) + +func TestHeartbeatResponse(t *testing.T) { + var response *HeartbeatResponse + + response = new(HeartbeatResponse) + testVersionDecodable(t, "no error", response, heartbeatResponseNoError, 0) + if response.Err != ErrNoError { + t.Error("Decoding error failed: no error expected but found", response.Err) + } +} diff --git a/vendor/github.com/Shopify/sarama/join_group_request.go b/vendor/github.com/Shopify/sarama/join_group_request.go new file mode 100644 index 00000000..3a7ba171 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/join_group_request.go @@ -0,0 +1,143 @@ +package sarama + +type GroupProtocol struct { + Name string + Metadata []byte +} + +func (p *GroupProtocol) decode(pd packetDecoder) (err error) { + p.Name, err = pd.getString() + if err != nil { + return err + } + p.Metadata, err = pd.getBytes() + return err +} + +func (p *GroupProtocol) encode(pe packetEncoder) (err error) { + if err := pe.putString(p.Name); err != nil { + return err + } + if err := pe.putBytes(p.Metadata); err != nil { + return err + } + return nil +} + +type JoinGroupRequest struct { + GroupId string + SessionTimeout int32 + MemberId string + ProtocolType string + GroupProtocols map[string][]byte // deprecated; use OrderedGroupProtocols + OrderedGroupProtocols []*GroupProtocol +} + +func (r *JoinGroupRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.GroupId); err != nil { + return err + } + pe.putInt32(r.SessionTimeout) + if err := pe.putString(r.MemberId); err != nil { + return err + } + if err := pe.putString(r.ProtocolType); err != nil { + return err + } + + if len(r.GroupProtocols) > 0 { + if len(r.OrderedGroupProtocols) > 0 { + return PacketDecodingError{"cannot specify both GroupProtocols and OrderedGroupProtocols on JoinGroupRequest"} + } + + if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil { + return err + } + for name, metadata := range r.GroupProtocols { + if err := pe.putString(name); err != nil { + return err + } + if err := pe.putBytes(metadata); err != nil { + return err + } + } + } else { + if err := pe.putArrayLength(len(r.OrderedGroupProtocols)); err != nil { + return err + } + for _, protocol := range r.OrderedGroupProtocols { + if err := protocol.encode(pe); err != nil { + return err + } + } + } + + return nil +} + +func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) { + if r.GroupId, err = pd.getString(); err != nil { + return + } + + if r.SessionTimeout, err = pd.getInt32(); err != nil { + return + } + + if r.MemberId, err = pd.getString(); err != nil { + return + } + + if r.ProtocolType, err = pd.getString(); err != nil { + return + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + + r.GroupProtocols = make(map[string][]byte) + for i := 0; i < n; i++ { + protocol := &GroupProtocol{} + if err := protocol.decode(pd); err != nil { + return err + } + r.GroupProtocols[protocol.Name] = protocol.Metadata + r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, protocol) + } + + return nil +} + +func (r *JoinGroupRequest) key() int16 { + return 11 +} + +func (r *JoinGroupRequest) version() int16 { + return 0 +} + +func (r *JoinGroupRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} + +func (r *JoinGroupRequest) AddGroupProtocol(name string, metadata []byte) { + r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, &GroupProtocol{ + Name: name, + Metadata: metadata, + }) +} + +func (r *JoinGroupRequest) AddGroupProtocolMetadata(name string, metadata *ConsumerGroupMemberMetadata) error { + bin, err := encode(metadata, nil) + if err != nil { + return err + } + + r.AddGroupProtocol(name, bin) + return nil +} diff --git a/vendor/github.com/Shopify/sarama/join_group_request_test.go b/vendor/github.com/Shopify/sarama/join_group_request_test.go new file mode 100644 index 00000000..1ba3308b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/join_group_request_test.go @@ -0,0 +1,57 @@ +package sarama + +import "testing" + +var ( + joinGroupRequestNoProtocols = []byte{ + 0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID + 0, 0, 0, 100, // Session timeout + 0, 0, // Member ID + 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // Protocol Type + 0, 0, 0, 0, // 0 protocol groups + } + + joinGroupRequestOneProtocol = []byte{ + 0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID + 0, 0, 0, 100, // Session timeout + 0, 11, 'O', 'n', 'e', 'P', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Member ID + 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // Protocol Type + 0, 0, 0, 1, // 1 group protocol + 0, 3, 'o', 'n', 'e', // Protocol name + 0, 0, 0, 3, 0x01, 0x02, 0x03, // protocol metadata + } +) + +func TestJoinGroupRequest(t *testing.T) { + request := new(JoinGroupRequest) + request.GroupId = "TestGroup" + request.SessionTimeout = 100 + request.ProtocolType = "consumer" + testRequest(t, "no protocols", request, joinGroupRequestNoProtocols) +} + +func TestJoinGroupRequestOneProtocol(t *testing.T) { + request := new(JoinGroupRequest) + request.GroupId = "TestGroup" + request.SessionTimeout = 100 + request.MemberId = "OneProtocol" + request.ProtocolType = "consumer" + request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) + packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) + request.GroupProtocols = make(map[string][]byte) + request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} + testRequestDecode(t, "one protocol", request, packet) +} + +func TestJoinGroupRequestDeprecatedEncode(t *testing.T) { + request := new(JoinGroupRequest) + request.GroupId = "TestGroup" + request.SessionTimeout = 100 + request.MemberId = "OneProtocol" + request.ProtocolType = "consumer" + request.GroupProtocols = make(map[string][]byte) + request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} + packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) + request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) + testRequestDecode(t, "one protocol", request, packet) +} diff --git a/vendor/github.com/Shopify/sarama/join_group_response.go b/vendor/github.com/Shopify/sarama/join_group_response.go new file mode 100644 index 00000000..6d35fe36 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/join_group_response.go @@ -0,0 +1,115 @@ +package sarama + +type JoinGroupResponse struct { + Err KError + GenerationId int32 + GroupProtocol string + LeaderId string + MemberId string + Members map[string][]byte +} + +func (r *JoinGroupResponse) GetMembers() (map[string]ConsumerGroupMemberMetadata, error) { + members := make(map[string]ConsumerGroupMemberMetadata, len(r.Members)) + for id, bin := range r.Members { + meta := new(ConsumerGroupMemberMetadata) + if err := decode(bin, meta); err != nil { + return nil, err + } + members[id] = *meta + } + return members, nil +} + +func (r *JoinGroupResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + pe.putInt32(r.GenerationId) + + if err := pe.putString(r.GroupProtocol); err != nil { + return err + } + if err := pe.putString(r.LeaderId); err != nil { + return err + } + if err := pe.putString(r.MemberId); err != nil { + return err + } + + if err := pe.putArrayLength(len(r.Members)); err != nil { + return err + } + + for memberId, memberMetadata := range r.Members { + if err := pe.putString(memberId); err != nil { + return err + } + + if err := pe.putBytes(memberMetadata); err != nil { + return err + } + } + + return nil +} + +func (r *JoinGroupResponse) decode(pd packetDecoder, version int16) (err error) { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Err = KError(kerr) + + if r.GenerationId, err = pd.getInt32(); err != nil { + return + } + + if r.GroupProtocol, err = pd.getString(); err != nil { + return + } + + if r.LeaderId, err = pd.getString(); err != nil { + return + } + + if r.MemberId, err = pd.getString(); err != nil { + return + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + + r.Members = make(map[string][]byte) + for i := 0; i < n; i++ { + memberId, err := pd.getString() + if err != nil { + return err + } + + memberMetadata, err := pd.getBytes() + if err != nil { + return err + } + + r.Members[memberId] = memberMetadata + } + + return nil +} + +func (r *JoinGroupResponse) key() int16 { + return 11 +} + +func (r *JoinGroupResponse) version() int16 { + return 0 +} + +func (r *JoinGroupResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/join_group_response_test.go b/vendor/github.com/Shopify/sarama/join_group_response_test.go new file mode 100644 index 00000000..ba7f71f2 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/join_group_response_test.go @@ -0,0 +1,98 @@ +package sarama + +import ( + "reflect" + "testing" +) + +var ( + joinGroupResponseNoError = []byte{ + 0x00, 0x00, // No error + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen + 0, 3, 'f', 'o', 'o', // Leader ID + 0, 3, 'b', 'a', 'r', // Member ID + 0, 0, 0, 0, // No member info + } + + joinGroupResponseWithError = []byte{ + 0, 23, // Error: inconsistent group protocol + 0x00, 0x00, 0x00, 0x00, // Generation ID + 0, 0, // Protocol name chosen + 0, 0, // Leader ID + 0, 0, // Member ID + 0, 0, 0, 0, // No member info + } + + joinGroupResponseLeader = []byte{ + 0x00, 0x00, // No error + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen + 0, 3, 'f', 'o', 'o', // Leader ID + 0, 3, 'f', 'o', 'o', // Member ID == Leader ID + 0, 0, 0, 1, // 1 member + 0, 3, 'f', 'o', 'o', // Member ID + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Member metadata + } +) + +func TestJoinGroupResponse(t *testing.T) { + var response *JoinGroupResponse + + response = new(JoinGroupResponse) + testVersionDecodable(t, "no error", response, joinGroupResponseNoError, 0) + if response.Err != ErrNoError { + t.Error("Decoding Err failed: no error expected but found", response.Err) + } + if response.GenerationId != 66051 { + t.Error("Decoding GenerationId failed, found:", response.GenerationId) + } + if response.LeaderId != "foo" { + t.Error("Decoding LeaderId failed, found:", response.LeaderId) + } + if response.MemberId != "bar" { + t.Error("Decoding MemberId failed, found:", response.MemberId) + } + if len(response.Members) != 0 { + t.Error("Decoding Members failed, found:", response.Members) + } + + response = new(JoinGroupResponse) + testVersionDecodable(t, "with error", response, joinGroupResponseWithError, 0) + if response.Err != ErrInconsistentGroupProtocol { + t.Error("Decoding Err failed: ErrInconsistentGroupProtocol expected but found", response.Err) + } + if response.GenerationId != 0 { + t.Error("Decoding GenerationId failed, found:", response.GenerationId) + } + if response.LeaderId != "" { + t.Error("Decoding LeaderId failed, found:", response.LeaderId) + } + if response.MemberId != "" { + t.Error("Decoding MemberId failed, found:", response.MemberId) + } + if len(response.Members) != 0 { + t.Error("Decoding Members failed, found:", response.Members) + } + + response = new(JoinGroupResponse) + testVersionDecodable(t, "with error", response, joinGroupResponseLeader, 0) + if response.Err != ErrNoError { + t.Error("Decoding Err failed: ErrNoError expected but found", response.Err) + } + if response.GenerationId != 66051 { + t.Error("Decoding GenerationId failed, found:", response.GenerationId) + } + if response.LeaderId != "foo" { + t.Error("Decoding LeaderId failed, found:", response.LeaderId) + } + if response.MemberId != "foo" { + t.Error("Decoding MemberId failed, found:", response.MemberId) + } + if len(response.Members) != 1 { + t.Error("Decoding Members failed, found:", response.Members) + } + if !reflect.DeepEqual(response.Members["foo"], []byte{0x01, 0x02, 0x03}) { + t.Error("Decoding foo member failed, found:", response.Members["foo"]) + } +} diff --git a/vendor/github.com/Shopify/sarama/leave_group_request.go b/vendor/github.com/Shopify/sarama/leave_group_request.go new file mode 100644 index 00000000..e1774274 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/leave_group_request.go @@ -0,0 +1,40 @@ +package sarama + +type LeaveGroupRequest struct { + GroupId string + MemberId string +} + +func (r *LeaveGroupRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.GroupId); err != nil { + return err + } + if err := pe.putString(r.MemberId); err != nil { + return err + } + + return nil +} + +func (r *LeaveGroupRequest) decode(pd packetDecoder, version int16) (err error) { + if r.GroupId, err = pd.getString(); err != nil { + return + } + if r.MemberId, err = pd.getString(); err != nil { + return + } + + return nil +} + +func (r *LeaveGroupRequest) key() int16 { + return 13 +} + +func (r *LeaveGroupRequest) version() int16 { + return 0 +} + +func (r *LeaveGroupRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/leave_group_request_test.go b/vendor/github.com/Shopify/sarama/leave_group_request_test.go new file mode 100644 index 00000000..c1fed6d2 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/leave_group_request_test.go @@ -0,0 +1,19 @@ +package sarama + +import "testing" + +var ( + basicLeaveGroupRequest = []byte{ + 0, 3, 'f', 'o', 'o', + 0, 3, 'b', 'a', 'r', + } +) + +func TestLeaveGroupRequest(t *testing.T) { + var request *LeaveGroupRequest + + request = new(LeaveGroupRequest) + request.GroupId = "foo" + request.MemberId = "bar" + testRequest(t, "basic", request, basicLeaveGroupRequest) +} diff --git a/vendor/github.com/Shopify/sarama/leave_group_response.go b/vendor/github.com/Shopify/sarama/leave_group_response.go new file mode 100644 index 00000000..d60c626d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/leave_group_response.go @@ -0,0 +1,32 @@ +package sarama + +type LeaveGroupResponse struct { + Err KError +} + +func (r *LeaveGroupResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + return nil +} + +func (r *LeaveGroupResponse) decode(pd packetDecoder, version int16) (err error) { + kerr, err := pd.getInt16() + if err != nil { + return err + } + r.Err = KError(kerr) + + return nil +} + +func (r *LeaveGroupResponse) key() int16 { + return 13 +} + +func (r *LeaveGroupResponse) version() int16 { + return 0 +} + +func (r *LeaveGroupResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/leave_group_response_test.go b/vendor/github.com/Shopify/sarama/leave_group_response_test.go new file mode 100644 index 00000000..9207c666 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/leave_group_response_test.go @@ -0,0 +1,24 @@ +package sarama + +import "testing" + +var ( + leaveGroupResponseNoError = []byte{0x00, 0x00} + leaveGroupResponseWithError = []byte{0, 25} +) + +func TestLeaveGroupResponse(t *testing.T) { + var response *LeaveGroupResponse + + response = new(LeaveGroupResponse) + testVersionDecodable(t, "no error", response, leaveGroupResponseNoError, 0) + if response.Err != ErrNoError { + t.Error("Decoding error failed: no error expected but found", response.Err) + } + + response = new(LeaveGroupResponse) + testVersionDecodable(t, "with error", response, leaveGroupResponseWithError, 0) + if response.Err != ErrUnknownMemberId { + t.Error("Decoding error failed: ErrUnknownMemberId expected but found", response.Err) + } +} diff --git a/vendor/github.com/Shopify/sarama/length_field.go b/vendor/github.com/Shopify/sarama/length_field.go new file mode 100644 index 00000000..70078be5 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/length_field.go @@ -0,0 +1,29 @@ +package sarama + +import "encoding/binary" + +// LengthField implements the PushEncoder and PushDecoder interfaces for calculating 4-byte lengths. +type lengthField struct { + startOffset int +} + +func (l *lengthField) saveOffset(in int) { + l.startOffset = in +} + +func (l *lengthField) reserveLength() int { + return 4 +} + +func (l *lengthField) run(curOffset int, buf []byte) error { + binary.BigEndian.PutUint32(buf[l.startOffset:], uint32(curOffset-l.startOffset-4)) + return nil +} + +func (l *lengthField) check(curOffset int, buf []byte) error { + if uint32(curOffset-l.startOffset-4) != binary.BigEndian.Uint32(buf[l.startOffset:]) { + return PacketDecodingError{"length field invalid"} + } + + return nil +} diff --git a/vendor/github.com/Shopify/sarama/list_groups_request.go b/vendor/github.com/Shopify/sarama/list_groups_request.go new file mode 100644 index 00000000..3b16abf7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/list_groups_request.go @@ -0,0 +1,24 @@ +package sarama + +type ListGroupsRequest struct { +} + +func (r *ListGroupsRequest) encode(pe packetEncoder) error { + return nil +} + +func (r *ListGroupsRequest) decode(pd packetDecoder, version int16) (err error) { + return nil +} + +func (r *ListGroupsRequest) key() int16 { + return 16 +} + +func (r *ListGroupsRequest) version() int16 { + return 0 +} + +func (r *ListGroupsRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/list_groups_request_test.go b/vendor/github.com/Shopify/sarama/list_groups_request_test.go new file mode 100644 index 00000000..2e977d9a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/list_groups_request_test.go @@ -0,0 +1,7 @@ +package sarama + +import "testing" + +func TestListGroupsRequest(t *testing.T) { + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{}, []byte{}) +} diff --git a/vendor/github.com/Shopify/sarama/list_groups_response.go b/vendor/github.com/Shopify/sarama/list_groups_response.go new file mode 100644 index 00000000..56115d4c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/list_groups_response.go @@ -0,0 +1,69 @@ +package sarama + +type ListGroupsResponse struct { + Err KError + Groups map[string]string +} + +func (r *ListGroupsResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + + if err := pe.putArrayLength(len(r.Groups)); err != nil { + return err + } + for groupId, protocolType := range r.Groups { + if err := pe.putString(groupId); err != nil { + return err + } + if err := pe.putString(protocolType); err != nil { + return err + } + } + + return nil +} + +func (r *ListGroupsResponse) decode(pd packetDecoder, version int16) error { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Err = KError(kerr) + + n, err := pd.getArrayLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + + r.Groups = make(map[string]string) + for i := 0; i < n; i++ { + groupId, err := pd.getString() + if err != nil { + return err + } + protocolType, err := pd.getString() + if err != nil { + return err + } + + r.Groups[groupId] = protocolType + } + + return nil +} + +func (r *ListGroupsResponse) key() int16 { + return 16 +} + +func (r *ListGroupsResponse) version() int16 { + return 0 +} + +func (r *ListGroupsResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/list_groups_response_test.go b/vendor/github.com/Shopify/sarama/list_groups_response_test.go new file mode 100644 index 00000000..41ab822f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/list_groups_response_test.go @@ -0,0 +1,58 @@ +package sarama + +import ( + "testing" +) + +var ( + listGroupsResponseEmpty = []byte{ + 0, 0, // no error + 0, 0, 0, 0, // no groups + } + + listGroupsResponseError = []byte{ + 0, 31, // no error + 0, 0, 0, 0, // ErrClusterAuthorizationFailed + } + + listGroupsResponseWithConsumer = []byte{ + 0, 0, // no error + 0, 0, 0, 1, // 1 group + 0, 3, 'f', 'o', 'o', // group name + 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // protocol type + } +) + +func TestListGroupsResponse(t *testing.T) { + var response *ListGroupsResponse + + response = new(ListGroupsResponse) + testVersionDecodable(t, "no error", response, listGroupsResponseEmpty, 0) + if response.Err != ErrNoError { + t.Error("Expected no gerror, found:", response.Err) + } + if len(response.Groups) != 0 { + t.Error("Expected no groups") + } + + response = new(ListGroupsResponse) + testVersionDecodable(t, "no error", response, listGroupsResponseError, 0) + if response.Err != ErrClusterAuthorizationFailed { + t.Error("Expected no gerror, found:", response.Err) + } + if len(response.Groups) != 0 { + t.Error("Expected no groups") + } + + response = new(ListGroupsResponse) + testVersionDecodable(t, "no error", response, listGroupsResponseWithConsumer, 0) + if response.Err != ErrNoError { + t.Error("Expected no gerror, found:", response.Err) + } + if len(response.Groups) != 1 { + t.Error("Expected one group") + } + if response.Groups["foo"] != "consumer" { + t.Error("Expected foo group to use consumer protocol") + } +} diff --git a/vendor/github.com/Shopify/sarama/message.go b/vendor/github.com/Shopify/sarama/message.go new file mode 100644 index 00000000..327c5fa2 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/message.go @@ -0,0 +1,196 @@ +package sarama + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + "time" + + "github.com/eapache/go-xerial-snappy" + "github.com/pierrec/lz4" +) + +// CompressionCodec represents the various compression codecs recognized by Kafka in messages. +type CompressionCodec int8 + +// only the last two bits are really used +const compressionCodecMask int8 = 0x03 + +const ( + CompressionNone CompressionCodec = 0 + CompressionGZIP CompressionCodec = 1 + CompressionSnappy CompressionCodec = 2 + CompressionLZ4 CompressionCodec = 3 +) + +type Message struct { + Codec CompressionCodec // codec used to compress the message contents + Key []byte // the message key, may be nil + Value []byte // the message contents + Set *MessageSet // the message set a message might wrap + Version int8 // v1 requires Kafka 0.10 + Timestamp time.Time // the timestamp of the message (version 1+ only) + + compressedCache []byte + compressedSize int // used for computing the compression ratio metrics +} + +func (m *Message) encode(pe packetEncoder) error { + pe.push(&crc32Field{}) + + pe.putInt8(m.Version) + + attributes := int8(m.Codec) & compressionCodecMask + pe.putInt8(attributes) + + if m.Version >= 1 { + pe.putInt64(m.Timestamp.UnixNano() / int64(time.Millisecond)) + } + + err := pe.putBytes(m.Key) + if err != nil { + return err + } + + var payload []byte + + if m.compressedCache != nil { + payload = m.compressedCache + m.compressedCache = nil + } else if m.Value != nil { + switch m.Codec { + case CompressionNone: + payload = m.Value + case CompressionGZIP: + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + if _, err = writer.Write(m.Value); err != nil { + return err + } + if err = writer.Close(); err != nil { + return err + } + m.compressedCache = buf.Bytes() + payload = m.compressedCache + case CompressionSnappy: + tmp := snappy.Encode(m.Value) + m.compressedCache = tmp + payload = m.compressedCache + case CompressionLZ4: + var buf bytes.Buffer + writer := lz4.NewWriter(&buf) + if _, err = writer.Write(m.Value); err != nil { + return err + } + if err = writer.Close(); err != nil { + return err + } + m.compressedCache = buf.Bytes() + payload = m.compressedCache + + default: + return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)} + } + // Keep in mind the compressed payload size for metric gathering + m.compressedSize = len(payload) + } + + if err = pe.putBytes(payload); err != nil { + return err + } + + return pe.pop() +} + +func (m *Message) decode(pd packetDecoder) (err error) { + err = pd.push(&crc32Field{}) + if err != nil { + return err + } + + m.Version, err = pd.getInt8() + if err != nil { + return err + } + + attribute, err := pd.getInt8() + if err != nil { + return err + } + m.Codec = CompressionCodec(attribute & compressionCodecMask) + + if m.Version >= 1 { + millis, err := pd.getInt64() + if err != nil { + return err + } + m.Timestamp = time.Unix(millis/1000, (millis%1000)*int64(time.Millisecond)) + } + + m.Key, err = pd.getBytes() + if err != nil { + return err + } + + m.Value, err = pd.getBytes() + if err != nil { + return err + } + + // Required for deep equal assertion during tests but might be useful + // for future metrics about the compression ratio in fetch requests + m.compressedSize = len(m.Value) + + switch m.Codec { + case CompressionNone: + // nothing to do + case CompressionGZIP: + if m.Value == nil { + break + } + reader, err := gzip.NewReader(bytes.NewReader(m.Value)) + if err != nil { + return err + } + if m.Value, err = ioutil.ReadAll(reader); err != nil { + return err + } + if err := m.decodeSet(); err != nil { + return err + } + case CompressionSnappy: + if m.Value == nil { + break + } + if m.Value, err = snappy.Decode(m.Value); err != nil { + return err + } + if err := m.decodeSet(); err != nil { + return err + } + case CompressionLZ4: + if m.Value == nil { + break + } + reader := lz4.NewReader(bytes.NewReader(m.Value)) + if m.Value, err = ioutil.ReadAll(reader); err != nil { + return err + } + if err := m.decodeSet(); err != nil { + return err + } + + default: + return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)} + } + + return pd.pop() +} + +// decodes a message set from a previousy encoded bulk-message +func (m *Message) decodeSet() (err error) { + pd := realDecoder{raw: m.Value} + m.Set = &MessageSet{} + return m.Set.decode(&pd) +} diff --git a/vendor/github.com/Shopify/sarama/message_set.go b/vendor/github.com/Shopify/sarama/message_set.go new file mode 100644 index 00000000..f028784e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/message_set.go @@ -0,0 +1,89 @@ +package sarama + +type MessageBlock struct { + Offset int64 + Msg *Message +} + +// Messages convenience helper which returns either all the +// messages that are wrapped in this block +func (msb *MessageBlock) Messages() []*MessageBlock { + if msb.Msg.Set != nil { + return msb.Msg.Set.Messages + } + return []*MessageBlock{msb} +} + +func (msb *MessageBlock) encode(pe packetEncoder) error { + pe.putInt64(msb.Offset) + pe.push(&lengthField{}) + err := msb.Msg.encode(pe) + if err != nil { + return err + } + return pe.pop() +} + +func (msb *MessageBlock) decode(pd packetDecoder) (err error) { + if msb.Offset, err = pd.getInt64(); err != nil { + return err + } + + if err = pd.push(&lengthField{}); err != nil { + return err + } + + msb.Msg = new(Message) + if err = msb.Msg.decode(pd); err != nil { + return err + } + + if err = pd.pop(); err != nil { + return err + } + + return nil +} + +type MessageSet struct { + PartialTrailingMessage bool // whether the set on the wire contained an incomplete trailing MessageBlock + Messages []*MessageBlock +} + +func (ms *MessageSet) encode(pe packetEncoder) error { + for i := range ms.Messages { + err := ms.Messages[i].encode(pe) + if err != nil { + return err + } + } + return nil +} + +func (ms *MessageSet) decode(pd packetDecoder) (err error) { + ms.Messages = nil + + for pd.remaining() > 0 { + msb := new(MessageBlock) + err = msb.decode(pd) + switch err { + case nil: + ms.Messages = append(ms.Messages, msb) + case ErrInsufficientData: + // As an optimization the server is allowed to return a partial message at the + // end of the message set. Clients should handle this case. So we just ignore such things. + ms.PartialTrailingMessage = true + return nil + default: + return err + } + } + + return nil +} + +func (ms *MessageSet) addMessage(msg *Message) { + block := new(MessageBlock) + block.Msg = msg + ms.Messages = append(ms.Messages, block) +} diff --git a/vendor/github.com/Shopify/sarama/message_test.go b/vendor/github.com/Shopify/sarama/message_test.go new file mode 100644 index 00000000..d4a37c22 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/message_test.go @@ -0,0 +1,181 @@ +package sarama + +import ( + "runtime" + "testing" + "time" +) + +var ( + emptyMessage = []byte{ + 167, 236, 104, 3, // CRC + 0x00, // magic version byte + 0x00, // attribute flags + 0xFF, 0xFF, 0xFF, 0xFF, // key + 0xFF, 0xFF, 0xFF, 0xFF} // value + + emptyGzipMessage = []byte{ + 97, 79, 149, 90, //CRC + 0x00, // magic version byte + 0x01, // attribute flags + 0xFF, 0xFF, 0xFF, 0xFF, // key + // value + 0x00, 0x00, 0x00, 0x17, + 0x1f, 0x8b, + 0x08, + 0, 0, 9, 110, 136, 0, 255, 1, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0} + + emptyGzipMessage18 = []byte{ + 132, 99, 80, 148, //CRC + 0x00, // magic version byte + 0x01, // attribute flags + 0xFF, 0xFF, 0xFF, 0xFF, // key + // value + 0x00, 0x00, 0x00, 0x17, + 0x1f, 0x8b, + 0x08, + 0, 0, 0, 0, 0, 0, 255, 1, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0} + + emptyLZ4Message = []byte{ + 132, 219, 238, 101, // CRC + 0x01, // version byte + 0x03, // attribute flags: lz4 + 0, 0, 1, 88, 141, 205, 89, 56, // timestamp + 0xFF, 0xFF, 0xFF, 0xFF, // key + 0x00, 0x00, 0x00, 0x0f, // len + 0x04, 0x22, 0x4D, 0x18, // LZ4 magic number + 100, // LZ4 flags: version 01, block indepedant, content checksum + 112, 185, 0, 0, 0, 0, // LZ4 data + 5, 93, 204, 2, // LZ4 checksum + } + + emptyBulkSnappyMessage = []byte{ + 180, 47, 53, 209, //CRC + 0x00, // magic version byte + 0x02, // attribute flags + 0xFF, 0xFF, 0xFF, 0xFF, // key + 0, 0, 0, 42, + 130, 83, 78, 65, 80, 80, 89, 0, // SNAPPY magic + 0, 0, 0, 1, // min version + 0, 0, 0, 1, // default version + 0, 0, 0, 22, 52, 0, 0, 25, 1, 16, 14, 227, 138, 104, 118, 25, 15, 13, 1, 8, 1, 0, 0, 62, 26, 0} + + emptyBulkGzipMessage = []byte{ + 139, 160, 63, 141, //CRC + 0x00, // magic version byte + 0x01, // attribute flags + 0xFF, 0xFF, 0xFF, 0xFF, // key + 0x00, 0x00, 0x00, 0x27, // len + 0x1f, 0x8b, // Gzip Magic + 0x08, // deflate compressed + 0, 0, 0, 0, 0, 0, 0, 99, 96, 128, 3, 190, 202, 112, 143, 7, 12, 12, 255, 129, 0, 33, 200, 192, 136, 41, 3, 0, 199, 226, 155, 70, 52, 0, 0, 0} + + emptyBulkLZ4Message = []byte{ + 246, 12, 188, 129, // CRC + 0x01, // Version + 0x03, // attribute flags (LZ4) + 255, 255, 249, 209, 212, 181, 73, 201, // timestamp + 0xFF, 0xFF, 0xFF, 0xFF, // key + 0x00, 0x00, 0x00, 0x47, // len + 0x04, 0x22, 0x4D, 0x18, // magic number lz4 + 100, // lz4 flags 01100100 + // version: 01, block indep: 1, block checksum: 0, content size: 0, content checksum: 1, reserved: 00 + 112, 185, 52, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, + 71, 129, 23, 111, // LZ4 checksum + } +) + +func TestMessageEncoding(t *testing.T) { + message := Message{} + testEncodable(t, "empty", &message, emptyMessage) + + message.Value = []byte{} + message.Codec = CompressionGZIP + if runtime.Version() == "go1.8" { + testEncodable(t, "empty gzip", &message, emptyGzipMessage18) + } else { + testEncodable(t, "empty gzip", &message, emptyGzipMessage) + } + + message.Value = []byte{} + message.Codec = CompressionLZ4 + message.Timestamp = time.Unix(1479847795, 0) + message.Version = 1 + testEncodable(t, "empty lz4", &message, emptyLZ4Message) +} + +func TestMessageDecoding(t *testing.T) { + message := Message{} + testDecodable(t, "empty", &message, emptyMessage) + if message.Codec != CompressionNone { + t.Error("Decoding produced compression codec where there was none.") + } + if message.Key != nil { + t.Error("Decoding produced key where there was none.") + } + if message.Value != nil { + t.Error("Decoding produced value where there was none.") + } + if message.Set != nil { + t.Error("Decoding produced set where there was none.") + } + + testDecodable(t, "empty gzip", &message, emptyGzipMessage) + if message.Codec != CompressionGZIP { + t.Error("Decoding produced incorrect compression codec (was gzip).") + } + if message.Key != nil { + t.Error("Decoding produced key where there was none.") + } + if message.Value == nil || len(message.Value) != 0 { + t.Error("Decoding produced nil or content-ful value where there was an empty array.") + } +} + +func TestMessageDecodingBulkSnappy(t *testing.T) { + message := Message{} + testDecodable(t, "bulk snappy", &message, emptyBulkSnappyMessage) + if message.Codec != CompressionSnappy { + t.Errorf("Decoding produced codec %d, but expected %d.", message.Codec, CompressionSnappy) + } + if message.Key != nil { + t.Errorf("Decoding produced key %+v, but none was expected.", message.Key) + } + if message.Set == nil { + t.Error("Decoding produced no set, but one was expected.") + } else if len(message.Set.Messages) != 2 { + t.Errorf("Decoding produced a set with %d messages, but 2 were expected.", len(message.Set.Messages)) + } +} + +func TestMessageDecodingBulkGzip(t *testing.T) { + message := Message{} + testDecodable(t, "bulk gzip", &message, emptyBulkGzipMessage) + if message.Codec != CompressionGZIP { + t.Errorf("Decoding produced codec %d, but expected %d.", message.Codec, CompressionGZIP) + } + if message.Key != nil { + t.Errorf("Decoding produced key %+v, but none was expected.", message.Key) + } + if message.Set == nil { + t.Error("Decoding produced no set, but one was expected.") + } else if len(message.Set.Messages) != 2 { + t.Errorf("Decoding produced a set with %d messages, but 2 were expected.", len(message.Set.Messages)) + } +} + +func TestMessageDecodingBulkLZ4(t *testing.T) { + message := Message{} + testDecodable(t, "bulk lz4", &message, emptyBulkLZ4Message) + if message.Codec != CompressionLZ4 { + t.Errorf("Decoding produced codec %d, but expected %d.", message.Codec, CompressionLZ4) + } + if message.Key != nil { + t.Errorf("Decoding produced key %+v, but none was expected.", message.Key) + } + if message.Set == nil { + t.Error("Decoding produced no set, but one was expected.") + } else if len(message.Set.Messages) != 2 { + t.Errorf("Decoding produced a set with %d messages, but 2 were expected.", len(message.Set.Messages)) + } +} diff --git a/vendor/github.com/Shopify/sarama/metadata_request.go b/vendor/github.com/Shopify/sarama/metadata_request.go new file mode 100644 index 00000000..9a26b55f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metadata_request.go @@ -0,0 +1,52 @@ +package sarama + +type MetadataRequest struct { + Topics []string +} + +func (r *MetadataRequest) encode(pe packetEncoder) error { + err := pe.putArrayLength(len(r.Topics)) + if err != nil { + return err + } + + for i := range r.Topics { + err = pe.putString(r.Topics[i]) + if err != nil { + return err + } + } + return nil +} + +func (r *MetadataRequest) decode(pd packetDecoder, version int16) error { + topicCount, err := pd.getArrayLength() + if err != nil { + return err + } + if topicCount == 0 { + return nil + } + + r.Topics = make([]string, topicCount) + for i := range r.Topics { + topic, err := pd.getString() + if err != nil { + return err + } + r.Topics[i] = topic + } + return nil +} + +func (r *MetadataRequest) key() int16 { + return 3 +} + +func (r *MetadataRequest) version() int16 { + return 0 +} + +func (r *MetadataRequest) requiredVersion() KafkaVersion { + return minVersion +} diff --git a/vendor/github.com/Shopify/sarama/metadata_request_test.go b/vendor/github.com/Shopify/sarama/metadata_request_test.go new file mode 100644 index 00000000..44f3146e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metadata_request_test.go @@ -0,0 +1,29 @@ +package sarama + +import "testing" + +var ( + metadataRequestNoTopics = []byte{ + 0x00, 0x00, 0x00, 0x00} + + metadataRequestOneTopic = []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x06, 't', 'o', 'p', 'i', 'c', '1'} + + metadataRequestThreeTopics = []byte{ + 0x00, 0x00, 0x00, 0x03, + 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x03, 'b', 'a', 'r', + 0x00, 0x03, 'b', 'a', 'z'} +) + +func TestMetadataRequest(t *testing.T) { + request := new(MetadataRequest) + testRequest(t, "no topics", request, metadataRequestNoTopics) + + request.Topics = []string{"topic1"} + testRequest(t, "one topic", request, metadataRequestOneTopic) + + request.Topics = []string{"foo", "bar", "baz"} + testRequest(t, "three topics", request, metadataRequestThreeTopics) +} diff --git a/vendor/github.com/Shopify/sarama/metadata_response.go b/vendor/github.com/Shopify/sarama/metadata_response.go new file mode 100644 index 00000000..f9d6a427 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metadata_response.go @@ -0,0 +1,239 @@ +package sarama + +type PartitionMetadata struct { + Err KError + ID int32 + Leader int32 + Replicas []int32 + Isr []int32 +} + +func (pm *PartitionMetadata) decode(pd packetDecoder) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + pm.Err = KError(tmp) + + pm.ID, err = pd.getInt32() + if err != nil { + return err + } + + pm.Leader, err = pd.getInt32() + if err != nil { + return err + } + + pm.Replicas, err = pd.getInt32Array() + if err != nil { + return err + } + + pm.Isr, err = pd.getInt32Array() + if err != nil { + return err + } + + return nil +} + +func (pm *PartitionMetadata) encode(pe packetEncoder) (err error) { + pe.putInt16(int16(pm.Err)) + pe.putInt32(pm.ID) + pe.putInt32(pm.Leader) + + err = pe.putInt32Array(pm.Replicas) + if err != nil { + return err + } + + err = pe.putInt32Array(pm.Isr) + if err != nil { + return err + } + + return nil +} + +type TopicMetadata struct { + Err KError + Name string + Partitions []*PartitionMetadata +} + +func (tm *TopicMetadata) decode(pd packetDecoder) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + tm.Err = KError(tmp) + + tm.Name, err = pd.getString() + if err != nil { + return err + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + tm.Partitions = make([]*PartitionMetadata, n) + for i := 0; i < n; i++ { + tm.Partitions[i] = new(PartitionMetadata) + err = tm.Partitions[i].decode(pd) + if err != nil { + return err + } + } + + return nil +} + +func (tm *TopicMetadata) encode(pe packetEncoder) (err error) { + pe.putInt16(int16(tm.Err)) + + err = pe.putString(tm.Name) + if err != nil { + return err + } + + err = pe.putArrayLength(len(tm.Partitions)) + if err != nil { + return err + } + + for _, pm := range tm.Partitions { + err = pm.encode(pe) + if err != nil { + return err + } + } + + return nil +} + +type MetadataResponse struct { + Brokers []*Broker + Topics []*TopicMetadata +} + +func (r *MetadataResponse) decode(pd packetDecoder, version int16) (err error) { + n, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Brokers = make([]*Broker, n) + for i := 0; i < n; i++ { + r.Brokers[i] = new(Broker) + err = r.Brokers[i].decode(pd) + if err != nil { + return err + } + } + + n, err = pd.getArrayLength() + if err != nil { + return err + } + + r.Topics = make([]*TopicMetadata, n) + for i := 0; i < n; i++ { + r.Topics[i] = new(TopicMetadata) + err = r.Topics[i].decode(pd) + if err != nil { + return err + } + } + + return nil +} + +func (r *MetadataResponse) encode(pe packetEncoder) error { + err := pe.putArrayLength(len(r.Brokers)) + if err != nil { + return err + } + for _, broker := range r.Brokers { + err = broker.encode(pe) + if err != nil { + return err + } + } + + err = pe.putArrayLength(len(r.Topics)) + if err != nil { + return err + } + for _, tm := range r.Topics { + err = tm.encode(pe) + if err != nil { + return err + } + } + + return nil +} + +func (r *MetadataResponse) key() int16 { + return 3 +} + +func (r *MetadataResponse) version() int16 { + return 0 +} + +func (r *MetadataResponse) requiredVersion() KafkaVersion { + return minVersion +} + +// testing API + +func (r *MetadataResponse) AddBroker(addr string, id int32) { + r.Brokers = append(r.Brokers, &Broker{id: id, addr: addr}) +} + +func (r *MetadataResponse) AddTopic(topic string, err KError) *TopicMetadata { + var tmatch *TopicMetadata + + for _, tm := range r.Topics { + if tm.Name == topic { + tmatch = tm + goto foundTopic + } + } + + tmatch = new(TopicMetadata) + tmatch.Name = topic + r.Topics = append(r.Topics, tmatch) + +foundTopic: + + tmatch.Err = err + return tmatch +} + +func (r *MetadataResponse) AddTopicPartition(topic string, partition, brokerID int32, replicas, isr []int32, err KError) { + tmatch := r.AddTopic(topic, ErrNoError) + var pmatch *PartitionMetadata + + for _, pm := range tmatch.Partitions { + if pm.ID == partition { + pmatch = pm + goto foundPartition + } + } + + pmatch = new(PartitionMetadata) + pmatch.ID = partition + tmatch.Partitions = append(tmatch.Partitions, pmatch) + +foundPartition: + + pmatch.Leader = brokerID + pmatch.Replicas = replicas + pmatch.Isr = isr + pmatch.Err = err + +} diff --git a/vendor/github.com/Shopify/sarama/metadata_response_test.go b/vendor/github.com/Shopify/sarama/metadata_response_test.go new file mode 100644 index 00000000..ea62a4f1 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metadata_response_test.go @@ -0,0 +1,139 @@ +package sarama + +import "testing" + +var ( + emptyMetadataResponse = []byte{ + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00} + + brokersNoTopicsMetadataResponse = []byte{ + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x00, 0xab, 0xff, + 0x00, 0x09, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', + 0x00, 0x00, 0x00, 0x33, + + 0x00, 0x01, 0x02, 0x03, + 0x00, 0x0a, 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x01, 0x11, + + 0x00, 0x00, 0x00, 0x00} + + topicsNoBrokersMetadataResponse = []byte{ + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x00, + 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x04, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x07, + 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, + 0x00, 0x03, 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x00} +) + +func TestEmptyMetadataResponse(t *testing.T) { + response := MetadataResponse{} + + testVersionDecodable(t, "empty", &response, emptyMetadataResponse, 0) + if len(response.Brokers) != 0 { + t.Error("Decoding produced", len(response.Brokers), "brokers where there were none!") + } + if len(response.Topics) != 0 { + t.Error("Decoding produced", len(response.Topics), "topics where there were none!") + } +} + +func TestMetadataResponseWithBrokers(t *testing.T) { + response := MetadataResponse{} + + testVersionDecodable(t, "brokers, no topics", &response, brokersNoTopicsMetadataResponse, 0) + if len(response.Brokers) != 2 { + t.Fatal("Decoding produced", len(response.Brokers), "brokers where there were two!") + } + + if response.Brokers[0].id != 0xabff { + t.Error("Decoding produced invalid broker 0 id.") + } + if response.Brokers[0].addr != "localhost:51" { + t.Error("Decoding produced invalid broker 0 address.") + } + if response.Brokers[1].id != 0x010203 { + t.Error("Decoding produced invalid broker 1 id.") + } + if response.Brokers[1].addr != "google.com:273" { + t.Error("Decoding produced invalid broker 1 address.") + } + + if len(response.Topics) != 0 { + t.Error("Decoding produced", len(response.Topics), "topics where there were none!") + } +} + +func TestMetadataResponseWithTopics(t *testing.T) { + response := MetadataResponse{} + + testVersionDecodable(t, "topics, no brokers", &response, topicsNoBrokersMetadataResponse, 0) + if len(response.Brokers) != 0 { + t.Error("Decoding produced", len(response.Brokers), "brokers where there were none!") + } + + if len(response.Topics) != 2 { + t.Fatal("Decoding produced", len(response.Topics), "topics where there were two!") + } + + if response.Topics[0].Err != ErrNoError { + t.Error("Decoding produced invalid topic 0 error.") + } + + if response.Topics[0].Name != "foo" { + t.Error("Decoding produced invalid topic 0 name.") + } + + if len(response.Topics[0].Partitions) != 1 { + t.Fatal("Decoding produced invalid partition count for topic 0.") + } + + if response.Topics[0].Partitions[0].Err != ErrInvalidMessageSize { + t.Error("Decoding produced invalid topic 0 partition 0 error.") + } + + if response.Topics[0].Partitions[0].ID != 0x01 { + t.Error("Decoding produced invalid topic 0 partition 0 id.") + } + + if response.Topics[0].Partitions[0].Leader != 0x07 { + t.Error("Decoding produced invalid topic 0 partition 0 leader.") + } + + if len(response.Topics[0].Partitions[0].Replicas) != 3 { + t.Fatal("Decoding produced invalid topic 0 partition 0 replicas.") + } + for i := 0; i < 3; i++ { + if response.Topics[0].Partitions[0].Replicas[i] != int32(i+1) { + t.Error("Decoding produced invalid topic 0 partition 0 replica", i) + } + } + + if len(response.Topics[0].Partitions[0].Isr) != 0 { + t.Error("Decoding produced invalid topic 0 partition 0 isr length.") + } + + if response.Topics[1].Err != ErrNoError { + t.Error("Decoding produced invalid topic 1 error.") + } + + if response.Topics[1].Name != "bar" { + t.Error("Decoding produced invalid topic 0 name.") + } + + if len(response.Topics[1].Partitions) != 0 { + t.Error("Decoding produced invalid partition count for topic 1.") + } +} diff --git a/vendor/github.com/Shopify/sarama/metrics.go b/vendor/github.com/Shopify/sarama/metrics.go new file mode 100644 index 00000000..4869708e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metrics.go @@ -0,0 +1,51 @@ +package sarama + +import ( + "fmt" + "strings" + + "github.com/rcrowley/go-metrics" +) + +// Use exponentially decaying reservoir for sampling histograms with the same defaults as the Java library: +// 1028 elements, which offers a 99.9% confidence level with a 5% margin of error assuming a normal distribution, +// and an alpha factor of 0.015, which heavily biases the reservoir to the past 5 minutes of measurements. +// See https://github.com/dropwizard/metrics/blob/v3.1.0/metrics-core/src/main/java/com/codahale/metrics/ExponentiallyDecayingReservoir.java#L38 +const ( + metricsReservoirSize = 1028 + metricsAlphaFactor = 0.015 +) + +func getOrRegisterHistogram(name string, r metrics.Registry) metrics.Histogram { + return r.GetOrRegister(name, func() metrics.Histogram { + return metrics.NewHistogram(metrics.NewExpDecaySample(metricsReservoirSize, metricsAlphaFactor)) + }).(metrics.Histogram) +} + +func getMetricNameForBroker(name string, broker *Broker) string { + // Use broker id like the Java client as it does not contain '.' or ':' characters that + // can be interpreted as special character by monitoring tool (e.g. Graphite) + return fmt.Sprintf(name+"-for-broker-%d", broker.ID()) +} + +func getOrRegisterBrokerMeter(name string, broker *Broker, r metrics.Registry) metrics.Meter { + return metrics.GetOrRegisterMeter(getMetricNameForBroker(name, broker), r) +} + +func getOrRegisterBrokerHistogram(name string, broker *Broker, r metrics.Registry) metrics.Histogram { + return getOrRegisterHistogram(getMetricNameForBroker(name, broker), r) +} + +func getMetricNameForTopic(name string, topic string) string { + // Convert dot to _ since reporters like Graphite typically use dot to represent hierarchy + // cf. KAFKA-1902 and KAFKA-2337 + return fmt.Sprintf(name+"-for-topic-%s", strings.Replace(topic, ".", "_", -1)) +} + +func getOrRegisterTopicMeter(name string, topic string, r metrics.Registry) metrics.Meter { + return metrics.GetOrRegisterMeter(getMetricNameForTopic(name, topic), r) +} + +func getOrRegisterTopicHistogram(name string, topic string, r metrics.Registry) metrics.Histogram { + return getOrRegisterHistogram(getMetricNameForTopic(name, topic), r) +} diff --git a/vendor/github.com/Shopify/sarama/metrics_test.go b/vendor/github.com/Shopify/sarama/metrics_test.go new file mode 100644 index 00000000..789c0ff3 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/metrics_test.go @@ -0,0 +1,172 @@ +package sarama + +import ( + "testing" + + "github.com/rcrowley/go-metrics" +) + +func TestGetOrRegisterHistogram(t *testing.T) { + metricRegistry := metrics.NewRegistry() + histogram := getOrRegisterHistogram("name", metricRegistry) + + if histogram == nil { + t.Error("Unexpected nil histogram") + } + + // Fetch the metric + foundHistogram := metricRegistry.Get("name") + + if foundHistogram != histogram { + t.Error("Unexpected different histogram", foundHistogram, histogram) + } + + // Try to register the metric again + sameHistogram := getOrRegisterHistogram("name", metricRegistry) + + if sameHistogram != histogram { + t.Error("Unexpected different histogram", sameHistogram, histogram) + } +} + +func TestGetMetricNameForBroker(t *testing.T) { + metricName := getMetricNameForBroker("name", &Broker{id: 1}) + + if metricName != "name-for-broker-1" { + t.Error("Unexpected metric name", metricName) + } +} + +// Common type and functions for metric validation +type metricValidator struct { + name string + validator func(*testing.T, interface{}) +} + +type metricValidators []*metricValidator + +func newMetricValidators() metricValidators { + return make([]*metricValidator, 0, 32) +} + +func (m *metricValidators) register(validator *metricValidator) { + *m = append(*m, validator) +} + +func (m *metricValidators) registerForBroker(broker *Broker, validator *metricValidator) { + m.register(&metricValidator{getMetricNameForBroker(validator.name, broker), validator.validator}) +} + +func (m *metricValidators) registerForGlobalAndTopic(topic string, validator *metricValidator) { + m.register(&metricValidator{validator.name, validator.validator}) + m.register(&metricValidator{getMetricNameForTopic(validator.name, topic), validator.validator}) +} + +func (m *metricValidators) registerForAllBrokers(broker *Broker, validator *metricValidator) { + m.register(validator) + m.registerForBroker(broker, validator) +} + +func (m metricValidators) run(t *testing.T, r metrics.Registry) { + for _, metricValidator := range m { + metric := r.Get(metricValidator.name) + if metric == nil { + t.Error("No metric named", metricValidator.name) + } else { + metricValidator.validator(t, metric) + } + } +} + +func meterValidator(name string, extraValidator func(*testing.T, metrics.Meter)) *metricValidator { + return &metricValidator{ + name: name, + validator: func(t *testing.T, metric interface{}) { + if meter, ok := metric.(metrics.Meter); !ok { + t.Errorf("Expected meter metric for '%s', got %T", name, metric) + } else { + extraValidator(t, meter) + } + }, + } +} + +func countMeterValidator(name string, expectedCount int) *metricValidator { + return meterValidator(name, func(t *testing.T, meter metrics.Meter) { + count := meter.Count() + if count != int64(expectedCount) { + t.Errorf("Expected meter metric '%s' count = %d, got %d", name, expectedCount, count) + } + }) +} + +func minCountMeterValidator(name string, minCount int) *metricValidator { + return meterValidator(name, func(t *testing.T, meter metrics.Meter) { + count := meter.Count() + if count < int64(minCount) { + t.Errorf("Expected meter metric '%s' count >= %d, got %d", name, minCount, count) + } + }) +} + +func histogramValidator(name string, extraValidator func(*testing.T, metrics.Histogram)) *metricValidator { + return &metricValidator{ + name: name, + validator: func(t *testing.T, metric interface{}) { + if histogram, ok := metric.(metrics.Histogram); !ok { + t.Errorf("Expected histogram metric for '%s', got %T", name, metric) + } else { + extraValidator(t, histogram) + } + }, + } +} + +func countHistogramValidator(name string, expectedCount int) *metricValidator { + return histogramValidator(name, func(t *testing.T, histogram metrics.Histogram) { + count := histogram.Count() + if count != int64(expectedCount) { + t.Errorf("Expected histogram metric '%s' count = %d, got %d", name, expectedCount, count) + } + }) +} + +func minCountHistogramValidator(name string, minCount int) *metricValidator { + return histogramValidator(name, func(t *testing.T, histogram metrics.Histogram) { + count := histogram.Count() + if count < int64(minCount) { + t.Errorf("Expected histogram metric '%s' count >= %d, got %d", name, minCount, count) + } + }) +} + +func minMaxHistogramValidator(name string, expectedMin int, expectedMax int) *metricValidator { + return histogramValidator(name, func(t *testing.T, histogram metrics.Histogram) { + min := int(histogram.Min()) + if min != expectedMin { + t.Errorf("Expected histogram metric '%s' min = %d, got %d", name, expectedMin, min) + } + max := int(histogram.Max()) + if max != expectedMax { + t.Errorf("Expected histogram metric '%s' max = %d, got %d", name, expectedMax, max) + } + }) +} + +func minValHistogramValidator(name string, minMin int) *metricValidator { + return histogramValidator(name, func(t *testing.T, histogram metrics.Histogram) { + min := int(histogram.Min()) + if min < minMin { + t.Errorf("Expected histogram metric '%s' min >= %d, got %d", name, minMin, min) + } + }) +} + +func maxValHistogramValidator(name string, maxMax int) *metricValidator { + return histogramValidator(name, func(t *testing.T, histogram metrics.Histogram) { + max := int(histogram.Max()) + if max > maxMax { + t.Errorf("Expected histogram metric '%s' max <= %d, got %d", name, maxMax, max) + } + }) +} diff --git a/vendor/github.com/Shopify/sarama/mockbroker.go b/vendor/github.com/Shopify/sarama/mockbroker.go new file mode 100644 index 00000000..0734d34f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mockbroker.go @@ -0,0 +1,324 @@ +package sarama + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "reflect" + "strconv" + "sync" + "time" + + "github.com/davecgh/go-spew/spew" +) + +const ( + expectationTimeout = 500 * time.Millisecond +) + +type requestHandlerFunc func(req *request) (res encoder) + +// RequestNotifierFunc is invoked when a mock broker processes a request successfully +// and will provides the number of bytes read and written. +type RequestNotifierFunc func(bytesRead, bytesWritten int) + +// MockBroker is a mock Kafka broker that is used in unit tests. It is exposed +// to facilitate testing of higher level or specialized consumers and producers +// built on top of Sarama. Note that it does not 'mimic' the Kafka API protocol, +// but rather provides a facility to do that. It takes care of the TCP +// transport, request unmarshaling, response marshaling, and makes it the test +// writer responsibility to program correct according to the Kafka API protocol +// MockBroker behaviour. +// +// MockBroker is implemented as a TCP server listening on a kernel-selected +// localhost port that can accept many connections. It reads Kafka requests +// from that connection and returns responses programmed by the SetHandlerByMap +// function. If a MockBroker receives a request that it has no programmed +// response for, then it returns nothing and the request times out. +// +// A set of MockRequest builders to define mappings used by MockBroker is +// provided by Sarama. But users can develop MockRequests of their own and use +// them along with or instead of the standard ones. +// +// When running tests with MockBroker it is strongly recommended to specify +// a timeout to `go test` so that if the broker hangs waiting for a response, +// the test panics. +// +// It is not necessary to prefix message length or correlation ID to your +// response bytes, the server does that automatically as a convenience. +type MockBroker struct { + brokerID int32 + port int32 + closing chan none + stopper chan none + expectations chan encoder + listener net.Listener + t TestReporter + latency time.Duration + handler requestHandlerFunc + notifier RequestNotifierFunc + history []RequestResponse + lock sync.Mutex +} + +// RequestResponse represents a Request/Response pair processed by MockBroker. +type RequestResponse struct { + Request protocolBody + Response encoder +} + +// SetLatency makes broker pause for the specified period every time before +// replying. +func (b *MockBroker) SetLatency(latency time.Duration) { + b.latency = latency +} + +// SetHandlerByMap defines mapping of Request types to MockResponses. When a +// request is received by the broker, it looks up the request type in the map +// and uses the found MockResponse instance to generate an appropriate reply. +// If the request type is not found in the map then nothing is sent. +func (b *MockBroker) SetHandlerByMap(handlerMap map[string]MockResponse) { + b.setHandler(func(req *request) (res encoder) { + reqTypeName := reflect.TypeOf(req.body).Elem().Name() + mockResponse := handlerMap[reqTypeName] + if mockResponse == nil { + return nil + } + return mockResponse.For(req.body) + }) +} + +// SetNotifier set a function that will get invoked whenever a request has been +// processed successfully and will provide the number of bytes read and written +func (b *MockBroker) SetNotifier(notifier RequestNotifierFunc) { + b.lock.Lock() + b.notifier = notifier + b.lock.Unlock() +} + +// BrokerID returns broker ID assigned to the broker. +func (b *MockBroker) BrokerID() int32 { + return b.brokerID +} + +// History returns a slice of RequestResponse pairs in the order they were +// processed by the broker. Note that in case of multiple connections to the +// broker the order expected by a test can be different from the order recorded +// in the history, unless some synchronization is implemented in the test. +func (b *MockBroker) History() []RequestResponse { + b.lock.Lock() + history := make([]RequestResponse, len(b.history)) + copy(history, b.history) + b.lock.Unlock() + return history +} + +// Port returns the TCP port number the broker is listening for requests on. +func (b *MockBroker) Port() int32 { + return b.port +} + +// Addr returns the broker connection string in the form "
:". +func (b *MockBroker) Addr() string { + return b.listener.Addr().String() +} + +// Close terminates the broker blocking until it stops internal goroutines and +// releases all resources. +func (b *MockBroker) Close() { + close(b.expectations) + if len(b.expectations) > 0 { + buf := bytes.NewBufferString(fmt.Sprintf("mockbroker/%d: not all expectations were satisfied! Still waiting on:\n", b.BrokerID())) + for e := range b.expectations { + _, _ = buf.WriteString(spew.Sdump(e)) + } + b.t.Error(buf.String()) + } + close(b.closing) + <-b.stopper +} + +// setHandler sets the specified function as the request handler. Whenever +// a mock broker reads a request from the wire it passes the request to the +// function and sends back whatever the handler function returns. +func (b *MockBroker) setHandler(handler requestHandlerFunc) { + b.lock.Lock() + b.handler = handler + b.lock.Unlock() +} + +func (b *MockBroker) serverLoop() { + defer close(b.stopper) + var err error + var conn net.Conn + + go func() { + <-b.closing + err := b.listener.Close() + if err != nil { + b.t.Error(err) + } + }() + + wg := &sync.WaitGroup{} + i := 0 + for conn, err = b.listener.Accept(); err == nil; conn, err = b.listener.Accept() { + wg.Add(1) + go b.handleRequests(conn, i, wg) + i++ + } + wg.Wait() + Logger.Printf("*** mockbroker/%d: listener closed, err=%v", b.BrokerID(), err) +} + +func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup) { + defer wg.Done() + defer func() { + _ = conn.Close() + }() + Logger.Printf("*** mockbroker/%d/%d: connection opened", b.BrokerID(), idx) + var err error + + abort := make(chan none) + defer close(abort) + go func() { + select { + case <-b.closing: + _ = conn.Close() + case <-abort: + } + }() + + resHeader := make([]byte, 8) + for { + req, bytesRead, err := decodeRequest(conn) + if err != nil { + Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(req)) + b.serverError(err) + break + } + + if b.latency > 0 { + time.Sleep(b.latency) + } + + b.lock.Lock() + res := b.handler(req) + b.history = append(b.history, RequestResponse{req.body, res}) + b.lock.Unlock() + + if res == nil { + Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(req)) + continue + } + Logger.Printf("*** mockbroker/%d/%d: served %v -> %v", b.brokerID, idx, req, res) + + encodedRes, err := encode(res, nil) + if err != nil { + b.serverError(err) + break + } + if len(encodedRes) == 0 { + b.lock.Lock() + if b.notifier != nil { + b.notifier(bytesRead, 0) + } + b.lock.Unlock() + continue + } + + binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4)) + binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID)) + if _, err = conn.Write(resHeader); err != nil { + b.serverError(err) + break + } + if _, err = conn.Write(encodedRes); err != nil { + b.serverError(err) + break + } + + b.lock.Lock() + if b.notifier != nil { + b.notifier(bytesRead, len(resHeader)+len(encodedRes)) + } + b.lock.Unlock() + } + Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err) +} + +func (b *MockBroker) defaultRequestHandler(req *request) (res encoder) { + select { + case res, ok := <-b.expectations: + if !ok { + return nil + } + return res + case <-time.After(expectationTimeout): + return nil + } +} + +func (b *MockBroker) serverError(err error) { + isConnectionClosedError := false + if _, ok := err.(*net.OpError); ok { + isConnectionClosedError = true + } else if err == io.EOF { + isConnectionClosedError = true + } else if err.Error() == "use of closed network connection" { + isConnectionClosedError = true + } + + if isConnectionClosedError { + return + } + + b.t.Errorf(err.Error()) +} + +// NewMockBroker launches a fake Kafka broker. It takes a TestReporter as provided by the +// test framework and a channel of responses to use. If an error occurs it is +// simply logged to the TestReporter and the broker exits. +func NewMockBroker(t TestReporter, brokerID int32) *MockBroker { + return NewMockBrokerAddr(t, brokerID, "localhost:0") +} + +// NewMockBrokerAddr behaves like newMockBroker but listens on the address you give +// it rather than just some ephemeral port. +func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker { + var err error + + broker := &MockBroker{ + closing: make(chan none), + stopper: make(chan none), + t: t, + brokerID: brokerID, + expectations: make(chan encoder, 512), + } + broker.handler = broker.defaultRequestHandler + + broker.listener, err = net.Listen("tcp", addr) + if err != nil { + t.Fatal(err) + } + Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String()) + _, portStr, err := net.SplitHostPort(broker.listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + tmp, err := strconv.ParseInt(portStr, 10, 32) + if err != nil { + t.Fatal(err) + } + broker.port = int32(tmp) + + go broker.serverLoop() + + return broker +} + +func (b *MockBroker) Returns(e encoder) { + b.expectations <- e +} diff --git a/vendor/github.com/Shopify/sarama/mockresponses.go b/vendor/github.com/Shopify/sarama/mockresponses.go new file mode 100644 index 00000000..a2031420 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mockresponses.go @@ -0,0 +1,455 @@ +package sarama + +import ( + "fmt" +) + +// TestReporter has methods matching go's testing.T to avoid importing +// `testing` in the main part of the library. +type TestReporter interface { + Error(...interface{}) + Errorf(string, ...interface{}) + Fatal(...interface{}) + Fatalf(string, ...interface{}) +} + +// MockResponse is a response builder interface it defines one method that +// allows generating a response based on a request body. MockResponses are used +// to program behavior of MockBroker in tests. +type MockResponse interface { + For(reqBody versionedDecoder) (res encoder) +} + +// MockWrapper is a mock response builder that returns a particular concrete +// response regardless of the actual request passed to the `For` method. +type MockWrapper struct { + res encoder +} + +func (mw *MockWrapper) For(reqBody versionedDecoder) (res encoder) { + return mw.res +} + +func NewMockWrapper(res encoder) *MockWrapper { + return &MockWrapper{res: res} +} + +// MockSequence is a mock response builder that is created from a sequence of +// concrete responses. Every time when a `MockBroker` calls its `For` method +// the next response from the sequence is returned. When the end of the +// sequence is reached the last element from the sequence is returned. +type MockSequence struct { + responses []MockResponse +} + +func NewMockSequence(responses ...interface{}) *MockSequence { + ms := &MockSequence{} + ms.responses = make([]MockResponse, len(responses)) + for i, res := range responses { + switch res := res.(type) { + case MockResponse: + ms.responses[i] = res + case encoder: + ms.responses[i] = NewMockWrapper(res) + default: + panic(fmt.Sprintf("Unexpected response type: %T", res)) + } + } + return ms +} + +func (mc *MockSequence) For(reqBody versionedDecoder) (res encoder) { + res = mc.responses[0].For(reqBody) + if len(mc.responses) > 1 { + mc.responses = mc.responses[1:] + } + return res +} + +// MockMetadataResponse is a `MetadataResponse` builder. +type MockMetadataResponse struct { + leaders map[string]map[int32]int32 + brokers map[string]int32 + t TestReporter +} + +func NewMockMetadataResponse(t TestReporter) *MockMetadataResponse { + return &MockMetadataResponse{ + leaders: make(map[string]map[int32]int32), + brokers: make(map[string]int32), + t: t, + } +} + +func (mmr *MockMetadataResponse) SetLeader(topic string, partition, brokerID int32) *MockMetadataResponse { + partitions := mmr.leaders[topic] + if partitions == nil { + partitions = make(map[int32]int32) + mmr.leaders[topic] = partitions + } + partitions[partition] = brokerID + return mmr +} + +func (mmr *MockMetadataResponse) SetBroker(addr string, brokerID int32) *MockMetadataResponse { + mmr.brokers[addr] = brokerID + return mmr +} + +func (mmr *MockMetadataResponse) For(reqBody versionedDecoder) encoder { + metadataRequest := reqBody.(*MetadataRequest) + metadataResponse := &MetadataResponse{} + for addr, brokerID := range mmr.brokers { + metadataResponse.AddBroker(addr, brokerID) + } + if len(metadataRequest.Topics) == 0 { + for topic, partitions := range mmr.leaders { + for partition, brokerID := range partitions { + metadataResponse.AddTopicPartition(topic, partition, brokerID, nil, nil, ErrNoError) + } + } + return metadataResponse + } + for _, topic := range metadataRequest.Topics { + for partition, brokerID := range mmr.leaders[topic] { + metadataResponse.AddTopicPartition(topic, partition, brokerID, nil, nil, ErrNoError) + } + } + return metadataResponse +} + +// MockOffsetResponse is an `OffsetResponse` builder. +type MockOffsetResponse struct { + offsets map[string]map[int32]map[int64]int64 + t TestReporter +} + +func NewMockOffsetResponse(t TestReporter) *MockOffsetResponse { + return &MockOffsetResponse{ + offsets: make(map[string]map[int32]map[int64]int64), + t: t, + } +} + +func (mor *MockOffsetResponse) SetOffset(topic string, partition int32, time, offset int64) *MockOffsetResponse { + partitions := mor.offsets[topic] + if partitions == nil { + partitions = make(map[int32]map[int64]int64) + mor.offsets[topic] = partitions + } + times := partitions[partition] + if times == nil { + times = make(map[int64]int64) + partitions[partition] = times + } + times[time] = offset + return mor +} + +func (mor *MockOffsetResponse) For(reqBody versionedDecoder) encoder { + offsetRequest := reqBody.(*OffsetRequest) + offsetResponse := &OffsetResponse{} + for topic, partitions := range offsetRequest.blocks { + for partition, block := range partitions { + offset := mor.getOffset(topic, partition, block.time) + offsetResponse.AddTopicPartition(topic, partition, offset) + } + } + return offsetResponse +} + +func (mor *MockOffsetResponse) getOffset(topic string, partition int32, time int64) int64 { + partitions := mor.offsets[topic] + if partitions == nil { + mor.t.Errorf("missing topic: %s", topic) + } + times := partitions[partition] + if times == nil { + mor.t.Errorf("missing partition: %d", partition) + } + offset, ok := times[time] + if !ok { + mor.t.Errorf("missing time: %d", time) + } + return offset +} + +// MockFetchResponse is a `FetchResponse` builder. +type MockFetchResponse struct { + messages map[string]map[int32]map[int64]Encoder + highWaterMarks map[string]map[int32]int64 + t TestReporter + batchSize int +} + +func NewMockFetchResponse(t TestReporter, batchSize int) *MockFetchResponse { + return &MockFetchResponse{ + messages: make(map[string]map[int32]map[int64]Encoder), + highWaterMarks: make(map[string]map[int32]int64), + t: t, + batchSize: batchSize, + } +} + +func (mfr *MockFetchResponse) SetMessage(topic string, partition int32, offset int64, msg Encoder) *MockFetchResponse { + partitions := mfr.messages[topic] + if partitions == nil { + partitions = make(map[int32]map[int64]Encoder) + mfr.messages[topic] = partitions + } + messages := partitions[partition] + if messages == nil { + messages = make(map[int64]Encoder) + partitions[partition] = messages + } + messages[offset] = msg + return mfr +} + +func (mfr *MockFetchResponse) SetHighWaterMark(topic string, partition int32, offset int64) *MockFetchResponse { + partitions := mfr.highWaterMarks[topic] + if partitions == nil { + partitions = make(map[int32]int64) + mfr.highWaterMarks[topic] = partitions + } + partitions[partition] = offset + return mfr +} + +func (mfr *MockFetchResponse) For(reqBody versionedDecoder) encoder { + fetchRequest := reqBody.(*FetchRequest) + res := &FetchResponse{} + for topic, partitions := range fetchRequest.blocks { + for partition, block := range partitions { + initialOffset := block.fetchOffset + offset := initialOffset + maxOffset := initialOffset + int64(mfr.getMessageCount(topic, partition)) + for i := 0; i < mfr.batchSize && offset < maxOffset; { + msg := mfr.getMessage(topic, partition, offset) + if msg != nil { + res.AddMessage(topic, partition, nil, msg, offset) + i++ + } + offset++ + } + fb := res.GetBlock(topic, partition) + if fb == nil { + res.AddError(topic, partition, ErrNoError) + fb = res.GetBlock(topic, partition) + } + fb.HighWaterMarkOffset = mfr.getHighWaterMark(topic, partition) + } + } + return res +} + +func (mfr *MockFetchResponse) getMessage(topic string, partition int32, offset int64) Encoder { + partitions := mfr.messages[topic] + if partitions == nil { + return nil + } + messages := partitions[partition] + if messages == nil { + return nil + } + return messages[offset] +} + +func (mfr *MockFetchResponse) getMessageCount(topic string, partition int32) int { + partitions := mfr.messages[topic] + if partitions == nil { + return 0 + } + messages := partitions[partition] + if messages == nil { + return 0 + } + return len(messages) +} + +func (mfr *MockFetchResponse) getHighWaterMark(topic string, partition int32) int64 { + partitions := mfr.highWaterMarks[topic] + if partitions == nil { + return 0 + } + return partitions[partition] +} + +// MockConsumerMetadataResponse is a `ConsumerMetadataResponse` builder. +type MockConsumerMetadataResponse struct { + coordinators map[string]interface{} + t TestReporter +} + +func NewMockConsumerMetadataResponse(t TestReporter) *MockConsumerMetadataResponse { + return &MockConsumerMetadataResponse{ + coordinators: make(map[string]interface{}), + t: t, + } +} + +func (mr *MockConsumerMetadataResponse) SetCoordinator(group string, broker *MockBroker) *MockConsumerMetadataResponse { + mr.coordinators[group] = broker + return mr +} + +func (mr *MockConsumerMetadataResponse) SetError(group string, kerror KError) *MockConsumerMetadataResponse { + mr.coordinators[group] = kerror + return mr +} + +func (mr *MockConsumerMetadataResponse) For(reqBody versionedDecoder) encoder { + req := reqBody.(*ConsumerMetadataRequest) + group := req.ConsumerGroup + res := &ConsumerMetadataResponse{} + v := mr.coordinators[group] + switch v := v.(type) { + case *MockBroker: + res.Coordinator = &Broker{id: v.BrokerID(), addr: v.Addr()} + case KError: + res.Err = v + } + return res +} + +// MockOffsetCommitResponse is a `OffsetCommitResponse` builder. +type MockOffsetCommitResponse struct { + errors map[string]map[string]map[int32]KError + t TestReporter +} + +func NewMockOffsetCommitResponse(t TestReporter) *MockOffsetCommitResponse { + return &MockOffsetCommitResponse{t: t} +} + +func (mr *MockOffsetCommitResponse) SetError(group, topic string, partition int32, kerror KError) *MockOffsetCommitResponse { + if mr.errors == nil { + mr.errors = make(map[string]map[string]map[int32]KError) + } + topics := mr.errors[group] + if topics == nil { + topics = make(map[string]map[int32]KError) + mr.errors[group] = topics + } + partitions := topics[topic] + if partitions == nil { + partitions = make(map[int32]KError) + topics[topic] = partitions + } + partitions[partition] = kerror + return mr +} + +func (mr *MockOffsetCommitResponse) For(reqBody versionedDecoder) encoder { + req := reqBody.(*OffsetCommitRequest) + group := req.ConsumerGroup + res := &OffsetCommitResponse{} + for topic, partitions := range req.blocks { + for partition := range partitions { + res.AddError(topic, partition, mr.getError(group, topic, partition)) + } + } + return res +} + +func (mr *MockOffsetCommitResponse) getError(group, topic string, partition int32) KError { + topics := mr.errors[group] + if topics == nil { + return ErrNoError + } + partitions := topics[topic] + if partitions == nil { + return ErrNoError + } + kerror, ok := partitions[partition] + if !ok { + return ErrNoError + } + return kerror +} + +// MockProduceResponse is a `ProduceResponse` builder. +type MockProduceResponse struct { + errors map[string]map[int32]KError + t TestReporter +} + +func NewMockProduceResponse(t TestReporter) *MockProduceResponse { + return &MockProduceResponse{t: t} +} + +func (mr *MockProduceResponse) SetError(topic string, partition int32, kerror KError) *MockProduceResponse { + if mr.errors == nil { + mr.errors = make(map[string]map[int32]KError) + } + partitions := mr.errors[topic] + if partitions == nil { + partitions = make(map[int32]KError) + mr.errors[topic] = partitions + } + partitions[partition] = kerror + return mr +} + +func (mr *MockProduceResponse) For(reqBody versionedDecoder) encoder { + req := reqBody.(*ProduceRequest) + res := &ProduceResponse{} + for topic, partitions := range req.msgSets { + for partition := range partitions { + res.AddTopicPartition(topic, partition, mr.getError(topic, partition)) + } + } + return res +} + +func (mr *MockProduceResponse) getError(topic string, partition int32) KError { + partitions := mr.errors[topic] + if partitions == nil { + return ErrNoError + } + kerror, ok := partitions[partition] + if !ok { + return ErrNoError + } + return kerror +} + +// MockOffsetFetchResponse is a `OffsetFetchResponse` builder. +type MockOffsetFetchResponse struct { + offsets map[string]map[string]map[int32]*OffsetFetchResponseBlock + t TestReporter +} + +func NewMockOffsetFetchResponse(t TestReporter) *MockOffsetFetchResponse { + return &MockOffsetFetchResponse{t: t} +} + +func (mr *MockOffsetFetchResponse) SetOffset(group, topic string, partition int32, offset int64, metadata string, kerror KError) *MockOffsetFetchResponse { + if mr.offsets == nil { + mr.offsets = make(map[string]map[string]map[int32]*OffsetFetchResponseBlock) + } + topics := mr.offsets[group] + if topics == nil { + topics = make(map[string]map[int32]*OffsetFetchResponseBlock) + mr.offsets[group] = topics + } + partitions := topics[topic] + if partitions == nil { + partitions = make(map[int32]*OffsetFetchResponseBlock) + topics[topic] = partitions + } + partitions[partition] = &OffsetFetchResponseBlock{offset, metadata, kerror} + return mr +} + +func (mr *MockOffsetFetchResponse) For(reqBody versionedDecoder) encoder { + req := reqBody.(*OffsetFetchRequest) + group := req.ConsumerGroup + res := &OffsetFetchResponse{} + for topic, partitions := range mr.offsets[group] { + for partition, block := range partitions { + res.AddBlock(topic, partition, block) + } + } + return res +} diff --git a/vendor/github.com/Shopify/sarama/mocks/README.md b/vendor/github.com/Shopify/sarama/mocks/README.md new file mode 100644 index 00000000..55a6c2e6 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/README.md @@ -0,0 +1,13 @@ +# sarama/mocks + +The `mocks` subpackage includes mock implementations that implement the interfaces of the major sarama types. +You can use them to test your sarama applications using dependency injection. + +The following mock objects are available: + +- [Consumer](https://godoc.org/github.com/Shopify/sarama/mocks#Consumer), which will create [PartitionConsumer](https://godoc.org/github.com/Shopify/sarama/mocks#PartitionConsumer) mocks. +- [AsyncProducer](https://godoc.org/github.com/Shopify/sarama/mocks#AsyncProducer) +- [SyncProducer](https://godoc.org/github.com/Shopify/sarama/mocks#SyncProducer) + +The mocks allow you to set expectations on them. When you close the mocks, the expectations will be verified, +and the results will be reported to the `*testing.T` object you provided when creating the mock. diff --git a/vendor/github.com/Shopify/sarama/mocks/async_producer.go b/vendor/github.com/Shopify/sarama/mocks/async_producer.go new file mode 100644 index 00000000..24ae5c0d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/async_producer.go @@ -0,0 +1,174 @@ +package mocks + +import ( + "sync" + + "github.com/Shopify/sarama" +) + +// AsyncProducer implements sarama's Producer interface for testing purposes. +// Before you can send messages to it's Input channel, you have to set expectations +// so it knows how to handle the input; it returns an error if the number of messages +// received is bigger then the number of expectations set. You can also set a +// function in each expectation so that the message value is checked by this function +// and an error is returned if the match fails. +type AsyncProducer struct { + l sync.Mutex + t ErrorReporter + expectations []*producerExpectation + closed chan struct{} + input chan *sarama.ProducerMessage + successes chan *sarama.ProducerMessage + errors chan *sarama.ProducerError + lastOffset int64 +} + +// NewAsyncProducer instantiates a new Producer mock. The t argument should +// be the *testing.T instance of your test method. An error will be written to it if +// an expectation is violated. The config argument is used to determine whether it +// should ack successes on the Successes channel. +func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer { + if config == nil { + config = sarama.NewConfig() + } + mp := &AsyncProducer{ + t: t, + closed: make(chan struct{}, 0), + expectations: make([]*producerExpectation, 0), + input: make(chan *sarama.ProducerMessage, config.ChannelBufferSize), + successes: make(chan *sarama.ProducerMessage, config.ChannelBufferSize), + errors: make(chan *sarama.ProducerError, config.ChannelBufferSize), + } + + go func() { + defer func() { + close(mp.successes) + close(mp.errors) + }() + + for msg := range mp.input { + mp.l.Lock() + if mp.expectations == nil || len(mp.expectations) == 0 { + mp.expectations = nil + mp.t.Errorf("No more expectation set on this mock producer to handle the input message.") + } else { + expectation := mp.expectations[0] + mp.expectations = mp.expectations[1:] + if expectation.CheckFunction != nil { + if val, err := msg.Value.Encode(); err != nil { + mp.t.Errorf("Input message encoding failed: %s", err.Error()) + mp.errors <- &sarama.ProducerError{Err: err, Msg: msg} + } else { + err = expectation.CheckFunction(val) + if err != nil { + mp.t.Errorf("Check function returned an error: %s", err.Error()) + mp.errors <- &sarama.ProducerError{Err: err, Msg: msg} + } + } + } + if expectation.Result == errProduceSuccess { + mp.lastOffset++ + if config.Producer.Return.Successes { + msg.Offset = mp.lastOffset + mp.successes <- msg + } + } else { + if config.Producer.Return.Errors { + mp.errors <- &sarama.ProducerError{Err: expectation.Result, Msg: msg} + } + } + } + mp.l.Unlock() + } + + mp.l.Lock() + if len(mp.expectations) > 0 { + mp.t.Errorf("Expected to exhaust all expectations, but %d are left.", len(mp.expectations)) + } + mp.l.Unlock() + + close(mp.closed) + }() + + return mp +} + +//////////////////////////////////////////////// +// Implement Producer interface +//////////////////////////////////////////////// + +// AsyncClose corresponds with the AsyncClose method of sarama's Producer implementation. +// By closing a mock producer, you also tell it that no more input will be provided, so it will +// write an error to the test state if there's any remaining expectations. +func (mp *AsyncProducer) AsyncClose() { + close(mp.input) +} + +// Close corresponds with the Close method of sarama's Producer implementation. +// By closing a mock producer, you also tell it that no more input will be provided, so it will +// write an error to the test state if there's any remaining expectations. +func (mp *AsyncProducer) Close() error { + mp.AsyncClose() + <-mp.closed + return nil +} + +// Input corresponds with the Input method of sarama's Producer implementation. +// You have to set expectations on the mock producer before writing messages to the Input +// channel, so it knows how to handle them. If there is no more remaining expectations and +// a messages is written to the Input channel, the mock producer will write an error to the test +// state object. +func (mp *AsyncProducer) Input() chan<- *sarama.ProducerMessage { + return mp.input +} + +// Successes corresponds with the Successes method of sarama's Producer implementation. +func (mp *AsyncProducer) Successes() <-chan *sarama.ProducerMessage { + return mp.successes +} + +// Errors corresponds with the Errors method of sarama's Producer implementation. +func (mp *AsyncProducer) Errors() <-chan *sarama.ProducerError { + return mp.errors +} + +//////////////////////////////////////////////// +// Setting expectations +//////////////////////////////////////////////// + +// ExpectInputWithCheckerFunctionAndSucceed sets an expectation on the mock producer that a message +// will be provided on the input channel. The mock producer will call the given function to check +// the message value. If an error is returned it will be made available on the Errors channel +// otherwise the mock will handle the message as if it produced successfully, i.e. it will make +// it available on the Successes channel if the Producer.Return.Successes setting is set to true. +func (mp *AsyncProducer) ExpectInputWithCheckerFunctionAndSucceed(cf ValueChecker) { + mp.l.Lock() + defer mp.l.Unlock() + mp.expectations = append(mp.expectations, &producerExpectation{Result: errProduceSuccess, CheckFunction: cf}) +} + +// ExpectInputWithCheckerFunctionAndFail sets an expectation on the mock producer that a message +// will be provided on the input channel. The mock producer will first call the given function to +// check the message value. If an error is returned it will be made available on the Errors channel +// otherwise the mock will handle the message as if it failed to produce successfully. This means +// it will make a ProducerError available on the Errors channel. +func (mp *AsyncProducer) ExpectInputWithCheckerFunctionAndFail(cf ValueChecker, err error) { + mp.l.Lock() + defer mp.l.Unlock() + mp.expectations = append(mp.expectations, &producerExpectation{Result: err, CheckFunction: cf}) +} + +// ExpectInputAndSucceed sets an expectation on the mock producer that a message will be provided +// on the input channel. The mock producer will handle the message as if it is produced successfully, +// i.e. it will make it available on the Successes channel if the Producer.Return.Successes setting +// is set to true. +func (mp *AsyncProducer) ExpectInputAndSucceed() { + mp.ExpectInputWithCheckerFunctionAndSucceed(nil) +} + +// ExpectInputAndFail sets an expectation on the mock producer that a message will be provided +// on the input channel. The mock producer will handle the message as if it failed to produce +// successfully. This means it will make a ProducerError available on the Errors channel. +func (mp *AsyncProducer) ExpectInputAndFail(err error) { + mp.ExpectInputWithCheckerFunctionAndFail(nil, err) +} diff --git a/vendor/github.com/Shopify/sarama/mocks/async_producer_test.go b/vendor/github.com/Shopify/sarama/mocks/async_producer_test.go new file mode 100644 index 00000000..b5d92aad --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/async_producer_test.go @@ -0,0 +1,132 @@ +package mocks + +import ( + "errors" + "fmt" + "regexp" + "strings" + "testing" + + "github.com/Shopify/sarama" +) + +func generateRegexpChecker(re string) func([]byte) error { + return func(val []byte) error { + matched, err := regexp.MatchString(re, string(val)) + if err != nil { + return errors.New("Error while trying to match the input message with the expected pattern: " + err.Error()) + } + if !matched { + return fmt.Errorf("No match between input value \"%s\" and expected pattern \"%s\"", val, re) + } + return nil + } +} + +type testReporterMock struct { + errors []string +} + +func newTestReporterMock() *testReporterMock { + return &testReporterMock{errors: make([]string, 0)} +} + +func (trm *testReporterMock) Errorf(format string, args ...interface{}) { + trm.errors = append(trm.errors, fmt.Sprintf(format, args...)) +} + +func TestMockAsyncProducerImplementsAsyncProducerInterface(t *testing.T) { + var mp interface{} = &AsyncProducer{} + if _, ok := mp.(sarama.AsyncProducer); !ok { + t.Error("The mock producer should implement the sarama.Producer interface.") + } +} + +func TestProducerReturnsExpectationsToChannels(t *testing.T) { + config := sarama.NewConfig() + config.Producer.Return.Successes = true + mp := NewAsyncProducer(t, config) + + mp.ExpectInputAndSucceed() + mp.ExpectInputAndSucceed() + mp.ExpectInputAndFail(sarama.ErrOutOfBrokers) + + mp.Input() <- &sarama.ProducerMessage{Topic: "test 1"} + mp.Input() <- &sarama.ProducerMessage{Topic: "test 2"} + mp.Input() <- &sarama.ProducerMessage{Topic: "test 3"} + + msg1 := <-mp.Successes() + msg2 := <-mp.Successes() + err1 := <-mp.Errors() + + if msg1.Topic != "test 1" { + t.Error("Expected message 1 to be returned first") + } + + if msg2.Topic != "test 2" { + t.Error("Expected message 2 to be returned second") + } + + if err1.Msg.Topic != "test 3" || err1.Err != sarama.ErrOutOfBrokers { + t.Error("Expected message 3 to be returned as error") + } + + if err := mp.Close(); err != nil { + t.Error(err) + } +} + +func TestProducerWithTooFewExpectations(t *testing.T) { + trm := newTestReporterMock() + mp := NewAsyncProducer(trm, nil) + mp.ExpectInputAndSucceed() + + mp.Input() <- &sarama.ProducerMessage{Topic: "test"} + mp.Input() <- &sarama.ProducerMessage{Topic: "test"} + + if err := mp.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Error("Expected to report an error") + } +} + +func TestProducerWithTooManyExpectations(t *testing.T) { + trm := newTestReporterMock() + mp := NewAsyncProducer(trm, nil) + mp.ExpectInputAndSucceed() + mp.ExpectInputAndFail(sarama.ErrOutOfBrokers) + + mp.Input() <- &sarama.ProducerMessage{Topic: "test"} + if err := mp.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Error("Expected to report an error") + } +} + +func TestProducerWithCheckerFunction(t *testing.T) { + trm := newTestReporterMock() + mp := NewAsyncProducer(trm, nil) + mp.ExpectInputWithCheckerFunctionAndSucceed(generateRegexpChecker("^tes")) + mp.ExpectInputWithCheckerFunctionAndSucceed(generateRegexpChecker("^tes$")) + + mp.Input() <- &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + mp.Input() <- &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + if err := mp.Close(); err != nil { + t.Error(err) + } + + if len(mp.Errors()) != 1 { + t.Error("Expected to report an error") + } + + err1 := <-mp.Errors() + if !strings.HasPrefix(err1.Err.Error(), "No match") { + t.Error("Expected to report a value check error, found: ", err1.Err) + } +} diff --git a/vendor/github.com/Shopify/sarama/mocks/consumer.go b/vendor/github.com/Shopify/sarama/mocks/consumer.go new file mode 100644 index 00000000..12aff0d4 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/consumer.go @@ -0,0 +1,315 @@ +package mocks + +import ( + "sync" + "sync/atomic" + + "github.com/Shopify/sarama" +) + +// Consumer implements sarama's Consumer interface for testing purposes. +// Before you can start consuming from this consumer, you have to register +// topic/partitions using ExpectConsumePartition, and set expectations on them. +type Consumer struct { + l sync.Mutex + t ErrorReporter + config *sarama.Config + partitionConsumers map[string]map[int32]*PartitionConsumer + metadata map[string][]int32 +} + +// NewConsumer returns a new mock Consumer instance. The t argument should +// be the *testing.T instance of your test method. An error will be written to it if +// an expectation is violated. The config argument is currently unused and can be set to nil. +func NewConsumer(t ErrorReporter, config *sarama.Config) *Consumer { + if config == nil { + config = sarama.NewConfig() + } + + c := &Consumer{ + t: t, + config: config, + partitionConsumers: make(map[string]map[int32]*PartitionConsumer), + } + return c +} + +/////////////////////////////////////////////////// +// Consumer interface implementation +/////////////////////////////////////////////////// + +// ConsumePartition implements the ConsumePartition method from the sarama.Consumer interface. +// Before you can start consuming a partition, you have to set expectations on it using +// ExpectConsumePartition. You can only consume a partition once per consumer. +func (c *Consumer) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) { + c.l.Lock() + defer c.l.Unlock() + + if c.partitionConsumers[topic] == nil || c.partitionConsumers[topic][partition] == nil { + c.t.Errorf("No expectations set for %s/%d", topic, partition) + return nil, errOutOfExpectations + } + + pc := c.partitionConsumers[topic][partition] + if pc.consumed { + return nil, sarama.ConfigurationError("The topic/partition is already being consumed") + } + + if pc.offset != AnyOffset && pc.offset != offset { + c.t.Errorf("Unexpected offset when calling ConsumePartition for %s/%d. Expected %d, got %d.", topic, partition, pc.offset, offset) + } + + pc.consumed = true + return pc, nil +} + +// Topics returns a list of topics, as registered with SetMetadata +func (c *Consumer) Topics() ([]string, error) { + c.l.Lock() + defer c.l.Unlock() + + if c.metadata == nil { + c.t.Errorf("Unexpected call to Topics. Initialize the mock's topic metadata with SetMetadata.") + return nil, sarama.ErrOutOfBrokers + } + + var result []string + for topic := range c.metadata { + result = append(result, topic) + } + return result, nil +} + +// Partitions returns the list of parititons for the given topic, as registered with SetMetadata +func (c *Consumer) Partitions(topic string) ([]int32, error) { + c.l.Lock() + defer c.l.Unlock() + + if c.metadata == nil { + c.t.Errorf("Unexpected call to Partitions. Initialize the mock's topic metadata with SetMetadata.") + return nil, sarama.ErrOutOfBrokers + } + if c.metadata[topic] == nil { + return nil, sarama.ErrUnknownTopicOrPartition + } + + return c.metadata[topic], nil +} + +func (c *Consumer) HighWaterMarks() map[string]map[int32]int64 { + c.l.Lock() + defer c.l.Unlock() + + hwms := make(map[string]map[int32]int64, len(c.partitionConsumers)) + for topic, partitionConsumers := range c.partitionConsumers { + hwm := make(map[int32]int64, len(partitionConsumers)) + for partition, pc := range partitionConsumers { + hwm[partition] = pc.HighWaterMarkOffset() + } + hwms[topic] = hwm + } + + return hwms +} + +// Close implements the Close method from the sarama.Consumer interface. It will close +// all registered PartitionConsumer instances. +func (c *Consumer) Close() error { + c.l.Lock() + defer c.l.Unlock() + + for _, partitions := range c.partitionConsumers { + for _, partitionConsumer := range partitions { + _ = partitionConsumer.Close() + } + } + + return nil +} + +/////////////////////////////////////////////////// +// Expectation API +/////////////////////////////////////////////////// + +// SetTopicMetadata sets the clusters topic/partition metadata, +// which will be returned by Topics() and Partitions(). +func (c *Consumer) SetTopicMetadata(metadata map[string][]int32) { + c.l.Lock() + defer c.l.Unlock() + + c.metadata = metadata +} + +// ExpectConsumePartition will register a topic/partition, so you can set expectations on it. +// The registered PartitionConsumer will be returned, so you can set expectations +// on it using method chaining. Once a topic/partition is registered, you are +// expected to start consuming it using ConsumePartition. If that doesn't happen, +// an error will be written to the error reporter once the mock consumer is closed. It will +// also expect that the +func (c *Consumer) ExpectConsumePartition(topic string, partition int32, offset int64) *PartitionConsumer { + c.l.Lock() + defer c.l.Unlock() + + if c.partitionConsumers[topic] == nil { + c.partitionConsumers[topic] = make(map[int32]*PartitionConsumer) + } + + if c.partitionConsumers[topic][partition] == nil { + c.partitionConsumers[topic][partition] = &PartitionConsumer{ + t: c.t, + topic: topic, + partition: partition, + offset: offset, + messages: make(chan *sarama.ConsumerMessage, c.config.ChannelBufferSize), + errors: make(chan *sarama.ConsumerError, c.config.ChannelBufferSize), + } + } + + return c.partitionConsumers[topic][partition] +} + +/////////////////////////////////////////////////// +// PartitionConsumer mock type +/////////////////////////////////////////////////// + +// PartitionConsumer implements sarama's PartitionConsumer interface for testing purposes. +// It is returned by the mock Consumers ConsumePartitionMethod, but only if it is +// registered first using the Consumer's ExpectConsumePartition method. Before consuming the +// Errors and Messages channel, you should specify what values will be provided on these +// channels using YieldMessage and YieldError. +type PartitionConsumer struct { + highWaterMarkOffset int64 // must be at the top of the struct because https://golang.org/pkg/sync/atomic/#pkg-note-BUG + l sync.Mutex + t ErrorReporter + topic string + partition int32 + offset int64 + messages chan *sarama.ConsumerMessage + errors chan *sarama.ConsumerError + singleClose sync.Once + consumed bool + errorsShouldBeDrained bool + messagesShouldBeDrained bool +} + +/////////////////////////////////////////////////// +// PartitionConsumer interface implementation +/////////////////////////////////////////////////// + +// AsyncClose implements the AsyncClose method from the sarama.PartitionConsumer interface. +func (pc *PartitionConsumer) AsyncClose() { + pc.singleClose.Do(func() { + close(pc.messages) + close(pc.errors) + }) +} + +// Close implements the Close method from the sarama.PartitionConsumer interface. It will +// verify whether the partition consumer was actually started. +func (pc *PartitionConsumer) Close() error { + if !pc.consumed { + pc.t.Errorf("Expectations set on %s/%d, but no partition consumer was started.", pc.topic, pc.partition) + return errPartitionConsumerNotStarted + } + + if pc.errorsShouldBeDrained && len(pc.errors) > 0 { + pc.t.Errorf("Expected the errors channel for %s/%d to be drained on close, but found %d errors.", pc.topic, pc.partition, len(pc.errors)) + } + + if pc.messagesShouldBeDrained && len(pc.messages) > 0 { + pc.t.Errorf("Expected the messages channel for %s/%d to be drained on close, but found %d messages.", pc.topic, pc.partition, len(pc.messages)) + } + + pc.AsyncClose() + + var ( + closeErr error + wg sync.WaitGroup + ) + + wg.Add(1) + go func() { + defer wg.Done() + + var errs = make(sarama.ConsumerErrors, 0) + for err := range pc.errors { + errs = append(errs, err) + } + + if len(errs) > 0 { + closeErr = errs + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for range pc.messages { + // drain + } + }() + + wg.Wait() + return closeErr +} + +// Errors implements the Errors method from the sarama.PartitionConsumer interface. +func (pc *PartitionConsumer) Errors() <-chan *sarama.ConsumerError { + return pc.errors +} + +// Messages implements the Messages method from the sarama.PartitionConsumer interface. +func (pc *PartitionConsumer) Messages() <-chan *sarama.ConsumerMessage { + return pc.messages +} + +func (pc *PartitionConsumer) HighWaterMarkOffset() int64 { + return atomic.LoadInt64(&pc.highWaterMarkOffset) + 1 +} + +/////////////////////////////////////////////////// +// Expectation API +/////////////////////////////////////////////////// + +// YieldMessage will yield a messages Messages channel of this partition consumer +// when it is consumed. By default, the mock consumer will not verify whether this +// message was consumed from the Messages channel, because there are legitimate +// reasons forthis not to happen. ou can call ExpectMessagesDrainedOnClose so it will +// verify that the channel is empty on close. +func (pc *PartitionConsumer) YieldMessage(msg *sarama.ConsumerMessage) { + pc.l.Lock() + defer pc.l.Unlock() + + msg.Topic = pc.topic + msg.Partition = pc.partition + msg.Offset = atomic.AddInt64(&pc.highWaterMarkOffset, 1) + + pc.messages <- msg +} + +// YieldError will yield an error on the Errors channel of this partition consumer +// when it is consumed. By default, the mock consumer will not verify whether this error was +// consumed from the Errors channel, because there are legitimate reasons for this +// not to happen. You can call ExpectErrorsDrainedOnClose so it will verify that +// the channel is empty on close. +func (pc *PartitionConsumer) YieldError(err error) { + pc.errors <- &sarama.ConsumerError{ + Topic: pc.topic, + Partition: pc.partition, + Err: err, + } +} + +// ExpectMessagesDrainedOnClose sets an expectation on the partition consumer +// that the messages channel will be fully drained when Close is called. If this +// expectation is not met, an error is reported to the error reporter. +func (pc *PartitionConsumer) ExpectMessagesDrainedOnClose() { + pc.messagesShouldBeDrained = true +} + +// ExpectErrorsDrainedOnClose sets an expectation on the partition consumer +// that the errors channel will be fully drained when Close is called. If this +// expectation is not met, an error is reported to the error reporter. +func (pc *PartitionConsumer) ExpectErrorsDrainedOnClose() { + pc.errorsShouldBeDrained = true +} diff --git a/vendor/github.com/Shopify/sarama/mocks/consumer_test.go b/vendor/github.com/Shopify/sarama/mocks/consumer_test.go new file mode 100644 index 00000000..311cfa02 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/consumer_test.go @@ -0,0 +1,249 @@ +package mocks + +import ( + "sort" + "testing" + + "github.com/Shopify/sarama" +) + +func TestMockConsumerImplementsConsumerInterface(t *testing.T) { + var c interface{} = &Consumer{} + if _, ok := c.(sarama.Consumer); !ok { + t.Error("The mock consumer should implement the sarama.Consumer interface.") + } + + var pc interface{} = &PartitionConsumer{} + if _, ok := pc.(sarama.PartitionConsumer); !ok { + t.Error("The mock partitionconsumer should implement the sarama.PartitionConsumer interface.") + } +} + +func TestConsumerHandlesExpectations(t *testing.T) { + consumer := NewConsumer(t, nil) + defer func() { + if err := consumer.Close(); err != nil { + t.Error(err) + } + }() + + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world")}) + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers) + consumer.ExpectConsumePartition("test", 1, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world again")}) + consumer.ExpectConsumePartition("other", 0, AnyOffset).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello other")}) + + pc_test0, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest) + if err != nil { + t.Fatal(err) + } + test0_msg := <-pc_test0.Messages() + if test0_msg.Topic != "test" || test0_msg.Partition != 0 || string(test0_msg.Value) != "hello world" { + t.Error("Message was not as expected:", test0_msg) + } + test0_err := <-pc_test0.Errors() + if test0_err.Err != sarama.ErrOutOfBrokers { + t.Error("Expected sarama.ErrOutOfBrokers, found:", test0_err.Err) + } + + pc_test1, err := consumer.ConsumePartition("test", 1, sarama.OffsetOldest) + if err != nil { + t.Fatal(err) + } + test1_msg := <-pc_test1.Messages() + if test1_msg.Topic != "test" || test1_msg.Partition != 1 || string(test1_msg.Value) != "hello world again" { + t.Error("Message was not as expected:", test1_msg) + } + + pc_other0, err := consumer.ConsumePartition("other", 0, sarama.OffsetNewest) + if err != nil { + t.Fatal(err) + } + other0_msg := <-pc_other0.Messages() + if other0_msg.Topic != "other" || other0_msg.Partition != 0 || string(other0_msg.Value) != "hello other" { + t.Error("Message was not as expected:", other0_msg) + } +} + +func TestConsumerReturnsNonconsumedErrorsOnClose(t *testing.T) { + consumer := NewConsumer(t, nil) + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers) + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldError(sarama.ErrOutOfBrokers) + + pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest) + if err != nil { + t.Fatal(err) + } + + select { + case <-pc.Messages(): + t.Error("Did not epxect a message on the messages channel.") + case err := <-pc.Errors(): + if err.Err != sarama.ErrOutOfBrokers { + t.Error("Expected sarama.ErrOutOfBrokers, found", err) + } + } + + errs := pc.Close().(sarama.ConsumerErrors) + if len(errs) != 1 && errs[0].Err != sarama.ErrOutOfBrokers { + t.Error("Expected Close to return the remaining sarama.ErrOutOfBrokers") + } +} + +func TestConsumerWithoutExpectationsOnPartition(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + + _, err := consumer.ConsumePartition("test", 1, sarama.OffsetOldest) + if err != errOutOfExpectations { + t.Error("Expected ConsumePartition to return errOutOfExpectations") + } + + if err := consumer.Close(); err != nil { + t.Error("No error expected on close, but found:", err) + } + + if len(trm.errors) != 1 { + t.Errorf("Expected an expectation failure to be set on the error reporter.") + } +} + +func TestConsumerWithExpectationsOnUnconsumedPartition(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest).YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello world")}) + + if err := consumer.Close(); err != nil { + t.Error("No error expected on close, but found:", err) + } + + if len(trm.errors) != 1 { + t.Errorf("Expected an expectation failure to be set on the error reporter.") + } +} + +func TestConsumerWithWrongOffsetExpectation(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest) + + _, err := consumer.ConsumePartition("test", 0, sarama.OffsetNewest) + if err != nil { + t.Error("Did not expect error, found:", err) + } + + if len(trm.errors) != 1 { + t.Errorf("Expected an expectation failure to be set on the error reporter.") + } + + if err := consumer.Close(); err != nil { + t.Error(err) + } +} + +func TestConsumerViolatesMessagesDrainedExpectation(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + pcmock := consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest) + pcmock.YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello")}) + pcmock.YieldMessage(&sarama.ConsumerMessage{Value: []byte("hello")}) + pcmock.ExpectMessagesDrainedOnClose() + + pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest) + if err != nil { + t.Error(err) + } + + // consume first message, not second one + <-pc.Messages() + + if err := consumer.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Errorf("Expected an expectation failure to be set on the error reporter.") + } +} + +func TestConsumerMeetsErrorsDrainedExpectation(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + + pcmock := consumer.ExpectConsumePartition("test", 0, sarama.OffsetOldest) + pcmock.YieldError(sarama.ErrInvalidMessage) + pcmock.YieldError(sarama.ErrInvalidMessage) + pcmock.ExpectErrorsDrainedOnClose() + + pc, err := consumer.ConsumePartition("test", 0, sarama.OffsetOldest) + if err != nil { + t.Error(err) + } + + // consume first and second error, + <-pc.Errors() + <-pc.Errors() + + if err := consumer.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 0 { + t.Errorf("Expected no expectation failures to be set on the error reporter.") + } +} + +func TestConsumerTopicMetadata(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + + consumer.SetTopicMetadata(map[string][]int32{ + "test1": {0, 1, 2, 3}, + "test2": {0, 1, 2, 3, 4, 5, 6, 7}, + }) + + topics, err := consumer.Topics() + if err != nil { + t.Error(t) + } + + sortedTopics := sort.StringSlice(topics) + sortedTopics.Sort() + if len(sortedTopics) != 2 || sortedTopics[0] != "test1" || sortedTopics[1] != "test2" { + t.Error("Unexpected topics returned:", sortedTopics) + } + + partitions1, err := consumer.Partitions("test1") + if err != nil { + t.Error(t) + } + + if len(partitions1) != 4 { + t.Error("Unexpected partitions returned:", len(partitions1)) + } + + partitions2, err := consumer.Partitions("test2") + if err != nil { + t.Error(t) + } + + if len(partitions2) != 8 { + t.Error("Unexpected partitions returned:", len(partitions2)) + } + + if len(trm.errors) != 0 { + t.Errorf("Expected no expectation failures to be set on the error reporter.") + } +} + +func TestConsumerUnexpectedTopicMetadata(t *testing.T) { + trm := newTestReporterMock() + consumer := NewConsumer(trm, nil) + + if _, err := consumer.Topics(); err != sarama.ErrOutOfBrokers { + t.Error("Expected sarama.ErrOutOfBrokers, found", err) + } + + if len(trm.errors) != 1 { + t.Errorf("Expected an expectation failure to be set on the error reporter.") + } +} diff --git a/vendor/github.com/Shopify/sarama/mocks/mocks.go b/vendor/github.com/Shopify/sarama/mocks/mocks.go new file mode 100644 index 00000000..4adb838d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/mocks.go @@ -0,0 +1,48 @@ +/* +Package mocks provides mocks that can be used for testing applications +that use Sarama. The mock types provided by this package implement the +interfaces Sarama exports, so you can use them for dependency injection +in your tests. + +All mock instances require you to set expectations on them before you +can use them. It will determine how the mock will behave. If an +expectation is not met, it will make your test fail. + +NOTE: this package currently does not fall under the API stability +guarantee of Sarama as it is still considered experimental. +*/ +package mocks + +import ( + "errors" + + "github.com/Shopify/sarama" +) + +// ErrorReporter is a simple interface that includes the testing.T methods we use to report +// expectation violations when using the mock objects. +type ErrorReporter interface { + Errorf(string, ...interface{}) +} + +// ValueChecker is a function type to be set in each expectation of the producer mocks +// to check the value passed. +type ValueChecker func(val []byte) error + +var ( + errProduceSuccess error = nil + errOutOfExpectations = errors.New("No more expectations set on mock") + errPartitionConsumerNotStarted = errors.New("The partition consumer was never started") +) + +const AnyOffset int64 = -1000 + +type producerExpectation struct { + Result error + CheckFunction ValueChecker +} + +type consumerExpectation struct { + Err error + Msg *sarama.ConsumerMessage +} diff --git a/vendor/github.com/Shopify/sarama/mocks/sync_producer.go b/vendor/github.com/Shopify/sarama/mocks/sync_producer.go new file mode 100644 index 00000000..5de79cce --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/sync_producer.go @@ -0,0 +1,146 @@ +package mocks + +import ( + "sync" + + "github.com/Shopify/sarama" +) + +// SyncProducer implements sarama's SyncProducer interface for testing purposes. +// Before you can use it, you have to set expectations on the mock SyncProducer +// to tell it how to handle calls to SendMessage, so you can easily test success +// and failure scenarios. +type SyncProducer struct { + l sync.Mutex + t ErrorReporter + expectations []*producerExpectation + lastOffset int64 +} + +// NewSyncProducer instantiates a new SyncProducer mock. The t argument should +// be the *testing.T instance of your test method. An error will be written to it if +// an expectation is violated. The config argument is currently unused, but is +// maintained to be compatible with the async Producer. +func NewSyncProducer(t ErrorReporter, config *sarama.Config) *SyncProducer { + return &SyncProducer{ + t: t, + expectations: make([]*producerExpectation, 0), + } +} + +//////////////////////////////////////////////// +// Implement SyncProducer interface +//////////////////////////////////////////////// + +// SendMessage corresponds with the SendMessage method of sarama's SyncProducer implementation. +// You have to set expectations on the mock producer before calling SendMessage, so it knows +// how to handle them. You can set a function in each expectation so that the message value +// checked by this function and an error is returned if the match fails. +// If there is no more remaining expectation when SendMessage is called, +// the mock producer will write an error to the test state object. +func (sp *SyncProducer) SendMessage(msg *sarama.ProducerMessage) (partition int32, offset int64, err error) { + sp.l.Lock() + defer sp.l.Unlock() + + if len(sp.expectations) > 0 { + expectation := sp.expectations[0] + sp.expectations = sp.expectations[1:] + if expectation.CheckFunction != nil { + val, err := msg.Value.Encode() + if err != nil { + sp.t.Errorf("Input message encoding failed: %s", err.Error()) + return -1, -1, err + } + + errCheck := expectation.CheckFunction(val) + if errCheck != nil { + sp.t.Errorf("Check function returned an error: %s", errCheck.Error()) + return -1, -1, errCheck + } + } + if expectation.Result == errProduceSuccess { + sp.lastOffset++ + msg.Offset = sp.lastOffset + return 0, msg.Offset, nil + } + return -1, -1, expectation.Result + } + sp.t.Errorf("No more expectation set on this mock producer to handle the input message.") + return -1, -1, errOutOfExpectations +} + +// SendMessages corresponds with the SendMessages method of sarama's SyncProducer implementation. +// You have to set expectations on the mock producer before calling SendMessages, so it knows +// how to handle them. If there is no more remaining expectations when SendMessages is called, +// the mock producer will write an error to the test state object. +func (sp *SyncProducer) SendMessages(msgs []*sarama.ProducerMessage) error { + sp.l.Lock() + defer sp.l.Unlock() + + if len(sp.expectations) >= len(msgs) { + expectations := sp.expectations[0 : len(msgs)-1] + sp.expectations = sp.expectations[len(msgs):] + + for _, expectation := range expectations { + if expectation.Result != errProduceSuccess { + return expectation.Result + } + + } + return nil + } + sp.t.Errorf("Insufficient expectations set on this mock producer to handle the input messages.") + return errOutOfExpectations +} + +// Close corresponds with the Close method of sarama's SyncProducer implementation. +// By closing a mock syncproducer, you also tell it that no more SendMessage calls will follow, +// so it will write an error to the test state if there's any remaining expectations. +func (sp *SyncProducer) Close() error { + sp.l.Lock() + defer sp.l.Unlock() + + if len(sp.expectations) > 0 { + sp.t.Errorf("Expected to exhaust all expectations, but %d are left.", len(sp.expectations)) + } + + return nil +} + +//////////////////////////////////////////////// +// Setting expectations +//////////////////////////////////////////////// + +// ExpectSendMessageWithCheckerFunctionAndSucceed sets an expectation on the mock producer that SendMessage +// will be called. The mock producer will first call the given function to check the message value. +// It will cascade the error of the function, if any, or handle the message as if it produced +// successfully, i.e. by returning a valid partition, and offset, and a nil error. +func (sp *SyncProducer) ExpectSendMessageWithCheckerFunctionAndSucceed(cf ValueChecker) { + sp.l.Lock() + defer sp.l.Unlock() + sp.expectations = append(sp.expectations, &producerExpectation{Result: errProduceSuccess, CheckFunction: cf}) +} + +// ExpectSendMessageWithCheckerFunctionAndFail sets an expectation on the mock producer that SendMessage will be +// called. The mock producer will first call the given function to check the message value. +// It will cascade the error of the function, if any, or handle the message as if it failed +// to produce successfully, i.e. by returning the provided error. +func (sp *SyncProducer) ExpectSendMessageWithCheckerFunctionAndFail(cf ValueChecker, err error) { + sp.l.Lock() + defer sp.l.Unlock() + sp.expectations = append(sp.expectations, &producerExpectation{Result: err, CheckFunction: cf}) +} + +// ExpectSendMessageAndSucceed sets an expectation on the mock producer that SendMessage will be +// called. The mock producer will handle the message as if it produced successfully, i.e. by +// returning a valid partition, and offset, and a nil error. +func (sp *SyncProducer) ExpectSendMessageAndSucceed() { + sp.ExpectSendMessageWithCheckerFunctionAndSucceed(nil) +} + +// ExpectSendMessageAndFail sets an expectation on the mock producer that SendMessage will be +// called. The mock producer will handle the message as if it failed to produce +// successfully, i.e. by returning the provided error. +func (sp *SyncProducer) ExpectSendMessageAndFail(err error) { + sp.ExpectSendMessageWithCheckerFunctionAndFail(nil, err) +} diff --git a/vendor/github.com/Shopify/sarama/mocks/sync_producer_test.go b/vendor/github.com/Shopify/sarama/mocks/sync_producer_test.go new file mode 100644 index 00000000..0fdc9987 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/mocks/sync_producer_test.go @@ -0,0 +1,124 @@ +package mocks + +import ( + "strings" + "testing" + + "github.com/Shopify/sarama" +) + +func TestMockSyncProducerImplementsSyncProducerInterface(t *testing.T) { + var mp interface{} = &SyncProducer{} + if _, ok := mp.(sarama.SyncProducer); !ok { + t.Error("The mock async producer should implement the sarama.SyncProducer interface.") + } +} + +func TestSyncProducerReturnsExpectationsToSendMessage(t *testing.T) { + sp := NewSyncProducer(t, nil) + defer func() { + if err := sp.Close(); err != nil { + t.Error(err) + } + }() + + sp.ExpectSendMessageAndSucceed() + sp.ExpectSendMessageAndSucceed() + sp.ExpectSendMessageAndFail(sarama.ErrOutOfBrokers) + + msg := &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + + _, offset, err := sp.SendMessage(msg) + if err != nil { + t.Errorf("The first message should have been produced successfully, but got %s", err) + } + if offset != 1 || offset != msg.Offset { + t.Errorf("The first message should have been assigned offset 1, but got %d", msg.Offset) + } + + _, offset, err = sp.SendMessage(msg) + if err != nil { + t.Errorf("The second message should have been produced successfully, but got %s", err) + } + if offset != 2 || offset != msg.Offset { + t.Errorf("The second message should have been assigned offset 2, but got %d", offset) + } + + _, _, err = sp.SendMessage(msg) + if err != sarama.ErrOutOfBrokers { + t.Errorf("The third message should not have been produced successfully") + } + + if err := sp.Close(); err != nil { + t.Error(err) + } +} + +func TestSyncProducerWithTooManyExpectations(t *testing.T) { + trm := newTestReporterMock() + + sp := NewSyncProducer(trm, nil) + sp.ExpectSendMessageAndSucceed() + sp.ExpectSendMessageAndFail(sarama.ErrOutOfBrokers) + + msg := &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + if _, _, err := sp.SendMessage(msg); err != nil { + t.Error("No error expected on first SendMessage call", err) + } + + if err := sp.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Error("Expected to report an error") + } +} + +func TestSyncProducerWithTooFewExpectations(t *testing.T) { + trm := newTestReporterMock() + + sp := NewSyncProducer(trm, nil) + sp.ExpectSendMessageAndSucceed() + + msg := &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + if _, _, err := sp.SendMessage(msg); err != nil { + t.Error("No error expected on first SendMessage call", err) + } + if _, _, err := sp.SendMessage(msg); err != errOutOfExpectations { + t.Error("errOutOfExpectations expected on second SendMessage call, found:", err) + } + + if err := sp.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Error("Expected to report an error") + } +} + +func TestSyncProducerWithCheckerFunction(t *testing.T) { + trm := newTestReporterMock() + + sp := NewSyncProducer(trm, nil) + sp.ExpectSendMessageWithCheckerFunctionAndSucceed(generateRegexpChecker("^tes")) + sp.ExpectSendMessageWithCheckerFunctionAndSucceed(generateRegexpChecker("^tes$")) + + msg := &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + if _, _, err := sp.SendMessage(msg); err != nil { + t.Error("No error expected on first SendMessage call, found: ", err) + } + msg = &sarama.ProducerMessage{Topic: "test", Value: sarama.StringEncoder("test")} + if _, _, err := sp.SendMessage(msg); err == nil || !strings.HasPrefix(err.Error(), "No match") { + t.Error("Error during value check expected on second SendMessage call, found:", err) + } + + if err := sp.Close(); err != nil { + t.Error(err) + } + + if len(trm.errors) != 1 { + t.Error("Expected to report an error") + } +} diff --git a/vendor/github.com/Shopify/sarama/offset_commit_request.go b/vendor/github.com/Shopify/sarama/offset_commit_request.go new file mode 100644 index 00000000..b21ea634 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_commit_request.go @@ -0,0 +1,190 @@ +package sarama + +// ReceiveTime is a special value for the timestamp field of Offset Commit Requests which +// tells the broker to set the timestamp to the time at which the request was received. +// The timestamp is only used if message version 1 is used, which requires kafka 0.8.2. +const ReceiveTime int64 = -1 + +// GroupGenerationUndefined is a special value for the group generation field of +// Offset Commit Requests that should be used when a consumer group does not rely +// on Kafka for partition management. +const GroupGenerationUndefined = -1 + +type offsetCommitRequestBlock struct { + offset int64 + timestamp int64 + metadata string +} + +func (b *offsetCommitRequestBlock) encode(pe packetEncoder, version int16) error { + pe.putInt64(b.offset) + if version == 1 { + pe.putInt64(b.timestamp) + } else if b.timestamp != 0 { + Logger.Println("Non-zero timestamp specified for OffsetCommitRequest not v1, it will be ignored") + } + + return pe.putString(b.metadata) +} + +func (b *offsetCommitRequestBlock) decode(pd packetDecoder, version int16) (err error) { + if b.offset, err = pd.getInt64(); err != nil { + return err + } + if version == 1 { + if b.timestamp, err = pd.getInt64(); err != nil { + return err + } + } + b.metadata, err = pd.getString() + return err +} + +type OffsetCommitRequest struct { + ConsumerGroup string + ConsumerGroupGeneration int32 // v1 or later + ConsumerID string // v1 or later + RetentionTime int64 // v2 or later + + // Version can be: + // - 0 (kafka 0.8.1 and later) + // - 1 (kafka 0.8.2 and later) + // - 2 (kafka 0.9.0 and later) + Version int16 + blocks map[string]map[int32]*offsetCommitRequestBlock +} + +func (r *OffsetCommitRequest) encode(pe packetEncoder) error { + if r.Version < 0 || r.Version > 2 { + return PacketEncodingError{"invalid or unsupported OffsetCommitRequest version field"} + } + + if err := pe.putString(r.ConsumerGroup); err != nil { + return err + } + + if r.Version >= 1 { + pe.putInt32(r.ConsumerGroupGeneration) + if err := pe.putString(r.ConsumerID); err != nil { + return err + } + } else { + if r.ConsumerGroupGeneration != 0 { + Logger.Println("Non-zero ConsumerGroupGeneration specified for OffsetCommitRequest v0, it will be ignored") + } + if r.ConsumerID != "" { + Logger.Println("Non-empty ConsumerID specified for OffsetCommitRequest v0, it will be ignored") + } + } + + if r.Version >= 2 { + pe.putInt64(r.RetentionTime) + } else if r.RetentionTime != 0 { + Logger.Println("Non-zero RetentionTime specified for OffsetCommitRequest version <2, it will be ignored") + } + + if err := pe.putArrayLength(len(r.blocks)); err != nil { + return err + } + for topic, partitions := range r.blocks { + if err := pe.putString(topic); err != nil { + return err + } + if err := pe.putArrayLength(len(partitions)); err != nil { + return err + } + for partition, block := range partitions { + pe.putInt32(partition) + if err := block.encode(pe, r.Version); err != nil { + return err + } + } + } + return nil +} + +func (r *OffsetCommitRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + + if r.ConsumerGroup, err = pd.getString(); err != nil { + return err + } + + if r.Version >= 1 { + if r.ConsumerGroupGeneration, err = pd.getInt32(); err != nil { + return err + } + if r.ConsumerID, err = pd.getString(); err != nil { + return err + } + } + + if r.Version >= 2 { + if r.RetentionTime, err = pd.getInt64(); err != nil { + return err + } + } + + topicCount, err := pd.getArrayLength() + if err != nil { + return err + } + if topicCount == 0 { + return nil + } + r.blocks = make(map[string]map[int32]*offsetCommitRequestBlock) + for i := 0; i < topicCount; i++ { + topic, err := pd.getString() + if err != nil { + return err + } + partitionCount, err := pd.getArrayLength() + if err != nil { + return err + } + r.blocks[topic] = make(map[int32]*offsetCommitRequestBlock) + for j := 0; j < partitionCount; j++ { + partition, err := pd.getInt32() + if err != nil { + return err + } + block := &offsetCommitRequestBlock{} + if err := block.decode(pd, r.Version); err != nil { + return err + } + r.blocks[topic][partition] = block + } + } + return nil +} + +func (r *OffsetCommitRequest) key() int16 { + return 8 +} + +func (r *OffsetCommitRequest) version() int16 { + return r.Version +} + +func (r *OffsetCommitRequest) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_8_2_0 + case 2: + return V0_9_0_0 + default: + return minVersion + } +} + +func (r *OffsetCommitRequest) AddBlock(topic string, partitionID int32, offset int64, timestamp int64, metadata string) { + if r.blocks == nil { + r.blocks = make(map[string]map[int32]*offsetCommitRequestBlock) + } + + if r.blocks[topic] == nil { + r.blocks[topic] = make(map[int32]*offsetCommitRequestBlock) + } + + r.blocks[topic][partitionID] = &offsetCommitRequestBlock{offset, timestamp, metadata} +} diff --git a/vendor/github.com/Shopify/sarama/offset_commit_request_test.go b/vendor/github.com/Shopify/sarama/offset_commit_request_test.go new file mode 100644 index 00000000..afc25b7b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_commit_request_test.go @@ -0,0 +1,90 @@ +package sarama + +import "testing" + +var ( + offsetCommitRequestNoBlocksV0 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x00} + + offsetCommitRequestNoBlocksV1 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x11, 0x22, + 0x00, 0x04, 'c', 'o', 'n', 's', + 0x00, 0x00, 0x00, 0x00} + + offsetCommitRequestNoBlocksV2 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x11, 0x22, + 0x00, 0x04, 'c', 'o', 'n', 's', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x44, 0x33, + 0x00, 0x00, 0x00, 0x00} + + offsetCommitRequestOneBlockV0 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x52, 0x21, + 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF, + 0x00, 0x08, 'm', 'e', 't', 'a', 'd', 'a', 't', 'a'} + + offsetCommitRequestOneBlockV1 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x11, 0x22, + 0x00, 0x04, 'c', 'o', 'n', 's', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x52, 0x21, + 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x08, 'm', 'e', 't', 'a', 'd', 'a', 't', 'a'} + + offsetCommitRequestOneBlockV2 = []byte{ + 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + 0x00, 0x00, 0x11, 0x22, + 0x00, 0x04, 'c', 'o', 'n', 's', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x44, 0x33, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x52, 0x21, + 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, 0xBE, 0xEF, + 0x00, 0x08, 'm', 'e', 't', 'a', 'd', 'a', 't', 'a'} +) + +func TestOffsetCommitRequestV0(t *testing.T) { + request := new(OffsetCommitRequest) + request.Version = 0 + request.ConsumerGroup = "foobar" + testRequest(t, "no blocks v0", request, offsetCommitRequestNoBlocksV0) + + request.AddBlock("topic", 0x5221, 0xDEADBEEF, 0, "metadata") + testRequest(t, "one block v0", request, offsetCommitRequestOneBlockV0) +} + +func TestOffsetCommitRequestV1(t *testing.T) { + request := new(OffsetCommitRequest) + request.ConsumerGroup = "foobar" + request.ConsumerID = "cons" + request.ConsumerGroupGeneration = 0x1122 + request.Version = 1 + testRequest(t, "no blocks v1", request, offsetCommitRequestNoBlocksV1) + + request.AddBlock("topic", 0x5221, 0xDEADBEEF, ReceiveTime, "metadata") + testRequest(t, "one block v1", request, offsetCommitRequestOneBlockV1) +} + +func TestOffsetCommitRequestV2(t *testing.T) { + request := new(OffsetCommitRequest) + request.ConsumerGroup = "foobar" + request.ConsumerID = "cons" + request.ConsumerGroupGeneration = 0x1122 + request.RetentionTime = 0x4433 + request.Version = 2 + testRequest(t, "no blocks v2", request, offsetCommitRequestNoBlocksV2) + + request.AddBlock("topic", 0x5221, 0xDEADBEEF, 0, "metadata") + testRequest(t, "one block v2", request, offsetCommitRequestOneBlockV2) +} diff --git a/vendor/github.com/Shopify/sarama/offset_commit_response.go b/vendor/github.com/Shopify/sarama/offset_commit_response.go new file mode 100644 index 00000000..7f277e77 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_commit_response.go @@ -0,0 +1,85 @@ +package sarama + +type OffsetCommitResponse struct { + Errors map[string]map[int32]KError +} + +func (r *OffsetCommitResponse) AddError(topic string, partition int32, kerror KError) { + if r.Errors == nil { + r.Errors = make(map[string]map[int32]KError) + } + partitions := r.Errors[topic] + if partitions == nil { + partitions = make(map[int32]KError) + r.Errors[topic] = partitions + } + partitions[partition] = kerror +} + +func (r *OffsetCommitResponse) encode(pe packetEncoder) error { + if err := pe.putArrayLength(len(r.Errors)); err != nil { + return err + } + for topic, partitions := range r.Errors { + if err := pe.putString(topic); err != nil { + return err + } + if err := pe.putArrayLength(len(partitions)); err != nil { + return err + } + for partition, kerror := range partitions { + pe.putInt32(partition) + pe.putInt16(int16(kerror)) + } + } + return nil +} + +func (r *OffsetCommitResponse) decode(pd packetDecoder, version int16) (err error) { + numTopics, err := pd.getArrayLength() + if err != nil || numTopics == 0 { + return err + } + + r.Errors = make(map[string]map[int32]KError, numTopics) + for i := 0; i < numTopics; i++ { + name, err := pd.getString() + if err != nil { + return err + } + + numErrors, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Errors[name] = make(map[int32]KError, numErrors) + + for j := 0; j < numErrors; j++ { + id, err := pd.getInt32() + if err != nil { + return err + } + + tmp, err := pd.getInt16() + if err != nil { + return err + } + r.Errors[name][id] = KError(tmp) + } + } + + return nil +} + +func (r *OffsetCommitResponse) key() int16 { + return 8 +} + +func (r *OffsetCommitResponse) version() int16 { + return 0 +} + +func (r *OffsetCommitResponse) requiredVersion() KafkaVersion { + return minVersion +} diff --git a/vendor/github.com/Shopify/sarama/offset_commit_response_test.go b/vendor/github.com/Shopify/sarama/offset_commit_response_test.go new file mode 100644 index 00000000..074ec923 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_commit_response_test.go @@ -0,0 +1,24 @@ +package sarama + +import ( + "testing" +) + +var ( + emptyOffsetCommitResponse = []byte{ + 0x00, 0x00, 0x00, 0x00} +) + +func TestEmptyOffsetCommitResponse(t *testing.T) { + response := OffsetCommitResponse{} + testResponse(t, "empty", &response, emptyOffsetCommitResponse) +} + +func TestNormalOffsetCommitResponse(t *testing.T) { + response := OffsetCommitResponse{} + response.AddError("t", 0, ErrNotLeaderForPartition) + response.Errors["m"] = make(map[int32]KError) + // The response encoded form cannot be checked for it varies due to + // unpredictable map traversal order. + testResponse(t, "normal", &response, nil) +} diff --git a/vendor/github.com/Shopify/sarama/offset_fetch_request.go b/vendor/github.com/Shopify/sarama/offset_fetch_request.go new file mode 100644 index 00000000..b19fe79b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_fetch_request.go @@ -0,0 +1,81 @@ +package sarama + +type OffsetFetchRequest struct { + ConsumerGroup string + Version int16 + partitions map[string][]int32 +} + +func (r *OffsetFetchRequest) encode(pe packetEncoder) (err error) { + if r.Version < 0 || r.Version > 1 { + return PacketEncodingError{"invalid or unsupported OffsetFetchRequest version field"} + } + + if err = pe.putString(r.ConsumerGroup); err != nil { + return err + } + if err = pe.putArrayLength(len(r.partitions)); err != nil { + return err + } + for topic, partitions := range r.partitions { + if err = pe.putString(topic); err != nil { + return err + } + if err = pe.putInt32Array(partitions); err != nil { + return err + } + } + return nil +} + +func (r *OffsetFetchRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + if r.ConsumerGroup, err = pd.getString(); err != nil { + return err + } + partitionCount, err := pd.getArrayLength() + if err != nil { + return err + } + if partitionCount == 0 { + return nil + } + r.partitions = make(map[string][]int32) + for i := 0; i < partitionCount; i++ { + topic, err := pd.getString() + if err != nil { + return err + } + partitions, err := pd.getInt32Array() + if err != nil { + return err + } + r.partitions[topic] = partitions + } + return nil +} + +func (r *OffsetFetchRequest) key() int16 { + return 9 +} + +func (r *OffsetFetchRequest) version() int16 { + return r.Version +} + +func (r *OffsetFetchRequest) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_8_2_0 + default: + return minVersion + } +} + +func (r *OffsetFetchRequest) AddPartition(topic string, partitionID int32) { + if r.partitions == nil { + r.partitions = make(map[string][]int32) + } + + r.partitions[topic] = append(r.partitions[topic], partitionID) +} diff --git a/vendor/github.com/Shopify/sarama/offset_fetch_request_test.go b/vendor/github.com/Shopify/sarama/offset_fetch_request_test.go new file mode 100644 index 00000000..025d725c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_fetch_request_test.go @@ -0,0 +1,31 @@ +package sarama + +import "testing" + +var ( + offsetFetchRequestNoGroupNoPartitions = []byte{ + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00} + + offsetFetchRequestNoPartitions = []byte{ + 0x00, 0x04, 'b', 'l', 'a', 'h', + 0x00, 0x00, 0x00, 0x00} + + offsetFetchRequestOnePartition = []byte{ + 0x00, 0x04, 'b', 'l', 'a', 'h', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x0D, 't', 'o', 'p', 'i', 'c', 'T', 'h', 'e', 'F', 'i', 'r', 's', 't', + 0x00, 0x00, 0x00, 0x01, + 0x4F, 0x4F, 0x4F, 0x4F} +) + +func TestOffsetFetchRequest(t *testing.T) { + request := new(OffsetFetchRequest) + testRequest(t, "no group, no partitions", request, offsetFetchRequestNoGroupNoPartitions) + + request.ConsumerGroup = "blah" + testRequest(t, "no partitions", request, offsetFetchRequestNoPartitions) + + request.AddPartition("topicTheFirst", 0x4F4F4F4F) + testRequest(t, "one partition", request, offsetFetchRequestOnePartition) +} diff --git a/vendor/github.com/Shopify/sarama/offset_fetch_response.go b/vendor/github.com/Shopify/sarama/offset_fetch_response.go new file mode 100644 index 00000000..323220ea --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_fetch_response.go @@ -0,0 +1,143 @@ +package sarama + +type OffsetFetchResponseBlock struct { + Offset int64 + Metadata string + Err KError +} + +func (b *OffsetFetchResponseBlock) decode(pd packetDecoder) (err error) { + b.Offset, err = pd.getInt64() + if err != nil { + return err + } + + b.Metadata, err = pd.getString() + if err != nil { + return err + } + + tmp, err := pd.getInt16() + if err != nil { + return err + } + b.Err = KError(tmp) + + return nil +} + +func (b *OffsetFetchResponseBlock) encode(pe packetEncoder) (err error) { + pe.putInt64(b.Offset) + + err = pe.putString(b.Metadata) + if err != nil { + return err + } + + pe.putInt16(int16(b.Err)) + + return nil +} + +type OffsetFetchResponse struct { + Blocks map[string]map[int32]*OffsetFetchResponseBlock +} + +func (r *OffsetFetchResponse) encode(pe packetEncoder) error { + if err := pe.putArrayLength(len(r.Blocks)); err != nil { + return err + } + for topic, partitions := range r.Blocks { + if err := pe.putString(topic); err != nil { + return err + } + if err := pe.putArrayLength(len(partitions)); err != nil { + return err + } + for partition, block := range partitions { + pe.putInt32(partition) + if err := block.encode(pe); err != nil { + return err + } + } + } + return nil +} + +func (r *OffsetFetchResponse) decode(pd packetDecoder, version int16) (err error) { + numTopics, err := pd.getArrayLength() + if err != nil || numTopics == 0 { + return err + } + + r.Blocks = make(map[string]map[int32]*OffsetFetchResponseBlock, numTopics) + for i := 0; i < numTopics; i++ { + name, err := pd.getString() + if err != nil { + return err + } + + numBlocks, err := pd.getArrayLength() + if err != nil { + return err + } + + if numBlocks == 0 { + r.Blocks[name] = nil + continue + } + r.Blocks[name] = make(map[int32]*OffsetFetchResponseBlock, numBlocks) + + for j := 0; j < numBlocks; j++ { + id, err := pd.getInt32() + if err != nil { + return err + } + + block := new(OffsetFetchResponseBlock) + err = block.decode(pd) + if err != nil { + return err + } + r.Blocks[name][id] = block + } + } + + return nil +} + +func (r *OffsetFetchResponse) key() int16 { + return 9 +} + +func (r *OffsetFetchResponse) version() int16 { + return 0 +} + +func (r *OffsetFetchResponse) requiredVersion() KafkaVersion { + return minVersion +} + +func (r *OffsetFetchResponse) GetBlock(topic string, partition int32) *OffsetFetchResponseBlock { + if r.Blocks == nil { + return nil + } + + if r.Blocks[topic] == nil { + return nil + } + + return r.Blocks[topic][partition] +} + +func (r *OffsetFetchResponse) AddBlock(topic string, partition int32, block *OffsetFetchResponseBlock) { + if r.Blocks == nil { + r.Blocks = make(map[string]map[int32]*OffsetFetchResponseBlock) + } + partitions := r.Blocks[topic] + if partitions == nil { + partitions = make(map[int32]*OffsetFetchResponseBlock) + r.Blocks[topic] = partitions + } + partitions[partition] = block +} diff --git a/vendor/github.com/Shopify/sarama/offset_fetch_response_test.go b/vendor/github.com/Shopify/sarama/offset_fetch_response_test.go new file mode 100644 index 00000000..7614ae42 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_fetch_response_test.go @@ -0,0 +1,22 @@ +package sarama + +import "testing" + +var ( + emptyOffsetFetchResponse = []byte{ + 0x00, 0x00, 0x00, 0x00} +) + +func TestEmptyOffsetFetchResponse(t *testing.T) { + response := OffsetFetchResponse{} + testResponse(t, "empty", &response, emptyOffsetFetchResponse) +} + +func TestNormalOffsetFetchResponse(t *testing.T) { + response := OffsetFetchResponse{} + response.AddBlock("t", 0, &OffsetFetchResponseBlock{0, "md", ErrRequestTimedOut}) + response.Blocks["m"] = nil + // The response encoded form cannot be checked for it varies due to + // unpredictable map traversal order. + testResponse(t, "normal", &response, nil) +} diff --git a/vendor/github.com/Shopify/sarama/offset_manager.go b/vendor/github.com/Shopify/sarama/offset_manager.go new file mode 100644 index 00000000..5e15cdaf --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_manager.go @@ -0,0 +1,542 @@ +package sarama + +import ( + "sync" + "time" +) + +// Offset Manager + +// OffsetManager uses Kafka to store and fetch consumed partition offsets. +type OffsetManager interface { + // ManagePartition creates a PartitionOffsetManager on the given topic/partition. + // It will return an error if this OffsetManager is already managing the given + // topic/partition. + ManagePartition(topic string, partition int32) (PartitionOffsetManager, error) + + // Close stops the OffsetManager from managing offsets. It is required to call + // this function before an OffsetManager object passes out of scope, as it + // will otherwise leak memory. You must call this after all the + // PartitionOffsetManagers are closed. + Close() error +} + +type offsetManager struct { + client Client + conf *Config + group string + + lock sync.Mutex + poms map[string]map[int32]*partitionOffsetManager + boms map[*Broker]*brokerOffsetManager +} + +// NewOffsetManagerFromClient creates a new OffsetManager from the given client. +// It is still necessary to call Close() on the underlying client when finished with the partition manager. +func NewOffsetManagerFromClient(group string, client Client) (OffsetManager, error) { + // Check that we are not dealing with a closed Client before processing any other arguments + if client.Closed() { + return nil, ErrClosedClient + } + + om := &offsetManager{ + client: client, + conf: client.Config(), + group: group, + poms: make(map[string]map[int32]*partitionOffsetManager), + boms: make(map[*Broker]*brokerOffsetManager), + } + + return om, nil +} + +func (om *offsetManager) ManagePartition(topic string, partition int32) (PartitionOffsetManager, error) { + pom, err := om.newPartitionOffsetManager(topic, partition) + if err != nil { + return nil, err + } + + om.lock.Lock() + defer om.lock.Unlock() + + topicManagers := om.poms[topic] + if topicManagers == nil { + topicManagers = make(map[int32]*partitionOffsetManager) + om.poms[topic] = topicManagers + } + + if topicManagers[partition] != nil { + return nil, ConfigurationError("That topic/partition is already being managed") + } + + topicManagers[partition] = pom + return pom, nil +} + +func (om *offsetManager) Close() error { + return nil +} + +func (om *offsetManager) refBrokerOffsetManager(broker *Broker) *brokerOffsetManager { + om.lock.Lock() + defer om.lock.Unlock() + + bom := om.boms[broker] + if bom == nil { + bom = om.newBrokerOffsetManager(broker) + om.boms[broker] = bom + } + + bom.refs++ + + return bom +} + +func (om *offsetManager) unrefBrokerOffsetManager(bom *brokerOffsetManager) { + om.lock.Lock() + defer om.lock.Unlock() + + bom.refs-- + + if bom.refs == 0 { + close(bom.updateSubscriptions) + if om.boms[bom.broker] == bom { + delete(om.boms, bom.broker) + } + } +} + +func (om *offsetManager) abandonBroker(bom *brokerOffsetManager) { + om.lock.Lock() + defer om.lock.Unlock() + + delete(om.boms, bom.broker) +} + +func (om *offsetManager) abandonPartitionOffsetManager(pom *partitionOffsetManager) { + om.lock.Lock() + defer om.lock.Unlock() + + delete(om.poms[pom.topic], pom.partition) + if len(om.poms[pom.topic]) == 0 { + delete(om.poms, pom.topic) + } +} + +// Partition Offset Manager + +// PartitionOffsetManager uses Kafka to store and fetch consumed partition offsets. You MUST call Close() +// on a partition offset manager to avoid leaks, it will not be garbage-collected automatically when it passes +// out of scope. +type PartitionOffsetManager interface { + // NextOffset returns the next offset that should be consumed for the managed + // partition, accompanied by metadata which can be used to reconstruct the state + // of the partition consumer when it resumes. NextOffset() will return + // `config.Consumer.Offsets.Initial` and an empty metadata string if no offset + // was committed for this partition yet. + NextOffset() (int64, string) + + // MarkOffset marks the provided offset, alongside a metadata string + // that represents the state of the partition consumer at that point in time. The + // metadata string can be used by another consumer to restore that state, so it + // can resume consumption. + // + // To follow upstream conventions, you are expected to mark the offset of the + // next message to read, not the last message read. Thus, when calling `MarkOffset` + // you should typically add one to the offset of the last consumed message. + // + // Note: calling MarkOffset does not necessarily commit the offset to the backend + // store immediately for efficiency reasons, and it may never be committed if + // your application crashes. This means that you may end up processing the same + // message twice, and your processing should ideally be idempotent. + MarkOffset(offset int64, metadata string) + + // Errors returns a read channel of errors that occur during offset management, if + // enabled. By default, errors are logged and not returned over this channel. If + // you want to implement any custom error handling, set your config's + // Consumer.Return.Errors setting to true, and read from this channel. + Errors() <-chan *ConsumerError + + // AsyncClose initiates a shutdown of the PartitionOffsetManager. This method will + // return immediately, after which you should wait until the 'errors' channel has + // been drained and closed. It is required to call this function, or Close before + // a consumer object passes out of scope, as it will otherwise leak memory. You + // must call this before calling Close on the underlying client. + AsyncClose() + + // Close stops the PartitionOffsetManager from managing offsets. It is required to + // call this function (or AsyncClose) before a PartitionOffsetManager object + // passes out of scope, as it will otherwise leak memory. You must call this + // before calling Close on the underlying client. + Close() error +} + +type partitionOffsetManager struct { + parent *offsetManager + topic string + partition int32 + + lock sync.Mutex + offset int64 + metadata string + dirty bool + clean sync.Cond + broker *brokerOffsetManager + + errors chan *ConsumerError + rebalance chan none + dying chan none +} + +func (om *offsetManager) newPartitionOffsetManager(topic string, partition int32) (*partitionOffsetManager, error) { + pom := &partitionOffsetManager{ + parent: om, + topic: topic, + partition: partition, + errors: make(chan *ConsumerError, om.conf.ChannelBufferSize), + rebalance: make(chan none, 1), + dying: make(chan none), + } + pom.clean.L = &pom.lock + + if err := pom.selectBroker(); err != nil { + return nil, err + } + + if err := pom.fetchInitialOffset(om.conf.Metadata.Retry.Max); err != nil { + return nil, err + } + + pom.broker.updateSubscriptions <- pom + + go withRecover(pom.mainLoop) + + return pom, nil +} + +func (pom *partitionOffsetManager) mainLoop() { + for { + select { + case <-pom.rebalance: + if err := pom.selectBroker(); err != nil { + pom.handleError(err) + pom.rebalance <- none{} + } else { + pom.broker.updateSubscriptions <- pom + } + case <-pom.dying: + if pom.broker != nil { + select { + case <-pom.rebalance: + case pom.broker.updateSubscriptions <- pom: + } + pom.parent.unrefBrokerOffsetManager(pom.broker) + } + pom.parent.abandonPartitionOffsetManager(pom) + close(pom.errors) + return + } + } +} + +func (pom *partitionOffsetManager) selectBroker() error { + if pom.broker != nil { + pom.parent.unrefBrokerOffsetManager(pom.broker) + pom.broker = nil + } + + var broker *Broker + var err error + + if err = pom.parent.client.RefreshCoordinator(pom.parent.group); err != nil { + return err + } + + if broker, err = pom.parent.client.Coordinator(pom.parent.group); err != nil { + return err + } + + pom.broker = pom.parent.refBrokerOffsetManager(broker) + return nil +} + +func (pom *partitionOffsetManager) fetchInitialOffset(retries int) error { + request := new(OffsetFetchRequest) + request.Version = 1 + request.ConsumerGroup = pom.parent.group + request.AddPartition(pom.topic, pom.partition) + + response, err := pom.broker.broker.FetchOffset(request) + if err != nil { + return err + } + + block := response.GetBlock(pom.topic, pom.partition) + if block == nil { + return ErrIncompleteResponse + } + + switch block.Err { + case ErrNoError: + pom.offset = block.Offset + pom.metadata = block.Metadata + return nil + case ErrNotCoordinatorForConsumer: + if retries <= 0 { + return block.Err + } + if err := pom.selectBroker(); err != nil { + return err + } + return pom.fetchInitialOffset(retries - 1) + case ErrOffsetsLoadInProgress: + if retries <= 0 { + return block.Err + } + time.Sleep(pom.parent.conf.Metadata.Retry.Backoff) + return pom.fetchInitialOffset(retries - 1) + default: + return block.Err + } +} + +func (pom *partitionOffsetManager) handleError(err error) { + cErr := &ConsumerError{ + Topic: pom.topic, + Partition: pom.partition, + Err: err, + } + + if pom.parent.conf.Consumer.Return.Errors { + pom.errors <- cErr + } else { + Logger.Println(cErr) + } +} + +func (pom *partitionOffsetManager) Errors() <-chan *ConsumerError { + return pom.errors +} + +func (pom *partitionOffsetManager) MarkOffset(offset int64, metadata string) { + pom.lock.Lock() + defer pom.lock.Unlock() + + if offset > pom.offset { + pom.offset = offset + pom.metadata = metadata + pom.dirty = true + } +} + +func (pom *partitionOffsetManager) updateCommitted(offset int64, metadata string) { + pom.lock.Lock() + defer pom.lock.Unlock() + + if pom.offset == offset && pom.metadata == metadata { + pom.dirty = false + pom.clean.Signal() + } +} + +func (pom *partitionOffsetManager) NextOffset() (int64, string) { + pom.lock.Lock() + defer pom.lock.Unlock() + + if pom.offset >= 0 { + return pom.offset, pom.metadata + } + + return pom.parent.conf.Consumer.Offsets.Initial, "" +} + +func (pom *partitionOffsetManager) AsyncClose() { + go func() { + pom.lock.Lock() + defer pom.lock.Unlock() + + for pom.dirty { + pom.clean.Wait() + } + + close(pom.dying) + }() +} + +func (pom *partitionOffsetManager) Close() error { + pom.AsyncClose() + + var errors ConsumerErrors + for err := range pom.errors { + errors = append(errors, err) + } + + if len(errors) > 0 { + return errors + } + return nil +} + +// Broker Offset Manager + +type brokerOffsetManager struct { + parent *offsetManager + broker *Broker + timer *time.Ticker + updateSubscriptions chan *partitionOffsetManager + subscriptions map[*partitionOffsetManager]none + refs int +} + +func (om *offsetManager) newBrokerOffsetManager(broker *Broker) *brokerOffsetManager { + bom := &brokerOffsetManager{ + parent: om, + broker: broker, + timer: time.NewTicker(om.conf.Consumer.Offsets.CommitInterval), + updateSubscriptions: make(chan *partitionOffsetManager), + subscriptions: make(map[*partitionOffsetManager]none), + } + + go withRecover(bom.mainLoop) + + return bom +} + +func (bom *brokerOffsetManager) mainLoop() { + for { + select { + case <-bom.timer.C: + if len(bom.subscriptions) > 0 { + bom.flushToBroker() + } + case s, ok := <-bom.updateSubscriptions: + if !ok { + bom.timer.Stop() + return + } + if _, ok := bom.subscriptions[s]; ok { + delete(bom.subscriptions, s) + } else { + bom.subscriptions[s] = none{} + } + } + } +} + +func (bom *brokerOffsetManager) flushToBroker() { + request := bom.constructRequest() + if request == nil { + return + } + + response, err := bom.broker.CommitOffset(request) + + if err != nil { + bom.abort(err) + return + } + + for s := range bom.subscriptions { + if request.blocks[s.topic] == nil || request.blocks[s.topic][s.partition] == nil { + continue + } + + var err KError + var ok bool + + if response.Errors[s.topic] == nil { + s.handleError(ErrIncompleteResponse) + delete(bom.subscriptions, s) + s.rebalance <- none{} + continue + } + if err, ok = response.Errors[s.topic][s.partition]; !ok { + s.handleError(ErrIncompleteResponse) + delete(bom.subscriptions, s) + s.rebalance <- none{} + continue + } + + switch err { + case ErrNoError: + block := request.blocks[s.topic][s.partition] + s.updateCommitted(block.offset, block.metadata) + case ErrNotLeaderForPartition, ErrLeaderNotAvailable, + ErrConsumerCoordinatorNotAvailable, ErrNotCoordinatorForConsumer: + // not a critical error, we just need to redispatch + delete(bom.subscriptions, s) + s.rebalance <- none{} + case ErrOffsetMetadataTooLarge, ErrInvalidCommitOffsetSize: + // nothing we can do about this, just tell the user and carry on + s.handleError(err) + case ErrOffsetsLoadInProgress: + // nothing wrong but we didn't commit, we'll get it next time round + break + case ErrUnknownTopicOrPartition: + // let the user know *and* try redispatching - if topic-auto-create is + // enabled, redispatching should trigger a metadata request and create the + // topic; if not then re-dispatching won't help, but we've let the user + // know and it shouldn't hurt either (see https://github.com/Shopify/sarama/issues/706) + fallthrough + default: + // dunno, tell the user and try redispatching + s.handleError(err) + delete(bom.subscriptions, s) + s.rebalance <- none{} + } + } +} + +func (bom *brokerOffsetManager) constructRequest() *OffsetCommitRequest { + var r *OffsetCommitRequest + var perPartitionTimestamp int64 + if bom.parent.conf.Consumer.Offsets.Retention == 0 { + perPartitionTimestamp = ReceiveTime + r = &OffsetCommitRequest{ + Version: 1, + ConsumerGroup: bom.parent.group, + ConsumerGroupGeneration: GroupGenerationUndefined, + } + } else { + r = &OffsetCommitRequest{ + Version: 2, + RetentionTime: int64(bom.parent.conf.Consumer.Offsets.Retention / time.Millisecond), + ConsumerGroup: bom.parent.group, + ConsumerGroupGeneration: GroupGenerationUndefined, + } + + } + + for s := range bom.subscriptions { + s.lock.Lock() + if s.dirty { + r.AddBlock(s.topic, s.partition, s.offset, perPartitionTimestamp, s.metadata) + } + s.lock.Unlock() + } + + if len(r.blocks) > 0 { + return r + } + + return nil +} + +func (bom *brokerOffsetManager) abort(err error) { + _ = bom.broker.Close() // we don't care about the error this might return, we already have one + bom.parent.abandonBroker(bom) + + for pom := range bom.subscriptions { + pom.handleError(err) + pom.rebalance <- none{} + } + + for s := range bom.updateSubscriptions { + if _, ok := bom.subscriptions[s]; !ok { + s.handleError(err) + s.rebalance <- none{} + } + } + + bom.subscriptions = make(map[*partitionOffsetManager]none) +} diff --git a/vendor/github.com/Shopify/sarama/offset_manager_test.go b/vendor/github.com/Shopify/sarama/offset_manager_test.go new file mode 100644 index 00000000..c111a5a6 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_manager_test.go @@ -0,0 +1,369 @@ +package sarama + +import ( + "testing" + "time" +) + +func initOffsetManager(t *testing.T) (om OffsetManager, + testClient Client, broker, coordinator *MockBroker) { + + config := NewConfig() + config.Metadata.Retry.Max = 1 + config.Consumer.Offsets.CommitInterval = 1 * time.Millisecond + config.Version = V0_9_0_0 + + broker = NewMockBroker(t, 1) + coordinator = NewMockBroker(t, 2) + + seedMeta := new(MetadataResponse) + seedMeta.AddBroker(coordinator.Addr(), coordinator.BrokerID()) + seedMeta.AddTopicPartition("my_topic", 0, 1, []int32{}, []int32{}, ErrNoError) + seedMeta.AddTopicPartition("my_topic", 1, 1, []int32{}, []int32{}, ErrNoError) + broker.Returns(seedMeta) + + var err error + testClient, err = NewClient([]string{broker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: coordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: coordinator.Port(), + }) + + om, err = NewOffsetManagerFromClient("group", testClient) + if err != nil { + t.Fatal(err) + } + + return om, testClient, broker, coordinator +} + +func initPartitionOffsetManager(t *testing.T, om OffsetManager, + coordinator *MockBroker, initialOffset int64, metadata string) PartitionOffsetManager { + + fetchResponse := new(OffsetFetchResponse) + fetchResponse.AddBlock("my_topic", 0, &OffsetFetchResponseBlock{ + Err: ErrNoError, + Offset: initialOffset, + Metadata: metadata, + }) + coordinator.Returns(fetchResponse) + + pom, err := om.ManagePartition("my_topic", 0) + if err != nil { + t.Fatal(err) + } + + return pom +} + +func TestNewOffsetManager(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + seedBroker.Returns(new(MetadataResponse)) + + testClient, err := NewClient([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + _, err = NewOffsetManagerFromClient("group", testClient) + if err != nil { + t.Error(err) + } + + safeClose(t, testClient) + + _, err = NewOffsetManagerFromClient("group", testClient) + if err != ErrClosedClient { + t.Errorf("Error expected for closed client; actual value: %v", err) + } + + seedBroker.Close() +} + +// Test recovery from ErrNotCoordinatorForConsumer +// on first fetchInitialOffset call +func TestOffsetManagerFetchInitialFail(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + + // Error on first fetchInitialOffset call + responseBlock := OffsetFetchResponseBlock{ + Err: ErrNotCoordinatorForConsumer, + Offset: 5, + Metadata: "test_meta", + } + + fetchResponse := new(OffsetFetchResponse) + fetchResponse.AddBlock("my_topic", 0, &responseBlock) + coordinator.Returns(fetchResponse) + + // Refresh coordinator + newCoordinator := NewMockBroker(t, 3) + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + // Second fetchInitialOffset call is fine + fetchResponse2 := new(OffsetFetchResponse) + responseBlock2 := responseBlock + responseBlock2.Err = ErrNoError + fetchResponse2.AddBlock("my_topic", 0, &responseBlock2) + newCoordinator.Returns(fetchResponse2) + + pom, err := om.ManagePartition("my_topic", 0) + if err != nil { + t.Error(err) + } + + broker.Close() + coordinator.Close() + newCoordinator.Close() + safeClose(t, pom) + safeClose(t, om) + safeClose(t, testClient) +} + +// Test fetchInitialOffset retry on ErrOffsetsLoadInProgress +func TestOffsetManagerFetchInitialLoadInProgress(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + + // Error on first fetchInitialOffset call + responseBlock := OffsetFetchResponseBlock{ + Err: ErrOffsetsLoadInProgress, + Offset: 5, + Metadata: "test_meta", + } + + fetchResponse := new(OffsetFetchResponse) + fetchResponse.AddBlock("my_topic", 0, &responseBlock) + coordinator.Returns(fetchResponse) + + // Second fetchInitialOffset call is fine + fetchResponse2 := new(OffsetFetchResponse) + responseBlock2 := responseBlock + responseBlock2.Err = ErrNoError + fetchResponse2.AddBlock("my_topic", 0, &responseBlock2) + coordinator.Returns(fetchResponse2) + + pom, err := om.ManagePartition("my_topic", 0) + if err != nil { + t.Error(err) + } + + broker.Close() + coordinator.Close() + safeClose(t, pom) + safeClose(t, om) + safeClose(t, testClient) +} + +func TestPartitionOffsetManagerInitialOffset(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + testClient.Config().Consumer.Offsets.Initial = OffsetOldest + + // Kafka returns -1 if no offset has been stored for this partition yet. + pom := initPartitionOffsetManager(t, om, coordinator, -1, "") + + offset, meta := pom.NextOffset() + if offset != OffsetOldest { + t.Errorf("Expected offset 5. Actual: %v", offset) + } + if meta != "" { + t.Errorf("Expected metadata to be empty. Actual: %q", meta) + } + + safeClose(t, pom) + safeClose(t, om) + broker.Close() + coordinator.Close() + safeClose(t, testClient) +} + +func TestPartitionOffsetManagerNextOffset(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + pom := initPartitionOffsetManager(t, om, coordinator, 5, "test_meta") + + offset, meta := pom.NextOffset() + if offset != 5 { + t.Errorf("Expected offset 5. Actual: %v", offset) + } + if meta != "test_meta" { + t.Errorf("Expected metadata \"test_meta\". Actual: %q", meta) + } + + safeClose(t, pom) + safeClose(t, om) + broker.Close() + coordinator.Close() + safeClose(t, testClient) +} + +func TestPartitionOffsetManagerMarkOffset(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + pom := initPartitionOffsetManager(t, om, coordinator, 5, "original_meta") + + ocResponse := new(OffsetCommitResponse) + ocResponse.AddError("my_topic", 0, ErrNoError) + coordinator.Returns(ocResponse) + + pom.MarkOffset(100, "modified_meta") + offset, meta := pom.NextOffset() + + if offset != 100 { + t.Errorf("Expected offset 100. Actual: %v", offset) + } + if meta != "modified_meta" { + t.Errorf("Expected metadata \"modified_meta\". Actual: %q", meta) + } + + safeClose(t, pom) + safeClose(t, om) + safeClose(t, testClient) + broker.Close() + coordinator.Close() +} + +func TestPartitionOffsetManagerMarkOffsetWithRetention(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + testClient.Config().Consumer.Offsets.Retention = time.Hour + + pom := initPartitionOffsetManager(t, om, coordinator, 5, "original_meta") + + ocResponse := new(OffsetCommitResponse) + ocResponse.AddError("my_topic", 0, ErrNoError) + handler := func(req *request) (res encoder) { + if req.body.version() != 2 { + t.Errorf("Expected to be using version 2. Actual: %v", req.body.version()) + } + offsetCommitRequest := req.body.(*OffsetCommitRequest) + if offsetCommitRequest.RetentionTime != (60 * 60 * 1000) { + t.Errorf("Expected an hour retention time. Actual: %v", offsetCommitRequest.RetentionTime) + } + return ocResponse + } + coordinator.setHandler(handler) + + pom.MarkOffset(100, "modified_meta") + offset, meta := pom.NextOffset() + + if offset != 100 { + t.Errorf("Expected offset 100. Actual: %v", offset) + } + if meta != "modified_meta" { + t.Errorf("Expected metadata \"modified_meta\". Actual: %q", meta) + } + + safeClose(t, pom) + safeClose(t, om) + safeClose(t, testClient) + broker.Close() + coordinator.Close() +} + +func TestPartitionOffsetManagerCommitErr(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + pom := initPartitionOffsetManager(t, om, coordinator, 5, "meta") + + // Error on one partition + ocResponse := new(OffsetCommitResponse) + ocResponse.AddError("my_topic", 0, ErrOffsetOutOfRange) + ocResponse.AddError("my_topic", 1, ErrNoError) + coordinator.Returns(ocResponse) + + newCoordinator := NewMockBroker(t, 3) + + // For RefreshCoordinator() + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + // Nothing in response.Errors at all + ocResponse2 := new(OffsetCommitResponse) + newCoordinator.Returns(ocResponse2) + + // For RefreshCoordinator() + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + // Error on the wrong partition for this pom + ocResponse3 := new(OffsetCommitResponse) + ocResponse3.AddError("my_topic", 1, ErrNoError) + newCoordinator.Returns(ocResponse3) + + // For RefreshCoordinator() + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + // ErrUnknownTopicOrPartition/ErrNotLeaderForPartition/ErrLeaderNotAvailable block + ocResponse4 := new(OffsetCommitResponse) + ocResponse4.AddError("my_topic", 0, ErrUnknownTopicOrPartition) + newCoordinator.Returns(ocResponse4) + + // For RefreshCoordinator() + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + // Normal error response + ocResponse5 := new(OffsetCommitResponse) + ocResponse5.AddError("my_topic", 0, ErrNoError) + newCoordinator.Returns(ocResponse5) + + pom.MarkOffset(100, "modified_meta") + + err := pom.Close() + if err != nil { + t.Error(err) + } + + broker.Close() + coordinator.Close() + newCoordinator.Close() + safeClose(t, om) + safeClose(t, testClient) +} + +// Test of recovery from abort +func TestAbortPartitionOffsetManager(t *testing.T) { + om, testClient, broker, coordinator := initOffsetManager(t) + pom := initPartitionOffsetManager(t, om, coordinator, 5, "meta") + + // this triggers an error in the CommitOffset request, + // which leads to the abort call + coordinator.Close() + + // Response to refresh coordinator request + newCoordinator := NewMockBroker(t, 3) + broker.Returns(&ConsumerMetadataResponse{ + CoordinatorID: newCoordinator.BrokerID(), + CoordinatorHost: "127.0.0.1", + CoordinatorPort: newCoordinator.Port(), + }) + + ocResponse := new(OffsetCommitResponse) + ocResponse.AddError("my_topic", 0, ErrNoError) + newCoordinator.Returns(ocResponse) + + pom.MarkOffset(100, "modified_meta") + + safeClose(t, pom) + safeClose(t, om) + broker.Close() + safeClose(t, testClient) +} diff --git a/vendor/github.com/Shopify/sarama/offset_request.go b/vendor/github.com/Shopify/sarama/offset_request.go new file mode 100644 index 00000000..6c269601 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_request.go @@ -0,0 +1,132 @@ +package sarama + +type offsetRequestBlock struct { + time int64 + maxOffsets int32 // Only used in version 0 +} + +func (b *offsetRequestBlock) encode(pe packetEncoder, version int16) error { + pe.putInt64(int64(b.time)) + if version == 0 { + pe.putInt32(b.maxOffsets) + } + + return nil +} + +func (b *offsetRequestBlock) decode(pd packetDecoder, version int16) (err error) { + if b.time, err = pd.getInt64(); err != nil { + return err + } + if version == 0 { + if b.maxOffsets, err = pd.getInt32(); err != nil { + return err + } + } + return nil +} + +type OffsetRequest struct { + Version int16 + blocks map[string]map[int32]*offsetRequestBlock +} + +func (r *OffsetRequest) encode(pe packetEncoder) error { + pe.putInt32(-1) // replica ID is always -1 for clients + err := pe.putArrayLength(len(r.blocks)) + if err != nil { + return err + } + for topic, partitions := range r.blocks { + err = pe.putString(topic) + if err != nil { + return err + } + err = pe.putArrayLength(len(partitions)) + if err != nil { + return err + } + for partition, block := range partitions { + pe.putInt32(partition) + if err = block.encode(pe, r.Version); err != nil { + return err + } + } + } + return nil +} + +func (r *OffsetRequest) decode(pd packetDecoder, version int16) error { + r.Version = version + + // Ignore replica ID + if _, err := pd.getInt32(); err != nil { + return err + } + blockCount, err := pd.getArrayLength() + if err != nil { + return err + } + if blockCount == 0 { + return nil + } + r.blocks = make(map[string]map[int32]*offsetRequestBlock) + for i := 0; i < blockCount; i++ { + topic, err := pd.getString() + if err != nil { + return err + } + partitionCount, err := pd.getArrayLength() + if err != nil { + return err + } + r.blocks[topic] = make(map[int32]*offsetRequestBlock) + for j := 0; j < partitionCount; j++ { + partition, err := pd.getInt32() + if err != nil { + return err + } + block := &offsetRequestBlock{} + if err := block.decode(pd, version); err != nil { + return err + } + r.blocks[topic][partition] = block + } + } + return nil +} + +func (r *OffsetRequest) key() int16 { + return 2 +} + +func (r *OffsetRequest) version() int16 { + return r.Version +} + +func (r *OffsetRequest) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_10_1_0 + default: + return minVersion + } +} + +func (r *OffsetRequest) AddBlock(topic string, partitionID int32, time int64, maxOffsets int32) { + if r.blocks == nil { + r.blocks = make(map[string]map[int32]*offsetRequestBlock) + } + + if r.blocks[topic] == nil { + r.blocks[topic] = make(map[int32]*offsetRequestBlock) + } + + tmp := new(offsetRequestBlock) + tmp.time = time + if r.Version == 0 { + tmp.maxOffsets = maxOffsets + } + + r.blocks[topic][partitionID] = tmp +} diff --git a/vendor/github.com/Shopify/sarama/offset_request_test.go b/vendor/github.com/Shopify/sarama/offset_request_test.go new file mode 100644 index 00000000..9ce562c9 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_request_test.go @@ -0,0 +1,43 @@ +package sarama + +import "testing" + +var ( + offsetRequestNoBlocks = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00} + + offsetRequestOneBlock = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02} + + offsetRequestOneBlockV1 = []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x03, 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} +) + +func TestOffsetRequest(t *testing.T) { + request := new(OffsetRequest) + testRequest(t, "no blocks", request, offsetRequestNoBlocks) + + request.AddBlock("foo", 4, 1, 2) + testRequest(t, "one block", request, offsetRequestOneBlock) +} + +func TestOffsetRequestV1(t *testing.T) { + request := new(OffsetRequest) + request.Version = 1 + testRequest(t, "no blocks", request, offsetRequestNoBlocks) + + request.AddBlock("bar", 4, 1, 2) // Last argument is ignored for V1 + testRequest(t, "one block", request, offsetRequestOneBlockV1) +} diff --git a/vendor/github.com/Shopify/sarama/offset_response.go b/vendor/github.com/Shopify/sarama/offset_response.go new file mode 100644 index 00000000..9a9cfe96 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_response.go @@ -0,0 +1,174 @@ +package sarama + +type OffsetResponseBlock struct { + Err KError + Offsets []int64 // Version 0 + Offset int64 // Version 1 + Timestamp int64 // Version 1 +} + +func (b *OffsetResponseBlock) decode(pd packetDecoder, version int16) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + b.Err = KError(tmp) + + if version == 0 { + b.Offsets, err = pd.getInt64Array() + + return err + } + + b.Timestamp, err = pd.getInt64() + if err != nil { + return err + } + + b.Offset, err = pd.getInt64() + if err != nil { + return err + } + + // For backwards compatibility put the offset in the offsets array too + b.Offsets = []int64{b.Offset} + + return nil +} + +func (b *OffsetResponseBlock) encode(pe packetEncoder, version int16) (err error) { + pe.putInt16(int16(b.Err)) + + if version == 0 { + return pe.putInt64Array(b.Offsets) + } + + pe.putInt64(b.Timestamp) + pe.putInt64(b.Offset) + + return nil +} + +type OffsetResponse struct { + Version int16 + Blocks map[string]map[int32]*OffsetResponseBlock +} + +func (r *OffsetResponse) decode(pd packetDecoder, version int16) (err error) { + numTopics, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks = make(map[string]map[int32]*OffsetResponseBlock, numTopics) + for i := 0; i < numTopics; i++ { + name, err := pd.getString() + if err != nil { + return err + } + + numBlocks, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks[name] = make(map[int32]*OffsetResponseBlock, numBlocks) + + for j := 0; j < numBlocks; j++ { + id, err := pd.getInt32() + if err != nil { + return err + } + + block := new(OffsetResponseBlock) + err = block.decode(pd, version) + if err != nil { + return err + } + r.Blocks[name][id] = block + } + } + + return nil +} + +func (r *OffsetResponse) GetBlock(topic string, partition int32) *OffsetResponseBlock { + if r.Blocks == nil { + return nil + } + + if r.Blocks[topic] == nil { + return nil + } + + return r.Blocks[topic][partition] +} + +/* +// [0 0 0 1 ntopics +0 8 109 121 95 116 111 112 105 99 topic +0 0 0 1 npartitions +0 0 0 0 id +0 0 + +0 0 0 1 0 0 0 0 +0 1 1 1 0 0 0 1 +0 8 109 121 95 116 111 112 +105 99 0 0 0 1 0 0 +0 0 0 0 0 0 0 1 +0 0 0 0 0 1 1 1] + +*/ +func (r *OffsetResponse) encode(pe packetEncoder) (err error) { + if err = pe.putArrayLength(len(r.Blocks)); err != nil { + return err + } + + for topic, partitions := range r.Blocks { + if err = pe.putString(topic); err != nil { + return err + } + if err = pe.putArrayLength(len(partitions)); err != nil { + return err + } + for partition, block := range partitions { + pe.putInt32(partition) + if err = block.encode(pe, r.version()); err != nil { + return err + } + } + } + + return nil +} + +func (r *OffsetResponse) key() int16 { + return 2 +} + +func (r *OffsetResponse) version() int16 { + return r.Version +} + +func (r *OffsetResponse) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_10_1_0 + default: + return minVersion + } +} + +// testing API + +func (r *OffsetResponse) AddTopicPartition(topic string, partition int32, offset int64) { + if r.Blocks == nil { + r.Blocks = make(map[string]map[int32]*OffsetResponseBlock) + } + byTopic, ok := r.Blocks[topic] + if !ok { + byTopic = make(map[int32]*OffsetResponseBlock) + r.Blocks[topic] = byTopic + } + byTopic[partition] = &OffsetResponseBlock{Offsets: []int64{offset}, Offset: offset} +} diff --git a/vendor/github.com/Shopify/sarama/offset_response_test.go b/vendor/github.com/Shopify/sarama/offset_response_test.go new file mode 100644 index 00000000..0df6c9f3 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/offset_response_test.go @@ -0,0 +1,111 @@ +package sarama + +import "testing" + +var ( + emptyOffsetResponse = []byte{ + 0x00, 0x00, 0x00, 0x00} + + normalOffsetResponse = []byte{ + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x01, 'a', + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x01, 'z', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06} + + normalOffsetResponseV1 = []byte{ + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x01, 'a', + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x01, 'z', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, + 0x00, 0x00, 0x01, 0x58, 0x1A, 0xE6, 0x48, 0x86, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06} +) + +func TestEmptyOffsetResponse(t *testing.T) { + response := OffsetResponse{} + + testVersionDecodable(t, "empty", &response, emptyOffsetResponse, 0) + if len(response.Blocks) != 0 { + t.Error("Decoding produced", len(response.Blocks), "topics where there were none.") + } + + response = OffsetResponse{} + + testVersionDecodable(t, "empty", &response, emptyOffsetResponse, 1) + if len(response.Blocks) != 0 { + t.Error("Decoding produced", len(response.Blocks), "topics where there were none.") + } +} + +func TestNormalOffsetResponse(t *testing.T) { + response := OffsetResponse{} + + testVersionDecodable(t, "normal", &response, normalOffsetResponse, 0) + + if len(response.Blocks) != 2 { + t.Fatal("Decoding produced", len(response.Blocks), "topics where there were two.") + } + + if len(response.Blocks["a"]) != 0 { + t.Fatal("Decoding produced", len(response.Blocks["a"]), "partitions for topic 'a' where there were none.") + } + + if len(response.Blocks["z"]) != 1 { + t.Fatal("Decoding produced", len(response.Blocks["z"]), "partitions for topic 'z' where there was one.") + } + + if response.Blocks["z"][2].Err != ErrNoError { + t.Fatal("Decoding produced invalid error for topic z partition 2.") + } + + if len(response.Blocks["z"][2].Offsets) != 2 { + t.Fatal("Decoding produced invalid number of offsets for topic z partition 2.") + } + + if response.Blocks["z"][2].Offsets[0] != 5 || response.Blocks["z"][2].Offsets[1] != 6 { + t.Fatal("Decoding produced invalid offsets for topic z partition 2.") + } +} + +func TestNormalOffsetResponseV1(t *testing.T) { + response := OffsetResponse{} + + testVersionDecodable(t, "normal", &response, normalOffsetResponseV1, 1) + + if len(response.Blocks) != 2 { + t.Fatal("Decoding produced", len(response.Blocks), "topics where there were two.") + } + + if len(response.Blocks["a"]) != 0 { + t.Fatal("Decoding produced", len(response.Blocks["a"]), "partitions for topic 'a' where there were none.") + } + + if len(response.Blocks["z"]) != 1 { + t.Fatal("Decoding produced", len(response.Blocks["z"]), "partitions for topic 'z' where there was one.") + } + + if response.Blocks["z"][2].Err != ErrNoError { + t.Fatal("Decoding produced invalid error for topic z partition 2.") + } + + if response.Blocks["z"][2].Timestamp != 1477920049286 { + t.Fatal("Decoding produced invalid timestamp for topic z partition 2.", response.Blocks["z"][2].Timestamp) + } + + if response.Blocks["z"][2].Offset != 6 { + t.Fatal("Decoding produced invalid offsets for topic z partition 2.") + } +} diff --git a/vendor/github.com/Shopify/sarama/packet_decoder.go b/vendor/github.com/Shopify/sarama/packet_decoder.go new file mode 100644 index 00000000..28670c0e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/packet_decoder.go @@ -0,0 +1,45 @@ +package sarama + +// PacketDecoder is the interface providing helpers for reading with Kafka's encoding rules. +// Types implementing Decoder only need to worry about calling methods like GetString, +// not about how a string is represented in Kafka. +type packetDecoder interface { + // Primitives + getInt8() (int8, error) + getInt16() (int16, error) + getInt32() (int32, error) + getInt64() (int64, error) + getArrayLength() (int, error) + + // Collections + getBytes() ([]byte, error) + getString() (string, error) + getInt32Array() ([]int32, error) + getInt64Array() ([]int64, error) + getStringArray() ([]string, error) + + // Subsets + remaining() int + getSubset(length int) (packetDecoder, error) + + // Stacks, see PushDecoder + push(in pushDecoder) error + pop() error +} + +// PushDecoder is the interface for decoding fields like CRCs and lengths where the validity +// of the field depends on what is after it in the packet. Start them with PacketDecoder.Push() where +// the actual value is located in the packet, then PacketDecoder.Pop() them when all the bytes they +// depend upon have been decoded. +type pushDecoder interface { + // Saves the offset into the input buffer as the location to actually read the calculated value when able. + saveOffset(in int) + + // Returns the length of data to reserve for the input of this encoder (eg 4 bytes for a CRC32). + reserveLength() int + + // Indicates that all required data is now available to calculate and check the field. + // SaveOffset is guaranteed to have been called first. The implementation should read ReserveLength() bytes + // of data from the saved offset, and verify it based on the data between the saved offset and curOffset. + check(curOffset int, buf []byte) error +} diff --git a/vendor/github.com/Shopify/sarama/packet_encoder.go b/vendor/github.com/Shopify/sarama/packet_encoder.go new file mode 100644 index 00000000..27a10f6d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/packet_encoder.go @@ -0,0 +1,50 @@ +package sarama + +import "github.com/rcrowley/go-metrics" + +// PacketEncoder is the interface providing helpers for writing with Kafka's encoding rules. +// Types implementing Encoder only need to worry about calling methods like PutString, +// not about how a string is represented in Kafka. +type packetEncoder interface { + // Primitives + putInt8(in int8) + putInt16(in int16) + putInt32(in int32) + putInt64(in int64) + putArrayLength(in int) error + + // Collections + putBytes(in []byte) error + putRawBytes(in []byte) error + putString(in string) error + putStringArray(in []string) error + putInt32Array(in []int32) error + putInt64Array(in []int64) error + + // Provide the current offset to record the batch size metric + offset() int + + // Stacks, see PushEncoder + push(in pushEncoder) + pop() error + + // To record metrics when provided + metricRegistry() metrics.Registry +} + +// PushEncoder is the interface for encoding fields like CRCs and lengths where the value +// of the field depends on what is encoded after it in the packet. Start them with PacketEncoder.Push() where +// the actual value is located in the packet, then PacketEncoder.Pop() them when all the bytes they +// depend upon have been written. +type pushEncoder interface { + // Saves the offset into the input buffer as the location to actually write the calculated value when able. + saveOffset(in int) + + // Returns the length of data to reserve for the output of this encoder (eg 4 bytes for a CRC32). + reserveLength() int + + // Indicates that all required data is now available to calculate and write the field. + // SaveOffset is guaranteed to have been called first. The implementation should write ReserveLength() bytes + // of data to the saved offset, based on the data between the saved offset and curOffset. + run(curOffset int, buf []byte) error +} diff --git a/vendor/github.com/Shopify/sarama/partitioner.go b/vendor/github.com/Shopify/sarama/partitioner.go new file mode 100644 index 00000000..97293272 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/partitioner.go @@ -0,0 +1,135 @@ +package sarama + +import ( + "hash" + "hash/fnv" + "math/rand" + "time" +) + +// Partitioner is anything that, given a Kafka message and a number of partitions indexed [0...numPartitions-1], +// decides to which partition to send the message. RandomPartitioner, RoundRobinPartitioner and HashPartitioner are provided +// as simple default implementations. +type Partitioner interface { + // Partition takes a message and partition count and chooses a partition + Partition(message *ProducerMessage, numPartitions int32) (int32, error) + + // RequiresConsistency indicates to the user of the partitioner whether the + // mapping of key->partition is consistent or not. Specifically, if a + // partitioner requires consistency then it must be allowed to choose from all + // partitions (even ones known to be unavailable), and its choice must be + // respected by the caller. The obvious example is the HashPartitioner. + RequiresConsistency() bool +} + +// PartitionerConstructor is the type for a function capable of constructing new Partitioners. +type PartitionerConstructor func(topic string) Partitioner + +type manualPartitioner struct{} + +// NewManualPartitioner returns a Partitioner which uses the partition manually set in the provided +// ProducerMessage's Partition field as the partition to produce to. +func NewManualPartitioner(topic string) Partitioner { + return new(manualPartitioner) +} + +func (p *manualPartitioner) Partition(message *ProducerMessage, numPartitions int32) (int32, error) { + return message.Partition, nil +} + +func (p *manualPartitioner) RequiresConsistency() bool { + return true +} + +type randomPartitioner struct { + generator *rand.Rand +} + +// NewRandomPartitioner returns a Partitioner which chooses a random partition each time. +func NewRandomPartitioner(topic string) Partitioner { + p := new(randomPartitioner) + p.generator = rand.New(rand.NewSource(time.Now().UTC().UnixNano())) + return p +} + +func (p *randomPartitioner) Partition(message *ProducerMessage, numPartitions int32) (int32, error) { + return int32(p.generator.Intn(int(numPartitions))), nil +} + +func (p *randomPartitioner) RequiresConsistency() bool { + return false +} + +type roundRobinPartitioner struct { + partition int32 +} + +// NewRoundRobinPartitioner returns a Partitioner which walks through the available partitions one at a time. +func NewRoundRobinPartitioner(topic string) Partitioner { + return &roundRobinPartitioner{} +} + +func (p *roundRobinPartitioner) Partition(message *ProducerMessage, numPartitions int32) (int32, error) { + if p.partition >= numPartitions { + p.partition = 0 + } + ret := p.partition + p.partition++ + return ret, nil +} + +func (p *roundRobinPartitioner) RequiresConsistency() bool { + return false +} + +type hashPartitioner struct { + random Partitioner + hasher hash.Hash32 +} + +// NewCustomHashPartitioner is a wrapper around NewHashPartitioner, allowing the use of custom hasher. +// The argument is a function providing the instance, implementing the hash.Hash32 interface. This is to ensure that +// each partition dispatcher gets its own hasher, to avoid concurrency issues by sharing an instance. +func NewCustomHashPartitioner(hasher func() hash.Hash32) PartitionerConstructor { + return func(topic string) Partitioner { + p := new(hashPartitioner) + p.random = NewRandomPartitioner(topic) + p.hasher = hasher() + return p + } +} + +// NewHashPartitioner returns a Partitioner which behaves as follows. If the message's key is nil then a +// random partition is chosen. Otherwise the FNV-1a hash of the encoded bytes of the message key is used, +// modulus the number of partitions. This ensures that messages with the same key always end up on the +// same partition. +func NewHashPartitioner(topic string) Partitioner { + p := new(hashPartitioner) + p.random = NewRandomPartitioner(topic) + p.hasher = fnv.New32a() + return p +} + +func (p *hashPartitioner) Partition(message *ProducerMessage, numPartitions int32) (int32, error) { + if message.Key == nil { + return p.random.Partition(message, numPartitions) + } + bytes, err := message.Key.Encode() + if err != nil { + return -1, err + } + p.hasher.Reset() + _, err = p.hasher.Write(bytes) + if err != nil { + return -1, err + } + partition := int32(p.hasher.Sum32()) % numPartitions + if partition < 0 { + partition = -partition + } + return partition, nil +} + +func (p *hashPartitioner) RequiresConsistency() bool { + return true +} diff --git a/vendor/github.com/Shopify/sarama/partitioner_test.go b/vendor/github.com/Shopify/sarama/partitioner_test.go new file mode 100644 index 00000000..83376431 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/partitioner_test.go @@ -0,0 +1,265 @@ +package sarama + +import ( + "crypto/rand" + "hash/fnv" + "log" + "testing" +) + +func assertPartitioningConsistent(t *testing.T, partitioner Partitioner, message *ProducerMessage, numPartitions int32) { + choice, err := partitioner.Partition(message, numPartitions) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= numPartitions { + t.Error(partitioner, "returned partition", choice, "outside of range for", message) + } + for i := 1; i < 50; i++ { + newChoice, err := partitioner.Partition(message, numPartitions) + if err != nil { + t.Error(partitioner, err) + } + if newChoice != choice { + t.Error(partitioner, "returned partition", newChoice, "inconsistent with", choice, ".") + } + } +} + +func TestRandomPartitioner(t *testing.T) { + partitioner := NewRandomPartitioner("mytopic") + + choice, err := partitioner.Partition(nil, 1) + if err != nil { + t.Error(partitioner, err) + } + if choice != 0 { + t.Error("Returned non-zero partition when only one available.") + } + + for i := 1; i < 50; i++ { + choice, err := partitioner.Partition(nil, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= 50 { + t.Error("Returned partition", choice, "outside of range.") + } + } +} + +func TestRoundRobinPartitioner(t *testing.T) { + partitioner := NewRoundRobinPartitioner("mytopic") + + choice, err := partitioner.Partition(nil, 1) + if err != nil { + t.Error(partitioner, err) + } + if choice != 0 { + t.Error("Returned non-zero partition when only one available.") + } + + var i int32 + for i = 1; i < 50; i++ { + choice, err := partitioner.Partition(nil, 7) + if err != nil { + t.Error(partitioner, err) + } + if choice != i%7 { + t.Error("Returned partition", choice, "expecting", i%7) + } + } +} + +func TestNewHashPartitionerWithHasher(t *testing.T) { + // use the current default hasher fnv.New32a() + partitioner := NewCustomHashPartitioner(fnv.New32a)("mytopic") + + choice, err := partitioner.Partition(&ProducerMessage{}, 1) + if err != nil { + t.Error(partitioner, err) + } + if choice != 0 { + t.Error("Returned non-zero partition when only one available.") + } + + for i := 1; i < 50; i++ { + choice, err := partitioner.Partition(&ProducerMessage{}, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= 50 { + t.Error("Returned partition", choice, "outside of range for nil key.") + } + } + + buf := make([]byte, 256) + for i := 1; i < 50; i++ { + if _, err := rand.Read(buf); err != nil { + t.Error(err) + } + assertPartitioningConsistent(t, partitioner, &ProducerMessage{Key: ByteEncoder(buf)}, 50) + } +} + +func TestHashPartitionerWithHasherMinInt32(t *testing.T) { + // use the current default hasher fnv.New32a() + partitioner := NewCustomHashPartitioner(fnv.New32a)("mytopic") + + msg := ProducerMessage{} + // "1468509572224" generates 2147483648 (uint32) result from Sum32 function + // which is -2147483648 or int32's min value + msg.Key = StringEncoder("1468509572224") + + choice, err := partitioner.Partition(&msg, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= 50 { + t.Error("Returned partition", choice, "outside of range for nil key.") + } +} + +func TestHashPartitioner(t *testing.T) { + partitioner := NewHashPartitioner("mytopic") + + choice, err := partitioner.Partition(&ProducerMessage{}, 1) + if err != nil { + t.Error(partitioner, err) + } + if choice != 0 { + t.Error("Returned non-zero partition when only one available.") + } + + for i := 1; i < 50; i++ { + choice, err := partitioner.Partition(&ProducerMessage{}, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= 50 { + t.Error("Returned partition", choice, "outside of range for nil key.") + } + } + + buf := make([]byte, 256) + for i := 1; i < 50; i++ { + if _, err := rand.Read(buf); err != nil { + t.Error(err) + } + assertPartitioningConsistent(t, partitioner, &ProducerMessage{Key: ByteEncoder(buf)}, 50) + } +} + +func TestHashPartitionerMinInt32(t *testing.T) { + partitioner := NewHashPartitioner("mytopic") + + msg := ProducerMessage{} + // "1468509572224" generates 2147483648 (uint32) result from Sum32 function + // which is -2147483648 or int32's min value + msg.Key = StringEncoder("1468509572224") + + choice, err := partitioner.Partition(&msg, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice < 0 || choice >= 50 { + t.Error("Returned partition", choice, "outside of range for nil key.") + } +} + +func TestManualPartitioner(t *testing.T) { + partitioner := NewManualPartitioner("mytopic") + + choice, err := partitioner.Partition(&ProducerMessage{}, 1) + if err != nil { + t.Error(partitioner, err) + } + if choice != 0 { + t.Error("Returned non-zero partition when only one available.") + } + + for i := int32(1); i < 50; i++ { + choice, err := partitioner.Partition(&ProducerMessage{Partition: i}, 50) + if err != nil { + t.Error(partitioner, err) + } + if choice != i { + t.Error("Returned partition not the same as the input partition") + } + } +} + +// By default, Sarama uses the message's key to consistently assign a partition to +// a message using hashing. If no key is set, a random partition will be chosen. +// This example shows how you can partition messages randomly, even when a key is set, +// by overriding Config.Producer.Partitioner. +func ExamplePartitioner_random() { + config := NewConfig() + config.Producer.Partitioner = NewRandomPartitioner + + producer, err := NewSyncProducer([]string{"localhost:9092"}, config) + if err != nil { + log.Fatal(err) + } + defer func() { + if err := producer.Close(); err != nil { + log.Println("Failed to close producer:", err) + } + }() + + msg := &ProducerMessage{Topic: "test", Key: StringEncoder("key is set"), Value: StringEncoder("test")} + partition, offset, err := producer.SendMessage(msg) + if err != nil { + log.Fatalln("Failed to produce message to kafka cluster.") + } + + log.Printf("Produced message to partition %d with offset %d", partition, offset) +} + +// This example shows how to assign partitions to your messages manually. +func ExamplePartitioner_manual() { + config := NewConfig() + + // First, we tell the producer that we are going to partition ourselves. + config.Producer.Partitioner = NewManualPartitioner + + producer, err := NewSyncProducer([]string{"localhost:9092"}, config) + if err != nil { + log.Fatal(err) + } + defer func() { + if err := producer.Close(); err != nil { + log.Println("Failed to close producer:", err) + } + }() + + // Now, we set the Partition field of the ProducerMessage struct. + msg := &ProducerMessage{Topic: "test", Partition: 6, Value: StringEncoder("test")} + + partition, offset, err := producer.SendMessage(msg) + if err != nil { + log.Fatalln("Failed to produce message to kafka cluster.") + } + + if partition != 6 { + log.Fatal("Message should have been produced to partition 6!") + } + + log.Printf("Produced message to partition %d with offset %d", partition, offset) +} + +// This example shows how to set a different partitioner depending on the topic. +func ExamplePartitioner_per_topic() { + config := NewConfig() + config.Producer.Partitioner = func(topic string) Partitioner { + switch topic { + case "access_log", "error_log": + return NewRandomPartitioner(topic) + + default: + return NewHashPartitioner(topic) + } + } + + // ... +} diff --git a/vendor/github.com/Shopify/sarama/prep_encoder.go b/vendor/github.com/Shopify/sarama/prep_encoder.go new file mode 100644 index 00000000..fd5ea0f9 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/prep_encoder.go @@ -0,0 +1,121 @@ +package sarama + +import ( + "fmt" + "math" + + "github.com/rcrowley/go-metrics" +) + +type prepEncoder struct { + length int +} + +// primitives + +func (pe *prepEncoder) putInt8(in int8) { + pe.length++ +} + +func (pe *prepEncoder) putInt16(in int16) { + pe.length += 2 +} + +func (pe *prepEncoder) putInt32(in int32) { + pe.length += 4 +} + +func (pe *prepEncoder) putInt64(in int64) { + pe.length += 8 +} + +func (pe *prepEncoder) putArrayLength(in int) error { + if in > math.MaxInt32 { + return PacketEncodingError{fmt.Sprintf("array too long (%d)", in)} + } + pe.length += 4 + return nil +} + +// arrays + +func (pe *prepEncoder) putBytes(in []byte) error { + pe.length += 4 + if in == nil { + return nil + } + if len(in) > math.MaxInt32 { + return PacketEncodingError{fmt.Sprintf("byteslice too long (%d)", len(in))} + } + pe.length += len(in) + return nil +} + +func (pe *prepEncoder) putRawBytes(in []byte) error { + if len(in) > math.MaxInt32 { + return PacketEncodingError{fmt.Sprintf("byteslice too long (%d)", len(in))} + } + pe.length += len(in) + return nil +} + +func (pe *prepEncoder) putString(in string) error { + pe.length += 2 + if len(in) > math.MaxInt16 { + return PacketEncodingError{fmt.Sprintf("string too long (%d)", len(in))} + } + pe.length += len(in) + return nil +} + +func (pe *prepEncoder) putStringArray(in []string) error { + err := pe.putArrayLength(len(in)) + if err != nil { + return err + } + + for _, str := range in { + if err := pe.putString(str); err != nil { + return err + } + } + + return nil +} + +func (pe *prepEncoder) putInt32Array(in []int32) error { + err := pe.putArrayLength(len(in)) + if err != nil { + return err + } + pe.length += 4 * len(in) + return nil +} + +func (pe *prepEncoder) putInt64Array(in []int64) error { + err := pe.putArrayLength(len(in)) + if err != nil { + return err + } + pe.length += 8 * len(in) + return nil +} + +func (pe *prepEncoder) offset() int { + return pe.length +} + +// stackable + +func (pe *prepEncoder) push(in pushEncoder) { + pe.length += in.reserveLength() +} + +func (pe *prepEncoder) pop() error { + return nil +} + +// we do not record metrics during the prep encoder pass +func (pe *prepEncoder) metricRegistry() metrics.Registry { + return nil +} diff --git a/vendor/github.com/Shopify/sarama/produce_request.go b/vendor/github.com/Shopify/sarama/produce_request.go new file mode 100644 index 00000000..40dc8015 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_request.go @@ -0,0 +1,209 @@ +package sarama + +import "github.com/rcrowley/go-metrics" + +// RequiredAcks is used in Produce Requests to tell the broker how many replica acknowledgements +// it must see before responding. Any of the constants defined here are valid. On broker versions +// prior to 0.8.2.0 any other positive int16 is also valid (the broker will wait for that many +// acknowledgements) but in 0.8.2.0 and later this will raise an exception (it has been replaced +// by setting the `min.isr` value in the brokers configuration). +type RequiredAcks int16 + +const ( + // NoResponse doesn't send any response, the TCP ACK is all you get. + NoResponse RequiredAcks = 0 + // WaitForLocal waits for only the local commit to succeed before responding. + WaitForLocal RequiredAcks = 1 + // WaitForAll waits for all in-sync replicas to commit before responding. + // The minimum number of in-sync replicas is configured on the broker via + // the `min.insync.replicas` configuration key. + WaitForAll RequiredAcks = -1 +) + +type ProduceRequest struct { + RequiredAcks RequiredAcks + Timeout int32 + Version int16 // v1 requires Kafka 0.9, v2 requires Kafka 0.10 + msgSets map[string]map[int32]*MessageSet +} + +func (r *ProduceRequest) encode(pe packetEncoder) error { + pe.putInt16(int16(r.RequiredAcks)) + pe.putInt32(r.Timeout) + err := pe.putArrayLength(len(r.msgSets)) + if err != nil { + return err + } + metricRegistry := pe.metricRegistry() + var batchSizeMetric metrics.Histogram + var compressionRatioMetric metrics.Histogram + if metricRegistry != nil { + batchSizeMetric = getOrRegisterHistogram("batch-size", metricRegistry) + compressionRatioMetric = getOrRegisterHistogram("compression-ratio", metricRegistry) + } + + totalRecordCount := int64(0) + for topic, partitions := range r.msgSets { + err = pe.putString(topic) + if err != nil { + return err + } + err = pe.putArrayLength(len(partitions)) + if err != nil { + return err + } + topicRecordCount := int64(0) + var topicCompressionRatioMetric metrics.Histogram + if metricRegistry != nil { + topicCompressionRatioMetric = getOrRegisterTopicHistogram("compression-ratio", topic, metricRegistry) + } + for id, msgSet := range partitions { + startOffset := pe.offset() + pe.putInt32(id) + pe.push(&lengthField{}) + err = msgSet.encode(pe) + if err != nil { + return err + } + err = pe.pop() + if err != nil { + return err + } + if metricRegistry != nil { + for _, messageBlock := range msgSet.Messages { + // Is this a fake "message" wrapping real messages? + if messageBlock.Msg.Set != nil { + topicRecordCount += int64(len(messageBlock.Msg.Set.Messages)) + } else { + // A single uncompressed message + topicRecordCount++ + } + // Better be safe than sorry when computing the compression ratio + if messageBlock.Msg.compressedSize != 0 { + compressionRatio := float64(len(messageBlock.Msg.Value)) / + float64(messageBlock.Msg.compressedSize) + // Histogram do not support decimal values, let's multiple it by 100 for better precision + intCompressionRatio := int64(100 * compressionRatio) + compressionRatioMetric.Update(intCompressionRatio) + topicCompressionRatioMetric.Update(intCompressionRatio) + } + } + batchSize := int64(pe.offset() - startOffset) + batchSizeMetric.Update(batchSize) + getOrRegisterTopicHistogram("batch-size", topic, metricRegistry).Update(batchSize) + } + } + if topicRecordCount > 0 { + getOrRegisterTopicMeter("record-send-rate", topic, metricRegistry).Mark(topicRecordCount) + getOrRegisterTopicHistogram("records-per-request", topic, metricRegistry).Update(topicRecordCount) + totalRecordCount += topicRecordCount + } + } + if totalRecordCount > 0 { + metrics.GetOrRegisterMeter("record-send-rate", metricRegistry).Mark(totalRecordCount) + getOrRegisterHistogram("records-per-request", metricRegistry).Update(totalRecordCount) + } + + return nil +} + +func (r *ProduceRequest) decode(pd packetDecoder, version int16) error { + requiredAcks, err := pd.getInt16() + if err != nil { + return err + } + r.RequiredAcks = RequiredAcks(requiredAcks) + if r.Timeout, err = pd.getInt32(); err != nil { + return err + } + topicCount, err := pd.getArrayLength() + if err != nil { + return err + } + if topicCount == 0 { + return nil + } + r.msgSets = make(map[string]map[int32]*MessageSet) + for i := 0; i < topicCount; i++ { + topic, err := pd.getString() + if err != nil { + return err + } + partitionCount, err := pd.getArrayLength() + if err != nil { + return err + } + r.msgSets[topic] = make(map[int32]*MessageSet) + for j := 0; j < partitionCount; j++ { + partition, err := pd.getInt32() + if err != nil { + return err + } + messageSetSize, err := pd.getInt32() + if err != nil { + return err + } + msgSetDecoder, err := pd.getSubset(int(messageSetSize)) + if err != nil { + return err + } + msgSet := &MessageSet{} + err = msgSet.decode(msgSetDecoder) + if err != nil { + return err + } + r.msgSets[topic][partition] = msgSet + } + } + return nil +} + +func (r *ProduceRequest) key() int16 { + return 0 +} + +func (r *ProduceRequest) version() int16 { + return r.Version +} + +func (r *ProduceRequest) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_9_0_0 + case 2: + return V0_10_0_0 + default: + return minVersion + } +} + +func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message) { + if r.msgSets == nil { + r.msgSets = make(map[string]map[int32]*MessageSet) + } + + if r.msgSets[topic] == nil { + r.msgSets[topic] = make(map[int32]*MessageSet) + } + + set := r.msgSets[topic][partition] + + if set == nil { + set = new(MessageSet) + r.msgSets[topic][partition] = set + } + + set.addMessage(msg) +} + +func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) { + if r.msgSets == nil { + r.msgSets = make(map[string]map[int32]*MessageSet) + } + + if r.msgSets[topic] == nil { + r.msgSets[topic] = make(map[int32]*MessageSet) + } + + r.msgSets[topic][partition] = set +} diff --git a/vendor/github.com/Shopify/sarama/produce_request_test.go b/vendor/github.com/Shopify/sarama/produce_request_test.go new file mode 100644 index 00000000..21f4ba5b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_request_test.go @@ -0,0 +1,47 @@ +package sarama + +import ( + "testing" +) + +var ( + produceRequestEmpty = []byte{ + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00} + + produceRequestHeader = []byte{ + 0x01, 0x23, + 0x00, 0x00, 0x04, 0x44, + 0x00, 0x00, 0x00, 0x00} + + produceRequestOneMessage = []byte{ + 0x01, 0x23, + 0x00, 0x00, 0x04, 0x44, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x05, 't', 'o', 'p', 'i', 'c', + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0xAD, + 0x00, 0x00, 0x00, 0x1C, + // messageSet + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x10, + // message + 0x23, 0x96, 0x4a, 0xf7, // CRC + 0x00, + 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x02, 0x00, 0xEE} +) + +func TestProduceRequest(t *testing.T) { + request := new(ProduceRequest) + testRequest(t, "empty", request, produceRequestEmpty) + + request.RequiredAcks = 0x123 + request.Timeout = 0x444 + testRequest(t, "header", request, produceRequestHeader) + + request.AddMessage("topic", 0xAD, &Message{Codec: CompressionNone, Key: nil, Value: []byte{0x00, 0xEE}}) + testRequest(t, "one message", request, produceRequestOneMessage) +} diff --git a/vendor/github.com/Shopify/sarama/produce_response.go b/vendor/github.com/Shopify/sarama/produce_response.go new file mode 100644 index 00000000..3f05dd9f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_response.go @@ -0,0 +1,159 @@ +package sarama + +import "time" + +type ProduceResponseBlock struct { + Err KError + Offset int64 + // only provided if Version >= 2 and the broker is configured with `LogAppendTime` + Timestamp time.Time +} + +func (b *ProduceResponseBlock) decode(pd packetDecoder, version int16) (err error) { + tmp, err := pd.getInt16() + if err != nil { + return err + } + b.Err = KError(tmp) + + b.Offset, err = pd.getInt64() + if err != nil { + return err + } + + if version >= 2 { + if millis, err := pd.getInt64(); err != nil { + return err + } else if millis != -1 { + b.Timestamp = time.Unix(millis/1000, (millis%1000)*int64(time.Millisecond)) + } + } + + return nil +} + +type ProduceResponse struct { + Blocks map[string]map[int32]*ProduceResponseBlock + Version int16 + ThrottleTime time.Duration // only provided if Version >= 1 +} + +func (r *ProduceResponse) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + + numTopics, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks = make(map[string]map[int32]*ProduceResponseBlock, numTopics) + for i := 0; i < numTopics; i++ { + name, err := pd.getString() + if err != nil { + return err + } + + numBlocks, err := pd.getArrayLength() + if err != nil { + return err + } + + r.Blocks[name] = make(map[int32]*ProduceResponseBlock, numBlocks) + + for j := 0; j < numBlocks; j++ { + id, err := pd.getInt32() + if err != nil { + return err + } + + block := new(ProduceResponseBlock) + err = block.decode(pd, version) + if err != nil { + return err + } + r.Blocks[name][id] = block + } + } + + if r.Version >= 1 { + millis, err := pd.getInt32() + if err != nil { + return err + } + + r.ThrottleTime = time.Duration(millis) * time.Millisecond + } + + return nil +} + +func (r *ProduceResponse) encode(pe packetEncoder) error { + err := pe.putArrayLength(len(r.Blocks)) + if err != nil { + return err + } + for topic, partitions := range r.Blocks { + err = pe.putString(topic) + if err != nil { + return err + } + err = pe.putArrayLength(len(partitions)) + if err != nil { + return err + } + for id, prb := range partitions { + pe.putInt32(id) + pe.putInt16(int16(prb.Err)) + pe.putInt64(prb.Offset) + } + } + if r.Version >= 1 { + pe.putInt32(int32(r.ThrottleTime / time.Millisecond)) + } + return nil +} + +func (r *ProduceResponse) key() int16 { + return 0 +} + +func (r *ProduceResponse) version() int16 { + return r.Version +} + +func (r *ProduceResponse) requiredVersion() KafkaVersion { + switch r.Version { + case 1: + return V0_9_0_0 + case 2: + return V0_10_0_0 + default: + return minVersion + } +} + +func (r *ProduceResponse) GetBlock(topic string, partition int32) *ProduceResponseBlock { + if r.Blocks == nil { + return nil + } + + if r.Blocks[topic] == nil { + return nil + } + + return r.Blocks[topic][partition] +} + +// Testing API + +func (r *ProduceResponse) AddTopicPartition(topic string, partition int32, err KError) { + if r.Blocks == nil { + r.Blocks = make(map[string]map[int32]*ProduceResponseBlock) + } + byTopic, ok := r.Blocks[topic] + if !ok { + byTopic = make(map[int32]*ProduceResponseBlock) + r.Blocks[topic] = byTopic + } + byTopic[partition] = &ProduceResponseBlock{Err: err} +} diff --git a/vendor/github.com/Shopify/sarama/produce_response_test.go b/vendor/github.com/Shopify/sarama/produce_response_test.go new file mode 100644 index 00000000..f71709fe --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_response_test.go @@ -0,0 +1,67 @@ +package sarama + +import "testing" + +var ( + produceResponseNoBlocks = []byte{ + 0x00, 0x00, 0x00, 0x00} + + produceResponseManyBlocks = []byte{ + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x03, 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x02, + + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, + + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} +) + +func TestProduceResponse(t *testing.T) { + response := ProduceResponse{} + + testVersionDecodable(t, "no blocks", &response, produceResponseNoBlocks, 0) + if len(response.Blocks) != 0 { + t.Error("Decoding produced", len(response.Blocks), "topics where there were none") + } + + testVersionDecodable(t, "many blocks", &response, produceResponseManyBlocks, 0) + if len(response.Blocks) != 2 { + t.Error("Decoding produced", len(response.Blocks), "topics where there were 2") + } + if len(response.Blocks["foo"]) != 0 { + t.Error("Decoding produced", len(response.Blocks["foo"]), "partitions for 'foo' where there were none") + } + if len(response.Blocks["bar"]) != 2 { + t.Error("Decoding produced", len(response.Blocks["bar"]), "partitions for 'bar' where there were two") + } + block := response.GetBlock("bar", 1) + if block == nil { + t.Error("Decoding did not produce a block for bar/1") + } else { + if block.Err != ErrNoError { + t.Error("Decoding failed for bar/1/Err, got:", int16(block.Err)) + } + if block.Offset != 0xFF { + t.Error("Decoding failed for bar/1/Offset, got:", block.Offset) + } + } + block = response.GetBlock("bar", 2) + if block == nil { + t.Error("Decoding did not produce a block for bar/2") + } else { + if block.Err != ErrInvalidMessage { + t.Error("Decoding failed for bar/2/Err, got:", int16(block.Err)) + } + if block.Offset != 0 { + t.Error("Decoding failed for bar/2/Offset, got:", block.Offset) + } + } +} diff --git a/vendor/github.com/Shopify/sarama/produce_set.go b/vendor/github.com/Shopify/sarama/produce_set.go new file mode 100644 index 00000000..158d9c47 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_set.go @@ -0,0 +1,176 @@ +package sarama + +import "time" + +type partitionSet struct { + msgs []*ProducerMessage + setToSend *MessageSet + bufferBytes int +} + +type produceSet struct { + parent *asyncProducer + msgs map[string]map[int32]*partitionSet + + bufferBytes int + bufferCount int +} + +func newProduceSet(parent *asyncProducer) *produceSet { + return &produceSet{ + msgs: make(map[string]map[int32]*partitionSet), + parent: parent, + } +} + +func (ps *produceSet) add(msg *ProducerMessage) error { + var err error + var key, val []byte + + if msg.Key != nil { + if key, err = msg.Key.Encode(); err != nil { + return err + } + } + + if msg.Value != nil { + if val, err = msg.Value.Encode(); err != nil { + return err + } + } + + partitions := ps.msgs[msg.Topic] + if partitions == nil { + partitions = make(map[int32]*partitionSet) + ps.msgs[msg.Topic] = partitions + } + + set := partitions[msg.Partition] + if set == nil { + set = &partitionSet{setToSend: new(MessageSet)} + partitions[msg.Partition] = set + } + + set.msgs = append(set.msgs, msg) + msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val} + if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { + if msg.Timestamp.IsZero() { + msgToSend.Timestamp = time.Now() + } else { + msgToSend.Timestamp = msg.Timestamp + } + msgToSend.Version = 1 + } + set.setToSend.addMessage(msgToSend) + + size := producerMessageOverhead + len(key) + len(val) + set.bufferBytes += size + ps.bufferBytes += size + ps.bufferCount++ + + return nil +} + +func (ps *produceSet) buildRequest() *ProduceRequest { + req := &ProduceRequest{ + RequiredAcks: ps.parent.conf.Producer.RequiredAcks, + Timeout: int32(ps.parent.conf.Producer.Timeout / time.Millisecond), + } + if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { + req.Version = 2 + } + + for topic, partitionSet := range ps.msgs { + for partition, set := range partitionSet { + if ps.parent.conf.Producer.Compression == CompressionNone { + req.AddSet(topic, partition, set.setToSend) + } else { + // When compression is enabled, the entire set for each partition is compressed + // and sent as the payload of a single fake "message" with the appropriate codec + // set and no key. When the server sees a message with a compression codec, it + // decompresses the payload and treats the result as its message set. + payload, err := encode(set.setToSend, ps.parent.conf.MetricRegistry) + if err != nil { + Logger.Println(err) // if this happens, it's basically our fault. + panic(err) + } + compMsg := &Message{ + Codec: ps.parent.conf.Producer.Compression, + Key: nil, + Value: payload, + Set: set.setToSend, // Provide the underlying message set for accurate metrics + } + if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) { + compMsg.Version = 1 + compMsg.Timestamp = set.setToSend.Messages[0].Msg.Timestamp + } + req.AddMessage(topic, partition, compMsg) + } + } + } + + return req +} + +func (ps *produceSet) eachPartition(cb func(topic string, partition int32, msgs []*ProducerMessage)) { + for topic, partitionSet := range ps.msgs { + for partition, set := range partitionSet { + cb(topic, partition, set.msgs) + } + } +} + +func (ps *produceSet) dropPartition(topic string, partition int32) []*ProducerMessage { + if ps.msgs[topic] == nil { + return nil + } + set := ps.msgs[topic][partition] + if set == nil { + return nil + } + ps.bufferBytes -= set.bufferBytes + ps.bufferCount -= len(set.msgs) + delete(ps.msgs[topic], partition) + return set.msgs +} + +func (ps *produceSet) wouldOverflow(msg *ProducerMessage) bool { + switch { + // Would we overflow our maximum possible size-on-the-wire? 10KiB is arbitrary overhead for safety. + case ps.bufferBytes+msg.byteSize() >= int(MaxRequestSize-(10*1024)): + return true + // Would we overflow the size-limit of a compressed message-batch for this partition? + case ps.parent.conf.Producer.Compression != CompressionNone && + ps.msgs[msg.Topic] != nil && ps.msgs[msg.Topic][msg.Partition] != nil && + ps.msgs[msg.Topic][msg.Partition].bufferBytes+msg.byteSize() >= ps.parent.conf.Producer.MaxMessageBytes: + return true + // Would we overflow simply in number of messages? + case ps.parent.conf.Producer.Flush.MaxMessages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.MaxMessages: + return true + default: + return false + } +} + +func (ps *produceSet) readyToFlush() bool { + switch { + // If we don't have any messages, nothing else matters + case ps.empty(): + return false + // If all three config values are 0, we always flush as-fast-as-possible + case ps.parent.conf.Producer.Flush.Frequency == 0 && ps.parent.conf.Producer.Flush.Bytes == 0 && ps.parent.conf.Producer.Flush.Messages == 0: + return true + // If we've passed the message trigger-point + case ps.parent.conf.Producer.Flush.Messages > 0 && ps.bufferCount >= ps.parent.conf.Producer.Flush.Messages: + return true + // If we've passed the byte trigger-point + case ps.parent.conf.Producer.Flush.Bytes > 0 && ps.bufferBytes >= ps.parent.conf.Producer.Flush.Bytes: + return true + default: + return false + } +} + +func (ps *produceSet) empty() bool { + return ps.bufferCount == 0 +} diff --git a/vendor/github.com/Shopify/sarama/produce_set_test.go b/vendor/github.com/Shopify/sarama/produce_set_test.go new file mode 100644 index 00000000..d016a10b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/produce_set_test.go @@ -0,0 +1,185 @@ +package sarama + +import ( + "testing" + "time" +) + +func makeProduceSet() (*asyncProducer, *produceSet) { + parent := &asyncProducer{ + conf: NewConfig(), + } + return parent, newProduceSet(parent) +} + +func safeAddMessage(t *testing.T, ps *produceSet, msg *ProducerMessage) { + if err := ps.add(msg); err != nil { + t.Error(err) + } +} + +func TestProduceSetInitial(t *testing.T) { + _, ps := makeProduceSet() + + if !ps.empty() { + t.Error("New produceSet should be empty") + } + + if ps.readyToFlush() { + t.Error("Empty produceSet must never be ready to flush") + } +} + +func TestProduceSetAddingMessages(t *testing.T) { + parent, ps := makeProduceSet() + parent.conf.Producer.Flush.MaxMessages = 1000 + + msg := &ProducerMessage{Key: StringEncoder(TestMessage), Value: StringEncoder(TestMessage)} + safeAddMessage(t, ps, msg) + + if ps.empty() { + t.Error("set shouldn't be empty when a message is added") + } + + if !ps.readyToFlush() { + t.Error("by default set should be ready to flush when any message is in place") + } + + for i := 0; i < 999; i++ { + if ps.wouldOverflow(msg) { + t.Error("set shouldn't fill up after only", i+1, "messages") + } + safeAddMessage(t, ps, msg) + } + + if !ps.wouldOverflow(msg) { + t.Error("set should be full after 1000 messages") + } +} + +func TestProduceSetPartitionTracking(t *testing.T) { + _, ps := makeProduceSet() + + m1 := &ProducerMessage{Topic: "t1", Partition: 0} + m2 := &ProducerMessage{Topic: "t1", Partition: 1} + m3 := &ProducerMessage{Topic: "t2", Partition: 0} + safeAddMessage(t, ps, m1) + safeAddMessage(t, ps, m2) + safeAddMessage(t, ps, m3) + + seenT1P0 := false + seenT1P1 := false + seenT2P0 := false + + ps.eachPartition(func(topic string, partition int32, msgs []*ProducerMessage) { + if len(msgs) != 1 { + t.Error("Wrong message count") + } + + if topic == "t1" && partition == 0 { + seenT1P0 = true + } else if topic == "t1" && partition == 1 { + seenT1P1 = true + } else if topic == "t2" && partition == 0 { + seenT2P0 = true + } + }) + + if !seenT1P0 { + t.Error("Didn't see t1p0") + } + if !seenT1P1 { + t.Error("Didn't see t1p1") + } + if !seenT2P0 { + t.Error("Didn't see t2p0") + } + + if len(ps.dropPartition("t1", 1)) != 1 { + t.Error("Got wrong messages back from dropping partition") + } + + if ps.bufferCount != 2 { + t.Error("Incorrect buffer count after dropping partition") + } +} + +func TestProduceSetRequestBuilding(t *testing.T) { + parent, ps := makeProduceSet() + parent.conf.Producer.RequiredAcks = WaitForAll + parent.conf.Producer.Timeout = 10 * time.Second + + msg := &ProducerMessage{ + Topic: "t1", + Partition: 0, + Key: StringEncoder(TestMessage), + Value: StringEncoder(TestMessage), + } + for i := 0; i < 10; i++ { + safeAddMessage(t, ps, msg) + } + msg.Partition = 1 + for i := 0; i < 10; i++ { + safeAddMessage(t, ps, msg) + } + msg.Topic = "t2" + for i := 0; i < 10; i++ { + safeAddMessage(t, ps, msg) + } + + req := ps.buildRequest() + + if req.RequiredAcks != WaitForAll { + t.Error("RequiredAcks not set properly") + } + + if req.Timeout != 10000 { + t.Error("Timeout not set properly") + } + + if len(req.msgSets) != 2 { + t.Error("Wrong number of topics in request") + } +} + +func TestProduceSetCompressedRequestBuilding(t *testing.T) { + parent, ps := makeProduceSet() + parent.conf.Producer.RequiredAcks = WaitForAll + parent.conf.Producer.Timeout = 10 * time.Second + parent.conf.Producer.Compression = CompressionGZIP + parent.conf.Version = V0_10_0_0 + + msg := &ProducerMessage{ + Topic: "t1", + Partition: 0, + Key: StringEncoder(TestMessage), + Value: StringEncoder(TestMessage), + Timestamp: time.Now(), + } + for i := 0; i < 10; i++ { + safeAddMessage(t, ps, msg) + } + + req := ps.buildRequest() + + if req.Version != 2 { + t.Error("Wrong request version") + } + + for _, msgBlock := range req.msgSets["t1"][0].Messages { + msg := msgBlock.Msg + err := msg.decodeSet() + if err != nil { + t.Error("Failed to decode set from payload") + } + for _, compMsgBlock := range msg.Set.Messages { + compMsg := compMsgBlock.Msg + if compMsg.Version != 1 { + t.Error("Wrong compressed message version") + } + } + if msg.Version != 1 { + t.Error("Wrong compressed parent message version") + } + } +} diff --git a/vendor/github.com/Shopify/sarama/real_decoder.go b/vendor/github.com/Shopify/sarama/real_decoder.go new file mode 100644 index 00000000..3cf93533 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/real_decoder.go @@ -0,0 +1,260 @@ +package sarama + +import ( + "encoding/binary" + "math" +) + +var errInvalidArrayLength = PacketDecodingError{"invalid array length"} +var errInvalidByteSliceLength = PacketDecodingError{"invalid byteslice length"} +var errInvalidStringLength = PacketDecodingError{"invalid string length"} +var errInvalidSubsetSize = PacketDecodingError{"invalid subset size"} + +type realDecoder struct { + raw []byte + off int + stack []pushDecoder +} + +// primitives + +func (rd *realDecoder) getInt8() (int8, error) { + if rd.remaining() < 1 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int8(rd.raw[rd.off]) + rd.off++ + return tmp, nil +} + +func (rd *realDecoder) getInt16() (int16, error) { + if rd.remaining() < 2 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int16(binary.BigEndian.Uint16(rd.raw[rd.off:])) + rd.off += 2 + return tmp, nil +} + +func (rd *realDecoder) getInt32() (int32, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int32(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + return tmp, nil +} + +func (rd *realDecoder) getInt64() (int64, error) { + if rd.remaining() < 8 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int64(binary.BigEndian.Uint64(rd.raw[rd.off:])) + rd.off += 8 + return tmp, nil +} + +func (rd *realDecoder) getArrayLength() (int, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + if tmp > rd.remaining() { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } else if tmp > 2*math.MaxUint16 { + return -1, errInvalidArrayLength + } + return tmp, nil +} + +// collections + +func (rd *realDecoder) getBytes() ([]byte, error) { + tmp, err := rd.getInt32() + + if err != nil { + return nil, err + } + + n := int(tmp) + + switch { + case n < -1: + return nil, errInvalidByteSliceLength + case n == -1: + return nil, nil + case n == 0: + return make([]byte, 0), nil + case n > rd.remaining(): + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + + tmpStr := rd.raw[rd.off : rd.off+n] + rd.off += n + return tmpStr, nil +} + +func (rd *realDecoder) getString() (string, error) { + tmp, err := rd.getInt16() + + if err != nil { + return "", err + } + + n := int(tmp) + + switch { + case n < -1: + return "", errInvalidStringLength + case n == -1: + return "", nil + case n == 0: + return "", nil + case n > rd.remaining(): + rd.off = len(rd.raw) + return "", ErrInsufficientData + } + + tmpStr := string(rd.raw[rd.off : rd.off+n]) + rd.off += n + return tmpStr, nil +} + +func (rd *realDecoder) getInt32Array() ([]int32, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + n := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + + if rd.remaining() < 4*n { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + + if n == 0 { + return nil, nil + } + + if n < 0 { + return nil, errInvalidArrayLength + } + + ret := make([]int32, n) + for i := range ret { + ret[i] = int32(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + } + return ret, nil +} + +func (rd *realDecoder) getInt64Array() ([]int64, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + n := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + + if rd.remaining() < 8*n { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + + if n == 0 { + return nil, nil + } + + if n < 0 { + return nil, errInvalidArrayLength + } + + ret := make([]int64, n) + for i := range ret { + ret[i] = int64(binary.BigEndian.Uint64(rd.raw[rd.off:])) + rd.off += 8 + } + return ret, nil +} + +func (rd *realDecoder) getStringArray() ([]string, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + n := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + rd.off += 4 + + if n == 0 { + return nil, nil + } + + if n < 0 { + return nil, errInvalidArrayLength + } + + ret := make([]string, n) + for i := range ret { + str, err := rd.getString() + if err != nil { + return nil, err + } + + ret[i] = str + } + return ret, nil +} + +// subsets + +func (rd *realDecoder) remaining() int { + return len(rd.raw) - rd.off +} + +func (rd *realDecoder) getSubset(length int) (packetDecoder, error) { + if length < 0 { + return nil, errInvalidSubsetSize + } else if length > rd.remaining() { + rd.off = len(rd.raw) + return nil, ErrInsufficientData + } + + start := rd.off + rd.off += length + return &realDecoder{raw: rd.raw[start:rd.off]}, nil +} + +// stacks + +func (rd *realDecoder) push(in pushDecoder) error { + in.saveOffset(rd.off) + + reserve := in.reserveLength() + if rd.remaining() < reserve { + rd.off = len(rd.raw) + return ErrInsufficientData + } + + rd.stack = append(rd.stack, in) + + rd.off += reserve + + return nil +} + +func (rd *realDecoder) pop() error { + // this is go's ugly pop pattern (the inverse of append) + in := rd.stack[len(rd.stack)-1] + rd.stack = rd.stack[:len(rd.stack)-1] + + return in.check(rd.off, rd.raw) +} diff --git a/vendor/github.com/Shopify/sarama/real_encoder.go b/vendor/github.com/Shopify/sarama/real_encoder.go new file mode 100644 index 00000000..ced4267c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/real_encoder.go @@ -0,0 +1,129 @@ +package sarama + +import ( + "encoding/binary" + + "github.com/rcrowley/go-metrics" +) + +type realEncoder struct { + raw []byte + off int + stack []pushEncoder + registry metrics.Registry +} + +// primitives + +func (re *realEncoder) putInt8(in int8) { + re.raw[re.off] = byte(in) + re.off++ +} + +func (re *realEncoder) putInt16(in int16) { + binary.BigEndian.PutUint16(re.raw[re.off:], uint16(in)) + re.off += 2 +} + +func (re *realEncoder) putInt32(in int32) { + binary.BigEndian.PutUint32(re.raw[re.off:], uint32(in)) + re.off += 4 +} + +func (re *realEncoder) putInt64(in int64) { + binary.BigEndian.PutUint64(re.raw[re.off:], uint64(in)) + re.off += 8 +} + +func (re *realEncoder) putArrayLength(in int) error { + re.putInt32(int32(in)) + return nil +} + +// collection + +func (re *realEncoder) putRawBytes(in []byte) error { + copy(re.raw[re.off:], in) + re.off += len(in) + return nil +} + +func (re *realEncoder) putBytes(in []byte) error { + if in == nil { + re.putInt32(-1) + return nil + } + re.putInt32(int32(len(in))) + copy(re.raw[re.off:], in) + re.off += len(in) + return nil +} + +func (re *realEncoder) putString(in string) error { + re.putInt16(int16(len(in))) + copy(re.raw[re.off:], in) + re.off += len(in) + return nil +} + +func (re *realEncoder) putStringArray(in []string) error { + err := re.putArrayLength(len(in)) + if err != nil { + return err + } + + for _, val := range in { + if err := re.putString(val); err != nil { + return err + } + } + + return nil +} + +func (re *realEncoder) putInt32Array(in []int32) error { + err := re.putArrayLength(len(in)) + if err != nil { + return err + } + for _, val := range in { + re.putInt32(val) + } + return nil +} + +func (re *realEncoder) putInt64Array(in []int64) error { + err := re.putArrayLength(len(in)) + if err != nil { + return err + } + for _, val := range in { + re.putInt64(val) + } + return nil +} + +func (re *realEncoder) offset() int { + return re.off +} + +// stacks + +func (re *realEncoder) push(in pushEncoder) { + in.saveOffset(re.off) + re.off += in.reserveLength() + re.stack = append(re.stack, in) +} + +func (re *realEncoder) pop() error { + // this is go's ugly pop pattern (the inverse of append) + in := re.stack[len(re.stack)-1] + re.stack = re.stack[:len(re.stack)-1] + + return in.run(re.off, re.raw) +} + +// we do record metrics during the real encoder pass +func (re *realEncoder) metricRegistry() metrics.Registry { + return re.registry +} diff --git a/vendor/github.com/Shopify/sarama/request.go b/vendor/github.com/Shopify/sarama/request.go new file mode 100644 index 00000000..73310ca8 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/request.go @@ -0,0 +1,119 @@ +package sarama + +import ( + "encoding/binary" + "fmt" + "io" +) + +type protocolBody interface { + encoder + versionedDecoder + key() int16 + version() int16 + requiredVersion() KafkaVersion +} + +type request struct { + correlationID int32 + clientID string + body protocolBody +} + +func (r *request) encode(pe packetEncoder) (err error) { + pe.push(&lengthField{}) + pe.putInt16(r.body.key()) + pe.putInt16(r.body.version()) + pe.putInt32(r.correlationID) + err = pe.putString(r.clientID) + if err != nil { + return err + } + err = r.body.encode(pe) + if err != nil { + return err + } + return pe.pop() +} + +func (r *request) decode(pd packetDecoder) (err error) { + var key int16 + if key, err = pd.getInt16(); err != nil { + return err + } + var version int16 + if version, err = pd.getInt16(); err != nil { + return err + } + if r.correlationID, err = pd.getInt32(); err != nil { + return err + } + r.clientID, err = pd.getString() + + r.body = allocateBody(key, version) + if r.body == nil { + return PacketDecodingError{fmt.Sprintf("unknown request key (%d)", key)} + } + return r.body.decode(pd, version) +} + +func decodeRequest(r io.Reader) (req *request, bytesRead int, err error) { + lengthBytes := make([]byte, 4) + if _, err := io.ReadFull(r, lengthBytes); err != nil { + return nil, bytesRead, err + } + bytesRead += len(lengthBytes) + + length := int32(binary.BigEndian.Uint32(lengthBytes)) + if length <= 4 || length > MaxRequestSize { + return nil, bytesRead, PacketDecodingError{fmt.Sprintf("message of length %d too large or too small", length)} + } + + encodedReq := make([]byte, length) + if _, err := io.ReadFull(r, encodedReq); err != nil { + return nil, bytesRead, err + } + bytesRead += len(encodedReq) + + req = &request{} + if err := decode(encodedReq, req); err != nil { + return nil, bytesRead, err + } + return req, bytesRead, nil +} + +func allocateBody(key, version int16) protocolBody { + switch key { + case 0: + return &ProduceRequest{} + case 1: + return &FetchRequest{} + case 2: + return &OffsetRequest{Version: version} + case 3: + return &MetadataRequest{} + case 8: + return &OffsetCommitRequest{Version: version} + case 9: + return &OffsetFetchRequest{} + case 10: + return &ConsumerMetadataRequest{} + case 11: + return &JoinGroupRequest{} + case 12: + return &HeartbeatRequest{} + case 13: + return &LeaveGroupRequest{} + case 14: + return &SyncGroupRequest{} + case 15: + return &DescribeGroupsRequest{} + case 16: + return &ListGroupsRequest{} + case 17: + return &SaslHandshakeRequest{} + case 18: + return &ApiVersionsRequest{} + } + return nil +} diff --git a/vendor/github.com/Shopify/sarama/request_test.go b/vendor/github.com/Shopify/sarama/request_test.go new file mode 100644 index 00000000..bd9cef4e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/request_test.go @@ -0,0 +1,98 @@ +package sarama + +import ( + "bytes" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +type testRequestBody struct { +} + +func (s *testRequestBody) key() int16 { + return 0x666 +} + +func (s *testRequestBody) version() int16 { + return 0xD2 +} + +func (s *testRequestBody) encode(pe packetEncoder) error { + return pe.putString("abc") +} + +// not specific to request tests, just helper functions for testing structures that +// implement the encoder or decoder interfaces that needed somewhere to live + +func testEncodable(t *testing.T, name string, in encoder, expect []byte) { + packet, err := encode(in, nil) + if err != nil { + t.Error(err) + } else if !bytes.Equal(packet, expect) { + t.Error("Encoding", name, "failed\ngot ", packet, "\nwant", expect) + } +} + +func testDecodable(t *testing.T, name string, out decoder, in []byte) { + err := decode(in, out) + if err != nil { + t.Error("Decoding", name, "failed:", err) + } +} + +func testVersionDecodable(t *testing.T, name string, out versionedDecoder, in []byte, version int16) { + err := versionedDecode(in, out, version) + if err != nil { + t.Error("Decoding", name, "version", version, "failed:", err) + } +} + +func testRequest(t *testing.T, name string, rb protocolBody, expected []byte) { + packet := testRequestEncode(t, name, rb, expected) + testRequestDecode(t, name, rb, packet) +} + +func testRequestEncode(t *testing.T, name string, rb protocolBody, expected []byte) []byte { + req := &request{correlationID: 123, clientID: "foo", body: rb} + packet, err := encode(req, nil) + headerSize := 14 + len("foo") + if err != nil { + t.Error(err) + } else if !bytes.Equal(packet[headerSize:], expected) { + t.Error("Encoding", name, "failed\ngot ", packet[headerSize:], "\nwant", expected) + } + return packet +} + +func testRequestDecode(t *testing.T, name string, rb protocolBody, packet []byte) { + decoded, n, err := decodeRequest(bytes.NewReader(packet)) + if err != nil { + t.Error("Failed to decode request", err) + } else if decoded.correlationID != 123 || decoded.clientID != "foo" { + t.Errorf("Decoded header %q is not valid: %+v", name, decoded) + } else if !reflect.DeepEqual(rb, decoded.body) { + t.Error(spew.Sprintf("Decoded request %q does not match the encoded one\nencoded: %+v\ndecoded: %+v", name, rb, decoded.body)) + } else if n != len(packet) { + t.Errorf("Decoded request %q bytes: %d does not match the encoded one: %d\n", name, n, len(packet)) + } +} + +func testResponse(t *testing.T, name string, res protocolBody, expected []byte) { + encoded, err := encode(res, nil) + if err != nil { + t.Error(err) + } else if expected != nil && !bytes.Equal(encoded, expected) { + t.Error("Encoding", name, "failed\ngot ", encoded, "\nwant", expected) + } + + decoded := reflect.New(reflect.TypeOf(res).Elem()).Interface().(versionedDecoder) + if err := versionedDecode(encoded, decoded, res.version()); err != nil { + t.Error("Decoding", name, "failed:", err) + } + + if !reflect.DeepEqual(decoded, res) { + t.Errorf("Decoded response does not match the encoded one\nencoded: %#v\ndecoded: %#v", res, decoded) + } +} diff --git a/vendor/github.com/Shopify/sarama/response_header.go b/vendor/github.com/Shopify/sarama/response_header.go new file mode 100644 index 00000000..f3f4d27d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/response_header.go @@ -0,0 +1,21 @@ +package sarama + +import "fmt" + +type responseHeader struct { + length int32 + correlationID int32 +} + +func (r *responseHeader) decode(pd packetDecoder) (err error) { + r.length, err = pd.getInt32() + if err != nil { + return err + } + if r.length <= 4 || r.length > MaxResponseSize { + return PacketDecodingError{fmt.Sprintf("message of length %d too large or too small", r.length)} + } + + r.correlationID, err = pd.getInt32() + return err +} diff --git a/vendor/github.com/Shopify/sarama/response_header_test.go b/vendor/github.com/Shopify/sarama/response_header_test.go new file mode 100644 index 00000000..8f9fdb80 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/response_header_test.go @@ -0,0 +1,21 @@ +package sarama + +import "testing" + +var ( + responseHeaderBytes = []byte{ + 0x00, 0x00, 0x0f, 0x00, + 0x0a, 0xbb, 0xcc, 0xff} +) + +func TestResponseHeader(t *testing.T) { + header := responseHeader{} + + testDecodable(t, "response header", &header, responseHeaderBytes) + if header.length != 0xf00 { + t.Error("Decoding header length failed, got", header.length) + } + if header.correlationID != 0x0abbccff { + t.Error("Decoding header correlation id failed, got", header.correlationID) + } +} diff --git a/vendor/github.com/Shopify/sarama/sarama.go b/vendor/github.com/Shopify/sarama/sarama.go new file mode 100644 index 00000000..7d5dc60d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sarama.go @@ -0,0 +1,99 @@ +/* +Package sarama is a pure Go client library for dealing with Apache Kafka (versions 0.8 and later). It includes a high-level +API for easily producing and consuming messages, and a low-level API for controlling bytes on the wire when the high-level +API is insufficient. Usage examples for the high-level APIs are provided inline with their full documentation. + +To produce messages, use either the AsyncProducer or the SyncProducer. The AsyncProducer accepts messages on a channel +and produces them asynchronously in the background as efficiently as possible; it is preferred in most cases. +The SyncProducer provides a method which will block until Kafka acknowledges the message as produced. This can be +useful but comes with two caveats: it will generally be less efficient, and the actual durability guarantees +depend on the configured value of `Producer.RequiredAcks`. There are configurations where a message acknowledged by the +SyncProducer can still sometimes be lost. + +To consume messages, use the Consumer. Note that Sarama's Consumer implementation does not currently support automatic +consumer-group rebalancing and offset tracking. For Zookeeper-based tracking (Kafka 0.8.2 and earlier), the +https://github.com/wvanbergen/kafka library builds on Sarama to add this support. For Kafka-based tracking (Kafka 0.9 +and later), the https://github.com/bsm/sarama-cluster library builds on Sarama to add this support. + +For lower-level needs, the Broker and Request/Response objects permit precise control over each connection +and message sent on the wire; the Client provides higher-level metadata management that is shared between +the producers and the consumer. The Request/Response objects and properties are mostly undocumented, as they line up +exactly with the protocol fields documented by Kafka at +https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol + +Metrics are exposed through https://github.com/rcrowley/go-metrics library in a local registry. + +Broker related metrics: + + +----------------------------------------------+------------+---------------------------------------------------------------+ + | Name | Type | Description | + +----------------------------------------------+------------+---------------------------------------------------------------+ + | incoming-byte-rate | meter | Bytes/second read off all brokers | + | incoming-byte-rate-for-broker- | meter | Bytes/second read off a given broker | + | outgoing-byte-rate | meter | Bytes/second written off all brokers | + | outgoing-byte-rate-for-broker- | meter | Bytes/second written off a given broker | + | request-rate | meter | Requests/second sent to all brokers | + | request-rate-for-broker- | meter | Requests/second sent to a given broker | + | request-size | histogram | Distribution of the request size in bytes for all brokers | + | request-size-for-broker- | histogram | Distribution of the request size in bytes for a given broker | + | request-latency-in-ms | histogram | Distribution of the request latency in ms for all brokers | + | request-latency-in-ms-for-broker- | histogram | Distribution of the request latency in ms for a given broker | + | response-rate | meter | Responses/second received from all brokers | + | response-rate-for-broker- | meter | Responses/second received from a given broker | + | response-size | histogram | Distribution of the response size in bytes for all brokers | + | response-size-for-broker- | histogram | Distribution of the response size in bytes for a given broker | + +----------------------------------------------+------------+---------------------------------------------------------------+ + +Note that we do not gather specific metrics for seed brokers but they are part of the "all brokers" metrics. + +Producer related metrics: + + +-------------------------------------------+------------+--------------------------------------------------------------------------------------+ + | Name | Type | Description | + +-------------------------------------------+------------+--------------------------------------------------------------------------------------+ + | batch-size | histogram | Distribution of the number of bytes sent per partition per request for all topics | + | batch-size-for-topic- | histogram | Distribution of the number of bytes sent per partition per request for a given topic | + | record-send-rate | meter | Records/second sent to all topics | + | record-send-rate-for-topic- | meter | Records/second sent to a given topic | + | records-per-request | histogram | Distribution of the number of records sent per request for all topics | + | records-per-request-for-topic- | histogram | Distribution of the number of records sent per request for a given topic | + | compression-ratio | histogram | Distribution of the compression ratio times 100 of record batches for all topics | + | compression-ratio-for-topic- | histogram | Distribution of the compression ratio times 100 of record batches for a given topic | + +-------------------------------------------+------------+--------------------------------------------------------------------------------------+ + +*/ +package sarama + +import ( + "io/ioutil" + "log" +) + +// Logger is the instance of a StdLogger interface that Sarama writes connection +// management events to. By default it is set to discard all log messages via ioutil.Discard, +// but you can set it to redirect wherever you want. +var Logger StdLogger = log.New(ioutil.Discard, "[Sarama] ", log.LstdFlags) + +// StdLogger is used to log error messages. +type StdLogger interface { + Print(v ...interface{}) + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +// PanicHandler is called for recovering from panics spawned internally to the library (and thus +// not recoverable by the caller's goroutine). Defaults to nil, which means panics are not recovered. +var PanicHandler func(interface{}) + +// MaxRequestSize is the maximum size (in bytes) of any request that Sarama will attempt to send. Trying +// to send a request larger than this will result in an PacketEncodingError. The default of 100 MiB is aligned +// with Kafka's default `socket.request.max.bytes`, which is the largest request the broker will attempt +// to process. +var MaxRequestSize int32 = 100 * 1024 * 1024 + +// MaxResponseSize is the maximum size (in bytes) of any response that Sarama will attempt to parse. If +// a broker returns a response message larger than this value, Sarama will return a PacketDecodingError to +// protect the client from running out of memory. Please note that brokers do not have any natural limit on +// the size of responses they send. In particular, they can send arbitrarily large fetch responses to consumers +// (see https://issues.apache.org/jira/browse/KAFKA-2063). +var MaxResponseSize int32 = 100 * 1024 * 1024 diff --git a/vendor/github.com/Shopify/sarama/sasl_handshake_request.go b/vendor/github.com/Shopify/sarama/sasl_handshake_request.go new file mode 100644 index 00000000..fbbc8947 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sasl_handshake_request.go @@ -0,0 +1,33 @@ +package sarama + +type SaslHandshakeRequest struct { + Mechanism string +} + +func (r *SaslHandshakeRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.Mechanism); err != nil { + return err + } + + return nil +} + +func (r *SaslHandshakeRequest) decode(pd packetDecoder, version int16) (err error) { + if r.Mechanism, err = pd.getString(); err != nil { + return err + } + + return nil +} + +func (r *SaslHandshakeRequest) key() int16 { + return 17 +} + +func (r *SaslHandshakeRequest) version() int16 { + return 0 +} + +func (r *SaslHandshakeRequest) requiredVersion() KafkaVersion { + return V0_10_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/sasl_handshake_request_test.go b/vendor/github.com/Shopify/sarama/sasl_handshake_request_test.go new file mode 100644 index 00000000..806e628f --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sasl_handshake_request_test.go @@ -0,0 +1,17 @@ +package sarama + +import "testing" + +var ( + baseSaslRequest = []byte{ + 0, 3, 'f', 'o', 'o', // Mechanism + } +) + +func TestSaslHandshakeRequest(t *testing.T) { + var request *SaslHandshakeRequest + + request = new(SaslHandshakeRequest) + request.Mechanism = "foo" + testRequest(t, "basic", request, baseSaslRequest) +} diff --git a/vendor/github.com/Shopify/sarama/sasl_handshake_response.go b/vendor/github.com/Shopify/sarama/sasl_handshake_response.go new file mode 100644 index 00000000..ef290d4b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sasl_handshake_response.go @@ -0,0 +1,38 @@ +package sarama + +type SaslHandshakeResponse struct { + Err KError + EnabledMechanisms []string +} + +func (r *SaslHandshakeResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + return pe.putStringArray(r.EnabledMechanisms) +} + +func (r *SaslHandshakeResponse) decode(pd packetDecoder, version int16) error { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Err = KError(kerr) + + if r.EnabledMechanisms, err = pd.getStringArray(); err != nil { + return err + } + + return nil +} + +func (r *SaslHandshakeResponse) key() int16 { + return 17 +} + +func (r *SaslHandshakeResponse) version() int16 { + return 0 +} + +func (r *SaslHandshakeResponse) requiredVersion() KafkaVersion { + return V0_10_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/sasl_handshake_response_test.go b/vendor/github.com/Shopify/sarama/sasl_handshake_response_test.go new file mode 100644 index 00000000..1fd4c79e --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sasl_handshake_response_test.go @@ -0,0 +1,24 @@ +package sarama + +import "testing" + +var ( + saslHandshakeResponse = []byte{ + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x03, 'f', 'o', 'o', + } +) + +func TestSaslHandshakeResponse(t *testing.T) { + var response *SaslHandshakeResponse + + response = new(SaslHandshakeResponse) + testVersionDecodable(t, "no error", response, saslHandshakeResponse, 0) + if response.Err != ErrNoError { + t.Error("Decoding error failed: no error expected but found", response.Err) + } + if response.EnabledMechanisms[0] != "foo" { + t.Error("Decoding error failed: expected 'foo' but found", response.EnabledMechanisms) + } +} diff --git a/vendor/github.com/Shopify/sarama/sync_group_request.go b/vendor/github.com/Shopify/sarama/sync_group_request.go new file mode 100644 index 00000000..fe207080 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_group_request.go @@ -0,0 +1,100 @@ +package sarama + +type SyncGroupRequest struct { + GroupId string + GenerationId int32 + MemberId string + GroupAssignments map[string][]byte +} + +func (r *SyncGroupRequest) encode(pe packetEncoder) error { + if err := pe.putString(r.GroupId); err != nil { + return err + } + + pe.putInt32(r.GenerationId) + + if err := pe.putString(r.MemberId); err != nil { + return err + } + + if err := pe.putArrayLength(len(r.GroupAssignments)); err != nil { + return err + } + for memberId, memberAssignment := range r.GroupAssignments { + if err := pe.putString(memberId); err != nil { + return err + } + if err := pe.putBytes(memberAssignment); err != nil { + return err + } + } + + return nil +} + +func (r *SyncGroupRequest) decode(pd packetDecoder, version int16) (err error) { + if r.GroupId, err = pd.getString(); err != nil { + return + } + if r.GenerationId, err = pd.getInt32(); err != nil { + return + } + if r.MemberId, err = pd.getString(); err != nil { + return + } + + n, err := pd.getArrayLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + + r.GroupAssignments = make(map[string][]byte) + for i := 0; i < n; i++ { + memberId, err := pd.getString() + if err != nil { + return err + } + memberAssignment, err := pd.getBytes() + if err != nil { + return err + } + + r.GroupAssignments[memberId] = memberAssignment + } + + return nil +} + +func (r *SyncGroupRequest) key() int16 { + return 14 +} + +func (r *SyncGroupRequest) version() int16 { + return 0 +} + +func (r *SyncGroupRequest) requiredVersion() KafkaVersion { + return V0_9_0_0 +} + +func (r *SyncGroupRequest) AddGroupAssignment(memberId string, memberAssignment []byte) { + if r.GroupAssignments == nil { + r.GroupAssignments = make(map[string][]byte) + } + + r.GroupAssignments[memberId] = memberAssignment +} + +func (r *SyncGroupRequest) AddGroupAssignmentMember(memberId string, memberAssignment *ConsumerGroupMemberAssignment) error { + bin, err := encode(memberAssignment, nil) + if err != nil { + return err + } + + r.AddGroupAssignment(memberId, bin) + return nil +} diff --git a/vendor/github.com/Shopify/sarama/sync_group_request_test.go b/vendor/github.com/Shopify/sarama/sync_group_request_test.go new file mode 100644 index 00000000..3f537ef9 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_group_request_test.go @@ -0,0 +1,38 @@ +package sarama + +import "testing" + +var ( + emptySyncGroupRequest = []byte{ + 0, 3, 'f', 'o', 'o', // Group ID + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 3, 'b', 'a', 'z', // Member ID + 0, 0, 0, 0, // no assignments + } + + populatedSyncGroupRequest = []byte{ + 0, 3, 'f', 'o', 'o', // Group ID + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 3, 'b', 'a', 'z', // Member ID + 0, 0, 0, 1, // one assignment + 0, 3, 'b', 'a', 'z', // Member ID + 0, 0, 0, 3, 'f', 'o', 'o', // Member assignment + } +) + +func TestSyncGroupRequest(t *testing.T) { + var request *SyncGroupRequest + + request = new(SyncGroupRequest) + request.GroupId = "foo" + request.GenerationId = 66051 + request.MemberId = "baz" + testRequest(t, "empty", request, emptySyncGroupRequest) + + request = new(SyncGroupRequest) + request.GroupId = "foo" + request.GenerationId = 66051 + request.MemberId = "baz" + request.AddGroupAssignment("baz", []byte("foo")) + testRequest(t, "populated", request, populatedSyncGroupRequest) +} diff --git a/vendor/github.com/Shopify/sarama/sync_group_response.go b/vendor/github.com/Shopify/sarama/sync_group_response.go new file mode 100644 index 00000000..194b382b --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_group_response.go @@ -0,0 +1,41 @@ +package sarama + +type SyncGroupResponse struct { + Err KError + MemberAssignment []byte +} + +func (r *SyncGroupResponse) GetMemberAssignment() (*ConsumerGroupMemberAssignment, error) { + assignment := new(ConsumerGroupMemberAssignment) + err := decode(r.MemberAssignment, assignment) + return assignment, err +} + +func (r *SyncGroupResponse) encode(pe packetEncoder) error { + pe.putInt16(int16(r.Err)) + return pe.putBytes(r.MemberAssignment) +} + +func (r *SyncGroupResponse) decode(pd packetDecoder, version int16) (err error) { + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Err = KError(kerr) + + r.MemberAssignment, err = pd.getBytes() + return +} + +func (r *SyncGroupResponse) key() int16 { + return 14 +} + +func (r *SyncGroupResponse) version() int16 { + return 0 +} + +func (r *SyncGroupResponse) requiredVersion() KafkaVersion { + return V0_9_0_0 +} diff --git a/vendor/github.com/Shopify/sarama/sync_group_response_test.go b/vendor/github.com/Shopify/sarama/sync_group_response_test.go new file mode 100644 index 00000000..6fb70885 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_group_response_test.go @@ -0,0 +1,40 @@ +package sarama + +import ( + "reflect" + "testing" +) + +var ( + syncGroupResponseNoError = []byte{ + 0x00, 0x00, // No error + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Member assignment data + } + + syncGroupResponseWithError = []byte{ + 0, 27, // ErrRebalanceInProgress + 0, 0, 0, 0, // No member assignment data + } +) + +func TestSyncGroupResponse(t *testing.T) { + var response *SyncGroupResponse + + response = new(SyncGroupResponse) + testVersionDecodable(t, "no error", response, syncGroupResponseNoError, 0) + if response.Err != ErrNoError { + t.Error("Decoding Err failed: no error expected but found", response.Err) + } + if !reflect.DeepEqual(response.MemberAssignment, []byte{0x01, 0x02, 0x03}) { + t.Error("Decoding MemberAssignment failed, found:", response.MemberAssignment) + } + + response = new(SyncGroupResponse) + testVersionDecodable(t, "no error", response, syncGroupResponseWithError, 0) + if response.Err != ErrRebalanceInProgress { + t.Error("Decoding Err failed: ErrRebalanceInProgress expected but found", response.Err) + } + if !reflect.DeepEqual(response.MemberAssignment, []byte{}) { + t.Error("Decoding MemberAssignment failed, found:", response.MemberAssignment) + } +} diff --git a/vendor/github.com/Shopify/sarama/sync_producer.go b/vendor/github.com/Shopify/sarama/sync_producer.go new file mode 100644 index 00000000..dd096b6d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_producer.go @@ -0,0 +1,164 @@ +package sarama + +import "sync" + +// SyncProducer publishes Kafka messages, blocking until they have been acknowledged. It routes messages to the correct +// broker, refreshing metadata as appropriate, and parses responses for errors. You must call Close() on a producer +// to avoid leaks, it may not be garbage-collected automatically when it passes out of scope. +// +// The SyncProducer comes with two caveats: it will generally be less efficient than the AsyncProducer, and the actual +// durability guarantee provided when a message is acknowledged depend on the configured value of `Producer.RequiredAcks`. +// There are configurations where a message acknowledged by the SyncProducer can still sometimes be lost. +// +// For implementation reasons, the SyncProducer requires `Producer.Return.Errors` and `Producer.Return.Successes` to +// be set to true in its configuration. +type SyncProducer interface { + + // SendMessage produces a given message, and returns only when it either has + // succeeded or failed to produce. It will return the partition and the offset + // of the produced message, or an error if the message failed to produce. + SendMessage(msg *ProducerMessage) (partition int32, offset int64, err error) + + // SendMessages produces a given set of messages, and returns only when all + // messages in the set have either succeeded or failed. Note that messages + // can succeed and fail individually; if some succeed and some fail, + // SendMessages will return an error. + SendMessages(msgs []*ProducerMessage) error + + // Close shuts down the producer and waits for any buffered messages to be + // flushed. You must call this function before a producer object passes out of + // scope, as it may otherwise leak memory. You must call this before calling + // Close on the underlying client. + Close() error +} + +type syncProducer struct { + producer *asyncProducer + wg sync.WaitGroup +} + +// NewSyncProducer creates a new SyncProducer using the given broker addresses and configuration. +func NewSyncProducer(addrs []string, config *Config) (SyncProducer, error) { + if config == nil { + config = NewConfig() + config.Producer.Return.Successes = true + } + + if err := verifyProducerConfig(config); err != nil { + return nil, err + } + + p, err := NewAsyncProducer(addrs, config) + if err != nil { + return nil, err + } + return newSyncProducerFromAsyncProducer(p.(*asyncProducer)), nil +} + +// NewSyncProducerFromClient creates a new SyncProducer using the given client. It is still +// necessary to call Close() on the underlying client when shutting down this producer. +func NewSyncProducerFromClient(client Client) (SyncProducer, error) { + if err := verifyProducerConfig(client.Config()); err != nil { + return nil, err + } + + p, err := NewAsyncProducerFromClient(client) + if err != nil { + return nil, err + } + return newSyncProducerFromAsyncProducer(p.(*asyncProducer)), nil +} + +func newSyncProducerFromAsyncProducer(p *asyncProducer) *syncProducer { + sp := &syncProducer{producer: p} + + sp.wg.Add(2) + go withRecover(sp.handleSuccesses) + go withRecover(sp.handleErrors) + + return sp +} + +func verifyProducerConfig(config *Config) error { + if !config.Producer.Return.Errors { + return ConfigurationError("Producer.Return.Errors must be true to be used in a SyncProducer") + } + if !config.Producer.Return.Successes { + return ConfigurationError("Producer.Return.Successes must be true to be used in a SyncProducer") + } + return nil +} + +func (sp *syncProducer) SendMessage(msg *ProducerMessage) (partition int32, offset int64, err error) { + oldMetadata := msg.Metadata + defer func() { + msg.Metadata = oldMetadata + }() + + expectation := make(chan *ProducerError, 1) + msg.Metadata = expectation + sp.producer.Input() <- msg + + if err := <-expectation; err != nil { + return -1, -1, err.Err + } + + return msg.Partition, msg.Offset, nil +} + +func (sp *syncProducer) SendMessages(msgs []*ProducerMessage) error { + savedMetadata := make([]interface{}, len(msgs)) + for i := range msgs { + savedMetadata[i] = msgs[i].Metadata + } + defer func() { + for i := range msgs { + msgs[i].Metadata = savedMetadata[i] + } + }() + + expectations := make(chan chan *ProducerError, len(msgs)) + go func() { + for _, msg := range msgs { + expectation := make(chan *ProducerError, 1) + msg.Metadata = expectation + sp.producer.Input() <- msg + expectations <- expectation + } + close(expectations) + }() + + var errors ProducerErrors + for expectation := range expectations { + if err := <-expectation; err != nil { + errors = append(errors, err) + } + } + + if len(errors) > 0 { + return errors + } + return nil +} + +func (sp *syncProducer) handleSuccesses() { + defer sp.wg.Done() + for msg := range sp.producer.Successes() { + expectation := msg.Metadata.(chan *ProducerError) + expectation <- nil + } +} + +func (sp *syncProducer) handleErrors() { + defer sp.wg.Done() + for err := range sp.producer.Errors() { + expectation := err.Msg.Metadata.(chan *ProducerError) + expectation <- err + } +} + +func (sp *syncProducer) Close() error { + sp.producer.AsyncClose() + sp.wg.Wait() + return nil +} diff --git a/vendor/github.com/Shopify/sarama/sync_producer_test.go b/vendor/github.com/Shopify/sarama/sync_producer_test.go new file mode 100644 index 00000000..cb97548d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/sync_producer_test.go @@ -0,0 +1,199 @@ +package sarama + +import ( + "log" + "sync" + "testing" +) + +func TestSyncProducer(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + for i := 0; i < 10; i++ { + leader.Returns(prodSuccess) + } + + producer, err := NewSyncProducer([]string{seedBroker.Addr()}, nil) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + msg := &ProducerMessage{ + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + } + + partition, offset, err := producer.SendMessage(msg) + + if partition != 0 || msg.Partition != partition { + t.Error("Unexpected partition") + } + if offset != 0 || msg.Offset != offset { + t.Error("Unexpected offset") + } + if str, ok := msg.Metadata.(string); !ok || str != "test" { + t.Error("Unexpected metadata") + } + if err != nil { + t.Error(err) + } + } + + safeClose(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestSyncProducerBatch(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 3 + config.Producer.Return.Successes = true + producer, err := NewSyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + err = producer.SendMessages([]*ProducerMessage{ + { + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + { + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + { + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + }) + + if err != nil { + t.Error(err) + } + + safeClose(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestConcurrentSyncProducer(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 100 + config.Producer.Return.Successes = true + producer, err := NewSyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + msg := &ProducerMessage{Topic: "my_topic", Value: StringEncoder(TestMessage)} + partition, _, err := producer.SendMessage(msg) + if partition != 0 { + t.Error("Unexpected partition") + } + if err != nil { + t.Error(err) + } + wg.Done() + }() + } + wg.Wait() + + safeClose(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestSyncProducerToNonExistingTopic(t *testing.T) { + broker := NewMockBroker(t, 1) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(broker.Addr(), broker.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, broker.BrokerID(), nil, nil, ErrNoError) + broker.Returns(metadataResponse) + + config := NewConfig() + config.Metadata.Retry.Max = 0 + config.Producer.Retry.Max = 0 + config.Producer.Return.Successes = true + + producer, err := NewSyncProducer([]string{broker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + metadataResponse = new(MetadataResponse) + metadataResponse.AddTopic("unknown", ErrUnknownTopicOrPartition) + broker.Returns(metadataResponse) + + _, _, err = producer.SendMessage(&ProducerMessage{Topic: "unknown"}) + if err != ErrUnknownTopicOrPartition { + t.Error("Uxpected ErrUnknownTopicOrPartition, found:", err) + } + + safeClose(t, producer) + broker.Close() +} + +// This example shows the basic usage pattern of the SyncProducer. +func ExampleSyncProducer() { + producer, err := NewSyncProducer([]string{"localhost:9092"}, nil) + if err != nil { + log.Fatalln(err) + } + defer func() { + if err := producer.Close(); err != nil { + log.Fatalln(err) + } + }() + + msg := &ProducerMessage{Topic: "my_topic", Value: StringEncoder("testing 123")} + partition, offset, err := producer.SendMessage(msg) + if err != nil { + log.Printf("FAILED to send message: %s\n", err) + } else { + log.Printf("> message sent to partition %d at offset %d\n", partition, offset) + } +} diff --git a/vendor/github.com/Shopify/sarama/tools/README.md b/vendor/github.com/Shopify/sarama/tools/README.md new file mode 100644 index 00000000..3464c4ad --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/README.md @@ -0,0 +1,10 @@ +# Sarama tools + +This folder contains applications that are useful for exploration of your Kafka cluster, or instrumentation. +Some of these tools mirror tools that ship with Kafka, but these tools won't require installing the JVM to function. + +- [kafka-console-producer](./kafka-console-producer): a command line tool to produce a single message to your Kafka custer. +- [kafka-console-partitionconsumer](./kafka-console-partitionconsumer): (deprecated) a command line tool to consume a single partition of a topic on your Kafka cluster. +- [kafka-console-consumer](./kafka-console-consumer): a command line tool to consume arbitrary partitions of a topic on your Kafka cluster. + +To install all tools, run `go get github.com/Shopify/sarama/tools/...` diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/.gitignore b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/.gitignore new file mode 100644 index 00000000..67da9dfa --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/.gitignore @@ -0,0 +1,2 @@ +kafka-console-consumer +kafka-console-consumer.test diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/README.md b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/README.md new file mode 100644 index 00000000..4e77f0b7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/README.md @@ -0,0 +1,29 @@ +# kafka-console-consumer + +A simple command line tool to consume partitions of a topic and print the +messages on the standard output. + +### Installation + + go get github.com/Shopify/sarama/tools/kafka-console-consumer + +### Usage + + # Minimum invocation + kafka-console-consumer -topic=test -brokers=kafka1:9092 + + # It will pick up a KAFKA_PEERS environment variable + export KAFKA_PEERS=kafka1:9092,kafka2:9092,kafka3:9092 + kafka-console-consumer -topic=test + + # You can specify the offset you want to start at. It can be either + # `oldest`, `newest`. The default is `newest`. + kafka-console-consumer -topic=test -offset=oldest + kafka-console-consumer -topic=test -offset=newest + + # You can specify the partition(s) you want to consume as a comma-separated + # list. The default is `all`. + kafka-console-consumer -topic=test -partitions=1,2,3 + + # Display all command line options + kafka-console-consumer -help diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/kafka-console-consumer.go b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/kafka-console-consumer.go new file mode 100644 index 00000000..0f1eb89a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-consumer/kafka-console-consumer.go @@ -0,0 +1,145 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "os/signal" + "strconv" + "strings" + "sync" + + "github.com/Shopify/sarama" +) + +var ( + brokerList = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster") + topic = flag.String("topic", "", "REQUIRED: the topic to consume") + partitions = flag.String("partitions", "all", "The partitions to consume, can be 'all' or comma-separated numbers") + offset = flag.String("offset", "newest", "The offset to start with. Can be `oldest`, `newest`") + verbose = flag.Bool("verbose", false, "Whether to turn on sarama logging") + bufferSize = flag.Int("buffer-size", 256, "The buffer size of the message channel.") + + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +func main() { + flag.Parse() + + if *brokerList == "" { + printUsageErrorAndExit("You have to provide -brokers as a comma-separated list, or set the KAFKA_PEERS environment variable.") + } + + if *topic == "" { + printUsageErrorAndExit("-topic is required") + } + + if *verbose { + sarama.Logger = logger + } + + var initialOffset int64 + switch *offset { + case "oldest": + initialOffset = sarama.OffsetOldest + case "newest": + initialOffset = sarama.OffsetNewest + default: + printUsageErrorAndExit("-offset should be `oldest` or `newest`") + } + + c, err := sarama.NewConsumer(strings.Split(*brokerList, ","), nil) + if err != nil { + printErrorAndExit(69, "Failed to start consumer: %s", err) + } + + partitionList, err := getPartitions(c) + if err != nil { + printErrorAndExit(69, "Failed to get the list of partitions: %s", err) + } + + var ( + messages = make(chan *sarama.ConsumerMessage, *bufferSize) + closing = make(chan struct{}) + wg sync.WaitGroup + ) + + go func() { + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Kill, os.Interrupt) + <-signals + logger.Println("Initiating shutdown of consumer...") + close(closing) + }() + + for _, partition := range partitionList { + pc, err := c.ConsumePartition(*topic, partition, initialOffset) + if err != nil { + printErrorAndExit(69, "Failed to start consumer for partition %d: %s", partition, err) + } + + go func(pc sarama.PartitionConsumer) { + <-closing + pc.AsyncClose() + }(pc) + + wg.Add(1) + go func(pc sarama.PartitionConsumer) { + defer wg.Done() + for message := range pc.Messages() { + messages <- message + } + }(pc) + } + + go func() { + for msg := range messages { + fmt.Printf("Partition:\t%d\n", msg.Partition) + fmt.Printf("Offset:\t%d\n", msg.Offset) + fmt.Printf("Key:\t%s\n", string(msg.Key)) + fmt.Printf("Value:\t%s\n", string(msg.Value)) + fmt.Println() + } + }() + + wg.Wait() + logger.Println("Done consuming topic", *topic) + close(messages) + + if err := c.Close(); err != nil { + logger.Println("Failed to close consumer: ", err) + } +} + +func getPartitions(c sarama.Consumer) ([]int32, error) { + if *partitions == "all" { + return c.Partitions(*topic) + } + + tmp := strings.Split(*partitions, ",") + var pList []int32 + for i := range tmp { + val, err := strconv.ParseInt(tmp[i], 10, 32) + if err != nil { + return nil, err + } + pList = append(pList, int32(val)) + } + + return pList, nil +} + +func printErrorAndExit(code int, format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...)) + fmt.Fprintln(os.Stderr) + os.Exit(code) +} + +func printUsageErrorAndExit(format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...)) + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Available command line options:") + flag.PrintDefaults() + os.Exit(64) +} diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/.gitignore b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/.gitignore new file mode 100644 index 00000000..5837fe8c --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/.gitignore @@ -0,0 +1,2 @@ +kafka-console-partitionconsumer +kafka-console-partitionconsumer.test diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/README.md b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/README.md new file mode 100644 index 00000000..646dd5f5 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/README.md @@ -0,0 +1,28 @@ +# kafka-console-partitionconsumer + +NOTE: this tool is deprecated in favour of the more general and more powerful +`kafka-console-consumer`. + +A simple command line tool to consume a partition of a topic and print the messages +on the standard output. + +### Installation + + go get github.com/Shopify/sarama/tools/kafka-console-partitionconsumer + +### Usage + + # Minimum invocation + kafka-console-partitionconsumer -topic=test -partition=4 -brokers=kafka1:9092 + + # It will pick up a KAFKA_PEERS environment variable + export KAFKA_PEERS=kafka1:9092,kafka2:9092,kafka3:9092 + kafka-console-partitionconsumer -topic=test -partition=4 + + # You can specify the offset you want to start at. It can be either + # `oldest`, `newest`, or a specific offset number + kafka-console-partitionconsumer -topic=test -partition=3 -offset=oldest + kafka-console-partitionconsumer -topic=test -partition=2 -offset=1337 + + # Display all command line options + kafka-console-partitionconsumer -help diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/kafka-console-partitionconsumer.go b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/kafka-console-partitionconsumer.go new file mode 100644 index 00000000..d5e4464d --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-partitionconsumer/kafka-console-partitionconsumer.go @@ -0,0 +1,102 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "os/signal" + "strconv" + "strings" + + "github.com/Shopify/sarama" +) + +var ( + brokerList = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster") + topic = flag.String("topic", "", "REQUIRED: the topic to consume") + partition = flag.Int("partition", -1, "REQUIRED: the partition to consume") + offset = flag.String("offset", "newest", "The offset to start with. Can be `oldest`, `newest`, or an actual offset") + verbose = flag.Bool("verbose", false, "Whether to turn on sarama logging") + + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +func main() { + flag.Parse() + + if *brokerList == "" { + printUsageErrorAndExit("You have to provide -brokers as a comma-separated list, or set the KAFKA_PEERS environment variable.") + } + + if *topic == "" { + printUsageErrorAndExit("-topic is required") + } + + if *partition == -1 { + printUsageErrorAndExit("-partition is required") + } + + if *verbose { + sarama.Logger = logger + } + + var ( + initialOffset int64 + offsetError error + ) + switch *offset { + case "oldest": + initialOffset = sarama.OffsetOldest + case "newest": + initialOffset = sarama.OffsetNewest + default: + initialOffset, offsetError = strconv.ParseInt(*offset, 10, 64) + } + + if offsetError != nil { + printUsageErrorAndExit("Invalid initial offset: %s", *offset) + } + + c, err := sarama.NewConsumer(strings.Split(*brokerList, ","), nil) + if err != nil { + printErrorAndExit(69, "Failed to start consumer: %s", err) + } + + pc, err := c.ConsumePartition(*topic, int32(*partition), initialOffset) + if err != nil { + printErrorAndExit(69, "Failed to start partition consumer: %s", err) + } + + go func() { + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Kill, os.Interrupt) + <-signals + pc.AsyncClose() + }() + + for msg := range pc.Messages() { + fmt.Printf("Offset:\t%d\n", msg.Offset) + fmt.Printf("Key:\t%s\n", string(msg.Key)) + fmt.Printf("Value:\t%s\n", string(msg.Value)) + fmt.Println() + } + + if err := c.Close(); err != nil { + logger.Println("Failed to close consumer: ", err) + } +} + +func printErrorAndExit(code int, format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...)) + fmt.Fprintln(os.Stderr) + os.Exit(code) +} + +func printUsageErrorAndExit(format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...)) + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Available command line options:") + flag.PrintDefaults() + os.Exit(64) +} diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/.gitignore b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/.gitignore new file mode 100644 index 00000000..2b9e563a --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/.gitignore @@ -0,0 +1,2 @@ +kafka-console-producer +kafka-console-producer.test diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/README.md b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/README.md new file mode 100644 index 00000000..6b3a65f2 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/README.md @@ -0,0 +1,34 @@ +# kafka-console-producer + +A simple command line tool to produce a single message to Kafka. + +### Installation + + go get github.com/Shopify/sarama/tools/kafka-console-producer + + +### Usage + + # Minimum invocation + kafka-console-producer -topic=test -value=value -brokers=kafka1:9092 + + # It will pick up a KAFKA_PEERS environment variable + export KAFKA_PEERS=kafka1:9092,kafka2:9092,kafka3:9092 + kafka-console-producer -topic=test -value=value + + # It will read the value from stdin by using pipes + echo "hello world" | kafka-console-producer -topic=test + + # Specify a key: + echo "hello world" | kafka-console-producer -topic=test -key=key + + # Partitioning: by default, kafka-console-producer will partition as follows: + # - manual partitioning if a -partition is provided + # - hash partitioning by key if a -key is provided + # - random partioning otherwise. + # + # You can override this using the -partitioner argument: + echo "hello world" | kafka-console-producer -topic=test -key=key -partitioner=random + + # Display all command line options + kafka-console-producer -help diff --git a/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/kafka-console-producer.go b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/kafka-console-producer.go new file mode 100644 index 00000000..83054ed7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/tools/kafka-console-producer/kafka-console-producer.go @@ -0,0 +1,124 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + + "github.com/Shopify/sarama" + "github.com/rcrowley/go-metrics" +) + +var ( + brokerList = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster. You can also set the KAFKA_PEERS environment variable") + topic = flag.String("topic", "", "REQUIRED: the topic to produce to") + key = flag.String("key", "", "The key of the message to produce. Can be empty.") + value = flag.String("value", "", "REQUIRED: the value of the message to produce. You can also provide the value on stdin.") + partitioner = flag.String("partitioner", "", "The partitioning scheme to use. Can be `hash`, `manual`, or `random`") + partition = flag.Int("partition", -1, "The partition to produce to.") + verbose = flag.Bool("verbose", false, "Turn on sarama logging to stderr") + showMetrics = flag.Bool("metrics", false, "Output metrics on successful publish to stderr") + silent = flag.Bool("silent", false, "Turn off printing the message's topic, partition, and offset to stdout") + + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +func main() { + flag.Parse() + + if *brokerList == "" { + printUsageErrorAndExit("no -brokers specified. Alternatively, set the KAFKA_PEERS environment variable") + } + + if *topic == "" { + printUsageErrorAndExit("no -topic specified") + } + + if *verbose { + sarama.Logger = logger + } + + config := sarama.NewConfig() + config.Producer.RequiredAcks = sarama.WaitForAll + config.Producer.Return.Successes = true + + switch *partitioner { + case "": + if *partition >= 0 { + config.Producer.Partitioner = sarama.NewManualPartitioner + } else { + config.Producer.Partitioner = sarama.NewHashPartitioner + } + case "hash": + config.Producer.Partitioner = sarama.NewHashPartitioner + case "random": + config.Producer.Partitioner = sarama.NewRandomPartitioner + case "manual": + config.Producer.Partitioner = sarama.NewManualPartitioner + if *partition == -1 { + printUsageErrorAndExit("-partition is required when partitioning manually") + } + default: + printUsageErrorAndExit(fmt.Sprintf("Partitioner %s not supported.", *partitioner)) + } + + message := &sarama.ProducerMessage{Topic: *topic, Partition: int32(*partition)} + + if *key != "" { + message.Key = sarama.StringEncoder(*key) + } + + if *value != "" { + message.Value = sarama.StringEncoder(*value) + } else if stdinAvailable() { + bytes, err := ioutil.ReadAll(os.Stdin) + if err != nil { + printErrorAndExit(66, "Failed to read data from the standard input: %s", err) + } + message.Value = sarama.ByteEncoder(bytes) + } else { + printUsageErrorAndExit("-value is required, or you have to provide the value on stdin") + } + + producer, err := sarama.NewSyncProducer(strings.Split(*brokerList, ","), config) + if err != nil { + printErrorAndExit(69, "Failed to open Kafka producer: %s", err) + } + defer func() { + if err := producer.Close(); err != nil { + logger.Println("Failed to close Kafka producer cleanly:", err) + } + }() + + partition, offset, err := producer.SendMessage(message) + if err != nil { + printErrorAndExit(69, "Failed to produce message: %s", err) + } else if !*silent { + fmt.Printf("topic=%s\tpartition=%d\toffset=%d\n", *topic, partition, offset) + } + if *showMetrics { + metrics.WriteOnce(config.MetricRegistry, os.Stderr) + } +} + +func printErrorAndExit(code int, format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", fmt.Sprintf(format, values...)) + fmt.Fprintln(os.Stderr) + os.Exit(code) +} + +func printUsageErrorAndExit(message string) { + fmt.Fprintln(os.Stderr, "ERROR:", message) + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Available command line options:") + flag.PrintDefaults() + os.Exit(64) +} + +func stdinAvailable() bool { + stat, _ := os.Stdin.Stat() + return (stat.Mode() & os.ModeCharDevice) == 0 +} diff --git a/vendor/github.com/Shopify/sarama/utils.go b/vendor/github.com/Shopify/sarama/utils.go new file mode 100644 index 00000000..d36db921 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/utils.go @@ -0,0 +1,153 @@ +package sarama + +import ( + "bufio" + "net" + "sort" +) + +type none struct{} + +// make []int32 sortable so we can sort partition numbers +type int32Slice []int32 + +func (slice int32Slice) Len() int { + return len(slice) +} + +func (slice int32Slice) Less(i, j int) bool { + return slice[i] < slice[j] +} + +func (slice int32Slice) Swap(i, j int) { + slice[i], slice[j] = slice[j], slice[i] +} + +func dupeAndSort(input []int32) []int32 { + ret := make([]int32, 0, len(input)) + for _, val := range input { + ret = append(ret, val) + } + + sort.Sort(int32Slice(ret)) + return ret +} + +func withRecover(fn func()) { + defer func() { + handler := PanicHandler + if handler != nil { + if err := recover(); err != nil { + handler(err) + } + } + }() + + fn() +} + +func safeAsyncClose(b *Broker) { + tmp := b // local var prevents clobbering in goroutine + go withRecover(func() { + if connected, _ := tmp.Connected(); connected { + if err := tmp.Close(); err != nil { + Logger.Println("Error closing broker", tmp.ID(), ":", err) + } + } + }) +} + +// Encoder is a simple interface for any type that can be encoded as an array of bytes +// in order to be sent as the key or value of a Kafka message. Length() is provided as an +// optimization, and must return the same as len() on the result of Encode(). +type Encoder interface { + Encode() ([]byte, error) + Length() int +} + +// make strings and byte slices encodable for convenience so they can be used as keys +// and/or values in kafka messages + +// StringEncoder implements the Encoder interface for Go strings so that they can be used +// as the Key or Value in a ProducerMessage. +type StringEncoder string + +func (s StringEncoder) Encode() ([]byte, error) { + return []byte(s), nil +} + +func (s StringEncoder) Length() int { + return len(s) +} + +// ByteEncoder implements the Encoder interface for Go byte slices so that they can be used +// as the Key or Value in a ProducerMessage. +type ByteEncoder []byte + +func (b ByteEncoder) Encode() ([]byte, error) { + return b, nil +} + +func (b ByteEncoder) Length() int { + return len(b) +} + +// bufConn wraps a net.Conn with a buffer for reads to reduce the number of +// reads that trigger syscalls. +type bufConn struct { + net.Conn + buf *bufio.Reader +} + +func newBufConn(conn net.Conn) *bufConn { + return &bufConn{ + Conn: conn, + buf: bufio.NewReader(conn), + } +} + +func (bc *bufConn) Read(b []byte) (n int, err error) { + return bc.buf.Read(b) +} + +// KafkaVersion instances represent versions of the upstream Kafka broker. +type KafkaVersion struct { + // it's a struct rather than just typing the array directly to make it opaque and stop people + // generating their own arbitrary versions + version [4]uint +} + +func newKafkaVersion(major, minor, veryMinor, patch uint) KafkaVersion { + return KafkaVersion{ + version: [4]uint{major, minor, veryMinor, patch}, + } +} + +// IsAtLeast return true if and only if the version it is called on is +// greater than or equal to the version passed in: +// V1.IsAtLeast(V2) // false +// V2.IsAtLeast(V1) // true +func (v KafkaVersion) IsAtLeast(other KafkaVersion) bool { + for i := range v.version { + if v.version[i] > other.version[i] { + return true + } else if v.version[i] < other.version[i] { + return false + } + } + return true +} + +// Effective constants defining the supported kafka versions. +var ( + V0_8_2_0 = newKafkaVersion(0, 8, 2, 0) + V0_8_2_1 = newKafkaVersion(0, 8, 2, 1) + V0_8_2_2 = newKafkaVersion(0, 8, 2, 2) + V0_9_0_0 = newKafkaVersion(0, 9, 0, 0) + V0_9_0_1 = newKafkaVersion(0, 9, 0, 1) + V0_10_0_0 = newKafkaVersion(0, 10, 0, 0) + V0_10_0_1 = newKafkaVersion(0, 10, 0, 1) + V0_10_1_0 = newKafkaVersion(0, 10, 1, 0) + V0_10_2_0 = newKafkaVersion(0, 10, 2, 0) + minVersion = V0_8_2_0 +) diff --git a/vendor/github.com/Shopify/sarama/utils_test.go b/vendor/github.com/Shopify/sarama/utils_test.go new file mode 100644 index 00000000..a9e09502 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/utils_test.go @@ -0,0 +1,21 @@ +package sarama + +import "testing" + +func TestVersionCompare(t *testing.T) { + if V0_8_2_0.IsAtLeast(V0_8_2_1) { + t.Error("0.8.2.0 >= 0.8.2.1") + } + if !V0_8_2_1.IsAtLeast(V0_8_2_0) { + t.Error("! 0.8.2.1 >= 0.8.2.0") + } + if !V0_8_2_0.IsAtLeast(V0_8_2_0) { + t.Error("! 0.8.2.0 >= 0.8.2.0") + } + if !V0_9_0_0.IsAtLeast(V0_8_2_1) { + t.Error("! 0.9.0.0 >= 0.8.2.1") + } + if V0_8_2_1.IsAtLeast(V0_10_0_0) { + t.Error("0.8.2.1 >= 0.10.0.0") + } +} diff --git a/vendor/github.com/Shopify/sarama/vagrant/boot_cluster.sh b/vendor/github.com/Shopify/sarama/vagrant/boot_cluster.sh new file mode 100644 index 00000000..95e47dde --- /dev/null +++ b/vendor/github.com/Shopify/sarama/vagrant/boot_cluster.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +set -ex + +# Launch and wait for toxiproxy +${REPOSITORY_ROOT}/vagrant/run_toxiproxy.sh & +while ! nc -q 1 localhost 2181 ${KAFKA_INSTALL_ROOT}/zookeeper-${ZK_PORT}/myid +done diff --git a/vendor/github.com/Shopify/sarama/vagrant/kafka.conf b/vendor/github.com/Shopify/sarama/vagrant/kafka.conf new file mode 100644 index 00000000..25101df5 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/vagrant/kafka.conf @@ -0,0 +1,9 @@ +start on started zookeeper-ZK_PORT +stop on stopping zookeeper-ZK_PORT + +# Use a script instead of exec (using env stanza leaks KAFKA_HEAP_OPTS from zookeeper) +script + sleep 2 + export KAFKA_HEAP_OPTS="-Xmx320m" + exec /opt/kafka-KAFKAID/bin/kafka-server-start.sh /opt/kafka-KAFKAID/config/server.properties +end script diff --git a/vendor/github.com/Shopify/sarama/vagrant/provision.sh b/vendor/github.com/Shopify/sarama/vagrant/provision.sh new file mode 100644 index 00000000..ace768f4 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/vagrant/provision.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +set -ex + +apt-get update +yes | apt-get install default-jre + +export KAFKA_INSTALL_ROOT=/opt +export KAFKA_HOSTNAME=192.168.100.67 +export KAFKA_VERSION=0.9.0.1 +export REPOSITORY_ROOT=/vagrant + +sh /vagrant/vagrant/install_cluster.sh +sh /vagrant/vagrant/setup_services.sh +sh /vagrant/vagrant/create_topics.sh diff --git a/vendor/github.com/Shopify/sarama/vagrant/run_toxiproxy.sh b/vendor/github.com/Shopify/sarama/vagrant/run_toxiproxy.sh new file mode 100644 index 00000000..e52c00e7 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/vagrant/run_toxiproxy.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +set -ex + +${KAFKA_INSTALL_ROOT}/toxiproxy -port 8474 -host 0.0.0.0 & +PID=$! + +while ! nc -q 1 localhost 8474 + +# The number of threads handling network requests +num.network.threads=2 + +# The number of threads doing disk I/O +num.io.threads=8 + +# The send buffer (SO_SNDBUF) used by the socket server +socket.send.buffer.bytes=1048576 + +# The receive buffer (SO_RCVBUF) used by the socket server +socket.receive.buffer.bytes=1048576 + +# The maximum size of a request that the socket server will accept (protection against OOM) +socket.request.max.bytes=104857600 + + +############################# Log Basics ############################# + +# A comma seperated list of directories under which to store log files +log.dirs=KAFKA_DATADIR + +# The default number of log partitions per topic. More partitions allow greater +# parallelism for consumption, but this will also result in more files across +# the brokers. +num.partitions=2 + +# Create new topics with a replication factor of 2 so failover can be tested +# more easily. +default.replication.factor=2 + +auto.create.topics.enable=false +delete.topic.enable=true + +############################# Log Flush Policy ############################# + +# Messages are immediately written to the filesystem but by default we only fsync() to sync +# the OS cache lazily. The following configurations control the flush of data to disk. +# There are a few important trade-offs here: +# 1. Durability: Unflushed data may be lost if you are not using replication. +# 2. Latency: Very large flush intervals may lead to latency spikes when the flush does occur as there will be a lot of data to flush. +# 3. Throughput: The flush is generally the most expensive operation, and a small flush interval may lead to exceessive seeks. +# The settings below allow one to configure the flush policy to flush data after a period of time or +# every N messages (or both). This can be done globally and overridden on a per-topic basis. + +# The number of messages to accept before forcing a flush of data to disk +#log.flush.interval.messages=10000 + +# The maximum amount of time a message can sit in a log before we force a flush +#log.flush.interval.ms=1000 + +############################# Log Retention Policy ############################# + +# The following configurations control the disposal of log segments. The policy can +# be set to delete segments after a period of time, or after a given size has accumulated. +# A segment will be deleted whenever *either* of these criteria are met. Deletion always happens +# from the end of the log. + +# The minimum age of a log file to be eligible for deletion +log.retention.hours=168 + +# A size-based retention policy for logs. Segments are pruned from the log as long as the remaining +# segments don't drop below log.retention.bytes. +log.retention.bytes=268435456 + +# The maximum size of a log segment file. When this size is reached a new log segment will be created. +log.segment.bytes=268435456 + +# The interval at which log segments are checked to see if they can be deleted according +# to the retention policies +log.retention.check.interval.ms=60000 + +# By default the log cleaner is disabled and the log retention policy will default to just delete segments after their retention expires. +# If log.cleaner.enable=true is set the cleaner will be enabled and individual logs can then be marked for log compaction. +log.cleaner.enable=false + +############################# Zookeeper ############################# + +# Zookeeper connection string (see zookeeper docs for details). +# This is a comma separated host:port pairs, each corresponding to a zk +# server. e.g. "127.0.0.1:3000,127.0.0.1:3001,127.0.0.1:3002". +# You can also append an optional chroot string to the urls to specify the +# root directory for all kafka znodes. +zookeeper.connect=localhost:ZK_PORT + +# Timeout in ms for connecting to zookeeper +zookeeper.session.timeout.ms=3000 +zookeeper.connection.timeout.ms=3000 diff --git a/vendor/github.com/Shopify/sarama/vagrant/setup_services.sh b/vendor/github.com/Shopify/sarama/vagrant/setup_services.sh new file mode 100644 index 00000000..81d8ea05 --- /dev/null +++ b/vendor/github.com/Shopify/sarama/vagrant/setup_services.sh @@ -0,0 +1,29 @@ +#!/bin/sh + +set -ex + +stop toxiproxy || true +cp ${REPOSITORY_ROOT}/vagrant/toxiproxy.conf /etc/init/toxiproxy.conf +cp ${REPOSITORY_ROOT}/vagrant/run_toxiproxy.sh ${KAFKA_INSTALL_ROOT}/ +start toxiproxy + +for i in 1 2 3 4 5; do + ZK_PORT=`expr $i + 2180` + KAFKA_PORT=`expr $i + 9090` + + stop zookeeper-${ZK_PORT} || true + + # set up zk service + cp ${REPOSITORY_ROOT}/vagrant/zookeeper.conf /etc/init/zookeeper-${ZK_PORT}.conf + sed -i s/KAFKAID/${KAFKA_PORT}/g /etc/init/zookeeper-${ZK_PORT}.conf + + # set up kafka service + cp ${REPOSITORY_ROOT}/vagrant/kafka.conf /etc/init/kafka-${KAFKA_PORT}.conf + sed -i s/KAFKAID/${KAFKA_PORT}/g /etc/init/kafka-${KAFKA_PORT}.conf + sed -i s/ZK_PORT/${ZK_PORT}/g /etc/init/kafka-${KAFKA_PORT}.conf + + start zookeeper-${ZK_PORT} +done + +# Wait for the last kafka node to finish booting +while ! nc -q 1 localhost 29095 0 { + msg += fmt.Sprintf(" (%s)", strings.Join(p.Choices, "/")) + } + if p.Default != "" { + msg += " [" + p.Default + "]" + } + return msg + ": " +} + +func (p *Prompter) errorMsg() string { + if p.Regexp != nil { + return fmt.Sprintf("# Answer should match /%s/", p.Regexp) + } + if p.Choices != nil && len(p.Choices) > 0 { + if len(p.Choices) == 1 { + return fmt.Sprintf("# Enter `%s`", p.Choices[0]) + } + choices := make([]string, len(p.Choices)-1) + for i, v := range p.Choices[:len(p.Choices)-1] { + choices[i] = "`" + v + "`" + } + return fmt.Sprintf("# Enter %s or `%s`", strings.Join(choices, ", "), p.Choices[len(p.Choices)-1]) + } + return "" +} + +func (p *Prompter) inputIsValid(input string) bool { + return p.regexp().MatchString(input) +} + +var allReg = regexp.MustCompile(`.*`) + +func (p *Prompter) regexp() *regexp.Regexp { + if p.Regexp != nil { + return p.Regexp + } + if p.reg != nil { + return p.reg + } + if p.Choices == nil || len(p.Choices) == 0 { + p.reg = allReg + return p.reg + } + + choices := make([]string, len(p.Choices)) + for i, v := range p.Choices { + choices[i] = regexp.QuoteMeta(v) + } + ignoreReg := "" + if p.IgnoreCase { + ignoreReg = "(?i)" + } + p.reg = regexp.MustCompile(fmt.Sprintf(`%s\A(?:%s)\z`, ignoreReg, strings.Join(choices, "|"))) + return p.reg +} diff --git a/vendor/github.com/Songmu/prompter/prompter_test.go b/vendor/github.com/Songmu/prompter/prompter_test.go new file mode 100644 index 00000000..dc069cef --- /dev/null +++ b/vendor/github.com/Songmu/prompter/prompter_test.go @@ -0,0 +1,36 @@ +package prompter + +import ( + "fmt" + "testing" +) + +func TestMsg(t *testing.T) { + p := &Prompter{ + Choices: []string{"aa", "bb", "cc"}, + Default: "aa", + Message: "plaase select", + } + + if p.msg() != "plaase select (aa/bb/cc) [aa]: " { + t.Errorf("something went wrong") + } + + if p.errorMsg() != "# Enter `aa`, `bb` or `cc`" { + t.Errorf("something went wrong") + } + + if !p.inputIsValid("aa") { + t.Errorf("something went wrong") + } + + if p.inputIsValid("AA") { + t.Errorf("something went wrong") + } + + input := p.Prompt() + if input != "aa" { + fmt.Printf("%s\n", input) + t.Errorf("something went wrong") + } +} diff --git a/vendor/github.com/alicebob/miniredis/.travis.yml b/vendor/github.com/alicebob/miniredis/.travis.yml new file mode 100644 index 00000000..94033b1b --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/.travis.yml @@ -0,0 +1,10 @@ +language: go + +install: go get -t + +script: make test + +sudo: false + +go: + - 1.8 diff --git a/vendor/github.com/alicebob/miniredis/LICENSE b/vendor/github.com/alicebob/miniredis/LICENSE new file mode 100644 index 00000000..bb02657c --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Harmen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/alicebob/miniredis/Makefile b/vendor/github.com/alicebob/miniredis/Makefile new file mode 100644 index 00000000..c3b5340f --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/Makefile @@ -0,0 +1,13 @@ +.PHONY: all install test vet + +all: test vet + +install: + go install + +test: + go test ./... + +vet: + go vet ./... + golint ./... diff --git a/vendor/github.com/alicebob/miniredis/README.md b/vendor/github.com/alicebob/miniredis/README.md new file mode 100644 index 00000000..ff88ac78 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/README.md @@ -0,0 +1,276 @@ +# Miniredis + +Pure Go Redis test server, used in Go unittests. + + +## + +Sometimes you want to test code which uses Redis, without making it a full-blown +integration test. +Miniredis implements (parts of) the Redis server, to be used in unittests. It +enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`. + +It saves you from using mock code, and since the redis server lives in the +test process you can query for values directly, without going through the server +stack. + +There are no dependencies on external binaries, so you can easily integrate it in automated build processes. + +## 1.0.0 incompatiliby notice + +2.0.0 improves TTLs to be `time.Duration` values. `.Expire()` is removed and +replaced by `.TTL()`, which returns the TTL as a `time.Duration`. +This should be the change needed to upgrade: + +1.0.0: + + m.Expire() == 4 + +2.0.0: + + m.TTL() == 4 * time.Second + +Furthermore, `.SetTime()` is added to help with `EXPIREAT` commands, and `.FastForward()` is introduced to test keys expiration. + + +## Commands + +Implemented commands: + + - Connection (complete) + - AUTH -- see RequireAuth() + - ECHO + - PING + - SELECT + - QUIT + - Key + - DEL + - EXISTS + - EXPIRE + - EXPIREAT + - KEYS + - MOVE + - PERSIST + - PEXPIRE + - PEXPIREAT + - PTTL + - RENAME + - RENAMENX + - RANDOMKEY -- call math.rand.Seed(...) once before using. + - TTL + - TYPE + - SCAN + - Transactions (complete) + - DISCARD + - EXEC + - MULTI + - UNWATCH + - WATCH + - Server + - DBSIZE + - FLUSHALL + - FLUSHDB + - String keys (complete) + - APPEND + - BITCOUNT + - BITOP + - BITPOS + - DECR + - DECRBY + - GET + - GETBIT + - GETRANGE + - GETSET + - INCR + - INCRBY + - INCRBYFLOAT + - MGET + - MSET + - MSETNX + - PSETEX + - SET + - SETBIT + - SETEX + - SETNX + - SETRANGE + - STRLEN + - Hash keys (complete) + - HDEL + - HEXISTS + - HGET + - HGETALL + - HINCRBY + - HINCRBYFLOAT + - HKEYS + - HLEN + - HMGET + - HMSET + - HSET + - HSETNX + - HVALS + - HSCAN + - List keys (complete) + - BLPOP + - BRPOP + - BRPOPLPUSH + - LINDEX + - LINSERT + - LLEN + - LPOP + - LPUSH + - LPUSHX + - LRANGE + - LREM + - LSET + - LTRIM + - RPOP + - RPOPLPUSH + - RPUSH + - RPUSHX + - Set keys (complete) + - SADD + - SCARD + - SDIFF + - SDIFFSTORE + - SINTER + - SINTERSTORE + - SISMEMBER + - SMEMBERS + - SMOVE + - SPOP -- call math.rand.Seed(...) once before using. + - SRANDMEMBER -- call math.rand.Seed(...) once before using. + - SREM + - SUNION + - SUNIONSTORE + - SSCAN + - Sorted Set keys (complete) + - ZADD + - ZCARD + - ZCOUNT + - ZINCRBY + - ZINTERSTORE + - ZLEXCOUNT + - ZRANGE + - ZRANGEBYLEX + - ZRANGEBYSCORE + - ZRANK + - ZREM + - ZREMRANGEBYLEX + - ZREMRANGEBYRANK + - ZREMRANGEBYSCORE + - ZREVRANGE + - ZREVRANGEBYSCORE + - ZREVRANK + - ZSCORE + - ZUNIONSTORE + - ZSCAN + + +Since miniredis is intended to be used in unittests TTLs don't decrease +automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a +key. It will return 0 when no TTL is set. EXPIREAT and PEXPIREAT values will be +converted to a duration. For that you can either set m.SetTime(t) to use that +time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in +which case time.Now() will be used. +`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <= +0 will be removed. + +## Example + +``` Go +func TestSomething(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + panic(err) + } + defer s.Close() + + // Optionally set some keys your code expects: + s.Set("foo", "bar") + s.HSet("some", "other", "key") + + // Run your code and see if it behaves. + // An example using the redigo library from "github.com/garyburd/redigo/redis": + c, err := redis.Dial("tcp", s.Addr()) + _, err = c.Do("SET", "foo", "bar") + + // Optionally check values in redis... + if got, err := s.Get("foo"); err != nil || got != "bar" { + t.Error("'foo' has the wrong value") + } + // ... or use a helper for that: + s.CheckGet(t, "foo", "bar") + + // TTL and expiration: + s.Set("foo", "bar") + s.SetTTL("foo", 10 * time.Second) + s.FastForward(11 * time.Second) + if s.Exists("foo") { + t.Fatal("'foo' should not have existed anymore") + } +} +``` + +## Not supported + +Commands which will probably not be implemented: + + - CLUSTER (all) + - ~~CLUSTER *~~ + - ~~READONLY~~ + - ~~READWRITE~~ + - GEO (all) -- unless someone needs these + - ~~GEOADD~~ + - ~~GEODIST~~ + - ~~GEOHASH~~ + - ~~GEOPOS~~ + - ~~GEORADIUS~~ + - ~~GEORADIUSBYMEMBER~~ + - HyperLogLog (all) -- unless someone needs these + - ~~PFADD~~ + - ~~PFCOUNT~~ + - ~~PFMERGE~~ + - Key + - ~~DUMP~~ + - ~~MIGRATE~~ + - ~~OBJECT~~ + - ~~RESTORE~~ + - ~~WAIT~~ + - Pub/Sub (all) + - ~~PSUBSCRIBE~~ + - ~~PUBLISH~~ + - ~~PUBSUB~~ + - ~~PUNSUBSCRIBE~~ + - ~~SUBSCRIBE~~ + - ~~UNSUBSCRIBE~~ + - Scripting (all) + - ~~EVAL~~ + - ~~EVALSHA~~ + - ~~SCRIPT *~~ + - Server + - ~~BGSAVE~~ + - ~~BGWRITEAOF~~ + - ~~CLIENT *~~ + - ~~COMMAND *~~ + - ~~CONFIG *~~ + - ~~DEBUG *~~ + - ~~INFO~~ + - ~~LASTSAVE~~ + - ~~MONITOR~~ + - ~~ROLE~~ + - ~~SAVE~~ + - ~~SHUTDOWN~~ + - ~~SLAVEOF~~ + - ~~SLOWLOG~~ + - ~~SYNC~~ + - ~~TIME~~ + + +## &c. + +See https://github.com/alicebob/miniredis_vs_redis for tests comparing +miniredis against the real thing. Tests are run against Redis 3.2.5 (Debian). + + +[![Build Status](https://travis-ci.org/alicebob/miniredis.svg?branch=master)](https://travis-ci.org/alicebob/miniredis) +[![GoDoc](https://godoc.org/github.com/alicebob/miniredis?status.svg)](https://godoc.org/github.com/alicebob/miniredis) diff --git a/vendor/github.com/alicebob/miniredis/check.go b/vendor/github.com/alicebob/miniredis/check.go new file mode 100644 index 00000000..8b42b2e0 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/check.go @@ -0,0 +1,68 @@ +package miniredis + +// 'Fail' methods. + +import ( + "fmt" + "path/filepath" + "reflect" + "runtime" + "sort" +) + +// T is implemented by Testing.T +type T interface { + Fail() +} + +// CheckGet does not call Errorf() iff there is a string key with the +// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`. +func (m *Miniredis) CheckGet(t T, key, expected string) { + found, err := m.Get(key) + if err != nil { + lError(t, "GET error, key %#v: %v", key, err) + return + } + if found != expected { + lError(t, "GET error, key %#v: Expected %#v, got %#v", key, expected, found) + return + } +} + +// CheckList does not call Errorf() iff there is a list key with the +// expected values. +// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`. +func (m *Miniredis) CheckList(t T, key string, expected ...string) { + found, err := m.List(key) + if err != nil { + lError(t, "List error, key %#v: %v", key, err) + return + } + if !reflect.DeepEqual(expected, found) { + lError(t, "List error, key %#v: Expected %#v, got %#v", key, expected, found) + return + } +} + +// CheckSet does not call Errorf() iff there is a set key with the +// expected values. +// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`. +func (m *Miniredis) CheckSet(t T, key string, expected ...string) { + found, err := m.Members(key) + if err != nil { + lError(t, "Set error, key %#v: %v", key, err) + return + } + sort.Strings(expected) + if !reflect.DeepEqual(expected, found) { + lError(t, "Set error, key %#v: Expected %#v, got %#v", key, expected, found) + return + } +} + +func lError(t T, format string, args ...interface{}) { + _, file, line, _ := runtime.Caller(2) + prefix := fmt.Sprintf("%s:%d: ", filepath.Base(file), line) + fmt.Printf(prefix+format+"\n", args...) + t.Fail() +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_connection.go b/vendor/github.com/alicebob/miniredis/cmd_connection.go new file mode 100644 index 00000000..dbe1f897 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_connection.go @@ -0,0 +1,96 @@ +// Commands from http://redis.io/commands#connection + +package miniredis + +import ( + "strconv" + + "github.com/alicebob/miniredis/server" +) + +func commandsConnection(m *Miniredis) { + m.srv.Register("AUTH", m.cmdAuth) + m.srv.Register("ECHO", m.cmdEcho) + m.srv.Register("PING", m.cmdPing) + m.srv.Register("SELECT", m.cmdSelect) + m.srv.Register("QUIT", m.cmdQuit) +} + +// PING +func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + c.WriteInline("PONG") +} + +// AUTH +func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + pw := args[0] + + m.Lock() + defer m.Unlock() + if m.password == "" { + c.WriteError("ERR Client sent AUTH, but no password is set") + return + } + if m.password != pw { + c.WriteError("ERR invalid password") + return + } + + setAuthenticated(c) + c.WriteOK() +} + +// ECHO +func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + msg := args[0] + c.WriteBulk(msg) +} + +// SELECT +func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + id, err := strconv.Atoi(args[0]) + if err != nil { + id = 0 + } + + m.Lock() + defer m.Unlock() + + ctx := getCtx(c) + ctx.selectedDB = id + + c.WriteOK() +} + +// QUIT +func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) { + // QUIT isn't transactionfied and accepts any arguments. + c.WriteOK() + c.Close() +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_connection_test.go b/vendor/github.com/alicebob/miniredis/cmd_connection_test.go new file mode 100644 index 00000000..72108df9 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_connection_test.go @@ -0,0 +1,92 @@ +package miniredis + +import ( + "testing" + + "github.com/garyburd/redigo/redis" +) + +func TestAuth(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + _, err = c.Do("AUTH", "foo", "bar") + assert(t, err != nil, "no password set") + + s.RequireAuth("nocomment") + _, err = c.Do("PING", "foo", "bar") + assert(t, err != nil, "need AUTH") + + _, err = c.Do("AUTH", "wrongpasswd") + assert(t, err != nil, "wrong password") + + _, err = c.Do("AUTH", "nocomment") + ok(t, err) + + _, err = c.Do("PING") + ok(t, err) +} + +func TestEcho(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + r, err := redis.String(c.Do("ECHO", "hello\nworld")) + ok(t, err) + equals(t, "hello\nworld", r) +} + +func TestSelect(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + _, err = redis.String(c.Do("SET", "foo", "bar")) + ok(t, err) + + _, err = redis.String(c.Do("SELECT", "5")) + ok(t, err) + + _, err = redis.String(c.Do("SET", "foo", "baz")) + ok(t, err) + + // Direct access. + got, err := s.Get("foo") + ok(t, err) + equals(t, "bar", got) + s.Select(5) + got, err = s.Get("foo") + ok(t, err) + equals(t, "baz", got) + + // Another connection should have its own idea of the db: + c2, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + v, err := redis.String(c2.Do("GET", "foo")) + ok(t, err) + equals(t, "bar", v) +} + +func TestQuit(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + v, err := redis.String(c.Do("QUIT")) + ok(t, err) + equals(t, "OK", v) + + v, err = redis.String(c.Do("PING")) + assert(t, err != nil, "QUIT closed the client") + equals(t, "", v) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_generic.go b/vendor/github.com/alicebob/miniredis/cmd_generic.go new file mode 100644 index 00000000..44b21711 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_generic.go @@ -0,0 +1,479 @@ +// Commands from http://redis.io/commands#generic + +package miniredis + +import ( + "math/rand" + "strconv" + "strings" + "time" + + "github.com/alicebob/miniredis/server" +) + +// commandsGeneric handles EXPIRE, TTL, PERSIST, &c. +func commandsGeneric(m *Miniredis) { + m.srv.Register("DEL", m.cmdDel) + // DUMP + m.srv.Register("EXISTS", m.cmdExists) + m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second)) + m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second)) + m.srv.Register("KEYS", m.cmdKeys) + // MIGRATE + m.srv.Register("MOVE", m.cmdMove) + // OBJECT + m.srv.Register("PERSIST", m.cmdPersist) + m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond)) + m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond)) + m.srv.Register("PTTL", m.cmdPTTL) + m.srv.Register("RANDOMKEY", m.cmdRandomkey) + m.srv.Register("RENAME", m.cmdRename) + m.srv.Register("RENAMENX", m.cmdRenamenx) + // RESTORE + // SORT + m.srv.Register("TTL", m.cmdTTL) + m.srv.Register("TYPE", m.cmdType) + m.srv.Register("SCAN", m.cmdScan) +} + +// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT +// d is the time unit. If unix is set it'll be seen as a unixtimestamp and +// converted to a duration. +func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) { + return func(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + value := args[1] + i, err := strconv.Atoi(value) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + // Key must be present. + if _, ok := db.keys[key]; !ok { + c.WriteInt(0) + return + } + if unix { + var ts time.Time + switch d { + case time.Millisecond: + ts = time.Unix(0, int64(i)) + case time.Second: + ts = time.Unix(int64(i), 0) + default: + panic("invalid time unit (d). Fixme!") + } + now := m.now + if now.IsZero() { + now = time.Now().UTC() + } + db.ttl[key] = ts.Sub(now) + } else { + db.ttl[key] = time.Duration(i) * d + } + db.keyVersion[key]++ + db.checkTTL(key) + c.WriteInt(1) + }) + } +} + +// TTL +func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if _, ok := db.keys[key]; !ok { + // No such key + c.WriteInt(-2) + return + } + + v, ok := db.ttl[key] + if !ok { + // no expire value + c.WriteInt(-1) + return + } + c.WriteInt(int(v.Seconds())) + }) +} + +// PTTL +func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if _, ok := db.keys[key]; !ok { + // no such key + c.WriteInt(-2) + return + } + + v, ok := db.ttl[key] + if !ok { + // no expire value + c.WriteInt(-1) + return + } + c.WriteInt(int(v.Nanoseconds() / 1000000)) + }) +} + +// PERSIST +func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if _, ok := db.keys[key]; !ok { + // no such key + c.WriteInt(0) + return + } + + if _, ok := db.ttl[key]; !ok { + // no expire value + c.WriteInt(0) + return + } + delete(db.ttl, key) + db.keyVersion[key]++ + c.WriteInt(1) + }) +} + +// DEL +func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + count := 0 + for _, key := range args { + if db.exists(key) { + count++ + } + db.del(key, true) // delete expire + } + c.WriteInt(count) + }) +} + +// TYPE +func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError("usage error") + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteInline("none") + return + } + + c.WriteInline(t) + }) +} + +// EXISTS +func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + found := 0 + for _, k := range args { + if db.exists(k) { + found++ + } + } + c.WriteInt(found) + }) +} + +// MOVE +func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + targetDB, err := strconv.Atoi(args[1]) + if err != nil { + targetDB = 0 + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + if ctx.selectedDB == targetDB { + c.WriteError("ERR source and destination objects are the same") + return + } + db := m.db(ctx.selectedDB) + targetDB := m.db(targetDB) + + if !db.move(key, targetDB) { + c.WriteInt(0) + return + } + c.WriteInt(1) + }) +} + +// KEYS +func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + keys := matchKeys(db.allKeys(), key) + c.WriteLen(len(keys)) + for _, s := range keys { + c.WriteBulk(s) + } + }) +} + +// RANDOMKEY +func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) { + if len(args) != 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if len(db.keys) == 0 { + c.WriteNull() + return + } + nr := rand.Intn(len(db.keys)) + for k := range db.keys { + if nr == 0 { + c.WriteBulk(k) + return + } + nr-- + } + }) +} + +// RENAME +func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + from, to := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(from) { + c.WriteError(msgKeyNotFound) + return + } + + db.rename(from, to) + c.WriteOK() + }) +} + +// RENAMENX +func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + from, to := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(from) { + c.WriteError(msgKeyNotFound) + return + } + + if db.exists(to) { + c.WriteInt(0) + return + } + + db.rename(from, to) + c.WriteInt(1) + }) +} + +// SCAN +func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + cursor, err := strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidCursor) + return + } + args = args[1:] + + // MATCH and COUNT options + var withMatch bool + var match string + for len(args) > 0 { + if strings.ToLower(args[0]) == "count" { + // we do nothing with count + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + if _, err := strconv.Atoi(args[1]); err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + continue + } + if strings.ToLower(args[0]) == "match" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + withMatch = true + match, args = args[1], args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + // We return _all_ (matched) keys every time. + + if cursor != 0 { + // Invalid cursor. + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(0) // no elements + return + } + + keys := db.allKeys() + if withMatch { + keys = matchKeys(keys, match) + } + + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(len(keys)) + for _, k := range keys { + c.WriteBulk(k) + } + }) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_generic_test.go b/vendor/github.com/alicebob/miniredis/cmd_generic_test.go new file mode 100644 index 00000000..5adb4b89 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_generic_test.go @@ -0,0 +1,649 @@ +package miniredis + +import ( + "testing" + "time" + + "github.com/garyburd/redigo/redis" +) + +// Test EXPIRE. Keys with an expiration are called volatile in Redis parlance. +func TestTTL(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Not volatile yet + { + equals(t, time.Duration(0), s.TTL("foo")) + b, err := redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, -2, b) + } + + // Set something + { + _, err := c.Do("SET", "foo", "bar") + ok(t, err) + // key exists, but no Expire set yet + b, err := redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, -1, b) + + n, err := redis.Int(c.Do("EXPIRE", "foo", "1200")) + ok(t, err) + equals(t, 1, n) // EXPIRE returns 1 on success + + equals(t, 1200*time.Second, s.TTL("foo")) + b, err = redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, 1200, b) + } + + // A SET resets the expire. + { + _, err := c.Do("SET", "foo", "bar") + ok(t, err) + b, err := redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, -1, b) + } + + // Set a non-existing key + { + n, err := redis.Int(c.Do("EXPIRE", "nokey", "1200")) + ok(t, err) + equals(t, 0, n) // EXPIRE returns 0 on failure. + } + + // Remove an expire + { + + // No key yet + n, err := redis.Int(c.Do("PERSIST", "exkey")) + ok(t, err) + equals(t, 0, n) + + _, err = c.Do("SET", "exkey", "bar") + ok(t, err) + + // No timeout yet + n, err = redis.Int(c.Do("PERSIST", "exkey")) + ok(t, err) + equals(t, 0, n) + + _, err = redis.Int(c.Do("EXPIRE", "exkey", "1200")) + ok(t, err) + + // All fine now + n, err = redis.Int(c.Do("PERSIST", "exkey")) + ok(t, err) + equals(t, 1, n) + + // No TTL left + b, err := redis.Int(c.Do("TTL", "exkey")) + ok(t, err) + equals(t, -1, b) + } + + // Hash key works fine, too + { + _, err := c.Do("HSET", "wim", "zus", "jet") + ok(t, err) + b, err := redis.Int(c.Do("EXPIRE", "wim", "1234")) + ok(t, err) + equals(t, 1, b) + } + + { + _, err = c.Do("SET", "wim", "zus") + ok(t, err) + _, err = redis.Int(c.Do("EXPIRE", "wim", -1200)) + ok(t, err) + equals(t, false, s.Exists("wim")) + } +} + +func TestExpireat(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Not volatile yet + { + equals(t, time.Duration(0), s.TTL("foo")) + b, err := redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, -2, b) + } + + // Set something + { + _, err := c.Do("SET", "foo", "bar") + ok(t, err) + // Key exists, but no ttl set. + b, err := redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, -1, b) + + s.SetTime(time.Unix(1234567890, 0)) + n, err := redis.Int(c.Do("EXPIREAT", "foo", 1234567890+100)) + ok(t, err) + equals(t, 1, n) // EXPIREAT returns 1 on success. + + equals(t, 100*time.Second, s.TTL("foo")) + b, err = redis.Int(c.Do("TTL", "foo")) + ok(t, err) + equals(t, 100, b) + equals(t, 100*time.Second, s.TTL("foo")) + } +} + +func TestPexpire(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Key exists + { + ok(t, s.Set("foo", "bar")) + b, err := redis.Int(c.Do("PEXPIRE", "foo", 12)) + ok(t, err) + equals(t, 1, b) + + e, err := redis.Int(c.Do("PTTL", "foo")) + ok(t, err) + equals(t, 12, e) + + equals(t, 12*time.Millisecond, s.TTL("foo")) + } + // Key doesn't exist + { + b, err := redis.Int(c.Do("PEXPIRE", "nosuch", 12)) + ok(t, err) + equals(t, 0, b) + + e, err := redis.Int(c.Do("PTTL", "nosuch")) + ok(t, err) + equals(t, -2, e) + } + + // No expire + { + s.Set("aap", "noot") + e, err := redis.Int(c.Do("PTTL", "aap")) + ok(t, err) + equals(t, -1, e) + } +} + +func TestDel(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("foo", "bar") + s.HSet("aap", "noot", "mies") + s.Set("one", "two") + s.SetTTL("one", time.Second*1234) + s.Set("three", "four") + r, err := redis.Int(c.Do("DEL", "one", "aap", "nosuch")) + ok(t, err) + equals(t, 2, r) + equals(t, time.Duration(0), s.TTL("one")) + + // Direct also works: + s.Set("foo", "bar") + s.Del("foo") + got, err := s.Get("foo") + equals(t, ErrKeyNotFound, err) + equals(t, "", got) +} + +func TestType(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // String key + { + s.Set("foo", "bar!") + v, err := redis.String(c.Do("TYPE", "foo")) + ok(t, err) + equals(t, "string", v) + } + + // Hash key + { + s.HSet("aap", "noot", "mies") + v, err := redis.String(c.Do("TYPE", "aap")) + ok(t, err) + equals(t, "hash", v) + } + + // New key + { + v, err := redis.String(c.Do("TYPE", "nosuch")) + ok(t, err) + equals(t, "none", v) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("TYPE")) + assert(t, err != nil, "do TYPE error") + _, err = redis.Int(c.Do("TYPE", "spurious", "arguments")) + assert(t, err != nil, "do TYPE error") + } + + // Direct usage: + { + equals(t, "hash", s.Type("aap")) + equals(t, "", s.Type("nokey")) + } +} + +func TestExists(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // String key + { + s.Set("foo", "bar!") + v, err := redis.Int(c.Do("EXISTS", "foo")) + ok(t, err) + equals(t, 1, v) + } + + // Hash key + { + s.HSet("aap", "noot", "mies") + v, err := redis.Int(c.Do("EXISTS", "aap")) + ok(t, err) + equals(t, 1, v) + } + + // Multiple keys + { + v, err := redis.Int(c.Do("EXISTS", "foo", "aap")) + ok(t, err) + equals(t, 2, v) + + v, err = redis.Int(c.Do("EXISTS", "foo", "noot", "aap")) + ok(t, err) + equals(t, 2, v) + } + + // New key + { + v, err := redis.Int(c.Do("EXISTS", "nosuch")) + ok(t, err) + equals(t, 0, v) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("EXISTS")) + assert(t, err != nil, "do EXISTS error") + } + + // Direct usage: + { + equals(t, true, s.Exists("aap")) + equals(t, false, s.Exists("nokey")) + } +} + +func TestMove(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // No problem. + { + s.Set("foo", "bar!") + v, err := redis.Int(c.Do("MOVE", "foo", 1)) + ok(t, err) + equals(t, 1, v) + } + + // Src key doesn't exists. + { + v, err := redis.Int(c.Do("MOVE", "nosuch", 1)) + ok(t, err) + equals(t, 0, v) + } + + // Target key already exists. + { + s.DB(0).Set("two", "orig") + s.DB(1).Set("two", "taken") + v, err := redis.Int(c.Do("MOVE", "two", 1)) + ok(t, err) + equals(t, 0, v) + s.CheckGet(t, "two", "orig") + } + + // TTL is also moved + { + s.DB(0).Set("one", "two") + s.DB(0).SetTTL("one", time.Second*4242) + v, err := redis.Int(c.Do("MOVE", "one", 1)) + ok(t, err) + equals(t, 1, v) + equals(t, s.DB(1).TTL("one"), time.Second*4242) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("MOVE")) + assert(t, err != nil, "do MOVE error") + _, err = redis.Int(c.Do("MOVE", "foo")) + assert(t, err != nil, "do MOVE error") + _, err = redis.Int(c.Do("MOVE", "foo", "noint")) + assert(t, err != nil, "do MOVE error") + _, err = redis.Int(c.Do("MOVE", "foo", 2, "toomany")) + assert(t, err != nil, "do MOVE error") + } +} + +func TestKeys(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("foo", "bar!") + s.Set("foobar", "bar!") + s.Set("barfoo", "bar!") + s.Set("fooooo", "bar!") + + { + v, err := redis.Strings(c.Do("KEYS", "foo")) + ok(t, err) + equals(t, []string{"foo"}, v) + } + + // simple '*' + { + v, err := redis.Strings(c.Do("KEYS", "foo*")) + ok(t, err) + equals(t, []string{"foo", "foobar", "fooooo"}, v) + } + // simple '?' + { + v, err := redis.Strings(c.Do("KEYS", "fo?")) + ok(t, err) + equals(t, []string{"foo"}, v) + } + + // Don't die on never-matching pattern. + { + v, err := redis.Strings(c.Do("KEYS", `f\`)) + ok(t, err) + equals(t, []string{}, v) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("KEYS")) + assert(t, err != nil, "do KEYS error") + _, err = redis.Int(c.Do("KEYS", "foo", "noint")) + assert(t, err != nil, "do KEYS error") + } +} + +func TestRandom(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Empty db. + { + v, err := c.Do("RANDOMKEY") + ok(t, err) + equals(t, nil, v) + } + + s.Set("one", "bar!") + s.Set("two", "bar!") + s.Set("three", "bar!") + + // No idea which key will be returned. + { + v, err := redis.String(c.Do("RANDOMKEY")) + ok(t, err) + assert(t, v == "one" || v == "two" || v == "three", "RANDOMKEY looks sane") + } + + // Wrong usage + { + _, err = redis.Int(c.Do("RANDOMKEY", "spurious")) + assert(t, err != nil, "do RANDOMKEY error") + } +} + +func TestRename(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Non-existing key + { + _, err := redis.Int(c.Do("RENAME", "nosuch", "to")) + assert(t, err != nil, "do RENAME error") + } + + // Same key + { + _, err := redis.Int(c.Do("RENAME", "from", "from")) + assert(t, err != nil, "do RENAME error") + } + + // Move a string key + { + s.Set("from", "value") + str, err := redis.String(c.Do("RENAME", "from", "to")) + ok(t, err) + equals(t, "OK", str) + equals(t, false, s.Exists("from")) + equals(t, true, s.Exists("to")) + s.CheckGet(t, "to", "value") + } + + // Move a hash key + { + s.HSet("from", "key", "value") + str, err := redis.String(c.Do("RENAME", "from", "to")) + ok(t, err) + equals(t, "OK", str) + equals(t, false, s.Exists("from")) + equals(t, true, s.Exists("to")) + equals(t, "value", s.HGet("to", "key")) + } + + // Move over something which exists + { + s.Set("from", "string value") + s.HSet("to", "key", "value") + s.SetTTL("from", time.Second*999999) + + str, err := redis.String(c.Do("RENAME", "from", "to")) + ok(t, err) + equals(t, "OK", str) + equals(t, false, s.Exists("from")) + equals(t, true, s.Exists("to")) + s.CheckGet(t, "to", "string value") + equals(t, time.Duration(0), s.TTL("from")) + equals(t, time.Second*999999, s.TTL("to")) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("RENAME")) + assert(t, err != nil, "do RENAME error") + _, err = redis.Int(c.Do("RENAME", "too few")) + assert(t, err != nil, "do RENAME error") + _, err = redis.Int(c.Do("RENAME", "some", "spurious", "arguments")) + assert(t, err != nil, "do RENAME error") + } +} + +func TestScan(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // We cheat with scan. It always returns everything. + + s.Set("key", "value") + + // No problem + { + res, err := redis.Values(c.Do("SCAN", 0)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"key"}, keys) + } + + // Invalid cursor + { + res, err := redis.Values(c.Do("SCAN", 42)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string(nil), keys) + } + + // COUNT (ignored) + { + res, err := redis.Values(c.Do("SCAN", 0, "COUNT", 200)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"key"}, keys) + } + + // MATCH + { + s.Set("aap", "noot") + s.Set("mies", "wim") + res, err := redis.Values(c.Do("SCAN", 0, "MATCH", "mi*")) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"mies"}, keys) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("SCAN")) + assert(t, err != nil, "do SCAN error") + _, err = redis.Int(c.Do("SCAN", "noint")) + assert(t, err != nil, "do SCAN error") + _, err = redis.Int(c.Do("SCAN", 1, "MATCH")) + assert(t, err != nil, "do SCAN error") + _, err = redis.Int(c.Do("SCAN", 1, "COUNT")) + assert(t, err != nil, "do SCAN error") + _, err = redis.Int(c.Do("SCAN", 1, "COUNT", "noint")) + assert(t, err != nil, "do SCAN error") + } +} + +func TestRenamenx(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Non-existing key + { + _, err := redis.Int(c.Do("RENAMENX", "nosuch", "to")) + assert(t, err != nil, "do RENAMENX error") + } + + // Same key + { + _, err := redis.Int(c.Do("RENAMENX", "from", "from")) + assert(t, err != nil, "do RENAMENX error") + } + + // Move a string key + { + s.Set("from", "value") + n, err := redis.Int(c.Do("RENAMENX", "from", "to")) + ok(t, err) + equals(t, 1, n) + equals(t, false, s.Exists("from")) + equals(t, true, s.Exists("to")) + s.CheckGet(t, "to", "value") + } + + // Move over something which exists + { + s.Set("from", "string value") + s.Set("to", "value") + + n, err := redis.Int(c.Do("RENAMENX", "from", "to")) + ok(t, err) + equals(t, 0, n) + equals(t, true, s.Exists("from")) + equals(t, true, s.Exists("to")) + s.CheckGet(t, "from", "string value") + s.CheckGet(t, "to", "value") + } + + // Wrong usage + { + _, err := redis.Int(c.Do("RENAMENX")) + assert(t, err != nil, "do RENAMENX error") + _, err = redis.Int(c.Do("RENAMENX", "too few")) + assert(t, err != nil, "do RENAMENX error") + _, err = redis.Int(c.Do("RENAMENX", "some", "spurious", "arguments")) + assert(t, err != nil, "do RENAMENX error") + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_hash.go b/vendor/github.com/alicebob/miniredis/cmd_hash.go new file mode 100644 index 00000000..6db872f0 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_hash.go @@ -0,0 +1,571 @@ +// Commands from http://redis.io/commands#hash + +package miniredis + +import ( + "strconv" + "strings" + + "github.com/alicebob/miniredis/server" +) + +// commandsHash handles all hash value operations. +func commandsHash(m *Miniredis) { + m.srv.Register("HDEL", m.cmdHdel) + m.srv.Register("HEXISTS", m.cmdHexists) + m.srv.Register("HGET", m.cmdHget) + m.srv.Register("HGETALL", m.cmdHgetall) + m.srv.Register("HINCRBY", m.cmdHincrby) + m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat) + m.srv.Register("HKEYS", m.cmdHkeys) + m.srv.Register("HLEN", m.cmdHlen) + m.srv.Register("HMGET", m.cmdHmget) + m.srv.Register("HMSET", m.cmdHmset) + m.srv.Register("HSET", m.cmdHset) + m.srv.Register("HSETNX", m.cmdHsetnx) + m.srv.Register("HVALS", m.cmdHvals) + m.srv.Register("HSCAN", m.cmdHscan) +} + +// HSET +func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field, value := args[0], args[1], args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + if db.hashSet(key, field, value) { + c.WriteInt(0) + } else { + c.WriteInt(1) + } + }) +} + +// HSETNX +func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field, value := args[0], args[1], args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + if _, ok := db.hashKeys[key]; !ok { + db.hashKeys[key] = map[string]string{} + db.keys[key] = "hash" + } + _, ok := db.hashKeys[key][field] + if ok { + c.WriteInt(0) + return + } + db.hashKeys[key][field] = value + db.keyVersion[key]++ + c.WriteInt(1) + }) +} + +// HMSET +func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, args := args[0], args[1:] + if len(args)%2 != 0 { + setDirty(c) + // non-default error message + c.WriteError("ERR wrong number of arguments for HMSET") + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + for len(args) > 0 { + field, value := args[0], args[1] + args = args[2:] + db.hashSet(key, field, value) + } + c.WriteOK() + }) +} + +// HGET +func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteNull() + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + value, ok := db.hashKeys[key][field] + if !ok { + c.WriteNull() + return + } + c.WriteBulk(value) + }) +} + +// HDEL +func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, fields := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + // No key is zero deleted + c.WriteInt(0) + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + + deleted := 0 + for _, f := range fields { + _, ok := db.hashKeys[key][f] + if !ok { + continue + } + delete(db.hashKeys[key], f) + deleted++ + } + c.WriteInt(deleted) + + // Nothing left. Remove the whole key. + if len(db.hashKeys[key]) == 0 { + db.del(key, true) + } + }) +} + +// HEXISTS +func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteInt(0) + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + + if _, ok := db.hashKeys[key][field]; !ok { + c.WriteInt(0) + return + } + c.WriteInt(1) + }) +} + +// HGETALL +func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteLen(0) + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + + c.WriteLen(len(db.hashKeys[key]) * 2) + for _, k := range db.hashFields(key) { + c.WriteBulk(k) + c.WriteBulk(db.hashGet(key, k)) + } + }) +} + +// HKEYS +func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(0) + return + } + if db.t(key) != "hash" { + c.WriteError(msgWrongType) + return + } + + fields := db.hashFields(key) + c.WriteLen(len(fields)) + for _, f := range fields { + c.WriteBulk(f) + } + }) +} + +// HVALS +func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteLen(0) + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + + c.WriteLen(len(db.hashKeys[key])) + for _, v := range db.hashKeys[key] { + c.WriteBulk(v) + } + }) +} + +// HLEN +func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteInt(0) + return + } + if t != "hash" { + c.WriteError(msgWrongType) + return + } + + c.WriteInt(len(db.hashKeys[key])) + }) +} + +// HMGET +func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + f, ok := db.hashKeys[key] + if !ok { + f = map[string]string{} + } + + c.WriteLen(len(args) - 1) + for _, k := range args[1:] { + v, ok := f[k] + if !ok { + c.WriteNull() + continue + } + c.WriteBulk(v) + } + }) +} + +// HINCRBY +func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field, deltas := args[0], args[1], args[2] + + delta, err := strconv.Atoi(deltas) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + v, err := db.hashIncr(key, field, delta) + if err != nil { + c.WriteError(err.Error()) + return + } + c.WriteInt(v) + }) +} + +// HINCRBYFLOAT +func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, field, deltas := args[0], args[1], args[2] + + delta, err := strconv.ParseFloat(deltas, 64) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidFloat) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "hash" { + c.WriteError(msgWrongType) + return + } + + v, err := db.hashIncrfloat(key, field, delta) + if err != nil { + c.WriteError(err.Error()) + return + } + c.WriteBulk(formatFloat(v)) + }) +} + +// HSCAN +func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + cursor, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidCursor) + return + } + args = args[2:] + + // MATCH and COUNT options + var withMatch bool + var match string + for len(args) > 0 { + if strings.ToLower(args[0]) == "count" { + // we do nothing with count + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + _, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + continue + } + if strings.ToLower(args[0]) == "match" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + withMatch = true + match, args = args[1], args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + // return _all_ (matched) keys every time + + if cursor != 0 { + // Invalid cursor. + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(0) // no elements + return + } + if db.exists(key) && db.t(key) != "hash" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.hashFields(key) + if withMatch { + members = matchKeys(members, match) + } + + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + // HSCAN gives key, values. + c.WriteLen(len(members) * 2) + for _, k := range members { + c.WriteBulk(k) + c.WriteBulk(db.hashGet(key, k)) + } + }) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_hash_test.go b/vendor/github.com/alicebob/miniredis/cmd_hash_test.go new file mode 100644 index 00000000..bdbd0187 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_hash_test.go @@ -0,0 +1,582 @@ +package miniredis + +import ( + "sort" + "testing" + "time" + + "github.com/garyburd/redigo/redis" +) + +// Test Hash. +func TestHash(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("HSET", "aap", "noot", "mies")) + ok(t, err) + equals(t, 1, b) // New field. + } + + { + v, err := redis.String(c.Do("HGET", "aap", "noot")) + ok(t, err) + equals(t, "mies", v) + equals(t, "mies", s.HGet("aap", "noot")) + } + + { + b, err := redis.Int(c.Do("HSET", "aap", "noot", "mies")) + ok(t, err) + equals(t, 0, b) // Existing field. + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "foo", "bar")) + ok(t, err) + _, err = redis.Int(c.Do("HSET", "foo", "noot", "mies")) + assert(t, err != nil, "HSET error") + } + + // hash exists, key doesn't. + { + b, err := c.Do("HGET", "aap", "nosuch") + ok(t, err) + equals(t, nil, b) + } + + // hash doesn't exists. + { + b, err := c.Do("HGET", "nosuch", "nosuch") + ok(t, err) + equals(t, nil, b) + equals(t, "", s.HGet("nosuch", "nosuch")) + } + + // HGET on wrong type + { + _, err := redis.Int(c.Do("HGET", "aap")) + assert(t, err != nil, "HGET error") + } + + // Direct HSet() + { + s.HSet("wim", "zus", "jet") + v, err := redis.String(c.Do("HGET", "wim", "zus")) + ok(t, err) + equals(t, "jet", v) + } +} + +func TestHashSetNX(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // New Hash + v, err := redis.Int(c.Do("HSETNX", "wim", "zus", "jet")) + ok(t, err) + equals(t, 1, v) + + v, err = redis.Int(c.Do("HSETNX", "wim", "zus", "jet")) + ok(t, err) + equals(t, 0, v) + + // Just a new key + v, err = redis.Int(c.Do("HSETNX", "wim", "aap", "noot")) + ok(t, err) + equals(t, 1, v) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HSETNX", "foo", "nosuch", "nosuch")) + assert(t, err != nil, "no HSETNX error") +} + +func TestHashMSet(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // New Hash + { + v, err := redis.String(c.Do("HMSET", "hash", "wim", "zus", "jet", "vuur")) + ok(t, err) + equals(t, "OK", v) + + equals(t, "zus", s.HGet("hash", "wim")) + equals(t, "vuur", s.HGet("hash", "jet")) + } + + // Doesn't touch ttl. + { + s.SetTTL("hash", time.Second*999) + v, err := redis.String(c.Do("HMSET", "hash", "gijs", "lam")) + ok(t, err) + equals(t, "OK", v) + equals(t, time.Second*999, s.TTL("hash")) + } + + { + // Wrong key type + s.Set("str", "value") + _, err = redis.Int(c.Do("HMSET", "str", "key", "value")) + assert(t, err != nil, "no HSETerror") + // Usage error + _, err = redis.Int(c.Do("HMSET", "str")) + assert(t, err != nil, "no HSETerror") + _, err = redis.Int(c.Do("HMSET", "str", "odd")) + assert(t, err != nil, "no HSETerror") + _, err = redis.Int(c.Do("HMSET", "str", "key", "value", "odd")) + assert(t, err != nil, "no HSETerror") + } +} + +func TestHashDel(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + v, err := redis.Int(c.Do("HDEL", "wim", "zus", "gijs")) + ok(t, err) + equals(t, 2, v) + + v, err = redis.Int(c.Do("HDEL", "wim", "nosuch")) + ok(t, err) + equals(t, 0, v) + + // Deleting all makes the key disappear + v, err = redis.Int(c.Do("HDEL", "wim", "teun", "kees")) + ok(t, err) + equals(t, 2, v) + assert(t, !s.Exists("wim"), "no more wim key") + + // Key doesn't exists. + v, err = redis.Int(c.Do("HDEL", "nosuch", "nosuch")) + ok(t, err) + equals(t, 0, v) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HDEL", "foo", "nosuch")) + assert(t, err != nil, "no HDEL error") + + // Direct HDel() + s.HSet("aap", "noot", "mies") + s.HDel("aap", "noot") + equals(t, "", s.HGet("aap", "noot")) +} + +func TestHashExists(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + v, err := redis.Int(c.Do("HEXISTS", "wim", "zus")) + ok(t, err) + equals(t, 1, v) + + v, err = redis.Int(c.Do("HEXISTS", "wim", "nosuch")) + ok(t, err) + equals(t, 0, v) + + v, err = redis.Int(c.Do("HEXISTS", "nosuch", "nosuch")) + ok(t, err) + equals(t, 0, v) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HEXISTS", "foo", "nosuch")) + assert(t, err != nil, "no HDEL error") +} + +func TestHashGetall(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + v, err := redis.Strings(c.Do("HGETALL", "wim")) + ok(t, err) + equals(t, 8, len(v)) + d := map[string]string{} + for len(v) > 0 { + d[v[0]] = v[1] + v = v[2:] + } + equals(t, map[string]string{ + "zus": "jet", + "teun": "vuur", + "gijs": "lam", + "kees": "bok", + }, d) + + v, err = redis.Strings(c.Do("HGETALL", "nosuch")) + ok(t, err) + equals(t, 0, len(v)) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HGETALL", "foo")) + assert(t, err != nil, "no HGETALL error") +} + +func TestHashKeys(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + { + v, err := redis.Strings(c.Do("HKEYS", "wim")) + ok(t, err) + equals(t, 4, len(v)) + sort.Strings(v) + equals(t, []string{ + "gijs", + "kees", + "teun", + "zus", + }, v) + } + + // Direct command + { + direct, err := s.HKeys("wim") + ok(t, err) + sort.Strings(direct) + equals(t, []string{ + "gijs", + "kees", + "teun", + "zus", + }, direct) + _, err = s.HKeys("nosuch") + equals(t, err, ErrKeyNotFound) + } + + v, err := redis.Strings(c.Do("HKEYS", "nosuch")) + ok(t, err) + equals(t, 0, len(v)) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HKEYS", "foo")) + assert(t, err != nil, "no HKEYS error") +} + +func TestHashValues(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + v, err := redis.Strings(c.Do("HVALS", "wim")) + ok(t, err) + equals(t, 4, len(v)) + sort.Strings(v) + equals(t, []string{ + "bok", + "jet", + "lam", + "vuur", + }, v) + + v, err = redis.Strings(c.Do("HVALS", "nosuch")) + ok(t, err) + equals(t, 0, len(v)) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HVALS", "foo")) + assert(t, err != nil, "no HVALS error") +} + +func TestHashLen(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + v, err := redis.Int(c.Do("HLEN", "wim")) + ok(t, err) + equals(t, 4, v) + + v, err = redis.Int(c.Do("HLEN", "nosuch")) + ok(t, err) + equals(t, 0, v) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HLEN", "foo")) + assert(t, err != nil, "no HLEN error") +} + +func TestHashMget(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.HSet("wim", "zus", "jet") + s.HSet("wim", "teun", "vuur") + s.HSet("wim", "gijs", "lam") + s.HSet("wim", "kees", "bok") + v, err := redis.Values(c.Do("HMGET", "wim", "zus", "nosuch", "kees")) + ok(t, err) + equals(t, 3, len(v)) + equals(t, "jet", string(v[0].([]byte))) + equals(t, nil, v[1]) + equals(t, "bok", string(v[2].([]byte))) + + v, err = redis.Values(c.Do("HMGET", "nosuch", "zus", "kees")) + ok(t, err) + equals(t, 2, len(v)) + equals(t, nil, v[0]) + equals(t, nil, v[1]) + + // Wrong key type + s.Set("foo", "bar") + _, err = redis.Int(c.Do("HMGET", "foo", "bar")) + assert(t, err != nil, "no HMGET error") +} + +func TestHashIncrby(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // New key + { + v, err := redis.Int(c.Do("HINCRBY", "hash", "field", 1)) + ok(t, err) + equals(t, 1, v) + } + + // Existing key + { + v, err := redis.Int(c.Do("HINCRBY", "hash", "field", 100)) + ok(t, err) + equals(t, 101, v) + } + + // Minus works. + { + v, err := redis.Int(c.Do("HINCRBY", "hash", "field", -12)) + ok(t, err) + equals(t, 101-12, v) + } + + // Direct usage + s.HIncr("hash", "field", -3) + equals(t, "86", s.HGet("hash", "field")) + + // Error cases. + { + // Wrong key type + s.Set("str", "cake") + _, err = redis.Values(c.Do("HINCRBY", "str", "case", 4)) + assert(t, err != nil, "no HINCRBY error") + + _, err = redis.Values(c.Do("HINCRBY", "str", "case", "foo")) + assert(t, err != nil, "no HINCRBY error") + + _, err = redis.Values(c.Do("HINCRBY", "str")) + assert(t, err != nil, "no HINCRBY error") + } +} + +func TestHashIncrbyfloat(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Existing key + { + s.HSet("hash", "field", "12") + v, err := redis.Float64(c.Do("HINCRBYFLOAT", "hash", "field", "400.12")) + ok(t, err) + equals(t, 412.12, v) + equals(t, "412.12", s.HGet("hash", "field")) + } + + // Existing key, not a number + { + s.HSet("hash", "field", "noint") + _, err := redis.Float64(c.Do("HINCRBYFLOAT", "hash", "field", "400")) + assert(t, err != nil, "do HINCRBYFLOAT error") + } + + // New key + { + v, err := redis.Float64(c.Do("HINCRBYFLOAT", "hash", "newfield", "40.33")) + ok(t, err) + equals(t, 40.33, v) + equals(t, "40.33", s.HGet("hash", "newfield")) + } + + // Direct usage + { + s.HSet("hash", "field", "500.1") + f, err := s.HIncrfloat("hash", "field", 12) + ok(t, err) + equals(t, 512.1, f) + equals(t, "512.1", s.HGet("hash", "field")) + } + + // Wrong type of existing key + { + s.Set("wrong", "type") + _, err := redis.Int(c.Do("HINCRBYFLOAT", "wrong", "type", "400")) + assert(t, err != nil, "do HINCRBYFLOAT error") + } + + // Wrong usage + { + _, err := redis.Int(c.Do("HINCRBYFLOAT")) + assert(t, err != nil, "do HINCRBYFLOAT error") + _, err = redis.Int(c.Do("HINCRBYFLOAT", "wrong")) + assert(t, err != nil, "do HINCRBYFLOAT error") + _, err = redis.Int(c.Do("HINCRBYFLOAT", "wrong", "value")) + assert(t, err != nil, "do HINCRBYFLOAT error") + _, err = redis.Int(c.Do("HINCRBYFLOAT", "wrong", "value", "noint")) + assert(t, err != nil, "do HINCRBYFLOAT error") + _, err = redis.Int(c.Do("HINCRBYFLOAT", "foo", "bar", 12, "tomanye")) + assert(t, err != nil, "do HINCRBYFLOAT error") + } +} + +func TestHscan(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // We cheat with hscan. It always returns everything. + + s.HSet("h", "field1", "value1") + s.HSet("h", "field2", "value2") + + // No problem + { + res, err := redis.Values(c.Do("HSCAN", "h", 0)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"field1", "value1", "field2", "value2"}, keys) + } + + // Invalid cursor + { + res, err := redis.Values(c.Do("HSCAN", "h", 42)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string(nil), keys) + } + + // COUNT (ignored) + { + res, err := redis.Values(c.Do("HSCAN", "h", 0, "COUNT", 200)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"field1", "value1", "field2", "value2"}, keys) + } + + // MATCH + { + s.HSet("h", "aap", "a") + s.HSet("h", "noot", "b") + s.HSet("h", "mies", "m") + res, err := redis.Values(c.Do("HSCAN", "h", 0, "MATCH", "mi*")) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"mies", "m"}, keys) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("HSCAN")) + assert(t, err != nil, "do HSCAN error") + _, err = redis.Int(c.Do("HSCAN", "set")) + assert(t, err != nil, "do HSCAN error") + _, err = redis.Int(c.Do("HSCAN", "set", "noint")) + assert(t, err != nil, "do HSCAN error") + _, err = redis.Int(c.Do("HSCAN", "set", 1, "MATCH")) + assert(t, err != nil, "do HSCAN error") + _, err = redis.Int(c.Do("HSCAN", "set", 1, "COUNT")) + assert(t, err != nil, "do HSCAN error") + _, err = redis.Int(c.Do("HSCAN", "set", 1, "COUNT", "noint")) + assert(t, err != nil, "do HSCAN error") + s.Set("str", "value") + _, err = redis.Int(c.Do("HSCAN", "str", 1)) + assert(t, err != nil, "do HSCAN error") + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_list.go b/vendor/github.com/alicebob/miniredis/cmd_list.go new file mode 100644 index 00000000..1c22b821 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_list.go @@ -0,0 +1,687 @@ +// Commands from http://redis.io/commands#list + +package miniredis + +import ( + "strconv" + "strings" + "time" + + "github.com/alicebob/miniredis/server" +) + +type leftright int + +const ( + left leftright = iota + right +) + +// commandsList handles list commands (mostly L*) +func commandsList(m *Miniredis) { + m.srv.Register("BLPOP", m.cmdBlpop) + m.srv.Register("BRPOP", m.cmdBrpop) + m.srv.Register("BRPOPLPUSH", m.cmdBrpoplpush) + m.srv.Register("LINDEX", m.cmdLindex) + m.srv.Register("LINSERT", m.cmdLinsert) + m.srv.Register("LLEN", m.cmdLlen) + m.srv.Register("LPOP", m.cmdLpop) + m.srv.Register("LPUSH", m.cmdLpush) + m.srv.Register("LPUSHX", m.cmdLpushx) + m.srv.Register("LRANGE", m.cmdLrange) + m.srv.Register("LREM", m.cmdLrem) + m.srv.Register("LSET", m.cmdLset) + m.srv.Register("LTRIM", m.cmdLtrim) + m.srv.Register("RPOP", m.cmdRpop) + m.srv.Register("RPOPLPUSH", m.cmdRpoplpush) + m.srv.Register("RPUSH", m.cmdRpush) + m.srv.Register("RPUSHX", m.cmdRpushx) +} + +// BLPOP +func (m *Miniredis) cmdBlpop(c *server.Peer, cmd string, args []string) { + m.cmdBXpop(c, cmd, args, left) +} + +// BRPOP +func (m *Miniredis) cmdBrpop(c *server.Peer, cmd string, args []string) { + m.cmdBXpop(c, cmd, args, right) +} + +func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftright) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + timeoutS := args[len(args)-1] + keys := args[:len(args)-1] + + timeout, err := strconv.Atoi(timeoutS) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidTimeout) + return + } + if timeout < 0 { + setDirty(c) + c.WriteError(msgNegTimeout) + return + } + + blocking( + m, + c, + time.Duration(timeout)*time.Second, + func(c *server.Peer, ctx *connCtx) bool { + db := m.db(ctx.selectedDB) + for _, key := range keys { + if !db.exists(key) { + continue + } + if db.t(key) != "list" { + c.WriteError(msgWrongType) + return true + } + + if len(db.listKeys[key]) == 0 { + continue + } + c.WriteLen(2) + c.WriteBulk(key) + var v string + switch lr { + case left: + v = db.listLpop(key) + case right: + v = db.listPop(key) + } + c.WriteBulk(v) + return true + } + return false + }, + func(c *server.Peer) { + // timeout + c.WriteNull() + }, + ) + return +} + +// LINDEX +func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, offsets := args[0], args[1] + + offset, err := strconv.Atoi(offsets) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + // No such key + c.WriteNull() + return + } + if t != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + if offset < 0 { + offset = len(l) + offset + } + if offset < 0 || offset > len(l)-1 { + c.WriteNull() + return + } + c.WriteBulk(l[offset]) + }) +} + +// LINSERT +func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) { + if len(args) != 4 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + where := 0 + switch strings.ToLower(args[1]) { + case "before": + where = -1 + case "after": + where = +1 + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + pivot := args[2] + value := args[3] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + // No such key + c.WriteInt(0) + return + } + if t != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + for i, el := range l { + if el != pivot { + continue + } + + if where < 0 { + l = append(l[:i], append(listKey{value}, l[i:]...)...) + } else { + if i == len(l)-1 { + l = append(l, value) + } else { + l = append(l[:i+1], append(listKey{value}, l[i+1:]...)...) + } + } + db.listKeys[key] = l + db.keyVersion[key]++ + c.WriteInt(len(l)) + return + } + c.WriteInt(-1) + }) +} + +// LLEN +func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + // No such key. That's zero length. + c.WriteInt(0) + return + } + if t != "list" { + c.WriteError(msgWrongType) + return + } + + c.WriteInt(len(db.listKeys[key])) + }) +} + +// LPOP +func (m *Miniredis) cmdLpop(c *server.Peer, cmd string, args []string) { + m.cmdXpop(c, cmd, args, left) +} + +// RPOP +func (m *Miniredis) cmdRpop(c *server.Peer, cmd string, args []string) { + m.cmdXpop(c, cmd, args, right) +} + +func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftright) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + // non-existing key is fine + c.WriteNull() + return + } + if db.t(key) != "list" { + c.WriteError(msgWrongType) + return + } + + var elem string + switch lr { + case left: + elem = db.listLpop(key) + case right: + elem = db.listPop(key) + } + c.WriteBulk(elem) + }) +} + +// LPUSH +func (m *Miniredis) cmdLpush(c *server.Peer, cmd string, args []string) { + m.cmdXpush(c, cmd, args, left) +} + +// RPUSH +func (m *Miniredis) cmdRpush(c *server.Peer, cmd string, args []string) { + m.cmdXpush(c, cmd, args, right) +} + +func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftright) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, args := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if db.exists(key) && db.t(key) != "list" { + c.WriteError(msgWrongType) + return + } + + var newLen int + for _, value := range args { + switch lr { + case left: + newLen = db.listLpush(key, value) + case right: + newLen = db.listPush(key, value) + } + } + c.WriteInt(newLen) + }) +} + +// LPUSHX +func (m *Miniredis) cmdLpushx(c *server.Peer, cmd string, args []string) { + m.cmdXpushx(c, cmd, args, left) +} + +// RPUSHX +func (m *Miniredis) cmdRpushx(c *server.Peer, cmd string, args []string) { + m.cmdXpushx(c, cmd, args, right) +} + +func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr leftright) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, value := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + if db.t(key) != "list" { + c.WriteError(msgWrongType) + return + } + + var newLen int + switch lr { + case left: + newLen = db.listLpush(key, value) + case right: + newLen = db.listPush(key, value) + } + c.WriteInt(newLen) + }) +} + +// LRANGE +func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + start, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + if len(l) == 0 { + c.WriteLen(0) + return + } + + rs, re := redisRange(len(l), start, end, false) + c.WriteLen(re - rs) + for _, el := range l[rs:re] { + c.WriteBulk(el) + } + }) +} + +// LREM +func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + count, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + value := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + if db.t(key) != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + if count < 0 { + reverseSlice(l) + } + deleted := 0 + newL := []string{} + toDelete := len(l) + if count < 0 { + toDelete = -count + } + if count > 0 { + toDelete = count + } + for _, el := range l { + if el == value { + if toDelete > 0 { + deleted++ + toDelete-- + continue + } + } + newL = append(newL, el) + } + if count < 0 { + reverseSlice(newL) + } + if len(newL) == 0 { + db.del(key, true) + } else { + db.listKeys[key] = newL + db.keyVersion[key]++ + } + + c.WriteInt(deleted) + }) +} + +// LSET +func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + index, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + value := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteError(msgKeyNotFound) + return + } + if db.t(key) != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + if index < 0 { + index = len(l) + index + } + if index < 0 || index > len(l)-1 { + c.WriteError(msgOutOfRange) + return + } + l[index] = value + db.keyVersion[key]++ + + c.WriteOK() + }) +} + +// LTRIM +func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + start, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + t, ok := db.keys[key] + if !ok { + c.WriteOK() + return + } + if t != "list" { + c.WriteError(msgWrongType) + return + } + + l := db.listKeys[key] + rs, re := redisRange(len(l), start, end, false) + l = l[rs:re] + if len(l) == 0 { + db.del(key, true) + } else { + db.listKeys[key] = l + db.keyVersion[key]++ + } + c.WriteOK() + }) +} + +// RPOPLPUSH +func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + src, dst := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(src) { + c.WriteNull() + return + } + if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") { + c.WriteError(msgWrongType) + return + } + elem := db.listPop(src) + db.listLpush(dst, elem) + c.WriteBulk(elem) + }) +} + +// BRPOPLPUSH +func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + src := args[0] + dst := args[1] + timeout, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidTimeout) + return + } + if timeout < 0 { + setDirty(c) + c.WriteError(msgNegTimeout) + return + } + + blocking( + m, + c, + time.Duration(timeout)*time.Second, + func(c *server.Peer, ctx *connCtx) bool { + db := m.db(ctx.selectedDB) + + if !db.exists(src) { + return false + } + if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") { + c.WriteError(msgWrongType) + return true + } + if len(db.listKeys[src]) == 0 { + return false + } + elem := db.listPop(src) + db.listLpush(dst, elem) + c.WriteBulk(elem) + return true + }, + func(c *server.Peer) { + // timeout + c.WriteNull() + }, + ) + return +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_list_test.go b/vendor/github.com/alicebob/miniredis/cmd_list_test.go new file mode 100644 index 00000000..cec28fab --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_list_test.go @@ -0,0 +1,1036 @@ +package miniredis + +import ( + "testing" + "time" + + "github.com/garyburd/redigo/redis" +) + +func setup(t *testing.T) (*Miniredis, redis.Conn, func()) { + s, err := Run() + ok(t, err) + c1, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + return s, c1, func() { s.Close() } +} +func setup2(t *testing.T) (*Miniredis, redis.Conn, redis.Conn, func()) { + s, err := Run() + ok(t, err) + c1, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + c2, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + return s, c1, c2, func() { s.Close() } +} + +func TestLpush(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + b, err := redis.Int(c.Do("LPUSH", "l", "aap", "noot", "mies")) + ok(t, err) + equals(t, 3, b) // New length. + + r, err := redis.Strings(c.Do("LRANGE", "l", "0", "0")) + ok(t, err) + equals(t, []string{"mies"}, r) + + r, err = redis.Strings(c.Do("LRANGE", "l", "-1", "-1")) + ok(t, err) + equals(t, []string{"aap"}, r) + } + + // Push more. + { + b, err := redis.Int(c.Do("LPUSH", "l", "aap2", "noot2", "mies2")) + ok(t, err) + equals(t, 6, b) // New length. + + r, err := redis.Strings(c.Do("LRANGE", "l", "0", "0")) + ok(t, err) + equals(t, []string{"mies2"}, r) + + r, err = redis.Strings(c.Do("LRANGE", "l", "-1", "-1")) + ok(t, err) + equals(t, []string{"aap"}, r) + } + + // Direct usage + { + l, err := s.Lpush("l2", "a") + ok(t, err) + equals(t, 1, l) + l, err = s.Lpush("l2", "b") + ok(t, err) + equals(t, 2, l) + list, err := s.List("l2") + ok(t, err) + equals(t, []string{"b", "a"}, list) + + el, err := s.Lpop("l2") + ok(t, err) + equals(t, "b", el) + el, err = s.Lpop("l2") + ok(t, err) + equals(t, "a", el) + // Key is removed on pop-empty. + equals(t, false, s.Exists("l2")) + } + + // Various errors + { + _, err := redis.Int(c.Do("LPUSH")) + assert(t, err != nil, "LPUSH error") + _, err = redis.Int(c.Do("LPUSH", "l")) + assert(t, err != nil, "LPUSH error") + _, err = redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("LPUSH", "str", "noot", "mies")) + assert(t, err != nil, "LPUSH error") + } + +} + +func TestLpushx(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("LPUSHX", "l", "aap")) + ok(t, err) + equals(t, 0, b) + equals(t, false, s.Exists("l")) + + // Create the list with a normal LPUSH + b, err = redis.Int(c.Do("LPUSH", "l", "noot")) + ok(t, err) + equals(t, 1, b) + equals(t, true, s.Exists("l")) + + b, err = redis.Int(c.Do("LPUSHX", "l", "mies")) + ok(t, err) + equals(t, 2, b) + equals(t, true, s.Exists("l")) + } + + // Errors + { + _, err = redis.Int(c.Do("LPUSHX")) + assert(t, err != nil, "LPUSHX error") + _, err = redis.Int(c.Do("LPUSHX", "l")) + assert(t, err != nil, "LPUSHX error") + _, err = redis.Int(c.Do("LPUSHX", "l", "too", "many")) + assert(t, err != nil, "LPUSHX error") + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("LPUSHX", "str", "mies")) + assert(t, err != nil, "LPUSHX error") + } + +} + +func TestLpop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + b, err := redis.Int(c.Do("LPUSH", "l", "aap", "noot", "mies")) + ok(t, err) + equals(t, 3, b) // New length. + + // Simple pops. + { + el, err := redis.String(c.Do("LPOP", "l")) + ok(t, err) + equals(t, "mies", el) + + el, err = redis.String(c.Do("LPOP", "l")) + ok(t, err) + equals(t, "noot", el) + + el, err = redis.String(c.Do("LPOP", "l")) + ok(t, err) + equals(t, "aap", el) + + // Last element has been popped. Key is gone. + i, err := redis.Int(c.Do("EXISTS", "l")) + ok(t, err) + equals(t, 0, i) + + // Can pop non-existing keys just fine. + v, err := c.Do("LPOP", "l") + ok(t, err) + equals(t, nil, v) + } +} + +func TestRPushPop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("RPUSH", "l", "aap", "noot", "mies")) + ok(t, err) + equals(t, 3, b) // New length. + + r, err := redis.Strings(c.Do("LRANGE", "l", "0", "0")) + ok(t, err) + equals(t, []string{"aap"}, r) + + r, err = redis.Strings(c.Do("LRANGE", "l", "-1", "-1")) + ok(t, err) + equals(t, []string{"mies"}, r) + } + + // Push more. + { + b, err := redis.Int(c.Do("RPUSH", "l", "aap2", "noot2", "mies2")) + ok(t, err) + equals(t, 6, b) // New length. + + r, err := redis.Strings(c.Do("LRANGE", "l", "0", "0")) + ok(t, err) + equals(t, []string{"aap"}, r) + + r, err = redis.Strings(c.Do("LRANGE", "l", "-1", "-1")) + ok(t, err) + equals(t, []string{"mies2"}, r) + } + + // Direct usage + { + l, err := s.Push("l2", "a") + ok(t, err) + equals(t, 1, l) + l, err = s.Push("l2", "b") + ok(t, err) + equals(t, 2, l) + list, err := s.List("l2") + ok(t, err) + equals(t, []string{"a", "b"}, list) + + el, err := s.Pop("l2") + ok(t, err) + equals(t, "b", el) + el, err = s.Pop("l2") + ok(t, err) + equals(t, "a", el) + // Key is removed on pop-empty. + equals(t, false, s.Exists("l2")) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("RPUSH", "str", "noot", "mies")) + assert(t, err != nil, "RPUSH error") + } + +} + +func TestRpop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies") + + // Simple pops. + { + el, err := redis.String(c.Do("RPOP", "l")) + ok(t, err) + equals(t, "mies", el) + + el, err = redis.String(c.Do("RPOP", "l")) + ok(t, err) + equals(t, "noot", el) + + el, err = redis.String(c.Do("RPOP", "l")) + ok(t, err) + equals(t, "aap", el) + + // Last element has been popped. Key is gone. + i, err := redis.Int(c.Do("EXISTS", "l")) + ok(t, err) + equals(t, 0, i) + + // Can pop non-existing keys just fine. + v, err := c.Do("RPOP", "l") + ok(t, err) + equals(t, nil, v) + } +} + +func TestLindex(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies", "vuur") + + { + el, err := redis.String(c.Do("LINDEX", "l", "0")) + ok(t, err) + equals(t, "aap", el) + } + { + el, err := redis.String(c.Do("LINDEX", "l", "1")) + ok(t, err) + equals(t, "noot", el) + } + { + el, err := redis.String(c.Do("LINDEX", "l", "3")) + ok(t, err) + equals(t, "vuur", el) + } + // Too many + { + el, err := c.Do("LINDEX", "l", "3000") + ok(t, err) + equals(t, nil, el) + } + { + el, err := redis.String(c.Do("LINDEX", "l", "-1")) + ok(t, err) + equals(t, "vuur", el) + } + { + el, err := redis.String(c.Do("LINDEX", "l", "-2")) + ok(t, err) + equals(t, "mies", el) + } + // Too big + { + el, err := c.Do("LINDEX", "l", "-400") + ok(t, err) + equals(t, nil, el) + } + // Non exising key + { + el, err := c.Do("LINDEX", "nonexisting", "400") + ok(t, err) + equals(t, nil, el) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("LINDEX", "str", "1")) + assert(t, err != nil, "LINDEX error") + // Not an integer + _, err = redis.String(c.Do("LINDEX", "l", "noint")) + assert(t, err != nil, "LINDEX error") + // Too many arguments + _, err = redis.String(c.Do("LINDEX", "str", "l", "foo")) + assert(t, err != nil, "LINDEX error") + } +} + +func TestLlen(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies", "vuur") + + { + el, err := redis.Int(c.Do("LLEN", "l")) + ok(t, err) + equals(t, 4, el) + } + + // Non exising key + { + el, err := redis.Int(c.Do("LLEN", "nonexisting")) + ok(t, err) + equals(t, 0, el) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("LLEN", "str")) + assert(t, err != nil, "LLEN error") + // Too many arguments + _, err = redis.String(c.Do("LLEN", "too", "many")) + assert(t, err != nil, "LLEN error") + } +} + +func TestLtrim(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies", "vuur") + + { + el, err := redis.String(c.Do("LTRIM", "l", 0, 2)) + ok(t, err) + equals(t, "OK", el) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "noot", "mies"}, l) + } + + // Delete key on empty list + { + el, err := redis.String(c.Do("LTRIM", "l", 0, -99)) + ok(t, err) + equals(t, "OK", el) + equals(t, false, s.Exists("l")) + } + + // Non exising key + { + el, err := redis.String(c.Do("LTRIM", "nonexisting", 0, 1)) + ok(t, err) + equals(t, "OK", el) + } + + // Wrong type of key + { + s.Set("str", "string!") + _, err = redis.Int(c.Do("LTRIM", "str", 0, 1)) + assert(t, err != nil, "LTRIM error") + // Too many/little/wrong arguments + _, err = redis.String(c.Do("LTRIM", "l", 1, 2, "toomany")) + assert(t, err != nil, "LTRIM error") + _, err = redis.String(c.Do("LTRIM", "l", 1, "noint")) + assert(t, err != nil, "LTRIM error") + _, err = redis.String(c.Do("LTRIM", "l", "noint", 1)) + assert(t, err != nil, "LTRIM error") + _, err = redis.String(c.Do("LTRIM", "l", 1)) + assert(t, err != nil, "LTRIM error") + _, err = redis.String(c.Do("LTRIM", "l")) + assert(t, err != nil, "LTRIM error") + _, err = redis.String(c.Do("LTRIM")) + assert(t, err != nil, "LTRIM error") + } +} + +func TestLrem(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Reverse + { + s.Push("l", "aap", "noot", "mies", "vuur", "noot", "noot") + n, err := redis.Int(c.Do("LREM", "l", -1, "noot")) + ok(t, err) + equals(t, 1, n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "noot", "mies", "vuur", "noot"}, l) + } + // Normal + { + s.Push("l2", "aap", "noot", "mies", "vuur", "noot", "noot") + n, err := redis.Int(c.Do("LREM", "l2", 2, "noot")) + ok(t, err) + equals(t, 2, n) + l, err := s.List("l2") + ok(t, err) + equals(t, []string{"aap", "mies", "vuur", "noot"}, l) + } + + // All + { + s.Push("l3", "aap", "noot", "mies", "vuur", "noot", "noot") + n, err := redis.Int(c.Do("LREM", "l3", 0, "noot")) + ok(t, err) + equals(t, 3, n) + l, err := s.List("l3") + ok(t, err) + equals(t, []string{"aap", "mies", "vuur"}, l) + } + + // All + { + s.Push("l4", "aap", "noot", "mies", "vuur", "noot", "noot") + n, err := redis.Int(c.Do("LREM", "l4", 200, "noot")) + ok(t, err) + equals(t, 3, n) + l, err := s.List("l4") + ok(t, err) + equals(t, []string{"aap", "mies", "vuur"}, l) + } + + // Delete key on empty list + { + s.Push("l5", "noot", "noot", "noot") + n, err := redis.Int(c.Do("LREM", "l5", 99, "noot")) + ok(t, err) + equals(t, 3, n) + equals(t, false, s.Exists("l5")) + } + + // Non exising key + { + n, err := redis.Int(c.Do("LREM", "nonexisting", 0, "aap")) + ok(t, err) + equals(t, 0, n) + } + + // Error cases + { + _, err = redis.String(c.Do("LREM")) + assert(t, err != nil, "LREM error") + _, err = redis.String(c.Do("LREM", "l")) + assert(t, err != nil, "LREM error") + _, err = redis.String(c.Do("LREM", "l", 1)) + assert(t, err != nil, "LREM error") + _, err = redis.String(c.Do("LREM", "l", "noint", "aap")) + assert(t, err != nil, "LREM error") + _, err = redis.String(c.Do("LREM", "l", 1, "aap", "toomany")) + assert(t, err != nil, "LREM error") + s.Set("str", "string!") + _, err = redis.Int(c.Do("LREM", "str", 0, "aap")) + assert(t, err != nil, "LREM error") + } +} + +func TestLset(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies", "vuur", "noot", "noot") + // Simple LSET + { + n, err := redis.String(c.Do("LSET", "l", 1, "noot!")) + ok(t, err) + equals(t, "OK", n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "noot!", "mies", "vuur", "noot", "noot"}, l) + } + + { + n, err := redis.String(c.Do("LSET", "l", -1, "noot?")) + ok(t, err) + equals(t, "OK", n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "noot!", "mies", "vuur", "noot", "noot?"}, l) + } + + // Out of range + { + _, err := c.Do("LSET", "l", 10000, "aap") + assert(t, err != nil, "LSET error") + + _, err = c.Do("LSET", "l", -10000, "aap") + assert(t, err != nil, "LSET error") + } + + // Non exising key + { + _, err := c.Do("LSET", "nonexisting", 0, "aap") + assert(t, err != nil, "LSET error") + } + + // Error cases + { + _, err = redis.String(c.Do("LSET")) + assert(t, err != nil, "LSET error") + _, err = redis.String(c.Do("LSET", "l")) + assert(t, err != nil, "LSET error") + _, err = redis.String(c.Do("LSET", "l", 1)) + assert(t, err != nil, "LSET error") + _, err = redis.String(c.Do("LSET", "l", "noint", "aap")) + assert(t, err != nil, "SET error") + _, err = redis.String(c.Do("LSET", "l", 1, "aap", "toomany")) + assert(t, err != nil, "LSET error") + s.Set("str", "string!") + _, err = redis.Int(c.Do("LSET", "str", 0, "aap")) + assert(t, err != nil, "LSET error") + } +} + +func TestLinsert(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies", "vuur", "noot", "end") + // Before + { + n, err := redis.Int(c.Do("LINSERT", "l", "BEFORE", "noot", "!")) + ok(t, err) + equals(t, 7, n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "!", "noot", "mies", "vuur", "noot", "end"}, l) + } + + // After + { + n, err := redis.Int(c.Do("LINSERT", "l", "AFTER", "noot", "?")) + ok(t, err) + equals(t, 8, n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"aap", "!", "noot", "?", "mies", "vuur", "noot", "end"}, l) + } + + // Edge case before + { + n, err := redis.Int(c.Do("LINSERT", "l", "BEFORE", "aap", "[")) + ok(t, err) + equals(t, 9, n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"[", "aap", "!", "noot", "?", "mies", "vuur", "noot", "end"}, l) + } + + // Edge case after + { + n, err := redis.Int(c.Do("LINSERT", "l", "AFTER", "end", "]")) + ok(t, err) + equals(t, 10, n) + l, err := s.List("l") + ok(t, err) + equals(t, []string{"[", "aap", "!", "noot", "?", "mies", "vuur", "noot", "end", "]"}, l) + } + + // Non exising pivot + { + n, err := redis.Int(c.Do("LINSERT", "l", "before", "nosuch", "noot")) + ok(t, err) + equals(t, -1, n) + } + + // Non exising key + { + n, err := redis.Int(c.Do("LINSERT", "nonexisting", "before", "aap", + "noot")) + ok(t, err) + equals(t, 0, n) + } + + // Error cases + { + _, err = redis.String(c.Do("LINSERT")) + assert(t, err != nil, "LINSERT error") + _, err = redis.String(c.Do("LINSERT", "l")) + assert(t, err != nil, "LINSERT error") + _, err = redis.String(c.Do("LINSERT", "l", "before")) + assert(t, err != nil, "LINSERT error") + _, err = redis.String(c.Do("LINSERT", "l", "before", "value")) + assert(t, err != nil, "LINSERT error") + _, err = redis.String(c.Do("LINSERT", "l", "wrong", "value", "value")) + assert(t, err != nil, "LINSERT error") + _, err = redis.String(c.Do("LINSERT", "l", "wrong", "value", "value", + "toomany")) + assert(t, err != nil, "LINSERT error") + s.Set("str", "string!") + _, err = redis.String(c.Do("LINSERT", "str", "before", "value", "value")) + assert(t, err != nil, "LINSERT error") + } +} + +func TestRpoplpush(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Push("l", "aap", "noot", "mies") + s.Push("l2", "vuur", "noot", "end") + { + n, err := redis.String(c.Do("RPOPLPUSH", "l", "l2")) + ok(t, err) + equals(t, "mies", n) + s.CheckList(t, "l", "aap", "noot") + s.CheckList(t, "l2", "mies", "vuur", "noot", "end") + } + // Again! + { + n, err := redis.String(c.Do("RPOPLPUSH", "l", "l2")) + ok(t, err) + equals(t, "noot", n) + s.CheckList(t, "l", "aap") + s.CheckList(t, "l2", "noot", "mies", "vuur", "noot", "end") + } + // Again! + { + n, err := redis.String(c.Do("RPOPLPUSH", "l", "l2")) + ok(t, err) + equals(t, "aap", n) + assert(t, !s.Exists("l"), "l exists") + s.CheckList(t, "l2", "aap", "noot", "mies", "vuur", "noot", "end") + } + + // Non exising lists + { + s.Push("ll", "aap", "noot", "mies") + + n, err := redis.String(c.Do("RPOPLPUSH", "ll", "nosuch")) + ok(t, err) + equals(t, "mies", n) + assert(t, s.Exists("nosuch"), "nosuch exists") + s.CheckList(t, "ll", "aap", "noot") + s.CheckList(t, "nosuch", "mies") + + nada, err := c.Do("RPOPLPUSH", "nosuch2", "ll") + ok(t, err) + equals(t, nil, nada) + } + + // Cycle + { + s.Push("cycle", "aap", "noot", "mies") + + n, err := redis.String(c.Do("RPOPLPUSH", "cycle", "cycle")) + ok(t, err) + equals(t, "mies", n) + s.CheckList(t, "cycle", "mies", "aap", "noot") + } + + // Error cases + { + s.Push("src", "aap", "noot", "mies") + _, err = redis.String(c.Do("RPOPLPUSH")) + assert(t, err != nil, "RPOPLPUSH error") + _, err = redis.String(c.Do("RPOPLPUSH", "l")) + assert(t, err != nil, "RPOPLPUSH error") + _, err = redis.String(c.Do("RPOPLPUSH", "too", "many", "arguments")) + assert(t, err != nil, "RPOPLPUSH error") + s.Set("str", "string!") + _, err = redis.String(c.Do("RPOPLPUSH", "str", "src")) + assert(t, err != nil, "RPOPLPUSH error") + _, err = redis.String(c.Do("RPOPLPUSH", "src", "str")) + assert(t, err != nil, "RPOPLPUSH error") + } +} + +func TestRpushx(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Simple cases + { + // No key key + i, err := redis.Int(c.Do("RPUSHX", "l", "value")) + ok(t, err) + equals(t, 0, i) + assert(t, !s.Exists("l"), "l doesn't exist") + + s.Push("l", "aap", "noot") + + i, err = redis.Int(c.Do("RPUSHX", "l", "mies")) + ok(t, err) + equals(t, 3, i) + + s.CheckList(t, "l", "aap", "noot", "mies") + } + + // Error cases + { + s.Push("src", "aap", "noot", "mies") + _, err = redis.String(c.Do("RPUSHX")) + assert(t, err != nil, "RPUSHX error") + _, err = redis.String(c.Do("RPUSHX", "l")) + assert(t, err != nil, "RPUSHX error") + _, err = redis.String(c.Do("RPUSHX", "too", "many", "arguments")) + assert(t, err != nil, "RPUSHX error") + s.Set("str", "string!") + _, err = redis.String(c.Do("RPUSHX", "str", "value")) + assert(t, err != nil, "RPUSHX error") + } +} + +// execute command in a go routine. Used to test blocking commands. +func goStrings(t *testing.T, c redis.Conn, cmds ...interface{}) <-chan []string { + var ( + got = make(chan []string, 1) + ) + go func() { + res, err := c.Do(cmds[0].(string), cmds[1:]...) + if err != nil { + got <- []string{err.Error()} + return + } + if res == nil { + got <- nil + } else { + st, _ := redis.Strings(res, err) + got <- st + } + }() + return got +} + +func TestBrpop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Simple cases + { + s.Push("ll", "aap", "noot", "mies") + v, err := redis.Strings(c.Do("BRPOP", "ll", 1)) + ok(t, err) + equals(t, []string{"ll", "mies"}, v) + } + + // Error cases + { + _, err = redis.String(c.Do("BRPOP")) + assert(t, err != nil, "BRPOP error") + _, err = redis.String(c.Do("BRPOP", "key")) + assert(t, err != nil, "BRPOP error") + _, err = redis.String(c.Do("BRPOP", "key", -1)) + assert(t, err != nil, "BRPOP error") + _, err = redis.String(c.Do("BRPOP", "key", "inf")) + assert(t, err != nil, "BRPOP error") + } +} + +func TestBrpopSimple(t *testing.T) { + _, c1, c2, done := setup2(t) + defer done() + + got := goStrings(t, c2, "BRPOP", "mylist", "0") + time.Sleep(30 * time.Millisecond) + + b, err := redis.Int(c1.Do("RPUSH", "mylist", "e1", "e2", "e3")) + ok(t, err) + equals(t, 3, b) + + select { + case have := <-got: + equals(t, []string{"mylist", "e3"}, have) + case <-time.After(500 * time.Millisecond): + t.Error("BRPOP took too long") + } +} + +func TestBrpopMulti(t *testing.T) { + _, c1, c2, done := setup2(t) + defer done() + + got := goStrings(t, c2, "BRPOP", "l1", "l2", "l3", 0) + _, err := redis.Int(c1.Do("RPUSH", "l0", "e01")) + ok(t, err) + _, err = redis.Int(c1.Do("RPUSH", "l2", "e21")) + ok(t, err) + _, err = redis.Int(c1.Do("RPUSH", "l3", "e31")) + ok(t, err) + + select { + case have := <-got: + equals(t, []string{"l2", "e21"}, have) + case <-time.After(500 * time.Millisecond): + t.Error("BRPOP took too long") + } + + got = goStrings(t, c2, "BRPOP", "l1", "l2", "l3", 0) + select { + case have := <-got: + equals(t, []string{"l3", "e31"}, have) + case <-time.After(500 * time.Millisecond): + t.Error("BRPOP took too long") + } +} + +func TestBrpopTimeout(t *testing.T) { + _, c, done := setup(t) + defer done() + + got := goStrings(t, c, "BRPOP", "l1", 1) + select { + case have := <-got: + equals(t, []string(nil), have) + case <-time.After(1500 * time.Millisecond): + t.Error("BRPOP took too long") + } +} + +func TestBrpopTx(t *testing.T) { + // BRPOP in a transaction behaves as if the timeout triggers right away + m, c, done := setup(t) + defer done() + + { + _, err := c.Do("MULTI") + ok(t, err) + s, err := redis.String(c.Do("BRPOP", "l1", 3)) + ok(t, err) + equals(t, "QUEUED", s) + s, err = redis.String(c.Do("SET", "foo", "bar")) + ok(t, err) + equals(t, "QUEUED", s) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 2, len(redis.Args(v))) + equals(t, nil, v[0]) + equals(t, "OK", v[1]) + } + + // Now set something + m.Push("l1", "e1") + + { + _, err := c.Do("MULTI") + ok(t, err) + s, err := redis.String(c.Do("BRPOP", "l1", 3)) + ok(t, err) + equals(t, "QUEUED", s) + s, err = redis.String(c.Do("SET", "foo", "bar")) + ok(t, err) + equals(t, "QUEUED", s) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 2, len(redis.Args(v))) + equals(t, "l1", string(v[0].([]interface{})[0].([]uint8))) + equals(t, "e1", string(v[0].([]interface{})[1].([]uint8))) + equals(t, "OK", v[1]) + } +} + +func TestBlpop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Simple cases + { + s.Push("ll", "aap", "noot", "mies") + v, err := redis.Strings(c.Do("BLPOP", "ll", 1)) + ok(t, err) + equals(t, []string{"ll", "aap"}, v) + } + + // Error cases + { + _, err = redis.String(c.Do("BLPOP")) + assert(t, err != nil, "BLPOP error") + _, err = redis.String(c.Do("BLPOP", "key")) + assert(t, err != nil, "BLPOP error") + _, err = redis.String(c.Do("BLPOP", "key", -1)) + assert(t, err != nil, "BLPOP error") + _, err = redis.String(c.Do("BLPOP", "key", "inf")) + assert(t, err != nil, "BLPOP error") + } +} + +func TestBrpoplpush(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Simple cases + { + s.Push("l1", "aap", "noot", "mies") + v, err := redis.String(c.Do("BRPOPLPUSH", "l1", "l2", "1")) + ok(t, err) + equals(t, "mies", v) + + lv, err := s.List("l2") + ok(t, err) + equals(t, []string{"mies"}, lv) + } + + // Error cases + { + _, err = redis.String(c.Do("BRPOPLPUSH")) + assert(t, err != nil, "BRPOPLPUSH error") + _, err = redis.String(c.Do("BRPOPLPUSH", "key")) + assert(t, err != nil, "BRPOPLPUSH error") + _, err = redis.String(c.Do("BRPOPLPUSH", "key", "bar")) + assert(t, err != nil, "BRPOPLPUSH error") + _, err = redis.String(c.Do("BRPOPLPUSH", "key", "foo", -1)) + assert(t, err != nil, "BRPOPLPUSH error") + _, err = redis.String(c.Do("BRPOPLPUSH", "key", "foo", "inf")) + assert(t, err != nil, "BRPOPLPUSH error") + _, err = redis.String(c.Do("BRPOPLPUSH", "key", "foo", 1, "baz")) + assert(t, err != nil, "BRPOPLPUSH error") + } +} + +func TestBrpoplpushSimple(t *testing.T) { + s, c1, c2, done := setup2(t) + defer done() + + got := make(chan string, 1) + go func() { + b, err := redis.String(c2.Do("BRPOPLPUSH", "from", "to", "1")) + ok(t, err) + got <- b + }() + + time.Sleep(30 * time.Millisecond) + + b, err := redis.Int(c1.Do("RPUSH", "from", "e1", "e2", "e3")) + ok(t, err) + equals(t, 3, b) + + select { + case have := <-got: + equals(t, "e3", have) + case <-time.After(500 * time.Millisecond): + t.Error("BRPOP took too long") + } + + lv, err := s.List("from") + ok(t, err) + equals(t, []string{"e1", "e2"}, lv) + lv, err = s.List("to") + ok(t, err) + equals(t, []string{"e3"}, lv) +} + +func TestBrpoplpushTimeout(t *testing.T) { + _, c, done := setup(t) + defer done() + + got := goStrings(t, c, "BRPOPLPUSH", "l1", "l2", 1) + select { + case have := <-got: + equals(t, []string(nil), have) + case <-time.After(1500 * time.Millisecond): + t.Error("BRPOPLPUSH took too long") + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_server.go b/vendor/github.com/alicebob/miniredis/cmd_server.go new file mode 100644 index 00000000..67dea5f9 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_server.go @@ -0,0 +1,65 @@ +// Commands from http://redis.io/commands#server + +package miniredis + +import ( + "github.com/alicebob/miniredis/server" +) + +func commandsServer(m *Miniredis) { + m.srv.Register("DBSIZE", m.cmdDbsize) + m.srv.Register("FLUSHALL", m.cmdFlushall) + m.srv.Register("FLUSHDB", m.cmdFlushdb) +} + +// DBSIZE +func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) { + if len(args) > 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + c.WriteInt(len(db.keys)) + }) +} + +// FLUSHALL +func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) { + if len(args) > 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + m.flushAll() + c.WriteOK() + }) +} + +// FLUSHDB +func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) { + if len(args) > 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + m.db(ctx.selectedDB).flush() + c.WriteOK() + }) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_server_test.go b/vendor/github.com/alicebob/miniredis/cmd_server_test.go new file mode 100644 index 00000000..a4198b2b --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_server_test.go @@ -0,0 +1,71 @@ +package miniredis + +import ( + "testing" + + "github.com/garyburd/redigo/redis" +) + +// Test DBSIZE, FLUSHDB, and FLUSHALL. +func TestCmdServer(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Set something + { + s.Set("aap", "niet") + s.Set("roos", "vuur") + s.DB(1).Set("noot", "mies") + } + + { + n, err := redis.Int(c.Do("DBSIZE")) + ok(t, err) + equals(t, 2, n) + + b, err := redis.String(c.Do("FLUSHDB")) + ok(t, err) + equals(t, "OK", b) + + n, err = redis.Int(c.Do("DBSIZE")) + ok(t, err) + equals(t, 0, n) + + _, err = c.Do("SELECT", 1) + ok(t, err) + + n, err = redis.Int(c.Do("DBSIZE")) + ok(t, err) + equals(t, 1, n) + + b, err = redis.String(c.Do("FLUSHALL")) + ok(t, err) + equals(t, "OK", b) + + n, err = redis.Int(c.Do("DBSIZE")) + ok(t, err) + equals(t, 0, n) + + _, err = c.Do("SELECT", 4) + ok(t, err) + + n, err = redis.Int(c.Do("DBSIZE")) + ok(t, err) + equals(t, 0, n) + + } + + { + _, err := redis.Int(c.Do("DBSIZE", "FOO")) + assert(t, err != nil, "no DBSIZE error") + + _, err = redis.Int(c.Do("FLUSHDB", "FOO")) + assert(t, err != nil, "no FLUSHDB error") + + _, err = redis.Int(c.Do("FLUSHALL", "FOO")) + assert(t, err != nil, "no FLUSHALL error") + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_set.go b/vendor/github.com/alicebob/miniredis/cmd_set.go new file mode 100644 index 00000000..76c4995f --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_set.go @@ -0,0 +1,639 @@ +// Commands from http://redis.io/commands#set + +package miniredis + +import ( + "math/rand" + "strconv" + "strings" + + "github.com/alicebob/miniredis/server" +) + +// commandsSet handles all set value operations. +func commandsSet(m *Miniredis) { + m.srv.Register("SADD", m.cmdSadd) + m.srv.Register("SCARD", m.cmdScard) + m.srv.Register("SDIFF", m.cmdSdiff) + m.srv.Register("SDIFFSTORE", m.cmdSdiffstore) + m.srv.Register("SINTER", m.cmdSinter) + m.srv.Register("SINTERSTORE", m.cmdSinterstore) + m.srv.Register("SISMEMBER", m.cmdSismember) + m.srv.Register("SMEMBERS", m.cmdSmembers) + m.srv.Register("SMOVE", m.cmdSmove) + m.srv.Register("SPOP", m.cmdSpop) + m.srv.Register("SRANDMEMBER", m.cmdSrandmember) + m.srv.Register("SREM", m.cmdSrem) + m.srv.Register("SUNION", m.cmdSunion) + m.srv.Register("SUNIONSTORE", m.cmdSunionstore) + m.srv.Register("SSCAN", m.cmdSscan) +} + +// SADD +func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, elems := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if db.exists(key) && db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + added := db.setAdd(key, elems...) + c.WriteInt(added) + }) +} + +// SCARD +func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.setMembers(key) + c.WriteInt(len(members)) + }) +} + +// SDIFF +func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + keys := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setDiff(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + c.WriteLen(len(set)) + for k := range set { + c.WriteBulk(k) + } + }) +} + +// SDIFFSTORE +func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + dest, keys := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setDiff(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + db.del(dest, true) + db.setSet(dest, set) + c.WriteInt(len(set)) + }) +} + +// SINTER +func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + keys := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setInter(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + c.WriteLen(len(set)) + for k := range set { + c.WriteBulk(k) + } + }) +} + +// SINTERSTORE +func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + dest, keys := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setInter(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + db.del(dest, true) + db.setSet(dest, set) + c.WriteInt(len(set)) + }) +} + +// SISMEMBER +func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, value := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + if db.setIsMember(key, value) { + c.WriteInt(1) + return + } + c.WriteInt(0) + }) +} + +// SMEMBERS +func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.setMembers(key) + + c.WriteLen(len(members)) + for _, elem := range members { + c.WriteBulk(elem) + } + }) +} + +// SMOVE +func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + src, dst, member := args[0], args[1], args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(src) { + c.WriteInt(0) + return + } + + if db.t(src) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + if db.exists(dst) && db.t(dst) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + if !db.setIsMember(src, member) { + c.WriteInt(0) + return + } + db.setRem(src, member) + db.setAdd(dst, member) + c.WriteInt(1) + }) +} + +// SPOP +func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) { + if len(args) == 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, args := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + withCount := false + count := 1 + if len(args) > 0 { + v, err := strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + count = v + withCount = true + args = args[1:] + } + if len(args) > 0 { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + if !db.exists(key) { + if !withCount { + c.WriteNull() + return + } + c.WriteLen(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + var deleted []string + for i := 0; i < count; i++ { + members := db.setMembers(key) + if len(members) == 0 { + break + } + member := members[rand.Intn(len(members))] + db.setRem(key, member) + deleted = append(deleted, member) + } + // without `count` return a single value... + if !withCount { + if len(deleted) == 0 { + c.WriteNull() + return + } + c.WriteBulk(deleted[0]) + return + } + // ... with `count` return a list + c.WriteLen(len(deleted)) + for _, v := range deleted { + c.WriteBulk(v) + } + }) +} + +// SRANDMEMBER +func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if len(args) > 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + count := 0 + withCount := false + if len(args) == 2 { + var err error + count, err = strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + withCount = true + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteNull() + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.setMembers(key) + if count < 0 { + // Non-unique elements is allowed with negative count. + c.WriteLen(-count) + for count != 0 { + member := members[rand.Intn(len(members))] + c.WriteBulk(member) + count++ + } + return + } + + // Must be unique elements. + shuffle(members) + if count > len(members) { + count = len(members) + } + if !withCount { + c.WriteBulk(members[0]) + return + } + c.WriteLen(count) + for i := range make([]struct{}, count) { + c.WriteBulk(members[i]) + } + }) +} + +// SREM +func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, fields := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + c.WriteInt(db.setRem(key, fields...)) + }) +} + +// SUNION +func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + keys := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setUnion(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + c.WriteLen(len(set)) + for k := range set { + c.WriteBulk(k) + } + }) +} + +// SUNIONSTORE +func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + dest, keys := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + set, err := db.setUnion(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + db.del(dest, true) + db.setSet(dest, set) + c.WriteInt(len(set)) + }) +} + +// SSCAN +func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + cursor, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidCursor) + return + } + args = args[2:] + // MATCH and COUNT options + var withMatch bool + var match string + for len(args) > 0 { + if strings.ToLower(args[0]) == "count" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + _, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + // We do nothing with count. + args = args[2:] + continue + } + if strings.ToLower(args[0]) == "match" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + withMatch = true + match = args[1] + args = args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + // return _all_ (matched) keys every time + + if cursor != 0 { + // invalid cursor + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(0) // no elements + return + } + if db.exists(key) && db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.setMembers(key) + if withMatch { + members = matchKeys(members, match) + } + + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(len(members)) + for _, k := range members { + c.WriteBulk(k) + } + }) +} + +// shuffle shuffles a string. Kinda. +func shuffle(m []string) { + for _ = range m { + i := rand.Intn(len(m)) + j := rand.Intn(len(m)) + m[i], m[j] = m[j], m[i] + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_set_test.go b/vendor/github.com/alicebob/miniredis/cmd_set_test.go new file mode 100644 index 00000000..ec27e83f --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_set_test.go @@ -0,0 +1,735 @@ +package miniredis + +import ( + "sort" + "testing" + + "github.com/garyburd/redigo/redis" +) + +// Test SADD / SMEMBERS. +func TestSadd(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("SADD", "s", "aap", "noot", "mies")) + ok(t, err) + equals(t, 3, b) // New elements. + + members, err := s.Members("s") + ok(t, err) + equals(t, []string{"aap", "mies", "noot"}, members) + + m, err := redis.Strings(c.Do("SMEMBERS", "s")) + ok(t, err) + equals(t, []string{"aap", "mies", "noot"}, m) + } + + { + b, err := redis.String(c.Do("TYPE", "s")) + ok(t, err) + equals(t, "set", b) + } + + // SMEMBERS on an nonexisting key + { + m, err := redis.Strings(c.Do("SMEMBERS", "nosuch")) + ok(t, err) + equals(t, []string{}, m) + } + + { + b, err := redis.Int(c.Do("SADD", "s", "new", "noot", "mies")) + ok(t, err) + equals(t, 1, b) // Only one new field. + + members, err := s.Members("s") + ok(t, err) + equals(t, []string{"aap", "mies", "new", "noot"}, members) + } + + // Direct usage + { + added, err := s.SetAdd("s1", "aap") + ok(t, err) + equals(t, 1, added) + + members, err := s.Members("s1") + ok(t, err) + equals(t, []string{"aap"}, members) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("SADD", "str", "hi")) + assert(t, err != nil, "SADD error") + _, err = redis.Int(c.Do("SMEMBERS", "str")) + assert(t, err != nil, "MEMBERS error") + // Wrong argument counts + _, err = redis.String(c.Do("SADD")) + assert(t, err != nil, "SADD error") + _, err = redis.String(c.Do("SADD", "set")) + assert(t, err != nil, "SADD error") + _, err = redis.String(c.Do("SMEMBERS")) + assert(t, err != nil, "SMEMBERS error") + _, err = redis.String(c.Do("SMEMBERS", "set", "spurious")) + assert(t, err != nil, "SMEMBERS error") + } + +} + +// Test SISMEMBER +func TestSismember(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s", "aap", "noot", "mies") + + { + b, err := redis.Int(c.Do("SISMEMBER", "s", "aap")) + ok(t, err) + equals(t, 1, b) + + b, err = redis.Int(c.Do("SISMEMBER", "s", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // a nonexisting key + { + b, err := redis.Int(c.Do("SISMEMBER", "nosuch", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // Direct usage + { + isMember, err := s.IsMember("s", "noot") + ok(t, err) + equals(t, true, isMember) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("SISMEMBER", "str")) + assert(t, err != nil, "SISMEMBER error") + // Wrong argument counts + _, err = redis.String(c.Do("SISMEMBER")) + assert(t, err != nil, "SISMEMBER error") + _, err = redis.String(c.Do("SISMEMBER", "set")) + assert(t, err != nil, "SISMEMBER error") + _, err = redis.String(c.Do("SISMEMBER", "set", "spurious", "args")) + assert(t, err != nil, "SISMEMBER error") + } + +} + +// Test SREM +func TestSrem(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s", "aap", "noot", "mies", "vuur") + + { + b, err := redis.Int(c.Do("SREM", "s", "aap", "noot")) + ok(t, err) + equals(t, 2, b) + + members, err := s.Members("s") + ok(t, err) + equals(t, []string{"mies", "vuur"}, members) + } + + // a nonexisting field + { + b, err := redis.Int(c.Do("SREM", "s", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // a nonexisting key + { + b, err := redis.Int(c.Do("SREM", "nosuch", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // Direct usage + { + b, err := s.SRem("s", "mies") + ok(t, err) + equals(t, 1, b) + + members, err := s.Members("s") + ok(t, err) + equals(t, []string{"vuur"}, members) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("SREM", "str", "value")) + assert(t, err != nil, "SREM error") + // Wrong argument counts + _, err = redis.String(c.Do("SREM")) + assert(t, err != nil, "SREM error") + _, err = redis.String(c.Do("SREM", "set")) + assert(t, err != nil, "SREM error") + _, err = redis.String(c.Do("SREM", "set", "spurious", "args")) + assert(t, err != nil, "SREM error") + } +} + +// Test SMOVE +func TestSmove(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s", "aap", "noot") + + { + b, err := redis.Int(c.Do("SMOVE", "s", "s2", "aap")) + ok(t, err) + equals(t, 1, b) + + m, err := s.IsMember("s", "aap") + ok(t, err) + equals(t, false, m) + m, err = s.IsMember("s2", "aap") + ok(t, err) + equals(t, true, m) + } + + // Move away the last member + { + b, err := redis.Int(c.Do("SMOVE", "s", "s2", "noot")) + ok(t, err) + equals(t, 1, b) + + equals(t, false, s.Exists("s")) + + m, err := s.IsMember("s2", "noot") + ok(t, err) + equals(t, true, m) + } + + // a nonexisting member + { + b, err := redis.Int(c.Do("SMOVE", "s", "s2", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // a nonexisting key + { + b, err := redis.Int(c.Do("SMOVE", "nosuch", "nosuch2", "nosuch")) + ok(t, err) + equals(t, 0, b) + } + + // Wrong type of key + { + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + _, err = redis.Int(c.Do("SMOVE", "str", "dst", "value")) + assert(t, err != nil, "SMOVE error") + _, err = redis.Int(c.Do("SMOVE", "s2", "str", "value")) + assert(t, err != nil, "SMOVE error") + // Wrong argument counts + _, err = redis.String(c.Do("SMOVE")) + assert(t, err != nil, "SMOVE error") + _, err = redis.String(c.Do("SMOVE", "set")) + assert(t, err != nil, "SMOVE error") + _, err = redis.String(c.Do("SMOVE", "set", "set2")) + assert(t, err != nil, "SMOVE error") + _, err = redis.String(c.Do("SMOVE", "set", "set2", "spurious", "args")) + assert(t, err != nil, "SMOVE error") + } +} + +// Test SPOP +func TestSpop(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s", "aap", "noot") + + { + el, err := redis.String(c.Do("SPOP", "s")) + ok(t, err) + assert(t, el == "aap" || el == "noot", "spop got something") + + el, err = redis.String(c.Do("SPOP", "s")) + ok(t, err) + assert(t, el == "aap" || el == "noot", "spop got something") + + assert(t, !s.Exists("s"), "all spopped away") + } + + // a nonexisting key + { + b, err := c.Do("SPOP", "nosuch") + ok(t, err) + equals(t, nil, b) + } + + // various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SMOVE")) + assert(t, err != nil, "SMOVE error") + _, err = redis.String(c.Do("SMOVE", "chk", "set2")) + assert(t, err != nil, "SMOVE error") + + _, err = c.Do("SPOP", "str") + assert(t, err != nil, "SPOP error") + } + + // count argument + { + s.SetAdd("s", "aap", "noot", "mies", "vuur") + el, err := redis.Strings(c.Do("SPOP", "s", 2)) + ok(t, err) + assert(t, len(el) == 2, "SPOP s 2") + members, err := s.Members("s") + ok(t, err) + assert(t, len(members) == 2, "SPOP s 2") + } +} + +// Test SRANDMEMBER +func TestSrandmember(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s", "aap", "noot", "mies") + + // No count + { + el, err := redis.String(c.Do("SRANDMEMBER", "s")) + ok(t, err) + assert(t, el == "aap" || el == "noot" || el == "mies", "srandmember got something") + } + + // Positive count + { + els, err := redis.Strings(c.Do("SRANDMEMBER", "s", 2)) + ok(t, err) + equals(t, 2, len(els)) + } + + // Negative count + { + els, err := redis.Strings(c.Do("SRANDMEMBER", "s", -2)) + ok(t, err) + equals(t, 2, len(els)) + } + + // a nonexisting key + { + b, err := c.Do("SRANDMEMBER", "nosuch") + ok(t, err) + equals(t, nil, b) + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SRANDMEMBER")) + assert(t, err != nil, "SRANDMEMBER error") + _, err = redis.String(c.Do("SRANDMEMBER", "chk", "noint")) + assert(t, err != nil, "SRANDMEMBER error") + _, err = redis.String(c.Do("SRANDMEMBER", "chk", 1, "toomanu")) + assert(t, err != nil, "SRANDMEMBER error") + + _, err = c.Do("SRANDMEMBER", "str") + assert(t, err != nil, "SRANDMEMBER error") + } +} + +// Test SDIFF +func TestSdiff(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + els, err := redis.Strings(c.Do("SDIFF", "s1", "s2")) + ok(t, err) + equals(t, []string{"aap"}, els) + } + + // No other set + { + els, err := redis.Strings(c.Do("SDIFF", "s1")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"aap", "mies", "noot"}, els) + } + + // 3 sets + { + els, err := redis.Strings(c.Do("SDIFF", "s1", "s2", "s3")) + ok(t, err) + equals(t, []string{}, els) + } + + // A nonexisting key + { + els, err := redis.Strings(c.Do("SDIFF", "s9")) + ok(t, err) + equals(t, []string{}, els) + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SDIFF")) + assert(t, err != nil, "SDIFF error") + _, err = redis.String(c.Do("SDIFF", "str")) + assert(t, err != nil, "SDIFF error") + _, err = redis.String(c.Do("SDIFF", "chk", "str")) + assert(t, err != nil, "SDIFF error") + } +} + +// Test SDIFFSTORE +func TestSdiffstore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + i, err := redis.Int(c.Do("SDIFFSTORE", "res", "s1", "s3")) + ok(t, err) + equals(t, 1, i) + s.CheckSet(t, "res", "noot") + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SDIFFSTORE")) + assert(t, err != nil, "SDIFFSTORE error") + _, err = redis.String(c.Do("SDIFFSTORE", "t")) + assert(t, err != nil, "SDIFFSTORE error") + _, err = redis.String(c.Do("SDIFFSTORE", "t", "str")) + assert(t, err != nil, "SDIFFSTORE error") + } +} + +// Test SINTER +func TestSinter(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + els, err := redis.Strings(c.Do("SINTER", "s1", "s2")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"mies", "noot"}, els) + } + + // No other set + { + els, err := redis.Strings(c.Do("SINTER", "s1")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"aap", "mies", "noot"}, els) + } + + // 3 sets + { + els, err := redis.Strings(c.Do("SINTER", "s1", "s2", "s3")) + ok(t, err) + equals(t, []string{"mies"}, els) + } + + // A nonexisting key + { + els, err := redis.Strings(c.Do("SINTER", "s9")) + ok(t, err) + equals(t, []string{}, els) + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SINTER")) + assert(t, err != nil, "SINTER error") + _, err = redis.String(c.Do("SINTER", "str")) + assert(t, err != nil, "SINTER error") + _, err = redis.String(c.Do("SINTER", "chk", "str")) + assert(t, err != nil, "SINTER error") + } +} + +// Test SINTERSTORE +func TestSinterstore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + i, err := redis.Int(c.Do("SINTERSTORE", "res", "s1", "s3")) + ok(t, err) + equals(t, 2, i) + s.CheckSet(t, "res", "aap", "mies") + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SINTERSTORE")) + assert(t, err != nil, "SINTERSTORE error") + _, err = redis.String(c.Do("SINTERSTORE", "t")) + assert(t, err != nil, "SINTERSTORE error") + _, err = redis.String(c.Do("SINTERSTORE", "t", "str")) + assert(t, err != nil, "SINTERSTORE error") + } +} + +// Test SUNION +func TestSunion(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + els, err := redis.Strings(c.Do("SUNION", "s1", "s2")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"aap", "mies", "noot", "vuur"}, els) + } + + // No other set + { + els, err := redis.Strings(c.Do("SUNION", "s1")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"aap", "mies", "noot"}, els) + } + + // 3 sets + { + els, err := redis.Strings(c.Do("SUNION", "s1", "s2", "s3")) + ok(t, err) + sort.Strings(els) + equals(t, []string{"aap", "mies", "noot", "vuur", "wim"}, els) + } + + // A nonexisting key + { + els, err := redis.Strings(c.Do("SUNION", "s9")) + ok(t, err) + equals(t, []string{}, els) + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SUNION")) + assert(t, err != nil, "SUNION error") + _, err = redis.String(c.Do("SUNION", "str")) + assert(t, err != nil, "SUNION error") + _, err = redis.String(c.Do("SUNION", "chk", "str")) + assert(t, err != nil, "SUNION error") + } +} + +// Test SUNIONSTORE +func TestSunionstore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.SetAdd("s1", "aap", "noot", "mies") + s.SetAdd("s2", "noot", "mies", "vuur") + s.SetAdd("s3", "aap", "mies", "wim") + + // Simple case + { + i, err := redis.Int(c.Do("SUNIONSTORE", "res", "s1", "s3")) + ok(t, err) + equals(t, 4, i) + s.CheckSet(t, "res", "aap", "mies", "noot", "wim") + } + + // Various errors + { + s.SetAdd("chk", "aap", "noot") + s.Set("str", "value") + + _, err = redis.String(c.Do("SUNIONSTORE")) + assert(t, err != nil, "SUNIONSTORE error") + _, err = redis.String(c.Do("SUNIONSTORE", "t")) + assert(t, err != nil, "SUNIONSTORE error") + _, err = redis.String(c.Do("SUNIONSTORE", "t", "str")) + assert(t, err != nil, "SUNIONSTORE error") + } +} + +func TestSscan(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // We cheat with sscan. It always returns everything. + + s.SetAdd("set", "value1", "value2") + + // No problem + { + res, err := redis.Values(c.Do("SSCAN", "set", 0)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"value1", "value2"}, keys) + } + + // Invalid cursor + { + res, err := redis.Values(c.Do("SSCAN", "set", 42)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string(nil), keys) + } + + // COUNT (ignored) + { + res, err := redis.Values(c.Do("SSCAN", "set", 0, "COUNT", 200)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"value1", "value2"}, keys) + } + + // MATCH + { + s.SetAdd("set", "aap", "noot", "mies") + res, err := redis.Values(c.Do("SSCAN", "set", 0, "MATCH", "mi*")) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"mies"}, keys) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("SSCAN")) + assert(t, err != nil, "do SSCAN error") + _, err = redis.Int(c.Do("SSCAN", "set")) + assert(t, err != nil, "do SSCAN error") + _, err = redis.Int(c.Do("SSCAN", "set", "noint")) + assert(t, err != nil, "do SSCAN error") + _, err = redis.Int(c.Do("SSCAN", "set", 1, "MATCH")) + assert(t, err != nil, "do SSCAN error") + _, err = redis.Int(c.Do("SSCAN", "set", 1, "COUNT")) + assert(t, err != nil, "do SSCAN error") + _, err = redis.Int(c.Do("SSCAN", "set", 1, "COUNT", "noint")) + assert(t, err != nil, "do SSCAN error") + s.Set("str", "value") + _, err = redis.Int(c.Do("SSCAN", "str", 1)) + assert(t, err != nil, "do SSCAN error") + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_sorted_set.go b/vendor/github.com/alicebob/miniredis/cmd_sorted_set.go new file mode 100644 index 00000000..564b9e25 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_sorted_set.go @@ -0,0 +1,1283 @@ +// Commands from http://redis.io/commands#sorted_set + +package miniredis + +import ( + "errors" + "sort" + "strconv" + "strings" + + "github.com/alicebob/miniredis/server" +) + +var ( + errInvalidRangeItem = errors.New(msgInvalidRangeItem) +) + +// commandsSortedSet handles all sorted set operations. +func commandsSortedSet(m *Miniredis) { + m.srv.Register("ZADD", m.cmdZadd) + m.srv.Register("ZCARD", m.cmdZcard) + m.srv.Register("ZCOUNT", m.cmdZcount) + m.srv.Register("ZINCRBY", m.cmdZincrby) + m.srv.Register("ZINTERSTORE", m.cmdZinterstore) + m.srv.Register("ZLEXCOUNT", m.cmdZlexcount) + m.srv.Register("ZRANGE", m.makeCmdZrange(false)) + m.srv.Register("ZRANGEBYLEX", m.cmdZrangebylex) + m.srv.Register("ZRANGEBYSCORE", m.makeCmdZrangebyscore(false)) + m.srv.Register("ZRANK", m.makeCmdZrank(false)) + m.srv.Register("ZREM", m.cmdZrem) + m.srv.Register("ZREMRANGEBYLEX", m.cmdZremrangebylex) + m.srv.Register("ZREMRANGEBYRANK", m.cmdZremrangebyrank) + m.srv.Register("ZREMRANGEBYSCORE", m.cmdZremrangebyscore) + m.srv.Register("ZREVRANGE", m.makeCmdZrange(true)) + m.srv.Register("ZREVRANGEBYSCORE", m.makeCmdZrangebyscore(true)) + m.srv.Register("ZREVRANK", m.makeCmdZrank(true)) + m.srv.Register("ZSCORE", m.cmdZscore) + m.srv.Register("ZUNIONSTORE", m.cmdZunionstore) + m.srv.Register("ZSCAN", m.cmdZscan) +} + +// ZADD +func (m *Miniredis) cmdZadd(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, args := args[0], args[1:] + var ( + nx = false + xx = false + ch = false + elems = map[string]float64{} + ) + for len(args) > 0 { + switch strings.ToUpper(args[0]) { + case "NX": + nx = true + args = args[1:] + continue + case "XX": + xx = true + args = args[1:] + continue + case "CH": + ch = true + args = args[1:] + continue + default: + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + score, err := strconv.ParseFloat(args[0], 64) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidFloat) + return + } + elems[args[1]] = score + args = args[2:] + } + } + + if xx && nx { + setDirty(c) + c.WriteError(msgXXandNX) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if db.exists(key) && db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + res := 0 + for member, score := range elems { + if nx && db.ssetExists(key, member) { + continue + } + if xx && !db.ssetExists(key, member) { + continue + } + old := db.ssetScore(key, member) + if db.ssetAdd(key, score, member) { + res++ + } else { + if ch && old != score { + // if 'CH' is specified, only count changed keys + res++ + } + } + } + c.WriteInt(res) + }) +} + +// ZCARD +func (m *Miniredis) cmdZcard(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + c.WriteInt(db.ssetCard(key)) + }) +} + +// ZCOUNT +func (m *Miniredis) cmdZcount(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseFloatRange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + max, maxIncl, err := parseFloatRange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetElements(key) + members = withSSRange(members, min, minIncl, max, maxIncl) + c.WriteInt(len(members)) + }) +} + +// ZINCRBY +func (m *Miniredis) cmdZincrby(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + delta, err := strconv.ParseFloat(args[1], 64) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidFloat) + return + } + member := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if db.exists(key) && db.t(key) != "zset" { + c.WriteError(msgWrongType) + return + } + newScore := db.ssetIncrby(key, member, delta) + c.WriteBulk(formatFloat(newScore)) + }) +} + +// ZINTERSTORE +func (m *Miniredis) cmdZinterstore(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + destination := args[0] + numKeys, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + if len(args) < numKeys { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + if numKeys <= 0 { + setDirty(c) + c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE") + return + } + keys := args[:numKeys] + args = args[numKeys:] + + withWeights := false + weights := []float64{} + aggregate := "sum" + for len(args) > 0 { + if strings.ToLower(args[0]) == "weights" { + if len(args) < numKeys+1 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + for i := 0; i < numKeys; i++ { + f, err := strconv.ParseFloat(args[i+1], 64) + if err != nil { + setDirty(c) + c.WriteError("ERR weight value is not a float") + return + } + weights = append(weights, f) + } + withWeights = true + args = args[numKeys+1:] + continue + } + if strings.ToLower(args[0]) == "aggregate" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + aggregate = strings.ToLower(args[1]) + switch aggregate { + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + case "sum", "min", "max": + } + args = args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + db.del(destination, true) + + // We collect everything and remove all keys which turned out not to be + // present in every set. + sset := map[string]float64{} + counts := map[string]int{} + for i, key := range keys { + if !db.exists(key) { + continue + } + if db.t(key) != "zset" { + c.WriteError(msgWrongType) + return + } + for _, el := range db.ssetElements(key) { + score := el.score + if withWeights { + score *= weights[i] + } + counts[el.member]++ + old, ok := sset[el.member] + if !ok { + sset[el.member] = score + continue + } + switch aggregate { + default: + panic("Invalid aggregate") + case "sum": + sset[el.member] += score + case "min": + if score < old { + sset[el.member] = score + } + case "max": + if score > old { + sset[el.member] = score + } + } + } + } + for key, count := range counts { + if count != numKeys { + delete(sset, key) + } + } + db.ssetSet(destination, sset) + c.WriteInt(len(sset)) + }) +} + +// ZLEXCOUNT +func (m *Miniredis) cmdZlexcount(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseLexrange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + max, maxIncl, err := parseLexrange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + // Just key sort. If scores are not the same we don't care. + sort.Strings(members) + members = withLexRange(members, min, minIncl, max, maxIncl) + + c.WriteInt(len(members)) + }) +} + +// ZRANGE and ZREVRANGE +func (m *Miniredis) makeCmdZrange(reverse bool) server.Cmd { + return func(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + start, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withScores := false + if len(args) > 4 { + c.WriteError(msgSyntaxError) + return + } + if len(args) == 4 { + if strings.ToLower(args[3]) != "withscores" { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + withScores = true + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + if reverse { + reverseSlice(members) + } + rs, re := redisRange(len(members), start, end, false) + if withScores { + c.WriteLen((re - rs) * 2) + } else { + c.WriteLen(re - rs) + } + for _, el := range members[rs:re] { + c.WriteBulk(el) + if withScores { + c.WriteBulk(formatFloat(db.ssetScore(key, el))) + } + } + }) + } +} + +// ZRANGEBYLEX +func (m *Miniredis) cmdZrangebylex(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseLexrange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + max, maxIncl, err := parseLexrange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + args = args[3:] + + withLimit := false + limitStart := 0 + limitEnd := 0 + for len(args) > 0 { + if strings.ToLower(args[0]) == "limit" { + withLimit = true + args = args[1:] + if len(args) < 2 { + c.WriteError(msgSyntaxError) + return + } + limitStart, err = strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + limitEnd, err = strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + continue + } + // Syntax error + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + // Just key sort. If scores are not the same we don't care. + sort.Strings(members) + members = withLexRange(members, min, minIncl, max, maxIncl) + + // Apply LIMIT ranges. That's . Unlike RANGE. + if withLimit { + if limitStart < 0 { + members = nil + } else { + if limitStart < len(members) { + members = members[limitStart:] + } else { + // out of range + members = nil + } + if limitEnd >= 0 { + if len(members) > limitEnd { + members = members[:limitEnd] + } + } + } + } + + c.WriteLen(len(members)) + for _, el := range members { + c.WriteBulk(el) + } + }) +} + +// ZRANGEBYSCORE and ZREVRANGEBYSCORE +func (m *Miniredis) makeCmdZrangebyscore(reverse bool) server.Cmd { + return func(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseFloatRange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + max, maxIncl, err := parseFloatRange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + args = args[3:] + + withScores := false + withLimit := false + limitStart := 0 + limitEnd := 0 + for len(args) > 0 { + if strings.ToLower(args[0]) == "limit" { + withLimit = true + args = args[1:] + if len(args) < 2 { + c.WriteError(msgSyntaxError) + return + } + limitStart, err = strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + limitEnd, err = strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + continue + } + if strings.ToLower(args[0]) == "withscores" { + withScores = true + args = args[1:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetElements(key) + if reverse { + min, max = max, min + minIncl, maxIncl = maxIncl, minIncl + } + members = withSSRange(members, min, minIncl, max, maxIncl) + if reverse { + reverseElems(members) + } + + // Apply LIMIT ranges. That's . Unlike RANGE. + if withLimit { + if limitStart < 0 { + members = ssElems{} + } else { + if limitStart < len(members) { + members = members[limitStart:] + } else { + // out of range + members = ssElems{} + } + if limitEnd >= 0 { + if len(members) > limitEnd { + members = members[:limitEnd] + } + } + } + } + + if withScores { + c.WriteLen(len(members) * 2) + } else { + c.WriteLen(len(members)) + } + for _, el := range members { + c.WriteBulk(el.member) + if withScores { + c.WriteBulk(formatFloat(el.score)) + } + } + }) + } +} + +// ZRANK and ZREVRANK +func (m *Miniredis) makeCmdZrank(reverse bool) server.Cmd { + return func(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, member := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteNull() + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + direction := asc + if reverse { + direction = desc + } + rank, ok := db.ssetRank(key, member, direction) + if !ok { + c.WriteNull() + return + } + c.WriteInt(rank) + }) + } +} + +// ZREM +func (m *Miniredis) cmdZrem(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, members := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + deleted := 0 + for _, member := range members { + if db.ssetRem(key, member) { + deleted++ + } + } + c.WriteInt(deleted) + }) +} + +// ZREMRANGEBYLEX +func (m *Miniredis) cmdZremrangebylex(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseLexrange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + max, maxIncl, err := parseLexrange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + // Just key sort. If scores are not the same we don't care. + sort.Strings(members) + members = withLexRange(members, min, minIncl, max, maxIncl) + + for _, el := range members { + db.ssetRem(key, el) + } + c.WriteInt(len(members)) + }) +} + +// ZREMRANGEBYRANK +func (m *Miniredis) cmdZremrangebyrank(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + start, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + rs, re := redisRange(len(members), start, end, false) + for _, el := range members[rs:re] { + db.ssetRem(key, el) + } + c.WriteInt(re - rs) + }) +} + +// ZREMRANGEBYSCORE +func (m *Miniredis) cmdZremrangebyscore(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + min, minIncl, err := parseFloatRange(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + max, maxIncl, err := parseFloatRange(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidMinMax) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetElements(key) + members = withSSRange(members, min, minIncl, max, maxIncl) + + for _, el := range members { + db.ssetRem(key, el.member) + } + c.WriteInt(len(members)) + }) +} + +// ZSCORE +func (m *Miniredis) cmdZscore(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, member := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteNull() + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + if !db.ssetExists(key, member) { + c.WriteNull() + return + } + + c.WriteBulk(formatFloat(db.ssetScore(key, member))) + }) +} + +// parseFloatRange handles ZRANGEBYSCORE floats. They are inclusive unless the +// string starts with '(' +func parseFloatRange(s string) (float64, bool, error) { + if len(s) == 0 { + return 0, false, nil + } + inclusive := true + if s[0] == '(' { + s = s[1:] + inclusive = false + } + f, err := strconv.ParseFloat(s, 64) + return f, inclusive, err +} + +// parseLexrange handles ZRANGEBYLEX ranges. They start with '[', '(', or are +// '+' or '-'. +// Returns range, inclusive, error. +// On '+' or '-' that's just returned. +func parseLexrange(s string) (string, bool, error) { + if len(s) == 0 { + return "", false, errInvalidRangeItem + } + if s == "+" || s == "-" { + return s, false, nil + } + switch s[0] { + case '(': + return s[1:], false, nil + case '[': + return s[1:], true, nil + default: + return "", false, errInvalidRangeItem + } +} + +// withSSRange limits a list of sorted set elements by the ZRANGEBYSCORE range +// logic. +func withSSRange(members ssElems, min float64, minIncl bool, max float64, maxIncl bool) ssElems { + gt := func(a, b float64) bool { return a > b } + gteq := func(a, b float64) bool { return a >= b } + + mincmp := gt + if minIncl { + mincmp = gteq + } + for i, m := range members { + if mincmp(m.score, min) { + members = members[i:] + goto checkmax + } + } + // all elements were smaller + return nil + +checkmax: + maxcmp := gteq + if maxIncl { + maxcmp = gt + } + for i, m := range members { + if maxcmp(m.score, max) { + members = members[:i] + break + } + } + + return members +} + +// withLexRange limits a list of sorted set elements. +func withLexRange(members []string, min string, minIncl bool, max string, maxIncl bool) []string { + if max == "-" || min == "+" { + return nil + } + if min != "-" { + if minIncl { + for i, m := range members { + if m >= min { + members = members[i:] + break + } + } + } else { + // Excluding min + for i, m := range members { + if m > min { + members = members[i:] + break + } + } + } + } + if max != "+" { + if maxIncl { + for i, m := range members { + if m > max { + members = members[:i] + break + } + } + } else { + // Excluding max + for i, m := range members { + if m >= max { + members = members[:i] + break + } + } + } + } + return members +} + +// ZUNIONSTORE +func (m *Miniredis) cmdZunionstore(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + destination := args[0] + numKeys, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + if len(args) < numKeys { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + if numKeys <= 0 { + setDirty(c) + c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE") + return + } + keys := args[:numKeys] + args = args[numKeys:] + + withWeights := false + weights := []float64{} + aggregate := "sum" + for len(args) > 0 { + if strings.ToLower(args[0]) == "weights" { + if len(args) < numKeys+1 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + for i := 0; i < numKeys; i++ { + f, err := strconv.ParseFloat(args[i+1], 64) + if err != nil { + setDirty(c) + c.WriteError("ERR weight value is not a float") + return + } + weights = append(weights, f) + } + withWeights = true + args = args[numKeys+1:] + continue + } + if strings.ToLower(args[0]) == "aggregate" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + aggregate = strings.ToLower(args[1]) + switch aggregate { + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + case "sum", "min", "max": + } + args = args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + db.del(destination, true) + + sset := sortedSet{} + for i, key := range keys { + if !db.exists(key) { + continue + } + if db.t(key) != "zset" { + c.WriteError(msgWrongType) + return + } + for _, el := range db.ssetElements(key) { + score := el.score + if withWeights { + score *= weights[i] + } + old, ok := sset[el.member] + if !ok { + sset[el.member] = score + continue + } + switch aggregate { + default: + panic("Invalid aggregate") + case "sum": + sset[el.member] += score + case "min": + if score < old { + sset[el.member] = score + } + case "max": + if score > old { + sset[el.member] = score + } + } + } + } + db.ssetSet(destination, sset) + c.WriteInt(sset.card()) + }) +} + +// ZSCAN +func (m *Miniredis) cmdZscan(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + cursor, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidCursor) + return + } + args = args[2:] + // MATCH and COUNT options + var withMatch bool + var match string + for len(args) > 0 { + if strings.ToLower(args[0]) == "count" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + _, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + // We do nothing with count. + args = args[2:] + continue + } + if strings.ToLower(args[0]) == "match" { + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + withMatch = true + match = args[1] + args = args[2:] + continue + } + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + // We return _all_ (matched) keys every time. + + if cursor != 0 { + // Invalid cursor. + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + c.WriteLen(0) // no elements + return + } + if db.exists(key) && db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + members := db.ssetMembers(key) + if withMatch { + members = matchKeys(members, match) + } + + c.WriteLen(2) + c.WriteBulk("0") // no next cursor + // HSCAN gives key, values. + c.WriteLen(len(members) * 2) + for _, k := range members { + c.WriteBulk(k) + c.WriteBulk(formatFloat(db.ssetScore(key, k))) + } + }) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_sorted_set_test.go b/vendor/github.com/alicebob/miniredis/cmd_sorted_set_test.go new file mode 100644 index 00000000..8211d833 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_sorted_set_test.go @@ -0,0 +1,1329 @@ +package miniredis + +import ( + "math" + "testing" + + "github.com/garyburd/redigo/redis" +) + +// Test ZADD / ZCARD / ZRANK / ZREVRANK. +func TestSortedSet(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("ZADD", "z", 1, "one", 2, "two", 3, "three")) + ok(t, err) + equals(t, 3, b) // New elements. + + b, err = redis.Int(c.Do("ZCARD", "z")) + ok(t, err) + equals(t, 3, b) + + m, err := redis.Int(c.Do("ZRANK", "z", "one")) + ok(t, err) + equals(t, 0, m) + m, err = redis.Int(c.Do("ZRANK", "z", "three")) + ok(t, err) + equals(t, 2, m) + + m, err = redis.Int(c.Do("ZREVRANK", "z", "one")) + ok(t, err) + equals(t, 2, m) + m, err = redis.Int(c.Do("ZREVRANK", "z", "three")) + ok(t, err) + equals(t, 0, m) + } + + // TYPE of our zset + { + s, err := redis.String(c.Do("TYPE", "z")) + ok(t, err) + equals(t, "zset", s) + } + + // Replace a key + { + b, err := redis.Int(c.Do("ZADD", "z", 2.1, "two")) + ok(t, err) + equals(t, 0, b) // No new elements. + + b, err = redis.Int(c.Do("ZCARD", "z")) + ok(t, err) + equals(t, 3, b) + } + + // To infinity! + { + b, err := redis.Int(c.Do("ZADD", "zinf", "inf", "plus inf", "-inf", "minus inf", 10, "ten")) + ok(t, err) + equals(t, 3, b) + + b, err = redis.Int(c.Do("ZCARD", "zinf")) + ok(t, err) + equals(t, 3, b) + + smap, err := s.SortedSet("zinf") + ok(t, err) + equals(t, map[string]float64{ + "plus inf": math.Inf(+1), + "minus inf": math.Inf(-1), + "ten": 10.0, + }, smap) + } + + // Invalid score + { + _, err := c.Do("ZADD", "z", "noint", "two") + assert(t, err != nil, "ZADD err") + } + + // ZRANK on non-existing key/member + { + m, err := c.Do("ZRANK", "z", "nosuch") + ok(t, err) + equals(t, nil, m) + + m, err = c.Do("ZRANK", "nosuch", "nosuch") + ok(t, err) + equals(t, nil, m) + } + + // Direct usage + { + added, err := s.ZAdd("s1", 12.4, "aap") + ok(t, err) + equals(t, true, added) + added, err = s.ZAdd("s1", 3.4, "noot") + ok(t, err) + equals(t, true, added) + added, err = s.ZAdd("s1", 3.5, "noot") + ok(t, err) + equals(t, false, added) + + members, err := s.ZMembers("s1") + ok(t, err) + equals(t, []string{"noot", "aap"}, members) + } + + // Error cases + { + // Wrong type of key + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + + _, err = redis.Int(c.Do("ZRANK", "str")) + assert(t, err != nil, "ZRANK error") + _, err = redis.String(c.Do("ZRANK")) + assert(t, err != nil, "ZRANK error") + _, err = redis.String(c.Do("ZRANK", "set", "spurious")) + assert(t, err != nil, "ZRANK error") + + _, err = redis.String(c.Do("ZDEVRANK")) + assert(t, err != nil, "ZDEVRANK error") + + _, err = redis.Int(c.Do("ZCARD", "str")) + assert(t, err != nil, "ZCARD error") + _, err = redis.String(c.Do("ZCARD")) + assert(t, err != nil, "ZCARD error") + _, err = redis.String(c.Do("ZCARD", "set", "spurious")) + assert(t, err != nil, "ZCARD error") + } +} + +// Test ZADD +func TestSortedSetAdd(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + { + b, err := redis.Int(c.Do("ZADD", "z", 1, "one", 2, "two", 3, "three")) + ok(t, err) + equals(t, 3, b) // New elements. + + b, err = redis.Int(c.Do("ZADD", "z", 1, "one", 2.1, "two", 3, "three")) + ok(t, err) + equals(t, 0, b) // no new elements + + b, err = redis.Int(c.Do("ZADD", "z", "CH", 1, "one", 2.2, "two", 3, "three")) + ok(t, err) + equals(t, 1, b) + + b, err = redis.Int(c.Do("ZADD", "z", "NX", 1, "one", 2.2, "two", 3, "three")) + ok(t, err) + equals(t, 0, b) + + b, err = redis.Int(c.Do("ZADD", "z", "NX", 1, "one", 4, "four")) + ok(t, err) + equals(t, 1, b) + + b, err = redis.Int(c.Do("ZADD", "z", "XX", 1.1, "one", 4, "four")) + ok(t, err) + equals(t, 0, b) + + b, err = redis.Int(c.Do("ZADD", "z", "XX", "CH", 1.2, "one", 4, "four")) + ok(t, err) + equals(t, 1, b) + + } + + // Error cases + { + // Wrong type of key + _, err := redis.String(c.Do("SET", "str", "value")) + ok(t, err) + + _, err = redis.Int(c.Do("ZADD", "str", 1.0, "hi")) + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD")) + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD", "set")) + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD", "set", 1.0)) + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD", "set", 1.0, "foo", 1.0)) // odd + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD", "set", "MX", 1.0)) + assert(t, err != nil, "ZADD error") + _, err = redis.String(c.Do("ZADD", "set", "MX", "XX", 1.0, "foo")) + assert(t, err != nil, "ZADD error") + } +} + +// Test ZRANGE and ZREVRANGE +func TestSortedSetRange(t *testing.T) { + // ZREVRANGE is the same code as ZRANGE + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + s.ZAdd("z", 3, "three") + s.ZAdd("z", 3, "drei") + s.ZAdd("z", math.Inf(+1), "inf") + + { + b, err := redis.Strings(c.Do("ZRANGE", "z", 0, -1)) + ok(t, err) + equals(t, []string{"one", "two", "zwei", "drei", "three", "inf"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGE", "z", 0, -1)) + ok(t, err) + equals(t, []string{"inf", "three", "drei", "zwei", "two", "one"}, b) + } + { + b, err := redis.Strings(c.Do("ZRANGE", "z", 0, 1)) + ok(t, err) + equals(t, []string{"one", "two"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGE", "z", 0, 1)) + ok(t, err) + equals(t, []string{"inf", "three"}, b) + } + { + b, err := redis.Strings(c.Do("ZRANGE", "z", -1, -1)) + ok(t, err) + equals(t, []string{"inf"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGE", "z", -1, -1)) + ok(t, err) + equals(t, []string{"one"}, b) + } + + // weird cases. + { + b, err := redis.Strings(c.Do("ZRANGE", "z", -100, -100)) + ok(t, err) + equals(t, []string{}, b) + } + { + b, err := redis.Strings(c.Do("ZRANGE", "z", 100, 400)) + ok(t, err) + equals(t, []string{}, b) + } + // Nonexistent key + { + b, err := redis.Strings(c.Do("ZRANGE", "nosuch", 1, 4)) + ok(t, err) + equals(t, []string{}, b) + } + + // With scores + { + b, err := redis.Strings(c.Do("ZRANGE", "z", 1, 2, "WITHSCORES")) + ok(t, err) + equals(t, []string{"two", "2", "zwei", "2"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGE", "z", 1, 2, "WITHSCORES")) + ok(t, err) + equals(t, []string{"three", "3", "drei", "3"}, b) + } + // INF in WITHSCORES + { + b, err := redis.Strings(c.Do("ZRANGE", "z", 4, -1, "WITHSCORES")) + ok(t, err) + equals(t, []string{"three", "3", "inf", "inf"}, b) + } + + // Error cases + { + _, err = redis.String(c.Do("ZRANGE")) + assert(t, err != nil, "ZRANGE error") + _, err = redis.String(c.Do("ZREVRANGE")) + assert(t, err != nil, "ZREVRANGE error") + _, err = redis.String(c.Do("ZRANGE", "set")) + assert(t, err != nil, "ZRANGE error") + _, err = redis.String(c.Do("ZRANGE", "set", 1)) + assert(t, err != nil, "ZRANGE error") + _, err = redis.String(c.Do("ZRANGE", "set", "noint", 1)) + assert(t, err != nil, "ZRANGE error") + _, err = redis.String(c.Do("ZRANGE", "set", 1, "noint")) + assert(t, err != nil, "ZRANGE error") + _, err = redis.String(c.Do("ZRANGE", "set", 1, 2, "toomany")) + assert(t, err != nil, "ZRANGE error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZRANGE", "str", 1, 2)) + assert(t, err != nil, "ZRANGE error") + } +} + +// Test ZRANGEBYSCORE, ZREVRANGEBYSCORE, and ZCOUNT +func TestSortedSetRangeByScore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", -273.15, "zero kelvin") + s.ZAdd("z", -4, "minusfour") + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + s.ZAdd("z", 3, "three") + s.ZAdd("z", 3, "drei") + s.ZAdd("z", math.Inf(+1), "inf") + + // Normal cases + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf")) + ok(t, err) + equals(t, []string{"zero kelvin", "minusfour", "one", "two", "zwei", "drei", "three", "inf"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "inf", "-inf")) + ok(t, err) + equals(t, []string{"inf", "three", "drei", "zwei", "two", "one", "minusfour", "zero kelvin"}, b) + + i, err := redis.Int(c.Do("ZCOUNT", "z", "-inf", "inf")) + ok(t, err) + equals(t, 8, i) + } + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "2", "3")) + ok(t, err) + equals(t, []string{"two", "zwei", "drei", "three"}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "z", "4", "4")) + ok(t, err) + equals(t, []string{}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "3", "2")) + ok(t, err) + equals(t, []string{"three", "drei", "zwei", "two"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "4", "4")) + ok(t, err) + equals(t, []string{}, b) + + i, err := redis.Int(c.Do("ZCOUNT", "z", "2", "3")) + ok(t, err) + equals(t, 4, i) + } + // Exclusive min + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "(2", "3")) + ok(t, err) + equals(t, []string{"drei", "three"}, b) + + i, err := redis.Int(c.Do("ZCOUNT", "z", "(2", "3")) + ok(t, err) + equals(t, 2, i) + } + // Exclusive max + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "2", "(3")) + ok(t, err) + equals(t, []string{"two", "zwei"}, b) + } + // Exclusive both + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "(2", "(3")) + ok(t, err) + equals(t, []string{}, b) + } + // Wrong ranges + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "+inf", "-inf")) + ok(t, err) + equals(t, []string{}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "-inf", "+inf")) + ok(t, err) + equals(t, []string{}, b) + } + + // No such key + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "nosuch", "-inf", "inf")) + ok(t, err) + equals(t, []string{}, b) + } + + // With scores + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "(1", 2, "WITHSCORES")) + ok(t, err) + equals(t, []string{"two", "2", "zwei", "2"}, b) + } + + // With LIMIT + // (note, this is SQL like logic, not the redis RANGE logic) + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf", "LIMIT", 1, 2)) + ok(t, err) + equals(t, []string{"minusfour", "one"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "inf", "-inf", "LIMIT", 1, 2)) + ok(t, err) + equals(t, []string{"three", "drei"}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "z", "1", "inf", "LIMIT", 1, 2000)) + ok(t, err) + equals(t, []string{"two", "zwei", "drei", "three", "inf"}, b) + + b, err = redis.Strings(c.Do("ZREVRANGEBYSCORE", "z", "inf", "1", "LIMIT", 1, 2000)) + ok(t, err) + equals(t, []string{"three", "drei", "zwei", "two", "one"}, b) + + // Negative start limit. No go. + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf", "LIMIT", -1, 2)) + ok(t, err) + equals(t, []string{}, b) + + // Negative end limit. Is fine but ignored. + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf", "LIMIT", 1, -2)) + ok(t, err) + equals(t, []string{"minusfour", "one", "two", "zwei", "drei", "three", "inf"}, b) + } + // Everything + { + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf", "WITHSCORES", "LIMIT", 1, 2)) + ok(t, err) + equals(t, []string{"minusfour", "-4", "one", "1"}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "z", "-inf", "inf", "LIMIT", 1, 2, "WITHSCORES")) + ok(t, err) + equals(t, []string{"minusfour", "-4", "one", "1"}, b) + } + + // Error cases + { + _, err = redis.String(c.Do("ZRANGEBYSCORE")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", 1)) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", "nofloat", 1)) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", 1, "nofloat")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", 1, 2, "toomany")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", "[1", 2, "toomany")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", 1, "[2", "toomany")) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", "[1", 2, "LIMIT", "noint", 1)) + assert(t, err != nil, "ZRANGEBYSCORE error") + _, err = redis.String(c.Do("ZRANGEBYSCORE", "set", "[1", 2, "LIMIT", 1, "noint")) + assert(t, err != nil, "ZRANGEBYSCORE error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZRANGEBYSCORE", "str", 1, 2)) + assert(t, err != nil, "ZRANGEBYSCORE error") + + _, err = redis.String(c.Do("ZREVRANGEBYSCORE")) + assert(t, err != nil, "ZREVRANGEBYSCORE error") + + _, err = redis.String(c.Do("ZCOUNT")) + assert(t, err != nil, "ZCOUNT error") + } +} + +func TestIssue10(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("key", 3.3, "element") + + b, err := redis.Strings(c.Do("ZRANGEBYSCORE", "key", "3.3", "3.3")) + ok(t, err) + equals(t, []string{"element"}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYSCORE", "key", "4.3", "4.3")) + ok(t, err) + equals(t, []string{}, b) +} + +// Test ZREM +func TestSortedSetRem(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + + // Simple delete + { + b, err := redis.Int(c.Do("ZREM", "z", "two", "zwei", "nosuch")) + ok(t, err) + equals(t, 2, b) + assert(t, s.Exists("z"), "key is there") + } + // Delete the last member + { + b, err := redis.Int(c.Do("ZREM", "z", "one")) + ok(t, err) + equals(t, 1, b) + assert(t, !s.Exists("z"), "key is gone") + } + // Nonexistent key + { + b, err := redis.Int(c.Do("ZREM", "nosuch", "member")) + ok(t, err) + equals(t, 0, b) + } + + // Direct + { + s.ZAdd("z2", 1, "one") + s.ZAdd("z2", 2, "two") + s.ZAdd("z2", 2, "zwei") + gone, err := s.ZRem("z2", "two") + ok(t, err) + assert(t, gone, "member gone") + members, err := s.ZMembers("z2") + ok(t, err) + equals(t, []string{"one", "zwei"}, members) + } + + // Error cases + { + _, err = redis.String(c.Do("ZREM")) + assert(t, err != nil, "ZREM error") + _, err = redis.String(c.Do("ZREM", "set")) + assert(t, err != nil, "ZREM error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZREM", "str", "aap")) + assert(t, err != nil, "ZREM error") + } +} + +// Test ZREMRANGEBYLEX +func TestSortedSetRemRangeByLex(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 12, "zero kelvin") + s.ZAdd("z", 12, "minusfour") + s.ZAdd("z", 12, "one") + s.ZAdd("z", 12, "oneone") + s.ZAdd("z", 12, "two") + s.ZAdd("z", 12, "zwei") + s.ZAdd("z", 12, "three") + s.ZAdd("z", 12, "drei") + s.ZAdd("z", 12, "inf") + + // Inclusive range + { + b, err := redis.Int(c.Do("ZREMRANGEBYLEX", "z", "[o", "[three")) + ok(t, err) + equals(t, 3, b) + + members, err := s.ZMembers("z") + ok(t, err) + equals(t, + []string{"drei", "inf", "minusfour", "two", "zero kelvin", "zwei"}, + members, + ) + } + + // Wrong ranges + { + b, err := redis.Int(c.Do("ZREMRANGEBYLEX", "z", "+", "(z")) + ok(t, err) + equals(t, 0, b) + } + + // No such key + { + b, err := redis.Int(c.Do("ZREMRANGEBYLEX", "nosuch", "-", "+")) + ok(t, err) + equals(t, 0, b) + } + + // Error cases + { + _, err = c.Do("ZREMRANGEBYLEX") + assert(t, err != nil, "ZREMRANGEBYLEX error") + _, err = c.Do("ZREMRANGEBYLEX", "set") + assert(t, err != nil, "ZREMRANGEBYLEX error") + _, err = c.Do("ZREMRANGEBYLEX", "set", "1", "[a") + assert(t, err != nil, "ZREMRANGEBYLEX error") + _, err = c.Do("ZREMRANGEBYLEX", "set", "[a", "1") + assert(t, err != nil, "ZREMRANGEBYLEX error") + _, err = c.Do("ZREMRANGEBYLEX", "set", "[a", "!a") + assert(t, err != nil, "ZREMRANGEBYLEX error") + _, err = c.Do("ZREMRANGEBYLEX", "set", "-", "+", "toomany") + assert(t, err != nil, "ZREMRANGEBYLEX error") + // Wrong type of key + s.Set("str", "value") + _, err = c.Do("ZREMRANGEBYLEX", "str", "-", "+") + assert(t, err != nil, "ZREMRANGEBYLEX error") + } +} + +// Test ZREMRANGEBYRANK +func TestSortedSetRemRangeByRank(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + s.ZAdd("z", 3, "three") + s.ZAdd("z", 3, "drei") + s.ZAdd("z", math.Inf(+1), "inf") + + { + n, err := redis.Int(c.Do("ZREMRANGEBYRANK", "z", -2, -1)) + ok(t, err) + equals(t, 2, n) + + b, err := redis.Strings(c.Do("ZRANGE", "z", 0, -1)) + ok(t, err) + equals(t, []string{"one", "two", "zwei", "drei"}, b) + } + + // weird cases. + { + n, err := redis.Int(c.Do("ZREMRANGEBYRANK", "z", -100, -100)) + ok(t, err) + equals(t, 0, n) + } + { + n, err := redis.Int(c.Do("ZREMRANGEBYRANK", "z", 100, 400)) + ok(t, err) + equals(t, 0, n) + } + // Nonexistent key + { + n, err := redis.Int(c.Do("ZREMRANGEBYRANK", "nosuch", 1, 4)) + ok(t, err) + equals(t, 0, n) + } + + // Delete all. Key should be gone. + { + n, err := redis.Int(c.Do("ZREMRANGEBYRANK", "z", 0, -1)) + ok(t, err) + equals(t, 4, n) + equals(t, false, s.Exists("z")) + } + + // Error cases + { + _, err = redis.String(c.Do("ZREMRANGEBYRANK")) + assert(t, err != nil, "ZREMRANGEBYRANK error") + _, err = redis.String(c.Do("ZREMRANGEBYRANK", "set")) + assert(t, err != nil, "ZREMRANGEBYRANK error") + _, err = redis.String(c.Do("ZREMRANGEBYRANK", "set", 1)) + assert(t, err != nil, "ZREMRANGEBYRANK error") + _, err = redis.String(c.Do("ZREMRANGEBYRANK", "set", "noint", 1)) + assert(t, err != nil, "ZREMRANGEBYRANK error") + _, err = redis.String(c.Do("ZREMRANGEBYRANK", "set", 1, "noint")) + assert(t, err != nil, "ZREMRANGEBYRANK error") + _, err = redis.String(c.Do("ZREMRANGEBYRANK", "set", 1, 2, "toomany")) + assert(t, err != nil, "ZREMRANGEBYRANK error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZREMRANGEBYRANK", "str", 1, 2)) + assert(t, err != nil, "ZREMRANGEBYRANK error") + } +} + +// Test ZREMRANGEBYSCORE +func TestSortedSetRangeRemByScore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", -273.15, "zero kelvin") + s.ZAdd("z", -4, "minusfour") + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + s.ZAdd("z", 3, "three") + s.ZAdd("z", 3, "drei") + s.ZAdd("z", math.Inf(+1), "inf") + + // Normal cases + { + n, err := redis.Int(c.Do("ZREMRANGEBYSCORE", "z", "-inf", 1)) + ok(t, err) + equals(t, 3, n) + + b, err := redis.Strings(c.Do("ZRANGE", "z", 0, -1)) + ok(t, err) + equals(t, []string{"two", "zwei", "drei", "three", "inf"}, b) + } + // Exclusive min + { + n, err := redis.Int(c.Do("ZREMRANGEBYSCORE", "z", "(2", "(4")) + ok(t, err) + equals(t, 2, n) + + b, err := redis.Strings(c.Do("ZRANGE", "z", 0, -1)) + ok(t, err) + equals(t, []string{"two", "zwei", "inf"}, b) + } + + // Wrong ranges + { + n, err := redis.Int(c.Do("ZREMRANGEBYSCORE", "z", "+inf", "-inf")) + ok(t, err) + equals(t, 0, n) + } + + // No such key + { + n, err := redis.Int(c.Do("ZREMRANGEBYSCORE", "nosuch", "-inf", "inf")) + ok(t, err) + equals(t, 0, n) + } + + // Error cases + { + _, err = redis.String(c.Do("ZREMRANGEBYSCORE")) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + _, err = redis.String(c.Do("ZREMRANGEBYSCORE", "set")) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + _, err = redis.String(c.Do("ZREMRANGEBYSCORE", "set", 1)) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + _, err = redis.String(c.Do("ZREMRANGEBYSCORE", "set", "nofloat", 1)) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + _, err = redis.String(c.Do("ZREMRANGEBYSCORE", "set", 1, "nofloat")) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + _, err = redis.String(c.Do("ZREMRANGEBYSCORE", "set", 1, 2, "toomany")) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZREMRANGEBYSCORE", "str", 1, 2)) + assert(t, err != nil, "ZREMRANGEBYSCORE error") + } +} + +// Test ZSCORE +func TestSortedSetScore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + + // Simple case + { + b, err := redis.Float64(c.Do("ZSCORE", "z", "two")) + ok(t, err) + equals(t, 2.0, b) + } + // no such member + { + b, err := c.Do("ZSCORE", "z", "nosuch") + ok(t, err) + equals(t, nil, b) + } + // no such key + { + b, err := c.Do("ZSCORE", "nosuch", "nosuch") + ok(t, err) + equals(t, nil, b) + } + + // Direct + { + s.ZAdd("z2", 1, "one") + s.ZAdd("z2", 2, "two") + score, err := s.ZScore("z2", "two") + ok(t, err) + equals(t, 2.0, score) + } + + // Error cases + { + _, err = redis.String(c.Do("ZSCORE")) + assert(t, err != nil, "ZSCORE error") + _, err = redis.String(c.Do("ZSCORE", "key")) + assert(t, err != nil, "ZSCORE error") + _, err = redis.String(c.Do("ZSCORE", "too", "many", "arguments")) + assert(t, err != nil, "ZSCORE error") + // Wrong type of key + s.Set("str", "value") + _, err = redis.Int(c.Do("ZSCORE", "str", "aap")) + assert(t, err != nil, "ZSCORE error") + } +} + +// Test ZRANGEBYLEX, ZLEXCOUNT +func TestSortedSetRangeByLex(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("z", 12, "zero kelvin") + s.ZAdd("z", 12, "minusfour") + s.ZAdd("z", 12, "one") + s.ZAdd("z", 12, "oneone") + s.ZAdd("z", 12, "two") + s.ZAdd("z", 12, "zwei") + s.ZAdd("z", 12, "three") + s.ZAdd("z", 12, "drei") + s.ZAdd("z", 12, "inf") + + // Normal cases + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "z", "-", "+")) + ok(t, err) + equals(t, []string{"drei", "inf", "minusfour", "one", "oneone", "three", "two", "zero kelvin", "zwei"}, b) + + i, err := redis.Int(c.Do("ZLEXCOUNT", "z", "-", "+")) + ok(t, err) + equals(t, 9, i) + } + // Inclusive range + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "z", "[o", "[three")) + ok(t, err) + equals(t, []string{"one", "oneone", "three"}, b) + + i, err := redis.Int(c.Do("ZLEXCOUNT", "z", "[o", "[three")) + ok(t, err) + equals(t, 3, i) + } + // Inclusive range + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "z", "(o", "(z")) + ok(t, err) + equals(t, []string{"one", "oneone", "three", "two"}, b) + + i, err := redis.Int(c.Do("ZLEXCOUNT", "z", "(o", "(z")) + ok(t, err) + equals(t, 4, i) + } + // Wrong ranges + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "z", "+", "(z")) + ok(t, err) + equals(t, []string{}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYLEX", "z", "(a", "-")) + ok(t, err) + equals(t, []string{}, b) + + b, err = redis.Strings(c.Do("ZRANGEBYLEX", "z", "(z", "(a")) + ok(t, err) + equals(t, []string{}, b) + + i, err := redis.Int(c.Do("ZLEXCOUNT", "z", "(z", "(z")) + ok(t, err) + equals(t, 0, i) + } + + // No such key + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "nosuch", "-", "+")) + ok(t, err) + equals(t, []string{}, b) + + i, err := redis.Int(c.Do("ZLEXCOUNT", "nosuch", "-", "+")) + ok(t, err) + equals(t, 0, i) + } + + // With LIMIT + // (note, this is SQL like logic, not the redis RANGE logic) + { + b, err := redis.Strings(c.Do("ZRANGEBYLEX", "z", "-", "+", "LIMIT", 1, 2)) + ok(t, err) + equals(t, []string{"inf", "minusfour"}, b) + + // Negative start limit. No go. + b, err = redis.Strings(c.Do("ZRANGEBYLEX", "z", "-", "+", "LIMIT", -1, 2)) + ok(t, err) + equals(t, []string{}, b) + + // Negative end limit. Is fine but ignored. + b, err = redis.Strings(c.Do("ZRANGEBYLEX", "z", "-", "+", "LIMIT", 1, -2)) + ok(t, err) + equals(t, []string{"inf", "minusfour", "one", "oneone", "three", "two", "zero kelvin", "zwei"}, b) + } + + // Error cases + { + _, err = c.Do("ZRANGEBYLEX") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "1", "[a") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "[a", "1") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "[a", "!a") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "-", "+", "toomany") + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "[1", "(1", "LIMIT", "noint", 1) + assert(t, err != nil, "ZRANGEBYLEX error") + _, err = c.Do("ZRANGEBYLEX", "set", "[1", "(1", "LIMIT", 1, "noint") + assert(t, err != nil, "ZRANGEBYLEX error") + // Wrong type of key + s.Set("str", "value") + _, err = c.Do("ZRANGEBYLEX", "str", "-", "+") + assert(t, err != nil, "ZRANGEBYLEX error") + + _, err = c.Do("ZLEXCOUNT") + assert(t, err != nil, "ZLEXCOUNT error") + _, err = c.Do("ZLEXCOUNT", "k") + assert(t, err != nil, "ZLEXCOUNT error") + _, err = c.Do("ZLEXCOUNT", "k", "[a", "a") + assert(t, err != nil, "ZLEXCOUNT error") + _, err = c.Do("ZLEXCOUNT", "k", "a", "(a") + assert(t, err != nil, "ZLEXCOUNT error") + _, err = c.Do("ZLEXCOUNT", "k", "(a", "(a", "toomany") + assert(t, err != nil, "ZLEXCOUNT error") + } +} + +// Test ZINCRBY +func TestSortedSetIncrby(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Normal cases + { + // New key + b, err := redis.Float64(c.Do("ZINCRBY", "z", 1, "member")) + ok(t, err) + equals(t, 1.0, b) + + // Existing key + b, err = redis.Float64(c.Do("ZINCRBY", "z", 2.5, "member")) + ok(t, err) + equals(t, 3.5, b) + + // New member + b, err = redis.Float64(c.Do("ZINCRBY", "z", 1, "othermember")) + ok(t, err) + equals(t, 1.0, b) + } + + // Error cases + { + _, err = redis.String(c.Do("ZINCRBY")) + assert(t, err != nil, "ZINCRBY error") + _, err = redis.String(c.Do("ZINCRBY", "set")) + assert(t, err != nil, "ZINCRBY error") + _, err = redis.String(c.Do("ZINCRBY", "set", "nofloat", "a")) + assert(t, err != nil, "ZINCRBY error") + _, err = redis.String(c.Do("ZINCRBY", "set", 1.0, "too", "many")) + assert(t, err != nil, "ZINCRBY error") + // Wrong type of key + s.Set("str", "value") + _, err = c.Do("ZINCRBY", "str", 1.0, "member") + assert(t, err != nil, "ZINCRBY error") + } +} + +func TestZscan(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // We cheat with zscan. It always returns everything. + + s.ZAdd("h", 1.0, "field1") + s.ZAdd("h", 2.0, "field2") + + // No problem + { + res, err := redis.Values(c.Do("ZSCAN", "h", 0)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"field1", "1", "field2", "2"}, keys) + } + + // Invalid cursor + { + res, err := redis.Values(c.Do("ZSCAN", "h", 42)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string(nil), keys) + } + + // COUNT (ignored) + { + res, err := redis.Values(c.Do("ZSCAN", "h", 0, "COUNT", 200)) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"field1", "1", "field2", "2"}, keys) + } + + // MATCH + { + s.ZAdd("h", 3.0, "aap") + s.ZAdd("h", 4.0, "noot") + s.ZAdd("h", 5.0, "mies") + res, err := redis.Values(c.Do("ZSCAN", "h", 0, "MATCH", "mi*")) + ok(t, err) + equals(t, 2, len(res)) + + var c int + var keys []string + _, err = redis.Scan(res, &c, &keys) + ok(t, err) + equals(t, 0, c) + equals(t, []string{"mies", "5"}, keys) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("ZSCAN")) + assert(t, err != nil, "do ZSCAN error") + _, err = redis.Int(c.Do("ZSCAN", "set")) + assert(t, err != nil, "do ZSCAN error") + _, err = redis.Int(c.Do("ZSCAN", "set", "noint")) + assert(t, err != nil, "do ZSCAN error") + _, err = redis.Int(c.Do("ZSCAN", "set", 1, "MATCH")) + assert(t, err != nil, "do ZSCAN error") + _, err = redis.Int(c.Do("ZSCAN", "set", 1, "COUNT")) + assert(t, err != nil, "do ZSCAN error") + _, err = redis.Int(c.Do("ZSCAN", "set", 1, "COUNT", "noint")) + assert(t, err != nil, "do ZSCAN error") + s.Set("str", "value") + _, err = redis.Int(c.Do("ZSCAN", "str", 1)) + assert(t, err != nil, "do ZSCAN error") + } +} + +func TestZunionstore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("h1", 1.0, "field1") + s.ZAdd("h1", 2.0, "field2") + s.ZAdd("h2", 1.0, "field1") + s.ZAdd("h2", 2.0, "field2") + + // Simple case + { + res, err := redis.Int(c.Do("ZUNIONSTORE", "new", 2, "h1", "h2")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("new") + ok(t, err) + equals(t, map[string]float64{"field1": 2, "field2": 4}, ss) + } + + // WEIGHTS + { + res, err := redis.Int(c.Do("ZUNIONSTORE", "weighted", 2, "h1", "h2", "WeIgHtS", "4.5", "12")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("weighted") + ok(t, err) + equals(t, map[string]float64{"field1": 16.5, "field2": 33}, ss) + } + + // AGGREGATE + { + res, err := redis.Int(c.Do("ZUNIONSTORE", "aggr", 2, "h1", "h2", "AgGrEgAtE", "min")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("aggr") + ok(t, err) + equals(t, map[string]float64{"field1": 1.0, "field2": 2.0}, ss) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("ZUNIONSTORE")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", "noint")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 0, "key")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", -1, "key")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 1, "too", "many")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "key")) + assert(t, err != nil, "do ZUNIONSTORE error") + + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "WEIGHTS")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "WEIGHTS", 1, 2, 3)) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "WEIGHTS", 1, "nof")) + assert(t, err != nil, "do ZUNIONSTORE error") + + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "AGGREGATE")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "AGGREGATE", "foo")) + assert(t, err != nil, "do ZUNIONSTORE error") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 2, "k1", "k2", "AGGREGATE", "sum", "foo")) + assert(t, err != nil, "do ZUNIONSTORE error") + + s.Set("str", "value") + _, err = redis.Int(c.Do("ZUNIONSTORE", "set", 1, "str")) + assert(t, err != nil, "do ZUNIONSTORE error") + } +} + +func TestZinterstore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.ZAdd("h1", 1.0, "field1") + s.ZAdd("h1", 2.0, "field2") + s.ZAdd("h1", 3.0, "field3") + s.ZAdd("h2", 1.0, "field1") + s.ZAdd("h2", 2.0, "field2") + s.ZAdd("h2", 4.0, "field4") + + // Simple case + { + res, err := redis.Int(c.Do("ZINTERSTORE", "new", 2, "h1", "h2")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("new") + ok(t, err) + equals(t, map[string]float64{"field1": 2, "field2": 4}, ss) + } + + // WEIGHTS + { + res, err := redis.Int(c.Do("ZINTERSTORE", "weighted", 2, "h1", "h2", "WeIgHtS", "4.5", "12")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("weighted") + ok(t, err) + equals(t, map[string]float64{"field1": 16.5, "field2": 33}, ss) + } + + // AGGREGATE + { + res, err := redis.Int(c.Do("ZINTERSTORE", "aggr", 2, "h1", "h2", "AgGrEgAtE", "min")) + ok(t, err) + equals(t, 2, res) + + ss, err := s.SortedSet("aggr") + ok(t, err) + equals(t, map[string]float64{"field1": 1.0, "field2": 2.0}, ss) + } + + // Wrong usage + { + _, err := redis.Int(c.Do("ZINTERSTORE")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", "noint")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 0, "key")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", -1, "key")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 1, "too", "many")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "key")) + assert(t, err != nil, "do ZINTERSTORE error") + + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "WEIGHTS")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "WEIGHTS", 1, 2, 3)) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "WEIGHTS", 1, "nof")) + assert(t, err != nil, "do ZINTERSTORE error") + + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "AGGREGATE")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "AGGREGATE", "foo")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "k1", "k2", "AGGREGATE", "sum", "foo")) + assert(t, err != nil, "do ZINTERSTORE error") + + s.Set("str", "value") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 1, "str")) + assert(t, err != nil, "do ZINTERSTORE error") + _, err = redis.Int(c.Do("ZINTERSTORE", "set", 2, "set", "str")) + assert(t, err != nil, "do ZINTERSTORE error") + } +} + +func TestSSRange(t *testing.T) { + ss := newSortedSet() + ss.set(1.0, "key1") + ss.set(5.0, "key5") + elems := ss.byScore(asc) + type cas struct { + min, max float64 + minInc, maxInc bool + want []string + } + for _, c := range []cas{ + { + min: 2.0, + minInc: true, + max: 3.0, + maxInc: true, + want: []string(nil), + }, + { + min: -2.0, + minInc: true, + max: -3.0, + maxInc: true, + want: []string(nil), + }, + { + min: 12.0, + minInc: true, + max: 13.0, + maxInc: true, + want: []string(nil), + }, + { + min: 1.0, + minInc: false, + max: 3.0, + maxInc: true, + want: []string(nil), + }, + { + min: 2.0, + minInc: true, + max: 5.0, + maxInc: false, + want: []string(nil), + }, + { + min: 0.0, + max: 2.0, + want: []string{"key1"}, + }, + { + min: 2.0, + max: 7.0, + want: []string{"key5"}, + }, + { + min: 0.0, + max: 7.0, + want: []string{"key1", "key5"}, + }, + { + min: 1.0, + minInc: false, + max: 5.0, + maxInc: false, + want: []string(nil), + }, + } { + var have []string + for _, v := range withSSRange(elems, c.min, c.minInc, c.max, c.maxInc) { + have = append(have, v.member) + } + equals(t, have, c.want) + } +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_string.go b/vendor/github.com/alicebob/miniredis/cmd_string.go new file mode 100644 index 00000000..732cfb7f --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_string.go @@ -0,0 +1,1076 @@ +// Commands from http://redis.io/commands#string + +package miniredis + +import ( + "strconv" + "strings" + "time" + + "github.com/alicebob/miniredis/server" +) + +// commandsString handles all string value operations. +func commandsString(m *Miniredis) { + m.srv.Register("APPEND", m.cmdAppend) + m.srv.Register("BITCOUNT", m.cmdBitcount) + m.srv.Register("BITOP", m.cmdBitop) + m.srv.Register("BITPOS", m.cmdBitpos) + m.srv.Register("DECRBY", m.cmdDecrby) + m.srv.Register("DECR", m.cmdDecr) + m.srv.Register("GETBIT", m.cmdGetbit) + m.srv.Register("GET", m.cmdGet) + m.srv.Register("GETRANGE", m.cmdGetrange) + m.srv.Register("GETSET", m.cmdGetset) + m.srv.Register("INCRBYFLOAT", m.cmdIncrbyfloat) + m.srv.Register("INCRBY", m.cmdIncrby) + m.srv.Register("INCR", m.cmdIncr) + m.srv.Register("MGET", m.cmdMget) + m.srv.Register("MSET", m.cmdMset) + m.srv.Register("MSETNX", m.cmdMsetnx) + m.srv.Register("PSETEX", m.cmdPsetex) + m.srv.Register("SETBIT", m.cmdSetbit) + m.srv.Register("SETEX", m.cmdSetex) + m.srv.Register("SET", m.cmdSet) + m.srv.Register("SETNX", m.cmdSetnx) + m.srv.Register("SETRANGE", m.cmdSetrange) + m.srv.Register("STRLEN", m.cmdStrlen) +} + +// SET +func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + var ( + nx = false // set iff not exists + xx = false // set iff exists + ttl time.Duration + ) + + key, value, args := args[0], args[1], args[2:] + for len(args) > 0 { + timeUnit := time.Second + switch strings.ToUpper(args[0]) { + case "NX": + nx = true + args = args[1:] + continue + case "XX": + xx = true + args = args[1:] + continue + case "PX": + timeUnit = time.Millisecond + fallthrough + case "EX": + if len(args) < 2 { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + expire, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + ttl = time.Duration(expire) * timeUnit + if ttl <= 0 { + setDirty(c) + c.WriteError(msgInvalidSETime) + return + } + + args = args[2:] + continue + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if nx { + if db.exists(key) { + c.WriteNull() + return + } + } + if xx { + if !db.exists(key) { + c.WriteNull() + return + } + } + + db.del(key, true) // be sure to remove existing values of other type keys. + // a vanilla SET clears the expire + db.stringSet(key, value) + if ttl != 0 { + db.ttl[key] = ttl + } + c.WriteOK() + }) +} + +// SETEX +func (m *Miniredis) cmdSetex(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + ttl, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + if ttl <= 0 { + setDirty(c) + c.WriteError(msgInvalidSETEXTime) + return + } + value := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + db.del(key, true) // Clear any existing keys. + db.stringSet(key, value) + db.ttl[key] = time.Duration(ttl) * time.Second + c.WriteOK() + }) +} + +// PSETEX +func (m *Miniredis) cmdPsetex(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + ttl, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + if ttl <= 0 { + setDirty(c) + c.WriteError(msgInvalidPSETEXTime) + return + } + value := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + db.del(key, true) // Clear any existing keys. + db.stringSet(key, value) + db.ttl[key] = time.Duration(ttl) * time.Millisecond + c.WriteOK() + }) +} + +// SETNX +func (m *Miniredis) cmdSetnx(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, value := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if _, ok := db.keys[key]; ok { + c.WriteInt(0) + return + } + + db.stringSet(key, value) + c.WriteInt(1) + }) +} + +// MSET +func (m *Miniredis) cmdMset(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + if len(args)%2 != 0 { + setDirty(c) + // non-default error message + c.WriteError("ERR wrong number of arguments for MSET") + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + for len(args) > 0 { + key, value := args[0], args[1] + args = args[2:] + + db.del(key, true) // clear TTL + db.stringSet(key, value) + } + c.WriteOK() + }) +} + +// MSETNX +func (m *Miniredis) cmdMsetnx(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + if len(args)%2 != 0 { + setDirty(c) + // non-default error message (yes, with 'MSET'). + c.WriteError("ERR wrong number of arguments for MSET") + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + keys := map[string]string{} + existing := false + for len(args) > 0 { + key := args[0] + value := args[1] + args = args[2:] + keys[key] = value + if _, ok := db.keys[key]; ok { + existing = true + } + } + + res := 0 + if !existing { + res = 1 + for k, v := range keys { + // Nothing to delete. That's the whole point. + db.stringSet(k, v) + } + } + c.WriteInt(res) + }) +} + +// GET +func (m *Miniredis) cmdGet(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteNull() + return + } + if db.t(key) != "string" { + c.WriteError(msgWrongType) + return + } + + c.WriteBulk(db.stringGet(key)) + }) +} + +// GETSET +func (m *Miniredis) cmdGetset(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, value := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + old, ok := db.stringKeys[key] + db.stringSet(key, value) + // a GETSET clears the ttl + delete(db.ttl, key) + + if !ok { + c.WriteNull() + return + } + c.WriteBulk(old) + return + }) +} + +// MGET +func (m *Miniredis) cmdMget(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + c.WriteLen(len(args)) + for _, k := range args { + if t, ok := db.keys[k]; !ok || t != "string" { + c.WriteNull() + continue + } + v, ok := db.stringKeys[k] + if !ok { + // Should not happen, we just checked keys[] + c.WriteNull() + continue + } + c.WriteBulk(v) + } + }) +} + +// INCR +func (m *Miniredis) cmdIncr(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + key := args[0] + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + v, err := db.stringIncr(key, +1) + if err != nil { + c.WriteError(err.Error()) + return + } + // Don't touch TTL + c.WriteInt(v) + }) +} + +// INCRBY +func (m *Miniredis) cmdIncrby(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + delta, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + v, err := db.stringIncr(key, delta) + if err != nil { + c.WriteError(err.Error()) + return + } + // Don't touch TTL + c.WriteInt(v) + }) +} + +// INCRBYFLOAT +func (m *Miniredis) cmdIncrbyfloat(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + delta, err := strconv.ParseFloat(args[1], 64) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidFloat) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + v, err := db.stringIncrfloat(key, delta) + if err != nil { + c.WriteError(err.Error()) + return + } + // Don't touch TTL + c.WriteBulk(formatFloat(v)) + }) +} + +// DECR +func (m *Miniredis) cmdDecr(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + key := args[0] + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + v, err := db.stringIncr(key, -1) + if err != nil { + c.WriteError(err.Error()) + return + } + // Don't touch TTL + c.WriteInt(v) + }) +} + +// DECRBY +func (m *Miniredis) cmdDecrby(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + delta, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + v, err := db.stringIncr(key, -delta) + if err != nil { + c.WriteError(err.Error()) + return + } + // Don't touch TTL + c.WriteInt(v) + }) +} + +// STRLEN +func (m *Miniredis) cmdStrlen(c *server.Peer, cmd string, args []string) { + if len(args) != 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + c.WriteInt(len(db.stringKeys[key])) + }) +} + +// APPEND +func (m *Miniredis) cmdAppend(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key, value := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + newValue := db.stringKeys[key] + value + db.stringSet(key, newValue) + + c.WriteInt(len(newValue)) + }) +} + +// GETRANGE +func (m *Miniredis) cmdGetrange(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + start, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err := strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + v := db.stringKeys[key] + c.WriteBulk(withRange(v, start, end)) + }) +} + +// SETRANGE +func (m *Miniredis) cmdSetrange(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + pos, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + if pos < 0 { + setDirty(c) + c.WriteError("ERR offset is out of range") + return + } + subst := args[2] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + + v := []byte(db.stringKeys[key]) + if len(v) < pos+len(subst) { + newV := make([]byte, pos+len(subst)) + copy(newV, v) + v = newV + } + copy(v[pos:pos+len(subst)], subst) + db.stringSet(key, string(v)) + c.WriteInt(len(v)) + }) +} + +// BITCOUNT +func (m *Miniredis) cmdBitcount(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + var ( + useRange = false + start, end = 0, 0 + key = args[0] + ) + args = args[1:] + if len(args) >= 2 { + useRange = true + var err error + start, err = strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + end, err = strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[2:] + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + if db.t(key) != "string" { + c.WriteError(msgWrongType) + return + } + + // Real redis only checks after it knows the key is there and a string. + if len(args) != 0 { + c.WriteError(msgSyntaxError) + return + } + + v := db.stringKeys[key] + if useRange { + v = withRange(v, start, end) + } + + c.WriteInt(countBits([]byte(v))) + }) +} + +// BITOP +func (m *Miniredis) cmdBitop(c *server.Peer, cmd string, args []string) { + if len(args) < 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + var ( + op = strings.ToUpper(args[0]) + target = args[1] + input = args[2:] + ) + + // 'op' is tested when the transaction is executed. + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + switch op { + case "AND", "OR", "XOR": + first := input[0] + if t, ok := db.keys[first]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + res := []byte(db.stringKeys[first]) + for _, vk := range input[1:] { + if t, ok := db.keys[vk]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + v := db.stringKeys[vk] + cb := map[string]func(byte, byte) byte{ + "AND": func(a, b byte) byte { return a & b }, + "OR": func(a, b byte) byte { return a | b }, + "XOR": func(a, b byte) byte { return a ^ b }, + }[op] + res = sliceBinOp(cb, res, []byte(v)) + } + db.del(target, false) // Keep TTL + if len(res) == 0 { + db.del(target, true) + } else { + db.stringSet(target, string(res)) + } + c.WriteInt(len(res)) + case "NOT": + // NOT only takes a single argument. + if len(input) != 1 { + c.WriteError("ERR BITOP NOT must be called with a single source key.") + return + } + key := input[0] + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + value := []byte(db.stringKeys[key]) + for i := range value { + value[i] = ^value[i] + } + db.del(target, false) // Keep TTL + if len(value) == 0 { + db.del(target, true) + } else { + db.stringSet(target, string(value)) + } + c.WriteInt(len(value)) + default: + c.WriteError(msgSyntaxError) + } + }) +} + +// BITPOS +func (m *Miniredis) cmdBitpos(c *server.Peer, cmd string, args []string) { + if len(args) < 2 || len(args) > 4 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + bit, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + var start, end int + withEnd := false + if len(args) > 2 { + start, err = strconv.Atoi(args[2]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + } + if len(args) > 3 { + end, err = strconv.Atoi(args[3]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + withEnd = true + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + value := db.stringKeys[key] + if start != 0 { + if start > len(value) { + start = len(value) + } + } + if withEnd { + end++ // redis end semantics. + if end < 0 { + end = len(value) + end + } + if end > len(value) { + end = len(value) + } + } else { + end = len(value) + } + if start != 0 || withEnd { + if end < start { + value = "" + } else { + value = value[start:end] + } + } + pos := bitPos([]byte(value), bit == 1) + if pos >= 0 { + pos += start * 8 + } + // Special case when looking for 0, but not when start and end are + // given. + if bit == 0 && pos == -1 && !withEnd { + pos = start*8 + len(value)*8 + } + c.WriteInt(pos) + }) +} + +// GETBIT +func (m *Miniredis) cmdGetbit(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + bit, err := strconv.Atoi(args[1]) + if err != nil { + setDirty(c) + c.WriteError("ERR bit offset is not an integer or out of range") + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + value := db.stringKeys[key] + + ourByteNr := bit / 8 + var ourByte byte + if ourByteNr > len(value)-1 { + ourByte = '\x00' + } else { + ourByte = value[ourByteNr] + } + res := 0 + if toBits(ourByte)[bit%8] { + res = 1 + } + c.WriteInt(res) + }) +} + +// SETBIT +func (m *Miniredis) cmdSetbit(c *server.Peer, cmd string, args []string) { + if len(args) != 3 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + key := args[0] + bit, err := strconv.Atoi(args[1]) + if err != nil || bit < 0 { + setDirty(c) + c.WriteError("ERR bit offset is not an integer or out of range") + return + } + newBit, err := strconv.Atoi(args[2]) + if err != nil || (newBit != 0 && newBit != 1) { + setDirty(c) + c.WriteError("ERR bit is not an integer or out of range") + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if t, ok := db.keys[key]; ok && t != "string" { + c.WriteError(msgWrongType) + return + } + value := []byte(db.stringKeys[key]) + + ourByteNr := bit / 8 + ourBitNr := bit % 8 + if ourByteNr > len(value)-1 { + // Too short. Expand. + newValue := make([]byte, ourByteNr+1) + copy(newValue, value) + value = newValue + } + old := 0 + if toBits(value[ourByteNr])[ourBitNr] { + old = 1 + } + if newBit == 0 { + value[ourByteNr] &^= 1 << uint8(7-ourBitNr) + } else { + value[ourByteNr] |= 1 << uint8(7-ourBitNr) + } + db.stringSet(key, string(value)) + + c.WriteInt(old) + }) +} + +// Redis range. both start and end can be negative. +func withRange(v string, start, end int) string { + s, e := redisRange(len(v), start, end, true /* string getrange symantics */) + return v[s:e] +} + +func countBits(v []byte) int { + count := 0 + for _, b := range []byte(v) { + for b > 0 { + count += int((b % uint8(2))) + b = b >> 1 + } + } + return count +} + +// sliceBinOp applies an operator to all slice elements, with Redis string +// padding logic. +func sliceBinOp(f func(a, b byte) byte, a, b []byte) []byte { + maxl := len(a) + if len(b) > maxl { + maxl = len(b) + } + lA := make([]byte, maxl) + copy(lA, a) + lB := make([]byte, maxl) + copy(lB, b) + res := make([]byte, maxl) + for i := range res { + res[i] = f(lA[i], lB[i]) + } + return res +} + +// Return the number of the first bit set/unset. +func bitPos(s []byte, bit bool) int { + for i, b := range s { + for j, set := range toBits(b) { + if set == bit { + return i*8 + j + } + } + } + return -1 +} + +// toBits changes a byte in 8 bools. +func toBits(s byte) [8]bool { + r := [8]bool{} + for i := range r { + if s&(uint8(1)< version { + // Abort! Abort! + stopTx(ctx) + c.WriteLen(0) + return + } + } + + c.WriteLen(len(ctx.transaction)) + for _, cb := range ctx.transaction { + cb(c, ctx) + } + // wake up anyone who waits on anything. + m.signal.Broadcast() + + stopTx(ctx) +} + +// DISCARD +func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) { + if len(args) != 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + ctx := getCtx(c) + if !inTx(ctx) { + c.WriteError("ERR DISCARD without MULTI") + return + } + + stopTx(ctx) + c.WriteOK() +} + +// WATCH +func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) { + if len(args) == 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + ctx := getCtx(c) + if inTx(ctx) { + c.WriteError("ERR WATCH in MULTI") + return + } + + m.Lock() + defer m.Unlock() + db := m.db(ctx.selectedDB) + + for _, key := range args { + watch(db, ctx, key) + } + c.WriteOK() +} + +// UNWATCH +func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) { + if len(args) != 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + // Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me. + unwatch(getCtx(c)) + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + // Do nothing if it's called in a transaction. + c.WriteOK() + }) +} diff --git a/vendor/github.com/alicebob/miniredis/cmd_transactions_test.go b/vendor/github.com/alicebob/miniredis/cmd_transactions_test.go new file mode 100644 index 00000000..ca46a213 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/cmd_transactions_test.go @@ -0,0 +1,261 @@ +package miniredis + +import ( + "testing" + + "github.com/garyburd/redigo/redis" +) + +func TestMulti(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Do accept MULTI, but use it as a no-op + r, err := redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", r) +} + +func TestExec(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Exec without MULTI. + _, err = c.Do("EXEC") + assert(t, err != nil, "do EXEC error") +} + +func TestDiscard(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // DISCARD without MULTI. + _, err = c.Do("DISCARD") + assert(t, err != nil, "do DISCARD error") +} + +func TestWatch(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Simple WATCH + r, err := redis.String(c.Do("WATCH", "foo")) + ok(t, err) + equals(t, "OK", r) + + // Can't do WATCH in a MULTI + { + _, err = redis.String(c.Do("MULTI")) + ok(t, err) + _, err = redis.String(c.Do("WATCH", "foo")) + assert(t, err != nil, "do WATCH error") + } +} + +// Test simple multi/exec block. +func TestSimpleTransaction(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + b, err := redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("SET", "aap", 1)) + ok(t, err) + equals(t, "QUEUED", b) + + // Not set yet. + equals(t, false, s.Exists("aap")) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 1, len(redis.Args(v))) + equals(t, "OK", v[0]) + + // SET should be back to normal mode + b, err = redis.String(c.Do("SET", "aap", 1)) + ok(t, err) + equals(t, "OK", b) +} + +func TestDiscardTransaction(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("aap", "noot") + + b, err := redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("SET", "aap", "mies")) + ok(t, err) + equals(t, "QUEUED", b) + + // Not committed + s.CheckGet(t, "aap", "noot") + + v, err := redis.String(c.Do("DISCARD")) + ok(t, err) + equals(t, "OK", v) + + // TX didn't get executed + s.CheckGet(t, "aap", "noot") +} + +func TestTxQueueErr(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + b, err := redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("SET", "aap", "mies")) + ok(t, err) + equals(t, "QUEUED", b) + + // That's an error! + _, err = redis.String(c.Do("SET", "aap")) + assert(t, err != nil, "do SET error") + + // Thisone is ok again + b, err = redis.String(c.Do("SET", "noot", "vuur")) + ok(t, err) + equals(t, "QUEUED", b) + + _, err = redis.String(c.Do("EXEC")) + assert(t, err != nil, "do EXEC error") + + // Didn't get EXECed + equals(t, false, s.Exists("aap")) +} + +func TestTxWatch(t *testing.T) { + // Watch with no error. + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("one", "two") + b, err := redis.String(c.Do("WATCH", "one")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("GET", "one")) + ok(t, err) + equals(t, "QUEUED", b) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 1, len(v)) + equals(t, []byte("two"), v[0]) +} + +func TestTxWatchErr(t *testing.T) { + // Watch with en error. + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + c2, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("one", "two") + b, err := redis.String(c.Do("WATCH", "one")) + ok(t, err) + equals(t, "OK", b) + + // Here comes client 2 + b, err = redis.String(c2.Do("SET", "one", "three")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("GET", "one")) + ok(t, err) + equals(t, "QUEUED", b) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 0, len(v)) + + // It did get updated, and we're not in a transaction anymore. + b, err = redis.String(c.Do("GET", "one")) + ok(t, err) + equals(t, "three", b) +} + +func TestUnwatch(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + c2, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + s.Set("one", "two") + b, err := redis.String(c.Do("WATCH", "one")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("UNWATCH")) + ok(t, err) + equals(t, "OK", b) + + // Here comes client 2 + b, err = redis.String(c2.Do("SET", "one", "three")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("MULTI")) + ok(t, err) + equals(t, "OK", b) + + b, err = redis.String(c.Do("SET", "one", "four")) + ok(t, err) + equals(t, "QUEUED", b) + + v, err := redis.Values(c.Do("EXEC")) + ok(t, err) + equals(t, 1, len(v)) + equals(t, "OK", v[0]) + + // It did get updated by our TX + b, err = redis.String(c.Do("GET", "one")) + ok(t, err) + equals(t, "four", b) +} diff --git a/vendor/github.com/alicebob/miniredis/db.go b/vendor/github.com/alicebob/miniredis/db.go new file mode 100644 index 00000000..57edbf6e --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/db.go @@ -0,0 +1,550 @@ +package miniredis + +import ( + "sort" + "strconv" + "time" +) + +func (db *RedisDB) exists(k string) bool { + _, ok := db.keys[k] + return ok +} + +// t gives the type of a key, or "" +func (db *RedisDB) t(k string) string { + return db.keys[k] +} + +// allKeys returns all keys. Sorted. +func (db *RedisDB) allKeys() []string { + res := make([]string, 0, len(db.keys)) + for k := range db.keys { + res = append(res, k) + } + sort.Strings(res) // To make things deterministic. + return res +} + +// flush removes all keys and values. +func (db *RedisDB) flush() { + db.keys = map[string]string{} + db.stringKeys = map[string]string{} + db.hashKeys = map[string]hashKey{} + db.listKeys = map[string]listKey{} + db.setKeys = map[string]setKey{} + db.sortedsetKeys = map[string]sortedSet{} + db.ttl = map[string]time.Duration{} +} + +// move something to another db. Will return ok. Or not. +func (db *RedisDB) move(key string, to *RedisDB) bool { + if _, ok := to.keys[key]; ok { + return false + } + + t, ok := db.keys[key] + if !ok { + return false + } + to.keys[key] = db.keys[key] + switch t { + case "string": + to.stringKeys[key] = db.stringKeys[key] + case "hash": + to.hashKeys[key] = db.hashKeys[key] + case "list": + to.listKeys[key] = db.listKeys[key] + case "set": + to.setKeys[key] = db.setKeys[key] + case "zset": + to.sortedsetKeys[key] = db.sortedsetKeys[key] + default: + panic("unhandled key type") + } + to.keyVersion[key]++ + if v, ok := db.ttl[key]; ok { + to.ttl[key] = v + } + db.del(key, true) + return true +} + +func (db *RedisDB) rename(from, to string) { + db.del(to, true) + switch db.t(from) { + case "string": + db.stringKeys[to] = db.stringKeys[from] + case "hash": + db.hashKeys[to] = db.hashKeys[from] + case "list": + db.listKeys[to] = db.listKeys[from] + case "set": + db.setKeys[to] = db.setKeys[from] + case "zset": + db.sortedsetKeys[to] = db.sortedsetKeys[from] + default: + panic("missing case") + } + db.keys[to] = db.keys[from] + db.keyVersion[to]++ + db.ttl[to] = db.ttl[from] + + db.del(from, true) +} + +func (db *RedisDB) del(k string, delTTL bool) { + if !db.exists(k) { + return + } + t := db.t(k) + delete(db.keys, k) + db.keyVersion[k]++ + if delTTL { + delete(db.ttl, k) + } + switch t { + case "string": + delete(db.stringKeys, k) + case "hash": + delete(db.hashKeys, k) + case "list": + delete(db.listKeys, k) + case "set": + delete(db.setKeys, k) + case "zset": + delete(db.sortedsetKeys, k) + default: + panic("Unknown key type: " + t) + } +} + +// stringGet returns the string key or "" on error/nonexists. +func (db *RedisDB) stringGet(k string) string { + if t, ok := db.keys[k]; !ok || t != "string" { + return "" + } + return db.stringKeys[k] +} + +// stringSet force set()s a key. Does not touch expire. +func (db *RedisDB) stringSet(k, v string) { + db.del(k, false) + db.keys[k] = "string" + db.stringKeys[k] = v + db.keyVersion[k]++ +} + +// change int key value +func (db *RedisDB) stringIncr(k string, delta int) (int, error) { + v := 0 + if sv, ok := db.stringKeys[k]; ok { + var err error + v, err = strconv.Atoi(sv) + if err != nil { + return 0, ErrIntValueError + } + } + v += delta + db.stringSet(k, strconv.Itoa(v)) + return v, nil +} + +// change float key value +func (db *RedisDB) stringIncrfloat(k string, delta float64) (float64, error) { + v := 0.0 + if sv, ok := db.stringKeys[k]; ok { + var err error + v, err = strconv.ParseFloat(sv, 64) + if err != nil { + return 0, ErrFloatValueError + } + } + v += delta + db.stringSet(k, formatFloat(v)) + return v, nil +} + +// listLpush is 'left push', aka unshift. Returns the new length. +func (db *RedisDB) listLpush(k, v string) int { + l, ok := db.listKeys[k] + if !ok { + db.keys[k] = "list" + } + l = append([]string{v}, l...) + db.listKeys[k] = l + db.keyVersion[k]++ + return len(l) +} + +// 'left pop', aka shift. +func (db *RedisDB) listLpop(k string) string { + l := db.listKeys[k] + el := l[0] + l = l[1:] + if len(l) == 0 { + db.del(k, true) + } else { + db.listKeys[k] = l + } + db.keyVersion[k]++ + return el +} + +func (db *RedisDB) listPush(k string, v ...string) int { + l, ok := db.listKeys[k] + if !ok { + db.keys[k] = "list" + } + l = append(l, v...) + db.listKeys[k] = l + db.keyVersion[k]++ + return len(l) +} + +func (db *RedisDB) listPop(k string) string { + l := db.listKeys[k] + el := l[len(l)-1] + l = l[:len(l)-1] + if len(l) == 0 { + db.del(k, true) + } else { + db.listKeys[k] = l + db.keyVersion[k]++ + } + return el +} + +// setset replaces a whole set. +func (db *RedisDB) setSet(k string, set setKey) { + db.keys[k] = "set" + db.setKeys[k] = set + db.keyVersion[k]++ +} + +// setadd adds members to a set. Returns nr of new keys. +func (db *RedisDB) setAdd(k string, elems ...string) int { + s, ok := db.setKeys[k] + if !ok { + s = setKey{} + db.keys[k] = "set" + } + added := 0 + for _, e := range elems { + if _, ok := s[e]; !ok { + added++ + } + s[e] = struct{}{} + } + db.setKeys[k] = s + db.keyVersion[k]++ + return added +} + +// setrem removes members from a set. Returns nr of deleted keys. +func (db *RedisDB) setRem(k string, fields ...string) int { + s, ok := db.setKeys[k] + if !ok { + return 0 + } + removed := 0 + for _, f := range fields { + if _, ok := s[f]; ok { + removed++ + delete(s, f) + } + } + if len(s) == 0 { + db.del(k, true) + } else { + db.setKeys[k] = s + } + db.keyVersion[k]++ + return removed +} + +// All members of a set. +func (db *RedisDB) setMembers(k string) []string { + set := db.setKeys[k] + members := make([]string, 0, len(set)) + for k := range set { + members = append(members, k) + } + sort.Strings(members) + return members +} + +// Is a SET value present? +func (db *RedisDB) setIsMember(k, v string) bool { + set, ok := db.setKeys[k] + if !ok { + return false + } + _, ok = set[v] + return ok +} + +// hashFields returns all (sorted) keys ('fields') for a hash key. +func (db *RedisDB) hashFields(k string) []string { + v := db.hashKeys[k] + r := make([]string, 0, len(v)) + for k := range v { + r = append(r, k) + } + sort.Strings(r) + return r +} + +// hashGet a value +func (db *RedisDB) hashGet(key, field string) string { + return db.hashKeys[key][field] +} + +// hashSet returns whether the key already existed +func (db *RedisDB) hashSet(k, f, v string) bool { + if t, ok := db.keys[k]; ok && t != "hash" { + db.del(k, true) + } + db.keys[k] = "hash" + if _, ok := db.hashKeys[k]; !ok { + db.hashKeys[k] = map[string]string{} + } + _, ok := db.hashKeys[k][f] + db.hashKeys[k][f] = v + db.keyVersion[k]++ + return ok +} + +// hashIncr changes int key value +func (db *RedisDB) hashIncr(key, field string, delta int) (int, error) { + v := 0 + if h, ok := db.hashKeys[key]; ok { + if f, ok := h[field]; ok { + var err error + v, err = strconv.Atoi(f) + if err != nil { + return 0, ErrIntValueError + } + } + } + v += delta + db.hashSet(key, field, strconv.Itoa(v)) + return v, nil +} + +// hashIncrfloat changes float key value +func (db *RedisDB) hashIncrfloat(key, field string, delta float64) (float64, error) { + v := 0.0 + if h, ok := db.hashKeys[key]; ok { + if f, ok := h[field]; ok { + var err error + v, err = strconv.ParseFloat(f, 64) + if err != nil { + return 0, ErrFloatValueError + } + } + } + v += delta + db.hashSet(key, field, formatFloat(v)) + return v, nil +} + +// sortedSet set returns a sortedSet as map +func (db *RedisDB) sortedSet(key string) map[string]float64 { + ss := db.sortedsetKeys[key] + return map[string]float64(ss) +} + +// ssetSet sets a complete sorted set. +func (db *RedisDB) ssetSet(key string, sset sortedSet) { + db.keys[key] = "zset" + db.keyVersion[key]++ + db.sortedsetKeys[key] = sset +} + +// ssetAdd adds member to a sorted set. Returns whether this was a new member. +func (db *RedisDB) ssetAdd(key string, score float64, member string) bool { + ss, ok := db.sortedsetKeys[key] + if !ok { + ss = newSortedSet() + db.keys[key] = "zset" + } + _, ok = ss[member] + ss[member] = score + db.sortedsetKeys[key] = ss + db.keyVersion[key]++ + return !ok +} + +// All members from a sorted set, ordered by score. +func (db *RedisDB) ssetMembers(key string) []string { + ss, ok := db.sortedsetKeys[key] + if !ok { + return nil + } + elems := ss.byScore(asc) + members := make([]string, 0, len(elems)) + for _, e := range elems { + members = append(members, e.member) + } + return members +} + +// All members+scores from a sorted set, ordered by score. +func (db *RedisDB) ssetElements(key string) ssElems { + ss, ok := db.sortedsetKeys[key] + if !ok { + return nil + } + return ss.byScore(asc) +} + +// ssetCard is the sorted set cardinality. +func (db *RedisDB) ssetCard(key string) int { + ss := db.sortedsetKeys[key] + return ss.card() +} + +// ssetRank is the sorted set rank. +func (db *RedisDB) ssetRank(key, member string, d direction) (int, bool) { + ss := db.sortedsetKeys[key] + return ss.rankByScore(member, d) +} + +// ssetScore is sorted set score. +func (db *RedisDB) ssetScore(key, member string) float64 { + ss := db.sortedsetKeys[key] + return ss[member] +} + +// ssetRem is sorted set key delete. +func (db *RedisDB) ssetRem(key, member string) bool { + ss := db.sortedsetKeys[key] + _, ok := ss[member] + delete(ss, member) + if len(ss) == 0 { + // Delete key on removal of last member + db.del(key, true) + } + return ok +} + +// ssetExists tells if a member exists in a sorted set. +func (db *RedisDB) ssetExists(key, member string) bool { + ss := db.sortedsetKeys[key] + _, ok := ss[member] + return ok +} + +// ssetIncrby changes float sorted set score. +func (db *RedisDB) ssetIncrby(k, m string, delta float64) float64 { + ss, ok := db.sortedsetKeys[k] + if !ok { + ss = newSortedSet() + db.keys[k] = "zset" + db.sortedsetKeys[k] = ss + } + + v, _ := ss.get(m) + v += delta + ss.set(v, m) + db.keyVersion[k]++ + return v +} + +// setDiff implements the logic behind SDIFF* +func (db *RedisDB) setDiff(keys []string) (setKey, error) { + key := keys[0] + keys = keys[1:] + if db.exists(key) && db.t(key) != "set" { + return nil, ErrWrongType + } + s := setKey{} + for k := range db.setKeys[key] { + s[k] = struct{}{} + } + for _, sk := range keys { + if !db.exists(sk) { + continue + } + if db.t(sk) != "set" { + return nil, ErrWrongType + } + for e := range db.setKeys[sk] { + delete(s, e) + } + } + return s, nil +} + +// setInter implements the logic behind SINTER* +func (db *RedisDB) setInter(keys []string) (setKey, error) { + key := keys[0] + keys = keys[1:] + if db.exists(key) && db.t(key) != "set" { + return nil, ErrWrongType + } + s := setKey{} + for k := range db.setKeys[key] { + s[k] = struct{}{} + } + for _, sk := range keys { + if !db.exists(sk) { + continue + } + if db.t(sk) != "set" { + // Bug(?) in redis 2.8.14, it just skips the key. + continue + // return nil, ErrWrongType + } + other := db.setKeys[sk] + for e := range s { + if _, ok := other[e]; ok { + continue + } + delete(s, e) + } + } + return s, nil +} + +// setUnion implements the logic behind SUNION* +func (db *RedisDB) setUnion(keys []string) (setKey, error) { + key := keys[0] + keys = keys[1:] + if db.exists(key) && db.t(key) != "set" { + return nil, ErrWrongType + } + s := setKey{} + for k := range db.setKeys[key] { + s[k] = struct{}{} + } + for _, sk := range keys { + if !db.exists(sk) { + continue + } + if db.t(sk) != "set" { + return nil, ErrWrongType + } + for e := range db.setKeys[sk] { + s[e] = struct{}{} + } + } + return s, nil +} + +// fastForward proceeds the current timestamp with duration, works as a time machine +func (db *RedisDB) fastForward(duration time.Duration) { + for _, key := range db.allKeys() { + if value, ok := db.ttl[key]; ok { + db.ttl[key] = value - duration + db.checkTTL(key) + } + } +} + +func (db *RedisDB) checkTTL(key string) { + if v, ok := db.ttl[key]; ok && v <= 0 { + db.del(key, true) + } +} diff --git a/vendor/github.com/alicebob/miniredis/direct.go b/vendor/github.com/alicebob/miniredis/direct.go new file mode 100644 index 00000000..8c89bdd7 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/direct.go @@ -0,0 +1,549 @@ +package miniredis + +// Commands to modify and query our databases directly. + +import ( + "errors" + "time" +) + +var ( + // ErrKeyNotFound is returned when a key doesn't exist. + ErrKeyNotFound = errors.New(msgKeyNotFound) + // ErrWrongType when a key is not the right type. + ErrWrongType = errors.New(msgWrongType) + // ErrIntValueError can returned by INCRBY + ErrIntValueError = errors.New(msgInvalidInt) + // ErrFloatValueError can returned by INCRBYFLOAT + ErrFloatValueError = errors.New(msgInvalidFloat) +) + +// Select sets the DB id for all direct commands. +func (m *Miniredis) Select(i int) { + m.Lock() + defer m.Unlock() + m.selectedDB = i +} + +// Keys returns all keys from the selected database, sorted. +func (m *Miniredis) Keys() []string { + return m.DB(m.selectedDB).Keys() +} + +// Keys returns all keys, sorted. +func (db *RedisDB) Keys() []string { + db.master.Lock() + defer db.master.Unlock() + return db.allKeys() +} + +// FlushAll removes all keys from all databases. +func (m *Miniredis) FlushAll() { + m.Lock() + defer m.Unlock() + m.flushAll() +} + +func (m *Miniredis) flushAll() { + for _, db := range m.dbs { + db.flush() + } +} + +// FlushDB removes all keys from the selected database. +func (m *Miniredis) FlushDB() { + m.DB(m.selectedDB).FlushDB() +} + +// FlushDB removes all keys. +func (db *RedisDB) FlushDB() { + db.master.Lock() + defer db.master.Unlock() + db.flush() +} + +// Get returns string keys added with SET. +func (m *Miniredis) Get(k string) (string, error) { + return m.DB(m.selectedDB).Get(k) +} + +// Get returns a string key. +func (db *RedisDB) Get(k string) (string, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return "", ErrKeyNotFound + } + if db.t(k) != "string" { + return "", ErrWrongType + } + return db.stringGet(k), nil +} + +// Set sets a string key. Removes expire. +func (m *Miniredis) Set(k, v string) error { + return m.DB(m.selectedDB).Set(k, v) +} + +// Set sets a string key. Removes expire. +// Unlike redis the key can't be an existing non-string key. +func (db *RedisDB) Set(k, v string) error { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "string" { + return ErrWrongType + } + db.del(k, true) // Remove expire + db.stringSet(k, v) + return nil +} + +// Incr changes a int string value by delta. +func (m *Miniredis) Incr(k string, delta int) (int, error) { + return m.DB(m.selectedDB).Incr(k, delta) +} + +// Incr changes a int string value by delta. +func (db *RedisDB) Incr(k string, delta int) (int, error) { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "string" { + return 0, ErrWrongType + } + + return db.stringIncr(k, delta) +} + +// Incrfloat changes a float string value by delta. +func (m *Miniredis) Incrfloat(k string, delta float64) (float64, error) { + return m.DB(m.selectedDB).Incrfloat(k, delta) +} + +// Incrfloat changes a float string value by delta. +func (db *RedisDB) Incrfloat(k string, delta float64) (float64, error) { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "string" { + return 0, ErrWrongType + } + + return db.stringIncrfloat(k, delta) +} + +// List returns the list k, or an error if it's not there or something else. +// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own +// range-ing. +func (m *Miniredis) List(k string) ([]string, error) { + return m.DB(m.selectedDB).List(k) +} + +// List returns the list k, or an error if it's not there or something else. +// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own +// range-ing. +func (db *RedisDB) List(k string) ([]string, error) { + db.master.Lock() + defer db.master.Unlock() + + if !db.exists(k) { + return nil, ErrKeyNotFound + } + if db.t(k) != "list" { + return nil, ErrWrongType + } + return db.listKeys[k], nil +} + +// Lpush is an unshift. Returns the new length. +func (m *Miniredis) Lpush(k, v string) (int, error) { + return m.DB(m.selectedDB).Lpush(k, v) +} + +// Lpush is an unshift. Returns the new length. +func (db *RedisDB) Lpush(k, v string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "list" { + return 0, ErrWrongType + } + return db.listLpush(k, v), nil +} + +// Lpop is a shift. Returns the popped element. +func (m *Miniredis) Lpop(k string) (string, error) { + return m.DB(m.selectedDB).Lpop(k) +} + +// Lpop is a shift. Returns the popped element. +func (db *RedisDB) Lpop(k string) (string, error) { + db.master.Lock() + defer db.master.Unlock() + + if !db.exists(k) { + return "", ErrKeyNotFound + } + if db.t(k) != "list" { + return "", ErrWrongType + } + return db.listLpop(k), nil +} + +// Push add element at the end. Is called RPUSH in redis. Returns the new length. +func (m *Miniredis) Push(k string, v ...string) (int, error) { + return m.DB(m.selectedDB).Push(k, v...) +} + +// Push add element at the end. Is called RPUSH in redis. Returns the new length. +func (db *RedisDB) Push(k string, v ...string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "list" { + return 0, ErrWrongType + } + return db.listPush(k, v...), nil +} + +// Pop removes and returns the last element. Is called RPOP in Redis. +func (m *Miniredis) Pop(k string) (string, error) { + return m.DB(m.selectedDB).Pop(k) +} + +// Pop removes and returns the last element. Is called RPOP in Redis. +func (db *RedisDB) Pop(k string) (string, error) { + db.master.Lock() + defer db.master.Unlock() + + if !db.exists(k) { + return "", ErrKeyNotFound + } + if db.t(k) != "list" { + return "", ErrWrongType + } + + return db.listPop(k), nil +} + +// SetAdd adds keys to a set. Returns the number of new keys. +func (m *Miniredis) SetAdd(k string, elems ...string) (int, error) { + return m.DB(m.selectedDB).SetAdd(k, elems...) +} + +// SetAdd adds keys to a set. Returns the number of new keys. +func (db *RedisDB) SetAdd(k string, elems ...string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + if db.exists(k) && db.t(k) != "set" { + return 0, ErrWrongType + } + return db.setAdd(k, elems...), nil +} + +// Members gives all set keys. Sorted. +func (m *Miniredis) Members(k string) ([]string, error) { + return m.DB(m.selectedDB).Members(k) +} + +// Members gives all set keys. Sorted. +func (db *RedisDB) Members(k string) ([]string, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return nil, ErrKeyNotFound + } + if db.t(k) != "set" { + return nil, ErrWrongType + } + return db.setMembers(k), nil +} + +// IsMember tells if value is in the set. +func (m *Miniredis) IsMember(k, v string) (bool, error) { + return m.DB(m.selectedDB).IsMember(k, v) +} + +// IsMember tells if value is in the set. +func (db *RedisDB) IsMember(k, v string) (bool, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return false, ErrKeyNotFound + } + if db.t(k) != "set" { + return false, ErrWrongType + } + return db.setIsMember(k, v), nil +} + +// HKeys returns all (sorted) keys ('fields') for a hash key. +func (m *Miniredis) HKeys(k string) ([]string, error) { + return m.DB(m.selectedDB).HKeys(k) +} + +// HKeys returns all (sorted) keys ('fields') for a hash key. +func (db *RedisDB) HKeys(key string) ([]string, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(key) { + return nil, ErrKeyNotFound + } + if db.t(key) != "hash" { + return nil, ErrWrongType + } + return db.hashFields(key), nil +} + +// Del deletes a key and any expiration value. Returns whether there was a key. +func (m *Miniredis) Del(k string) bool { + return m.DB(m.selectedDB).Del(k) +} + +// Del deletes a key and any expiration value. Returns whether there was a key. +func (db *RedisDB) Del(k string) bool { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return false + } + db.del(k, true) + return true +} + +// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT, +// PEXPIREAT. +// 0 if not set. +func (m *Miniredis) TTL(k string) time.Duration { + return m.DB(m.selectedDB).TTL(k) +} + +// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT, +// PEXPIREAT. +// 0 if not set. +func (db *RedisDB) TTL(k string) time.Duration { + db.master.Lock() + defer db.master.Unlock() + return db.ttl[k] +} + +// SetTTL sets the TTL of a key. +func (m *Miniredis) SetTTL(k string, ttl time.Duration) { + m.DB(m.selectedDB).SetTTL(k, ttl) +} + +// SetTTL sets the time to live of a key. +func (db *RedisDB) SetTTL(k string, ttl time.Duration) { + db.master.Lock() + defer db.master.Unlock() + db.ttl[k] = ttl + db.keyVersion[k]++ +} + +// Type gives the type of a key, or "" +func (m *Miniredis) Type(k string) string { + return m.DB(m.selectedDB).Type(k) +} + +// Type gives the type of a key, or "" +func (db *RedisDB) Type(k string) string { + db.master.Lock() + defer db.master.Unlock() + return db.t(k) +} + +// Exists tells whether a key exists. +func (m *Miniredis) Exists(k string) bool { + return m.DB(m.selectedDB).Exists(k) +} + +// Exists tells whether a key exists. +func (db *RedisDB) Exists(k string) bool { + db.master.Lock() + defer db.master.Unlock() + return db.exists(k) +} + +// HGet returns hash keys added with HSET. +// This will return an empty string if the key is not set. Redis would return +// a nil. +// Returns empty string when the key is of a different type. +func (m *Miniredis) HGet(k, f string) string { + return m.DB(m.selectedDB).HGet(k, f) +} + +// HGet returns hash keys added with HSET. +// Returns empty string when the key is of a different type. +func (db *RedisDB) HGet(k, f string) string { + db.master.Lock() + defer db.master.Unlock() + h, ok := db.hashKeys[k] + if !ok { + return "" + } + return h[f] +} + +// HSet sets a hash key. +// If there is another key by the same name it will be gone. +func (m *Miniredis) HSet(k, f, v string) { + m.DB(m.selectedDB).HSet(k, f, v) +} + +// HSet sets a hash key. +// If there is another key by the same name it will be gone. +func (db *RedisDB) HSet(k, f, v string) { + db.master.Lock() + defer db.master.Unlock() + db.hashSet(k, f, v) +} + +// HDel deletes a hash key. +func (m *Miniredis) HDel(k, f string) { + m.DB(m.selectedDB).HDel(k, f) +} + +// HDel deletes a hash key. +func (db *RedisDB) HDel(k, f string) { + db.master.Lock() + defer db.master.Unlock() + db.hdel(k, f) +} + +func (db *RedisDB) hdel(k, f string) { + if _, ok := db.hashKeys[k]; !ok { + return + } + delete(db.hashKeys[k], f) + db.keyVersion[k]++ +} + +// HIncr increases a key/field by delta (int). +func (m *Miniredis) HIncr(k, f string, delta int) (int, error) { + return m.DB(m.selectedDB).HIncr(k, f, delta) +} + +// HIncr increases a key/field by delta (int). +func (db *RedisDB) HIncr(k, f string, delta int) (int, error) { + db.master.Lock() + defer db.master.Unlock() + return db.hashIncr(k, f, delta) +} + +// HIncrfloat increases a key/field by delta (float). +func (m *Miniredis) HIncrfloat(k, f string, delta float64) (float64, error) { + return m.DB(m.selectedDB).HIncrfloat(k, f, delta) +} + +// HIncrfloat increases a key/field by delta (float). +func (db *RedisDB) HIncrfloat(k, f string, delta float64) (float64, error) { + db.master.Lock() + defer db.master.Unlock() + return db.hashIncrfloat(k, f, delta) +} + +// SRem removes fields from a set. Returns number of deleted fields. +func (m *Miniredis) SRem(k string, fields ...string) (int, error) { + return m.DB(m.selectedDB).SRem(k, fields...) +} + +// SRem removes fields from a set. Returns number of deleted fields. +func (db *RedisDB) SRem(k string, fields ...string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return 0, ErrKeyNotFound + } + if db.t(k) != "set" { + return 0, ErrWrongType + } + return db.setRem(k, fields...), nil +} + +// ZAdd adds a score,member to a sorted set. +func (m *Miniredis) ZAdd(k string, score float64, member string) (bool, error) { + return m.DB(m.selectedDB).ZAdd(k, score, member) +} + +// ZAdd adds a score,member to a sorted set. +func (db *RedisDB) ZAdd(k string, score float64, member string) (bool, error) { + db.master.Lock() + defer db.master.Unlock() + if db.exists(k) && !db.exists(k) && db.t(k) != "zset" { + return false, ErrWrongType + } + return db.ssetAdd(k, score, member), nil +} + +// ZMembers returns all members by score +func (m *Miniredis) ZMembers(k string) ([]string, error) { + return m.DB(m.selectedDB).ZMembers(k) +} + +// ZMembers returns all members by score +func (db *RedisDB) ZMembers(k string) ([]string, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return nil, ErrKeyNotFound + } + if db.t(k) != "zset" { + return nil, ErrWrongType + } + return db.ssetMembers(k), nil +} + +// SortedSet returns a raw string->float64 map. +func (m *Miniredis) SortedSet(k string) (map[string]float64, error) { + return m.DB(m.selectedDB).SortedSet(k) +} + +// SortedSet returns a raw string->float64 map. +func (db *RedisDB) SortedSet(k string) (map[string]float64, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return nil, ErrKeyNotFound + } + if db.t(k) != "zset" { + return nil, ErrWrongType + } + return db.sortedSet(k), nil +} + +// ZRem deletes a member. Returns whether the was a key. +func (m *Miniredis) ZRem(k, member string) (bool, error) { + return m.DB(m.selectedDB).ZRem(k, member) +} + +// ZRem deletes a member. Returns whether the was a key. +func (db *RedisDB) ZRem(k, member string) (bool, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return false, ErrKeyNotFound + } + if db.t(k) != "zset" { + return false, ErrWrongType + } + return db.ssetRem(k, member), nil +} + +// ZScore gives the score of a sorted set member. +func (m *Miniredis) ZScore(k, member string) (float64, error) { + return m.DB(m.selectedDB).ZScore(k, member) +} + +// ZScore gives the score of a sorted set member. +func (db *RedisDB) ZScore(k, member string) (float64, error) { + db.master.Lock() + defer db.master.Unlock() + if !db.exists(k) { + return 0, ErrKeyNotFound + } + if db.t(k) != "zset" { + return 0, ErrWrongType + } + return db.ssetScore(k, member), nil +} diff --git a/vendor/github.com/alicebob/miniredis/example_test.go b/vendor/github.com/alicebob/miniredis/example_test.go new file mode 100644 index 00000000..c3127b4a --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/example_test.go @@ -0,0 +1,54 @@ +package miniredis_test + +import ( + "time" + + "github.com/alicebob/miniredis" + "github.com/garyburd/redigo/redis" +) + +func Example() { + s, err := miniredis.Run() + if err != nil { + panic(err) + } + defer s.Close() + + // Configure you application to connect to redis at s.Addr() + // Any redis client should work, as long as you use redis commands which + // miniredis implements. + c, err := redis.Dial("tcp", s.Addr()) + if err != nil { + panic(err) + } + if _, err = c.Do("SET", "foo", "bar"); err != nil { + panic(err) + } + + // You can ask miniredis about keys directly, without going over the network. + if got, err := s.Get("foo"); err != nil || got != "bar" { + panic("Didn't get 'bar' back") + } + // Or with a DB id + if _, err := s.DB(42).Get("foo"); err != miniredis.ErrKeyNotFound { + panic("didn't use a different database") + } + + // Test key with expiration + s.SetTTL("foo", 60*time.Second) + s.FastForward(60 * time.Second) + if s.Exists("foo") { + panic("expect key to be expired") + } + + // Or use a Check* function which Fail()s if the key is not what we expect + // (checks for existence, key type and the value) + // s.CheckGet(t, "foo", "bar") + + // Check if there really was only one connection. + if s.TotalConnectionCount() != 1 { + panic("too many connections made") + } + + // Output: +} diff --git a/vendor/github.com/alicebob/miniredis/keys.go b/vendor/github.com/alicebob/miniredis/keys.go new file mode 100644 index 00000000..b7cd98fb --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/keys.go @@ -0,0 +1,65 @@ +package miniredis + +// Translate the 'KEYS' argument ('foo*', 'f??', &c.) into a regexp. + +import ( + "bytes" + "regexp" +) + +// patternRE compiles a KEYS argument to a regexp. Returns nil if the given +// pattern will never match anything. +// The general strategy is to sandwich all non-meta characters between \Q...\E. +func patternRE(k string) *regexp.Regexp { + re := bytes.Buffer{} + re.WriteString(`^\Q`) + for i := 0; i < len(k); i++ { + p := k[i] + switch p { + case '*': + re.WriteString(`\E.*\Q`) + case '?': + re.WriteString(`\E.\Q`) + case '[': + charClass := bytes.Buffer{} + i++ + for ; i < len(k); i++ { + if k[i] == ']' { + break + } + if k[i] == '\\' { + if i == len(k)-1 { + // Ends with a '\'. U-huh. + return nil + } + charClass.WriteByte(k[i]) + i++ + charClass.WriteByte(k[i]) + continue + } + charClass.WriteByte(k[i]) + } + if charClass.Len() == 0 { + // '[]' is valid in Redis, but matches nothing. + return nil + } + re.WriteString(`\E[`) + re.Write(charClass.Bytes()) + re.WriteString(`]\Q`) + + case '\\': + if i == len(k)-1 { + // Ends with a '\'. U-huh. + return nil + } + // Forget the \, keep the next char. + i++ + re.WriteByte(k[i]) + continue + default: + re.WriteByte(p) + } + } + re.WriteString(`\E$`) + return regexp.MustCompile(re.String()) +} diff --git a/vendor/github.com/alicebob/miniredis/keys_test.go b/vendor/github.com/alicebob/miniredis/keys_test.go new file mode 100644 index 00000000..31f2ad00 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/keys_test.go @@ -0,0 +1,123 @@ +package miniredis + +import ( + "testing" +) + +func TestKeysSel(t *testing.T) { + // Helper to test the selection behind KEYS + // pattern -> cases -> should match? + for pat, chk := range map[string]map[string]bool{ + "aap": { + "aap": true, + "aapnoot": false, + "nootaap": false, + "nootaapnoot": false, + "AAP": false, + }, + "aap*": { + "aap": true, + "aapnoot": true, + "nootaap": false, + "nootaapnoot": false, + "AAP": false, + }, + // No problem with regexp meta chars? + "(?:a)ap*": { + "(?:a)ap!": true, + "aap": false, + }, + "*aap*": { + "aap": true, + "aapnoot": true, + "nootaap": true, + "nootaapnoot": true, + "AAP": false, + "a_a_p": false, + }, + `\*aap*`: { + "*aap": true, + "aap": false, + "*aapnoot": true, + "aapnoot": false, + }, + `aa?`: { + "aap": true, + "aal": true, + "aaf": true, + "aa?": true, + "aap!": false, + }, + `aa\?`: { + "aap": false, + "aa?": true, + "aa?!": false, + }, + "aa[pl]": { + "aap": true, + "aal": true, + "aaf": false, + "aa?": false, + "aap!": false, + }, + "[ab]a[pl]": { + "aap": true, + "aal": true, + "bap": true, + "bal": true, + "aaf": false, + "cap": false, + "aa?": false, + "aap!": false, + }, + `\[ab\]`: { + "[ab]": true, + "a": false, + }, + `[\[ab]`: { + "[": true, + "a": true, + "b": true, + "c": false, + "]": false, + }, + `[\[\]]`: { + "[": true, + "]": true, + "c": false, + }, + `\\ap`: { + `\ap`: true, + `\\ap`: false, + }, + // Escape a normal char + `\foo`: { + `foo`: true, + `\foo`: false, + }, + } { + patRe := patternRE(pat) + if patRe == nil { + t.Errorf("'%v' won't match anything. Didn't expect that.\n", pat) + continue + } + for key, expected := range chk { + match := patRe.MatchString(key) + if expected != match { + t.Errorf("'%v' -> '%v'. Matches %v, should %v\n", pat, key, match, expected) + } + } + } + + // Patterns which won't match anything. + for _, pat := range []string{ + `ap[\`, // trailing \ in char class + `ap[`, // open char class + `[]ap`, // empty char class + `ap\`, // trailing \ + } { + if patternRE(pat) != nil { + t.Errorf("'%v' will match something. Didn't expect that.\n", pat) + } + } +} diff --git a/vendor/github.com/alicebob/miniredis/miniredis.go b/vendor/github.com/alicebob/miniredis/miniredis.go new file mode 100644 index 00000000..dff760c1 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/miniredis.go @@ -0,0 +1,343 @@ +// Package miniredis is a pure Go Redis test server, for use in Go unittests. +// There are no dependencies on system binaries, and every server you start +// will be empty. +// +// Start a server with `s, err := miniredis.Run()`. +// Stop it with `defer s.Close()`. +// +// Point your Redis client to `s.Addr()` or `s.Host(), s.Port()`. +// +// Set keys directly via s.Set(...) and similar commands, or use a Redis client. +// +// For direct use you can select a Redis database with either `s.Select(12); +// s.Get("foo")` or `s.DB(12).Get("foo")`. +// +package miniredis + +import ( + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/alicebob/miniredis/server" +) + +type hashKey map[string]string +type listKey []string +type setKey map[string]struct{} + +// RedisDB holds a single (numbered) Redis database. +type RedisDB struct { + master *sync.Mutex // pointer to the lock in Miniredis + id int // db id + keys map[string]string // Master map of keys with their type + stringKeys map[string]string // GET/SET &c. keys + hashKeys map[string]hashKey // MGET/MSET &c. keys + listKeys map[string]listKey // LPUSH &c. keys + setKeys map[string]setKey // SADD &c. keys + sortedsetKeys map[string]sortedSet // ZADD &c. keys + ttl map[string]time.Duration // effective TTL values + keyVersion map[string]uint // used to watch values +} + +// Miniredis is a Redis server implementation. +type Miniredis struct { + sync.Mutex + srv *server.Server + port int + password string + listen net.Listener + dbs map[int]*RedisDB + selectedDB int // DB id used in the direct Get(), Set() &c. + signal *sync.Cond + now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. +} + +type txCmd func(*server.Peer, *connCtx) + +// database id + key combo +type dbKey struct { + db int + key string +} + +// connCtx has all state for a single connection. +type connCtx struct { + selectedDB int // selected DB + authenticated bool // auth enabled and a valid AUTH seen + transaction []txCmd // transaction callbacks. Or nil. + dirtyTransaction bool // any error during QUEUEing. + watch map[dbKey]uint // WATCHed keys. +} + +// NewMiniRedis makes a new, non-started, Miniredis object. +func NewMiniRedis() *Miniredis { + m := Miniredis{ + dbs: map[int]*RedisDB{}, + } + m.signal = sync.NewCond(&m) + return &m +} + +func newRedisDB(id int, l *sync.Mutex) RedisDB { + return RedisDB{ + id: id, + master: l, + keys: map[string]string{}, + stringKeys: map[string]string{}, + hashKeys: map[string]hashKey{}, + listKeys: map[string]listKey{}, + setKeys: map[string]setKey{}, + sortedsetKeys: map[string]sortedSet{}, + ttl: map[string]time.Duration{}, + keyVersion: map[string]uint{}, + } +} + +// Run creates and Start()s a Miniredis. +func Run() (*Miniredis, error) { + m := NewMiniRedis() + return m, m.Start() +} + +// Start starts a server. It listens on a random port on localhost. See also +// Addr(). +func (m *Miniredis) Start() error { + m.Lock() + defer m.Unlock() + + s, err := server.NewServer(fmt.Sprintf("127.0.0.1:%d", m.port)) + if err != nil { + return err + } + m.srv = s + m.port = s.Addr().Port + + commandsConnection(m) + commandsGeneric(m) + commandsServer(m) + commandsString(m) + commandsHash(m) + commandsList(m) + commandsSet(m) + commandsSortedSet(m) + commandsTransaction(m) + + return nil +} + +// Restart restarts a Close()d server on the same port. Values will be +// preserved. +func (m *Miniredis) Restart() error { + return m.Start() +} + +// Close shuts down a Miniredis. +func (m *Miniredis) Close() { + m.Lock() + defer m.Unlock() + if m.srv == nil { + return + } + m.srv.Close() + m.srv = nil +} + +// RequireAuth makes every connection need to AUTH first. Disable again by +// setting an empty string. +func (m *Miniredis) RequireAuth(pw string) { + m.Lock() + defer m.Unlock() + m.password = pw +} + +// DB returns a DB by ID. +func (m *Miniredis) DB(i int) *RedisDB { + m.Lock() + defer m.Unlock() + return m.db(i) +} + +// get DB. No locks! +func (m *Miniredis) db(i int) *RedisDB { + if db, ok := m.dbs[i]; ok { + return db + } + db := newRedisDB(i, &m.Mutex) // the DB has our lock. + m.dbs[i] = &db + return &db +} + +// Addr returns '127.0.0.1:12345'. Can be given to a Dial(). See also Host() +// and Port(), which return the same things. +func (m *Miniredis) Addr() string { + m.Lock() + defer m.Unlock() + return m.srv.Addr().String() +} + +// Host returns the host part of Addr(). +func (m *Miniredis) Host() string { + m.Lock() + defer m.Unlock() + return m.srv.Addr().IP.String() +} + +// Port returns the (random) port part of Addr(). +func (m *Miniredis) Port() string { + m.Lock() + defer m.Unlock() + return strconv.Itoa(m.srv.Addr().Port) +} + +// CommandCount returns the number of processed commands. +func (m *Miniredis) CommandCount() int { + m.Lock() + defer m.Unlock() + return int(m.srv.TotalCommands()) +} + +// CurrentConnectionCount returns the number of currently connected clients. +func (m *Miniredis) CurrentConnectionCount() int { + m.Lock() + defer m.Unlock() + return m.srv.ClientsLen() +} + +// TotalConnectionCount returns the number of client connections since server start. +func (m *Miniredis) TotalConnectionCount() int { + m.Lock() + defer m.Unlock() + return int(m.srv.TotalConnections()) +} + +// FastForward decreases all TTLs by the given duration. All TTLs <= 0 will be +// expired. +func (m *Miniredis) FastForward(duration time.Duration) { + m.Lock() + defer m.Unlock() + for _, db := range m.dbs { + db.fastForward(duration) + } +} + +// Dump returns a text version of the selected DB, usable for debugging. +func (m *Miniredis) Dump() string { + m.Lock() + defer m.Unlock() + + var ( + maxLen = 60 + indent = " " + db = m.db(m.selectedDB) + r = "" + v = func(s string) string { + suffix := "" + if len(s) > maxLen { + suffix = fmt.Sprintf("...(%d)", len(s)) + s = s[:maxLen-len(suffix)] + } + return fmt.Sprintf("%q%s", s, suffix) + } + ) + for _, k := range db.allKeys() { + r += fmt.Sprintf("- %s\n", k) + t := db.t(k) + switch t { + case "string": + r += fmt.Sprintf("%s%s\n", indent, v(db.stringKeys[k])) + case "hash": + for _, hk := range db.hashFields(k) { + r += fmt.Sprintf("%s%s: %s\n", indent, hk, v(db.hashGet(k, hk))) + } + case "list": + for _, lk := range db.listKeys[k] { + r += fmt.Sprintf("%s%s\n", indent, v(lk)) + } + case "set": + for _, mk := range db.setMembers(k) { + r += fmt.Sprintf("%s%s\n", indent, v(mk)) + } + case "zset": + for _, el := range db.ssetElements(k) { + r += fmt.Sprintf("%s%f: %s\n", indent, el.score, v(el.member)) + } + default: + r += fmt.Sprintf("%s(a %s, fixme!)\n", indent, t) + } + } + return r +} + +// SetTime sets the time against which EXPIREAT values are compared. EXPIREAT +// will use time.Now() if this is not set. +func (m *Miniredis) SetTime(t time.Time) { + m.Lock() + defer m.Unlock() + m.now = t +} + +// handleAuth returns false if connection has no access. It sends the reply. +func (m *Miniredis) handleAuth(c *server.Peer) bool { + m.Lock() + defer m.Unlock() + if m.password == "" { + return true + } + if !getCtx(c).authenticated { + c.WriteError("NOAUTH Authentication required.") + return false + } + return true +} + +func getCtx(c *server.Peer) *connCtx { + if c.Ctx == nil { + c.Ctx = &connCtx{} + } + return c.Ctx.(*connCtx) +} + +func startTx(ctx *connCtx) { + ctx.transaction = []txCmd{} + ctx.dirtyTransaction = false +} + +func stopTx(ctx *connCtx) { + ctx.transaction = nil + unwatch(ctx) +} + +func inTx(ctx *connCtx) bool { + return ctx.transaction != nil +} + +func addTxCmd(ctx *connCtx, cb txCmd) { + ctx.transaction = append(ctx.transaction, cb) +} + +func watch(db *RedisDB, ctx *connCtx, key string) { + if ctx.watch == nil { + ctx.watch = map[dbKey]uint{} + } + ctx.watch[dbKey{db: db.id, key: key}] = db.keyVersion[key] // Can be 0. +} + +func unwatch(ctx *connCtx) { + ctx.watch = nil +} + +// setDirty can be called even when not in an tx. Is an no-op then. +func setDirty(c *server.Peer) { + if c.Ctx == nil { + // No transaction. Not relevant. + return + } + getCtx(c).dirtyTransaction = true +} + +func setAuthenticated(c *server.Peer) { + getCtx(c).authenticated = true +} diff --git a/vendor/github.com/alicebob/miniredis/miniredis_test.go b/vendor/github.com/alicebob/miniredis/miniredis_test.go new file mode 100644 index 00000000..7ad811dc --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/miniredis_test.go @@ -0,0 +1,200 @@ +package miniredis + +import ( + "testing" + "time" + + "github.com/garyburd/redigo/redis" +) + +// Test starting/stopping a server +func TestServer(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + _, err = c.Do("PING") + ok(t, err) + + // A single client + equals(t, 1, s.CurrentConnectionCount()) + equals(t, 1, s.TotalConnectionCount()) + equals(t, 1, s.CommandCount()) + _, err = c.Do("PING") + ok(t, err) + equals(t, 2, s.CommandCount()) +} + +func TestMultipleServers(t *testing.T) { + s1, err := Run() + ok(t, err) + s2, err := Run() + ok(t, err) + if s1.Addr() == s2.Addr() { + t.Fatal("Non-unique addresses", s1.Addr(), s2.Addr()) + } + + s2.Close() + s1.Close() + // Closing multiple times is fine + go s1.Close() + go s1.Close() + s1.Close() +} + +func TestRestart(t *testing.T) { + s, err := Run() + ok(t, err) + addr := s.Addr() + + s.Set("color", "red") + + s.Close() + err = s.Restart() + ok(t, err) + if have, want := s.Addr(), addr; have != want { + t.Fatalf("have: %s, want: %s", have, want) + } + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + _, err = c.Do("PING") + ok(t, err) + + red, err := redis.String(c.Do("GET", "color")) + ok(t, err) + if have, want := red, "red"; have != want { + t.Errorf("have: %s, want: %s", have, want) + } +} + +func TestDump(t *testing.T) { + s, err := Run() + ok(t, err) + s.Set("aap", "noot") + s.Set("vuur", "mies") + s.HSet("ahash", "aap", "noot") + s.HSet("ahash", "vuur", "mies") + if have, want := s.Dump(), `- aap + "noot" +- ahash + aap: "noot" + vuur: "mies" +- vuur + "mies" +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } + + // Tricky whitespace + s.Select(1) + s.Set("whitespace", "foo\nbar\tbaz!") + if have, want := s.Dump(), `- whitespace + "foo\nbar\tbaz!" +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } + + // Long key + s.Select(2) + s.Set("long", "This is a rather long key, with some fox jumping over a fence or something.") + s.Set("countonme", "0123456789012345678901234567890123456789012345678901234567890123456789") + s.HSet("hlong", "long", "This is another rather long key, with some fox jumping over a fence or something.") + if have, want := s.Dump(), `- countonme + "01234567890123456789012345678901234567890123456789012"...(70) +- hlong + long: "This is another rather long key, with some fox jumpin"...(81) +- long + "This is a rather long key, with some fox jumping over"...(75) +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } +} + +func TestDumpList(t *testing.T) { + s, err := Run() + ok(t, err) + s.Push("elements", "earth") + s.Push("elements", "wind") + s.Push("elements", "fire") + if have, want := s.Dump(), `- elements + "earth" + "wind" + "fire" +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } +} + +func TestDumpSet(t *testing.T) { + s, err := Run() + ok(t, err) + s.SetAdd("elements", "earth") + s.SetAdd("elements", "wind") + s.SetAdd("elements", "fire") + if have, want := s.Dump(), `- elements + "earth" + "fire" + "wind" +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } +} + +func TestDumpSortedSet(t *testing.T) { + s, err := Run() + ok(t, err) + s.ZAdd("elements", 2.0, "wind") + s.ZAdd("elements", 3.0, "earth") + s.ZAdd("elements", 1.0, "fire") + if have, want := s.Dump(), `- elements + 1.000000: "fire" + 2.000000: "wind" + 3.000000: "earth" +`; have != want { + t.Errorf("have: %q, want: %q", have, want) + } +} + +func TestKeysAndFlush(t *testing.T) { + s, err := Run() + ok(t, err) + s.Set("aap", "noot") + s.Set("vuur", "mies") + s.Set("muur", "oom") + s.HSet("hash", "key", "value") + equals(t, []string{"aap", "hash", "muur", "vuur"}, s.Keys()) + + s.Select(1) + s.Set("1aap", "1noot") + equals(t, []string{"1aap"}, s.Keys()) + + s.Select(0) + s.FlushDB() + equals(t, []string{}, s.Keys()) + s.Select(1) + equals(t, []string{"1aap"}, s.Keys()) + + s.Select(0) + s.FlushAll() + equals(t, []string{}, s.Keys()) + s.Select(1) + equals(t, []string{}, s.Keys()) +} + +func TestExpireWithFastForward(t *testing.T) { + s, err := Run() + ok(t, err) + + s.Set("aap", "noot") + s.Set("noot", "aap") + s.SetTTL("aap", 10*time.Second) + + s.FastForward(5 * time.Second) + equals(t, 2, len(s.Keys())) + + s.FastForward(5 * time.Second) + equals(t, 1, len(s.Keys())) +} diff --git a/vendor/github.com/alicebob/miniredis/redis.go b/vendor/github.com/alicebob/miniredis/redis.go new file mode 100644 index 00000000..80affc36 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/redis.go @@ -0,0 +1,199 @@ +package miniredis + +import ( + "fmt" + "math" + "strings" + "sync" + "time" + + "github.com/alicebob/miniredis/server" +) + +const ( + msgWrongType = "WRONGTYPE Operation against a key holding the wrong kind of value" + msgInvalidInt = "ERR value is not an integer or out of range" + msgInvalidFloat = "ERR value is not a valid float" + msgInvalidMinMax = "ERR min or max is not a float" + msgInvalidRangeItem = "ERR min or max not valid string range item" + msgInvalidTimeout = "ERR timeout is not an integer or out of range" + msgSyntaxError = "ERR syntax error" + msgKeyNotFound = "ERR no such key" + msgOutOfRange = "ERR index out of range" + msgInvalidCursor = "ERR invalid cursor" + msgXXandNX = "ERR XX and NX options at the same time are not compatible" + msgNegTimeout = "ERR timeout is negative" + msgInvalidSETime = "ERR invalid expire time in set" + msgInvalidSETEXTime = "ERR invalid expire time in setex" + msgInvalidPSETEXTime = "ERR invalid expire time in psetex" +) + +func errWrongNumber(cmd string) string { + return fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmd)) +} + +// withTx wraps the non-argument-checking part of command handling code in +// transaction logic. +func withTx( + m *Miniredis, + c *server.Peer, + cb txCmd, +) { + ctx := getCtx(c) + if inTx(ctx) { + addTxCmd(ctx, cb) + c.WriteInline("QUEUED") + return + } + m.Lock() + cb(c, ctx) + // done, wake up anyone who waits on anything. + m.signal.Broadcast() + m.Unlock() +} + +// blockCmd is executed returns whether it is done +type blockCmd func(*server.Peer, *connCtx) bool + +// blocking keeps trying a command until the callback returns true. Calls +// onTimeout after the timeout (or when we call this in a transaction). +func blocking( + m *Miniredis, + c *server.Peer, + timeout time.Duration, + cb blockCmd, + onTimeout func(*server.Peer), +) { + var ( + ctx = getCtx(c) + dl *time.Timer + dlc <-chan time.Time + ) + if inTx(ctx) { + addTxCmd(ctx, func(c *server.Peer, ctx *connCtx) { + if !cb(c, ctx) { + onTimeout(c) + } + }) + c.WriteInline("QUEUED") + return + } + if timeout != 0 { + dl = time.NewTimer(timeout) + defer dl.Stop() + dlc = dl.C + } + + m.Lock() + defer m.Unlock() + for { + done := cb(c, ctx) + if done { + return + } + // there is no cond.WaitTimeout(), so hence the the goroutine to wait + // for a timeout + var ( + wg sync.WaitGroup + wakeup = make(chan struct{}, 1) + ) + wg.Add(1) + go func() { + m.signal.Wait() + wakeup <- struct{}{} + wg.Done() + }() + select { + case <-wakeup: + case <-dlc: + onTimeout(c) + m.signal.Broadcast() // to kill the wakeup go routine + wg.Wait() + return + } + wg.Wait() + } +} + +// formatFloat formats a float the way redis does (sort-of) +func formatFloat(v float64) string { + // Format with %f and strip trailing 0s. This is the most like Redis does + // it :( + // .12 is the magic number where most output is the same as Redis. + if math.IsInf(v, +1) { + return "inf" + } + if math.IsInf(v, -1) { + return "-inf" + } + sv := fmt.Sprintf("%.12f", v) + for strings.Contains(sv, ".") { + if sv[len(sv)-1] != '0' { + break + } + // Remove trailing 0s. + sv = sv[:len(sv)-1] + // Ends with a '.'. + if sv[len(sv)-1] == '.' { + sv = sv[:len(sv)-1] + break + } + } + return sv +} + +// redisRange gives Go offsets for something l long with start/end in +// Redis semantics. Both start and end can be negative. +// Used for string range and list range things. +// The results can be used as: v[start:end] +// Note that GETRANGE (on a string key) never returns an empty string when end +// is a large negative number. +func redisRange(l, start, end int, stringSymantics bool) (int, int) { + if start < 0 { + start = l + start + if start < 0 { + start = 0 + } + } + if start > l { + start = l + } + + if end < 0 { + end = l + end + if end < 0 { + end = -1 + if stringSymantics { + end = 0 + } + } + } + end++ // end argument is inclusive in Redis. + if end > l { + end = l + } + + if end < start { + return 0, 0 + } + return start, end +} + +// matchKeys filters only matching keys. +// Will return an empty list on invalid match expression. +func matchKeys(keys []string, match string) []string { + re := patternRE(match) + if re == nil { + // Special case, the given pattern won't match anything / is + // invalid. + return nil + } + res := []string{} + for _, k := range keys { + if !re.MatchString(k) { + continue + } + res = append(res, k) + } + return res +} diff --git a/vendor/github.com/alicebob/miniredis/server/Makefile b/vendor/github.com/alicebob/miniredis/server/Makefile new file mode 100644 index 00000000..c82e336f --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/server/Makefile @@ -0,0 +1,9 @@ +.PHONY: all build test + +all: build test + +build: + go build + +test: + go test diff --git a/vendor/github.com/alicebob/miniredis/server/proto.go b/vendor/github.com/alicebob/miniredis/server/proto.go new file mode 100644 index 00000000..208813cb --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/server/proto.go @@ -0,0 +1,79 @@ +package server + +import ( + "bufio" + "errors" + "strconv" +) + +// ErrProtocol is the general error for unexpected input +var ErrProtocol = errors.New("invalid request") + +// client always sends arrays with bulk strings +func readArray(rd *bufio.Reader) ([]string, error) { + line, err := rd.ReadString('\n') + if err != nil { + return nil, err + } + if len(line) < 3 { + return nil, ErrProtocol + } + + switch line[0] { + default: + return nil, ErrProtocol + case '*': + l, err := strconv.Atoi(line[1 : len(line)-2]) + if err != nil { + return nil, err + } + // l can be -1 + var fields []string + for ; l > 0; l-- { + s, err := readString(rd) + if err != nil { + return nil, err + } + fields = append(fields, s) + } + return fields, nil + } +} + +func readString(rd *bufio.Reader) (string, error) { + line, err := rd.ReadString('\n') + if err != nil { + return "", err + } + if len(line) < 3 { + return "", ErrProtocol + } + + switch line[0] { + default: + return "", ErrProtocol + case '+', '-', ':': + // +: simple string + // -: errors + // :: integer + // Simple line based replies. + return string(line[1 : len(line)-2]), nil + case '$': + // bulk strings are: `$5\r\nhello\r\n` + length, err := strconv.Atoi(line[1 : len(line)-2]) + if err != nil { + return "", err + } + if length < 0 { + // -1 is a nil response + return "", nil + } + buf := make([]byte, length+2) + if n, err := rd.Read(buf); err != nil { + return "", err + } else if n != length+2 { + return "", ErrProtocol + } + return string(buf[:len(buf)-2]), nil + } +} diff --git a/vendor/github.com/alicebob/miniredis/server/proto_test.go b/vendor/github.com/alicebob/miniredis/server/proto_test.go new file mode 100644 index 00000000..9cf89cb1 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/server/proto_test.go @@ -0,0 +1,50 @@ +package server + +import ( + "bufio" + "bytes" + "io" + "reflect" + "testing" +) + +func TestReadArray(t *testing.T) { + type cas struct { + payload string + err error + res []string + } + for i, c := range []cas{ + { + payload: "*1\r\n$4\r\nPING\r\n", + res: []string{"PING"}, + }, + { + payload: "*2\r\n$4\r\nLLEN\r\n$6\r\nmylist\r\n", + res: []string{"LLEN", "mylist"}, + }, + { + payload: "*2\r\n$4\r\nLLEN\r\n$6\r\nmyl", + err: ErrProtocol, + }, + { + payload: "PING", + err: io.EOF, + }, + { + payload: "*0\r\n", + }, + { + payload: "*-1\r\n", // not sure this is legal in a request + }, + } { + res, err := readArray(bufio.NewReader(bytes.NewBufferString(c.payload))) + if have, want := err, c.err; have != want { + t.Errorf("err %d: have %v, want %v", i, have, want) + continue + } + if have, want := res, c.res; !reflect.DeepEqual(have, want) { + t.Errorf("case %d: have %v, want %v", i, have, want) + } + } +} diff --git a/vendor/github.com/alicebob/miniredis/server/server.go b/vendor/github.com/alicebob/miniredis/server/server.go new file mode 100644 index 00000000..9123017a --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/server/server.go @@ -0,0 +1,219 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "strings" + "sync" +) + +func errUnknownCommand(cmd string) string { + return fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)) +} + +// Cmd is what Register expects +type Cmd func(c *Peer, cmd string, args []string) + +// Server is a simple redis server +type Server struct { + l net.Listener + cmds map[string]Cmd + peers map[net.Conn]struct{} + mu sync.Mutex + wg sync.WaitGroup + infoConns int + infoCmds int +} + +// NewServer makes a server listening on addr. Close with .Close(). +func NewServer(addr string) (*Server, error) { + s := Server{ + cmds: map[string]Cmd{}, + peers: map[net.Conn]struct{}{}, + } + + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + s.l = l + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.serve(l) + }() + return &s, nil +} + +func (s *Server) serve(l net.Listener) { + for { + conn, err := l.Accept() + if err != nil { + return + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer conn.Close() + s.mu.Lock() + s.peers[conn] = struct{}{} + s.infoConns++ + s.mu.Unlock() + + s.servePeer(conn) + + s.mu.Lock() + delete(s.peers, conn) + s.mu.Unlock() + }() + } +} + +// Addr has the net.Addr struct +func (s *Server) Addr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + return nil + } + return s.l.Addr().(*net.TCPAddr) +} + +// Close a server started with NewServer. It will wait until all clients are +// closed. +func (s *Server) Close() { + s.mu.Lock() + if s.l != nil { + s.l.Close() + } + s.l = nil + for c := range s.peers { + c.Close() + } + s.mu.Unlock() + s.wg.Wait() +} + +// Register a command. It can't have been registered before. Safe to call on a +// running server. +func (s *Server) Register(cmd string, f Cmd) error { + s.mu.Lock() + defer s.mu.Unlock() + cmd = strings.ToUpper(cmd) + if _, ok := s.cmds[cmd]; ok { + return fmt.Errorf("command already registered: %s", cmd) + } + s.cmds[cmd] = f + return nil +} + +func (s *Server) servePeer(c net.Conn) { + r := bufio.NewReader(c) + cl := &Peer{ + w: bufio.NewWriter(c), + } + for { + args, err := readArray(r) + if err != nil { + return + } + s.dispatch(cl, args) + cl.w.Flush() + if cl.closed { + c.Close() + return + } + } +} + +func (s *Server) dispatch(c *Peer, args []string) { + cmd, args := strings.ToUpper(args[0]), args[1:] + s.mu.Lock() + cb, ok := s.cmds[cmd] + s.mu.Unlock() + if !ok { + c.WriteError(errUnknownCommand(cmd)) + return + } + + s.mu.Lock() + s.infoCmds++ + s.mu.Unlock() + cb(c, cmd, args) +} + +// TotalCommands is total (known) commands since this the server started +func (s *Server) TotalCommands() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.infoCmds +} + +// ClientsLen gives the number of connected clients right now +func (s *Server) ClientsLen() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.peers) +} + +// TotalConnections give the number of clients connected since the server +// started, including the currently connected ones +func (s *Server) TotalConnections() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.infoConns +} + +// Peer is a client connected to the server +type Peer struct { + w *bufio.Writer + closed bool + Ctx interface{} // anything goes, server won't touch this +} + +// Flush the write buffer. Called automatically after every redis command +func (c *Peer) Flush() { + c.w.Flush() +} + +// Close the client connection after the current command is done. +func (c *Peer) Close() { + c.closed = true +} + +// WriteError writes a redis 'Error' +func (c *Peer) WriteError(e string) { + fmt.Fprintf(c.w, "-%s\r\n", e) +} + +// WriteInline writes a redis inline string +func (c *Peer) WriteInline(s string) { + fmt.Fprintf(c.w, "+%s\r\n", s) +} + +// WriteOK write the inline string `OK` +func (c *Peer) WriteOK() { + c.WriteInline("OK") +} + +// WriteBulk writes a bulk string +func (c *Peer) WriteBulk(s string) { + fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s) +} + +// WriteNull writes a redis Null element +func (c *Peer) WriteNull() { + fmt.Fprintf(c.w, "$-1\r\n") +} + +// WriteLen starts an array with the given length +func (c *Peer) WriteLen(n int) { + fmt.Fprintf(c.w, "*%d\r\n", n) +} + +// WriteInt writes an integer +func (c *Peer) WriteInt(i int) { + fmt.Fprintf(c.w, ":%d\r\n", i) +} diff --git a/vendor/github.com/alicebob/miniredis/server/server_test.go b/vendor/github.com/alicebob/miniredis/server/server_test.go new file mode 100644 index 00000000..064662c5 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/server/server_test.go @@ -0,0 +1,147 @@ +package server + +import ( + "fmt" + "reflect" + "strconv" + "testing" + + "github.com/garyburd/redigo/redis" +) + +const ( + errWrongNumberOfArgs = "ERR Wrong number of args" +) + +func Test(t *testing.T) { + s, err := NewServer(":0") + if err != nil { + t.Fatal(err) + } + defer s.Close() + + if have := s.Addr().Port; have <= 0 { + t.Fatalf("have %v, want > 0", have) + } + + s.Register("PING", func(c *Peer, cmd string, args []string) { + c.WriteInline("PONG") + }) + s.Register("ECHO", func(c *Peer, cmd string, args []string) { + if len(args) != 1 { + c.WriteError(errWrongNumberOfArgs) + return + } + c.WriteBulk(args[0]) + }) + s.Register("dWaRfS", func(c *Peer, cmd string, args []string) { + if len(args) != 0 { + c.WriteError(errWrongNumberOfArgs) + return + } + c.WriteLen(7) + c.WriteBulk("Blick") + c.WriteBulk("Flick") + c.WriteBulk("Glick") + c.WriteBulk("Plick") + c.WriteBulk("Quee") + c.WriteBulk("Snick") + c.WriteBulk("Whick") + }) + s.Register("PLUS", func(c *Peer, cmd string, args []string) { + if len(args) != 2 { + c.WriteError(errWrongNumberOfArgs) + return + } + a, err := strconv.Atoi(args[0]) + if err != nil { + c.WriteError(fmt.Sprintf("ERR not an int: %q", args[0])) + return + } + b, err := strconv.Atoi(args[1]) + if err != nil { + c.WriteError(fmt.Sprintf("ERR not an int: %q", args[1])) + return + } + c.WriteInt(a + b) + }) + s.Register("NULL", func(c *Peer, cmd string, args []string) { + c.WriteNull() + }) + + c, err := redis.Dial("tcp", s.Addr().String()) + if err != nil { + t.Fatal(err) + } + + { + res, err := redis.String(c.Do("PING")) + if err != nil { + t.Fatal(err) + } + if have, want := res, "PONG"; have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + _, err := c.Do("NOSUCH") + if have, want := err.Error(), "ERR unknown command 'nosuch'"; have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + res, err := redis.String(c.Do("pInG")) + if err != nil { + t.Fatal(err) + } + if have, want := res, "PONG"; have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + echo, err := redis.String(c.Do("ECHO", "hello\nworld")) + if err != nil { + t.Fatal(err) + } + if have, want := echo, "hello\nworld"; have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + _, err := c.Do("ECHO") + if have, want := err.Error(), errWrongNumberOfArgs; have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + dwarfs, err := redis.Strings(c.Do("dwaRFS")) + if err != nil { + t.Fatal(err) + } + if have, want := dwarfs, []string{"Blick", + "Flick", + "Glick", + "Plick", + "Quee", + "Snick", + "Whick", + }; !reflect.DeepEqual(have, want) { + t.Errorf("have: %s, want: %s", have, want) + } + } + + { + res, err := c.Do("NULL") + if err != nil { + t.Fatal(err) + } + if have, want := res, interface{}(nil); have != want { + t.Errorf("have: %s, want: %s", have, want) + } + } +} diff --git a/vendor/github.com/alicebob/miniredis/sorted_set.go b/vendor/github.com/alicebob/miniredis/sorted_set.go new file mode 100644 index 00000000..9b1894d8 --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/sorted_set.go @@ -0,0 +1,97 @@ +package miniredis + +// The most KISS way to implement a sorted set. Luckily we don't care about +// performance that much. + +import ( + "sort" +) + +type direction int + +const ( + asc direction = iota + desc +) + +type sortedSet map[string]float64 + +type ssElem struct { + score float64 + member string +} +type ssElems []ssElem + +type byScore ssElems + +func (sse byScore) Len() int { return len(sse) } +func (sse byScore) Swap(i, j int) { sse[i], sse[j] = sse[j], sse[i] } +func (sse byScore) Less(i, j int) bool { + if sse[i].score != sse[j].score { + return sse[i].score < sse[j].score + } + return sse[i].member < sse[j].member +} + +func newSortedSet() sortedSet { + return sortedSet{} +} + +func (ss *sortedSet) card() int { + return len(*ss) +} + +func (ss *sortedSet) set(score float64, member string) { + (*ss)[member] = score +} + +func (ss *sortedSet) get(member string) (float64, bool) { + v, ok := (*ss)[member] + return v, ok +} + +// elems gives the list of ssElem, ready to sort. +func (ss *sortedSet) elems() ssElems { + elems := make(ssElems, 0, len(*ss)) + for e, s := range *ss { + elems = append(elems, ssElem{s, e}) + } + return elems +} + +func (ss *sortedSet) byScore(d direction) ssElems { + elems := ss.elems() + sort.Sort(byScore(elems)) + if d == desc { + reverseElems(elems) + } + return ssElems(elems) +} + +// rankByScore gives the (0-based) index of member, or returns false. +func (ss *sortedSet) rankByScore(member string, d direction) (int, bool) { + if _, ok := (*ss)[member]; !ok { + return 0, false + } + for i, e := range ss.byScore(d) { + if e.member == member { + return i, true + } + } + // Can't happen + return 0, false +} + +func reverseSlice(o []string) { + for i := range make([]struct{}, len(o)/2) { + other := len(o) - 1 - i + o[i], o[other] = o[other], o[i] + } +} + +func reverseElems(o ssElems) { + for i := range make([]struct{}, len(o)/2) { + other := len(o) - 1 - i + o[i], o[other] = o[other], o[i] + } +} diff --git a/vendor/github.com/alicebob/miniredis/sorted_set_test.go b/vendor/github.com/alicebob/miniredis/sorted_set_test.go new file mode 100644 index 00000000..a29101bd --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/sorted_set_test.go @@ -0,0 +1,86 @@ +package miniredis + +import ( + "testing" +) + +func TestSortedSetImpl(t *testing.T) { + s := newSortedSet() + equals(t, 0, s.card()) + s.set(3.1415, "pi") + s.set(2*3.1415, "2pi") + s.set(3*3.1415, "3pi") + equals(t, 3, s.card()) + // replace works? + s.set(3.141592, "pi") + equals(t, 3, s.card()) + + // Get a key + { + pi, ok := s.get("pi") + assert(t, ok, "got pi") + equals(t, 3.141592, pi) + } + + // Set ordered by score + { + elems := s.byScore(asc) + equals(t, 3, len(elems)) + equals(t, ssElems{ + {3.141592, "pi"}, + {2 * 3.1415, "2pi"}, + {3 * 3.1415, "3pi"}, + }, elems) + } + + // Rank of a key + { + rank, found := s.rankByScore("pi", asc) + assert(t, found, "Found pi") + equals(t, 0, rank) + + rank, found = s.rankByScore("3pi", desc) + assert(t, found, "Found 3pi") + equals(t, 0, rank) + + rank, found = s.rankByScore("3pi", asc) + assert(t, found, "Found 3pi") + equals(t, 2, rank) + + _, found = s.rankByScore("nosuch", asc) + assert(t, !found, "Did not find nosuch") + } +} + +func TestSortOrder(t *testing.T) { + // Keys with the same key should be sorted lexicographically + s := newSortedSet() + equals(t, 0, s.card()) + s.set(1, "one") + s.set(1, "1") + s.set(1, "eins") + s.set(2, "two") + s.set(2, "2") + s.set(2, "zwei") + s.set(3, "three") + s.set(3, "3") + s.set(3, "drei") + equals(t, 9, s.card()) + + // Set ordered by score, member + { + elems := s.byScore(asc) + equals(t, 9, len(elems)) + equals(t, ssElems{ + {1, "1"}, + {1, "eins"}, + {1, "one"}, + {2, "2"}, + {2, "two"}, + {2, "zwei"}, + {3, "3"}, + {3, "drei"}, + {3, "three"}, + }, elems) + } +} diff --git a/vendor/github.com/alicebob/miniredis/test.go b/vendor/github.com/alicebob/miniredis/test.go new file mode 100644 index 00000000..06d0ffbd --- /dev/null +++ b/vendor/github.com/alicebob/miniredis/test.go @@ -0,0 +1,36 @@ +package miniredis + +import ( + "fmt" + "path/filepath" + "reflect" + "runtime" + "testing" +) + +// assert fails the test if the condition is false. +func assert(tb testing.TB, condition bool, msg string, v ...interface{}) { + if !condition { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("%s:%d: "+msg+"\n", append([]interface{}{filepath.Base(file), line}, v...)...) + tb.FailNow() + } +} + +// ok fails the test if an err is not nil. +func ok(tb testing.TB, err error) { + if err != nil { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("%s:%d: unexpected error: %s\n", filepath.Base(file), line, err.Error()) + tb.FailNow() + } +} + +// equals fails the test if exp is not equal to act. +func equals(tb testing.TB, exp, act interface{}) { + if !reflect.DeepEqual(exp, act) { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("%s:%d: expected: %#v got: %#v\n", filepath.Base(file), line, exp, act) + tb.FailNow() + } +} diff --git a/vendor/github.com/beorn7/perks/.gitignore b/vendor/github.com/beorn7/perks/.gitignore new file mode 100644 index 00000000..1bd9209a --- /dev/null +++ b/vendor/github.com/beorn7/perks/.gitignore @@ -0,0 +1,2 @@ +*.test +*.prof diff --git a/vendor/github.com/beorn7/perks/LICENSE b/vendor/github.com/beorn7/perks/LICENSE new file mode 100644 index 00000000..339177be --- /dev/null +++ b/vendor/github.com/beorn7/perks/LICENSE @@ -0,0 +1,20 @@ +Copyright (C) 2013 Blake Mizerany + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/beorn7/perks/README.md b/vendor/github.com/beorn7/perks/README.md new file mode 100644 index 00000000..fc057777 --- /dev/null +++ b/vendor/github.com/beorn7/perks/README.md @@ -0,0 +1,31 @@ +# Perks for Go (golang.org) + +Perks contains the Go package quantile that computes approximate quantiles over +an unbounded data stream within low memory and CPU bounds. + +For more information and examples, see: +http://godoc.org/github.com/bmizerany/perks + +A very special thank you and shout out to Graham Cormode (Rutgers University), +Flip Korn (AT&T Labs–Research), S. Muthukrishnan (Rutgers University), and +Divesh Srivastava (AT&T Labs–Research) for their research and publication of +[Effective Computation of Biased Quantiles over Data Streams](http://www.cs.rutgers.edu/~muthu/bquant.pdf) + +Thank you, also: +* Armon Dadgar (@armon) +* Andrew Gerrand (@nf) +* Brad Fitzpatrick (@bradfitz) +* Keith Rarick (@kr) + +FAQ: + +Q: Why not move the quantile package into the project root? +A: I want to add more packages to perks later. + +Copyright (C) 2013 Blake Mizerany + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/beorn7/perks/histogram/bench_test.go b/vendor/github.com/beorn7/perks/histogram/bench_test.go new file mode 100644 index 00000000..56c7e551 --- /dev/null +++ b/vendor/github.com/beorn7/perks/histogram/bench_test.go @@ -0,0 +1,26 @@ +package histogram + +import ( + "math/rand" + "testing" +) + +func BenchmarkInsert10Bins(b *testing.B) { + b.StopTimer() + h := New(10) + b.StartTimer() + for i := 0; i < b.N; i++ { + f := rand.ExpFloat64() + h.Insert(f) + } +} + +func BenchmarkInsert100Bins(b *testing.B) { + b.StopTimer() + h := New(100) + b.StartTimer() + for i := 0; i < b.N; i++ { + f := rand.ExpFloat64() + h.Insert(f) + } +} diff --git a/vendor/github.com/beorn7/perks/histogram/histogram.go b/vendor/github.com/beorn7/perks/histogram/histogram.go new file mode 100644 index 00000000..bef05c70 --- /dev/null +++ b/vendor/github.com/beorn7/perks/histogram/histogram.go @@ -0,0 +1,108 @@ +// Package histogram provides a Go implementation of BigML's histogram package +// for Clojure/Java. It is currently experimental. +package histogram + +import ( + "container/heap" + "math" + "sort" +) + +type Bin struct { + Count int + Sum float64 +} + +func (b *Bin) Update(x *Bin) { + b.Count += x.Count + b.Sum += x.Sum +} + +func (b *Bin) Mean() float64 { + return b.Sum / float64(b.Count) +} + +type Bins []*Bin + +func (bs Bins) Len() int { return len(bs) } +func (bs Bins) Less(i, j int) bool { return bs[i].Mean() < bs[j].Mean() } +func (bs Bins) Swap(i, j int) { bs[i], bs[j] = bs[j], bs[i] } + +func (bs *Bins) Push(x interface{}) { + *bs = append(*bs, x.(*Bin)) +} + +func (bs *Bins) Pop() interface{} { + return bs.remove(len(*bs) - 1) +} + +func (bs *Bins) remove(n int) *Bin { + if n < 0 || len(*bs) < n { + return nil + } + x := (*bs)[n] + *bs = append((*bs)[:n], (*bs)[n+1:]...) + return x +} + +type Histogram struct { + res *reservoir +} + +func New(maxBins int) *Histogram { + return &Histogram{res: newReservoir(maxBins)} +} + +func (h *Histogram) Insert(f float64) { + h.res.insert(&Bin{1, f}) + h.res.compress() +} + +func (h *Histogram) Bins() Bins { + return h.res.bins +} + +type reservoir struct { + n int + maxBins int + bins Bins +} + +func newReservoir(maxBins int) *reservoir { + return &reservoir{maxBins: maxBins} +} + +func (r *reservoir) insert(bin *Bin) { + r.n += bin.Count + i := sort.Search(len(r.bins), func(i int) bool { + return r.bins[i].Mean() >= bin.Mean() + }) + if i < 0 || i == r.bins.Len() { + // TODO(blake): Maybe use an .insert(i, bin) instead of + // performing the extra work of a heap.Push. + heap.Push(&r.bins, bin) + return + } + r.bins[i].Update(bin) +} + +func (r *reservoir) compress() { + for r.bins.Len() > r.maxBins { + minGapIndex := -1 + minGap := math.MaxFloat64 + for i := 0; i < r.bins.Len()-1; i++ { + gap := gapWeight(r.bins[i], r.bins[i+1]) + if minGap > gap { + minGap = gap + minGapIndex = i + } + } + prev := r.bins[minGapIndex] + next := r.bins.remove(minGapIndex + 1) + prev.Update(next) + } +} + +func gapWeight(prev, next *Bin) float64 { + return next.Mean() - prev.Mean() +} diff --git a/vendor/github.com/beorn7/perks/histogram/histogram_test.go b/vendor/github.com/beorn7/perks/histogram/histogram_test.go new file mode 100644 index 00000000..0575ebee --- /dev/null +++ b/vendor/github.com/beorn7/perks/histogram/histogram_test.go @@ -0,0 +1,38 @@ +package histogram + +import ( + "math/rand" + "testing" +) + +func TestHistogram(t *testing.T) { + const numPoints = 1e6 + const maxBins = 3 + + h := New(maxBins) + for i := 0; i < numPoints; i++ { + f := rand.ExpFloat64() + h.Insert(f) + } + + bins := h.Bins() + if g := len(bins); g > maxBins { + t.Fatalf("got %d bins, wanted <= %d", g, maxBins) + } + + for _, b := range bins { + t.Logf("%+v", b) + } + + if g := count(h.Bins()); g != numPoints { + t.Fatalf("binned %d points, wanted %d", g, numPoints) + } +} + +func count(bins Bins) int { + binCounts := 0 + for _, b := range bins { + binCounts += b.Count + } + return binCounts +} diff --git a/vendor/github.com/beorn7/perks/quantile/bench_test.go b/vendor/github.com/beorn7/perks/quantile/bench_test.go new file mode 100644 index 00000000..0bd0e4e7 --- /dev/null +++ b/vendor/github.com/beorn7/perks/quantile/bench_test.go @@ -0,0 +1,63 @@ +package quantile + +import ( + "testing" +) + +func BenchmarkInsertTargeted(b *testing.B) { + b.ReportAllocs() + + s := NewTargeted(Targets) + b.ResetTimer() + for i := float64(0); i < float64(b.N); i++ { + s.Insert(i) + } +} + +func BenchmarkInsertTargetedSmallEpsilon(b *testing.B) { + s := NewTargeted(TargetsSmallEpsilon) + b.ResetTimer() + for i := float64(0); i < float64(b.N); i++ { + s.Insert(i) + } +} + +func BenchmarkInsertBiased(b *testing.B) { + s := NewLowBiased(0.01) + b.ResetTimer() + for i := float64(0); i < float64(b.N); i++ { + s.Insert(i) + } +} + +func BenchmarkInsertBiasedSmallEpsilon(b *testing.B) { + s := NewLowBiased(0.0001) + b.ResetTimer() + for i := float64(0); i < float64(b.N); i++ { + s.Insert(i) + } +} + +func BenchmarkQuery(b *testing.B) { + s := NewTargeted(Targets) + for i := float64(0); i < 1e6; i++ { + s.Insert(i) + } + b.ResetTimer() + n := float64(b.N) + for i := float64(0); i < n; i++ { + s.Query(i / n) + } +} + +func BenchmarkQuerySmallEpsilon(b *testing.B) { + s := NewTargeted(TargetsSmallEpsilon) + for i := float64(0); i < 1e6; i++ { + s.Insert(i) + } + b.ResetTimer() + n := float64(b.N) + for i := float64(0); i < n; i++ { + s.Query(i / n) + } +} diff --git a/vendor/github.com/beorn7/perks/quantile/example_test.go b/vendor/github.com/beorn7/perks/quantile/example_test.go new file mode 100644 index 00000000..ab3293aa --- /dev/null +++ b/vendor/github.com/beorn7/perks/quantile/example_test.go @@ -0,0 +1,121 @@ +// +build go1.1 + +package quantile_test + +import ( + "bufio" + "fmt" + "log" + "os" + "strconv" + "time" + + "github.com/beorn7/perks/quantile" +) + +func Example_simple() { + ch := make(chan float64) + go sendFloats(ch) + + // Compute the 50th, 90th, and 99th percentile. + q := quantile.NewTargeted(map[float64]float64{ + 0.50: 0.005, + 0.90: 0.001, + 0.99: 0.0001, + }) + for v := range ch { + q.Insert(v) + } + + fmt.Println("perc50:", q.Query(0.50)) + fmt.Println("perc90:", q.Query(0.90)) + fmt.Println("perc99:", q.Query(0.99)) + fmt.Println("count:", q.Count()) + // Output: + // perc50: 5 + // perc90: 16 + // perc99: 223 + // count: 2388 +} + +func Example_mergeMultipleStreams() { + // Scenario: + // We have multiple database shards. On each shard, there is a process + // collecting query response times from the database logs and inserting + // them into a Stream (created via NewTargeted(0.90)), much like the + // Simple example. These processes expose a network interface for us to + // ask them to serialize and send us the results of their + // Stream.Samples so we may Merge and Query them. + // + // NOTES: + // * These sample sets are small, allowing us to get them + // across the network much faster than sending the entire list of data + // points. + // + // * For this to work correctly, we must supply the same quantiles + // a priori the process collecting the samples supplied to NewTargeted, + // even if we do not plan to query them all here. + ch := make(chan quantile.Samples) + getDBQuerySamples(ch) + q := quantile.NewTargeted(map[float64]float64{0.90: 0.001}) + for samples := range ch { + q.Merge(samples) + } + fmt.Println("perc90:", q.Query(0.90)) +} + +func Example_window() { + // Scenario: We want the 90th, 95th, and 99th percentiles for each + // minute. + + ch := make(chan float64) + go sendStreamValues(ch) + + tick := time.NewTicker(1 * time.Minute) + q := quantile.NewTargeted(map[float64]float64{ + 0.90: 0.001, + 0.95: 0.0005, + 0.99: 0.0001, + }) + for { + select { + case t := <-tick.C: + flushToDB(t, q.Samples()) + q.Reset() + case v := <-ch: + q.Insert(v) + } + } +} + +func sendStreamValues(ch chan float64) { + // Use your imagination +} + +func flushToDB(t time.Time, samples quantile.Samples) { + // Use your imagination +} + +// This is a stub for the above example. In reality this would hit the remote +// servers via http or something like it. +func getDBQuerySamples(ch chan quantile.Samples) {} + +func sendFloats(ch chan<- float64) { + f, err := os.Open("exampledata.txt") + if err != nil { + log.Fatal(err) + } + sc := bufio.NewScanner(f) + for sc.Scan() { + b := sc.Bytes() + v, err := strconv.ParseFloat(string(b), 64) + if err != nil { + log.Fatal(err) + } + ch <- v + } + if sc.Err() != nil { + log.Fatal(sc.Err()) + } + close(ch) +} diff --git a/vendor/github.com/beorn7/perks/quantile/exampledata.txt b/vendor/github.com/beorn7/perks/quantile/exampledata.txt new file mode 100644 index 00000000..1602287d --- /dev/null +++ b/vendor/github.com/beorn7/perks/quantile/exampledata.txt @@ -0,0 +1,2388 @@ +8 +5 +26 +12 +5 +235 +13 +6 +28 +30 +3 +3 +3 +3 +5 +2 +33 +7 +2 +4 +7 +12 +14 +5 +8 +3 +10 +4 +5 +3 +6 +6 +209 +20 +3 +10 +14 +3 +4 +6 +8 +5 +11 +7 +3 +2 +3 +3 +212 +5 +222 +4 +10 +10 +5 +6 +3 +8 +3 +10 +254 +220 +2 +3 +5 +24 +5 +4 +222 +7 +3 +3 +223 +8 +15 +12 +14 +14 +3 +2 +2 +3 +13 +3 +11 +4 +4 +6 +5 +7 +13 +5 +3 +5 +2 +5 +3 +5 +2 +7 +15 +17 +14 +3 +6 +6 +3 +17 +5 +4 +7 +6 +4 +4 +8 +6 +8 +3 +9 +3 +6 +3 +4 +5 +3 +3 +660 +4 +6 +10 +3 +6 +3 +2 +5 +13 +2 +4 +4 +10 +4 +8 +4 +3 +7 +9 +9 +3 +10 +37 +3 +13 +4 +12 +3 +6 +10 +8 +5 +21 +2 +3 +8 +3 +2 +3 +3 +4 +12 +2 +4 +8 +8 +4 +3 +2 +20 +1 +6 +32 +2 +11 +6 +18 +3 +8 +11 +3 +212 +3 +4 +2 +6 +7 +12 +11 +3 +2 +16 +10 +6 +4 +6 +3 +2 +7 +3 +2 +2 +2 +2 +5 +6 +4 +3 +10 +3 +4 +6 +5 +3 +4 +4 +5 +6 +4 +3 +4 +4 +5 +7 +5 +5 +3 +2 +7 +2 +4 +12 +4 +5 +6 +2 +4 +4 +8 +4 +15 +13 +7 +16 +5 +3 +23 +5 +5 +7 +3 +2 +9 +8 +7 +5 +8 +11 +4 +10 +76 +4 +47 +4 +3 +2 +7 +4 +2 +3 +37 +10 +4 +2 +20 +5 +4 +4 +10 +10 +4 +3 +7 +23 +240 +7 +13 +5 +5 +3 +3 +2 +5 +4 +2 +8 +7 +19 +2 +23 +8 +7 +2 +5 +3 +8 +3 +8 +13 +5 +5 +5 +2 +3 +23 +4 +9 +8 +4 +3 +3 +5 +220 +2 +3 +4 +6 +14 +3 +53 +6 +2 +5 +18 +6 +3 +219 +6 +5 +2 +5 +3 +6 +5 +15 +4 +3 +17 +3 +2 +4 +7 +2 +3 +3 +4 +4 +3 +2 +664 +6 +3 +23 +5 +5 +16 +5 +8 +2 +4 +2 +24 +12 +3 +2 +3 +5 +8 +3 +5 +4 +3 +14 +3 +5 +8 +2 +3 +7 +9 +4 +2 +3 +6 +8 +4 +3 +4 +6 +5 +3 +3 +6 +3 +19 +4 +4 +6 +3 +6 +3 +5 +22 +5 +4 +4 +3 +8 +11 +4 +9 +7 +6 +13 +4 +4 +4 +6 +17 +9 +3 +3 +3 +4 +3 +221 +5 +11 +3 +4 +2 +12 +6 +3 +5 +7 +5 +7 +4 +9 +7 +14 +37 +19 +217 +16 +3 +5 +2 +2 +7 +19 +7 +6 +7 +4 +24 +5 +11 +4 +7 +7 +9 +13 +3 +4 +3 +6 +28 +4 +4 +5 +5 +2 +5 +6 +4 +4 +6 +10 +5 +4 +3 +2 +3 +3 +6 +5 +5 +4 +3 +2 +3 +7 +4 +6 +18 +16 +8 +16 +4 +5 +8 +6 +9 +13 +1545 +6 +215 +6 +5 +6 +3 +45 +31 +5 +2 +2 +4 +3 +3 +2 +5 +4 +3 +5 +7 +7 +4 +5 +8 +5 +4 +749 +2 +31 +9 +11 +2 +11 +5 +4 +4 +7 +9 +11 +4 +5 +4 +7 +3 +4 +6 +2 +15 +3 +4 +3 +4 +3 +5 +2 +13 +5 +5 +3 +3 +23 +4 +4 +5 +7 +4 +13 +2 +4 +3 +4 +2 +6 +2 +7 +3 +5 +5 +3 +29 +5 +4 +4 +3 +10 +2 +3 +79 +16 +6 +6 +7 +7 +3 +5 +5 +7 +4 +3 +7 +9 +5 +6 +5 +9 +6 +3 +6 +4 +17 +2 +10 +9 +3 +6 +2 +3 +21 +22 +5 +11 +4 +2 +17 +2 +224 +2 +14 +3 +4 +4 +2 +4 +4 +4 +4 +5 +3 +4 +4 +10 +2 +6 +3 +3 +5 +7 +2 +7 +5 +6 +3 +218 +2 +2 +5 +2 +6 +3 +5 +222 +14 +6 +33 +3 +2 +5 +3 +3 +3 +9 +5 +3 +3 +2 +7 +4 +3 +4 +3 +5 +6 +5 +26 +4 +13 +9 +7 +3 +221 +3 +3 +4 +4 +4 +4 +2 +18 +5 +3 +7 +9 +6 +8 +3 +10 +3 +11 +9 +5 +4 +17 +5 +5 +6 +6 +3 +2 +4 +12 +17 +6 +7 +218 +4 +2 +4 +10 +3 +5 +15 +3 +9 +4 +3 +3 +6 +29 +3 +3 +4 +5 +5 +3 +8 +5 +6 +6 +7 +5 +3 +5 +3 +29 +2 +31 +5 +15 +24 +16 +5 +207 +4 +3 +3 +2 +15 +4 +4 +13 +5 +5 +4 +6 +10 +2 +7 +8 +4 +6 +20 +5 +3 +4 +3 +12 +12 +5 +17 +7 +3 +3 +3 +6 +10 +3 +5 +25 +80 +4 +9 +3 +2 +11 +3 +3 +2 +3 +8 +7 +5 +5 +19 +5 +3 +3 +12 +11 +2 +6 +5 +5 +5 +3 +3 +3 +4 +209 +14 +3 +2 +5 +19 +4 +4 +3 +4 +14 +5 +6 +4 +13 +9 +7 +4 +7 +10 +2 +9 +5 +7 +2 +8 +4 +6 +5 +5 +222 +8 +7 +12 +5 +216 +3 +4 +4 +6 +3 +14 +8 +7 +13 +4 +3 +3 +3 +3 +17 +5 +4 +3 +33 +6 +6 +33 +7 +5 +3 +8 +7 +5 +2 +9 +4 +2 +233 +24 +7 +4 +8 +10 +3 +4 +15 +2 +16 +3 +3 +13 +12 +7 +5 +4 +207 +4 +2 +4 +27 +15 +2 +5 +2 +25 +6 +5 +5 +6 +13 +6 +18 +6 +4 +12 +225 +10 +7 +5 +2 +2 +11 +4 +14 +21 +8 +10 +3 +5 +4 +232 +2 +5 +5 +3 +7 +17 +11 +6 +6 +23 +4 +6 +3 +5 +4 +2 +17 +3 +6 +5 +8 +3 +2 +2 +14 +9 +4 +4 +2 +5 +5 +3 +7 +6 +12 +6 +10 +3 +6 +2 +2 +19 +5 +4 +4 +9 +2 +4 +13 +3 +5 +6 +3 +6 +5 +4 +9 +6 +3 +5 +7 +3 +6 +6 +4 +3 +10 +6 +3 +221 +3 +5 +3 +6 +4 +8 +5 +3 +6 +4 +4 +2 +54 +5 +6 +11 +3 +3 +4 +4 +4 +3 +7 +3 +11 +11 +7 +10 +6 +13 +223 +213 +15 +231 +7 +3 +7 +228 +2 +3 +4 +4 +5 +6 +7 +4 +13 +3 +4 +5 +3 +6 +4 +6 +7 +2 +4 +3 +4 +3 +3 +6 +3 +7 +3 +5 +18 +5 +6 +8 +10 +3 +3 +3 +2 +4 +2 +4 +4 +5 +6 +6 +4 +10 +13 +3 +12 +5 +12 +16 +8 +4 +19 +11 +2 +4 +5 +6 +8 +5 +6 +4 +18 +10 +4 +2 +216 +6 +6 +6 +2 +4 +12 +8 +3 +11 +5 +6 +14 +5 +3 +13 +4 +5 +4 +5 +3 +28 +6 +3 +7 +219 +3 +9 +7 +3 +10 +6 +3 +4 +19 +5 +7 +11 +6 +15 +19 +4 +13 +11 +3 +7 +5 +10 +2 +8 +11 +2 +6 +4 +6 +24 +6 +3 +3 +3 +3 +6 +18 +4 +11 +4 +2 +5 +10 +8 +3 +9 +5 +3 +4 +5 +6 +2 +5 +7 +4 +4 +14 +6 +4 +4 +5 +5 +7 +2 +4 +3 +7 +3 +3 +6 +4 +5 +4 +4 +4 +3 +3 +3 +3 +8 +14 +2 +3 +5 +3 +2 +4 +5 +3 +7 +3 +3 +18 +3 +4 +4 +5 +7 +3 +3 +3 +13 +5 +4 +8 +211 +5 +5 +3 +5 +2 +5 +4 +2 +655 +6 +3 +5 +11 +2 +5 +3 +12 +9 +15 +11 +5 +12 +217 +2 +6 +17 +3 +3 +207 +5 +5 +4 +5 +9 +3 +2 +8 +5 +4 +3 +2 +5 +12 +4 +14 +5 +4 +2 +13 +5 +8 +4 +225 +4 +3 +4 +5 +4 +3 +3 +6 +23 +9 +2 +6 +7 +233 +4 +4 +6 +18 +3 +4 +6 +3 +4 +4 +2 +3 +7 +4 +13 +227 +4 +3 +5 +4 +2 +12 +9 +17 +3 +7 +14 +6 +4 +5 +21 +4 +8 +9 +2 +9 +25 +16 +3 +6 +4 +7 +8 +5 +2 +3 +5 +4 +3 +3 +5 +3 +3 +3 +2 +3 +19 +2 +4 +3 +4 +2 +3 +4 +4 +2 +4 +3 +3 +3 +2 +6 +3 +17 +5 +6 +4 +3 +13 +5 +3 +3 +3 +4 +9 +4 +2 +14 +12 +4 +5 +24 +4 +3 +37 +12 +11 +21 +3 +4 +3 +13 +4 +2 +3 +15 +4 +11 +4 +4 +3 +8 +3 +4 +4 +12 +8 +5 +3 +3 +4 +2 +220 +3 +5 +223 +3 +3 +3 +10 +3 +15 +4 +241 +9 +7 +3 +6 +6 +23 +4 +13 +7 +3 +4 +7 +4 +9 +3 +3 +4 +10 +5 +5 +1 +5 +24 +2 +4 +5 +5 +6 +14 +3 +8 +2 +3 +5 +13 +13 +3 +5 +2 +3 +15 +3 +4 +2 +10 +4 +4 +4 +5 +5 +3 +5 +3 +4 +7 +4 +27 +3 +6 +4 +15 +3 +5 +6 +6 +5 +4 +8 +3 +9 +2 +6 +3 +4 +3 +7 +4 +18 +3 +11 +3 +3 +8 +9 +7 +24 +3 +219 +7 +10 +4 +5 +9 +12 +2 +5 +4 +4 +4 +3 +3 +19 +5 +8 +16 +8 +6 +22 +3 +23 +3 +242 +9 +4 +3 +3 +5 +7 +3 +3 +5 +8 +3 +7 +5 +14 +8 +10 +3 +4 +3 +7 +4 +6 +7 +4 +10 +4 +3 +11 +3 +7 +10 +3 +13 +6 +8 +12 +10 +5 +7 +9 +3 +4 +7 +7 +10 +8 +30 +9 +19 +4 +3 +19 +15 +4 +13 +3 +215 +223 +4 +7 +4 +8 +17 +16 +3 +7 +6 +5 +5 +4 +12 +3 +7 +4 +4 +13 +4 +5 +2 +5 +6 +5 +6 +6 +7 +10 +18 +23 +9 +3 +3 +6 +5 +2 +4 +2 +7 +3 +3 +2 +5 +5 +14 +10 +224 +6 +3 +4 +3 +7 +5 +9 +3 +6 +4 +2 +5 +11 +4 +3 +3 +2 +8 +4 +7 +4 +10 +7 +3 +3 +18 +18 +17 +3 +3 +3 +4 +5 +3 +3 +4 +12 +7 +3 +11 +13 +5 +4 +7 +13 +5 +4 +11 +3 +12 +3 +6 +4 +4 +21 +4 +6 +9 +5 +3 +10 +8 +4 +6 +4 +4 +6 +5 +4 +8 +6 +4 +6 +4 +4 +5 +9 +6 +3 +4 +2 +9 +3 +18 +2 +4 +3 +13 +3 +6 +6 +8 +7 +9 +3 +2 +16 +3 +4 +6 +3 +2 +33 +22 +14 +4 +9 +12 +4 +5 +6 +3 +23 +9 +4 +3 +5 +5 +3 +4 +5 +3 +5 +3 +10 +4 +5 +5 +8 +4 +4 +6 +8 +5 +4 +3 +4 +6 +3 +3 +3 +5 +9 +12 +6 +5 +9 +3 +5 +3 +2 +2 +2 +18 +3 +2 +21 +2 +5 +4 +6 +4 +5 +10 +3 +9 +3 +2 +10 +7 +3 +6 +6 +4 +4 +8 +12 +7 +3 +7 +3 +3 +9 +3 +4 +5 +4 +4 +5 +5 +10 +15 +4 +4 +14 +6 +227 +3 +14 +5 +216 +22 +5 +4 +2 +2 +6 +3 +4 +2 +9 +9 +4 +3 +28 +13 +11 +4 +5 +3 +3 +2 +3 +3 +5 +3 +4 +3 +5 +23 +26 +3 +4 +5 +6 +4 +6 +3 +5 +5 +3 +4 +3 +2 +2 +2 +7 +14 +3 +6 +7 +17 +2 +2 +15 +14 +16 +4 +6 +7 +13 +6 +4 +5 +6 +16 +3 +3 +28 +3 +6 +15 +3 +9 +2 +4 +6 +3 +3 +22 +4 +12 +6 +7 +2 +5 +4 +10 +3 +16 +6 +9 +2 +5 +12 +7 +5 +5 +5 +5 +2 +11 +9 +17 +4 +3 +11 +7 +3 +5 +15 +4 +3 +4 +211 +8 +7 +5 +4 +7 +6 +7 +6 +3 +6 +5 +6 +5 +3 +4 +4 +26 +4 +6 +10 +4 +4 +3 +2 +3 +3 +4 +5 +9 +3 +9 +4 +4 +5 +5 +8 +2 +4 +2 +3 +8 +4 +11 +19 +5 +8 +6 +3 +5 +6 +12 +3 +2 +4 +16 +12 +3 +4 +4 +8 +6 +5 +6 +6 +219 +8 +222 +6 +16 +3 +13 +19 +5 +4 +3 +11 +6 +10 +4 +7 +7 +12 +5 +3 +3 +5 +6 +10 +3 +8 +2 +5 +4 +7 +2 +4 +4 +2 +12 +9 +6 +4 +2 +40 +2 +4 +10 +4 +223 +4 +2 +20 +6 +7 +24 +5 +4 +5 +2 +20 +16 +6 +5 +13 +2 +3 +3 +19 +3 +2 +4 +5 +6 +7 +11 +12 +5 +6 +7 +7 +3 +5 +3 +5 +3 +14 +3 +4 +4 +2 +11 +1 +7 +3 +9 +6 +11 +12 +5 +8 +6 +221 +4 +2 +12 +4 +3 +15 +4 +5 +226 +7 +218 +7 +5 +4 +5 +18 +4 +5 +9 +4 +4 +2 +9 +18 +18 +9 +5 +6 +6 +3 +3 +7 +3 +5 +4 +4 +4 +12 +3 +6 +31 +5 +4 +7 +3 +6 +5 +6 +5 +11 +2 +2 +11 +11 +6 +7 +5 +8 +7 +10 +5 +23 +7 +4 +3 +5 +34 +2 +5 +23 +7 +3 +6 +8 +4 +4 +4 +2 +5 +3 +8 +5 +4 +8 +25 +2 +3 +17 +8 +3 +4 +8 +7 +3 +15 +6 +5 +7 +21 +9 +5 +6 +6 +5 +3 +2 +3 +10 +3 +6 +3 +14 +7 +4 +4 +8 +7 +8 +2 +6 +12 +4 +213 +6 +5 +21 +8 +2 +5 +23 +3 +11 +2 +3 +6 +25 +2 +3 +6 +7 +6 +6 +4 +4 +6 +3 +17 +9 +7 +6 +4 +3 +10 +7 +2 +3 +3 +3 +11 +8 +3 +7 +6 +4 +14 +36 +3 +4 +3 +3 +22 +13 +21 +4 +2 +7 +4 +4 +17 +15 +3 +7 +11 +2 +4 +7 +6 +209 +6 +3 +2 +2 +24 +4 +9 +4 +3 +3 +3 +29 +2 +2 +4 +3 +3 +5 +4 +6 +3 +3 +2 +4 diff --git a/vendor/github.com/beorn7/perks/quantile/stream.go b/vendor/github.com/beorn7/perks/quantile/stream.go new file mode 100644 index 00000000..f4cabd66 --- /dev/null +++ b/vendor/github.com/beorn7/perks/quantile/stream.go @@ -0,0 +1,292 @@ +// Package quantile computes approximate quantiles over an unbounded data +// stream within low memory and CPU bounds. +// +// A small amount of accuracy is traded to achieve the above properties. +// +// Multiple streams can be merged before calling Query to generate a single set +// of results. This is meaningful when the streams represent the same type of +// data. See Merge and Samples. +// +// For more detailed information about the algorithm used, see: +// +// Effective Computation of Biased Quantiles over Data Streams +// +// http://www.cs.rutgers.edu/~muthu/bquant.pdf +package quantile + +import ( + "math" + "sort" +) + +// Sample holds an observed value and meta information for compression. JSON +// tags have been added for convenience. +type Sample struct { + Value float64 `json:",string"` + Width float64 `json:",string"` + Delta float64 `json:",string"` +} + +// Samples represents a slice of samples. It implements sort.Interface. +type Samples []Sample + +func (a Samples) Len() int { return len(a) } +func (a Samples) Less(i, j int) bool { return a[i].Value < a[j].Value } +func (a Samples) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +type invariant func(s *stream, r float64) float64 + +// NewLowBiased returns an initialized Stream for low-biased quantiles +// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but +// error guarantees can still be given even for the lower ranks of the data +// distribution. +// +// The provided epsilon is a relative error, i.e. the true quantile of a value +// returned by a query is guaranteed to be within (1±Epsilon)*Quantile. +// +// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error +// properties. +func NewLowBiased(epsilon float64) *Stream { + ƒ := func(s *stream, r float64) float64 { + return 2 * epsilon * r + } + return newStream(ƒ) +} + +// NewHighBiased returns an initialized Stream for high-biased quantiles +// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but +// error guarantees can still be given even for the higher ranks of the data +// distribution. +// +// The provided epsilon is a relative error, i.e. the true quantile of a value +// returned by a query is guaranteed to be within 1-(1±Epsilon)*(1-Quantile). +// +// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error +// properties. +func NewHighBiased(epsilon float64) *Stream { + ƒ := func(s *stream, r float64) float64 { + return 2 * epsilon * (s.n - r) + } + return newStream(ƒ) +} + +// NewTargeted returns an initialized Stream concerned with a particular set of +// quantile values that are supplied a priori. Knowing these a priori reduces +// space and computation time. The targets map maps the desired quantiles to +// their absolute errors, i.e. the true quantile of a value returned by a query +// is guaranteed to be within (Quantile±Epsilon). +// +// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties. +func NewTargeted(targets map[float64]float64) *Stream { + ƒ := func(s *stream, r float64) float64 { + var m = math.MaxFloat64 + var f float64 + for quantile, epsilon := range targets { + if quantile*s.n <= r { + f = (2 * epsilon * r) / quantile + } else { + f = (2 * epsilon * (s.n - r)) / (1 - quantile) + } + if f < m { + m = f + } + } + return m + } + return newStream(ƒ) +} + +// Stream computes quantiles for a stream of float64s. It is not thread-safe by +// design. Take care when using across multiple goroutines. +type Stream struct { + *stream + b Samples + sorted bool +} + +func newStream(ƒ invariant) *Stream { + x := &stream{ƒ: ƒ} + return &Stream{x, make(Samples, 0, 500), true} +} + +// Insert inserts v into the stream. +func (s *Stream) Insert(v float64) { + s.insert(Sample{Value: v, Width: 1}) +} + +func (s *Stream) insert(sample Sample) { + s.b = append(s.b, sample) + s.sorted = false + if len(s.b) == cap(s.b) { + s.flush() + } +} + +// Query returns the computed qth percentiles value. If s was created with +// NewTargeted, and q is not in the set of quantiles provided a priori, Query +// will return an unspecified result. +func (s *Stream) Query(q float64) float64 { + if !s.flushed() { + // Fast path when there hasn't been enough data for a flush; + // this also yields better accuracy for small sets of data. + l := len(s.b) + if l == 0 { + return 0 + } + i := int(math.Ceil(float64(l) * q)) + if i > 0 { + i -= 1 + } + s.maybeSort() + return s.b[i].Value + } + s.flush() + return s.stream.query(q) +} + +// Merge merges samples into the underlying streams samples. This is handy when +// merging multiple streams from separate threads, database shards, etc. +// +// ATTENTION: This method is broken and does not yield correct results. The +// underlying algorithm is not capable of merging streams correctly. +func (s *Stream) Merge(samples Samples) { + sort.Sort(samples) + s.stream.merge(samples) +} + +// Reset reinitializes and clears the list reusing the samples buffer memory. +func (s *Stream) Reset() { + s.stream.reset() + s.b = s.b[:0] +} + +// Samples returns stream samples held by s. +func (s *Stream) Samples() Samples { + if !s.flushed() { + return s.b + } + s.flush() + return s.stream.samples() +} + +// Count returns the total number of samples observed in the stream +// since initialization. +func (s *Stream) Count() int { + return len(s.b) + s.stream.count() +} + +func (s *Stream) flush() { + s.maybeSort() + s.stream.merge(s.b) + s.b = s.b[:0] +} + +func (s *Stream) maybeSort() { + if !s.sorted { + s.sorted = true + sort.Sort(s.b) + } +} + +func (s *Stream) flushed() bool { + return len(s.stream.l) > 0 +} + +type stream struct { + n float64 + l []Sample + ƒ invariant +} + +func (s *stream) reset() { + s.l = s.l[:0] + s.n = 0 +} + +func (s *stream) insert(v float64) { + s.merge(Samples{{v, 1, 0}}) +} + +func (s *stream) merge(samples Samples) { + // TODO(beorn7): This tries to merge not only individual samples, but + // whole summaries. The paper doesn't mention merging summaries at + // all. Unittests show that the merging is inaccurate. Find out how to + // do merges properly. + var r float64 + i := 0 + for _, sample := range samples { + for ; i < len(s.l); i++ { + c := s.l[i] + if c.Value > sample.Value { + // Insert at position i. + s.l = append(s.l, Sample{}) + copy(s.l[i+1:], s.l[i:]) + s.l[i] = Sample{ + sample.Value, + sample.Width, + math.Max(sample.Delta, math.Floor(s.ƒ(s, r))-1), + // TODO(beorn7): How to calculate delta correctly? + } + i++ + goto inserted + } + r += c.Width + } + s.l = append(s.l, Sample{sample.Value, sample.Width, 0}) + i++ + inserted: + s.n += sample.Width + r += sample.Width + } + s.compress() +} + +func (s *stream) count() int { + return int(s.n) +} + +func (s *stream) query(q float64) float64 { + t := math.Ceil(q * s.n) + t += math.Ceil(s.ƒ(s, t) / 2) + p := s.l[0] + var r float64 + for _, c := range s.l[1:] { + r += p.Width + if r+c.Width+c.Delta > t { + return p.Value + } + p = c + } + return p.Value +} + +func (s *stream) compress() { + if len(s.l) < 2 { + return + } + x := s.l[len(s.l)-1] + xi := len(s.l) - 1 + r := s.n - 1 - x.Width + + for i := len(s.l) - 2; i >= 0; i-- { + c := s.l[i] + if c.Width+x.Width+x.Delta <= s.ƒ(s, r) { + x.Width += c.Width + s.l[xi] = x + // Remove element at i. + copy(s.l[i:], s.l[i+1:]) + s.l = s.l[:len(s.l)-1] + xi -= 1 + } else { + x = c + xi = i + } + r -= c.Width + } +} + +func (s *stream) samples() Samples { + samples := make(Samples, len(s.l)) + copy(samples, s.l) + return samples +} diff --git a/vendor/github.com/beorn7/perks/quantile/stream_test.go b/vendor/github.com/beorn7/perks/quantile/stream_test.go new file mode 100644 index 00000000..85519509 --- /dev/null +++ b/vendor/github.com/beorn7/perks/quantile/stream_test.go @@ -0,0 +1,215 @@ +package quantile + +import ( + "math" + "math/rand" + "sort" + "testing" +) + +var ( + Targets = map[float64]float64{ + 0.01: 0.001, + 0.10: 0.01, + 0.50: 0.05, + 0.90: 0.01, + 0.99: 0.001, + } + TargetsSmallEpsilon = map[float64]float64{ + 0.01: 0.0001, + 0.10: 0.001, + 0.50: 0.005, + 0.90: 0.001, + 0.99: 0.0001, + } + LowQuantiles = []float64{0.01, 0.1, 0.5} + HighQuantiles = []float64{0.99, 0.9, 0.5} +) + +const RelativeEpsilon = 0.01 + +func verifyPercsWithAbsoluteEpsilon(t *testing.T, a []float64, s *Stream) { + sort.Float64s(a) + for quantile, epsilon := range Targets { + n := float64(len(a)) + k := int(quantile * n) + if k < 1 { + k = 1 + } + lower := int((quantile - epsilon) * n) + if lower < 1 { + lower = 1 + } + upper := int(math.Ceil((quantile + epsilon) * n)) + if upper > len(a) { + upper = len(a) + } + w, min, max := a[k-1], a[lower-1], a[upper-1] + if g := s.Query(quantile); g < min || g > max { + t.Errorf("q=%f: want %v [%f,%f], got %v", quantile, w, min, max, g) + } + } +} + +func verifyLowPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) { + sort.Float64s(a) + for _, qu := range LowQuantiles { + n := float64(len(a)) + k := int(qu * n) + + lowerRank := int((1 - RelativeEpsilon) * qu * n) + upperRank := int(math.Ceil((1 + RelativeEpsilon) * qu * n)) + w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1] + if g := s.Query(qu); g < min || g > max { + t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g) + } + } +} + +func verifyHighPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) { + sort.Float64s(a) + for _, qu := range HighQuantiles { + n := float64(len(a)) + k := int(qu * n) + + lowerRank := int((1 - (1+RelativeEpsilon)*(1-qu)) * n) + upperRank := int(math.Ceil((1 - (1-RelativeEpsilon)*(1-qu)) * n)) + w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1] + if g := s.Query(qu); g < min || g > max { + t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g) + } + } +} + +func populateStream(s *Stream) []float64 { + a := make([]float64, 0, 1e5+100) + for i := 0; i < cap(a); i++ { + v := rand.NormFloat64() + // Add 5% asymmetric outliers. + if i%20 == 0 { + v = v*v + 1 + } + s.Insert(v) + a = append(a, v) + } + return a +} + +func TestTargetedQuery(t *testing.T) { + rand.Seed(42) + s := NewTargeted(Targets) + a := populateStream(s) + verifyPercsWithAbsoluteEpsilon(t, a, s) +} + +func TestTargetedQuerySmallSampleSize(t *testing.T) { + rand.Seed(42) + s := NewTargeted(TargetsSmallEpsilon) + a := []float64{1, 2, 3, 4, 5} + for _, v := range a { + s.Insert(v) + } + verifyPercsWithAbsoluteEpsilon(t, a, s) + // If not yet flushed, results should be precise: + if !s.flushed() { + for φ, want := range map[float64]float64{ + 0.01: 1, + 0.10: 1, + 0.50: 3, + 0.90: 5, + 0.99: 5, + } { + if got := s.Query(φ); got != want { + t.Errorf("want %f for φ=%f, got %f", want, φ, got) + } + } + } +} + +func TestLowBiasedQuery(t *testing.T) { + rand.Seed(42) + s := NewLowBiased(RelativeEpsilon) + a := populateStream(s) + verifyLowPercsWithRelativeEpsilon(t, a, s) +} + +func TestHighBiasedQuery(t *testing.T) { + rand.Seed(42) + s := NewHighBiased(RelativeEpsilon) + a := populateStream(s) + verifyHighPercsWithRelativeEpsilon(t, a, s) +} + +// BrokenTestTargetedMerge is broken, see Merge doc comment. +func BrokenTestTargetedMerge(t *testing.T) { + rand.Seed(42) + s1 := NewTargeted(Targets) + s2 := NewTargeted(Targets) + a := populateStream(s1) + a = append(a, populateStream(s2)...) + s1.Merge(s2.Samples()) + verifyPercsWithAbsoluteEpsilon(t, a, s1) +} + +// BrokenTestLowBiasedMerge is broken, see Merge doc comment. +func BrokenTestLowBiasedMerge(t *testing.T) { + rand.Seed(42) + s1 := NewLowBiased(RelativeEpsilon) + s2 := NewLowBiased(RelativeEpsilon) + a := populateStream(s1) + a = append(a, populateStream(s2)...) + s1.Merge(s2.Samples()) + verifyLowPercsWithRelativeEpsilon(t, a, s2) +} + +// BrokenTestHighBiasedMerge is broken, see Merge doc comment. +func BrokenTestHighBiasedMerge(t *testing.T) { + rand.Seed(42) + s1 := NewHighBiased(RelativeEpsilon) + s2 := NewHighBiased(RelativeEpsilon) + a := populateStream(s1) + a = append(a, populateStream(s2)...) + s1.Merge(s2.Samples()) + verifyHighPercsWithRelativeEpsilon(t, a, s2) +} + +func TestUncompressed(t *testing.T) { + q := NewTargeted(Targets) + for i := 100; i > 0; i-- { + q.Insert(float64(i)) + } + if g := q.Count(); g != 100 { + t.Errorf("want count 100, got %d", g) + } + // Before compression, Query should have 100% accuracy. + for quantile := range Targets { + w := quantile * 100 + if g := q.Query(quantile); g != w { + t.Errorf("want %f, got %f", w, g) + } + } +} + +func TestUncompressedSamples(t *testing.T) { + q := NewTargeted(map[float64]float64{0.99: 0.001}) + for i := 1; i <= 100; i++ { + q.Insert(float64(i)) + } + if g := q.Samples().Len(); g != 100 { + t.Errorf("want count 100, got %d", g) + } +} + +func TestUncompressedOne(t *testing.T) { + q := NewTargeted(map[float64]float64{0.99: 0.01}) + q.Insert(3.14) + if g := q.Query(0.90); g != 3.14 { + t.Error("want PI, got", g) + } +} + +func TestDefaults(t *testing.T) { + if g := NewTargeted(map[float64]float64{0.99: 0.001}).Query(0.99); g != 0 { + t.Errorf("want 0, got %f", g) + } +} diff --git a/vendor/github.com/beorn7/perks/topk/topk.go b/vendor/github.com/beorn7/perks/topk/topk.go new file mode 100644 index 00000000..5ac3d990 --- /dev/null +++ b/vendor/github.com/beorn7/perks/topk/topk.go @@ -0,0 +1,90 @@ +package topk + +import ( + "sort" +) + +// http://www.cs.ucsb.edu/research/tech_reports/reports/2005-23.pdf + +type Element struct { + Value string + Count int +} + +type Samples []*Element + +func (sm Samples) Len() int { + return len(sm) +} + +func (sm Samples) Less(i, j int) bool { + return sm[i].Count < sm[j].Count +} + +func (sm Samples) Swap(i, j int) { + sm[i], sm[j] = sm[j], sm[i] +} + +type Stream struct { + k int + mon map[string]*Element + + // the minimum Element + min *Element +} + +func New(k int) *Stream { + s := new(Stream) + s.k = k + s.mon = make(map[string]*Element) + s.min = &Element{} + + // Track k+1 so that less frequenet items contended for that spot, + // resulting in k being more accurate. + return s +} + +func (s *Stream) Insert(x string) { + s.insert(&Element{x, 1}) +} + +func (s *Stream) Merge(sm Samples) { + for _, e := range sm { + s.insert(e) + } +} + +func (s *Stream) insert(in *Element) { + e := s.mon[in.Value] + if e != nil { + e.Count++ + } else { + if len(s.mon) < s.k+1 { + e = &Element{in.Value, in.Count} + s.mon[in.Value] = e + } else { + e = s.min + delete(s.mon, e.Value) + e.Value = in.Value + e.Count += in.Count + s.min = e + } + } + if e.Count < s.min.Count { + s.min = e + } +} + +func (s *Stream) Query() Samples { + var sm Samples + for _, e := range s.mon { + sm = append(sm, e) + } + sort.Sort(sort.Reverse(sm)) + + if len(sm) < s.k { + return sm + } + + return sm[:s.k] +} diff --git a/vendor/github.com/beorn7/perks/topk/topk_test.go b/vendor/github.com/beorn7/perks/topk/topk_test.go new file mode 100644 index 00000000..c24f0f72 --- /dev/null +++ b/vendor/github.com/beorn7/perks/topk/topk_test.go @@ -0,0 +1,57 @@ +package topk + +import ( + "fmt" + "math/rand" + "sort" + "testing" +) + +func TestTopK(t *testing.T) { + stream := New(10) + ss := []*Stream{New(10), New(10), New(10)} + m := make(map[string]int) + for _, s := range ss { + for i := 0; i < 1e6; i++ { + v := fmt.Sprintf("%x", int8(rand.ExpFloat64())) + s.Insert(v) + m[v]++ + } + stream.Merge(s.Query()) + } + + var sm Samples + for x, s := range m { + sm = append(sm, &Element{x, s}) + } + sort.Sort(sort.Reverse(sm)) + + g := stream.Query() + if len(g) != 10 { + t.Fatalf("got %d, want 10", len(g)) + } + for i, e := range g { + if sm[i].Value != e.Value { + t.Errorf("at %d: want %q, got %q", i, sm[i].Value, e.Value) + } + } +} + +func TestQuery(t *testing.T) { + queryTests := []struct { + value string + expected int + }{ + {"a", 1}, + {"b", 2}, + {"c", 2}, + } + + stream := New(2) + for _, tt := range queryTests { + stream.Insert(tt.value) + if n := len(stream.Query()); n != tt.expected { + t.Errorf("want %d, got %d", tt.expected, n) + } + } +} diff --git a/vendor/github.com/boltdb/bolt/.gitignore b/vendor/github.com/boltdb/bolt/.gitignore new file mode 100644 index 00000000..c7bd2b7a --- /dev/null +++ b/vendor/github.com/boltdb/bolt/.gitignore @@ -0,0 +1,4 @@ +*.prof +*.test +*.swp +/bin/ diff --git a/vendor/github.com/boltdb/bolt/LICENSE b/vendor/github.com/boltdb/bolt/LICENSE new file mode 100644 index 00000000..004e77fe --- /dev/null +++ b/vendor/github.com/boltdb/bolt/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Ben Johnson + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/boltdb/bolt/Makefile b/vendor/github.com/boltdb/bolt/Makefile new file mode 100644 index 00000000..e035e63a --- /dev/null +++ b/vendor/github.com/boltdb/bolt/Makefile @@ -0,0 +1,18 @@ +BRANCH=`git rev-parse --abbrev-ref HEAD` +COMMIT=`git rev-parse --short HEAD` +GOLDFLAGS="-X main.branch $(BRANCH) -X main.commit $(COMMIT)" + +default: build + +race: + @go test -v -race -test.run="TestSimulate_(100op|1000op)" + +# go get github.com/kisielk/errcheck +errcheck: + @errcheck -ignorepkg=bytes -ignore=os:Remove github.com/boltdb/bolt + +test: + @go test -v -cover . + @go test -v ./cmd/bolt + +.PHONY: fmt test diff --git a/vendor/github.com/boltdb/bolt/README.md b/vendor/github.com/boltdb/bolt/README.md new file mode 100644 index 00000000..8523e337 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/README.md @@ -0,0 +1,852 @@ +Bolt [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.svg?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.svg)](https://godoc.org/github.com/boltdb/bolt) ![Version](https://img.shields.io/badge/version-1.2.1-green.svg) +==== + +Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] +[LMDB project][lmdb]. The goal of the project is to provide a simple, +fast, and reliable database for projects that don't require a full database +server such as Postgres or MySQL. + +Since Bolt is meant to be used as such a low-level piece of functionality, +simplicity is key. The API will be small and only focus on getting values +and setting values. That's it. + +[hyc_symas]: https://twitter.com/hyc_symas +[lmdb]: http://symas.com/mdb/ + +## Project Status + +Bolt is stable and the API is fixed. Full unit test coverage and randomized +black box testing are used to ensure database consistency and thread safety. +Bolt is currently in high-load production environments serving databases as +large as 1TB. Many companies such as Shopify and Heroku use Bolt-backed +services every day. + +## Table of Contents + +- [Getting Started](#getting-started) + - [Installing](#installing) + - [Opening a database](#opening-a-database) + - [Transactions](#transactions) + - [Read-write transactions](#read-write-transactions) + - [Read-only transactions](#read-only-transactions) + - [Batch read-write transactions](#batch-read-write-transactions) + - [Managing transactions manually](#managing-transactions-manually) + - [Using buckets](#using-buckets) + - [Using key/value pairs](#using-keyvalue-pairs) + - [Autoincrementing integer for the bucket](#autoincrementing-integer-for-the-bucket) + - [Iterating over keys](#iterating-over-keys) + - [Prefix scans](#prefix-scans) + - [Range scans](#range-scans) + - [ForEach()](#foreach) + - [Nested buckets](#nested-buckets) + - [Database backups](#database-backups) + - [Statistics](#statistics) + - [Read-Only Mode](#read-only-mode) + - [Mobile Use (iOS/Android)](#mobile-use-iosandroid) +- [Resources](#resources) +- [Comparison with other databases](#comparison-with-other-databases) + - [Postgres, MySQL, & other relational databases](#postgres-mysql--other-relational-databases) + - [LevelDB, RocksDB](#leveldb-rocksdb) + - [LMDB](#lmdb) +- [Caveats & Limitations](#caveats--limitations) +- [Reading the Source](#reading-the-source) +- [Other Projects Using Bolt](#other-projects-using-bolt) + +## Getting Started + +### Installing + +To start using Bolt, install Go and run `go get`: + +```sh +$ go get github.com/boltdb/bolt/... +``` + +This will retrieve the library and install the `bolt` command line utility into +your `$GOBIN` path. + + +### Opening a database + +The top-level object in Bolt is a `DB`. It is represented as a single file on +your disk and represents a consistent snapshot of your data. + +To open your database, simply use the `bolt.Open()` function: + +```go +package main + +import ( + "log" + + "github.com/boltdb/bolt" +) + +func main() { + // Open the my.db data file in your current directory. + // It will be created if it doesn't exist. + db, err := bolt.Open("my.db", 0600, nil) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + ... +} +``` + +Please note that Bolt obtains a file lock on the data file so multiple processes +cannot open the same database at the same time. Opening an already open Bolt +database will cause it to hang until the other process closes it. To prevent +an indefinite wait you can pass a timeout option to the `Open()` function: + +```go +db, err := bolt.Open("my.db", 0600, &bolt.Options{Timeout: 1 * time.Second}) +``` + + +### Transactions + +Bolt allows only one read-write transaction at a time but allows as many +read-only transactions as you want at a time. Each transaction has a consistent +view of the data as it existed when the transaction started. + +Individual transactions and all objects created from them (e.g. buckets, keys) +are not thread safe. To work with data in multiple goroutines you must start +a transaction for each one or use locking to ensure only one goroutine accesses +a transaction at a time. Creating transaction from the `DB` is thread safe. + +Read-only transactions and read-write transactions should not depend on one +another and generally shouldn't be opened simultaneously in the same goroutine. +This can cause a deadlock as the read-write transaction needs to periodically +re-map the data file but it cannot do so while a read-only transaction is open. + + +#### Read-write transactions + +To start a read-write transaction, you can use the `DB.Update()` function: + +```go +err := db.Update(func(tx *bolt.Tx) error { + ... + return nil +}) +``` + +Inside the closure, you have a consistent view of the database. You commit the +transaction by returning `nil` at the end. You can also rollback the transaction +at any point by returning an error. All database operations are allowed inside +a read-write transaction. + +Always check the return error as it will report any disk failures that can cause +your transaction to not complete. If you return an error within your closure +it will be passed through. + + +#### Read-only transactions + +To start a read-only transaction, you can use the `DB.View()` function: + +```go +err := db.View(func(tx *bolt.Tx) error { + ... + return nil +}) +``` + +You also get a consistent view of the database within this closure, however, +no mutating operations are allowed within a read-only transaction. You can only +retrieve buckets, retrieve values, and copy the database within a read-only +transaction. + + +#### Batch read-write transactions + +Each `DB.Update()` waits for disk to commit the writes. This overhead +can be minimized by combining multiple updates with the `DB.Batch()` +function: + +```go +err := db.Batch(func(tx *bolt.Tx) error { + ... + return nil +}) +``` + +Concurrent Batch calls are opportunistically combined into larger +transactions. Batch is only useful when there are multiple goroutines +calling it. + +The trade-off is that `Batch` can call the given +function multiple times, if parts of the transaction fail. The +function must be idempotent and side effects must take effect only +after a successful return from `DB.Batch()`. + +For example: don't display messages from inside the function, instead +set variables in the enclosing scope: + +```go +var id uint64 +err := db.Batch(func(tx *bolt.Tx) error { + // Find last key in bucket, decode as bigendian uint64, increment + // by one, encode back to []byte, and add new key. + ... + id = newValue + return nil +}) +if err != nil { + return ... +} +fmt.Println("Allocated ID %d", id) +``` + + +#### Managing transactions manually + +The `DB.View()` and `DB.Update()` functions are wrappers around the `DB.Begin()` +function. These helper functions will start the transaction, execute a function, +and then safely close your transaction if an error is returned. This is the +recommended way to use Bolt transactions. + +However, sometimes you may want to manually start and end your transactions. +You can use the `Tx.Begin()` function directly but **please** be sure to close +the transaction. + +```go +// Start a writable transaction. +tx, err := db.Begin(true) +if err != nil { + return err +} +defer tx.Rollback() + +// Use the transaction... +_, err := tx.CreateBucket([]byte("MyBucket")) +if err != nil { + return err +} + +// Commit the transaction and check for error. +if err := tx.Commit(); err != nil { + return err +} +``` + +The first argument to `DB.Begin()` is a boolean stating if the transaction +should be writable. + + +### Using buckets + +Buckets are collections of key/value pairs within the database. All keys in a +bucket must be unique. You can create a bucket using the `DB.CreateBucket()` +function: + +```go +db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("MyBucket")) + if err != nil { + return fmt.Errorf("create bucket: %s", err) + } + return nil +}) +``` + +You can also create a bucket only if it doesn't exist by using the +`Tx.CreateBucketIfNotExists()` function. It's a common pattern to call this +function for all your top-level buckets after you open your database so you can +guarantee that they exist for future transactions. + +To delete a bucket, simply call the `Tx.DeleteBucket()` function. + + +### Using key/value pairs + +To save a key/value pair to a bucket, use the `Bucket.Put()` function: + +```go +db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("MyBucket")) + err := b.Put([]byte("answer"), []byte("42")) + return err +}) +``` + +This will set the value of the `"answer"` key to `"42"` in the `MyBucket` +bucket. To retrieve this value, we can use the `Bucket.Get()` function: + +```go +db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("MyBucket")) + v := b.Get([]byte("answer")) + fmt.Printf("The answer is: %s\n", v) + return nil +}) +``` + +The `Get()` function does not return an error because its operation is +guaranteed to work (unless there is some kind of system failure). If the key +exists then it will return its byte slice value. If it doesn't exist then it +will return `nil`. It's important to note that you can have a zero-length value +set to a key which is different than the key not existing. + +Use the `Bucket.Delete()` function to delete a key from the bucket. + +Please note that values returned from `Get()` are only valid while the +transaction is open. If you need to use a value outside of the transaction +then you must use `copy()` to copy it to another byte slice. + + +### Autoincrementing integer for the bucket +By using the `NextSequence()` function, you can let Bolt determine a sequence +which can be used as the unique identifier for your key/value pairs. See the +example below. + +```go +// CreateUser saves u to the store. The new user ID is set on u once the data is persisted. +func (s *Store) CreateUser(u *User) error { + return s.db.Update(func(tx *bolt.Tx) error { + // Retrieve the users bucket. + // This should be created when the DB is first opened. + b := tx.Bucket([]byte("users")) + + // Generate ID for the user. + // This returns an error only if the Tx is closed or not writeable. + // That can't happen in an Update() call so I ignore the error check. + id, _ := b.NextSequence() + u.ID = int(id) + + // Marshal user data into bytes. + buf, err := json.Marshal(u) + if err != nil { + return err + } + + // Persist bytes to users bucket. + return b.Put(itob(u.ID), buf) + }) +} + +// itob returns an 8-byte big endian representation of v. +func itob(v int) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(v)) + return b +} + +type User struct { + ID int + ... +} +``` + +### Iterating over keys + +Bolt stores its keys in byte-sorted order within a bucket. This makes sequential +iteration over these keys extremely fast. To iterate over keys we'll use a +`Cursor`: + +```go +db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys + b := tx.Bucket([]byte("MyBucket")) + + c := b.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + fmt.Printf("key=%s, value=%s\n", k, v) + } + + return nil +}) +``` + +The cursor allows you to move to a specific point in the list of keys and move +forward or backward through the keys one at a time. + +The following functions are available on the cursor: + +``` +First() Move to the first key. +Last() Move to the last key. +Seek() Move to a specific key. +Next() Move to the next key. +Prev() Move to the previous key. +``` + +Each of those functions has a return signature of `(key []byte, value []byte)`. +When you have iterated to the end of the cursor then `Next()` will return a +`nil` key. You must seek to a position using `First()`, `Last()`, or `Seek()` +before calling `Next()` or `Prev()`. If you do not seek to a position then +these functions will return a `nil` key. + +During iteration, if the key is non-`nil` but the value is `nil`, that means +the key refers to a bucket rather than a value. Use `Bucket.Bucket()` to +access the sub-bucket. + + +#### Prefix scans + +To iterate over a key prefix, you can combine `Seek()` and `bytes.HasPrefix()`: + +```go +db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys + c := tx.Bucket([]byte("MyBucket")).Cursor() + + prefix := []byte("1234") + for k, v := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, v = c.Next() { + fmt.Printf("key=%s, value=%s\n", k, v) + } + + return nil +}) +``` + +#### Range scans + +Another common use case is scanning over a range such as a time range. If you +use a sortable time encoding such as RFC3339 then you can query a specific +date range like this: + +```go +db.View(func(tx *bolt.Tx) error { + // Assume our events bucket exists and has RFC3339 encoded time keys. + c := tx.Bucket([]byte("Events")).Cursor() + + // Our time range spans the 90's decade. + min := []byte("1990-01-01T00:00:00Z") + max := []byte("2000-01-01T00:00:00Z") + + // Iterate over the 90's. + for k, v := c.Seek(min); k != nil && bytes.Compare(k, max) <= 0; k, v = c.Next() { + fmt.Printf("%s: %s\n", k, v) + } + + return nil +}) +``` + +Note that, while RFC3339 is sortable, the Golang implementation of RFC3339Nano does not use a fixed number of digits after the decimal point and is therefore not sortable. + + +#### ForEach() + +You can also use the function `ForEach()` if you know you'll be iterating over +all the keys in a bucket: + +```go +db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys + b := tx.Bucket([]byte("MyBucket")) + + b.ForEach(func(k, v []byte) error { + fmt.Printf("key=%s, value=%s\n", k, v) + return nil + }) + return nil +}) +``` + + +### Nested buckets + +You can also store a bucket in a key to create nested buckets. The API is the +same as the bucket management API on the `DB` object: + +```go +func (*Bucket) CreateBucket(key []byte) (*Bucket, error) +func (*Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) +func (*Bucket) DeleteBucket(key []byte) error +``` + + +### Database backups + +Bolt is a single file so it's easy to backup. You can use the `Tx.WriteTo()` +function to write a consistent view of the database to a writer. If you call +this from a read-only transaction, it will perform a hot backup and not block +your other database reads and writes. + +By default, it will use a regular file handle which will utilize the operating +system's page cache. See the [`Tx`](https://godoc.org/github.com/boltdb/bolt#Tx) +documentation for information about optimizing for larger-than-RAM datasets. + +One common use case is to backup over HTTP so you can use tools like `cURL` to +do database backups: + +```go +func BackupHandleFunc(w http.ResponseWriter, req *http.Request) { + err := db.View(func(tx *bolt.Tx) error { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", `attachment; filename="my.db"`) + w.Header().Set("Content-Length", strconv.Itoa(int(tx.Size()))) + _, err := tx.WriteTo(w) + return err + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} +``` + +Then you can backup using this command: + +```sh +$ curl http://localhost/backup > my.db +``` + +Or you can open your browser to `http://localhost/backup` and it will download +automatically. + +If you want to backup to another file you can use the `Tx.CopyFile()` helper +function. + + +### Statistics + +The database keeps a running count of many of the internal operations it +performs so you can better understand what's going on. By grabbing a snapshot +of these stats at two points in time we can see what operations were performed +in that time range. + +For example, we could start a goroutine to log stats every 10 seconds: + +```go +go func() { + // Grab the initial stats. + prev := db.Stats() + + for { + // Wait for 10s. + time.Sleep(10 * time.Second) + + // Grab the current stats and diff them. + stats := db.Stats() + diff := stats.Sub(&prev) + + // Encode stats to JSON and print to STDERR. + json.NewEncoder(os.Stderr).Encode(diff) + + // Save stats for the next loop. + prev = stats + } +}() +``` + +It's also useful to pipe these stats to a service such as statsd for monitoring +or to provide an HTTP endpoint that will perform a fixed-length sample. + + +### Read-Only Mode + +Sometimes it is useful to create a shared, read-only Bolt database. To this, +set the `Options.ReadOnly` flag when opening your database. Read-only mode +uses a shared lock to allow multiple processes to read from the database but +it will block any processes from opening the database in read-write mode. + +```go +db, err := bolt.Open("my.db", 0666, &bolt.Options{ReadOnly: true}) +if err != nil { + log.Fatal(err) +} +``` + +### Mobile Use (iOS/Android) + +Bolt is able to run on mobile devices by leveraging the binding feature of the +[gomobile](https://github.com/golang/mobile) tool. Create a struct that will +contain your database logic and a reference to a `*bolt.DB` with a initializing +constructor that takes in a filepath where the database file will be stored. +Neither Android nor iOS require extra permissions or cleanup from using this method. + +```go +func NewBoltDB(filepath string) *BoltDB { + db, err := bolt.Open(filepath+"/demo.db", 0600, nil) + if err != nil { + log.Fatal(err) + } + + return &BoltDB{db} +} + +type BoltDB struct { + db *bolt.DB + ... +} + +func (b *BoltDB) Path() string { + return b.db.Path() +} + +func (b *BoltDB) Close() { + b.db.Close() +} +``` + +Database logic should be defined as methods on this wrapper struct. + +To initialize this struct from the native language (both platforms now sync +their local storage to the cloud. These snippets disable that functionality for the +database file): + +#### Android + +```java +String path; +if (android.os.Build.VERSION.SDK_INT >=android.os.Build.VERSION_CODES.LOLLIPOP){ + path = getNoBackupFilesDir().getAbsolutePath(); +} else{ + path = getFilesDir().getAbsolutePath(); +} +Boltmobiledemo.BoltDB boltDB = Boltmobiledemo.NewBoltDB(path) +``` + +#### iOS + +```objc +- (void)demo { + NSString* path = [NSSearchPathForDirectoriesInDomains(NSLibraryDirectory, + NSUserDomainMask, + YES) objectAtIndex:0]; + GoBoltmobiledemoBoltDB * demo = GoBoltmobiledemoNewBoltDB(path); + [self addSkipBackupAttributeToItemAtPath:demo.path]; + //Some DB Logic would go here + [demo close]; +} + +- (BOOL)addSkipBackupAttributeToItemAtPath:(NSString *) filePathString +{ + NSURL* URL= [NSURL fileURLWithPath: filePathString]; + assert([[NSFileManager defaultManager] fileExistsAtPath: [URL path]]); + + NSError *error = nil; + BOOL success = [URL setResourceValue: [NSNumber numberWithBool: YES] + forKey: NSURLIsExcludedFromBackupKey error: &error]; + if(!success){ + NSLog(@"Error excluding %@ from backup %@", [URL lastPathComponent], error); + } + return success; +} + +``` + +## Resources + +For more information on getting started with Bolt, check out the following articles: + +* [Intro to BoltDB: Painless Performant Persistence](http://npf.io/2014/07/intro-to-boltdb-painless-performant-persistence/) by [Nate Finch](https://github.com/natefinch). +* [Bolt -- an embedded key/value database for Go](https://www.progville.com/go/bolt-embedded-db-golang/) by Progville + + +## Comparison with other databases + +### Postgres, MySQL, & other relational databases + +Relational databases structure data into rows and are only accessible through +the use of SQL. This approach provides flexibility in how you store and query +your data but also incurs overhead in parsing and planning SQL statements. Bolt +accesses all data by a byte slice key. This makes Bolt fast to read and write +data by key but provides no built-in support for joining values together. + +Most relational databases (with the exception of SQLite) are standalone servers +that run separately from your application. This gives your systems +flexibility to connect multiple application servers to a single database +server but also adds overhead in serializing and transporting data over the +network. Bolt runs as a library included in your application so all data access +has to go through your application's process. This brings data closer to your +application but limits multi-process access to the data. + + +### LevelDB, RocksDB + +LevelDB and its derivatives (RocksDB, HyperLevelDB) are similar to Bolt in that +they are libraries bundled into the application, however, their underlying +structure is a log-structured merge-tree (LSM tree). An LSM tree optimizes +random writes by using a write ahead log and multi-tiered, sorted files called +SSTables. Bolt uses a B+tree internally and only a single file. Both approaches +have trade-offs. + +If you require a high random write throughput (>10,000 w/sec) or you need to use +spinning disks then LevelDB could be a good choice. If your application is +read-heavy or does a lot of range scans then Bolt could be a good choice. + +One other important consideration is that LevelDB does not have transactions. +It supports batch writing of key/values pairs and it supports read snapshots +but it will not give you the ability to do a compare-and-swap operation safely. +Bolt supports fully serializable ACID transactions. + + +### LMDB + +Bolt was originally a port of LMDB so it is architecturally similar. Both use +a B+tree, have ACID semantics with fully serializable transactions, and support +lock-free MVCC using a single writer and multiple readers. + +The two projects have somewhat diverged. LMDB heavily focuses on raw performance +while Bolt has focused on simplicity and ease of use. For example, LMDB allows +several unsafe actions such as direct writes for the sake of performance. Bolt +opts to disallow actions which can leave the database in a corrupted state. The +only exception to this in Bolt is `DB.NoSync`. + +There are also a few differences in API. LMDB requires a maximum mmap size when +opening an `mdb_env` whereas Bolt will handle incremental mmap resizing +automatically. LMDB overloads the getter and setter functions with multiple +flags whereas Bolt splits these specialized cases into their own functions. + + +## Caveats & Limitations + +It's important to pick the right tool for the job and Bolt is no exception. +Here are a few things to note when evaluating and using Bolt: + +* Bolt is good for read intensive workloads. Sequential write performance is + also fast but random writes can be slow. You can use `DB.Batch()` or add a + write-ahead log to help mitigate this issue. + +* Bolt uses a B+tree internally so there can be a lot of random page access. + SSDs provide a significant performance boost over spinning disks. + +* Try to avoid long running read transactions. Bolt uses copy-on-write so + old pages cannot be reclaimed while an old transaction is using them. + +* Byte slices returned from Bolt are only valid during a transaction. Once the + transaction has been committed or rolled back then the memory they point to + can be reused by a new page or can be unmapped from virtual memory and you'll + see an `unexpected fault address` panic when accessing it. + +* Be careful when using `Bucket.FillPercent`. Setting a high fill percent for + buckets that have random inserts will cause your database to have very poor + page utilization. + +* Use larger buckets in general. Smaller buckets causes poor page utilization + once they become larger than the page size (typically 4KB). + +* Bulk loading a lot of random writes into a new bucket can be slow as the + page will not split until the transaction is committed. Randomly inserting + more than 100,000 key/value pairs into a single new bucket in a single + transaction is not advised. + +* Bolt uses a memory-mapped file so the underlying operating system handles the + caching of the data. Typically, the OS will cache as much of the file as it + can in memory and will release memory as needed to other processes. This means + that Bolt can show very high memory usage when working with large databases. + However, this is expected and the OS will release memory as needed. Bolt can + handle databases much larger than the available physical RAM, provided its + memory-map fits in the process virtual address space. It may be problematic + on 32-bits systems. + +* The data structures in the Bolt database are memory mapped so the data file + will be endian specific. This means that you cannot copy a Bolt file from a + little endian machine to a big endian machine and have it work. For most + users this is not a concern since most modern CPUs are little endian. + +* Because of the way pages are laid out on disk, Bolt cannot truncate data files + and return free pages back to the disk. Instead, Bolt maintains a free list + of unused pages within its data file. These free pages can be reused by later + transactions. This works well for many use cases as databases generally tend + to grow. However, it's important to note that deleting large chunks of data + will not allow you to reclaim that space on disk. + + For more information on page allocation, [see this comment][page-allocation]. + +[page-allocation]: https://github.com/boltdb/bolt/issues/308#issuecomment-74811638 + + +## Reading the Source + +Bolt is a relatively small code base (<3KLOC) for an embedded, serializable, +transactional key/value database so it can be a good starting point for people +interested in how databases work. + +The best places to start are the main entry points into Bolt: + +- `Open()` - Initializes the reference to the database. It's responsible for + creating the database if it doesn't exist, obtaining an exclusive lock on the + file, reading the meta pages, & memory-mapping the file. + +- `DB.Begin()` - Starts a read-only or read-write transaction depending on the + value of the `writable` argument. This requires briefly obtaining the "meta" + lock to keep track of open transactions. Only one read-write transaction can + exist at a time so the "rwlock" is acquired during the life of a read-write + transaction. + +- `Bucket.Put()` - Writes a key/value pair into a bucket. After validating the + arguments, a cursor is used to traverse the B+tree to the page and position + where they key & value will be written. Once the position is found, the bucket + materializes the underlying page and the page's parent pages into memory as + "nodes". These nodes are where mutations occur during read-write transactions. + These changes get flushed to disk during commit. + +- `Bucket.Get()` - Retrieves a key/value pair from a bucket. This uses a cursor + to move to the page & position of a key/value pair. During a read-only + transaction, the key and value data is returned as a direct reference to the + underlying mmap file so there's no allocation overhead. For read-write + transactions, this data may reference the mmap file or one of the in-memory + node values. + +- `Cursor` - This object is simply for traversing the B+tree of on-disk pages + or in-memory nodes. It can seek to a specific key, move to the first or last + value, or it can move forward or backward. The cursor handles the movement up + and down the B+tree transparently to the end user. + +- `Tx.Commit()` - Converts the in-memory dirty nodes and the list of free pages + into pages to be written to disk. Writing to disk then occurs in two phases. + First, the dirty pages are written to disk and an `fsync()` occurs. Second, a + new meta page with an incremented transaction ID is written and another + `fsync()` occurs. This two phase write ensures that partially written data + pages are ignored in the event of a crash since the meta page pointing to them + is never written. Partially written meta pages are invalidated because they + are written with a checksum. + +If you have additional notes that could be helpful for others, please submit +them via pull request. + + +## Other Projects Using Bolt + +Below is a list of public, open source projects that use Bolt: + +* [BoltDbWeb](https://github.com/evnix/boltdbweb) - A web based GUI for BoltDB files. +* [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard. +* [Bazil](https://bazil.org/) - A file system that lets your data reside where it is most convenient for it to reside. +* [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb. +* [Skybox Analytics](https://github.com/skybox/skybox) - A standalone funnel analysis tool for web analytics. +* [Scuttlebutt](https://github.com/benbjohnson/scuttlebutt) - Uses Bolt to store and process all Twitter mentions of GitHub projects. +* [Wiki](https://github.com/peterhellberg/wiki) - A tiny wiki using Goji, BoltDB and Blackfriday. +* [ChainStore](https://github.com/pressly/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations. +* [MetricBase](https://github.com/msiebuhr/MetricBase) - Single-binary version of Graphite. +* [Gitchain](https://github.com/gitchain/gitchain) - Decentralized, peer-to-peer Git repositories aka "Git meets Bitcoin". +* [event-shuttle](https://github.com/sclasen/event-shuttle) - A Unix system service to collect and reliably deliver messages to Kafka. +* [ipxed](https://github.com/kelseyhightower/ipxed) - Web interface and api for ipxed. +* [BoltStore](https://github.com/yosssi/boltstore) - Session store using Bolt. +* [photosite/session](https://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site. +* [LedisDB](https://github.com/siddontang/ledisdb) - A high performance NoSQL, using Bolt as optional storage. +* [ipLocator](https://github.com/AndreasBriese/ipLocator) - A fast ip-geo-location-server using bolt with bloom filters. +* [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend. +* [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend. +* [tentacool](https://github.com/optiflows/tentacool) - REST api server to manage system stuff (IP, DNS, Gateway...) on a linux server. +* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read. +* [InfluxDB](https://influxdata.com) - Scalable datastore for metrics, events, and real-time analytics. +* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data. +* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system. +* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware. +* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistent, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs. +* [drive](https://github.com/odeke-em/drive) - drive is an unofficial Google Drive command line client for \*NIX operating systems. +* [stow](https://github.com/djherbis/stow) - a persistence manager for objects + backed by boltdb. +* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining + simple tx and key scans. +* [mbuckets](https://github.com/abhigupta912/mbuckets) - A Bolt wrapper that allows easy operations on multi level (nested) buckets. +* [Request Baskets](https://github.com/darklynx/request-baskets) - A web service to collect arbitrary HTTP requests and inspect them via REST API or simple web UI, similar to [RequestBin](http://requestb.in/) service +* [Go Report Card](https://goreportcard.com/) - Go code quality report cards as a (free and open source) service. +* [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners. +* [lru](https://github.com/crowdriff/lru) - Easy to use Bolt-backed Least-Recently-Used (LRU) read-through cache with chainable remote stores. +* [Storm](https://github.com/asdine/storm) - Simple and powerful ORM for BoltDB. +* [GoWebApp](https://github.com/josephspurrier/gowebapp) - A basic MVC web application in Go using BoltDB. +* [SimpleBolt](https://github.com/xyproto/simplebolt) - A simple way to use BoltDB. Deals mainly with strings. +* [Algernon](https://github.com/xyproto/algernon) - A HTTP/2 web server with built-in support for Lua. Uses BoltDB as the default database backend. +* [MuLiFS](https://github.com/dankomiocevic/mulifs) - Music Library Filesystem creates a filesystem to organise your music files. +* [GoShort](https://github.com/pankajkhairnar/goShort) - GoShort is a URL shortener written in Golang and BoltDB for persistent key/value storage and for routing it's using high performent HTTPRouter. + +If you are using Bolt in a project please send a pull request to add it to the list. diff --git a/vendor/github.com/boltdb/bolt/appveyor.yml b/vendor/github.com/boltdb/bolt/appveyor.yml new file mode 100644 index 00000000..6e26e941 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/appveyor.yml @@ -0,0 +1,18 @@ +version: "{build}" + +os: Windows Server 2012 R2 + +clone_folder: c:\gopath\src\github.com\boltdb\bolt + +environment: + GOPATH: c:\gopath + +install: + - echo %PATH% + - echo %GOPATH% + - go version + - go env + - go get -v -t ./... + +build_script: + - go test -v ./... diff --git a/vendor/github.com/boltdb/bolt/bolt_386.go b/vendor/github.com/boltdb/bolt/bolt_386.go new file mode 100644 index 00000000..e659bfb9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_386.go @@ -0,0 +1,7 @@ +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_amd64.go b/vendor/github.com/boltdb/bolt/bolt_amd64.go new file mode 100644 index 00000000..cca6b7eb --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_amd64.go @@ -0,0 +1,7 @@ +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_arm.go b/vendor/github.com/boltdb/bolt/bolt_arm.go new file mode 100644 index 00000000..e659bfb9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_arm.go @@ -0,0 +1,7 @@ +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_arm64.go b/vendor/github.com/boltdb/bolt/bolt_arm64.go new file mode 100644 index 00000000..6d230935 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_arm64.go @@ -0,0 +1,9 @@ +// +build arm64 + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_linux.go b/vendor/github.com/boltdb/bolt/bolt_linux.go new file mode 100644 index 00000000..2b676661 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_linux.go @@ -0,0 +1,10 @@ +package bolt + +import ( + "syscall" +) + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return syscall.Fdatasync(int(db.file.Fd())) +} diff --git a/vendor/github.com/boltdb/bolt/bolt_openbsd.go b/vendor/github.com/boltdb/bolt/bolt_openbsd.go new file mode 100644 index 00000000..7058c3d7 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_openbsd.go @@ -0,0 +1,27 @@ +package bolt + +import ( + "syscall" + "unsafe" +) + +const ( + msAsync = 1 << iota // perform asynchronous writes + msSync // perform synchronous writes + msInvalidate // invalidate cached data +) + +func msync(db *DB) error { + _, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate) + if errno != 0 { + return errno + } + return nil +} + +func fdatasync(db *DB) error { + if db.data != nil { + return msync(db) + } + return db.file.Sync() +} diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc.go b/vendor/github.com/boltdb/bolt/bolt_ppc.go new file mode 100644 index 00000000..645ddc3e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc.go @@ -0,0 +1,9 @@ +// +build ppc + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc64.go b/vendor/github.com/boltdb/bolt/bolt_ppc64.go new file mode 100644 index 00000000..2dc6be02 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc64.go @@ -0,0 +1,9 @@ +// +build ppc64 + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc64le.go b/vendor/github.com/boltdb/bolt/bolt_ppc64le.go new file mode 100644 index 00000000..8351e129 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc64le.go @@ -0,0 +1,9 @@ +// +build ppc64le + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_s390x.go b/vendor/github.com/boltdb/bolt/bolt_s390x.go new file mode 100644 index 00000000..f4dd26bb --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_s390x.go @@ -0,0 +1,9 @@ +// +build s390x + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_unix.go b/vendor/github.com/boltdb/bolt/bolt_unix.go new file mode 100644 index 00000000..cad62dda --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_unix.go @@ -0,0 +1,89 @@ +// +build !windows,!plan9,!solaris + +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" +) + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + flag := syscall.LOCK_SH + if exclusive { + flag = syscall.LOCK_EX + } + + // Otherwise attempt to obtain an exclusive lock. + err := syscall.Flock(int(db.file.Fd()), flag|syscall.LOCK_NB) + if err == nil { + return nil + } else if err != syscall.EWOULDBLOCK { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + return syscall.Flock(int(db.file.Fd()), syscall.LOCK_UN) +} + +// mmap memory maps a DB's data file. +func mmap(db *DB, sz int) error { + // Map the data file to memory. + b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) + if err != nil { + return err + } + + // Advise the kernel that the mmap is accessed randomly. + if err := madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + + // Save the original byte slice and convert to a byte array pointer. + db.dataref = b + db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) + db.datasz = sz + return nil +} + +// munmap unmaps a DB's data file from memory. +func munmap(db *DB) error { + // Ignore the unmap if we have no mapped data. + if db.dataref == nil { + return nil + } + + // Unmap using the original byte slice. + err := syscall.Munmap(db.dataref) + db.dataref = nil + db.data = nil + db.datasz = 0 + return err +} + +// NOTE: This function is copied from stdlib because it is not available on darwin. +func madvise(b []byte, advice int) (err error) { + _, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice)) + if e1 != 0 { + err = e1 + } + return +} diff --git a/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go b/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go new file mode 100644 index 00000000..307bf2b3 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go @@ -0,0 +1,90 @@ +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/unix" +) + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Pid = 0 + lock.Whence = 0 + lock.Pid = 0 + if exclusive { + lock.Type = syscall.F_WRLCK + } else { + lock.Type = syscall.F_RDLCK + } + err := syscall.FcntlFlock(db.file.Fd(), syscall.F_SETLK, &lock) + if err == nil { + return nil + } else if err != syscall.EAGAIN { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Type = syscall.F_UNLCK + lock.Whence = 0 + return syscall.FcntlFlock(uintptr(db.file.Fd()), syscall.F_SETLK, &lock) +} + +// mmap memory maps a DB's data file. +func mmap(db *DB, sz int) error { + // Map the data file to memory. + b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) + if err != nil { + return err + } + + // Advise the kernel that the mmap is accessed randomly. + if err := unix.Madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + + // Save the original byte slice and convert to a byte array pointer. + db.dataref = b + db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) + db.datasz = sz + return nil +} + +// munmap unmaps a DB's data file from memory. +func munmap(db *DB) error { + // Ignore the unmap if we have no mapped data. + if db.dataref == nil { + return nil + } + + // Unmap using the original byte slice. + err := unix.Munmap(db.dataref) + db.dataref = nil + db.data = nil + db.datasz = 0 + return err +} diff --git a/vendor/github.com/boltdb/bolt/bolt_windows.go b/vendor/github.com/boltdb/bolt/bolt_windows.go new file mode 100644 index 00000000..d538e6af --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_windows.go @@ -0,0 +1,144 @@ +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" +) + +// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1 +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + procLockFileEx = modkernel32.NewProc("LockFileEx") + procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") +) + +const ( + lockExt = ".lock" + + // see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx + flagLockExclusive = 2 + flagLockFailImmediately = 1 + + // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx + errLockViolation syscall.Errno = 0x21 +) + +func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol))) + if r == 0 { + return err + } + return nil +} + +func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0) + if r == 0 { + return err + } + return nil +} + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return db.file.Sync() +} + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + // Create a separate lock file on windows because a process + // cannot share an exclusive lock on the same file. This is + // needed during Tx.WriteTo(). + f, err := os.OpenFile(db.path+lockExt, os.O_CREATE, mode) + if err != nil { + return err + } + db.lockfile = f + + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + + var flag uint32 = flagLockFailImmediately + if exclusive { + flag |= flagLockExclusive + } + + err := lockFileEx(syscall.Handle(db.lockfile.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}) + if err == nil { + return nil + } else if err != errLockViolation { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + err := unlockFileEx(syscall.Handle(db.lockfile.Fd()), 0, 1, 0, &syscall.Overlapped{}) + db.lockfile.Close() + os.Remove(db.path+lockExt) + return err +} + +// mmap memory maps a DB's data file. +// Based on: https://github.com/edsrzf/mmap-go +func mmap(db *DB, sz int) error { + if !db.readOnly { + // Truncate the database to the size of the mmap. + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("truncate: %s", err) + } + } + + // Open a file mapping handle. + sizelo := uint32(sz >> 32) + sizehi := uint32(sz) & 0xffffffff + h, errno := syscall.CreateFileMapping(syscall.Handle(db.file.Fd()), nil, syscall.PAGE_READONLY, sizelo, sizehi, nil) + if h == 0 { + return os.NewSyscallError("CreateFileMapping", errno) + } + + // Create the memory map. + addr, errno := syscall.MapViewOfFile(h, syscall.FILE_MAP_READ, 0, 0, uintptr(sz)) + if addr == 0 { + return os.NewSyscallError("MapViewOfFile", errno) + } + + // Close mapping handle. + if err := syscall.CloseHandle(syscall.Handle(h)); err != nil { + return os.NewSyscallError("CloseHandle", err) + } + + // Convert to a byte array. + db.data = ((*[maxMapSize]byte)(unsafe.Pointer(addr))) + db.datasz = sz + + return nil +} + +// munmap unmaps a pointer from a file. +// Based on: https://github.com/edsrzf/mmap-go +func munmap(db *DB) error { + if db.data == nil { + return nil + } + + addr := (uintptr)(unsafe.Pointer(&db.data[0])) + if err := syscall.UnmapViewOfFile(addr); err != nil { + return os.NewSyscallError("UnmapViewOfFile", err) + } + return nil +} diff --git a/vendor/github.com/boltdb/bolt/boltsync_unix.go b/vendor/github.com/boltdb/bolt/boltsync_unix.go new file mode 100644 index 00000000..f5044252 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/boltsync_unix.go @@ -0,0 +1,8 @@ +// +build !windows,!plan9,!linux,!openbsd + +package bolt + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return db.file.Sync() +} diff --git a/vendor/github.com/boltdb/bolt/bucket.go b/vendor/github.com/boltdb/bolt/bucket.go new file mode 100644 index 00000000..d2f8c524 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bucket.go @@ -0,0 +1,748 @@ +package bolt + +import ( + "bytes" + "fmt" + "unsafe" +) + +const ( + // MaxKeySize is the maximum length of a key, in bytes. + MaxKeySize = 32768 + + // MaxValueSize is the maximum length of a value, in bytes. + MaxValueSize = (1 << 31) - 2 +) + +const ( + maxUint = ^uint(0) + minUint = 0 + maxInt = int(^uint(0) >> 1) + minInt = -maxInt - 1 +) + +const bucketHeaderSize = int(unsafe.Sizeof(bucket{})) + +const ( + minFillPercent = 0.1 + maxFillPercent = 1.0 +) + +// DefaultFillPercent is the percentage that split pages are filled. +// This value can be changed by setting Bucket.FillPercent. +const DefaultFillPercent = 0.5 + +// Bucket represents a collection of key/value pairs inside the database. +type Bucket struct { + *bucket + tx *Tx // the associated transaction + buckets map[string]*Bucket // subbucket cache + page *page // inline page reference + rootNode *node // materialized node for the root page. + nodes map[pgid]*node // node cache + + // Sets the threshold for filling nodes when they split. By default, + // the bucket will fill to 50% but it can be useful to increase this + // amount if you know that your write workloads are mostly append-only. + // + // This is non-persisted across transactions so it must be set in every Tx. + FillPercent float64 +} + +// bucket represents the on-file representation of a bucket. +// This is stored as the "value" of a bucket key. If the bucket is small enough, +// then its root page can be stored inline in the "value", after the bucket +// header. In the case of inline buckets, the "root" will be 0. +type bucket struct { + root pgid // page id of the bucket's root-level page + sequence uint64 // monotonically incrementing, used by NextSequence() +} + +// newBucket returns a new bucket associated with a transaction. +func newBucket(tx *Tx) Bucket { + var b = Bucket{tx: tx, FillPercent: DefaultFillPercent} + if tx.writable { + b.buckets = make(map[string]*Bucket) + b.nodes = make(map[pgid]*node) + } + return b +} + +// Tx returns the tx of the bucket. +func (b *Bucket) Tx() *Tx { + return b.tx +} + +// Root returns the root of the bucket. +func (b *Bucket) Root() pgid { + return b.root +} + +// Writable returns whether the bucket is writable. +func (b *Bucket) Writable() bool { + return b.tx.writable +} + +// Cursor creates a cursor associated with the bucket. +// The cursor is only valid as long as the transaction is open. +// Do not use a cursor after the transaction is closed. +func (b *Bucket) Cursor() *Cursor { + // Update transaction statistics. + b.tx.stats.CursorCount++ + + // Allocate and return a cursor. + return &Cursor{ + bucket: b, + stack: make([]elemRef, 0), + } +} + +// Bucket retrieves a nested bucket by name. +// Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) Bucket(name []byte) *Bucket { + if b.buckets != nil { + if child := b.buckets[string(name)]; child != nil { + return child + } + } + + // Move cursor to key. + c := b.Cursor() + k, v, flags := c.seek(name) + + // Return nil if the key doesn't exist or it is not a bucket. + if !bytes.Equal(name, k) || (flags&bucketLeafFlag) == 0 { + return nil + } + + // Otherwise create a bucket and cache it. + var child = b.openBucket(v) + if b.buckets != nil { + b.buckets[string(name)] = child + } + + return child +} + +// Helper method that re-interprets a sub-bucket value +// from a parent into a Bucket +func (b *Bucket) openBucket(value []byte) *Bucket { + var child = newBucket(b.tx) + + // If this is a writable transaction then we need to copy the bucket entry. + // Read-only transactions can point directly at the mmap entry. + if b.tx.writable { + child.bucket = &bucket{} + *child.bucket = *(*bucket)(unsafe.Pointer(&value[0])) + } else { + child.bucket = (*bucket)(unsafe.Pointer(&value[0])) + } + + // Save a reference to the inline page if the bucket is inline. + if child.root == 0 { + child.page = (*page)(unsafe.Pointer(&value[bucketHeaderSize])) + } + + return &child +} + +// CreateBucket creates a new bucket at the given key and returns the new bucket. +// Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) { + if b.tx.db == nil { + return nil, ErrTxClosed + } else if !b.tx.writable { + return nil, ErrTxNotWritable + } else if len(key) == 0 { + return nil, ErrBucketNameRequired + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if there is an existing key. + if bytes.Equal(key, k) { + if (flags & bucketLeafFlag) != 0 { + return nil, ErrBucketExists + } else { + return nil, ErrIncompatibleValue + } + } + + // Create empty, inline bucket. + var bucket = Bucket{ + bucket: &bucket{}, + rootNode: &node{isLeaf: true}, + FillPercent: DefaultFillPercent, + } + var value = bucket.write() + + // Insert into node. + key = cloneBytes(key) + c.node().put(key, key, value, 0, bucketLeafFlag) + + // Since subbuckets are not allowed on inline buckets, we need to + // dereference the inline page, if it exists. This will cause the bucket + // to be treated as a regular, non-inline bucket for the rest of the tx. + b.page = nil + + return b.Bucket(key), nil +} + +// CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it. +// Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) { + child, err := b.CreateBucket(key) + if err == ErrBucketExists { + return b.Bucket(key), nil + } else if err != nil { + return nil, err + } + return child, nil +} + +// DeleteBucket deletes a bucket at the given key. +// Returns an error if the bucket does not exists, or if the key represents a non-bucket value. +func (b *Bucket) DeleteBucket(key []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if bucket doesn't exist or is not a bucket. + if !bytes.Equal(key, k) { + return ErrBucketNotFound + } else if (flags & bucketLeafFlag) == 0 { + return ErrIncompatibleValue + } + + // Recursively delete all child buckets. + child := b.Bucket(key) + err := child.ForEach(func(k, v []byte) error { + if v == nil { + if err := child.DeleteBucket(k); err != nil { + return fmt.Errorf("delete bucket: %s", err) + } + } + return nil + }) + if err != nil { + return err + } + + // Remove cached copy. + delete(b.buckets, string(key)) + + // Release all bucket pages to freelist. + child.nodes = nil + child.rootNode = nil + child.free() + + // Delete the node if we have a matching key. + c.node().del(key) + + return nil +} + +// Get retrieves the value for a key in the bucket. +// Returns a nil value if the key does not exist or if the key is a nested bucket. +// The returned value is only valid for the life of the transaction. +func (b *Bucket) Get(key []byte) []byte { + k, v, flags := b.Cursor().seek(key) + + // Return nil if this is a bucket. + if (flags & bucketLeafFlag) != 0 { + return nil + } + + // If our target node isn't the same key as what's passed in then return nil. + if !bytes.Equal(key, k) { + return nil + } + return v +} + +// Put sets the value for a key in the bucket. +// If the key exist then its previous value will be overwritten. +// Supplied value must remain valid for the life of the transaction. +// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large. +func (b *Bucket) Put(key []byte, value []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } else if len(key) == 0 { + return ErrKeyRequired + } else if len(key) > MaxKeySize { + return ErrKeyTooLarge + } else if int64(len(value)) > MaxValueSize { + return ErrValueTooLarge + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if there is an existing key with a bucket value. + if bytes.Equal(key, k) && (flags&bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + + // Insert into node. + key = cloneBytes(key) + c.node().put(key, key, value, 0, 0) + + return nil +} + +// Delete removes a key from the bucket. +// If the key does not exist then nothing is done and a nil error is returned. +// Returns an error if the bucket was created from a read-only transaction. +func (b *Bucket) Delete(key []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } + + // Move cursor to correct position. + c := b.Cursor() + _, _, flags := c.seek(key) + + // Return an error if there is already existing bucket value. + if (flags & bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + + // Delete the node if we have a matching key. + c.node().del(key) + + return nil +} + +// NextSequence returns an autoincrementing integer for the bucket. +func (b *Bucket) NextSequence() (uint64, error) { + if b.tx.db == nil { + return 0, ErrTxClosed + } else if !b.Writable() { + return 0, ErrTxNotWritable + } + + // Materialize the root node if it hasn't been already so that the + // bucket will be saved during commit. + if b.rootNode == nil { + _ = b.node(b.root, nil) + } + + // Increment and return the sequence. + b.bucket.sequence++ + return b.bucket.sequence, nil +} + +// ForEach executes a function for each key/value pair in a bucket. +// If the provided function returns an error then the iteration is stopped and +// the error is returned to the caller. The provided function must not modify +// the bucket; this will result in undefined behavior. +func (b *Bucket) ForEach(fn func(k, v []byte) error) error { + if b.tx.db == nil { + return ErrTxClosed + } + c := b.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if err := fn(k, v); err != nil { + return err + } + } + return nil +} + +// Stat returns stats on a bucket. +func (b *Bucket) Stats() BucketStats { + var s, subStats BucketStats + pageSize := b.tx.db.pageSize + s.BucketN += 1 + if b.root == 0 { + s.InlineBucketN += 1 + } + b.forEachPage(func(p *page, depth int) { + if (p.flags & leafPageFlag) != 0 { + s.KeyN += int(p.count) + + // used totals the used bytes for the page + used := pageHeaderSize + + if p.count != 0 { + // If page has any elements, add all element headers. + used += leafPageElementSize * int(p.count-1) + + // Add all element key, value sizes. + // The computation takes advantage of the fact that the position + // of the last element's key/value equals to the total of the sizes + // of all previous elements' keys and values. + // It also includes the last element's header. + lastElement := p.leafPageElement(p.count - 1) + used += int(lastElement.pos + lastElement.ksize + lastElement.vsize) + } + + if b.root == 0 { + // For inlined bucket just update the inline stats + s.InlineBucketInuse += used + } else { + // For non-inlined bucket update all the leaf stats + s.LeafPageN++ + s.LeafInuse += used + s.LeafOverflowN += int(p.overflow) + + // Collect stats from sub-buckets. + // Do that by iterating over all element headers + // looking for the ones with the bucketLeafFlag. + for i := uint16(0); i < p.count; i++ { + e := p.leafPageElement(i) + if (e.flags & bucketLeafFlag) != 0 { + // For any bucket element, open the element value + // and recursively call Stats on the contained bucket. + subStats.Add(b.openBucket(e.value()).Stats()) + } + } + } + } else if (p.flags & branchPageFlag) != 0 { + s.BranchPageN++ + lastElement := p.branchPageElement(p.count - 1) + + // used totals the used bytes for the page + // Add header and all element headers. + used := pageHeaderSize + (branchPageElementSize * int(p.count-1)) + + // Add size of all keys and values. + // Again, use the fact that last element's position equals to + // the total of key, value sizes of all previous elements. + used += int(lastElement.pos + lastElement.ksize) + s.BranchInuse += used + s.BranchOverflowN += int(p.overflow) + } + + // Keep track of maximum page depth. + if depth+1 > s.Depth { + s.Depth = (depth + 1) + } + }) + + // Alloc stats can be computed from page counts and pageSize. + s.BranchAlloc = (s.BranchPageN + s.BranchOverflowN) * pageSize + s.LeafAlloc = (s.LeafPageN + s.LeafOverflowN) * pageSize + + // Add the max depth of sub-buckets to get total nested depth. + s.Depth += subStats.Depth + // Add the stats for all sub-buckets + s.Add(subStats) + return s +} + +// forEachPage iterates over every page in a bucket, including inline pages. +func (b *Bucket) forEachPage(fn func(*page, int)) { + // If we have an inline page then just use that. + if b.page != nil { + fn(b.page, 0) + return + } + + // Otherwise traverse the page hierarchy. + b.tx.forEachPage(b.root, 0, fn) +} + +// forEachPageNode iterates over every page (or node) in a bucket. +// This also includes inline pages. +func (b *Bucket) forEachPageNode(fn func(*page, *node, int)) { + // If we have an inline page or root node then just use that. + if b.page != nil { + fn(b.page, nil, 0) + return + } + b._forEachPageNode(b.root, 0, fn) +} + +func (b *Bucket) _forEachPageNode(pgid pgid, depth int, fn func(*page, *node, int)) { + var p, n = b.pageNode(pgid) + + // Execute function. + fn(p, n, depth) + + // Recursively loop over children. + if p != nil { + if (p.flags & branchPageFlag) != 0 { + for i := 0; i < int(p.count); i++ { + elem := p.branchPageElement(uint16(i)) + b._forEachPageNode(elem.pgid, depth+1, fn) + } + } + } else { + if !n.isLeaf { + for _, inode := range n.inodes { + b._forEachPageNode(inode.pgid, depth+1, fn) + } + } + } +} + +// spill writes all the nodes for this bucket to dirty pages. +func (b *Bucket) spill() error { + // Spill all child buckets first. + for name, child := range b.buckets { + // If the child bucket is small enough and it has no child buckets then + // write it inline into the parent bucket's page. Otherwise spill it + // like a normal bucket and make the parent value a pointer to the page. + var value []byte + if child.inlineable() { + child.free() + value = child.write() + } else { + if err := child.spill(); err != nil { + return err + } + + // Update the child bucket header in this bucket. + value = make([]byte, unsafe.Sizeof(bucket{})) + var bucket = (*bucket)(unsafe.Pointer(&value[0])) + *bucket = *child.bucket + } + + // Skip writing the bucket if there are no materialized nodes. + if child.rootNode == nil { + continue + } + + // Update parent node. + var c = b.Cursor() + k, _, flags := c.seek([]byte(name)) + if !bytes.Equal([]byte(name), k) { + panic(fmt.Sprintf("misplaced bucket header: %x -> %x", []byte(name), k)) + } + if flags&bucketLeafFlag == 0 { + panic(fmt.Sprintf("unexpected bucket header flag: %x", flags)) + } + c.node().put([]byte(name), []byte(name), value, 0, bucketLeafFlag) + } + + // Ignore if there's not a materialized root node. + if b.rootNode == nil { + return nil + } + + // Spill nodes. + if err := b.rootNode.spill(); err != nil { + return err + } + b.rootNode = b.rootNode.root() + + // Update the root node for this bucket. + if b.rootNode.pgid >= b.tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", b.rootNode.pgid, b.tx.meta.pgid)) + } + b.root = b.rootNode.pgid + + return nil +} + +// inlineable returns true if a bucket is small enough to be written inline +// and if it contains no subbuckets. Otherwise returns false. +func (b *Bucket) inlineable() bool { + var n = b.rootNode + + // Bucket must only contain a single leaf node. + if n == nil || !n.isLeaf { + return false + } + + // Bucket is not inlineable if it contains subbuckets or if it goes beyond + // our threshold for inline bucket size. + var size = pageHeaderSize + for _, inode := range n.inodes { + size += leafPageElementSize + len(inode.key) + len(inode.value) + + if inode.flags&bucketLeafFlag != 0 { + return false + } else if size > b.maxInlineBucketSize() { + return false + } + } + + return true +} + +// Returns the maximum total size of a bucket to make it a candidate for inlining. +func (b *Bucket) maxInlineBucketSize() int { + return b.tx.db.pageSize / 4 +} + +// write allocates and writes a bucket to a byte slice. +func (b *Bucket) write() []byte { + // Allocate the appropriate size. + var n = b.rootNode + var value = make([]byte, bucketHeaderSize+n.size()) + + // Write a bucket header. + var bucket = (*bucket)(unsafe.Pointer(&value[0])) + *bucket = *b.bucket + + // Convert byte slice to a fake page and write the root node. + var p = (*page)(unsafe.Pointer(&value[bucketHeaderSize])) + n.write(p) + + return value +} + +// rebalance attempts to balance all nodes. +func (b *Bucket) rebalance() { + for _, n := range b.nodes { + n.rebalance() + } + for _, child := range b.buckets { + child.rebalance() + } +} + +// node creates a node from a page and associates it with a given parent. +func (b *Bucket) node(pgid pgid, parent *node) *node { + _assert(b.nodes != nil, "nodes map expected") + + // Retrieve node if it's already been created. + if n := b.nodes[pgid]; n != nil { + return n + } + + // Otherwise create a node and cache it. + n := &node{bucket: b, parent: parent} + if parent == nil { + b.rootNode = n + } else { + parent.children = append(parent.children, n) + } + + // Use the inline page if this is an inline bucket. + var p = b.page + if p == nil { + p = b.tx.page(pgid) + } + + // Read the page into the node and cache it. + n.read(p) + b.nodes[pgid] = n + + // Update statistics. + b.tx.stats.NodeCount++ + + return n +} + +// free recursively frees all pages in the bucket. +func (b *Bucket) free() { + if b.root == 0 { + return + } + + var tx = b.tx + b.forEachPageNode(func(p *page, n *node, _ int) { + if p != nil { + tx.db.freelist.free(tx.meta.txid, p) + } else { + n.free() + } + }) + b.root = 0 +} + +// dereference removes all references to the old mmap. +func (b *Bucket) dereference() { + if b.rootNode != nil { + b.rootNode.root().dereference() + } + + for _, child := range b.buckets { + child.dereference() + } +} + +// pageNode returns the in-memory node, if it exists. +// Otherwise returns the underlying page. +func (b *Bucket) pageNode(id pgid) (*page, *node) { + // Inline buckets have a fake page embedded in their value so treat them + // differently. We'll return the rootNode (if available) or the fake page. + if b.root == 0 { + if id != 0 { + panic(fmt.Sprintf("inline bucket non-zero page access(2): %d != 0", id)) + } + if b.rootNode != nil { + return nil, b.rootNode + } + return b.page, nil + } + + // Check the node cache for non-inline buckets. + if b.nodes != nil { + if n := b.nodes[id]; n != nil { + return nil, n + } + } + + // Finally lookup the page from the transaction if no node is materialized. + return b.tx.page(id), nil +} + +// BucketStats records statistics about resources used by a bucket. +type BucketStats struct { + // Page count statistics. + BranchPageN int // number of logical branch pages + BranchOverflowN int // number of physical branch overflow pages + LeafPageN int // number of logical leaf pages + LeafOverflowN int // number of physical leaf overflow pages + + // Tree statistics. + KeyN int // number of keys/value pairs + Depth int // number of levels in B+tree + + // Page size utilization. + BranchAlloc int // bytes allocated for physical branch pages + BranchInuse int // bytes actually used for branch data + LeafAlloc int // bytes allocated for physical leaf pages + LeafInuse int // bytes actually used for leaf data + + // Bucket statistics + BucketN int // total number of buckets including the top bucket + InlineBucketN int // total number on inlined buckets + InlineBucketInuse int // bytes used for inlined buckets (also accounted for in LeafInuse) +} + +func (s *BucketStats) Add(other BucketStats) { + s.BranchPageN += other.BranchPageN + s.BranchOverflowN += other.BranchOverflowN + s.LeafPageN += other.LeafPageN + s.LeafOverflowN += other.LeafOverflowN + s.KeyN += other.KeyN + if s.Depth < other.Depth { + s.Depth = other.Depth + } + s.BranchAlloc += other.BranchAlloc + s.BranchInuse += other.BranchInuse + s.LeafAlloc += other.LeafAlloc + s.LeafInuse += other.LeafInuse + + s.BucketN += other.BucketN + s.InlineBucketN += other.InlineBucketN + s.InlineBucketInuse += other.InlineBucketInuse +} + +// cloneBytes returns a copy of a given slice. +func cloneBytes(v []byte) []byte { + var clone = make([]byte, len(v)) + copy(clone, v) + return clone +} diff --git a/vendor/github.com/boltdb/bolt/bucket_test.go b/vendor/github.com/boltdb/bolt/bucket_test.go new file mode 100644 index 00000000..528fec24 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bucket_test.go @@ -0,0 +1,1867 @@ +package bolt_test + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "log" + "math/rand" + "os" + "strconv" + "strings" + "testing" + "testing/quick" + + "github.com/boltdb/bolt" +) + +// Ensure that a bucket that gets a non-existent key returns nil. +func TestBucket_Get_NonExistent(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatal("expected nil value") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can read a value that is not flushed yet. +func TestBucket_Get_FromNode(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); !bytes.Equal(v, []byte("bar")) { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket retrieved via Get() returns a nil. +func TestBucket_Get_IncompatibleValue(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + + if tx.Bucket([]byte("widgets")).Get([]byte("foo")) != nil { + t.Fatal("expected nil value") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a slice returned from a bucket has a capacity equal to its length. +// This also allows slices to be appended to since it will require a realloc by Go. +// +// https://github.com/boltdb/bolt/issues/544 +func TestBucket_Get_Capacity(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Write key to a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("bucket")) + if err != nil { + return err + } + return b.Put([]byte("key"), []byte("val")) + }); err != nil { + t.Fatal(err) + } + + // Retrieve value and attempt to append to it. + if err := db.Update(func(tx *bolt.Tx) error { + k, v := tx.Bucket([]byte("bucket")).Cursor().First() + + // Verify capacity. + if len(k) != cap(k) { + t.Fatalf("unexpected key slice capacity: %d", cap(k)) + } else if len(v) != cap(v) { + t.Fatalf("unexpected value slice capacity: %d", cap(v)) + } + + // Ensure slice can be appended to without a segfault. + k = append(k, []byte("123")...) + v = append(v, []byte("123")...) + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can write a key/value. +func TestBucket_Put(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + v := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + if !bytes.Equal([]byte("bar"), v) { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can rewrite a key in the same transaction. +func TestBucket_Put_Repeat(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("baz")); err != nil { + t.Fatal(err) + } + + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + if !bytes.Equal([]byte("baz"), value) { + t.Fatalf("unexpected value: %v", value) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can write a bunch of large values. +func TestBucket_Put_Large(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + count, factor := 100, 200 + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for i := 1; i < count; i++ { + if err := b.Put([]byte(strings.Repeat("0", i*factor)), []byte(strings.Repeat("X", (count-i)*factor))); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 1; i < count; i++ { + value := b.Get([]byte(strings.Repeat("0", i*factor))) + if !bytes.Equal(value, []byte(strings.Repeat("X", (count-i)*factor))) { + t.Fatalf("unexpected value: %v", value) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a database can perform multiple large appends safely. +func TestDB_Put_VeryLarge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + n, batchN := 400000, 200000 + ksize, vsize := 8, 500 + + db := MustOpenDB() + defer db.MustClose() + + for i := 0; i < n; i += batchN { + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for j := 0; j < batchN; j++ { + k, v := make([]byte, ksize), make([]byte, vsize) + binary.BigEndian.PutUint32(k, uint32(i+j)) + if err := b.Put(k, v); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + } +} + +// Ensure that a setting a value on a key with a bucket value returns an error. +func TestBucket_Put_IncompatibleValue(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b0, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if err := b0.Put([]byte("foo"), []byte("bar")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a setting a value while the transaction is closed returns an error. +func TestBucket_Put_Closed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte("foo"), []byte("bar")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that setting a value on a read-only bucket returns an error. +func TestBucket_Put_ReadOnly(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + if err := b.Put([]byte("foo"), []byte("bar")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can delete an existing key. +func TestBucket_Delete(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a large set of keys will work correctly. +func TestBucket_Delete_Large(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 100; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strings.Repeat("*", 1024))); err != nil { + t.Fatal(err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < 100; i++ { + if err := b.Delete([]byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < 100; i++ { + if v := b.Get([]byte(strconv.Itoa(i))); v != nil { + t.Fatalf("unexpected value: %v, i=%d", v, i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Deleting a very large list of keys will cause the freelist to use overflow. +func TestBucket_Delete_FreelistOverflow(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + db := MustOpenDB() + defer db.MustClose() + + k := make([]byte, 16) + for i := uint64(0); i < 10000; i++ { + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("0")) + if err != nil { + t.Fatalf("bucket error: %s", err) + } + + for j := uint64(0); j < 1000; j++ { + binary.BigEndian.PutUint64(k[:8], i) + binary.BigEndian.PutUint64(k[8:], j) + if err := b.Put(k, nil); err != nil { + t.Fatalf("put error: %s", err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + } + + // Delete all of them in one large transaction + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("0")) + c := b.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + if err := c.Delete(); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that accessing and updating nested buckets is ok across transactions. +func TestBucket_Nested(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + // Create a widgets bucket. + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + // Create a widgets/foo bucket. + _, err = b.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + // Create a widgets/bar key. + if err := b.Put([]byte("bar"), []byte("0000")); err != nil { + t.Fatal(err) + } + + return nil + }); err != nil { + t.Fatal(err) + } + db.MustCheck() + + // Update widgets/bar. + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + if err := b.Put([]byte("bar"), []byte("xxxx")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + db.MustCheck() + + // Cause a split. + if err := db.Update(func(tx *bolt.Tx) error { + var b = tx.Bucket([]byte("widgets")) + for i := 0; i < 10000; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + db.MustCheck() + + // Insert into widgets/foo/baz. + if err := db.Update(func(tx *bolt.Tx) error { + var b = tx.Bucket([]byte("widgets")) + if err := b.Bucket([]byte("foo")).Put([]byte("baz"), []byte("yyyy")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + db.MustCheck() + + // Verify. + if err := db.View(func(tx *bolt.Tx) error { + var b = tx.Bucket([]byte("widgets")) + if v := b.Bucket([]byte("foo")).Get([]byte("baz")); !bytes.Equal(v, []byte("yyyy")) { + t.Fatalf("unexpected value: %v", v) + } + if v := b.Get([]byte("bar")); !bytes.Equal(v, []byte("xxxx")) { + t.Fatalf("unexpected value: %v", v) + } + for i := 0; i < 10000; i++ { + if v := b.Get([]byte(strconv.Itoa(i))); !bytes.Equal(v, []byte(strconv.Itoa(i))) { + t.Fatalf("unexpected value: %v", v) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a bucket using Delete() returns an error. +func TestBucket_Delete_Bucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a key on a read-only bucket returns an error. +func TestBucket_Delete_ReadOnly(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).Delete([]byte("foo")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a deleting value while the transaction is closed returns an error. +func TestBucket_Delete_Closed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that deleting a bucket causes nested buckets to be deleted. +func TestBucket_DeleteBucket_Nested(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + bar, err := foo.CreateBucket([]byte("bar")) + if err != nil { + t.Fatal(err) + } + if err := bar.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + if err := tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a bucket causes nested buckets to be deleted after they have been committed. +func TestBucket_DeleteBucket_Nested2(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + bar, err := foo.CreateBucket([]byte("bar")) + if err != nil { + t.Fatal(err) + } + + if err := bar.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + widgets := tx.Bucket([]byte("widgets")) + if widgets == nil { + t.Fatal("expected widgets bucket") + } + + foo := widgets.Bucket([]byte("foo")) + if foo == nil { + t.Fatal("expected foo bucket") + } + + bar := foo.Bucket([]byte("bar")) + if bar == nil { + t.Fatal("expected bar bucket") + } + + if v := bar.Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) != nil { + t.Fatal("expected bucket to be deleted") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a child bucket with multiple pages causes all pages to get collected. +// NOTE: Consistency check in bolt_test.DB.Close() will panic if pages not freed properly. +func TestBucket_DeleteBucket_Large(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 1000; i++ { + if err := foo.Put([]byte(fmt.Sprintf("%d", i)), []byte(fmt.Sprintf("%0100d", i))); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a simple value retrieved via Bucket() returns a nil. +func TestBucket_Bucket_IncompatibleValue(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if b := tx.Bucket([]byte("widgets")).Bucket([]byte("foo")); b != nil { + t.Fatal("expected nil bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that creating a bucket on an existing non-bucket key returns an error. +func TestBucket_CreateBucket_IncompatibleValue(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := widgets.CreateBucket([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a bucket on an existing non-bucket key returns an error. +func TestBucket_DeleteBucket_IncompatibleValue(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can return an autoincrementing sequence. +func TestBucket_NextSequence(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + woojits, err := tx.CreateBucket([]byte("woojits")) + if err != nil { + t.Fatal(err) + } + + // Make sure sequence increments. + if seq, err := widgets.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 1 { + t.Fatalf("unexpecte sequence: %d", seq) + } + + if seq, err := widgets.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 2 { + t.Fatalf("unexpected sequence: %d", seq) + } + + // Buckets should be separate. + if seq, err := woojits.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 1 { + t.Fatalf("unexpected sequence: %d", 1) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket will persist an autoincrementing sequence even if its +// the only thing updated on the bucket. +// https://github.com/boltdb/bolt/issues/296 +func TestBucket_NextSequence_Persist(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.Bucket([]byte("widgets")).NextSequence(); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + seq, err := tx.Bucket([]byte("widgets")).NextSequence() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } else if seq != 2 { + t.Fatalf("unexpected sequence: %d", seq) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that retrieving the next sequence on a read-only bucket returns an error. +func TestBucket_NextSequence_ReadOnly(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + _, err := tx.Bucket([]byte("widgets")).NextSequence() + if err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that retrieving the next sequence for a bucket on a closed database return an error. +func TestBucket_NextSequence_Closed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if _, err := b.NextSequence(); err != bolt.ErrTxClosed { + t.Fatal(err) + } +} + +// Ensure a user can loop over all key/value pairs in a bucket. +func TestBucket_ForEach(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0001")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0002")); err != nil { + t.Fatal(err) + } + + var index int + if err := b.ForEach(func(k, v []byte) error { + switch index { + case 0: + if !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } + case 1: + if !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0001")) { + t.Fatalf("unexpected value: %v", v) + } + case 2: + if !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0000")) { + t.Fatalf("unexpected value: %v", v) + } + } + index++ + return nil + }); err != nil { + t.Fatal(err) + } + + if index != 3 { + t.Fatalf("unexpected index: %d", index) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a database can stop iteration early. +func TestBucket_ForEach_ShortCircuit(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0000")); err != nil { + t.Fatal(err) + } + + var index int + if err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { + index++ + if bytes.Equal(k, []byte("baz")) { + return errors.New("marker") + } + return nil + }); err == nil || err.Error() != "marker" { + t.Fatalf("unexpected error: %s", err) + } + if index != 2 { + t.Fatalf("unexpected index: %d", index) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that looping over a bucket on a closed database returns an error. +func TestBucket_ForEach_Closed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if err := b.ForEach(func(k, v []byte) error { return nil }); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that an error is returned when inserting with an empty key. +func TestBucket_Put_EmptyKey(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte(""), []byte("bar")); err != bolt.ErrKeyRequired { + t.Fatalf("unexpected error: %s", err) + } + if err := b.Put(nil, []byte("bar")); err != bolt.ErrKeyRequired { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that an error is returned when inserting with a key that's too large. +func TestBucket_Put_KeyTooLarge(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put(make([]byte, 32769), []byte("bar")); err != bolt.ErrKeyTooLarge { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that an error is returned when inserting a value that's too large. +func TestBucket_Put_ValueTooLarge(t *testing.T) { + // Skip this test on DroneCI because the machine is resource constrained. + if os.Getenv("DRONE") == "true" { + t.Skip("not enough RAM for test") + } + + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), make([]byte, bolt.MaxValueSize+1)); err != bolt.ErrValueTooLarge { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a bucket can calculate stats. +func TestBucket_Stats(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Add bucket with fewer keys but one big value. + bigKey := []byte("really-big-value") + for i := 0; i < 500; i++ { + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("woojits")) + if err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte(fmt.Sprintf("%03d", i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + } + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("woojits")).Put(bigKey, []byte(strings.Repeat("*", 10000))); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("woojits")).Stats() + if stats.BranchPageN != 1 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 7 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 2 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 501 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 2 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } + + branchInuse := 16 // branch page header + branchInuse += 7 * 16 // branch elements + branchInuse += 7 * 3 // branch keys (6 3-byte keys) + if stats.BranchInuse != branchInuse { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } + + leafInuse := 7 * 16 // leaf page header + leafInuse += 501 * 16 // leaf elements + leafInuse += 500*3 + len(bigKey) // leaf keys + leafInuse += 1*10 + 2*90 + 3*400 + 10000 // leaf values + if stats.LeafInuse != leafInuse { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + // Only check allocations for 4KB pages. + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 4096 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 36864 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 0 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 0 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a bucket with random insertion utilizes fill percentage correctly. +func TestBucket_Stats_RandomFill(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } else if os.Getpagesize() != 4096 { + t.Skip("invalid page size for test") + } + + db := MustOpenDB() + defer db.MustClose() + + // Add a set of values in random order. It will be the same random + // order so we can maintain consistency between test runs. + var count int + rand := rand.New(rand.NewSource(42)) + for _, i := range rand.Perm(1000) { + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("woojits")) + if err != nil { + t.Fatal(err) + } + b.FillPercent = 0.9 + for _, j := range rand.Perm(100) { + index := (j * 10000) + i + if err := b.Put([]byte(fmt.Sprintf("%d000000000000000", index)), []byte("0000000000")); err != nil { + t.Fatal(err) + } + count++ + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("woojits")).Stats() + if stats.KeyN != 100000 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } + + if stats.BranchPageN != 98 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.BranchInuse != 130984 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.BranchAlloc != 401408 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } + + if stats.LeafPageN != 3412 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.LeafInuse != 4742482 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } else if stats.LeafAlloc != 13975552 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a bucket can calculate stats. +func TestBucket_Stats_Small(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + // Add a bucket that fits on a single root leaf. + b, err := tx.CreateBucket([]byte("whozawhats")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + return nil + }); err != nil { + t.Fatal(err) + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("whozawhats")) + stats := b.Stats() + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 0 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 1 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 1 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 0 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 0 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 16+16+6 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestBucket_Stats_EmptyBucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + // Add a bucket that fits on a single root leaf. + if _, err := tx.CreateBucket([]byte("whozawhats")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("whozawhats")) + stats := b.Stats() + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 0 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 0 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 1 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 0 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 0 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 16 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a bucket can calculate stats. +func TestBucket_Stats_Nested(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i++ { + if err := b.Put([]byte(fmt.Sprintf("%02d", i)), []byte(fmt.Sprintf("%02d", i))); err != nil { + t.Fatal(err) + } + } + + bar, err := b.CreateBucket([]byte("bar")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 10; i++ { + if err := bar.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + + baz, err := bar.CreateBucket([]byte("baz")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 10; i++ { + if err := baz.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("foo")) + stats := b.Stats() + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 2 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 122 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 3 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } + + foo := 16 // foo (pghdr) + foo += 101 * 16 // foo leaf elements + foo += 100*2 + 100*2 // foo leaf key/values + foo += 3 + 16 // foo -> bar key/value + + bar := 16 // bar (pghdr) + bar += 11 * 16 // bar leaf elements + bar += 10 + 10 // bar leaf key/values + bar += 3 + 16 // bar -> baz key/value + + baz := 16 // baz (inline) (pghdr) + baz += 10 * 16 // baz leaf elements + baz += 10 + 10 // baz leaf key/values + + if stats.LeafInuse != foo+bar+baz { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 8192 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 3 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != baz { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a large bucket can calculate stats. +func TestBucket_Stats_Large(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + db := MustOpenDB() + defer db.MustClose() + + var index int + for i := 0; i < 100; i++ { + // Add bucket with lots of keys. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 1000; i++ { + if err := b.Put([]byte(strconv.Itoa(index)), []byte(strconv.Itoa(index))); err != nil { + t.Fatal(err) + } + index++ + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + db.MustCheck() + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("widgets")).Stats() + if stats.BranchPageN != 13 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 1196 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 100000 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 3 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 25257 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 2596916 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 53248 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 4898816 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 0 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 0 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can write random keys and values across multiple transactions. +func TestBucket_Put_Single(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + index := 0 + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + + m := make(map[string][]byte) + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + for _, item := range items { + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).Put(item.Key, item.Value); err != nil { + panic("put error: " + err.Error()) + } + m[string(item.Key)] = item.Value + return nil + }); err != nil { + t.Fatal(err) + } + + // Verify all key/values so far. + if err := db.View(func(tx *bolt.Tx) error { + i := 0 + for k, v := range m { + value := tx.Bucket([]byte("widgets")).Get([]byte(k)) + if !bytes.Equal(value, v) { + t.Logf("value mismatch [run %d] (%d of %d):\nkey: %x\ngot: %x\nexp: %x", index, i, len(m), []byte(k), value, v) + db.CopyTempFile() + t.FailNow() + } + i++ + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + index++ + return true + }, nil); err != nil { + t.Error(err) + } +} + +// Ensure that a transaction can insert multiple key/value pairs at once. +func TestBucket_Put_Multiple(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + + // Bulk insert all values. + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Verify all items exist. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for _, item := range items { + value := b.Get(item.Key) + if !bytes.Equal(item.Value, value) { + db.CopyTempFile() + t.Fatalf("exp=%x; got=%x", item.Value, value) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + return true + }, qconfig()); err != nil { + t.Error(err) + } +} + +// Ensure that a transaction can delete all key/value pairs and return to a single leaf page. +func TestBucket_Delete_Quick(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + + // Bulk insert all values. + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Remove items one at a time and check consistency. + for _, item := range items { + if err := db.Update(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Delete(item.Key) + }); err != nil { + t.Fatal(err) + } + } + + // Anything before our deletion index should be nil. + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { + t.Fatalf("bucket should be empty; found: %06x", trunc(k, 3)) + return nil + }); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + return true + }, qconfig()); err != nil { + t.Error(err) + } +} + +func ExampleBucket_Put() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Start a write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + // Create a bucket. + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } + + // Set the value "bar" for the key "foo". + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } + return nil + }); err != nil { + log.Fatal(err) + } + + // Read value back in a different read-only transaction. + if err := db.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + fmt.Printf("The value of 'foo' is: %s\n", value) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // The value of 'foo' is: bar +} + +func ExampleBucket_Delete() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Start a write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + // Create a bucket. + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } + + // Set the value "bar" for the key "foo". + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } + + // Retrieve the key back from the database and verify it. + value := b.Get([]byte("foo")) + fmt.Printf("The value of 'foo' was: %s\n", value) + + return nil + }); err != nil { + log.Fatal(err) + } + + // Delete the key in a different write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Delete([]byte("foo")) + }); err != nil { + log.Fatal(err) + } + + // Retrieve the key again. + if err := db.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + if value == nil { + fmt.Printf("The value of 'foo' is now: nil\n") + } + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // The value of 'foo' was: bar + // The value of 'foo' is now: nil +} + +func ExampleBucket_ForEach() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Insert data into a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } + + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + return err + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + return err + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + return err + } + + // Iterate over items in sorted key order. + if err := b.ForEach(func(k, v []byte) error { + fmt.Printf("A %s is %s.\n", k, v) + return nil + }); err != nil { + return err + } + + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // A cat is lame. + // A dog is fun. + // A liger is awesome. +} diff --git a/vendor/github.com/boltdb/bolt/cmd/bolt/main.go b/vendor/github.com/boltdb/bolt/cmd/bolt/main.go new file mode 100644 index 00000000..b96e6f73 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/cmd/bolt/main.go @@ -0,0 +1,1532 @@ +package main + +import ( + "bytes" + "encoding/binary" + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "runtime" + "runtime/pprof" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" + "unsafe" + + "github.com/boltdb/bolt" +) + +var ( + // ErrUsage is returned when a usage message was printed and the process + // should simply exit with an error. + ErrUsage = errors.New("usage") + + // ErrUnknownCommand is returned when a CLI command is not specified. + ErrUnknownCommand = errors.New("unknown command") + + // ErrPathRequired is returned when the path to a Bolt database is not specified. + ErrPathRequired = errors.New("path required") + + // ErrFileNotFound is returned when a Bolt database does not exist. + ErrFileNotFound = errors.New("file not found") + + // ErrInvalidValue is returned when a benchmark reads an unexpected value. + ErrInvalidValue = errors.New("invalid value") + + // ErrCorrupt is returned when a checking a data file finds errors. + ErrCorrupt = errors.New("invalid value") + + // ErrNonDivisibleBatchSize is returned when the batch size can't be evenly + // divided by the iteration count. + ErrNonDivisibleBatchSize = errors.New("number of iterations must be divisible by the batch size") + + // ErrPageIDRequired is returned when a required page id is not specified. + ErrPageIDRequired = errors.New("page id required") + + // ErrPageNotFound is returned when specifying a page above the high water mark. + ErrPageNotFound = errors.New("page not found") + + // ErrPageFreed is returned when reading a page that has already been freed. + ErrPageFreed = errors.New("page freed") +) + +// PageHeaderSize represents the size of the bolt.page header. +const PageHeaderSize = 16 + +func main() { + m := NewMain() + if err := m.Run(os.Args[1:]...); err == ErrUsage { + os.Exit(2) + } else if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } +} + +// Main represents the main program execution. +type Main struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewMain returns a new instance of Main connect to the standard input/output. +func NewMain() *Main { + return &Main{ + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + } +} + +// Run executes the program. +func (m *Main) Run(args ...string) error { + // Require a command at the beginning. + if len(args) == 0 || strings.HasPrefix(args[0], "-") { + fmt.Fprintln(m.Stderr, m.Usage()) + return ErrUsage + } + + // Execute command. + switch args[0] { + case "help": + fmt.Fprintln(m.Stderr, m.Usage()) + return ErrUsage + case "bench": + return newBenchCommand(m).Run(args[1:]...) + case "check": + return newCheckCommand(m).Run(args[1:]...) + case "dump": + return newDumpCommand(m).Run(args[1:]...) + case "info": + return newInfoCommand(m).Run(args[1:]...) + case "page": + return newPageCommand(m).Run(args[1:]...) + case "pages": + return newPagesCommand(m).Run(args[1:]...) + case "stats": + return newStatsCommand(m).Run(args[1:]...) + default: + return ErrUnknownCommand + } +} + +// Usage returns the help message. +func (m *Main) Usage() string { + return strings.TrimLeft(` +Bolt is a tool for inspecting bolt databases. + +Usage: + + bolt command [arguments] + +The commands are: + + bench run synthetic benchmark against bolt + check verifies integrity of bolt database + info print basic info + help print this screen + pages print list of pages with their types + stats iterate over all pages and generate usage stats + +Use "bolt [command] -h" for more information about a command. +`, "\n") +} + +// CheckCommand represents the "check" command execution. +type CheckCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewCheckCommand returns a CheckCommand. +func newCheckCommand(m *Main) *CheckCommand { + return &CheckCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *CheckCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + // Perform consistency check. + return db.View(func(tx *bolt.Tx) error { + var count int + ch := tx.Check() + loop: + for { + select { + case err, ok := <-ch: + if !ok { + break loop + } + fmt.Fprintln(cmd.Stdout, err) + count++ + } + } + + // Print summary of errors. + if count > 0 { + fmt.Fprintf(cmd.Stdout, "%d errors found\n", count) + return ErrCorrupt + } + + // Notify user that database is valid. + fmt.Fprintln(cmd.Stdout, "OK") + return nil + }) +} + +// Usage returns the help message. +func (cmd *CheckCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt check PATH + +Check opens a database at PATH and runs an exhaustive check to verify that +all pages are accessible or are marked as freed. It also verifies that no +pages are double referenced. + +Verification errors will stream out as they are found and the process will +return after all pages have been checked. +`, "\n") +} + +// InfoCommand represents the "info" command execution. +type InfoCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewInfoCommand returns a InfoCommand. +func newInfoCommand(m *Main) *InfoCommand { + return &InfoCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *InfoCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open the database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + // Print basic database info. + info := db.Info() + fmt.Fprintf(cmd.Stdout, "Page Size: %d\n", info.PageSize) + + return nil +} + +// Usage returns the help message. +func (cmd *InfoCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt info PATH + +Info prints basic information about the Bolt database at PATH. +`, "\n") +} + +// DumpCommand represents the "dump" command execution. +type DumpCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// newDumpCommand returns a DumpCommand. +func newDumpCommand(m *Main) *DumpCommand { + return &DumpCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *DumpCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path and page id. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Read page ids. + pageIDs, err := atois(fs.Args()[1:]) + if err != nil { + return err + } else if len(pageIDs) == 0 { + return ErrPageIDRequired + } + + // Open database to retrieve page size. + pageSize, err := ReadPageSize(path) + if err != nil { + return err + } + + // Open database file handler. + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + // Print each page listed. + for i, pageID := range pageIDs { + // Print a separator. + if i > 0 { + fmt.Fprintln(cmd.Stdout, "===============================================") + } + + // Print page to stdout. + if err := cmd.PrintPage(cmd.Stdout, f, pageID, pageSize); err != nil { + return err + } + } + + return nil +} + +// PrintPage prints a given page as hexidecimal. +func (cmd *DumpCommand) PrintPage(w io.Writer, r io.ReaderAt, pageID int, pageSize int) error { + const bytesPerLineN = 16 + + // Read page into buffer. + buf := make([]byte, pageSize) + addr := pageID * pageSize + if n, err := r.ReadAt(buf, int64(addr)); err != nil { + return err + } else if n != pageSize { + return io.ErrUnexpectedEOF + } + + // Write out to writer in 16-byte lines. + var prev []byte + var skipped bool + for offset := 0; offset < pageSize; offset += bytesPerLineN { + // Retrieve current 16-byte line. + line := buf[offset : offset+bytesPerLineN] + isLastLine := (offset == (pageSize - bytesPerLineN)) + + // If it's the same as the previous line then print a skip. + if bytes.Equal(line, prev) && !isLastLine { + if !skipped { + fmt.Fprintf(w, "%07x *\n", addr+offset) + skipped = true + } + } else { + // Print line as hexadecimal in 2-byte groups. + fmt.Fprintf(w, "%07x %04x %04x %04x %04x %04x %04x %04x %04x\n", addr+offset, + line[0:2], line[2:4], line[4:6], line[6:8], + line[8:10], line[10:12], line[12:14], line[14:16], + ) + + skipped = false + } + + // Save the previous line. + prev = line + } + fmt.Fprint(w, "\n") + + return nil +} + +// Usage returns the help message. +func (cmd *DumpCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt dump -page PAGEID PATH + +Dump prints a hexidecimal dump of a single page. +`, "\n") +} + +// PageCommand represents the "page" command execution. +type PageCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// newPageCommand returns a PageCommand. +func newPageCommand(m *Main) *PageCommand { + return &PageCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *PageCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path and page id. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Read page ids. + pageIDs, err := atois(fs.Args()[1:]) + if err != nil { + return err + } else if len(pageIDs) == 0 { + return ErrPageIDRequired + } + + // Open database file handler. + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + // Print each page listed. + for i, pageID := range pageIDs { + // Print a separator. + if i > 0 { + fmt.Fprintln(cmd.Stdout, "===============================================") + } + + // Retrieve page info and page size. + p, buf, err := ReadPage(path, pageID) + if err != nil { + return err + } + + // Print basic page info. + fmt.Fprintf(cmd.Stdout, "Page ID: %d\n", p.id) + fmt.Fprintf(cmd.Stdout, "Page Type: %s\n", p.Type()) + fmt.Fprintf(cmd.Stdout, "Total Size: %d bytes\n", len(buf)) + + // Print type-specific data. + switch p.Type() { + case "meta": + err = cmd.PrintMeta(cmd.Stdout, buf) + case "leaf": + err = cmd.PrintLeaf(cmd.Stdout, buf) + case "branch": + err = cmd.PrintBranch(cmd.Stdout, buf) + case "freelist": + err = cmd.PrintFreelist(cmd.Stdout, buf) + } + if err != nil { + return err + } + } + + return nil +} + +// PrintMeta prints the data from the meta page. +func (cmd *PageCommand) PrintMeta(w io.Writer, buf []byte) error { + m := (*meta)(unsafe.Pointer(&buf[PageHeaderSize])) + fmt.Fprintf(w, "Version: %d\n", m.version) + fmt.Fprintf(w, "Page Size: %d bytes\n", m.pageSize) + fmt.Fprintf(w, "Flags: %08x\n", m.flags) + fmt.Fprintf(w, "Root: \n", m.root.root) + fmt.Fprintf(w, "Freelist: \n", m.freelist) + fmt.Fprintf(w, "HWM: \n", m.pgid) + fmt.Fprintf(w, "Txn ID: %d\n", m.txid) + fmt.Fprintf(w, "Checksum: %016x\n", m.checksum) + fmt.Fprintf(w, "\n") + return nil +} + +// PrintLeaf prints the data for a leaf page. +func (cmd *PageCommand) PrintLeaf(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each key/value. + for i := uint16(0); i < p.count; i++ { + e := p.leafPageElement(i) + + // Format key as string. + var k string + if isPrintable(string(e.key())) { + k = fmt.Sprintf("%q", string(e.key())) + } else { + k = fmt.Sprintf("%x", string(e.key())) + } + + // Format value as string. + var v string + if (e.flags & uint32(bucketLeafFlag)) != 0 { + b := (*bucket)(unsafe.Pointer(&e.value()[0])) + v = fmt.Sprintf("", b.root, b.sequence) + } else if isPrintable(string(e.value())) { + k = fmt.Sprintf("%q", string(e.value())) + } else { + k = fmt.Sprintf("%x", string(e.value())) + } + + fmt.Fprintf(w, "%s: %s\n", k, v) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintBranch prints the data for a leaf page. +func (cmd *PageCommand) PrintBranch(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each key/value. + for i := uint16(0); i < p.count; i++ { + e := p.branchPageElement(i) + + // Format key as string. + var k string + if isPrintable(string(e.key())) { + k = fmt.Sprintf("%q", string(e.key())) + } else { + k = fmt.Sprintf("%x", string(e.key())) + } + + fmt.Fprintf(w, "%s: \n", k, e.pgid) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintFreelist prints the data for a freelist page. +func (cmd *PageCommand) PrintFreelist(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each page in the freelist. + ids := (*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)) + for i := uint16(0); i < p.count; i++ { + fmt.Fprintf(w, "%d\n", ids[i]) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintPage prints a given page as hexidecimal. +func (cmd *PageCommand) PrintPage(w io.Writer, r io.ReaderAt, pageID int, pageSize int) error { + const bytesPerLineN = 16 + + // Read page into buffer. + buf := make([]byte, pageSize) + addr := pageID * pageSize + if n, err := r.ReadAt(buf, int64(addr)); err != nil { + return err + } else if n != pageSize { + return io.ErrUnexpectedEOF + } + + // Write out to writer in 16-byte lines. + var prev []byte + var skipped bool + for offset := 0; offset < pageSize; offset += bytesPerLineN { + // Retrieve current 16-byte line. + line := buf[offset : offset+bytesPerLineN] + isLastLine := (offset == (pageSize - bytesPerLineN)) + + // If it's the same as the previous line then print a skip. + if bytes.Equal(line, prev) && !isLastLine { + if !skipped { + fmt.Fprintf(w, "%07x *\n", addr+offset) + skipped = true + } + } else { + // Print line as hexadecimal in 2-byte groups. + fmt.Fprintf(w, "%07x %04x %04x %04x %04x %04x %04x %04x %04x\n", addr+offset, + line[0:2], line[2:4], line[4:6], line[6:8], + line[8:10], line[10:12], line[12:14], line[14:16], + ) + + skipped = false + } + + // Save the previous line. + prev = line + } + fmt.Fprint(w, "\n") + + return nil +} + +// Usage returns the help message. +func (cmd *PageCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt page -page PATH pageid [pageid...] + +Page prints one or more pages in human readable format. +`, "\n") +} + +// PagesCommand represents the "pages" command execution. +type PagesCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewPagesCommand returns a PagesCommand. +func newPagesCommand(m *Main) *PagesCommand { + return &PagesCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *PagesCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer func() { _ = db.Close() }() + + // Write header. + fmt.Fprintln(cmd.Stdout, "ID TYPE ITEMS OVRFLW") + fmt.Fprintln(cmd.Stdout, "======== ========== ====== ======") + + return db.Update(func(tx *bolt.Tx) error { + var id int + for { + p, err := tx.Page(id) + if err != nil { + return &PageError{ID: id, Err: err} + } else if p == nil { + break + } + + // Only display count and overflow if this is a non-free page. + var count, overflow string + if p.Type != "free" { + count = strconv.Itoa(p.Count) + if p.OverflowCount > 0 { + overflow = strconv.Itoa(p.OverflowCount) + } + } + + // Print table row. + fmt.Fprintf(cmd.Stdout, "%-8d %-10s %-6s %-6s\n", p.ID, p.Type, count, overflow) + + // Move to the next non-overflow page. + id += 1 + if p.Type != "free" { + id += p.OverflowCount + } + } + return nil + }) +} + +// Usage returns the help message. +func (cmd *PagesCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt pages PATH + +Pages prints a table of pages with their type (meta, leaf, branch, freelist). +Leaf and branch pages will show a key count in the "items" column while the +freelist will show the number of free pages in the "items" column. + +The "overflow" column shows the number of blocks that the page spills over +into. Normally there is no overflow but large keys and values can cause +a single page to take up multiple blocks. +`, "\n") +} + +// StatsCommand represents the "stats" command execution. +type StatsCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewStatsCommand returns a StatsCommand. +func newStatsCommand(m *Main) *StatsCommand { + return &StatsCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *StatsCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path, prefix := fs.Arg(0), fs.Arg(1) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + return db.View(func(tx *bolt.Tx) error { + var s bolt.BucketStats + var count int + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + if bytes.HasPrefix(name, []byte(prefix)) { + s.Add(b.Stats()) + count += 1 + } + return nil + }); err != nil { + return err + } + + fmt.Fprintf(cmd.Stdout, "Aggregate statistics for %d buckets\n\n", count) + + fmt.Fprintln(cmd.Stdout, "Page count statistics") + fmt.Fprintf(cmd.Stdout, "\tNumber of logical branch pages: %d\n", s.BranchPageN) + fmt.Fprintf(cmd.Stdout, "\tNumber of physical branch overflow pages: %d\n", s.BranchOverflowN) + fmt.Fprintf(cmd.Stdout, "\tNumber of logical leaf pages: %d\n", s.LeafPageN) + fmt.Fprintf(cmd.Stdout, "\tNumber of physical leaf overflow pages: %d\n", s.LeafOverflowN) + + fmt.Fprintln(cmd.Stdout, "Tree statistics") + fmt.Fprintf(cmd.Stdout, "\tNumber of keys/value pairs: %d\n", s.KeyN) + fmt.Fprintf(cmd.Stdout, "\tNumber of levels in B+tree: %d\n", s.Depth) + + fmt.Fprintln(cmd.Stdout, "Page size utilization") + fmt.Fprintf(cmd.Stdout, "\tBytes allocated for physical branch pages: %d\n", s.BranchAlloc) + var percentage int + if s.BranchAlloc != 0 { + percentage = int(float32(s.BranchInuse) * 100.0 / float32(s.BranchAlloc)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes actually used for branch data: %d (%d%%)\n", s.BranchInuse, percentage) + fmt.Fprintf(cmd.Stdout, "\tBytes allocated for physical leaf pages: %d\n", s.LeafAlloc) + percentage = 0 + if s.LeafAlloc != 0 { + percentage = int(float32(s.LeafInuse) * 100.0 / float32(s.LeafAlloc)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes actually used for leaf data: %d (%d%%)\n", s.LeafInuse, percentage) + + fmt.Fprintln(cmd.Stdout, "Bucket statistics") + fmt.Fprintf(cmd.Stdout, "\tTotal number of buckets: %d\n", s.BucketN) + percentage = 0 + if s.BucketN != 0 { + percentage = int(float32(s.InlineBucketN) * 100.0 / float32(s.BucketN)) + } + fmt.Fprintf(cmd.Stdout, "\tTotal number on inlined buckets: %d (%d%%)\n", s.InlineBucketN, percentage) + percentage = 0 + if s.LeafInuse != 0 { + percentage = int(float32(s.InlineBucketInuse) * 100.0 / float32(s.LeafInuse)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes used for inlined buckets: %d (%d%%)\n", s.InlineBucketInuse, percentage) + + return nil + }) +} + +// Usage returns the help message. +func (cmd *StatsCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt stats PATH + +Stats performs an extensive search of the database to track every page +reference. It starts at the current meta page and recursively iterates +through every accessible bucket. + +The following errors can be reported: + + already freed + The page is referenced more than once in the freelist. + + unreachable unfreed + The page is not referenced by a bucket or in the freelist. + + reachable freed + The page is referenced by a bucket but is also in the freelist. + + out of bounds + A page is referenced that is above the high water mark. + + multiple references + A page is referenced by more than one other page. + + invalid type + The page type is not "meta", "leaf", "branch", or "freelist". + +No errors should occur in your database. However, if for some reason you +experience corruption, please submit a ticket to the Bolt project page: + + https://github.com/boltdb/bolt/issues +`, "\n") +} + +var benchBucketName = []byte("bench") + +// BenchCommand represents the "bench" command execution. +type BenchCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewBenchCommand returns a BenchCommand using the +func newBenchCommand(m *Main) *BenchCommand { + return &BenchCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the "bench" command. +func (cmd *BenchCommand) Run(args ...string) error { + // Parse CLI arguments. + options, err := cmd.ParseFlags(args) + if err != nil { + return err + } + + // Remove path if "-work" is not set. Otherwise keep path. + if options.Work { + fmt.Fprintf(cmd.Stdout, "work: %s\n", options.Path) + } else { + defer os.Remove(options.Path) + } + + // Create database. + db, err := bolt.Open(options.Path, 0666, nil) + if err != nil { + return err + } + db.NoSync = options.NoSync + defer db.Close() + + // Write to the database. + var results BenchResults + if err := cmd.runWrites(db, options, &results); err != nil { + return fmt.Errorf("write: %v", err) + } + + // Read from the database. + if err := cmd.runReads(db, options, &results); err != nil { + return fmt.Errorf("bench: read: %s", err) + } + + // Print results. + fmt.Fprintf(os.Stderr, "# Write\t%v\t(%v/op)\t(%v op/sec)\n", results.WriteDuration, results.WriteOpDuration(), results.WriteOpsPerSecond()) + fmt.Fprintf(os.Stderr, "# Read\t%v\t(%v/op)\t(%v op/sec)\n", results.ReadDuration, results.ReadOpDuration(), results.ReadOpsPerSecond()) + fmt.Fprintln(os.Stderr, "") + return nil +} + +// ParseFlags parses the command line flags. +func (cmd *BenchCommand) ParseFlags(args []string) (*BenchOptions, error) { + var options BenchOptions + + // Parse flagset. + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.StringVar(&options.ProfileMode, "profile-mode", "rw", "") + fs.StringVar(&options.WriteMode, "write-mode", "seq", "") + fs.StringVar(&options.ReadMode, "read-mode", "seq", "") + fs.IntVar(&options.Iterations, "count", 1000, "") + fs.IntVar(&options.BatchSize, "batch-size", 0, "") + fs.IntVar(&options.KeySize, "key-size", 8, "") + fs.IntVar(&options.ValueSize, "value-size", 32, "") + fs.StringVar(&options.CPUProfile, "cpuprofile", "", "") + fs.StringVar(&options.MemProfile, "memprofile", "", "") + fs.StringVar(&options.BlockProfile, "blockprofile", "", "") + fs.Float64Var(&options.FillPercent, "fill-percent", bolt.DefaultFillPercent, "") + fs.BoolVar(&options.NoSync, "no-sync", false, "") + fs.BoolVar(&options.Work, "work", false, "") + fs.StringVar(&options.Path, "path", "", "") + fs.SetOutput(cmd.Stderr) + if err := fs.Parse(args); err != nil { + return nil, err + } + + // Set batch size to iteration size if not set. + // Require that batch size can be evenly divided by the iteration count. + if options.BatchSize == 0 { + options.BatchSize = options.Iterations + } else if options.Iterations%options.BatchSize != 0 { + return nil, ErrNonDivisibleBatchSize + } + + // Generate temp path if one is not passed in. + if options.Path == "" { + f, err := ioutil.TempFile("", "bolt-bench-") + if err != nil { + return nil, fmt.Errorf("temp file: %s", err) + } + f.Close() + os.Remove(f.Name()) + options.Path = f.Name() + } + + return &options, nil +} + +// Writes to the database. +func (cmd *BenchCommand) runWrites(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + // Start profiling for writes. + if options.ProfileMode == "rw" || options.ProfileMode == "w" { + cmd.startProfiling(options) + } + + t := time.Now() + + var err error + switch options.WriteMode { + case "seq": + err = cmd.runWritesSequential(db, options, results) + case "rnd": + err = cmd.runWritesRandom(db, options, results) + case "seq-nest": + err = cmd.runWritesSequentialNested(db, options, results) + case "rnd-nest": + err = cmd.runWritesRandomNested(db, options, results) + default: + return fmt.Errorf("invalid write mode: %s", options.WriteMode) + } + + // Save time to write. + results.WriteDuration = time.Since(t) + + // Stop profiling for writes only. + if options.ProfileMode == "w" { + cmd.stopProfiling() + } + + return err +} + +func (cmd *BenchCommand) runWritesSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + var i = uint32(0) + return cmd.runWritesWithSource(db, options, results, func() uint32 { i++; return i }) +} + +func (cmd *BenchCommand) runWritesRandom(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return cmd.runWritesWithSource(db, options, results, func() uint32 { return r.Uint32() }) +} + +func (cmd *BenchCommand) runWritesSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + var i = uint32(0) + return cmd.runWritesWithSource(db, options, results, func() uint32 { i++; return i }) +} + +func (cmd *BenchCommand) runWritesRandomNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return cmd.runWritesWithSource(db, options, results, func() uint32 { return r.Uint32() }) +} + +func (cmd *BenchCommand) runWritesWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { + results.WriteOps = options.Iterations + + for i := 0; i < options.Iterations; i += options.BatchSize { + if err := db.Update(func(tx *bolt.Tx) error { + b, _ := tx.CreateBucketIfNotExists(benchBucketName) + b.FillPercent = options.FillPercent + + for j := 0; j < options.BatchSize; j++ { + key := make([]byte, options.KeySize) + value := make([]byte, options.ValueSize) + + // Write key as uint32. + binary.BigEndian.PutUint32(key, keySource()) + + // Insert key/value. + if err := b.Put(key, value); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + } + return nil +} + +func (cmd *BenchCommand) runWritesNestedWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { + results.WriteOps = options.Iterations + + for i := 0; i < options.Iterations; i += options.BatchSize { + if err := db.Update(func(tx *bolt.Tx) error { + top, err := tx.CreateBucketIfNotExists(benchBucketName) + if err != nil { + return err + } + top.FillPercent = options.FillPercent + + // Create bucket key. + name := make([]byte, options.KeySize) + binary.BigEndian.PutUint32(name, keySource()) + + // Create bucket. + b, err := top.CreateBucketIfNotExists(name) + if err != nil { + return err + } + b.FillPercent = options.FillPercent + + for j := 0; j < options.BatchSize; j++ { + var key = make([]byte, options.KeySize) + var value = make([]byte, options.ValueSize) + + // Generate key as uint32. + binary.BigEndian.PutUint32(key, keySource()) + + // Insert value into subbucket. + if err := b.Put(key, value); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + } + return nil +} + +// Reads from the database. +func (cmd *BenchCommand) runReads(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + // Start profiling for reads. + if options.ProfileMode == "r" { + cmd.startProfiling(options) + } + + t := time.Now() + + var err error + switch options.ReadMode { + case "seq": + switch options.WriteMode { + case "seq-nest", "rnd-nest": + err = cmd.runReadsSequentialNested(db, options, results) + default: + err = cmd.runReadsSequential(db, options, results) + } + default: + return fmt.Errorf("invalid read mode: %s", options.ReadMode) + } + + // Save read time. + results.ReadDuration = time.Since(t) + + // Stop profiling for reads. + if options.ProfileMode == "rw" || options.ProfileMode == "r" { + cmd.stopProfiling() + } + + return err +} + +func (cmd *BenchCommand) runReadsSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + return db.View(func(tx *bolt.Tx) error { + t := time.Now() + + for { + var count int + + c := tx.Bucket(benchBucketName).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if v == nil { + return errors.New("invalid value") + } + count++ + } + + if options.WriteMode == "seq" && count != options.Iterations { + return fmt.Errorf("read seq: iter mismatch: expected %d, got %d", options.Iterations, count) + } + + results.ReadOps += count + + // Make sure we do this for at least a second. + if time.Since(t) >= time.Second { + break + } + } + + return nil + }) +} + +func (cmd *BenchCommand) runReadsSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + return db.View(func(tx *bolt.Tx) error { + t := time.Now() + + for { + var count int + var top = tx.Bucket(benchBucketName) + if err := top.ForEach(func(name, _ []byte) error { + c := top.Bucket(name).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if v == nil { + return ErrInvalidValue + } + count++ + } + return nil + }); err != nil { + return err + } + + if options.WriteMode == "seq-nest" && count != options.Iterations { + return fmt.Errorf("read seq-nest: iter mismatch: expected %d, got %d", options.Iterations, count) + } + + results.ReadOps += count + + // Make sure we do this for at least a second. + if time.Since(t) >= time.Second { + break + } + } + + return nil + }) +} + +// File handlers for the various profiles. +var cpuprofile, memprofile, blockprofile *os.File + +// Starts all profiles set on the options. +func (cmd *BenchCommand) startProfiling(options *BenchOptions) { + var err error + + // Start CPU profiling. + if options.CPUProfile != "" { + cpuprofile, err = os.Create(options.CPUProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create cpu profile %q: %v\n", options.CPUProfile, err) + os.Exit(1) + } + pprof.StartCPUProfile(cpuprofile) + } + + // Start memory profiling. + if options.MemProfile != "" { + memprofile, err = os.Create(options.MemProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create memory profile %q: %v\n", options.MemProfile, err) + os.Exit(1) + } + runtime.MemProfileRate = 4096 + } + + // Start fatal profiling. + if options.BlockProfile != "" { + blockprofile, err = os.Create(options.BlockProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create block profile %q: %v\n", options.BlockProfile, err) + os.Exit(1) + } + runtime.SetBlockProfileRate(1) + } +} + +// Stops all profiles. +func (cmd *BenchCommand) stopProfiling() { + if cpuprofile != nil { + pprof.StopCPUProfile() + cpuprofile.Close() + cpuprofile = nil + } + + if memprofile != nil { + pprof.Lookup("heap").WriteTo(memprofile, 0) + memprofile.Close() + memprofile = nil + } + + if blockprofile != nil { + pprof.Lookup("block").WriteTo(blockprofile, 0) + blockprofile.Close() + blockprofile = nil + runtime.SetBlockProfileRate(0) + } +} + +// BenchOptions represents the set of options that can be passed to "bolt bench". +type BenchOptions struct { + ProfileMode string + WriteMode string + ReadMode string + Iterations int + BatchSize int + KeySize int + ValueSize int + CPUProfile string + MemProfile string + BlockProfile string + StatsInterval time.Duration + FillPercent float64 + NoSync bool + Work bool + Path string +} + +// BenchResults represents the performance results of the benchmark. +type BenchResults struct { + WriteOps int + WriteDuration time.Duration + ReadOps int + ReadDuration time.Duration +} + +// Returns the duration for a single write operation. +func (r *BenchResults) WriteOpDuration() time.Duration { + if r.WriteOps == 0 { + return 0 + } + return r.WriteDuration / time.Duration(r.WriteOps) +} + +// Returns average number of write operations that can be performed per second. +func (r *BenchResults) WriteOpsPerSecond() int { + var op = r.WriteOpDuration() + if op == 0 { + return 0 + } + return int(time.Second) / int(op) +} + +// Returns the duration for a single read operation. +func (r *BenchResults) ReadOpDuration() time.Duration { + if r.ReadOps == 0 { + return 0 + } + return r.ReadDuration / time.Duration(r.ReadOps) +} + +// Returns average number of read operations that can be performed per second. +func (r *BenchResults) ReadOpsPerSecond() int { + var op = r.ReadOpDuration() + if op == 0 { + return 0 + } + return int(time.Second) / int(op) +} + +type PageError struct { + ID int + Err error +} + +func (e *PageError) Error() string { + return fmt.Sprintf("page error: id=%d, err=%s", e.ID, e.Err) +} + +// isPrintable returns true if the string is valid unicode and contains only printable runes. +func isPrintable(s string) bool { + if !utf8.ValidString(s) { + return false + } + for _, ch := range s { + if !unicode.IsPrint(ch) { + return false + } + } + return true +} + +// ReadPage reads page info & full page data from a path. +// This is not transactionally safe. +func ReadPage(path string, pageID int) (*page, []byte, error) { + // Find page size. + pageSize, err := ReadPageSize(path) + if err != nil { + return nil, nil, fmt.Errorf("read page size: %s", err) + } + + // Open database file. + f, err := os.Open(path) + if err != nil { + return nil, nil, err + } + defer f.Close() + + // Read one block into buffer. + buf := make([]byte, pageSize) + if n, err := f.ReadAt(buf, int64(pageID*pageSize)); err != nil { + return nil, nil, err + } else if n != len(buf) { + return nil, nil, io.ErrUnexpectedEOF + } + + // Determine total number of blocks. + p := (*page)(unsafe.Pointer(&buf[0])) + overflowN := p.overflow + + // Re-read entire page (with overflow) into buffer. + buf = make([]byte, (int(overflowN)+1)*pageSize) + if n, err := f.ReadAt(buf, int64(pageID*pageSize)); err != nil { + return nil, nil, err + } else if n != len(buf) { + return nil, nil, io.ErrUnexpectedEOF + } + p = (*page)(unsafe.Pointer(&buf[0])) + + return p, buf, nil +} + +// ReadPageSize reads page size a path. +// This is not transactionally safe. +func ReadPageSize(path string) (int, error) { + // Open database file. + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer f.Close() + + // Read 4KB chunk. + buf := make([]byte, 4096) + if _, err := io.ReadFull(f, buf); err != nil { + return 0, err + } + + // Read page size from metadata. + m := (*meta)(unsafe.Pointer(&buf[PageHeaderSize])) + return int(m.pageSize), nil +} + +// atois parses a slice of strings into integers. +func atois(strs []string) ([]int, error) { + var a []int + for _, str := range strs { + i, err := strconv.Atoi(str) + if err != nil { + return nil, err + } + a = append(a, i) + } + return a, nil +} + +// DO NOT EDIT. Copied from the "bolt" package. +const maxAllocSize = 0xFFFFFFF + +// DO NOT EDIT. Copied from the "bolt" package. +const ( + branchPageFlag = 0x01 + leafPageFlag = 0x02 + metaPageFlag = 0x04 + freelistPageFlag = 0x10 +) + +// DO NOT EDIT. Copied from the "bolt" package. +const bucketLeafFlag = 0x01 + +// DO NOT EDIT. Copied from the "bolt" package. +type pgid uint64 + +// DO NOT EDIT. Copied from the "bolt" package. +type txid uint64 + +// DO NOT EDIT. Copied from the "bolt" package. +type meta struct { + magic uint32 + version uint32 + pageSize uint32 + flags uint32 + root bucket + freelist pgid + pgid pgid + txid txid + checksum uint64 +} + +// DO NOT EDIT. Copied from the "bolt" package. +type bucket struct { + root pgid + sequence uint64 +} + +// DO NOT EDIT. Copied from the "bolt" package. +type page struct { + id pgid + flags uint16 + count uint16 + overflow uint32 + ptr uintptr +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) Type() string { + if (p.flags & branchPageFlag) != 0 { + return "branch" + } else if (p.flags & leafPageFlag) != 0 { + return "leaf" + } else if (p.flags & metaPageFlag) != 0 { + return "meta" + } else if (p.flags & freelistPageFlag) != 0 { + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) leafPageElement(index uint16) *leafPageElement { + n := &((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[index] + return n +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) branchPageElement(index uint16) *branchPageElement { + return &((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[index] +} + +// DO NOT EDIT. Copied from the "bolt" package. +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid pgid +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *branchPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos : n.pos+n.ksize] +} + +// DO NOT EDIT. Copied from the "bolt" package. +type leafPageElement struct { + flags uint32 + pos uint32 + ksize uint32 + vsize uint32 +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *leafPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos : n.pos+n.ksize] +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *leafPageElement) value() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos+n.ksize : n.pos+n.ksize+n.vsize] +} diff --git a/vendor/github.com/boltdb/bolt/cmd/bolt/main_test.go b/vendor/github.com/boltdb/bolt/cmd/bolt/main_test.go new file mode 100644 index 00000000..c378b790 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/cmd/bolt/main_test.go @@ -0,0 +1,185 @@ +package main_test + +import ( + "bytes" + "io/ioutil" + "os" + "strconv" + "testing" + + "github.com/boltdb/bolt" + "github.com/boltdb/bolt/cmd/bolt" +) + +// Ensure the "info" command can print information about a database. +func TestInfoCommand_Run(t *testing.T) { + db := MustOpen(0666, nil) + db.DB.Close() + defer db.Close() + + // Run the info command. + m := NewMain() + if err := m.Run("info", db.Path); err != nil { + t.Fatal(err) + } +} + +// Ensure the "stats" command executes correctly with an empty database. +func TestStatsCommand_Run_EmptyDatabase(t *testing.T) { + // Ignore + if os.Getpagesize() != 4096 { + t.Skip("system does not use 4KB page size") + } + + db := MustOpen(0666, nil) + defer db.Close() + db.DB.Close() + + // Generate expected result. + exp := "Aggregate statistics for 0 buckets\n\n" + + "Page count statistics\n" + + "\tNumber of logical branch pages: 0\n" + + "\tNumber of physical branch overflow pages: 0\n" + + "\tNumber of logical leaf pages: 0\n" + + "\tNumber of physical leaf overflow pages: 0\n" + + "Tree statistics\n" + + "\tNumber of keys/value pairs: 0\n" + + "\tNumber of levels in B+tree: 0\n" + + "Page size utilization\n" + + "\tBytes allocated for physical branch pages: 0\n" + + "\tBytes actually used for branch data: 0 (0%)\n" + + "\tBytes allocated for physical leaf pages: 0\n" + + "\tBytes actually used for leaf data: 0 (0%)\n" + + "Bucket statistics\n" + + "\tTotal number of buckets: 0\n" + + "\tTotal number on inlined buckets: 0 (0%)\n" + + "\tBytes used for inlined buckets: 0 (0%)\n" + + // Run the command. + m := NewMain() + if err := m.Run("stats", db.Path); err != nil { + t.Fatal(err) + } else if m.Stdout.String() != exp { + t.Fatalf("unexpected stdout:\n\n%s", m.Stdout.String()) + } +} + +// Ensure the "stats" command can execute correctly. +func TestStatsCommand_Run(t *testing.T) { + // Ignore + if os.Getpagesize() != 4096 { + t.Skip("system does not use 4KB page size") + } + + db := MustOpen(0666, nil) + defer db.Close() + + if err := db.Update(func(tx *bolt.Tx) error { + // Create "foo" bucket. + b, err := tx.CreateBucket([]byte("foo")) + if err != nil { + return err + } + for i := 0; i < 10; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + return err + } + } + + // Create "bar" bucket. + b, err = tx.CreateBucket([]byte("bar")) + if err != nil { + return err + } + for i := 0; i < 100; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + return err + } + } + + // Create "baz" bucket. + b, err = tx.CreateBucket([]byte("baz")) + if err != nil { + return err + } + if err := b.Put([]byte("key"), []byte("value")); err != nil { + return err + } + + return nil + }); err != nil { + t.Fatal(err) + } + db.DB.Close() + + // Generate expected result. + exp := "Aggregate statistics for 3 buckets\n\n" + + "Page count statistics\n" + + "\tNumber of logical branch pages: 0\n" + + "\tNumber of physical branch overflow pages: 0\n" + + "\tNumber of logical leaf pages: 1\n" + + "\tNumber of physical leaf overflow pages: 0\n" + + "Tree statistics\n" + + "\tNumber of keys/value pairs: 111\n" + + "\tNumber of levels in B+tree: 1\n" + + "Page size utilization\n" + + "\tBytes allocated for physical branch pages: 0\n" + + "\tBytes actually used for branch data: 0 (0%)\n" + + "\tBytes allocated for physical leaf pages: 4096\n" + + "\tBytes actually used for leaf data: 1996 (48%)\n" + + "Bucket statistics\n" + + "\tTotal number of buckets: 3\n" + + "\tTotal number on inlined buckets: 2 (66%)\n" + + "\tBytes used for inlined buckets: 236 (11%)\n" + + // Run the command. + m := NewMain() + if err := m.Run("stats", db.Path); err != nil { + t.Fatal(err) + } else if m.Stdout.String() != exp { + t.Fatalf("unexpected stdout:\n\n%s", m.Stdout.String()) + } +} + +// Main represents a test wrapper for main.Main that records output. +type Main struct { + *main.Main + Stdin bytes.Buffer + Stdout bytes.Buffer + Stderr bytes.Buffer +} + +// NewMain returns a new instance of Main. +func NewMain() *Main { + m := &Main{Main: main.NewMain()} + m.Main.Stdin = &m.Stdin + m.Main.Stdout = &m.Stdout + m.Main.Stderr = &m.Stderr + return m +} + +// MustOpen creates a Bolt database in a temporary location. +func MustOpen(mode os.FileMode, options *bolt.Options) *DB { + // Create temporary path. + f, _ := ioutil.TempFile("", "bolt-") + f.Close() + os.Remove(f.Name()) + + db, err := bolt.Open(f.Name(), mode, options) + if err != nil { + panic(err.Error()) + } + return &DB{DB: db, Path: f.Name()} +} + +// DB is a test wrapper for bolt.DB. +type DB struct { + *bolt.DB + Path string +} + +// Close closes and removes the database. +func (db *DB) Close() error { + defer os.Remove(db.Path) + return db.DB.Close() +} diff --git a/vendor/github.com/boltdb/bolt/cursor.go b/vendor/github.com/boltdb/bolt/cursor.go new file mode 100644 index 00000000..1be9f35e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/cursor.go @@ -0,0 +1,400 @@ +package bolt + +import ( + "bytes" + "fmt" + "sort" +) + +// Cursor represents an iterator that can traverse over all key/value pairs in a bucket in sorted order. +// Cursors see nested buckets with value == nil. +// Cursors can be obtained from a transaction and are valid as long as the transaction is open. +// +// Keys and values returned from the cursor are only valid for the life of the transaction. +// +// Changing data while traversing with a cursor may cause it to be invalidated +// and return unexpected keys and/or values. You must reposition your cursor +// after mutating data. +type Cursor struct { + bucket *Bucket + stack []elemRef +} + +// Bucket returns the bucket that this cursor was created from. +func (c *Cursor) Bucket() *Bucket { + return c.bucket +} + +// First moves the cursor to the first item in the bucket and returns its key and value. +// If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) First() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + c.stack = c.stack[:0] + p, n := c.bucket.pageNode(c.bucket.root) + c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) + c.first() + + // If we land on an empty page then move to the next value. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + c.next() + } + + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v + +} + +// Last moves the cursor to the last item in the bucket and returns its key and value. +// If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Last() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + c.stack = c.stack[:0] + p, n := c.bucket.pageNode(c.bucket.root) + ref := elemRef{page: p, node: n} + ref.index = ref.count() - 1 + c.stack = append(c.stack, ref) + c.last() + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Next moves the cursor to the next item in the bucket and returns its key and value. +// If the cursor is at the end of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Next() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + k, v, flags := c.next() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Prev moves the cursor to the previous item in the bucket and returns its key and value. +// If the cursor is at the beginning of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Prev() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + + // Attempt to move back one element until we're successful. + // Move up the stack as we hit the beginning of each page in our stack. + for i := len(c.stack) - 1; i >= 0; i-- { + elem := &c.stack[i] + if elem.index > 0 { + elem.index-- + break + } + c.stack = c.stack[:i] + } + + // If we've hit the end then return nil. + if len(c.stack) == 0 { + return nil, nil + } + + // Move down the stack to find the last element of the last leaf under this branch. + c.last() + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Seek moves the cursor to a given key and returns it. +// If the key does not exist then the next key is used. If no keys +// follow, a nil key is returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) { + k, v, flags := c.seek(seek) + + // If we ended up after the last element of a page then move to the next one. + if ref := &c.stack[len(c.stack)-1]; ref.index >= ref.count() { + k, v, flags = c.next() + } + + if k == nil { + return nil, nil + } else if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Delete removes the current key/value under the cursor from the bucket. +// Delete fails if current key/value is a bucket or if the transaction is not writable. +func (c *Cursor) Delete() error { + if c.bucket.tx.db == nil { + return ErrTxClosed + } else if !c.bucket.Writable() { + return ErrTxNotWritable + } + + key, _, flags := c.keyValue() + // Return an error if current value is a bucket. + if (flags & bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + c.node().del(key) + + return nil +} + +// seek moves the cursor to a given key and returns it. +// If the key does not exist then the next key is used. +func (c *Cursor) seek(seek []byte) (key []byte, value []byte, flags uint32) { + _assert(c.bucket.tx.db != nil, "tx closed") + + // Start from root page/node and traverse to correct page. + c.stack = c.stack[:0] + c.search(seek, c.bucket.root) + ref := &c.stack[len(c.stack)-1] + + // If the cursor is pointing to the end of page/node then return nil. + if ref.index >= ref.count() { + return nil, nil, 0 + } + + // If this is a bucket then return a nil value. + return c.keyValue() +} + +// first moves the cursor to the first leaf element under the last page in the stack. +func (c *Cursor) first() { + for { + // Exit when we hit a leaf page. + var ref = &c.stack[len(c.stack)-1] + if ref.isLeaf() { + break + } + + // Keep adding pages pointing to the first element to the stack. + var pgid pgid + if ref.node != nil { + pgid = ref.node.inodes[ref.index].pgid + } else { + pgid = ref.page.branchPageElement(uint16(ref.index)).pgid + } + p, n := c.bucket.pageNode(pgid) + c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) + } +} + +// last moves the cursor to the last leaf element under the last page in the stack. +func (c *Cursor) last() { + for { + // Exit when we hit a leaf page. + ref := &c.stack[len(c.stack)-1] + if ref.isLeaf() { + break + } + + // Keep adding pages pointing to the last element in the stack. + var pgid pgid + if ref.node != nil { + pgid = ref.node.inodes[ref.index].pgid + } else { + pgid = ref.page.branchPageElement(uint16(ref.index)).pgid + } + p, n := c.bucket.pageNode(pgid) + + var nextRef = elemRef{page: p, node: n} + nextRef.index = nextRef.count() - 1 + c.stack = append(c.stack, nextRef) + } +} + +// next moves to the next leaf element and returns the key and value. +// If the cursor is at the last leaf element then it stays there and returns nil. +func (c *Cursor) next() (key []byte, value []byte, flags uint32) { + for { + // Attempt to move over one element until we're successful. + // Move up the stack as we hit the end of each page in our stack. + var i int + for i = len(c.stack) - 1; i >= 0; i-- { + elem := &c.stack[i] + if elem.index < elem.count()-1 { + elem.index++ + break + } + } + + // If we've hit the root page then stop and return. This will leave the + // cursor on the last element of the last page. + if i == -1 { + return nil, nil, 0 + } + + // Otherwise start from where we left off in the stack and find the + // first element of the first leaf page. + c.stack = c.stack[:i+1] + c.first() + + // If this is an empty page then restart and move back up the stack. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + continue + } + + return c.keyValue() + } +} + +// search recursively performs a binary search against a given page/node until it finds a given key. +func (c *Cursor) search(key []byte, pgid pgid) { + p, n := c.bucket.pageNode(pgid) + if p != nil && (p.flags&(branchPageFlag|leafPageFlag)) == 0 { + panic(fmt.Sprintf("invalid page type: %d: %x", p.id, p.flags)) + } + e := elemRef{page: p, node: n} + c.stack = append(c.stack, e) + + // If we're on a leaf page/node then find the specific node. + if e.isLeaf() { + c.nsearch(key) + return + } + + if n != nil { + c.searchNode(key, n) + return + } + c.searchPage(key, p) +} + +func (c *Cursor) searchNode(key []byte, n *node) { + var exact bool + index := sort.Search(len(n.inodes), func(i int) bool { + // TODO(benbjohnson): Optimize this range search. It's a bit hacky right now. + // sort.Search() finds the lowest index where f() != -1 but we need the highest index. + ret := bytes.Compare(n.inodes[i].key, key) + if ret == 0 { + exact = true + } + return ret != -1 + }) + if !exact && index > 0 { + index-- + } + c.stack[len(c.stack)-1].index = index + + // Recursively search to the next page. + c.search(key, n.inodes[index].pgid) +} + +func (c *Cursor) searchPage(key []byte, p *page) { + // Binary search for the correct range. + inodes := p.branchPageElements() + + var exact bool + index := sort.Search(int(p.count), func(i int) bool { + // TODO(benbjohnson): Optimize this range search. It's a bit hacky right now. + // sort.Search() finds the lowest index where f() != -1 but we need the highest index. + ret := bytes.Compare(inodes[i].key(), key) + if ret == 0 { + exact = true + } + return ret != -1 + }) + if !exact && index > 0 { + index-- + } + c.stack[len(c.stack)-1].index = index + + // Recursively search to the next page. + c.search(key, inodes[index].pgid) +} + +// nsearch searches the leaf node on the top of the stack for a key. +func (c *Cursor) nsearch(key []byte) { + e := &c.stack[len(c.stack)-1] + p, n := e.page, e.node + + // If we have a node then search its inodes. + if n != nil { + index := sort.Search(len(n.inodes), func(i int) bool { + return bytes.Compare(n.inodes[i].key, key) != -1 + }) + e.index = index + return + } + + // If we have a page then search its leaf elements. + inodes := p.leafPageElements() + index := sort.Search(int(p.count), func(i int) bool { + return bytes.Compare(inodes[i].key(), key) != -1 + }) + e.index = index +} + +// keyValue returns the key and value of the current leaf element. +func (c *Cursor) keyValue() ([]byte, []byte, uint32) { + ref := &c.stack[len(c.stack)-1] + if ref.count() == 0 || ref.index >= ref.count() { + return nil, nil, 0 + } + + // Retrieve value from node. + if ref.node != nil { + inode := &ref.node.inodes[ref.index] + return inode.key, inode.value, inode.flags + } + + // Or retrieve value from page. + elem := ref.page.leafPageElement(uint16(ref.index)) + return elem.key(), elem.value(), elem.flags +} + +// node returns the node that the cursor is currently positioned on. +func (c *Cursor) node() *node { + _assert(len(c.stack) > 0, "accessing a node with a zero-length cursor stack") + + // If the top of the stack is a leaf node then just return it. + if ref := &c.stack[len(c.stack)-1]; ref.node != nil && ref.isLeaf() { + return ref.node + } + + // Start from root and traverse down the hierarchy. + var n = c.stack[0].node + if n == nil { + n = c.bucket.node(c.stack[0].page.id, nil) + } + for _, ref := range c.stack[:len(c.stack)-1] { + _assert(!n.isLeaf, "expected branch node") + n = n.childAt(int(ref.index)) + } + _assert(n.isLeaf, "expected leaf node") + return n +} + +// elemRef represents a reference to an element on a given page/node. +type elemRef struct { + page *page + node *node + index int +} + +// isLeaf returns whether the ref is pointing at a leaf page/node. +func (r *elemRef) isLeaf() bool { + if r.node != nil { + return r.node.isLeaf + } + return (r.page.flags & leafPageFlag) != 0 +} + +// count returns the number of inodes or page elements. +func (r *elemRef) count() int { + if r.node != nil { + return len(r.node.inodes) + } + return int(r.page.count) +} diff --git a/vendor/github.com/boltdb/bolt/cursor_test.go b/vendor/github.com/boltdb/bolt/cursor_test.go new file mode 100644 index 00000000..562d60f9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/cursor_test.go @@ -0,0 +1,817 @@ +package bolt_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "log" + "os" + "reflect" + "sort" + "testing" + "testing/quick" + + "github.com/boltdb/bolt" +) + +// Ensure that a cursor can return a reference to the bucket that created it. +func TestCursor_Bucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if cb := b.Cursor().Bucket(); !reflect.DeepEqual(cb, b) { + t.Fatal("cursor bucket mismatch") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can seek to the appropriate keys. +func TestCursor_Seek(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0001")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0002")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0003")); err != nil { + t.Fatal(err) + } + + if _, err := b.CreateBucket([]byte("bkt")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte("widgets")).Cursor() + + // Exact match should go to the key. + if k, v := c.Seek([]byte("bar")); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } + + // Inexact match should go to the next key. + if k, v := c.Seek([]byte("bas")); !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0003")) { + t.Fatalf("unexpected value: %v", v) + } + + // Low key should go to the first key. + if k, v := c.Seek([]byte("")); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } + + // High key should return no key. + if k, v := c.Seek([]byte("zzz")); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + // Buckets should return their key but no value. + if k, v := c.Seek([]byte("bkt")); !bytes.Equal(k, []byte("bkt")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCursor_Delete(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + const count = 1000 + + // Insert every other key between 0 and $count. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < count; i += 1 { + k := make([]byte, 8) + binary.BigEndian.PutUint64(k, uint64(i)) + if err := b.Put(k, make([]byte, 100)); err != nil { + t.Fatal(err) + } + } + if _, err := b.CreateBucket([]byte("sub")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte("widgets")).Cursor() + bound := make([]byte, 8) + binary.BigEndian.PutUint64(bound, uint64(count/2)) + for key, _ := c.First(); bytes.Compare(key, bound) < 0; key, _ = c.Next() { + if err := c.Delete(); err != nil { + t.Fatal(err) + } + } + + c.Seek([]byte("sub")) + if err := c.Delete(); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("widgets")).Stats() + if stats.KeyN != count/2+1 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can seek to the appropriate keys when there are a +// large number of keys. This test also checks that seek will always move +// forward to the next key. +// +// Related: https://github.com/boltdb/bolt/pull/187 +func TestCursor_Seek_Large(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var count = 10000 + + // Insert every other key between 0 and $count. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < count; i += 100 { + for j := i; j < i+100; j += 2 { + k := make([]byte, 8) + binary.BigEndian.PutUint64(k, uint64(j)) + if err := b.Put(k, make([]byte, 100)); err != nil { + t.Fatal(err) + } + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte("widgets")).Cursor() + for i := 0; i < count; i++ { + seek := make([]byte, 8) + binary.BigEndian.PutUint64(seek, uint64(i)) + + k, _ := c.Seek(seek) + + // The last seek is beyond the end of the the range so + // it should return nil. + if i == count-1 { + if k != nil { + t.Fatal("expected nil key") + } + continue + } + + // Otherwise we should seek to the exact key or the next key. + num := binary.BigEndian.Uint64(k) + if i%2 == 0 { + if num != uint64(i) { + t.Fatalf("unexpected num: %d", num) + } + } else { + if num != uint64(i+1) { + t.Fatalf("unexpected num: %d", num) + } + } + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a cursor can iterate over an empty bucket without error. +func TestCursor_EmptyBucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte("widgets")).Cursor() + k, v := c.First() + if k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can reverse iterate over an empty bucket without error. +func TestCursor_EmptyBucketReverse(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte("widgets")).Cursor() + k, v := c.Last() + if k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can iterate over a single root with a couple elements. +func TestCursor_Iterate_Leaf(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{0}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{1}); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + defer func() { _ = tx.Rollback() }() + + c := tx.Bucket([]byte("widgets")).Cursor() + + k, v := c.First() + if !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{1}) { + t.Fatalf("unexpected value: %v", v) + } + + k, v = c.Next() + if !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{}) { + t.Fatalf("unexpected value: %v", v) + } + + k, v = c.Next() + if !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{0}) { + t.Fatalf("unexpected value: %v", v) + } + + k, v = c.Next() + if k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + k, v = c.Next() + if k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can iterate in reverse over a single root with a couple elements. +func TestCursor_LeafRootReverse(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{0}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{1}); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + c := tx.Bucket([]byte("widgets")).Cursor() + + if k, v := c.Last(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{0}) { + t.Fatalf("unexpected value: %v", v) + } + + if k, v := c.Prev(); !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{}) { + t.Fatalf("unexpected value: %v", v) + } + + if k, v := c.Prev(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{1}) { + t.Fatalf("unexpected value: %v", v) + } + + if k, v := c.Prev(); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + if k, v := c.Prev(); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can restart from the beginning. +func TestCursor_Restart(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{}); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + c := tx.Bucket([]byte("widgets")).Cursor() + + if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } + if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } + + if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } + if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } +} + +// Ensure that a cursor can skip over empty pages that have been deleted. +func TestCursor_First_EmptyPages(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Create 1000 keys in the "widgets" bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 1000; i++ { + if err := b.Put(u64tob(uint64(i)), []byte{}); err != nil { + t.Fatal(err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + + // Delete half the keys and then try to iterate. + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < 600; i++ { + if err := b.Delete(u64tob(uint64(i))); err != nil { + t.Fatal(err) + } + } + + c := b.Cursor() + var n int + for k, _ := c.First(); k != nil; k, _ = c.Next() { + n++ + } + if n != 400 { + t.Fatalf("unexpected key count: %d", n) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx can iterate over all elements in a bucket. +func TestCursor_QuickCheck(t *testing.T) { + f := func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + + // Bulk insert all values. + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Sort test data. + sort.Sort(items) + + // Iterate over all items and check consistency. + var index = 0 + tx, err = db.Begin(false) + if err != nil { + t.Fatal(err) + } + + c := tx.Bucket([]byte("widgets")).Cursor() + for k, v := c.First(); k != nil && index < len(items); k, v = c.Next() { + if !bytes.Equal(k, items[index].Key) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, items[index].Value) { + t.Fatalf("unexpected value: %v", v) + } + index++ + } + if len(items) != index { + t.Fatalf("unexpected item count: %v, expected %v", len(items), index) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + return true + } + if err := quick.Check(f, qconfig()); err != nil { + t.Error(err) + } +} + +// Ensure that a transaction can iterate over all elements in a bucket in reverse. +func TestCursor_QuickCheck_Reverse(t *testing.T) { + f := func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + + // Bulk insert all values. + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Sort test data. + sort.Sort(revtestdata(items)) + + // Iterate over all items and check consistency. + var index = 0 + tx, err = db.Begin(false) + if err != nil { + t.Fatal(err) + } + c := tx.Bucket([]byte("widgets")).Cursor() + for k, v := c.Last(); k != nil && index < len(items); k, v = c.Prev() { + if !bytes.Equal(k, items[index].Key) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, items[index].Value) { + t.Fatalf("unexpected value: %v", v) + } + index++ + } + if len(items) != index { + t.Fatalf("unexpected item count: %v, expected %v", len(items), index) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + return true + } + if err := quick.Check(f, qconfig()); err != nil { + t.Error(err) + } +} + +// Ensure that a Tx cursor can iterate over subbuckets. +func TestCursor_QuickCheck_BucketsOnly(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("baz")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + var names []string + c := tx.Bucket([]byte("widgets")).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + names = append(names, string(k)) + if v != nil { + t.Fatalf("unexpected value: %v", v) + } + } + if !reflect.DeepEqual(names, []string{"bar", "baz", "foo"}) { + t.Fatalf("unexpected names: %+v", names) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx cursor can reverse iterate over subbuckets. +func TestCursor_QuickCheck_BucketsOnly_Reverse(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("baz")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + var names []string + c := tx.Bucket([]byte("widgets")).Cursor() + for k, v := c.Last(); k != nil; k, v = c.Prev() { + names = append(names, string(k)) + if v != nil { + t.Fatalf("unexpected value: %v", v) + } + } + if !reflect.DeepEqual(names, []string{"foo", "baz", "bar"}) { + t.Fatalf("unexpected names: %+v", names) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func ExampleCursor() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Start a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + // Create a new bucket. + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } + + // Insert data into a bucket. + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + log.Fatal(err) + } + + // Create a cursor for iteration. + c := b.Cursor() + + // Iterate over items in sorted key order. This starts from the + // first key/value pair and updates the k/v variables to the + // next key/value on each iteration. + // + // The loop finishes at the end of the cursor when a nil key is returned. + for k, v := c.First(); k != nil; k, v = c.Next() { + fmt.Printf("A %s is %s.\n", k, v) + } + + return nil + }); err != nil { + log.Fatal(err) + } + + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // A cat is lame. + // A dog is fun. + // A liger is awesome. +} + +func ExampleCursor_reverse() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Start a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + // Create a new bucket. + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } + + // Insert data into a bucket. + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + log.Fatal(err) + } + + // Create a cursor for iteration. + c := b.Cursor() + + // Iterate over items in reverse sorted key order. This starts + // from the last key/value pair and updates the k/v variables to + // the previous key/value on each iteration. + // + // The loop finishes at the beginning of the cursor when a nil key + // is returned. + for k, v := c.Last(); k != nil; k, v = c.Prev() { + fmt.Printf("A %s is %s.\n", k, v) + } + + return nil + }); err != nil { + log.Fatal(err) + } + + // Close the database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // A liger is awesome. + // A dog is fun. + // A cat is lame. +} diff --git a/vendor/github.com/boltdb/bolt/db.go b/vendor/github.com/boltdb/bolt/db.go new file mode 100644 index 00000000..1223493c --- /dev/null +++ b/vendor/github.com/boltdb/bolt/db.go @@ -0,0 +1,1036 @@ +package bolt + +import ( + "errors" + "fmt" + "hash/fnv" + "log" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" + "time" + "unsafe" +) + +// The largest step that can be taken when remapping the mmap. +const maxMmapStep = 1 << 30 // 1GB + +// The data file format version. +const version = 2 + +// Represents a marker value to indicate that a file is a Bolt DB. +const magic uint32 = 0xED0CDAED + +// IgnoreNoSync specifies whether the NoSync field of a DB is ignored when +// syncing changes to a file. This is required as some operating systems, +// such as OpenBSD, do not have a unified buffer cache (UBC) and writes +// must be synchronized using the msync(2) syscall. +const IgnoreNoSync = runtime.GOOS == "openbsd" + +// Default values if not set in a DB instance. +const ( + DefaultMaxBatchSize int = 1000 + DefaultMaxBatchDelay = 10 * time.Millisecond + DefaultAllocSize = 16 * 1024 * 1024 +) + +// default page size for db is set to the OS page size. +var defaultPageSize = os.Getpagesize() + +// DB represents a collection of buckets persisted to a file on disk. +// All data access is performed through transactions which can be obtained through the DB. +// All the functions on DB will return a ErrDatabaseNotOpen if accessed before Open() is called. +type DB struct { + // When enabled, the database will perform a Check() after every commit. + // A panic is issued if the database is in an inconsistent state. This + // flag has a large performance impact so it should only be used for + // debugging purposes. + StrictMode bool + + // Setting the NoSync flag will cause the database to skip fsync() + // calls after each commit. This can be useful when bulk loading data + // into a database and you can restart the bulk load in the event of + // a system failure or database corruption. Do not set this flag for + // normal use. + // + // If the package global IgnoreNoSync constant is true, this value is + // ignored. See the comment on that constant for more details. + // + // THIS IS UNSAFE. PLEASE USE WITH CAUTION. + NoSync bool + + // When true, skips the truncate call when growing the database. + // Setting this to true is only safe on non-ext3/ext4 systems. + // Skipping truncation avoids preallocation of hard drive space and + // bypasses a truncate() and fsync() syscall on remapping. + // + // https://github.com/boltdb/bolt/issues/284 + NoGrowSync bool + + // If you want to read the entire database fast, you can set MmapFlag to + // syscall.MAP_POPULATE on Linux 2.6.23+ for sequential read-ahead. + MmapFlags int + + // MaxBatchSize is the maximum size of a batch. Default value is + // copied from DefaultMaxBatchSize in Open. + // + // If <=0, disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchSize int + + // MaxBatchDelay is the maximum delay before a batch starts. + // Default value is copied from DefaultMaxBatchDelay in Open. + // + // If <=0, effectively disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchDelay time.Duration + + // AllocSize is the amount of space allocated when the database + // needs to create new pages. This is done to amortize the cost + // of truncate() and fsync() when growing the data file. + AllocSize int + + path string + file *os.File + lockfile *os.File // windows only + dataref []byte // mmap'ed readonly, write throws SEGV + data *[maxMapSize]byte + datasz int + filesz int // current on disk file size + meta0 *meta + meta1 *meta + pageSize int + opened bool + rwtx *Tx + txs []*Tx + freelist *freelist + stats Stats + + pagePool sync.Pool + + batchMu sync.Mutex + batch *batch + + rwlock sync.Mutex // Allows only one writer at a time. + metalock sync.Mutex // Protects meta page access. + mmaplock sync.RWMutex // Protects mmap access during remapping. + statlock sync.RWMutex // Protects stats access. + + ops struct { + writeAt func(b []byte, off int64) (n int, err error) + } + + // Read only mode. + // When true, Update() and Begin(true) return ErrDatabaseReadOnly immediately. + readOnly bool +} + +// Path returns the path to currently open database file. +func (db *DB) Path() string { + return db.path +} + +// GoString returns the Go string representation of the database. +func (db *DB) GoString() string { + return fmt.Sprintf("bolt.DB{path:%q}", db.path) +} + +// String returns the string representation of the database. +func (db *DB) String() string { + return fmt.Sprintf("DB<%q>", db.path) +} + +// Open creates and opens a database at the given path. +// If the file does not exist then it will be created automatically. +// Passing in nil options will cause Bolt to open the database with the default options. +func Open(path string, mode os.FileMode, options *Options) (*DB, error) { + var db = &DB{opened: true} + + // Set default options if no options are provided. + if options == nil { + options = DefaultOptions + } + db.NoGrowSync = options.NoGrowSync + db.MmapFlags = options.MmapFlags + + // Set default values for later DB operations. + db.MaxBatchSize = DefaultMaxBatchSize + db.MaxBatchDelay = DefaultMaxBatchDelay + db.AllocSize = DefaultAllocSize + + flag := os.O_RDWR + if options.ReadOnly { + flag = os.O_RDONLY + db.readOnly = true + } + + // Open data file and separate sync handler for metadata writes. + db.path = path + var err error + if db.file, err = os.OpenFile(db.path, flag|os.O_CREATE, mode); err != nil { + _ = db.close() + return nil, err + } + + // Lock file so that other processes using Bolt in read-write mode cannot + // use the database at the same time. This would cause corruption since + // the two processes would write meta pages and free pages separately. + // The database file is locked exclusively (only one process can grab the lock) + // if !options.ReadOnly. + // The database file is locked using the shared lock (more than one process may + // hold a lock at the same time) otherwise (options.ReadOnly is set). + if err := flock(db, mode, !db.readOnly, options.Timeout); err != nil { + _ = db.close() + return nil, err + } + + // Default values for test hooks + db.ops.writeAt = db.file.WriteAt + + // Initialize the database if it doesn't exist. + if info, err := db.file.Stat(); err != nil { + return nil, err + } else if info.Size() == 0 { + // Initialize new files with meta pages. + if err := db.init(); err != nil { + return nil, err + } + } else { + // Read the first meta page to determine the page size. + var buf [0x1000]byte + if _, err := db.file.ReadAt(buf[:], 0); err == nil { + m := db.pageInBuffer(buf[:], 0).meta() + if err := m.validate(); err != nil { + // If we can't read the page size, we can assume it's the same + // as the OS -- since that's how the page size was chosen in the + // first place. + // + // If the first page is invalid and this OS uses a different + // page size than what the database was created with then we + // are out of luck and cannot access the database. + db.pageSize = os.Getpagesize() + } else { + db.pageSize = int(m.pageSize) + } + } + } + + // Initialize page pool. + db.pagePool = sync.Pool{ + New: func() interface{} { + return make([]byte, db.pageSize) + }, + } + + // Memory map the data file. + if err := db.mmap(options.InitialMmapSize); err != nil { + _ = db.close() + return nil, err + } + + // Read in the freelist. + db.freelist = newFreelist() + db.freelist.read(db.page(db.meta().freelist)) + + // Mark the database as opened and return. + return db, nil +} + +// mmap opens the underlying memory-mapped file and initializes the meta references. +// minsz is the minimum size that the new mmap can be. +func (db *DB) mmap(minsz int) error { + db.mmaplock.Lock() + defer db.mmaplock.Unlock() + + info, err := db.file.Stat() + if err != nil { + return fmt.Errorf("mmap stat error: %s", err) + } else if int(info.Size()) < db.pageSize*2 { + return fmt.Errorf("file size too small") + } + + // Ensure the size is at least the minimum size. + var size = int(info.Size()) + if size < minsz { + size = minsz + } + size, err = db.mmapSize(size) + if err != nil { + return err + } + + // Dereference all mmap references before unmapping. + if db.rwtx != nil { + db.rwtx.root.dereference() + } + + // Unmap existing data before continuing. + if err := db.munmap(); err != nil { + return err + } + + // Memory-map the data file as a byte slice. + if err := mmap(db, size); err != nil { + return err + } + + // Save references to the meta pages. + db.meta0 = db.page(0).meta() + db.meta1 = db.page(1).meta() + + // Validate the meta pages. We only return an error if both meta pages fail + // validation, since meta0 failing validation means that it wasn't saved + // properly -- but we can recover using meta1. And vice-versa. + err0 := db.meta0.validate() + err1 := db.meta1.validate() + if err0 != nil && err1 != nil { + return err0 + } + + return nil +} + +// munmap unmaps the data file from memory. +func (db *DB) munmap() error { + if err := munmap(db); err != nil { + return fmt.Errorf("unmap error: " + err.Error()) + } + return nil +} + +// mmapSize determines the appropriate size for the mmap given the current size +// of the database. The minimum size is 32KB and doubles until it reaches 1GB. +// Returns an error if the new mmap size is greater than the max allowed. +func (db *DB) mmapSize(size int) (int, error) { + // Double the size from 32KB until 1GB. + for i := uint(15); i <= 30; i++ { + if size <= 1< maxMapSize { + return 0, fmt.Errorf("mmap too large") + } + + // If larger than 1GB then grow by 1GB at a time. + sz := int64(size) + if remainder := sz % int64(maxMmapStep); remainder > 0 { + sz += int64(maxMmapStep) - remainder + } + + // Ensure that the mmap size is a multiple of the page size. + // This should always be true since we're incrementing in MBs. + pageSize := int64(db.pageSize) + if (sz % pageSize) != 0 { + sz = ((sz / pageSize) + 1) * pageSize + } + + // If we've exceeded the max size then only grow up to the max size. + if sz > maxMapSize { + sz = maxMapSize + } + + return int(sz), nil +} + +// init creates a new database file and initializes its meta pages. +func (db *DB) init() error { + // Set the page size to the OS page size. + db.pageSize = os.Getpagesize() + + // Create two meta pages on a buffer. + buf := make([]byte, db.pageSize*4) + for i := 0; i < 2; i++ { + p := db.pageInBuffer(buf[:], pgid(i)) + p.id = pgid(i) + p.flags = metaPageFlag + + // Initialize the meta page. + m := p.meta() + m.magic = magic + m.version = version + m.pageSize = uint32(db.pageSize) + m.freelist = 2 + m.root = bucket{root: 3} + m.pgid = 4 + m.txid = txid(i) + m.checksum = m.sum64() + } + + // Write an empty freelist at page 3. + p := db.pageInBuffer(buf[:], pgid(2)) + p.id = pgid(2) + p.flags = freelistPageFlag + p.count = 0 + + // Write an empty leaf page at page 4. + p = db.pageInBuffer(buf[:], pgid(3)) + p.id = pgid(3) + p.flags = leafPageFlag + p.count = 0 + + // Write the buffer to our data file. + if _, err := db.ops.writeAt(buf, 0); err != nil { + return err + } + if err := fdatasync(db); err != nil { + return err + } + + return nil +} + +// Close releases all database resources. +// All transactions must be closed before closing the database. +func (db *DB) Close() error { + db.rwlock.Lock() + defer db.rwlock.Unlock() + + db.metalock.Lock() + defer db.metalock.Unlock() + + db.mmaplock.RLock() + defer db.mmaplock.RUnlock() + + return db.close() +} + +func (db *DB) close() error { + if !db.opened { + return nil + } + + db.opened = false + + db.freelist = nil + + // Clear ops. + db.ops.writeAt = nil + + // Close the mmap. + if err := db.munmap(); err != nil { + return err + } + + // Close file handles. + if db.file != nil { + // No need to unlock read-only file. + if !db.readOnly { + // Unlock the file. + if err := funlock(db); err != nil { + log.Printf("bolt.Close(): funlock error: %s", err) + } + } + + // Close the file descriptor. + if err := db.file.Close(); err != nil { + return fmt.Errorf("db file close: %s", err) + } + db.file = nil + } + + db.path = "" + return nil +} + +// Begin starts a new transaction. +// Multiple read-only transactions can be used concurrently but only one +// write transaction can be used at a time. Starting multiple write transactions +// will cause the calls to block and be serialized until the current write +// transaction finishes. +// +// Transactions should not be dependent on one another. Opening a read +// transaction and a write transaction in the same goroutine can cause the +// writer to deadlock because the database periodically needs to re-mmap itself +// as it grows and it cannot do that while a read transaction is open. +// +// If a long running read transaction (for example, a snapshot transaction) is +// needed, you might want to set DB.InitialMmapSize to a large enough value +// to avoid potential blocking of write transaction. +// +// IMPORTANT: You must close read-only transactions after you are finished or +// else the database will not reclaim old pages. +func (db *DB) Begin(writable bool) (*Tx, error) { + if writable { + return db.beginRWTx() + } + return db.beginTx() +} + +func (db *DB) beginTx() (*Tx, error) { + // Lock the meta pages while we initialize the transaction. We obtain + // the meta lock before the mmap lock because that's the order that the + // write transaction will obtain them. + db.metalock.Lock() + + // Obtain a read-only lock on the mmap. When the mmap is remapped it will + // obtain a write lock so all transactions must finish before it can be + // remapped. + db.mmaplock.RLock() + + // Exit if the database is not open yet. + if !db.opened { + db.mmaplock.RUnlock() + db.metalock.Unlock() + return nil, ErrDatabaseNotOpen + } + + // Create a transaction associated with the database. + t := &Tx{} + t.init(db) + + // Keep track of transaction until it closes. + db.txs = append(db.txs, t) + n := len(db.txs) + + // Unlock the meta pages. + db.metalock.Unlock() + + // Update the transaction stats. + db.statlock.Lock() + db.stats.TxN++ + db.stats.OpenTxN = n + db.statlock.Unlock() + + return t, nil +} + +func (db *DB) beginRWTx() (*Tx, error) { + // If the database was opened with Options.ReadOnly, return an error. + if db.readOnly { + return nil, ErrDatabaseReadOnly + } + + // Obtain writer lock. This is released by the transaction when it closes. + // This enforces only one writer transaction at a time. + db.rwlock.Lock() + + // Once we have the writer lock then we can lock the meta pages so that + // we can set up the transaction. + db.metalock.Lock() + defer db.metalock.Unlock() + + // Exit if the database is not open yet. + if !db.opened { + db.rwlock.Unlock() + return nil, ErrDatabaseNotOpen + } + + // Create a transaction associated with the database. + t := &Tx{writable: true} + t.init(db) + db.rwtx = t + + // Free any pages associated with closed read-only transactions. + var minid txid = 0xFFFFFFFFFFFFFFFF + for _, t := range db.txs { + if t.meta.txid < minid { + minid = t.meta.txid + } + } + if minid > 0 { + db.freelist.release(minid - 1) + } + + return t, nil +} + +// removeTx removes a transaction from the database. +func (db *DB) removeTx(tx *Tx) { + // Release the read lock on the mmap. + db.mmaplock.RUnlock() + + // Use the meta lock to restrict access to the DB object. + db.metalock.Lock() + + // Remove the transaction. + for i, t := range db.txs { + if t == tx { + db.txs = append(db.txs[:i], db.txs[i+1:]...) + break + } + } + n := len(db.txs) + + // Unlock the meta pages. + db.metalock.Unlock() + + // Merge statistics. + db.statlock.Lock() + db.stats.OpenTxN = n + db.stats.TxStats.add(&tx.stats) + db.statlock.Unlock() +} + +// Update executes a function within the context of a read-write managed transaction. +// If no error is returned from the function then the transaction is committed. +// If an error is returned then the entire transaction is rolled back. +// Any error that is returned from the function or returned from the commit is +// returned from the Update() method. +// +// Attempting to manually commit or rollback within the function will cause a panic. +func (db *DB) Update(fn func(*Tx) error) error { + t, err := db.Begin(true) + if err != nil { + return err + } + + // Make sure the transaction rolls back in the event of a panic. + defer func() { + if t.db != nil { + t.rollback() + } + }() + + // Mark as a managed tx so that the inner function cannot manually commit. + t.managed = true + + // If an error is returned from the function then rollback and return error. + err = fn(t) + t.managed = false + if err != nil { + _ = t.Rollback() + return err + } + + return t.Commit() +} + +// View executes a function within the context of a managed read-only transaction. +// Any error that is returned from the function is returned from the View() method. +// +// Attempting to manually rollback within the function will cause a panic. +func (db *DB) View(fn func(*Tx) error) error { + t, err := db.Begin(false) + if err != nil { + return err + } + + // Make sure the transaction rolls back in the event of a panic. + defer func() { + if t.db != nil { + t.rollback() + } + }() + + // Mark as a managed tx so that the inner function cannot manually rollback. + t.managed = true + + // If an error is returned from the function then pass it through. + err = fn(t) + t.managed = false + if err != nil { + _ = t.Rollback() + return err + } + + if err := t.Rollback(); err != nil { + return err + } + + return nil +} + +// Batch calls fn as part of a batch. It behaves similar to Update, +// except: +// +// 1. concurrent Batch calls can be combined into a single Bolt +// transaction. +// +// 2. the function passed to Batch may be called multiple times, +// regardless of whether it returns error or not. +// +// This means that Batch function side effects must be idempotent and +// take permanent effect only after a successful return is seen in +// caller. +// +// The maximum batch size and delay can be adjusted with DB.MaxBatchSize +// and DB.MaxBatchDelay, respectively. +// +// Batch is only useful when there are multiple goroutines calling it. +func (db *DB) Batch(fn func(*Tx) error) error { + errCh := make(chan error, 1) + + db.batchMu.Lock() + if (db.batch == nil) || (db.batch != nil && len(db.batch.calls) >= db.MaxBatchSize) { + // There is no existing batch, or the existing batch is full; start a new one. + db.batch = &batch{ + db: db, + } + db.batch.timer = time.AfterFunc(db.MaxBatchDelay, db.batch.trigger) + } + db.batch.calls = append(db.batch.calls, call{fn: fn, err: errCh}) + if len(db.batch.calls) >= db.MaxBatchSize { + // wake up batch, it's ready to run + go db.batch.trigger() + } + db.batchMu.Unlock() + + err := <-errCh + if err == trySolo { + err = db.Update(fn) + } + return err +} + +type call struct { + fn func(*Tx) error + err chan<- error +} + +type batch struct { + db *DB + timer *time.Timer + start sync.Once + calls []call +} + +// trigger runs the batch if it hasn't already been run. +func (b *batch) trigger() { + b.start.Do(b.run) +} + +// run performs the transactions in the batch and communicates results +// back to DB.Batch. +func (b *batch) run() { + b.db.batchMu.Lock() + b.timer.Stop() + // Make sure no new work is added to this batch, but don't break + // other batches. + if b.db.batch == b { + b.db.batch = nil + } + b.db.batchMu.Unlock() + +retry: + for len(b.calls) > 0 { + var failIdx = -1 + err := b.db.Update(func(tx *Tx) error { + for i, c := range b.calls { + if err := safelyCall(c.fn, tx); err != nil { + failIdx = i + return err + } + } + return nil + }) + + if failIdx >= 0 { + // take the failing transaction out of the batch. it's + // safe to shorten b.calls here because db.batch no longer + // points to us, and we hold the mutex anyway. + c := b.calls[failIdx] + b.calls[failIdx], b.calls = b.calls[len(b.calls)-1], b.calls[:len(b.calls)-1] + // tell the submitter re-run it solo, continue with the rest of the batch + c.err <- trySolo + continue retry + } + + // pass success, or bolt internal errors, to all callers + for _, c := range b.calls { + if c.err != nil { + c.err <- err + } + } + break retry + } +} + +// trySolo is a special sentinel error value used for signaling that a +// transaction function should be re-run. It should never be seen by +// callers. +var trySolo = errors.New("batch function returned an error and should be re-run solo") + +type panicked struct { + reason interface{} +} + +func (p panicked) Error() string { + if err, ok := p.reason.(error); ok { + return err.Error() + } + return fmt.Sprintf("panic: %v", p.reason) +} + +func safelyCall(fn func(*Tx) error, tx *Tx) (err error) { + defer func() { + if p := recover(); p != nil { + err = panicked{p} + } + }() + return fn(tx) +} + +// Sync executes fdatasync() against the database file handle. +// +// This is not necessary under normal operation, however, if you use NoSync +// then it allows you to force the database file to sync against the disk. +func (db *DB) Sync() error { return fdatasync(db) } + +// Stats retrieves ongoing performance stats for the database. +// This is only updated when a transaction closes. +func (db *DB) Stats() Stats { + db.statlock.RLock() + defer db.statlock.RUnlock() + return db.stats +} + +// This is for internal access to the raw data bytes from the C cursor, use +// carefully, or not at all. +func (db *DB) Info() *Info { + return &Info{uintptr(unsafe.Pointer(&db.data[0])), db.pageSize} +} + +// page retrieves a page reference from the mmap based on the current page size. +func (db *DB) page(id pgid) *page { + pos := id * pgid(db.pageSize) + return (*page)(unsafe.Pointer(&db.data[pos])) +} + +// pageInBuffer retrieves a page reference from a given byte array based on the current page size. +func (db *DB) pageInBuffer(b []byte, id pgid) *page { + return (*page)(unsafe.Pointer(&b[id*pgid(db.pageSize)])) +} + +// meta retrieves the current meta page reference. +func (db *DB) meta() *meta { + // We have to return the meta with the highest txid which doesn't fail + // validation. Otherwise, we can cause errors when in fact the database is + // in a consistent state. metaA is the one with the higher txid. + metaA := db.meta0 + metaB := db.meta1 + if db.meta1.txid > db.meta0.txid { + metaA = db.meta1 + metaB = db.meta0 + } + + // Use higher meta page if valid. Otherwise fallback to previous, if valid. + if err := metaA.validate(); err == nil { + return metaA + } else if err := metaB.validate(); err == nil { + return metaB + } + + // This should never be reached, because both meta1 and meta0 were validated + // on mmap() and we do fsync() on every write. + panic("bolt.DB.meta(): invalid meta pages") +} + +// allocate returns a contiguous block of memory starting at a given page. +func (db *DB) allocate(count int) (*page, error) { + // Allocate a temporary buffer for the page. + var buf []byte + if count == 1 { + buf = db.pagePool.Get().([]byte) + } else { + buf = make([]byte, count*db.pageSize) + } + p := (*page)(unsafe.Pointer(&buf[0])) + p.overflow = uint32(count - 1) + + // Use pages from the freelist if they are available. + if p.id = db.freelist.allocate(count); p.id != 0 { + return p, nil + } + + // Resize mmap() if we're at the end. + p.id = db.rwtx.meta.pgid + var minsz = int((p.id+pgid(count))+1) * db.pageSize + if minsz >= db.datasz { + if err := db.mmap(minsz); err != nil { + return nil, fmt.Errorf("mmap allocate error: %s", err) + } + } + + // Move the page id high water mark. + db.rwtx.meta.pgid += pgid(count) + + return p, nil +} + +// grow grows the size of the database to the given sz. +func (db *DB) grow(sz int) error { + // Ignore if the new size is less than available file size. + if sz <= db.filesz { + return nil + } + + // If the data is smaller than the alloc size then only allocate what's needed. + // Once it goes over the allocation size then allocate in chunks. + if db.datasz < db.AllocSize { + sz = db.datasz + } else { + sz += db.AllocSize + } + + // Truncate and fsync to ensure file size metadata is flushed. + // https://github.com/boltdb/bolt/issues/284 + if !db.NoGrowSync && !db.readOnly { + if runtime.GOOS != "windows" { + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("file resize error: %s", err) + } + } + if err := db.file.Sync(); err != nil { + return fmt.Errorf("file sync error: %s", err) + } + } + + db.filesz = sz + return nil +} + +func (db *DB) IsReadOnly() bool { + return db.readOnly +} + +// Options represents the options that can be set when opening a database. +type Options struct { + // Timeout is the amount of time to wait to obtain a file lock. + // When set to zero it will wait indefinitely. This option is only + // available on Darwin and Linux. + Timeout time.Duration + + // Sets the DB.NoGrowSync flag before memory mapping the file. + NoGrowSync bool + + // Open database in read-only mode. Uses flock(..., LOCK_SH |LOCK_NB) to + // grab a shared lock (UNIX). + ReadOnly bool + + // Sets the DB.MmapFlags flag before memory mapping the file. + MmapFlags int + + // InitialMmapSize is the initial mmap size of the database + // in bytes. Read transactions won't block write transaction + // if the InitialMmapSize is large enough to hold database mmap + // size. (See DB.Begin for more information) + // + // If <=0, the initial map size is 0. + // If initialMmapSize is smaller than the previous database size, + // it takes no effect. + InitialMmapSize int +} + +// DefaultOptions represent the options used if nil options are passed into Open(). +// No timeout is used which will cause Bolt to wait indefinitely for a lock. +var DefaultOptions = &Options{ + Timeout: 0, + NoGrowSync: false, +} + +// Stats represents statistics about the database. +type Stats struct { + // Freelist stats + FreePageN int // total number of free pages on the freelist + PendingPageN int // total number of pending pages on the freelist + FreeAlloc int // total bytes allocated in free pages + FreelistInuse int // total bytes used by the freelist + + // Transaction stats + TxN int // total number of started read transactions + OpenTxN int // number of currently open read transactions + + TxStats TxStats // global, ongoing stats. +} + +// Sub calculates and returns the difference between two sets of database stats. +// This is useful when obtaining stats at two different points and time and +// you need the performance counters that occurred within that time span. +func (s *Stats) Sub(other *Stats) Stats { + if other == nil { + return *s + } + var diff Stats + diff.FreePageN = s.FreePageN + diff.PendingPageN = s.PendingPageN + diff.FreeAlloc = s.FreeAlloc + diff.FreelistInuse = s.FreelistInuse + diff.TxN = other.TxN - s.TxN + diff.TxStats = s.TxStats.Sub(&other.TxStats) + return diff +} + +func (s *Stats) add(other *Stats) { + s.TxStats.add(&other.TxStats) +} + +type Info struct { + Data uintptr + PageSize int +} + +type meta struct { + magic uint32 + version uint32 + pageSize uint32 + flags uint32 + root bucket + freelist pgid + pgid pgid + txid txid + checksum uint64 +} + +// validate checks the marker bytes and version of the meta page to ensure it matches this binary. +func (m *meta) validate() error { + if m.magic != magic { + return ErrInvalid + } else if m.version != version { + return ErrVersionMismatch + } else if m.checksum != 0 && m.checksum != m.sum64() { + return ErrChecksum + } + return nil +} + +// copy copies one meta object to another. +func (m *meta) copy(dest *meta) { + *dest = *m +} + +// write writes the meta onto a page. +func (m *meta) write(p *page) { + if m.root.root >= m.pgid { + panic(fmt.Sprintf("root bucket pgid (%d) above high water mark (%d)", m.root.root, m.pgid)) + } else if m.freelist >= m.pgid { + panic(fmt.Sprintf("freelist pgid (%d) above high water mark (%d)", m.freelist, m.pgid)) + } + + // Page id is either going to be 0 or 1 which we can determine by the transaction ID. + p.id = pgid(m.txid % 2) + p.flags |= metaPageFlag + + // Calculate the checksum. + m.checksum = m.sum64() + + m.copy(p.meta()) +} + +// generates the checksum for the meta. +func (m *meta) sum64() uint64 { + var h = fnv.New64a() + _, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:]) + return h.Sum64() +} + +// _assert will panic with a given formatted message if the given condition is false. +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func warn(v ...interface{}) { fmt.Fprintln(os.Stderr, v...) } +func warnf(msg string, v ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", v...) } + +func printstack() { + stack := strings.Join(strings.Split(string(debug.Stack()), "\n")[2:], "\n") + fmt.Fprintln(os.Stderr, stack) +} diff --git a/vendor/github.com/boltdb/bolt/db_test.go b/vendor/github.com/boltdb/bolt/db_test.go new file mode 100644 index 00000000..74ff93a9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/db_test.go @@ -0,0 +1,1706 @@ +package bolt_test + +import ( + "bytes" + "encoding/binary" + "errors" + "flag" + "fmt" + "hash/fnv" + "io/ioutil" + "log" + "os" + "path/filepath" + "regexp" + "runtime" + "sort" + "strings" + "sync" + "testing" + "time" + "unsafe" + + "github.com/boltdb/bolt" +) + +var statsFlag = flag.Bool("stats", false, "show performance stats") + +// version is the data file format version. +const version = 2 + +// magic is the marker value to indicate that a file is a Bolt DB. +const magic uint32 = 0xED0CDAED + +// pageSize is the size of one page in the data file. +const pageSize = 4096 + +// pageHeaderSize is the size of a page header. +const pageHeaderSize = 16 + +// meta represents a simplified version of a database meta page for testing. +type meta struct { + magic uint32 + version uint32 + _ uint32 + _ uint32 + _ [16]byte + _ uint64 + pgid uint64 + _ uint64 + checksum uint64 +} + +// Ensure that a database can be opened without error. +func TestOpen(t *testing.T) { + path := tempfile() + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } else if db == nil { + t.Fatal("expected db") + } + + if s := db.Path(); s != path { + t.Fatalf("unexpected path: %s", s) + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } +} + +// Ensure that opening a database with a blank path returns an error. +func TestOpen_ErrPathRequired(t *testing.T) { + _, err := bolt.Open("", 0666, nil) + if err == nil { + t.Fatalf("expected error") + } +} + +// Ensure that opening a database with a bad path returns an error. +func TestOpen_ErrNotExists(t *testing.T) { + _, err := bolt.Open(filepath.Join(tempfile(), "bad-path"), 0666, nil) + if err == nil { + t.Fatal("expected error") + } +} + +// Ensure that opening a file that is not a Bolt database returns ErrInvalid. +func TestOpen_ErrInvalid(t *testing.T) { + path := tempfile() + + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if _, err := fmt.Fprintln(f, "this is not a bolt database"); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrInvalid { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that opening a file with two invalid versions returns ErrVersionMismatch. +func TestOpen_ErrVersionMismatch(t *testing.T) { + if pageSize != os.Getpagesize() { + t.Skip("page size mismatch") + } + + // Create empty database. + db := MustOpenDB() + path := db.Path() + defer db.MustClose() + + // Close database. + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } + + // Read data file. + buf, err := ioutil.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + // Rewrite meta pages. + meta0 := (*meta)(unsafe.Pointer(&buf[pageHeaderSize])) + meta0.version++ + meta1 := (*meta)(unsafe.Pointer(&buf[pageSize+pageHeaderSize])) + meta1.version++ + if err := ioutil.WriteFile(path, buf, 0666); err != nil { + t.Fatal(err) + } + + // Reopen data file. + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrVersionMismatch { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that opening a file with two invalid checksums returns ErrChecksum. +func TestOpen_ErrChecksum(t *testing.T) { + if pageSize != os.Getpagesize() { + t.Skip("page size mismatch") + } + + // Create empty database. + db := MustOpenDB() + path := db.Path() + defer db.MustClose() + + // Close database. + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } + + // Read data file. + buf, err := ioutil.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + // Rewrite meta pages. + meta0 := (*meta)(unsafe.Pointer(&buf[pageHeaderSize])) + meta0.pgid++ + meta1 := (*meta)(unsafe.Pointer(&buf[pageSize+pageHeaderSize])) + meta1.pgid++ + if err := ioutil.WriteFile(path, buf, 0666); err != nil { + t.Fatal(err) + } + + // Reopen data file. + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrChecksum { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that opening an already open database file will timeout. +func TestOpen_Timeout(t *testing.T) { + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") + } + + path := tempfile() + + // Open a data file. + db0, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } else if db0 == nil { + t.Fatal("expected database") + } + + // Attempt to open the database again. + start := time.Now() + db1, err := bolt.Open(path, 0666, &bolt.Options{Timeout: 100 * time.Millisecond}) + if err != bolt.ErrTimeout { + t.Fatalf("unexpected timeout: %s", err) + } else if db1 != nil { + t.Fatal("unexpected database") + } else if time.Since(start) <= 100*time.Millisecond { + t.Fatal("expected to wait at least timeout duration") + } + + if err := db0.Close(); err != nil { + t.Fatal(err) + } +} + +// Ensure that opening an already open database file will wait until its closed. +func TestOpen_Wait(t *testing.T) { + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") + } + + path := tempfile() + + // Open a data file. + db0, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + + // Close it in just a bit. + time.AfterFunc(100*time.Millisecond, func() { _ = db0.Close() }) + + // Attempt to open the database again. + start := time.Now() + db1, err := bolt.Open(path, 0666, &bolt.Options{Timeout: 200 * time.Millisecond}) + if err != nil { + t.Fatal(err) + } else if time.Since(start) <= 100*time.Millisecond { + t.Fatal("expected to wait at least timeout duration") + } + + if err := db1.Close(); err != nil { + t.Fatal(err) + } +} + +// Ensure that opening a database does not increase its size. +// https://github.com/boltdb/bolt/issues/291 +func TestOpen_Size(t *testing.T) { + // Open a data file. + db := MustOpenDB() + path := db.Path() + defer db.MustClose() + + pagesize := db.Info().PageSize + + // Insert until we get above the minimum 4MB size. + if err := db.Update(func(tx *bolt.Tx) error { + b, _ := tx.CreateBucketIfNotExists([]byte("data")) + for i := 0; i < 10000; i++ { + if err := b.Put([]byte(fmt.Sprintf("%04d", i)), make([]byte, 1000)); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Close database and grab the size. + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } + sz := fileSize(path) + if sz == 0 { + t.Fatalf("unexpected new file size: %d", sz) + } + + // Reopen database, update, and check size again. + db0, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db0.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db0.Close(); err != nil { + t.Fatal(err) + } + newSz := fileSize(path) + if newSz == 0 { + t.Fatalf("unexpected new file size: %d", newSz) + } + + // Compare the original size with the new size. + // db size might increase by a few page sizes due to the new small update. + if sz < newSz-5*int64(pagesize) { + t.Fatalf("unexpected file growth: %d => %d", sz, newSz) + } +} + +// Ensure that opening a database beyond the max step size does not increase its size. +// https://github.com/boltdb/bolt/issues/303 +func TestOpen_Size_Large(t *testing.T) { + if testing.Short() { + t.Skip("short mode") + } + + // Open a data file. + db := MustOpenDB() + path := db.Path() + defer db.MustClose() + + pagesize := db.Info().PageSize + + // Insert until we get above the minimum 4MB size. + var index uint64 + for i := 0; i < 10000; i++ { + if err := db.Update(func(tx *bolt.Tx) error { + b, _ := tx.CreateBucketIfNotExists([]byte("data")) + for j := 0; j < 1000; j++ { + if err := b.Put(u64tob(index), make([]byte, 50)); err != nil { + t.Fatal(err) + } + index++ + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + // Close database and grab the size. + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } + sz := fileSize(path) + if sz == 0 { + t.Fatalf("unexpected new file size: %d", sz) + } else if sz < (1 << 30) { + t.Fatalf("expected larger initial size: %d", sz) + } + + // Reopen database, update, and check size again. + db0, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db0.Update(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}) + }); err != nil { + t.Fatal(err) + } + if err := db0.Close(); err != nil { + t.Fatal(err) + } + + newSz := fileSize(path) + if newSz == 0 { + t.Fatalf("unexpected new file size: %d", newSz) + } + + // Compare the original size with the new size. + // db size might increase by a few page sizes due to the new small update. + if sz < newSz-5*int64(pagesize) { + t.Fatalf("unexpected file growth: %d => %d", sz, newSz) + } +} + +// Ensure that a re-opened database is consistent. +func TestOpen_Check(t *testing.T) { + path := tempfile() + + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { return <-tx.Check() }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + + db, err = bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { return <-tx.Check() }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } +} + +// Ensure that write errors to the meta file handler during initialization are returned. +func TestOpen_MetaInitWriteError(t *testing.T) { + t.Skip("pending") +} + +// Ensure that a database that is too small returns an error. +func TestOpen_FileTooSmall(t *testing.T) { + path := tempfile() + + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + + // corrupt the database + if err := os.Truncate(path, int64(os.Getpagesize())); err != nil { + t.Fatal(err) + } + + db, err = bolt.Open(path, 0666, nil) + if err == nil || err.Error() != "file size too small" { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that a database can be opened in read-only mode by multiple processes +// and that a database can not be opened in read-write mode and in read-only +// mode at the same time. +func TestOpen_ReadOnly(t *testing.T) { + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") + } + + bucket, key, value := []byte(`bucket`), []byte(`key`), []byte(`value`) + + path := tempfile() + + // Open in read-write mode. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } else if db.IsReadOnly() { + t.Fatal("db should not be in read only mode") + } + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket(bucket) + if err != nil { + return err + } + if err := b.Put(key, value); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + + // Open in read-only mode. + db0, err := bolt.Open(path, 0666, &bolt.Options{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + // Opening in read-write mode should return an error. + if _, err = bolt.Open(path, 0666, &bolt.Options{Timeout: time.Millisecond * 100}); err == nil { + t.Fatal("expected error") + } + + // And again (in read-only mode). + db1, err := bolt.Open(path, 0666, &bolt.Options{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + // Verify both read-only databases are accessible. + for _, db := range []*bolt.DB{db0, db1} { + // Verify is is in read only mode indeed. + if !db.IsReadOnly() { + t.Fatal("expected read only mode") + } + + // Read-only databases should not allow updates. + if err := db.Update(func(*bolt.Tx) error { + panic(`should never get here`) + }); err != bolt.ErrDatabaseReadOnly { + t.Fatalf("unexpected error: %s", err) + } + + // Read-only databases should not allow beginning writable txns. + if _, err := db.Begin(true); err != bolt.ErrDatabaseReadOnly { + t.Fatalf("unexpected error: %s", err) + } + + // Verify the data. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(bucket) + if b == nil { + return fmt.Errorf("expected bucket `%s`", string(bucket)) + } + + got := string(b.Get(key)) + expected := string(value) + if got != expected { + return fmt.Errorf("expected `%s`, got `%s`", expected, got) + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + if err := db0.Close(); err != nil { + t.Fatal(err) + } + if err := db1.Close(); err != nil { + t.Fatal(err) + } +} + +// TestDB_Open_InitialMmapSize tests if having InitialMmapSize large enough +// to hold data from concurrent write transaction resolves the issue that +// read transaction blocks the write transaction and causes deadlock. +// This is a very hacky test since the mmap size is not exposed. +func TestDB_Open_InitialMmapSize(t *testing.T) { + path := tempfile() + defer os.Remove(path) + + initMmapSize := 1 << 31 // 2GB + testWriteSize := 1 << 27 // 134MB + + db, err := bolt.Open(path, 0666, &bolt.Options{InitialMmapSize: initMmapSize}) + if err != nil { + t.Fatal(err) + } + + // create a long-running read transaction + // that never gets closed while writing + rtx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + + // create a write transaction + wtx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := wtx.CreateBucket([]byte("test")) + if err != nil { + t.Fatal(err) + } + + // and commit a large write + err = b.Put([]byte("foo"), make([]byte, testWriteSize)) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + if err := wtx.Commit(); err != nil { + t.Fatal(err) + } + done <- struct{}{} + }() + + select { + case <-time.After(5 * time.Second): + t.Errorf("unexpected that the reader blocks writer") + case <-done: + } + + if err := rtx.Rollback(); err != nil { + t.Fatal(err) + } +} + +// Ensure that a database cannot open a transaction when it's not open. +func TestDB_Begin_ErrDatabaseNotOpen(t *testing.T) { + var db bolt.DB + if _, err := db.Begin(false); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that a read-write transaction can be retrieved. +func TestDB_BeginRW(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } else if tx == nil { + t.Fatal("expected tx") + } + + if tx.DB() != db.DB { + t.Fatal("unexpected tx database") + } else if !tx.Writable() { + t.Fatal("expected writable tx") + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } +} + +// Ensure that opening a transaction while the DB is closed returns an error. +func TestDB_BeginRW_Closed(t *testing.T) { + var db bolt.DB + if _, err := db.Begin(true); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestDB_Close_PendingTx_RW(t *testing.T) { testDB_Close_PendingTx(t, true) } +func TestDB_Close_PendingTx_RO(t *testing.T) { testDB_Close_PendingTx(t, false) } + +// Ensure that a database cannot close while transactions are open. +func testDB_Close_PendingTx(t *testing.T, writable bool) { + db := MustOpenDB() + defer db.MustClose() + + // Start transaction. + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + // Open update in separate goroutine. + done := make(chan struct{}) + go func() { + if err := db.Close(); err != nil { + t.Fatal(err) + } + close(done) + }() + + // Ensure database hasn't closed. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("database closed too early") + default: + } + + // Commit transaction. + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Ensure database closed now. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + default: + t.Fatal("database did not close") + } +} + +// Ensure a database can provide a transactional block. +func TestDB_Update(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("expected nil value, got: %v", v) + } + if v := b.Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a closed database returns an error while running a transaction block +func TestDB_Update_Closed(t *testing.T) { + var db bolt.DB + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure a panic occurs while trying to commit a managed transaction. +func TestDB_Update_ManualCommit(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var panicked bool + if err := db.Update(func(tx *bolt.Tx) error { + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + }() + return nil + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } +} + +// Ensure a panic occurs while trying to rollback a managed transaction. +func TestDB_Update_ManualRollback(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var panicked bool + if err := db.Update(func(tx *bolt.Tx) error { + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + }() + return nil + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } +} + +// Ensure a panic occurs while trying to commit a managed transaction. +func TestDB_View_ManualCommit(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var panicked bool + if err := db.View(func(tx *bolt.Tx) error { + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + }() + return nil + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } +} + +// Ensure a panic occurs while trying to rollback a managed transaction. +func TestDB_View_ManualRollback(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var panicked bool + if err := db.View(func(tx *bolt.Tx) error { + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + }() + return nil + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } +} + +// Ensure a write transaction that panics does not hold open locks. +func TestDB_Update_Panic(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Panic during update but recover. + func() { + defer func() { + if r := recover(); r != nil { + t.Log("recover: update", r) + } + }() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + panic("omg") + }); err != nil { + t.Fatal(err) + } + }() + + // Verify we can update again. + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Verify that our change persisted. + if err := db.Update(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure a database can return an error through a read-only transactional block. +func TestDB_View_Error(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.View(func(tx *bolt.Tx) error { + return errors.New("xxx") + }); err == nil || err.Error() != "xxx" { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure a read transaction that panics does not hold open locks. +func TestDB_View_Panic(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Panic during view transaction but recover. + func() { + defer func() { + if r := recover(); r != nil { + t.Log("recover: view", r) + } + }() + + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + panic("omg") + }); err != nil { + t.Fatal(err) + } + }() + + // Verify that we can still use read transactions. + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that DB stats can be returned. +func TestDB_Stats(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + stats := db.Stats() + if stats.TxStats.PageCount != 2 { + t.Fatalf("unexpected TxStats.PageCount: %d", stats.TxStats.PageCount) + } else if stats.FreePageN != 0 { + t.Fatalf("unexpected FreePageN != 0: %d", stats.FreePageN) + } else if stats.PendingPageN != 2 { + t.Fatalf("unexpected PendingPageN != 2: %d", stats.PendingPageN) + } +} + +// Ensure that database pages are in expected order and type. +func TestDB_Consistency(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + if err := db.Update(func(tx *bolt.Tx) error { + if p, _ := tx.Page(0); p == nil { + t.Fatal("expected page") + } else if p.Type != "meta" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(1); p == nil { + t.Fatal("expected page") + } else if p.Type != "meta" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(2); p == nil { + t.Fatal("expected page") + } else if p.Type != "free" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(3); p == nil { + t.Fatal("expected page") + } else if p.Type != "free" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(4); p == nil { + t.Fatal("expected page") + } else if p.Type != "leaf" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(5); p == nil { + t.Fatal("expected page") + } else if p.Type != "freelist" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(6); p != nil { + t.Fatal("unexpected page") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that DB stats can be subtracted from one another. +func TestDBStats_Sub(t *testing.T) { + var a, b bolt.Stats + a.TxStats.PageCount = 3 + a.FreePageN = 4 + b.TxStats.PageCount = 10 + b.FreePageN = 14 + diff := b.Sub(&a) + if diff.TxStats.PageCount != 7 { + t.Fatalf("unexpected TxStats.PageCount: %d", diff.TxStats.PageCount) + } + + // free page stats are copied from the receiver and not subtracted + if diff.FreePageN != 14 { + t.Fatalf("unexpected FreePageN: %d", diff.FreePageN) + } +} + +// Ensure two functions can perform updates in a single batch. +func TestDB_Batch(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Iterate over multiple updates in separate goroutines. + n := 2 + ch := make(chan error) + for i := 0; i < n; i++ { + go func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + }(i) + } + + // Check all responses to make sure there's no error. + for i := 0; i < n; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < n; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestDB_Batch_Panic(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var sentinel int + var bork = &sentinel + var problem interface{} + var err error + + // Execute a function inside a batch that panics. + func() { + defer func() { + if p := recover(); p != nil { + problem = p + } + }() + err = db.Batch(func(tx *bolt.Tx) error { + panic(bork) + }) + }() + + // Verify there is no error. + if g, e := err, error(nil); g != e { + t.Fatalf("wrong error: %v != %v", g, e) + } + // Verify the panic was captured. + if g, e := problem, bork; g != e { + t.Fatalf("wrong error: %v != %v", g, e) + } +} + +func TestDB_BatchFull(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + const size = 3 + // buffered so we never leak goroutines + ch := make(chan error, size) + put := func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + } + + db.MaxBatchSize = size + // high enough to never trigger here + db.MaxBatchDelay = 1 * time.Hour + + go put(1) + go put(2) + + // Give the batch a chance to exhibit bugs. + time.Sleep(10 * time.Millisecond) + + // not triggered yet + select { + case <-ch: + t.Fatalf("batch triggered too early") + default: + } + + go put(3) + + // Check all responses to make sure there's no error. + for i := 0; i < size; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 1; i <= size; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestDB_BatchTime(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + const size = 1 + // buffered so we never leak goroutines + ch := make(chan error, size) + put := func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + } + + db.MaxBatchSize = 1000 + db.MaxBatchDelay = 0 + + go put(1) + + // Batch must trigger by time alone. + + // Check all responses to make sure there's no error. + for i := 0; i < size; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 1; i <= size; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func ExampleDB_Update() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Execute several commands within a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } + return nil + }); err != nil { + log.Fatal(err) + } + + // Read the value back from a separate read-only transaction. + if err := db.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + fmt.Printf("The value of 'foo' is: %s\n", value) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // The value of 'foo' is: bar +} + +func ExampleDB_View() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Insert data into a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("people")) + if err != nil { + return err + } + if err := b.Put([]byte("john"), []byte("doe")); err != nil { + return err + } + if err := b.Put([]byte("susy"), []byte("que")); err != nil { + return err + } + return nil + }); err != nil { + log.Fatal(err) + } + + // Access data from within a read-only transactional block. + if err := db.View(func(tx *bolt.Tx) error { + v := tx.Bucket([]byte("people")).Get([]byte("john")) + fmt.Printf("John's last name is %s.\n", v) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // John's last name is doe. +} + +func ExampleDB_Begin_ReadOnly() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Create a bucket using a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + log.Fatal(err) + } + + // Create several keys in a transaction. + tx, err := db.Begin(true) + if err != nil { + log.Fatal(err) + } + b := tx.Bucket([]byte("widgets")) + if err := b.Put([]byte("john"), []byte("blue")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("abby"), []byte("red")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("zephyr"), []byte("purple")); err != nil { + log.Fatal(err) + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } + + // Iterate over the values in sorted key order. + tx, err = db.Begin(false) + if err != nil { + log.Fatal(err) + } + c := tx.Bucket([]byte("widgets")).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + fmt.Printf("%s likes %s\n", k, v) + } + + if err := tx.Rollback(); err != nil { + log.Fatal(err) + } + + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // abby likes red + // john likes blue + // zephyr likes purple +} + +func BenchmarkDBBatchAutomatic(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for round := 0; round < 1000; round++ { + wg.Add(1) + + go func(id uint32) { + defer wg.Done() + <-start + + h := fnv.New32a() + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, id) + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + insert := func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("bench")) + return b.Put(k, []byte("filler")) + } + if err := db.Batch(insert); err != nil { + b.Error(err) + return + } + }(uint32(round)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func BenchmarkDBBatchSingle(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for round := 0; round < 1000; round++ { + wg.Add(1) + go func(id uint32) { + defer wg.Done() + <-start + + h := fnv.New32a() + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, id) + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + insert := func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("bench")) + return b.Put(k, []byte("filler")) + } + if err := db.Update(insert); err != nil { + b.Error(err) + return + } + }(uint32(round)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func BenchmarkDBBatchManual10x100(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for major := 0; major < 10; major++ { + wg.Add(1) + go func(id uint32) { + defer wg.Done() + <-start + + insert100 := func(tx *bolt.Tx) error { + h := fnv.New32a() + buf := make([]byte, 4) + for minor := uint32(0); minor < 100; minor++ { + binary.LittleEndian.PutUint32(buf, uint32(id*100+minor)) + h.Reset() + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + b := tx.Bucket([]byte("bench")) + if err := b.Put(k, []byte("filler")); err != nil { + return err + } + } + return nil + } + if err := db.Update(insert100); err != nil { + b.Fatal(err) + } + }(uint32(major)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func validateBatchBench(b *testing.B, db *DB) { + var rollback = errors.New("sentinel error to cause rollback") + validate := func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("bench")) + h := fnv.New32a() + buf := make([]byte, 4) + for id := uint32(0); id < 1000; id++ { + binary.LittleEndian.PutUint32(buf, id) + h.Reset() + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + v := bucket.Get(k) + if v == nil { + b.Errorf("not found id=%d key=%x", id, k) + continue + } + if g, e := v, []byte("filler"); !bytes.Equal(g, e) { + b.Errorf("bad value for id=%d key=%x: %s != %q", id, k, g, e) + } + if err := bucket.Delete(k); err != nil { + return err + } + } + // should be empty now + c := bucket.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + b.Errorf("unexpected key: %x = %q", k, v) + } + return rollback + } + if err := db.Update(validate); err != nil && err != rollback { + b.Error(err) + } +} + +// DB is a test wrapper for bolt.DB. +type DB struct { + *bolt.DB +} + +// MustOpenDB returns a new, open DB at a temporary location. +func MustOpenDB() *DB { + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + panic(err) + } + return &DB{db} +} + +// Close closes the database and deletes the underlying file. +func (db *DB) Close() error { + // Log statistics. + if *statsFlag { + db.PrintStats() + } + + // Check database consistency after every test. + db.MustCheck() + + // Close database and remove file. + defer os.Remove(db.Path()) + return db.DB.Close() +} + +// MustClose closes the database and deletes the underlying file. Panic on error. +func (db *DB) MustClose() { + if err := db.Close(); err != nil { + panic(err) + } +} + +// PrintStats prints the database stats +func (db *DB) PrintStats() { + var stats = db.Stats() + fmt.Printf("[db] %-20s %-20s %-20s\n", + fmt.Sprintf("pg(%d/%d)", stats.TxStats.PageCount, stats.TxStats.PageAlloc), + fmt.Sprintf("cur(%d)", stats.TxStats.CursorCount), + fmt.Sprintf("node(%d/%d)", stats.TxStats.NodeCount, stats.TxStats.NodeDeref), + ) + fmt.Printf(" %-20s %-20s %-20s\n", + fmt.Sprintf("rebal(%d/%v)", stats.TxStats.Rebalance, truncDuration(stats.TxStats.RebalanceTime)), + fmt.Sprintf("spill(%d/%v)", stats.TxStats.Spill, truncDuration(stats.TxStats.SpillTime)), + fmt.Sprintf("w(%d/%v)", stats.TxStats.Write, truncDuration(stats.TxStats.WriteTime)), + ) +} + +// MustCheck runs a consistency check on the database and panics if any errors are found. +func (db *DB) MustCheck() { + if err := db.Update(func(tx *bolt.Tx) error { + // Collect all the errors. + var errors []error + for err := range tx.Check() { + errors = append(errors, err) + if len(errors) > 10 { + break + } + } + + // If errors occurred, copy the DB and print the errors. + if len(errors) > 0 { + var path = tempfile() + if err := tx.CopyFile(path, 0600); err != nil { + panic(err) + } + + // Print errors. + fmt.Print("\n\n") + fmt.Printf("consistency check failed (%d errors)\n", len(errors)) + for _, err := range errors { + fmt.Println(err) + } + fmt.Println("") + fmt.Println("db saved to:") + fmt.Println(path) + fmt.Print("\n\n") + os.Exit(-1) + } + + return nil + }); err != nil && err != bolt.ErrDatabaseNotOpen { + panic(err) + } +} + +// CopyTempFile copies a database to a temporary file. +func (db *DB) CopyTempFile() { + path := tempfile() + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(path, 0600) + }); err != nil { + panic(err) + } + fmt.Println("db copied to: ", path) +} + +// tempfile returns a temporary file path. +func tempfile() string { + f, err := ioutil.TempFile("", "bolt-") + if err != nil { + panic(err) + } + if err := f.Close(); err != nil { + panic(err) + } + if err := os.Remove(f.Name()); err != nil { + panic(err) + } + return f.Name() +} + +// mustContainKeys checks that a bucket contains a given set of keys. +func mustContainKeys(b *bolt.Bucket, m map[string]string) { + found := make(map[string]string) + if err := b.ForEach(func(k, _ []byte) error { + found[string(k)] = "" + return nil + }); err != nil { + panic(err) + } + + // Check for keys found in bucket that shouldn't be there. + var keys []string + for k, _ := range found { + if _, ok := m[string(k)]; !ok { + keys = append(keys, k) + } + } + if len(keys) > 0 { + sort.Strings(keys) + panic(fmt.Sprintf("keys found(%d): %s", len(keys), strings.Join(keys, ","))) + } + + // Check for keys not found in bucket that should be there. + for k, _ := range m { + if _, ok := found[string(k)]; !ok { + keys = append(keys, k) + } + } + if len(keys) > 0 { + sort.Strings(keys) + panic(fmt.Sprintf("keys not found(%d): %s", len(keys), strings.Join(keys, ","))) + } +} + +func trunc(b []byte, length int) []byte { + if length < len(b) { + return b[:length] + } + return b +} + +func truncDuration(d time.Duration) string { + return regexp.MustCompile(`^(\d+)(\.\d+)`).ReplaceAllString(d.String(), "$1") +} + +func fileSize(path string) int64 { + fi, err := os.Stat(path) + if err != nil { + return 0 + } + return fi.Size() +} + +func warn(v ...interface{}) { fmt.Fprintln(os.Stderr, v...) } +func warnf(msg string, v ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", v...) } + +// u64tob converts a uint64 into an 8-byte slice. +func u64tob(v uint64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, v) + return b +} + +// btou64 converts an 8-byte slice into an uint64. +func btou64(b []byte) uint64 { return binary.BigEndian.Uint64(b) } diff --git a/vendor/github.com/boltdb/bolt/doc.go b/vendor/github.com/boltdb/bolt/doc.go new file mode 100644 index 00000000..cc937845 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/doc.go @@ -0,0 +1,44 @@ +/* +Package bolt implements a low-level key/value store in pure Go. It supports +fully serializable transactions, ACID semantics, and lock-free MVCC with +multiple readers and a single writer. Bolt can be used for projects that +want a simple data store without the need to add large dependencies such as +Postgres or MySQL. + +Bolt is a single-level, zero-copy, B+tree data store. This means that Bolt is +optimized for fast read access and does not require recovery in the event of a +system crash. Transactions which have not finished committing will simply be +rolled back in the event of a crash. + +The design of Bolt is based on Howard Chu's LMDB database project. + +Bolt currently works on Windows, Mac OS X, and Linux. + + +Basics + +There are only a few types in Bolt: DB, Bucket, Tx, and Cursor. The DB is +a collection of buckets and is represented by a single file on disk. A bucket is +a collection of unique keys that are associated with values. + +Transactions provide either read-only or read-write access to the database. +Read-only transactions can retrieve key/value pairs and can use Cursors to +iterate over the dataset sequentially. Read-write transactions can create and +delete buckets and can insert and remove keys. Only one read-write transaction +is allowed at a time. + + +Caveats + +The database uses a read-only, memory-mapped data file to ensure that +applications cannot corrupt the database, however, this means that keys and +values returned from Bolt cannot be changed. Writing to a read-only byte slice +will cause Go to panic. + +Keys and values retrieved from the database are only valid for the life of +the transaction. When used outside the transaction, these byte slices can +point to different data or can point to invalid memory which will cause a panic. + + +*/ +package bolt diff --git a/vendor/github.com/boltdb/bolt/errors.go b/vendor/github.com/boltdb/bolt/errors.go new file mode 100644 index 00000000..a3620a3e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/errors.go @@ -0,0 +1,71 @@ +package bolt + +import "errors" + +// These errors can be returned when opening or calling methods on a DB. +var ( + // ErrDatabaseNotOpen is returned when a DB instance is accessed before it + // is opened or after it is closed. + ErrDatabaseNotOpen = errors.New("database not open") + + // ErrDatabaseOpen is returned when opening a database that is + // already open. + ErrDatabaseOpen = errors.New("database already open") + + // ErrInvalid is returned when both meta pages on a database are invalid. + // This typically occurs when a file is not a bolt database. + ErrInvalid = errors.New("invalid database") + + // ErrVersionMismatch is returned when the data file was created with a + // different version of Bolt. + ErrVersionMismatch = errors.New("version mismatch") + + // ErrChecksum is returned when either meta page checksum does not match. + ErrChecksum = errors.New("checksum error") + + // ErrTimeout is returned when a database cannot obtain an exclusive lock + // on the data file after the timeout passed to Open(). + ErrTimeout = errors.New("timeout") +) + +// These errors can occur when beginning or committing a Tx. +var ( + // ErrTxNotWritable is returned when performing a write operation on a + // read-only transaction. + ErrTxNotWritable = errors.New("tx not writable") + + // ErrTxClosed is returned when committing or rolling back a transaction + // that has already been committed or rolled back. + ErrTxClosed = errors.New("tx closed") + + // ErrDatabaseReadOnly is returned when a mutating transaction is started on a + // read-only database. + ErrDatabaseReadOnly = errors.New("database is in read-only mode") +) + +// These errors can occur when putting or deleting a value or a bucket. +var ( + // ErrBucketNotFound is returned when trying to access a bucket that has + // not been created yet. + ErrBucketNotFound = errors.New("bucket not found") + + // ErrBucketExists is returned when creating a bucket that already exists. + ErrBucketExists = errors.New("bucket already exists") + + // ErrBucketNameRequired is returned when creating a bucket with a blank name. + ErrBucketNameRequired = errors.New("bucket name required") + + // ErrKeyRequired is returned when inserting a zero-length key. + ErrKeyRequired = errors.New("key required") + + // ErrKeyTooLarge is returned when inserting a key that is larger than MaxKeySize. + ErrKeyTooLarge = errors.New("key too large") + + // ErrValueTooLarge is returned when inserting a value that is larger than MaxValueSize. + ErrValueTooLarge = errors.New("value too large") + + // ErrIncompatibleValue is returned when trying create or delete a bucket + // on an existing non-bucket key or when trying to create or delete a + // non-bucket key on an existing bucket key. + ErrIncompatibleValue = errors.New("incompatible value") +) diff --git a/vendor/github.com/boltdb/bolt/freelist.go b/vendor/github.com/boltdb/bolt/freelist.go new file mode 100644 index 00000000..1b7ba91b --- /dev/null +++ b/vendor/github.com/boltdb/bolt/freelist.go @@ -0,0 +1,248 @@ +package bolt + +import ( + "fmt" + "sort" + "unsafe" +) + +// freelist represents a list of all pages that are available for allocation. +// It also tracks pages that have been freed but are still in use by open transactions. +type freelist struct { + ids []pgid // all free and available free page ids. + pending map[txid][]pgid // mapping of soon-to-be free page ids by tx. + cache map[pgid]bool // fast lookup of all free and pending page ids. +} + +// newFreelist returns an empty, initialized freelist. +func newFreelist() *freelist { + return &freelist{ + pending: make(map[txid][]pgid), + cache: make(map[pgid]bool), + } +} + +// size returns the size of the page after serialization. +func (f *freelist) size() int { + return pageHeaderSize + (int(unsafe.Sizeof(pgid(0))) * f.count()) +} + +// count returns count of pages on the freelist +func (f *freelist) count() int { + return f.free_count() + f.pending_count() +} + +// free_count returns count of free pages +func (f *freelist) free_count() int { + return len(f.ids) +} + +// pending_count returns count of pending pages +func (f *freelist) pending_count() int { + var count int + for _, list := range f.pending { + count += len(list) + } + return count +} + +// all returns a list of all free ids and all pending ids in one sorted list. +func (f *freelist) all() []pgid { + m := make(pgids, 0) + + for _, list := range f.pending { + m = append(m, list...) + } + + sort.Sort(m) + return pgids(f.ids).merge(m) +} + +// allocate returns the starting page id of a contiguous list of pages of a given size. +// If a contiguous block cannot be found then 0 is returned. +func (f *freelist) allocate(n int) pgid { + if len(f.ids) == 0 { + return 0 + } + + var initial, previd pgid + for i, id := range f.ids { + if id <= 1 { + panic(fmt.Sprintf("invalid page allocation: %d", id)) + } + + // Reset initial page if this is not contiguous. + if previd == 0 || id-previd != 1 { + initial = id + } + + // If we found a contiguous block then remove it and return it. + if (id-initial)+1 == pgid(n) { + // If we're allocating off the beginning then take the fast path + // and just adjust the existing slice. This will use extra memory + // temporarily but the append() in free() will realloc the slice + // as is necessary. + if (i + 1) == n { + f.ids = f.ids[i+1:] + } else { + copy(f.ids[i-n+1:], f.ids[i+1:]) + f.ids = f.ids[:len(f.ids)-n] + } + + // Remove from the free cache. + for i := pgid(0); i < pgid(n); i++ { + delete(f.cache, initial+i) + } + + return initial + } + + previd = id + } + return 0 +} + +// free releases a page and its overflow for a given transaction id. +// If the page is already free then a panic will occur. +func (f *freelist) free(txid txid, p *page) { + if p.id <= 1 { + panic(fmt.Sprintf("cannot free page 0 or 1: %d", p.id)) + } + + // Free page and all its overflow pages. + var ids = f.pending[txid] + for id := p.id; id <= p.id+pgid(p.overflow); id++ { + // Verify that page is not already free. + if f.cache[id] { + panic(fmt.Sprintf("page %d already freed", id)) + } + + // Add to the freelist and cache. + ids = append(ids, id) + f.cache[id] = true + } + f.pending[txid] = ids +} + +// release moves all page ids for a transaction id (or older) to the freelist. +func (f *freelist) release(txid txid) { + m := make(pgids, 0) + for tid, ids := range f.pending { + if tid <= txid { + // Move transaction's pending pages to the available freelist. + // Don't remove from the cache since the page is still free. + m = append(m, ids...) + delete(f.pending, tid) + } + } + sort.Sort(m) + f.ids = pgids(f.ids).merge(m) +} + +// rollback removes the pages from a given pending tx. +func (f *freelist) rollback(txid txid) { + // Remove page ids from cache. + for _, id := range f.pending[txid] { + delete(f.cache, id) + } + + // Remove pages from pending list. + delete(f.pending, txid) +} + +// freed returns whether a given page is in the free list. +func (f *freelist) freed(pgid pgid) bool { + return f.cache[pgid] +} + +// read initializes the freelist from a freelist page. +func (f *freelist) read(p *page) { + // If the page.count is at the max uint16 value (64k) then it's considered + // an overflow and the size of the freelist is stored as the first element. + idx, count := 0, int(p.count) + if count == 0xFFFF { + idx = 1 + count = int(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0]) + } + + // Copy the list of page ids from the freelist. + if count == 0 { + f.ids = nil + } else { + ids := ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[idx:count] + f.ids = make([]pgid, len(ids)) + copy(f.ids, ids) + + // Make sure they're sorted. + sort.Sort(pgids(f.ids)) + } + + // Rebuild the page cache. + f.reindex() +} + +// write writes the page ids onto a freelist page. All free and pending ids are +// saved to disk since in the event of a program crash, all pending ids will +// become free. +func (f *freelist) write(p *page) error { + // Combine the old free pgids and pgids waiting on an open transaction. + ids := f.all() + + // Update the header flag. + p.flags |= freelistPageFlag + + // The page.count can only hold up to 64k elements so if we overflow that + // number then we handle it by putting the size in the first element. + if len(ids) == 0 { + p.count = uint16(len(ids)) + } else if len(ids) < 0xFFFF { + p.count = uint16(len(ids)) + copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[:], ids) + } else { + p.count = 0xFFFF + ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0] = pgid(len(ids)) + copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[1:], ids) + } + + return nil +} + +// reload reads the freelist from a page and filters out pending items. +func (f *freelist) reload(p *page) { + f.read(p) + + // Build a cache of only pending pages. + pcache := make(map[pgid]bool) + for _, pendingIDs := range f.pending { + for _, pendingID := range pendingIDs { + pcache[pendingID] = true + } + } + + // Check each page in the freelist and build a new available freelist + // with any pages not in the pending lists. + var a []pgid + for _, id := range f.ids { + if !pcache[id] { + a = append(a, id) + } + } + f.ids = a + + // Once the available list is rebuilt then rebuild the free cache so that + // it includes the available and pending free pages. + f.reindex() +} + +// reindex rebuilds the free cache based on available and pending free lists. +func (f *freelist) reindex() { + f.cache = make(map[pgid]bool) + for _, id := range f.ids { + f.cache[id] = true + } + for _, pendingIDs := range f.pending { + for _, pendingID := range pendingIDs { + f.cache[pendingID] = true + } + } +} diff --git a/vendor/github.com/boltdb/bolt/freelist_test.go b/vendor/github.com/boltdb/bolt/freelist_test.go new file mode 100644 index 00000000..4e9b3a8d --- /dev/null +++ b/vendor/github.com/boltdb/bolt/freelist_test.go @@ -0,0 +1,158 @@ +package bolt + +import ( + "math/rand" + "reflect" + "sort" + "testing" + "unsafe" +) + +// Ensure that a page is added to a transaction's freelist. +func TestFreelist_free(t *testing.T) { + f := newFreelist() + f.free(100, &page{id: 12}) + if !reflect.DeepEqual([]pgid{12}, f.pending[100]) { + t.Fatalf("exp=%v; got=%v", []pgid{12}, f.pending[100]) + } +} + +// Ensure that a page and its overflow is added to a transaction's freelist. +func TestFreelist_free_overflow(t *testing.T) { + f := newFreelist() + f.free(100, &page{id: 12, overflow: 3}) + if exp := []pgid{12, 13, 14, 15}; !reflect.DeepEqual(exp, f.pending[100]) { + t.Fatalf("exp=%v; got=%v", exp, f.pending[100]) + } +} + +// Ensure that a transaction's free pages can be released. +func TestFreelist_release(t *testing.T) { + f := newFreelist() + f.free(100, &page{id: 12, overflow: 1}) + f.free(100, &page{id: 9}) + f.free(102, &page{id: 39}) + f.release(100) + f.release(101) + if exp := []pgid{9, 12, 13}; !reflect.DeepEqual(exp, f.ids) { + t.Fatalf("exp=%v; got=%v", exp, f.ids) + } + + f.release(102) + if exp := []pgid{9, 12, 13, 39}; !reflect.DeepEqual(exp, f.ids) { + t.Fatalf("exp=%v; got=%v", exp, f.ids) + } +} + +// Ensure that a freelist can find contiguous blocks of pages. +func TestFreelist_allocate(t *testing.T) { + f := &freelist{ids: []pgid{3, 4, 5, 6, 7, 9, 12, 13, 18}} + if id := int(f.allocate(3)); id != 3 { + t.Fatalf("exp=3; got=%v", id) + } + if id := int(f.allocate(1)); id != 6 { + t.Fatalf("exp=6; got=%v", id) + } + if id := int(f.allocate(3)); id != 0 { + t.Fatalf("exp=0; got=%v", id) + } + if id := int(f.allocate(2)); id != 12 { + t.Fatalf("exp=12; got=%v", id) + } + if id := int(f.allocate(1)); id != 7 { + t.Fatalf("exp=7; got=%v", id) + } + if id := int(f.allocate(0)); id != 0 { + t.Fatalf("exp=0; got=%v", id) + } + if id := int(f.allocate(0)); id != 0 { + t.Fatalf("exp=0; got=%v", id) + } + if exp := []pgid{9, 18}; !reflect.DeepEqual(exp, f.ids) { + t.Fatalf("exp=%v; got=%v", exp, f.ids) + } + + if id := int(f.allocate(1)); id != 9 { + t.Fatalf("exp=9; got=%v", id) + } + if id := int(f.allocate(1)); id != 18 { + t.Fatalf("exp=18; got=%v", id) + } + if id := int(f.allocate(1)); id != 0 { + t.Fatalf("exp=0; got=%v", id) + } + if exp := []pgid{}; !reflect.DeepEqual(exp, f.ids) { + t.Fatalf("exp=%v; got=%v", exp, f.ids) + } +} + +// Ensure that a freelist can deserialize from a freelist page. +func TestFreelist_read(t *testing.T) { + // Create a page. + var buf [4096]byte + page := (*page)(unsafe.Pointer(&buf[0])) + page.flags = freelistPageFlag + page.count = 2 + + // Insert 2 page ids. + ids := (*[3]pgid)(unsafe.Pointer(&page.ptr)) + ids[0] = 23 + ids[1] = 50 + + // Deserialize page into a freelist. + f := newFreelist() + f.read(page) + + // Ensure that there are two page ids in the freelist. + if exp := []pgid{23, 50}; !reflect.DeepEqual(exp, f.ids) { + t.Fatalf("exp=%v; got=%v", exp, f.ids) + } +} + +// Ensure that a freelist can serialize into a freelist page. +func TestFreelist_write(t *testing.T) { + // Create a freelist and write it to a page. + var buf [4096]byte + f := &freelist{ids: []pgid{12, 39}, pending: make(map[txid][]pgid)} + f.pending[100] = []pgid{28, 11} + f.pending[101] = []pgid{3} + p := (*page)(unsafe.Pointer(&buf[0])) + if err := f.write(p); err != nil { + t.Fatal(err) + } + + // Read the page back out. + f2 := newFreelist() + f2.read(p) + + // Ensure that the freelist is correct. + // All pages should be present and in reverse order. + if exp := []pgid{3, 11, 12, 28, 39}; !reflect.DeepEqual(exp, f2.ids) { + t.Fatalf("exp=%v; got=%v", exp, f2.ids) + } +} + +func Benchmark_FreelistRelease10K(b *testing.B) { benchmark_FreelistRelease(b, 10000) } +func Benchmark_FreelistRelease100K(b *testing.B) { benchmark_FreelistRelease(b, 100000) } +func Benchmark_FreelistRelease1000K(b *testing.B) { benchmark_FreelistRelease(b, 1000000) } +func Benchmark_FreelistRelease10000K(b *testing.B) { benchmark_FreelistRelease(b, 10000000) } + +func benchmark_FreelistRelease(b *testing.B, size int) { + ids := randomPgids(size) + pending := randomPgids(len(ids) / 400) + b.ResetTimer() + for i := 0; i < b.N; i++ { + f := &freelist{ids: ids, pending: map[txid][]pgid{1: pending}} + f.release(1) + } +} + +func randomPgids(n int) []pgid { + rand.Seed(42) + pgids := make(pgids, n) + for i := range pgids { + pgids[i] = pgid(rand.Int63()) + } + sort.Sort(pgids) + return pgids +} diff --git a/vendor/github.com/boltdb/bolt/node.go b/vendor/github.com/boltdb/bolt/node.go new file mode 100644 index 00000000..159318b2 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/node.go @@ -0,0 +1,604 @@ +package bolt + +import ( + "bytes" + "fmt" + "sort" + "unsafe" +) + +// node represents an in-memory, deserialized page. +type node struct { + bucket *Bucket + isLeaf bool + unbalanced bool + spilled bool + key []byte + pgid pgid + parent *node + children nodes + inodes inodes +} + +// root returns the top-level node this node is attached to. +func (n *node) root() *node { + if n.parent == nil { + return n + } + return n.parent.root() +} + +// minKeys returns the minimum number of inodes this node should have. +func (n *node) minKeys() int { + if n.isLeaf { + return 1 + } + return 2 +} + +// size returns the size of the node after serialization. +func (n *node) size() int { + sz, elsz := pageHeaderSize, n.pageElementSize() + for i := 0; i < len(n.inodes); i++ { + item := &n.inodes[i] + sz += elsz + len(item.key) + len(item.value) + } + return sz +} + +// sizeLessThan returns true if the node is less than a given size. +// This is an optimization to avoid calculating a large node when we only need +// to know if it fits inside a certain page size. +func (n *node) sizeLessThan(v int) bool { + sz, elsz := pageHeaderSize, n.pageElementSize() + for i := 0; i < len(n.inodes); i++ { + item := &n.inodes[i] + sz += elsz + len(item.key) + len(item.value) + if sz >= v { + return false + } + } + return true +} + +// pageElementSize returns the size of each page element based on the type of node. +func (n *node) pageElementSize() int { + if n.isLeaf { + return leafPageElementSize + } + return branchPageElementSize +} + +// childAt returns the child node at a given index. +func (n *node) childAt(index int) *node { + if n.isLeaf { + panic(fmt.Sprintf("invalid childAt(%d) on a leaf node", index)) + } + return n.bucket.node(n.inodes[index].pgid, n) +} + +// childIndex returns the index of a given child node. +func (n *node) childIndex(child *node) int { + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, child.key) != -1 }) + return index +} + +// numChildren returns the number of children. +func (n *node) numChildren() int { + return len(n.inodes) +} + +// nextSibling returns the next node with the same parent. +func (n *node) nextSibling() *node { + if n.parent == nil { + return nil + } + index := n.parent.childIndex(n) + if index >= n.parent.numChildren()-1 { + return nil + } + return n.parent.childAt(index + 1) +} + +// prevSibling returns the previous node with the same parent. +func (n *node) prevSibling() *node { + if n.parent == nil { + return nil + } + index := n.parent.childIndex(n) + if index == 0 { + return nil + } + return n.parent.childAt(index - 1) +} + +// put inserts a key/value. +func (n *node) put(oldKey, newKey, value []byte, pgid pgid, flags uint32) { + if pgid >= n.bucket.tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", pgid, n.bucket.tx.meta.pgid)) + } else if len(oldKey) <= 0 { + panic("put: zero-length old key") + } else if len(newKey) <= 0 { + panic("put: zero-length new key") + } + + // Find insertion index. + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, oldKey) != -1 }) + + // Add capacity and shift nodes if we don't have an exact match and need to insert. + exact := (len(n.inodes) > 0 && index < len(n.inodes) && bytes.Equal(n.inodes[index].key, oldKey)) + if !exact { + n.inodes = append(n.inodes, inode{}) + copy(n.inodes[index+1:], n.inodes[index:]) + } + + inode := &n.inodes[index] + inode.flags = flags + inode.key = newKey + inode.value = value + inode.pgid = pgid + _assert(len(inode.key) > 0, "put: zero-length inode key") +} + +// del removes a key from the node. +func (n *node) del(key []byte) { + // Find index of key. + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, key) != -1 }) + + // Exit if the key isn't found. + if index >= len(n.inodes) || !bytes.Equal(n.inodes[index].key, key) { + return + } + + // Delete inode from the node. + n.inodes = append(n.inodes[:index], n.inodes[index+1:]...) + + // Mark the node as needing rebalancing. + n.unbalanced = true +} + +// read initializes the node from a page. +func (n *node) read(p *page) { + n.pgid = p.id + n.isLeaf = ((p.flags & leafPageFlag) != 0) + n.inodes = make(inodes, int(p.count)) + + for i := 0; i < int(p.count); i++ { + inode := &n.inodes[i] + if n.isLeaf { + elem := p.leafPageElement(uint16(i)) + inode.flags = elem.flags + inode.key = elem.key() + inode.value = elem.value() + } else { + elem := p.branchPageElement(uint16(i)) + inode.pgid = elem.pgid + inode.key = elem.key() + } + _assert(len(inode.key) > 0, "read: zero-length inode key") + } + + // Save first key so we can find the node in the parent when we spill. + if len(n.inodes) > 0 { + n.key = n.inodes[0].key + _assert(len(n.key) > 0, "read: zero-length node key") + } else { + n.key = nil + } +} + +// write writes the items onto one or more pages. +func (n *node) write(p *page) { + // Initialize page. + if n.isLeaf { + p.flags |= leafPageFlag + } else { + p.flags |= branchPageFlag + } + + if len(n.inodes) >= 0xFFFF { + panic(fmt.Sprintf("inode overflow: %d (pgid=%d)", len(n.inodes), p.id)) + } + p.count = uint16(len(n.inodes)) + + // Stop here if there are no items to write. + if p.count == 0 { + return + } + + // Loop over each item and write it to the page. + b := (*[maxAllocSize]byte)(unsafe.Pointer(&p.ptr))[n.pageElementSize()*len(n.inodes):] + for i, item := range n.inodes { + _assert(len(item.key) > 0, "write: zero-length inode key") + + // Write the page element. + if n.isLeaf { + elem := p.leafPageElement(uint16(i)) + elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem))) + elem.flags = item.flags + elem.ksize = uint32(len(item.key)) + elem.vsize = uint32(len(item.value)) + } else { + elem := p.branchPageElement(uint16(i)) + elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem))) + elem.ksize = uint32(len(item.key)) + elem.pgid = item.pgid + _assert(elem.pgid != p.id, "write: circular dependency occurred") + } + + // If the length of key+value is larger than the max allocation size + // then we need to reallocate the byte array pointer. + // + // See: https://github.com/boltdb/bolt/pull/335 + klen, vlen := len(item.key), len(item.value) + if len(b) < klen+vlen { + b = (*[maxAllocSize]byte)(unsafe.Pointer(&b[0]))[:] + } + + // Write data for the element to the end of the page. + copy(b[0:], item.key) + b = b[klen:] + copy(b[0:], item.value) + b = b[vlen:] + } + + // DEBUG ONLY: n.dump() +} + +// split breaks up a node into multiple smaller nodes, if appropriate. +// This should only be called from the spill() function. +func (n *node) split(pageSize int) []*node { + var nodes []*node + + node := n + for { + // Split node into two. + a, b := node.splitTwo(pageSize) + nodes = append(nodes, a) + + // If we can't split then exit the loop. + if b == nil { + break + } + + // Set node to b so it gets split on the next iteration. + node = b + } + + return nodes +} + +// splitTwo breaks up a node into two smaller nodes, if appropriate. +// This should only be called from the split() function. +func (n *node) splitTwo(pageSize int) (*node, *node) { + // Ignore the split if the page doesn't have at least enough nodes for + // two pages or if the nodes can fit in a single page. + if len(n.inodes) <= (minKeysPerPage*2) || n.sizeLessThan(pageSize) { + return n, nil + } + + // Determine the threshold before starting a new node. + var fillPercent = n.bucket.FillPercent + if fillPercent < minFillPercent { + fillPercent = minFillPercent + } else if fillPercent > maxFillPercent { + fillPercent = maxFillPercent + } + threshold := int(float64(pageSize) * fillPercent) + + // Determine split position and sizes of the two pages. + splitIndex, _ := n.splitIndex(threshold) + + // Split node into two separate nodes. + // If there's no parent then we'll need to create one. + if n.parent == nil { + n.parent = &node{bucket: n.bucket, children: []*node{n}} + } + + // Create a new node and add it to the parent. + next := &node{bucket: n.bucket, isLeaf: n.isLeaf, parent: n.parent} + n.parent.children = append(n.parent.children, next) + + // Split inodes across two nodes. + next.inodes = n.inodes[splitIndex:] + n.inodes = n.inodes[:splitIndex] + + // Update the statistics. + n.bucket.tx.stats.Split++ + + return n, next +} + +// splitIndex finds the position where a page will fill a given threshold. +// It returns the index as well as the size of the first page. +// This is only be called from split(). +func (n *node) splitIndex(threshold int) (index, sz int) { + sz = pageHeaderSize + + // Loop until we only have the minimum number of keys required for the second page. + for i := 0; i < len(n.inodes)-minKeysPerPage; i++ { + index = i + inode := n.inodes[i] + elsize := n.pageElementSize() + len(inode.key) + len(inode.value) + + // If we have at least the minimum number of keys and adding another + // node would put us over the threshold then exit and return. + if i >= minKeysPerPage && sz+elsize > threshold { + break + } + + // Add the element size to the total size. + sz += elsize + } + + return +} + +// spill writes the nodes to dirty pages and splits nodes as it goes. +// Returns an error if dirty pages cannot be allocated. +func (n *node) spill() error { + var tx = n.bucket.tx + if n.spilled { + return nil + } + + // Spill child nodes first. Child nodes can materialize sibling nodes in + // the case of split-merge so we cannot use a range loop. We have to check + // the children size on every loop iteration. + sort.Sort(n.children) + for i := 0; i < len(n.children); i++ { + if err := n.children[i].spill(); err != nil { + return err + } + } + + // We no longer need the child list because it's only used for spill tracking. + n.children = nil + + // Split nodes into appropriate sizes. The first node will always be n. + var nodes = n.split(tx.db.pageSize) + for _, node := range nodes { + // Add node's page to the freelist if it's not new. + if node.pgid > 0 { + tx.db.freelist.free(tx.meta.txid, tx.page(node.pgid)) + node.pgid = 0 + } + + // Allocate contiguous space for the node. + p, err := tx.allocate((node.size() / tx.db.pageSize) + 1) + if err != nil { + return err + } + + // Write the node. + if p.id >= tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", p.id, tx.meta.pgid)) + } + node.pgid = p.id + node.write(p) + node.spilled = true + + // Insert into parent inodes. + if node.parent != nil { + var key = node.key + if key == nil { + key = node.inodes[0].key + } + + node.parent.put(key, node.inodes[0].key, nil, node.pgid, 0) + node.key = node.inodes[0].key + _assert(len(node.key) > 0, "spill: zero-length node key") + } + + // Update the statistics. + tx.stats.Spill++ + } + + // If the root node split and created a new root then we need to spill that + // as well. We'll clear out the children to make sure it doesn't try to respill. + if n.parent != nil && n.parent.pgid == 0 { + n.children = nil + return n.parent.spill() + } + + return nil +} + +// rebalance attempts to combine the node with sibling nodes if the node fill +// size is below a threshold or if there are not enough keys. +func (n *node) rebalance() { + if !n.unbalanced { + return + } + n.unbalanced = false + + // Update statistics. + n.bucket.tx.stats.Rebalance++ + + // Ignore if node is above threshold (25%) and has enough keys. + var threshold = n.bucket.tx.db.pageSize / 4 + if n.size() > threshold && len(n.inodes) > n.minKeys() { + return + } + + // Root node has special handling. + if n.parent == nil { + // If root node is a branch and only has one node then collapse it. + if !n.isLeaf && len(n.inodes) == 1 { + // Move root's child up. + child := n.bucket.node(n.inodes[0].pgid, n) + n.isLeaf = child.isLeaf + n.inodes = child.inodes[:] + n.children = child.children + + // Reparent all child nodes being moved. + for _, inode := range n.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent = n + } + } + + // Remove old child. + child.parent = nil + delete(n.bucket.nodes, child.pgid) + child.free() + } + + return + } + + // If node has no keys then just remove it. + if n.numChildren() == 0 { + n.parent.del(n.key) + n.parent.removeChild(n) + delete(n.bucket.nodes, n.pgid) + n.free() + n.parent.rebalance() + return + } + + _assert(n.parent.numChildren() > 1, "parent must have at least 2 children") + + // Destination node is right sibling if idx == 0, otherwise left sibling. + var target *node + var useNextSibling = (n.parent.childIndex(n) == 0) + if useNextSibling { + target = n.nextSibling() + } else { + target = n.prevSibling() + } + + // If both this node and the target node are too small then merge them. + if useNextSibling { + // Reparent all child nodes being moved. + for _, inode := range target.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent.removeChild(child) + child.parent = n + child.parent.children = append(child.parent.children, child) + } + } + + // Copy over inodes from target and remove target. + n.inodes = append(n.inodes, target.inodes...) + n.parent.del(target.key) + n.parent.removeChild(target) + delete(n.bucket.nodes, target.pgid) + target.free() + } else { + // Reparent all child nodes being moved. + for _, inode := range n.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent.removeChild(child) + child.parent = target + child.parent.children = append(child.parent.children, child) + } + } + + // Copy over inodes to target and remove node. + target.inodes = append(target.inodes, n.inodes...) + n.parent.del(n.key) + n.parent.removeChild(n) + delete(n.bucket.nodes, n.pgid) + n.free() + } + + // Either this node or the target node was deleted from the parent so rebalance it. + n.parent.rebalance() +} + +// removes a node from the list of in-memory children. +// This does not affect the inodes. +func (n *node) removeChild(target *node) { + for i, child := range n.children { + if child == target { + n.children = append(n.children[:i], n.children[i+1:]...) + return + } + } +} + +// dereference causes the node to copy all its inode key/value references to heap memory. +// This is required when the mmap is reallocated so inodes are not pointing to stale data. +func (n *node) dereference() { + if n.key != nil { + key := make([]byte, len(n.key)) + copy(key, n.key) + n.key = key + _assert(n.pgid == 0 || len(n.key) > 0, "dereference: zero-length node key on existing node") + } + + for i := range n.inodes { + inode := &n.inodes[i] + + key := make([]byte, len(inode.key)) + copy(key, inode.key) + inode.key = key + _assert(len(inode.key) > 0, "dereference: zero-length inode key") + + value := make([]byte, len(inode.value)) + copy(value, inode.value) + inode.value = value + } + + // Recursively dereference children. + for _, child := range n.children { + child.dereference() + } + + // Update statistics. + n.bucket.tx.stats.NodeDeref++ +} + +// free adds the node's underlying page to the freelist. +func (n *node) free() { + if n.pgid != 0 { + n.bucket.tx.db.freelist.free(n.bucket.tx.meta.txid, n.bucket.tx.page(n.pgid)) + n.pgid = 0 + } +} + +// dump writes the contents of the node to STDERR for debugging purposes. +/* +func (n *node) dump() { + // Write node header. + var typ = "branch" + if n.isLeaf { + typ = "leaf" + } + warnf("[NODE %d {type=%s count=%d}]", n.pgid, typ, len(n.inodes)) + + // Write out abbreviated version of each item. + for _, item := range n.inodes { + if n.isLeaf { + if item.flags&bucketLeafFlag != 0 { + bucket := (*bucket)(unsafe.Pointer(&item.value[0])) + warnf("+L %08x -> (bucket root=%d)", trunc(item.key, 4), bucket.root) + } else { + warnf("+L %08x -> %08x", trunc(item.key, 4), trunc(item.value, 4)) + } + } else { + warnf("+B %08x -> pgid=%d", trunc(item.key, 4), item.pgid) + } + } + warn("") +} +*/ + +type nodes []*node + +func (s nodes) Len() int { return len(s) } +func (s nodes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s nodes) Less(i, j int) bool { return bytes.Compare(s[i].inodes[0].key, s[j].inodes[0].key) == -1 } + +// inode represents an internal node inside of a node. +// It can be used to point to elements in a page or point +// to an element which hasn't been added to a page yet. +type inode struct { + flags uint32 + pgid pgid + key []byte + value []byte +} + +type inodes []inode diff --git a/vendor/github.com/boltdb/bolt/node_test.go b/vendor/github.com/boltdb/bolt/node_test.go new file mode 100644 index 00000000..fa5d10f9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/node_test.go @@ -0,0 +1,156 @@ +package bolt + +import ( + "testing" + "unsafe" +) + +// Ensure that a node can insert a key/value. +func TestNode_put(t *testing.T) { + n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{meta: &meta{pgid: 1}}}} + n.put([]byte("baz"), []byte("baz"), []byte("2"), 0, 0) + n.put([]byte("foo"), []byte("foo"), []byte("0"), 0, 0) + n.put([]byte("bar"), []byte("bar"), []byte("1"), 0, 0) + n.put([]byte("foo"), []byte("foo"), []byte("3"), 0, leafPageFlag) + + if len(n.inodes) != 3 { + t.Fatalf("exp=3; got=%d", len(n.inodes)) + } + if k, v := n.inodes[0].key, n.inodes[0].value; string(k) != "bar" || string(v) != "1" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if k, v := n.inodes[1].key, n.inodes[1].value; string(k) != "baz" || string(v) != "2" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if k, v := n.inodes[2].key, n.inodes[2].value; string(k) != "foo" || string(v) != "3" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if n.inodes[2].flags != uint32(leafPageFlag) { + t.Fatalf("not a leaf: %d", n.inodes[2].flags) + } +} + +// Ensure that a node can deserialize from a leaf page. +func TestNode_read_LeafPage(t *testing.T) { + // Create a page. + var buf [4096]byte + page := (*page)(unsafe.Pointer(&buf[0])) + page.flags = leafPageFlag + page.count = 2 + + // Insert 2 elements at the beginning. sizeof(leafPageElement) == 16 + nodes := (*[3]leafPageElement)(unsafe.Pointer(&page.ptr)) + nodes[0] = leafPageElement{flags: 0, pos: 32, ksize: 3, vsize: 4} // pos = sizeof(leafPageElement) * 2 + nodes[1] = leafPageElement{flags: 0, pos: 23, ksize: 10, vsize: 3} // pos = sizeof(leafPageElement) + 3 + 4 + + // Write data for the nodes at the end. + data := (*[4096]byte)(unsafe.Pointer(&nodes[2])) + copy(data[:], []byte("barfooz")) + copy(data[7:], []byte("helloworldbye")) + + // Deserialize page into a leaf. + n := &node{} + n.read(page) + + // Check that there are two inodes with correct data. + if !n.isLeaf { + t.Fatal("expected leaf") + } + if len(n.inodes) != 2 { + t.Fatalf("exp=2; got=%d", len(n.inodes)) + } + if k, v := n.inodes[0].key, n.inodes[0].value; string(k) != "bar" || string(v) != "fooz" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if k, v := n.inodes[1].key, n.inodes[1].value; string(k) != "helloworld" || string(v) != "bye" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } +} + +// Ensure that a node can serialize into a leaf page. +func TestNode_write_LeafPage(t *testing.T) { + // Create a node. + n := &node{isLeaf: true, inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}} + n.put([]byte("susy"), []byte("susy"), []byte("que"), 0, 0) + n.put([]byte("ricki"), []byte("ricki"), []byte("lake"), 0, 0) + n.put([]byte("john"), []byte("john"), []byte("johnson"), 0, 0) + + // Write it to a page. + var buf [4096]byte + p := (*page)(unsafe.Pointer(&buf[0])) + n.write(p) + + // Read the page back in. + n2 := &node{} + n2.read(p) + + // Check that the two pages are the same. + if len(n2.inodes) != 3 { + t.Fatalf("exp=3; got=%d", len(n2.inodes)) + } + if k, v := n2.inodes[0].key, n2.inodes[0].value; string(k) != "john" || string(v) != "johnson" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if k, v := n2.inodes[1].key, n2.inodes[1].value; string(k) != "ricki" || string(v) != "lake" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } + if k, v := n2.inodes[2].key, n2.inodes[2].value; string(k) != "susy" || string(v) != "que" { + t.Fatalf("exp=; got=<%s,%s>", k, v) + } +} + +// Ensure that a node can split into appropriate subgroups. +func TestNode_split(t *testing.T) { + // Create a node. + n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}} + n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000003"), []byte("00000003"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000004"), []byte("00000004"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000005"), []byte("00000005"), []byte("0123456701234567"), 0, 0) + + // Split between 2 & 3. + n.split(100) + + var parent = n.parent + if len(parent.children) != 2 { + t.Fatalf("exp=2; got=%d", len(parent.children)) + } + if len(parent.children[0].inodes) != 2 { + t.Fatalf("exp=2; got=%d", len(parent.children[0].inodes)) + } + if len(parent.children[1].inodes) != 3 { + t.Fatalf("exp=3; got=%d", len(parent.children[1].inodes)) + } +} + +// Ensure that a page with the minimum number of inodes just returns a single node. +func TestNode_split_MinKeys(t *testing.T) { + // Create a node. + n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}} + n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0) + + // Split. + n.split(20) + if n.parent != nil { + t.Fatalf("expected nil parent") + } +} + +// Ensure that a node that has keys that all fit on a page just returns one leaf. +func TestNode_split_SinglePage(t *testing.T) { + // Create a node. + n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}} + n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000003"), []byte("00000003"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000004"), []byte("00000004"), []byte("0123456701234567"), 0, 0) + n.put([]byte("00000005"), []byte("00000005"), []byte("0123456701234567"), 0, 0) + + // Split. + n.split(4096) + if n.parent != nil { + t.Fatalf("expected nil parent") + } +} diff --git a/vendor/github.com/boltdb/bolt/page.go b/vendor/github.com/boltdb/bolt/page.go new file mode 100644 index 00000000..7651a6bf --- /dev/null +++ b/vendor/github.com/boltdb/bolt/page.go @@ -0,0 +1,178 @@ +package bolt + +import ( + "fmt" + "os" + "sort" + "unsafe" +) + +const pageHeaderSize = int(unsafe.Offsetof(((*page)(nil)).ptr)) + +const minKeysPerPage = 2 + +const branchPageElementSize = int(unsafe.Sizeof(branchPageElement{})) +const leafPageElementSize = int(unsafe.Sizeof(leafPageElement{})) + +const ( + branchPageFlag = 0x01 + leafPageFlag = 0x02 + metaPageFlag = 0x04 + freelistPageFlag = 0x10 +) + +const ( + bucketLeafFlag = 0x01 +) + +type pgid uint64 + +type page struct { + id pgid + flags uint16 + count uint16 + overflow uint32 + ptr uintptr +} + +// typ returns a human readable page type string used for debugging. +func (p *page) typ() string { + if (p.flags & branchPageFlag) != 0 { + return "branch" + } else if (p.flags & leafPageFlag) != 0 { + return "leaf" + } else if (p.flags & metaPageFlag) != 0 { + return "meta" + } else if (p.flags & freelistPageFlag) != 0 { + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +// meta returns a pointer to the metadata section of the page. +func (p *page) meta() *meta { + return (*meta)(unsafe.Pointer(&p.ptr)) +} + +// leafPageElement retrieves the leaf node by index +func (p *page) leafPageElement(index uint16) *leafPageElement { + n := &((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[index] + return n +} + +// leafPageElements retrieves a list of leaf nodes. +func (p *page) leafPageElements() []leafPageElement { + if p.count == 0 { + return nil + } + return ((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[:] +} + +// branchPageElement retrieves the branch node by index +func (p *page) branchPageElement(index uint16) *branchPageElement { + return &((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[index] +} + +// branchPageElements retrieves a list of branch nodes. +func (p *page) branchPageElements() []branchPageElement { + if p.count == 0 { + return nil + } + return ((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[:] +} + +// dump writes n bytes of the page to STDERR as hex output. +func (p *page) hexdump(n int) { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:n] + fmt.Fprintf(os.Stderr, "%x\n", buf) +} + +type pages []*page + +func (s pages) Len() int { return len(s) } +func (s pages) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s pages) Less(i, j int) bool { return s[i].id < s[j].id } + +// branchPageElement represents a node on a branch page. +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid pgid +} + +// key returns a byte slice of the node key. +func (n *branchPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize] +} + +// leafPageElement represents a node on a leaf page. +type leafPageElement struct { + flags uint32 + pos uint32 + ksize uint32 + vsize uint32 +} + +// key returns a byte slice of the node key. +func (n *leafPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize:n.ksize] +} + +// value returns a byte slice of the node value. +func (n *leafPageElement) value() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos+n.ksize]))[:n.vsize:n.vsize] +} + +// PageInfo represents human readable information about a page. +type PageInfo struct { + ID int + Type string + Count int + OverflowCount int +} + +type pgids []pgid + +func (s pgids) Len() int { return len(s) } +func (s pgids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s pgids) Less(i, j int) bool { return s[i] < s[j] } + +// merge returns the sorted union of a and b. +func (a pgids) merge(b pgids) pgids { + // Return the opposite slice if one is nil. + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + // Create a list to hold all elements from both lists. + merged := make(pgids, 0, len(a)+len(b)) + + // Assign lead to the slice with a lower starting value, follow to the higher value. + lead, follow := a, b + if b[0] < a[0] { + lead, follow = b, a + } + + // Continue while there are elements in the lead. + for len(lead) > 0 { + // Merge largest prefix of lead that is ahead of follow[0]. + n := sort.Search(len(lead), func(i int) bool { return lead[i] > follow[0] }) + merged = append(merged, lead[:n]...) + if n >= len(lead) { + break + } + + // Swap lead and follow. + lead, follow = follow, lead[n:] + } + + // Append what's left in follow. + merged = append(merged, follow...) + + return merged +} diff --git a/vendor/github.com/boltdb/bolt/page_test.go b/vendor/github.com/boltdb/bolt/page_test.go new file mode 100644 index 00000000..59f4a30e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/page_test.go @@ -0,0 +1,72 @@ +package bolt + +import ( + "reflect" + "sort" + "testing" + "testing/quick" +) + +// Ensure that the page type can be returned in human readable format. +func TestPage_typ(t *testing.T) { + if typ := (&page{flags: branchPageFlag}).typ(); typ != "branch" { + t.Fatalf("exp=branch; got=%v", typ) + } + if typ := (&page{flags: leafPageFlag}).typ(); typ != "leaf" { + t.Fatalf("exp=leaf; got=%v", typ) + } + if typ := (&page{flags: metaPageFlag}).typ(); typ != "meta" { + t.Fatalf("exp=meta; got=%v", typ) + } + if typ := (&page{flags: freelistPageFlag}).typ(); typ != "freelist" { + t.Fatalf("exp=freelist; got=%v", typ) + } + if typ := (&page{flags: 20000}).typ(); typ != "unknown<4e20>" { + t.Fatalf("exp=unknown<4e20>; got=%v", typ) + } +} + +// Ensure that the hexdump debugging function doesn't blow up. +func TestPage_dump(t *testing.T) { + (&page{id: 256}).hexdump(16) +} + +func TestPgids_merge(t *testing.T) { + a := pgids{4, 5, 6, 10, 11, 12, 13, 27} + b := pgids{1, 3, 8, 9, 25, 30} + c := a.merge(b) + if !reflect.DeepEqual(c, pgids{1, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30}) { + t.Errorf("mismatch: %v", c) + } + + a = pgids{4, 5, 6, 10, 11, 12, 13, 27, 35, 36} + b = pgids{8, 9, 25, 30} + c = a.merge(b) + if !reflect.DeepEqual(c, pgids{4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30, 35, 36}) { + t.Errorf("mismatch: %v", c) + } +} + +func TestPgids_merge_quick(t *testing.T) { + if err := quick.Check(func(a, b pgids) bool { + // Sort incoming lists. + sort.Sort(a) + sort.Sort(b) + + // Merge the two lists together. + got := a.merge(b) + + // The expected value should be the two lists combined and sorted. + exp := append(a, b...) + sort.Sort(exp) + + if !reflect.DeepEqual(exp, got) { + t.Errorf("\nexp=%+v\ngot=%+v\n", exp, got) + return false + } + + return true + }, nil); err != nil { + t.Fatal(err) + } +} diff --git a/vendor/github.com/boltdb/bolt/quick_test.go b/vendor/github.com/boltdb/bolt/quick_test.go new file mode 100644 index 00000000..4da58177 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/quick_test.go @@ -0,0 +1,79 @@ +package bolt_test + +import ( + "bytes" + "flag" + "fmt" + "math/rand" + "os" + "reflect" + "testing/quick" + "time" +) + +// testing/quick defaults to 5 iterations and a random seed. +// You can override these settings from the command line: +// +// -quick.count The number of iterations to perform. +// -quick.seed The seed to use for randomizing. +// -quick.maxitems The maximum number of items to insert into a DB. +// -quick.maxksize The maximum size of a key. +// -quick.maxvsize The maximum size of a value. +// + +var qcount, qseed, qmaxitems, qmaxksize, qmaxvsize int + +func init() { + flag.IntVar(&qcount, "quick.count", 5, "") + flag.IntVar(&qseed, "quick.seed", int(time.Now().UnixNano())%100000, "") + flag.IntVar(&qmaxitems, "quick.maxitems", 1000, "") + flag.IntVar(&qmaxksize, "quick.maxksize", 1024, "") + flag.IntVar(&qmaxvsize, "quick.maxvsize", 1024, "") + flag.Parse() + fmt.Fprintln(os.Stderr, "seed:", qseed) + fmt.Fprintf(os.Stderr, "quick settings: count=%v, items=%v, ksize=%v, vsize=%v\n", qcount, qmaxitems, qmaxksize, qmaxvsize) +} + +func qconfig() *quick.Config { + return &quick.Config{ + MaxCount: qcount, + Rand: rand.New(rand.NewSource(int64(qseed))), + } +} + +type testdata []testdataitem + +func (t testdata) Len() int { return len(t) } +func (t testdata) Swap(i, j int) { t[i], t[j] = t[j], t[i] } +func (t testdata) Less(i, j int) bool { return bytes.Compare(t[i].Key, t[j].Key) == -1 } + +func (t testdata) Generate(rand *rand.Rand, size int) reflect.Value { + n := rand.Intn(qmaxitems-1) + 1 + items := make(testdata, n) + for i := 0; i < n; i++ { + item := &items[i] + item.Key = randByteSlice(rand, 1, qmaxksize) + item.Value = randByteSlice(rand, 0, qmaxvsize) + } + return reflect.ValueOf(items) +} + +type revtestdata []testdataitem + +func (t revtestdata) Len() int { return len(t) } +func (t revtestdata) Swap(i, j int) { t[i], t[j] = t[j], t[i] } +func (t revtestdata) Less(i, j int) bool { return bytes.Compare(t[i].Key, t[j].Key) == 1 } + +type testdataitem struct { + Key []byte + Value []byte +} + +func randByteSlice(rand *rand.Rand, minSize, maxSize int) []byte { + n := rand.Intn(maxSize-minSize) + minSize + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} diff --git a/vendor/github.com/boltdb/bolt/simulation_test.go b/vendor/github.com/boltdb/bolt/simulation_test.go new file mode 100644 index 00000000..38310165 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/simulation_test.go @@ -0,0 +1,329 @@ +package bolt_test + +import ( + "bytes" + "fmt" + "math/rand" + "sync" + "testing" + + "github.com/boltdb/bolt" +) + +func TestSimulate_1op_1p(t *testing.T) { testSimulate(t, 1, 1) } +func TestSimulate_10op_1p(t *testing.T) { testSimulate(t, 10, 1) } +func TestSimulate_100op_1p(t *testing.T) { testSimulate(t, 100, 1) } +func TestSimulate_1000op_1p(t *testing.T) { testSimulate(t, 1000, 1) } +func TestSimulate_10000op_1p(t *testing.T) { testSimulate(t, 10000, 1) } + +func TestSimulate_10op_10p(t *testing.T) { testSimulate(t, 10, 10) } +func TestSimulate_100op_10p(t *testing.T) { testSimulate(t, 100, 10) } +func TestSimulate_1000op_10p(t *testing.T) { testSimulate(t, 1000, 10) } +func TestSimulate_10000op_10p(t *testing.T) { testSimulate(t, 10000, 10) } + +func TestSimulate_100op_100p(t *testing.T) { testSimulate(t, 100, 100) } +func TestSimulate_1000op_100p(t *testing.T) { testSimulate(t, 1000, 100) } +func TestSimulate_10000op_100p(t *testing.T) { testSimulate(t, 10000, 100) } + +func TestSimulate_10000op_1000p(t *testing.T) { testSimulate(t, 10000, 1000) } + +// Randomly generate operations on a given database with multiple clients to ensure consistency and thread safety. +func testSimulate(t *testing.T, threadCount, parallelism int) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + rand.Seed(int64(qseed)) + + // A list of operations that readers and writers can perform. + var readerHandlers = []simulateHandler{simulateGetHandler} + var writerHandlers = []simulateHandler{simulateGetHandler, simulatePutHandler} + + var versions = make(map[int]*QuickDB) + versions[1] = NewQuickDB() + + db := MustOpenDB() + defer db.MustClose() + + var mutex sync.Mutex + + // Run n threads in parallel, each with their own operation. + var wg sync.WaitGroup + var threads = make(chan bool, parallelism) + var i int + for { + threads <- true + wg.Add(1) + writable := ((rand.Int() % 100) < 20) // 20% writers + + // Choose an operation to execute. + var handler simulateHandler + if writable { + handler = writerHandlers[rand.Intn(len(writerHandlers))] + } else { + handler = readerHandlers[rand.Intn(len(readerHandlers))] + } + + // Execute a thread for the given operation. + go func(writable bool, handler simulateHandler) { + defer wg.Done() + + // Start transaction. + tx, err := db.Begin(writable) + if err != nil { + t.Fatal("tx begin: ", err) + } + + // Obtain current state of the dataset. + mutex.Lock() + var qdb = versions[tx.ID()] + if writable { + qdb = versions[tx.ID()-1].Copy() + } + mutex.Unlock() + + // Make sure we commit/rollback the tx at the end and update the state. + if writable { + defer func() { + mutex.Lock() + versions[tx.ID()] = qdb + mutex.Unlock() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + }() + } else { + defer func() { _ = tx.Rollback() }() + } + + // Ignore operation if we don't have data yet. + if qdb == nil { + return + } + + // Execute handler. + handler(tx, qdb) + + // Release a thread back to the scheduling loop. + <-threads + }(writable, handler) + + i++ + if i > threadCount { + break + } + } + + // Wait until all threads are done. + wg.Wait() +} + +type simulateHandler func(tx *bolt.Tx, qdb *QuickDB) + +// Retrieves a key from the database and verifies that it is what is expected. +func simulateGetHandler(tx *bolt.Tx, qdb *QuickDB) { + // Randomly retrieve an existing exist. + keys := qdb.Rand() + if len(keys) == 0 { + return + } + + // Retrieve root bucket. + b := tx.Bucket(keys[0]) + if b == nil { + panic(fmt.Sprintf("bucket[0] expected: %08x\n", trunc(keys[0], 4))) + } + + // Drill into nested buckets. + for _, key := range keys[1 : len(keys)-1] { + b = b.Bucket(key) + if b == nil { + panic(fmt.Sprintf("bucket[n] expected: %v -> %v\n", keys, key)) + } + } + + // Verify key/value on the final bucket. + expected := qdb.Get(keys) + actual := b.Get(keys[len(keys)-1]) + if !bytes.Equal(actual, expected) { + fmt.Println("=== EXPECTED ===") + fmt.Println(expected) + fmt.Println("=== ACTUAL ===") + fmt.Println(actual) + fmt.Println("=== END ===") + panic("value mismatch") + } +} + +// Inserts a key into the database. +func simulatePutHandler(tx *bolt.Tx, qdb *QuickDB) { + var err error + keys, value := randKeys(), randValue() + + // Retrieve root bucket. + b := tx.Bucket(keys[0]) + if b == nil { + b, err = tx.CreateBucket(keys[0]) + if err != nil { + panic("create bucket: " + err.Error()) + } + } + + // Create nested buckets, if necessary. + for _, key := range keys[1 : len(keys)-1] { + child := b.Bucket(key) + if child != nil { + b = child + } else { + b, err = b.CreateBucket(key) + if err != nil { + panic("create bucket: " + err.Error()) + } + } + } + + // Insert into database. + if err := b.Put(keys[len(keys)-1], value); err != nil { + panic("put: " + err.Error()) + } + + // Insert into in-memory database. + qdb.Put(keys, value) +} + +// QuickDB is an in-memory database that replicates the functionality of the +// Bolt DB type except that it is entirely in-memory. It is meant for testing +// that the Bolt database is consistent. +type QuickDB struct { + sync.RWMutex + m map[string]interface{} +} + +// NewQuickDB returns an instance of QuickDB. +func NewQuickDB() *QuickDB { + return &QuickDB{m: make(map[string]interface{})} +} + +// Get retrieves the value at a key path. +func (db *QuickDB) Get(keys [][]byte) []byte { + db.RLock() + defer db.RUnlock() + + m := db.m + for _, key := range keys[:len(keys)-1] { + value := m[string(key)] + if value == nil { + return nil + } + switch value := value.(type) { + case map[string]interface{}: + m = value + case []byte: + return nil + } + } + + // Only return if it's a simple value. + if value, ok := m[string(keys[len(keys)-1])].([]byte); ok { + return value + } + return nil +} + +// Put inserts a value into a key path. +func (db *QuickDB) Put(keys [][]byte, value []byte) { + db.Lock() + defer db.Unlock() + + // Build buckets all the way down the key path. + m := db.m + for _, key := range keys[:len(keys)-1] { + if _, ok := m[string(key)].([]byte); ok { + return // Keypath intersects with a simple value. Do nothing. + } + + if m[string(key)] == nil { + m[string(key)] = make(map[string]interface{}) + } + m = m[string(key)].(map[string]interface{}) + } + + // Insert value into the last key. + m[string(keys[len(keys)-1])] = value +} + +// Rand returns a random key path that points to a simple value. +func (db *QuickDB) Rand() [][]byte { + db.RLock() + defer db.RUnlock() + if len(db.m) == 0 { + return nil + } + var keys [][]byte + db.rand(db.m, &keys) + return keys +} + +func (db *QuickDB) rand(m map[string]interface{}, keys *[][]byte) { + i, index := 0, rand.Intn(len(m)) + for k, v := range m { + if i == index { + *keys = append(*keys, []byte(k)) + if v, ok := v.(map[string]interface{}); ok { + db.rand(v, keys) + } + return + } + i++ + } + panic("quickdb rand: out-of-range") +} + +// Copy copies the entire database. +func (db *QuickDB) Copy() *QuickDB { + db.RLock() + defer db.RUnlock() + return &QuickDB{m: db.copy(db.m)} +} + +func (db *QuickDB) copy(m map[string]interface{}) map[string]interface{} { + clone := make(map[string]interface{}, len(m)) + for k, v := range m { + switch v := v.(type) { + case map[string]interface{}: + clone[k] = db.copy(v) + default: + clone[k] = v + } + } + return clone +} + +func randKey() []byte { + var min, max = 1, 1024 + n := rand.Intn(max-min) + min + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} + +func randKeys() [][]byte { + var keys [][]byte + var count = rand.Intn(2) + 2 + for i := 0; i < count; i++ { + keys = append(keys, randKey()) + } + return keys +} + +func randValue() []byte { + n := rand.Intn(8192) + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} diff --git a/vendor/github.com/boltdb/bolt/tx.go b/vendor/github.com/boltdb/bolt/tx.go new file mode 100644 index 00000000..1cfb4cde --- /dev/null +++ b/vendor/github.com/boltdb/bolt/tx.go @@ -0,0 +1,682 @@ +package bolt + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "time" + "unsafe" +) + +// txid represents the internal transaction identifier. +type txid uint64 + +// Tx represents a read-only or read/write transaction on the database. +// Read-only transactions can be used for retrieving values for keys and creating cursors. +// Read/write transactions can create and remove buckets and create and remove keys. +// +// IMPORTANT: You must commit or rollback transactions when you are done with +// them. Pages can not be reclaimed by the writer until no more transactions +// are using them. A long running read transaction can cause the database to +// quickly grow. +type Tx struct { + writable bool + managed bool + db *DB + meta *meta + root Bucket + pages map[pgid]*page + stats TxStats + commitHandlers []func() + + // WriteFlag specifies the flag for write-related methods like WriteTo(). + // Tx opens the database file with the specified flag to copy the data. + // + // By default, the flag is unset, which works well for mostly in-memory + // workloads. For databases that are much larger than available RAM, + // set the flag to syscall.O_DIRECT to avoid trashing the page cache. + WriteFlag int +} + +// init initializes the transaction. +func (tx *Tx) init(db *DB) { + tx.db = db + tx.pages = nil + + // Copy the meta page since it can be changed by the writer. + tx.meta = &meta{} + db.meta().copy(tx.meta) + + // Copy over the root bucket. + tx.root = newBucket(tx) + tx.root.bucket = &bucket{} + *tx.root.bucket = tx.meta.root + + // Increment the transaction id and add a page cache for writable transactions. + if tx.writable { + tx.pages = make(map[pgid]*page) + tx.meta.txid += txid(1) + } +} + +// ID returns the transaction id. +func (tx *Tx) ID() int { + return int(tx.meta.txid) +} + +// DB returns a reference to the database that created the transaction. +func (tx *Tx) DB() *DB { + return tx.db +} + +// Size returns current database size in bytes as seen by this transaction. +func (tx *Tx) Size() int64 { + return int64(tx.meta.pgid) * int64(tx.db.pageSize) +} + +// Writable returns whether the transaction can perform write operations. +func (tx *Tx) Writable() bool { + return tx.writable +} + +// Cursor creates a cursor associated with the root bucket. +// All items in the cursor will return a nil value because all root bucket keys point to buckets. +// The cursor is only valid as long as the transaction is open. +// Do not use a cursor after the transaction is closed. +func (tx *Tx) Cursor() *Cursor { + return tx.root.Cursor() +} + +// Stats retrieves a copy of the current transaction statistics. +func (tx *Tx) Stats() TxStats { + return tx.stats +} + +// Bucket retrieves a bucket by name. +// Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) Bucket(name []byte) *Bucket { + return tx.root.Bucket(name) +} + +// CreateBucket creates a new bucket. +// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) { + return tx.root.CreateBucket(name) +} + +// CreateBucketIfNotExists creates a new bucket if it doesn't already exist. +// Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) { + return tx.root.CreateBucketIfNotExists(name) +} + +// DeleteBucket deletes a bucket. +// Returns an error if the bucket cannot be found or if the key represents a non-bucket value. +func (tx *Tx) DeleteBucket(name []byte) error { + return tx.root.DeleteBucket(name) +} + +// ForEach executes a function for each bucket in the root. +// If the provided function returns an error then the iteration is stopped and +// the error is returned to the caller. +func (tx *Tx) ForEach(fn func(name []byte, b *Bucket) error) error { + return tx.root.ForEach(func(k, v []byte) error { + if err := fn(k, tx.root.Bucket(k)); err != nil { + return err + } + return nil + }) +} + +// OnCommit adds a handler function to be executed after the transaction successfully commits. +func (tx *Tx) OnCommit(fn func()) { + tx.commitHandlers = append(tx.commitHandlers, fn) +} + +// Commit writes all changes to disk and updates the meta page. +// Returns an error if a disk write error occurs, or if Commit is +// called on a read-only transaction. +func (tx *Tx) Commit() error { + _assert(!tx.managed, "managed tx commit not allowed") + if tx.db == nil { + return ErrTxClosed + } else if !tx.writable { + return ErrTxNotWritable + } + + // TODO(benbjohnson): Use vectorized I/O to write out dirty pages. + + // Rebalance nodes which have had deletions. + var startTime = time.Now() + tx.root.rebalance() + if tx.stats.Rebalance > 0 { + tx.stats.RebalanceTime += time.Since(startTime) + } + + // spill data onto dirty pages. + startTime = time.Now() + if err := tx.root.spill(); err != nil { + tx.rollback() + return err + } + tx.stats.SpillTime += time.Since(startTime) + + // Free the old root bucket. + tx.meta.root.root = tx.root.root + + opgid := tx.meta.pgid + + // Free the freelist and allocate new pages for it. This will overestimate + // the size of the freelist but not underestimate the size (which would be bad). + tx.db.freelist.free(tx.meta.txid, tx.db.page(tx.meta.freelist)) + p, err := tx.allocate((tx.db.freelist.size() / tx.db.pageSize) + 1) + if err != nil { + tx.rollback() + return err + } + if err := tx.db.freelist.write(p); err != nil { + tx.rollback() + return err + } + tx.meta.freelist = p.id + + // If the high water mark has moved up then attempt to grow the database. + if tx.meta.pgid > opgid { + if err := tx.db.grow(int(tx.meta.pgid+1) * tx.db.pageSize); err != nil { + tx.rollback() + return err + } + } + + // Write dirty pages to disk. + startTime = time.Now() + if err := tx.write(); err != nil { + tx.rollback() + return err + } + + // If strict mode is enabled then perform a consistency check. + // Only the first consistency error is reported in the panic. + if tx.db.StrictMode { + ch := tx.Check() + var errs []string + for { + err, ok := <-ch + if !ok { + break + } + errs = append(errs, err.Error()) + } + if len(errs) > 0 { + panic("check fail: " + strings.Join(errs, "\n")) + } + } + + // Write meta to disk. + if err := tx.writeMeta(); err != nil { + tx.rollback() + return err + } + tx.stats.WriteTime += time.Since(startTime) + + // Finalize the transaction. + tx.close() + + // Execute commit handlers now that the locks have been removed. + for _, fn := range tx.commitHandlers { + fn() + } + + return nil +} + +// Rollback closes the transaction and ignores all previous updates. Read-only +// transactions must be rolled back and not committed. +func (tx *Tx) Rollback() error { + _assert(!tx.managed, "managed tx rollback not allowed") + if tx.db == nil { + return ErrTxClosed + } + tx.rollback() + return nil +} + +func (tx *Tx) rollback() { + if tx.db == nil { + return + } + if tx.writable { + tx.db.freelist.rollback(tx.meta.txid) + tx.db.freelist.reload(tx.db.page(tx.db.meta().freelist)) + } + tx.close() +} + +func (tx *Tx) close() { + if tx.db == nil { + return + } + if tx.writable { + // Grab freelist stats. + var freelistFreeN = tx.db.freelist.free_count() + var freelistPendingN = tx.db.freelist.pending_count() + var freelistAlloc = tx.db.freelist.size() + + // Remove transaction ref & writer lock. + tx.db.rwtx = nil + tx.db.rwlock.Unlock() + + // Merge statistics. + tx.db.statlock.Lock() + tx.db.stats.FreePageN = freelistFreeN + tx.db.stats.PendingPageN = freelistPendingN + tx.db.stats.FreeAlloc = (freelistFreeN + freelistPendingN) * tx.db.pageSize + tx.db.stats.FreelistInuse = freelistAlloc + tx.db.stats.TxStats.add(&tx.stats) + tx.db.statlock.Unlock() + } else { + tx.db.removeTx(tx) + } + + // Clear all references. + tx.db = nil + tx.meta = nil + tx.root = Bucket{tx: tx} + tx.pages = nil +} + +// Copy writes the entire database to a writer. +// This function exists for backwards compatibility. Use WriteTo() instead. +func (tx *Tx) Copy(w io.Writer) error { + _, err := tx.WriteTo(w) + return err +} + +// WriteTo writes the entire database to a writer. +// If err == nil then exactly tx.Size() bytes will be written into the writer. +func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) { + // Attempt to open reader with WriteFlag + f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0) + if err != nil { + return 0, err + } + defer func() { _ = f.Close() }() + + // Generate a meta page. We use the same page data for both meta pages. + buf := make([]byte, tx.db.pageSize) + page := (*page)(unsafe.Pointer(&buf[0])) + page.flags = metaPageFlag + *page.meta() = *tx.meta + + // Write meta 0. + page.id = 0 + page.meta().checksum = page.meta().sum64() + nn, err := w.Write(buf) + n += int64(nn) + if err != nil { + return n, fmt.Errorf("meta 0 copy: %s", err) + } + + // Write meta 1 with a lower transaction id. + page.id = 1 + page.meta().txid -= 1 + page.meta().checksum = page.meta().sum64() + nn, err = w.Write(buf) + n += int64(nn) + if err != nil { + return n, fmt.Errorf("meta 1 copy: %s", err) + } + + // Move past the meta pages in the file. + if _, err := f.Seek(int64(tx.db.pageSize*2), os.SEEK_SET); err != nil { + return n, fmt.Errorf("seek: %s", err) + } + + // Copy data pages. + wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2)) + n += wn + if err != nil { + return n, err + } + + return n, f.Close() +} + +// CopyFile copies the entire database to file at the given path. +// A reader transaction is maintained during the copy so it is safe to continue +// using the database while a copy is in progress. +func (tx *Tx) CopyFile(path string, mode os.FileMode) error { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + + err = tx.Copy(f) + if err != nil { + _ = f.Close() + return err + } + return f.Close() +} + +// Check performs several consistency checks on the database for this transaction. +// An error is returned if any inconsistency is found. +// +// It can be safely run concurrently on a writable transaction. However, this +// incurs a high cost for large databases and databases with a lot of subbuckets +// because of caching. This overhead can be removed if running on a read-only +// transaction, however, it is not safe to execute other writer transactions at +// the same time. +func (tx *Tx) Check() <-chan error { + ch := make(chan error) + go tx.check(ch) + return ch +} + +func (tx *Tx) check(ch chan error) { + // Check if any pages are double freed. + freed := make(map[pgid]bool) + for _, id := range tx.db.freelist.all() { + if freed[id] { + ch <- fmt.Errorf("page %d: already freed", id) + } + freed[id] = true + } + + // Track every reachable page. + reachable := make(map[pgid]*page) + reachable[0] = tx.page(0) // meta0 + reachable[1] = tx.page(1) // meta1 + for i := uint32(0); i <= tx.page(tx.meta.freelist).overflow; i++ { + reachable[tx.meta.freelist+pgid(i)] = tx.page(tx.meta.freelist) + } + + // Recursively check buckets. + tx.checkBucket(&tx.root, reachable, freed, ch) + + // Ensure all pages below high water mark are either reachable or freed. + for i := pgid(0); i < tx.meta.pgid; i++ { + _, isReachable := reachable[i] + if !isReachable && !freed[i] { + ch <- fmt.Errorf("page %d: unreachable unfreed", int(i)) + } + } + + // Close the channel to signal completion. + close(ch) +} + +func (tx *Tx) checkBucket(b *Bucket, reachable map[pgid]*page, freed map[pgid]bool, ch chan error) { + // Ignore inline buckets. + if b.root == 0 { + return + } + + // Check every page used by this bucket. + b.tx.forEachPage(b.root, 0, func(p *page, _ int) { + if p.id > tx.meta.pgid { + ch <- fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid)) + } + + // Ensure each page is only referenced once. + for i := pgid(0); i <= pgid(p.overflow); i++ { + var id = p.id + i + if _, ok := reachable[id]; ok { + ch <- fmt.Errorf("page %d: multiple references", int(id)) + } + reachable[id] = p + } + + // We should only encounter un-freed leaf and branch pages. + if freed[p.id] { + ch <- fmt.Errorf("page %d: reachable freed", int(p.id)) + } else if (p.flags&branchPageFlag) == 0 && (p.flags&leafPageFlag) == 0 { + ch <- fmt.Errorf("page %d: invalid type: %s", int(p.id), p.typ()) + } + }) + + // Check each bucket within this bucket. + _ = b.ForEach(func(k, v []byte) error { + if child := b.Bucket(k); child != nil { + tx.checkBucket(child, reachable, freed, ch) + } + return nil + }) +} + +// allocate returns a contiguous block of memory starting at a given page. +func (tx *Tx) allocate(count int) (*page, error) { + p, err := tx.db.allocate(count) + if err != nil { + return nil, err + } + + // Save to our page cache. + tx.pages[p.id] = p + + // Update statistics. + tx.stats.PageCount++ + tx.stats.PageAlloc += count * tx.db.pageSize + + return p, nil +} + +// write writes any dirty pages to disk. +func (tx *Tx) write() error { + // Sort pages by id. + pages := make(pages, 0, len(tx.pages)) + for _, p := range tx.pages { + pages = append(pages, p) + } + // Clear out page cache early. + tx.pages = make(map[pgid]*page) + sort.Sort(pages) + + // Write pages to disk in order. + for _, p := range pages { + size := (int(p.overflow) + 1) * tx.db.pageSize + offset := int64(p.id) * int64(tx.db.pageSize) + + // Write out page in "max allocation" sized chunks. + ptr := (*[maxAllocSize]byte)(unsafe.Pointer(p)) + for { + // Limit our write to our max allocation size. + sz := size + if sz > maxAllocSize-1 { + sz = maxAllocSize - 1 + } + + // Write chunk to disk. + buf := ptr[:sz] + if _, err := tx.db.ops.writeAt(buf, offset); err != nil { + return err + } + + // Update statistics. + tx.stats.Write++ + + // Exit inner for loop if we've written all the chunks. + size -= sz + if size == 0 { + break + } + + // Otherwise move offset forward and move pointer to next chunk. + offset += int64(sz) + ptr = (*[maxAllocSize]byte)(unsafe.Pointer(&ptr[sz])) + } + } + + // Ignore file sync if flag is set on DB. + if !tx.db.NoSync || IgnoreNoSync { + if err := fdatasync(tx.db); err != nil { + return err + } + } + + // Put small pages back to page pool. + for _, p := range pages { + // Ignore page sizes over 1 page. + // These are allocated using make() instead of the page pool. + if int(p.overflow) != 0 { + continue + } + + buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:tx.db.pageSize] + + // See https://go.googlesource.com/go/+/f03c9202c43e0abb130669852082117ca50aa9b1 + for i := range buf { + buf[i] = 0 + } + tx.db.pagePool.Put(buf) + } + + return nil +} + +// writeMeta writes the meta to the disk. +func (tx *Tx) writeMeta() error { + // Create a temporary buffer for the meta page. + buf := make([]byte, tx.db.pageSize) + p := tx.db.pageInBuffer(buf, 0) + tx.meta.write(p) + + // Write the meta page to file. + if _, err := tx.db.ops.writeAt(buf, int64(p.id)*int64(tx.db.pageSize)); err != nil { + return err + } + if !tx.db.NoSync || IgnoreNoSync { + if err := fdatasync(tx.db); err != nil { + return err + } + } + + // Update statistics. + tx.stats.Write++ + + return nil +} + +// page returns a reference to the page with a given id. +// If page has been written to then a temporary buffered page is returned. +func (tx *Tx) page(id pgid) *page { + // Check the dirty pages first. + if tx.pages != nil { + if p, ok := tx.pages[id]; ok { + return p + } + } + + // Otherwise return directly from the mmap. + return tx.db.page(id) +} + +// forEachPage iterates over every page within a given page and executes a function. +func (tx *Tx) forEachPage(pgid pgid, depth int, fn func(*page, int)) { + p := tx.page(pgid) + + // Execute function. + fn(p, depth) + + // Recursively loop over children. + if (p.flags & branchPageFlag) != 0 { + for i := 0; i < int(p.count); i++ { + elem := p.branchPageElement(uint16(i)) + tx.forEachPage(elem.pgid, depth+1, fn) + } + } +} + +// Page returns page information for a given page number. +// This is only safe for concurrent use when used by a writable transaction. +func (tx *Tx) Page(id int) (*PageInfo, error) { + if tx.db == nil { + return nil, ErrTxClosed + } else if pgid(id) >= tx.meta.pgid { + return nil, nil + } + + // Build the page info. + p := tx.db.page(pgid(id)) + info := &PageInfo{ + ID: id, + Count: int(p.count), + OverflowCount: int(p.overflow), + } + + // Determine the type (or if it's free). + if tx.db.freelist.freed(pgid(id)) { + info.Type = "free" + } else { + info.Type = p.typ() + } + + return info, nil +} + +// TxStats represents statistics about the actions performed by the transaction. +type TxStats struct { + // Page statistics. + PageCount int // number of page allocations + PageAlloc int // total bytes allocated + + // Cursor statistics. + CursorCount int // number of cursors created + + // Node statistics + NodeCount int // number of node allocations + NodeDeref int // number of node dereferences + + // Rebalance statistics. + Rebalance int // number of node rebalances + RebalanceTime time.Duration // total time spent rebalancing + + // Split/Spill statistics. + Split int // number of nodes split + Spill int // number of nodes spilled + SpillTime time.Duration // total time spent spilling + + // Write statistics. + Write int // number of writes performed + WriteTime time.Duration // total time spent writing to disk +} + +func (s *TxStats) add(other *TxStats) { + s.PageCount += other.PageCount + s.PageAlloc += other.PageAlloc + s.CursorCount += other.CursorCount + s.NodeCount += other.NodeCount + s.NodeDeref += other.NodeDeref + s.Rebalance += other.Rebalance + s.RebalanceTime += other.RebalanceTime + s.Split += other.Split + s.Spill += other.Spill + s.SpillTime += other.SpillTime + s.Write += other.Write + s.WriteTime += other.WriteTime +} + +// Sub calculates and returns the difference between two sets of transaction stats. +// This is useful when obtaining stats at two different points and time and +// you need the performance counters that occurred within that time span. +func (s *TxStats) Sub(other *TxStats) TxStats { + var diff TxStats + diff.PageCount = s.PageCount - other.PageCount + diff.PageAlloc = s.PageAlloc - other.PageAlloc + diff.CursorCount = s.CursorCount - other.CursorCount + diff.NodeCount = s.NodeCount - other.NodeCount + diff.NodeDeref = s.NodeDeref - other.NodeDeref + diff.Rebalance = s.Rebalance - other.Rebalance + diff.RebalanceTime = s.RebalanceTime - other.RebalanceTime + diff.Split = s.Split - other.Split + diff.Spill = s.Spill - other.Spill + diff.SpillTime = s.SpillTime - other.SpillTime + diff.Write = s.Write - other.Write + diff.WriteTime = s.WriteTime - other.WriteTime + return diff +} diff --git a/vendor/github.com/boltdb/bolt/tx_test.go b/vendor/github.com/boltdb/bolt/tx_test.go new file mode 100644 index 00000000..2201e792 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/tx_test.go @@ -0,0 +1,716 @@ +package bolt_test + +import ( + "bytes" + "errors" + "fmt" + "log" + "os" + "testing" + + "github.com/boltdb/bolt" +) + +// Ensure that committing a closed transaction returns an error. +func TestTx_Commit_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + if err := tx.Commit(); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that rolling back a closed transaction returns an error. +func TestTx_Rollback_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that committing a read-only transaction returns an error. +func TestTx_Commit_ErrTxNotWritable(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != bolt.ErrTxNotWritable { + t.Fatal(err) + } +} + +// Ensure that a transaction can retrieve a cursor on the root bucket. +func TestTx_Cursor(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("woojits")); err != nil { + t.Fatal(err) + } + + c := tx.Cursor() + if k, v := c.First(); !bytes.Equal(k, []byte("widgets")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } + + if k, v := c.Next(); !bytes.Equal(k, []byte("woojits")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } + + if k, v := c.Next(); k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", k) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that creating a bucket with a read-only transaction returns an error. +func TestTx_CreateBucket_ErrTxNotWritable(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.View(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("foo")) + if err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that creating a bucket on a closed transaction returns an error. +func TestTx_CreateBucket_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that a Tx can retrieve a bucket. +func TestTx_Bucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a Tx retrieving a non-existent key returns nil. +func TestTx_Get_NotFound(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if b.Get([]byte("no_such_key")) != nil { + t.Fatal("expected nil value") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can be created and retrieved. +func TestTx_CreateBucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Create a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Read the bucket through a separate transaction. + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can be created if it doesn't already exist. +func TestTx_CreateBucketIfNotExists(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + // Create bucket. + if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } + + // Create bucket again. + if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } + + return nil + }); err != nil { + t.Fatal(err) + } + + // Read the bucket through a separate transaction. + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure transaction returns an error if creating an unnamed bucket. +func TestTx_CreateBucketIfNotExists_ErrBucketNameRequired(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists([]byte{}); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } + + if _, err := tx.CreateBucketIfNotExists(nil); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } + + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket cannot be created twice. +func TestTx_CreateBucket_ErrBucketExists(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Create a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Create the same bucket again. + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != bolt.ErrBucketExists { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket is created with a non-blank name. +func TestTx_CreateBucket_ErrBucketNameRequired(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket(nil); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that a bucket can be deleted. +func TestTx_DeleteBucket(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + // Create a bucket and add a value. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Delete the bucket and make sure we can't get the value. + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + if tx.Bucket([]byte("widgets")) != nil { + t.Fatal("unexpected bucket") + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + // Create the bucket again and make sure there's not a phantom value. + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("unexpected phantom value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that deleting a bucket on a closed transaction returns an error. +func TestTx_DeleteBucket_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that deleting a bucket with a read-only transaction returns an error. +func TestTx_DeleteBucket_ReadOnly(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that nothing happens when deleting a bucket that doesn't exist. +func TestTx_DeleteBucket_NotFound(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != bolt.ErrBucketNotFound { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that no error is returned when a tx.ForEach function does not return +// an error. +func TestTx_ForEach_NoError(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + return nil + }); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that an error is returned when a tx.ForEach function returns an error. +func TestTx_ForEach_WithError(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + marker := errors.New("marker") + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + return marker + }); err != marker { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that Tx commit handlers are called after a transaction successfully commits. +func TestTx_OnCommit(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var x int + if err := db.Update(func(tx *bolt.Tx) error { + tx.OnCommit(func() { x += 1 }) + tx.OnCommit(func() { x += 2 }) + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } else if x != 3 { + t.Fatalf("unexpected x: %d", x) + } +} + +// Ensure that Tx commit handlers are NOT called after a transaction rolls back. +func TestTx_OnCommit_Rollback(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var x int + if err := db.Update(func(tx *bolt.Tx) error { + tx.OnCommit(func() { x += 1 }) + tx.OnCommit(func() { x += 2 }) + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return errors.New("rollback this commit") + }); err == nil || err.Error() != "rollback this commit" { + t.Fatalf("unexpected error: %s", err) + } else if x != 0 { + t.Fatalf("unexpected x: %d", x) + } +} + +// Ensure that the database can be copied to a file path. +func TestTx_CopyFile(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + path := tempfile() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(path, 0600) + }); err != nil { + t.Fatal(err) + } + + db2, err := bolt.Open(path, 0600, nil) + if err != nil { + t.Fatal(err) + } + + if err := db2.View(func(tx *bolt.Tx) error { + if v := tx.Bucket([]byte("widgets")).Get([]byte("foo")); !bytes.Equal(v, []byte("bar")) { + t.Fatalf("unexpected value: %v", v) + } + if v := tx.Bucket([]byte("widgets")).Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db2.Close(); err != nil { + t.Fatal(err) + } +} + +type failWriterError struct{} + +func (failWriterError) Error() string { + return "error injected for tests" +} + +type failWriter struct { + // fail after this many bytes + After int +} + +func (f *failWriter) Write(p []byte) (n int, err error) { + n = len(p) + if n > f.After { + n = f.After + err = failWriterError{} + } + f.After -= n + return n, err +} + +// Ensure that Copy handles write errors right. +func TestTx_CopyFile_Error_Meta(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + return tx.Copy(&failWriter{}) + }); err == nil || err.Error() != "meta 0 copy: error injected for tests" { + t.Fatalf("unexpected error: %v", err) + } +} + +// Ensure that Copy handles write errors right. +func TestTx_CopyFile_Error_Normal(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + return tx.Copy(&failWriter{3 * db.Info().PageSize}) + }); err == nil || err.Error() != "error injected for tests" { + t.Fatalf("unexpected error: %v", err) + } +} + +func ExampleTx_Rollback() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Create a bucket. + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + log.Fatal(err) + } + + // Set a value for a key. + if err := db.Update(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + }); err != nil { + log.Fatal(err) + } + + // Update the key but rollback the transaction so it never saves. + tx, err := db.Begin(true) + if err != nil { + log.Fatal(err) + } + b := tx.Bucket([]byte("widgets")) + if err := b.Put([]byte("foo"), []byte("baz")); err != nil { + log.Fatal(err) + } + if err := tx.Rollback(); err != nil { + log.Fatal(err) + } + + // Ensure that our original value is still set. + if err := db.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + fmt.Printf("The value for 'foo' is still: %s\n", value) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // The value for 'foo' is still: bar +} + +func ExampleTx_CopyFile() { + // Open the database. + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } + defer os.Remove(db.Path()) + + // Create a bucket and a key. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } + return nil + }); err != nil { + log.Fatal(err) + } + + // Copy the database to another file. + toFile := tempfile() + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(toFile, 0666) + }); err != nil { + log.Fatal(err) + } + defer os.Remove(toFile) + + // Open the cloned database. + db2, err := bolt.Open(toFile, 0666, nil) + if err != nil { + log.Fatal(err) + } + + // Ensure that the key exists in the copy. + if err := db2.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + fmt.Printf("The value for 'foo' in the clone is: %s\n", value) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + if err := db2.Close(); err != nil { + log.Fatal(err) + } + + // Output: + // The value for 'foo' in the clone is: bar +} diff --git a/vendor/github.com/bsm/sarama-cluster/.gitignore b/vendor/github.com/bsm/sarama-cluster/.gitignore new file mode 100644 index 00000000..88113c5b --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/.gitignore @@ -0,0 +1,4 @@ +*.log +*.pid +kafka*/ +vendor/ diff --git a/vendor/github.com/bsm/sarama-cluster/.travis.yml b/vendor/github.com/bsm/sarama-cluster/.travis.yml new file mode 100644 index 00000000..1911a0ba --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/.travis.yml @@ -0,0 +1,18 @@ +sudo: false +language: go +go: + - 1.8.1 + - 1.7.5 +install: + - go get -u github.com/Masterminds/glide + - glide install +env: + - SCALA_VERSION=2.11 KAFKA_VERSION=0.9.0.1 + - SCALA_VERSION=2.11 KAFKA_VERSION=0.10.1.1 + - SCALA_VERSION=2.12 KAFKA_VERSION=0.10.2.0 +script: + - make default test-race +addons: + apt: + packages: + - oracle-java8-set-default diff --git a/vendor/github.com/bsm/sarama-cluster/LICENSE b/vendor/github.com/bsm/sarama-cluster/LICENSE new file mode 100644 index 00000000..127751c4 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/LICENSE @@ -0,0 +1,22 @@ +(The MIT License) + +Copyright (c) 2017 Black Square Media Ltd + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +'Software'), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/bsm/sarama-cluster/Makefile b/vendor/github.com/bsm/sarama-cluster/Makefile new file mode 100644 index 00000000..449de4dd --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/Makefile @@ -0,0 +1,35 @@ +SCALA_VERSION?= 2.12 +KAFKA_VERSION?= 0.10.2.0 +KAFKA_DIR= kafka_$(SCALA_VERSION)-$(KAFKA_VERSION) +KAFKA_SRC= http://www.mirrorservice.org/sites/ftp.apache.org/kafka/$(KAFKA_VERSION)/$(KAFKA_DIR).tgz +KAFKA_ROOT= testdata/$(KAFKA_DIR) +PKG=$(shell glide nv) + +default: vet test + +vet: + go vet $(PKG) + +test: testdeps + KAFKA_DIR=$(KAFKA_DIR) go test $(PKG) -ginkgo.slowSpecThreshold=60 + +test-verbose: testdeps + KAFKA_DIR=$(KAFKA_DIR) go test $(PKG) -ginkgo.slowSpecThreshold=60 -v + +test-race: testdeps + KAFKA_DIR=$(KAFKA_DIR) go test $(PKG) -ginkgo.slowSpecThreshold=60 -v -race + +testdeps: $(KAFKA_ROOT) + +doc: README.md + +.PHONY: test testdeps vet doc + +# --------------------------------------------------------------------- + +$(KAFKA_ROOT): + @mkdir -p $(dir $@) + cd $(dir $@) && curl -sSL $(KAFKA_SRC) | tar xz + +README.md: README.md.tpl $(wildcard *.go) + becca -package $(subst $(GOPATH)/src/,,$(PWD)) diff --git a/vendor/github.com/bsm/sarama-cluster/README.md b/vendor/github.com/bsm/sarama-cluster/README.md new file mode 100644 index 00000000..42f9030a --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/README.md @@ -0,0 +1,86 @@ +# Sarama Cluster + +[![GoDoc](https://godoc.org/github.com/bsm/sarama-cluster?status.svg)](https://godoc.org/github.com/bsm/sarama-cluster) +[![Build Status](https://travis-ci.org/bsm/sarama-cluster.svg?branch=master)](https://travis-ci.org/bsm/sarama-cluster) +[![Go Report Card](https://goreportcard.com/badge/github.com/bsm/sarama-cluster)](https://goreportcard.com/report/github.com/bsm/sarama-cluster) +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) + +Cluster extensions for [Sarama](https://github.com/Shopify/sarama), the Go client library for Apache Kafka 0.9 (and later). + +## Documentation + +Documentation and example are available via godoc at http://godoc.org/github.com/bsm/sarama-cluster + +## Example + +```go +package main + +import ( + "fmt" + "log" + "os" + "os/signal" + + cluster "github.com/bsm/sarama-cluster" +) + +func main() { + + // init (custom) config, enable errors and notifications + config := cluster.NewConfig() + config.Consumer.Return.Errors = true + config.Group.Return.Notifications = true + + // init consumer + brokers := []string{"127.0.0.1:9092"} + topics := []string{"my_topic", "other_topic"} + consumer, err := cluster.NewConsumer(brokers, "my-consumer-group", topics, config) + if err != nil { + panic(err) + } + defer consumer.Close() + + // trap SIGINT to trigger a shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + // consume messages, watch errors and notifications + for { + select { + case msg, more := <-consumer.Messages(): + if more { + fmt.Fprintf(os.Stdout, "%s/%d/%d\t%s\t%s\n", msg.Topic, msg.Partition, msg.Offset, msg.Key, msg.Value) + consumer.MarkOffset(msg, "") // mark message as processed + } + case err, more := <-consumer.Errors(): + if more { + log.Printf("Error: %s\n", err.Error()) + } + case ntf, more := <-consumer.Notifications(): + if more { + log.Printf("Rebalanced: %+v\n", ntf) + } + case <-signals: + return + } + } +} +``` + +## Running tests + +You need to install Ginkgo & Gomega to run tests. Please see +http://onsi.github.io/ginkgo for more details. + +To run tests, call: + + $ make test + +## Troubleshooting + +### Consumer not receiving any messages? + +By default, sarama's `Config.Consumer.Offsets.Initial` is set to `sarama.OffsetNewest`. This means that in the event that a brand new consumer is created, and it has never committed any offsets to kafka, it will only receive messages starting from the message after the current one that was written. + +If you wish to receive all messages (from the start of all messages in the topic) in the event that a consumer does not have any offsets committed to kafka, you need to set `Config.Consumer.Offsets.Initial` to `sarama.OffsetOldest`. diff --git a/vendor/github.com/bsm/sarama-cluster/README.md.tpl b/vendor/github.com/bsm/sarama-cluster/README.md.tpl new file mode 100644 index 00000000..3576941e --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/README.md.tpl @@ -0,0 +1,46 @@ +# Sarama Cluster + +[![GoDoc](https://godoc.org/github.com/bsm/sarama-cluster?status.svg)](https://godoc.org/github.com/bsm/sarama-cluster) +[![Build Status](https://travis-ci.org/bsm/sarama-cluster.svg?branch=master)](https://travis-ci.org/bsm/sarama-cluster) +[![Go Report Card](https://goreportcard.com/badge/github.com/bsm/sarama-cluster)](https://goreportcard.com/report/github.com/bsm/sarama-cluster) +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) + +Cluster extensions for [Sarama](https://github.com/Shopify/sarama), the Go client library for Apache Kafka 0.9 (and later). + +## Documentation + +Documentation and example are available via godoc at http://godoc.org/github.com/bsm/sarama-cluster + +## Example + +```go +package main + +import ( + "fmt" + "log" + "os" + "os/signal" + + cluster "github.com/bsm/sarama-cluster" +) + +func main() {{ "ExampleConsumer" | code }} +``` + +## Running tests + +You need to install Ginkgo & Gomega to run tests. Please see +http://onsi.github.io/ginkgo for more details. + +To run tests, call: + + $ make test + +## Troubleshooting + +### Consumer not receiving any messages? + +By default, sarama's `Config.Consumer.Offsets.Initial` is set to `sarama.OffsetNewest`. This means that in the event that a brand new consumer is created, and it has never committed any offsets to kafka, it will only receive messages starting from the message after the current one that was written. + +If you wish to receive all messages (from the start of all messages in the topic) in the event that a consumer does not have any offsets committed to kafka, you need to set `Config.Consumer.Offsets.Initial` to `sarama.OffsetOldest`. diff --git a/vendor/github.com/bsm/sarama-cluster/balancer.go b/vendor/github.com/bsm/sarama-cluster/balancer.go new file mode 100644 index 00000000..d66ef71a --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/balancer.go @@ -0,0 +1,144 @@ +package cluster + +import ( + "math" + "sort" + + "github.com/Shopify/sarama" +) + +// Notification events are emitted by the consumers on rebalancing +type Notification struct { + // Claimed contains topic/partitions that were claimed by this rebalance cycle + Claimed map[string][]int32 + + // Released contains topic/partitions that were released as part of this rebalance cycle + Released map[string][]int32 + + // Current are topic/partitions that are currently claimed to the consumer + Current map[string][]int32 +} + +func newNotification(released map[string][]int32) *Notification { + return &Notification{ + Claimed: make(map[string][]int32), + Released: released, + Current: make(map[string][]int32), + } +} + +func (n *Notification) claim(current map[string][]int32) { + previous := n.Released + for topic, partitions := range current { + n.Claimed[topic] = int32Slice(partitions).Diff(int32Slice(previous[topic])) + } + for topic, partitions := range previous { + n.Released[topic] = int32Slice(partitions).Diff(int32Slice(current[topic])) + } + n.Current = current +} + +// -------------------------------------------------------------------- + +type topicInfo struct { + Partitions []int32 + MemberIDs []string +} + +func (info topicInfo) Perform(s Strategy) map[string][]int32 { + if s == StrategyRoundRobin { + return info.RoundRobin() + } + return info.Ranges() +} + +func (info topicInfo) Ranges() map[string][]int32 { + sort.Strings(info.MemberIDs) + + mlen := len(info.MemberIDs) + plen := len(info.Partitions) + res := make(map[string][]int32, mlen) + + for pos, memberID := range info.MemberIDs { + n, i := float64(plen)/float64(mlen), float64(pos) + min := int(math.Floor(i*n + 0.5)) + max := int(math.Floor((i+1)*n + 0.5)) + sub := info.Partitions[min:max] + if len(sub) > 0 { + res[memberID] = sub + } + } + return res +} + +func (info topicInfo) RoundRobin() map[string][]int32 { + sort.Strings(info.MemberIDs) + + mlen := len(info.MemberIDs) + res := make(map[string][]int32, mlen) + for i, pnum := range info.Partitions { + memberID := info.MemberIDs[i%mlen] + res[memberID] = append(res[memberID], pnum) + } + return res +} + +// -------------------------------------------------------------------- + +type balancer struct { + client sarama.Client + topics map[string]topicInfo +} + +func newBalancerFromMeta(client sarama.Client, members map[string]sarama.ConsumerGroupMemberMetadata) (*balancer, error) { + balancer := newBalancer(client) + for memberID, meta := range members { + for _, topic := range meta.Topics { + if err := balancer.Topic(topic, memberID); err != nil { + return nil, err + } + } + } + return balancer, nil +} + +func newBalancer(client sarama.Client) *balancer { + return &balancer{ + client: client, + topics: make(map[string]topicInfo), + } +} + +func (r *balancer) Topic(name string, memberID string) error { + topic, ok := r.topics[name] + if !ok { + nums, err := r.client.Partitions(name) + if err != nil { + return err + } + topic = topicInfo{ + Partitions: nums, + MemberIDs: make([]string, 0, 1), + } + } + topic.MemberIDs = append(topic.MemberIDs, memberID) + r.topics[name] = topic + return nil +} + +func (r *balancer) Perform(s Strategy) map[string]map[string][]int32 { + if r == nil { + return nil + } + + res := make(map[string]map[string][]int32, 1) + for topic, info := range r.topics { + for memberID, partitions := range info.Perform(s) { + if _, ok := res[memberID]; !ok { + res[memberID] = make(map[string][]int32, 1) + } + res[memberID][topic] = partitions + } + } + return res +} diff --git a/vendor/github.com/bsm/sarama-cluster/balancer_test.go b/vendor/github.com/bsm/sarama-cluster/balancer_test.go new file mode 100644 index 00000000..969dc0c8 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/balancer_test.go @@ -0,0 +1,124 @@ +package cluster + +import ( + "github.com/Shopify/sarama" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Notification", func() { + + It("should init and update", func() { + n := newNotification(map[string][]int32{ + "a": {1, 2, 3}, + "b": {4, 5}, + "c": {1, 2}, + }) + n.claim(map[string][]int32{ + "a": {3, 4}, + "b": {1, 2, 3, 4}, + "d": {3, 4}, + }) + Expect(n).To(Equal(&Notification{ + Claimed: map[string][]int32{"a": {4}, "b": {1, 2, 3}, "d": {3, 4}}, + Released: map[string][]int32{"a": {1, 2}, "b": {5}, "c": {1, 2}}, + Current: map[string][]int32{"a": {3, 4}, "b": {1, 2, 3, 4}, "d": {3, 4}}, + })) + }) + +}) + +var _ = Describe("balancer", func() { + var subject *balancer + + BeforeEach(func() { + client := &mockClient{ + topics: map[string][]int32{ + "one": {0, 1, 2, 3}, + "two": {0, 1, 2}, + "three": {0, 1}, + }, + } + + var err error + subject, err = newBalancerFromMeta(client, map[string]sarama.ConsumerGroupMemberMetadata{ + "b": {Topics: []string{"three", "one"}}, + "a": {Topics: []string{"one", "two"}}, + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should parse from meta data", func() { + Expect(subject.topics).To(HaveLen(3)) + }) + + It("should perform", func() { + Expect(subject.Perform(StrategyRange)).To(Equal(map[string]map[string][]int32{ + "a": {"one": {0, 1}, "two": {0, 1, 2}}, + "b": {"one": {2, 3}, "three": {0, 1}}, + })) + + Expect(subject.Perform(StrategyRoundRobin)).To(Equal(map[string]map[string][]int32{ + "a": {"one": {0, 2}, "two": {0, 1, 2}}, + "b": {"one": {1, 3}, "three": {0, 1}}, + })) + }) + +}) + +var _ = Describe("topicInfo", func() { + + DescribeTable("Ranges", + func(memberIDs []string, partitions []int32, expected map[string][]int32) { + info := topicInfo{MemberIDs: memberIDs, Partitions: partitions} + Expect(info.Ranges()).To(Equal(expected)) + }, + + Entry("three members, three partitions", []string{"M1", "M2", "M3"}, []int32{0, 1, 2}, map[string][]int32{ + "M1": {0}, "M2": {1}, "M3": {2}, + }), + Entry("member ID order", []string{"M3", "M1", "M2"}, []int32{0, 1, 2}, map[string][]int32{ + "M1": {0}, "M2": {1}, "M3": {2}, + }), + Entry("more members than partitions", []string{"M1", "M2", "M3"}, []int32{0, 1}, map[string][]int32{ + "M1": {0}, "M3": {1}, + }), + Entry("far more members than partitions", []string{"M1", "M2", "M3"}, []int32{0}, map[string][]int32{ + "M2": {0}, + }), + Entry("fewer members than partitions", []string{"M1", "M2", "M3"}, []int32{0, 1, 2, 3}, map[string][]int32{ + "M1": {0}, "M2": {1, 2}, "M3": {3}, + }), + Entry("uneven members/partitions ratio", []string{"M1", "M2", "M3"}, []int32{0, 2, 4, 6, 8}, map[string][]int32{ + "M1": {0, 2}, "M2": {4}, "M3": {6, 8}, + }), + ) + + DescribeTable("RoundRobin", + func(memberIDs []string, partitions []int32, expected map[string][]int32) { + info := topicInfo{MemberIDs: memberIDs, Partitions: partitions} + Expect(info.RoundRobin()).To(Equal(expected)) + }, + + Entry("three members, three partitions", []string{"M1", "M2", "M3"}, []int32{0, 1, 2}, map[string][]int32{ + "M1": {0}, "M2": {1}, "M3": {2}, + }), + Entry("member ID order", []string{"M3", "M1", "M2"}, []int32{0, 1, 2}, map[string][]int32{ + "M1": {0}, "M2": {1}, "M3": {2}, + }), + Entry("more members than partitions", []string{"M1", "M2", "M3"}, []int32{0, 1}, map[string][]int32{ + "M1": {0}, "M2": {1}, + }), + Entry("far more members than partitions", []string{"M1", "M2", "M3"}, []int32{0}, map[string][]int32{ + "M1": {0}, + }), + Entry("fewer members than partitions", []string{"M1", "M2", "M3"}, []int32{0, 1, 2, 3}, map[string][]int32{ + "M1": {0, 3}, "M2": {1}, "M3": {2}, + }), + Entry("uneven members/partitions ratio", []string{"M1", "M2", "M3"}, []int32{0, 2, 4, 6, 8}, map[string][]int32{ + "M1": {0, 6}, "M2": {2, 8}, "M3": {4}, + }), + ) + +}) diff --git a/vendor/github.com/bsm/sarama-cluster/client.go b/vendor/github.com/bsm/sarama-cluster/client.go new file mode 100644 index 00000000..2cfac5d6 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/client.go @@ -0,0 +1,28 @@ +package cluster + +import "github.com/Shopify/sarama" + +// Client is a group client +type Client struct { + sarama.Client + config Config + own bool +} + +// NewClient creates a new client instance +func NewClient(addrs []string, config *Config) (*Client, error) { + if config == nil { + config = NewConfig() + } + + if err := config.Validate(); err != nil { + return nil, err + } + + client, err := sarama.NewClient(addrs, &config.Config) + if err != nil { + return nil, err + } + + return &Client{Client: client, config: *config}, nil +} diff --git a/vendor/github.com/bsm/sarama-cluster/cluster.go b/vendor/github.com/bsm/sarama-cluster/cluster.go new file mode 100644 index 00000000..760d0c73 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/cluster.go @@ -0,0 +1,73 @@ +package cluster + +import ( + "fmt" + "sort" +) + +// Strategy for partition to consumer assignement +type Strategy string + +const ( + // StrategyRange is the default and assigns partition ranges to consumers. + // Example with six partitions and two consumers: + // C1: [0, 1, 2] + // C2: [3, 4, 5] + StrategyRange Strategy = "range" + + // StrategyRoundRobin assigns partitions by alternating over consumers. + // Example with six partitions and two consumers: + // C1: [0, 2, 4] + // C2: [1, 3, 5] + StrategyRoundRobin Strategy = "roundrobin" +) + +// Error instances are wrappers for internal errors with a context and +// may be returned through the consumer's Errors() channel +type Error struct { + Ctx string + error +} + +// -------------------------------------------------------------------- + +type none struct{} + +type topicPartition struct { + Topic string + Partition int32 +} + +func (tp *topicPartition) String() string { + return fmt.Sprintf("%s-%d", tp.Topic, tp.Partition) +} + +type offsetInfo struct { + Offset int64 + Metadata string +} + +func (i offsetInfo) NextOffset(fallback int64) int64 { + if i.Offset > -1 { + return i.Offset + } + return fallback +} + +type int32Slice []int32 + +func (p int32Slice) Len() int { return len(p) } +func (p int32Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p int32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func (p int32Slice) Diff(o int32Slice) (res []int32) { + on := len(o) + for _, x := range p { + n := sort.Search(on, func(i int) bool { return o[i] >= x }) + if n < on && o[n] == x { + continue + } + res = append(res, x) + } + return +} diff --git a/vendor/github.com/bsm/sarama-cluster/cluster_test.go b/vendor/github.com/bsm/sarama-cluster/cluster_test.go new file mode 100644 index 00000000..e55f7292 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/cluster_test.go @@ -0,0 +1,198 @@ +package cluster + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/Shopify/sarama" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const ( + testGroup = "sarama-cluster-group" + testKafkaData = "/tmp/sarama-cluster-test" +) + +var ( + testKafkaRoot = "kafka_2.12-0.10.2.0" + testKafkaAddrs = []string{"127.0.0.1:29092"} + testTopics = []string{"topic-a", "topic-b"} + + testClient sarama.Client + testKafkaCmd, testZkCmd *exec.Cmd +) + +func init() { + if dir := os.Getenv("KAFKA_DIR"); dir != "" { + testKafkaRoot = dir + } +} + +var _ = Describe("offsetInfo", func() { + + It("should calculate next offset", func() { + Expect(offsetInfo{-2, ""}.NextOffset(sarama.OffsetOldest)).To(Equal(sarama.OffsetOldest)) + Expect(offsetInfo{-2, ""}.NextOffset(sarama.OffsetNewest)).To(Equal(sarama.OffsetNewest)) + Expect(offsetInfo{-1, ""}.NextOffset(sarama.OffsetOldest)).To(Equal(sarama.OffsetOldest)) + Expect(offsetInfo{-1, ""}.NextOffset(sarama.OffsetNewest)).To(Equal(sarama.OffsetNewest)) + Expect(offsetInfo{0, ""}.NextOffset(sarama.OffsetOldest)).To(Equal(int64(0))) + Expect(offsetInfo{100, ""}.NextOffset(sarama.OffsetOldest)).To(Equal(int64(100))) + }) + +}) + +var _ = Describe("int32Slice", func() { + + It("should diff", func() { + Expect(((int32Slice)(nil)).Diff(int32Slice{1, 3, 5})).To(BeNil()) + Expect(int32Slice{1, 3, 5}.Diff((int32Slice)(nil))).To(Equal([]int32{1, 3, 5})) + Expect(int32Slice{1, 3, 5}.Diff(int32Slice{1, 3, 5})).To(BeNil()) + Expect(int32Slice{1, 3, 5}.Diff(int32Slice{1, 2, 3, 4, 5})).To(BeNil()) + Expect(int32Slice{1, 3, 5}.Diff(int32Slice{2, 3, 4})).To(Equal([]int32{1, 5})) + Expect(int32Slice{1, 3, 5}.Diff(int32Slice{1, 4})).To(Equal([]int32{3, 5})) + Expect(int32Slice{1, 3, 5}.Diff(int32Slice{2, 5})).To(Equal([]int32{1, 3})) + }) + +}) + +// -------------------------------------------------------------------- + +var _ = BeforeSuite(func() { + testZkCmd = exec.Command( + testDataDir(testKafkaRoot, "bin", "kafka-run-class.sh"), + "org.apache.zookeeper.server.quorum.QuorumPeerMain", + testDataDir("zookeeper.properties"), + ) + testZkCmd.Env = []string{"KAFKA_HEAP_OPTS=-Xmx512M -Xms512M"} + if testing.Verbose() || os.Getenv("CI") != "" { + testZkCmd.Stderr = os.Stderr + testZkCmd.Stdout = os.Stdout + } + + testKafkaCmd = exec.Command( + testDataDir(testKafkaRoot, "bin", "kafka-run-class.sh"), + "-name", "kafkaServer", "kafka.Kafka", + testDataDir("server.properties"), + ) + testKafkaCmd.Env = []string{"KAFKA_HEAP_OPTS=-Xmx1G -Xms1G"} + if testing.Verbose() || os.Getenv("CI") != "" { + testKafkaCmd.Stderr = os.Stderr + testKafkaCmd.Stdout = os.Stdout + } + + Expect(os.MkdirAll(testKafkaData, 0777)).NotTo(HaveOccurred()) + Expect(testZkCmd.Start()).NotTo(HaveOccurred()) + Expect(testKafkaCmd.Start()).NotTo(HaveOccurred()) + + // Wait for client + Eventually(func() error { + var err error + + // sync-producer requires Return.Successes set to true + testConf := sarama.NewConfig() + testConf.Producer.Return.Successes = true + testClient, err = sarama.NewClient(testKafkaAddrs, testConf) + return err + }, "30s", "1s").ShouldNot(HaveOccurred()) + + // Ensure we can retrieve partition info + Eventually(func() error { + _, err := testClient.Partitions(testTopics[0]) + return err + }, "30s", "1s").ShouldNot(HaveOccurred()) + + // Seed a few messages + Expect(testSeed(1000)).NotTo(HaveOccurred()) +}) + +var _ = AfterSuite(func() { + if testClient != nil { + _ = testClient.Close() + } + + _ = testKafkaCmd.Process.Kill() + _ = testZkCmd.Process.Kill() + _ = testKafkaCmd.Wait() + _ = testZkCmd.Wait() + _ = os.RemoveAll(testKafkaData) +}) + +// -------------------------------------------------------------------- + +func TestSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "sarama/cluster") +} + +func testDataDir(tokens ...string) string { + tokens = append([]string{"testdata"}, tokens...) + return filepath.Join(tokens...) +} + +// Seed messages +func testSeed(n int) error { + producer, err := sarama.NewSyncProducerFromClient(testClient) + if err != nil { + return err + } + + for i := 0; i < n; i++ { + kv := sarama.StringEncoder(fmt.Sprintf("PLAINDATA-%08d", i)) + for _, t := range testTopics { + msg := &sarama.ProducerMessage{Topic: t, Key: kv, Value: kv} + if _, _, err := producer.SendMessage(msg); err != nil { + return err + } + } + } + return producer.Close() +} + +type testConsumerMessage struct { + sarama.ConsumerMessage + ConsumerID string +} + +// -------------------------------------------------------------------- + +var _ sarama.Consumer = &mockConsumer{} +var _ sarama.PartitionConsumer = &mockPartitionConsumer{} + +type mockClient struct { + sarama.Client + + topics map[string][]int32 +} +type mockConsumer struct{ sarama.Consumer } +type mockPartitionConsumer struct { + sarama.PartitionConsumer + + Topic string + Partition int32 + Offset int64 +} + +func (m *mockClient) Partitions(t string) ([]int32, error) { + pts, ok := m.topics[t] + if !ok { + return nil, sarama.ErrInvalidTopic + } + return pts, nil +} + +func (*mockConsumer) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) { + if offset > -1 && offset < 1000 { + return nil, sarama.ErrOffsetOutOfRange + } + return &mockPartitionConsumer{ + Topic: topic, + Partition: partition, + Offset: offset, + }, nil +} + +func (*mockPartitionConsumer) Close() error { return nil } diff --git a/vendor/github.com/bsm/sarama-cluster/cmd/sarama-cluster-cli/main.go b/vendor/github.com/bsm/sarama-cluster/cmd/sarama-cluster-cli/main.go new file mode 100644 index 00000000..59a55fba --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/cmd/sarama-cluster-cli/main.go @@ -0,0 +1,97 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/Shopify/sarama" + "github.com/bsm/sarama-cluster" +) + +var ( + groupID = flag.String("group", "", "REQUIRED: The shared consumer group name") + brokerList = flag.String("brokers", os.Getenv("KAFKA_PEERS"), "The comma separated list of brokers in the Kafka cluster") + topicList = flag.String("topics", "", "REQUIRED: The comma separated list of topics to consume") + offset = flag.String("offset", "newest", "The offset to start with. Can be `oldest`, `newest`") + verbose = flag.Bool("verbose", false, "Whether to turn on sarama logging") + + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +func main() { + flag.Parse() + + if *groupID == "" { + printUsageErrorAndExit("You have to provide a -group name.") + } else if *brokerList == "" { + printUsageErrorAndExit("You have to provide -brokers as a comma-separated list, or set the KAFKA_PEERS environment variable.") + } else if *topicList == "" { + printUsageErrorAndExit("You have to provide -topics as a comma-separated list.") + } + + // Init config + config := cluster.NewConfig() + if *verbose { + sarama.Logger = logger + } else { + config.Consumer.Return.Errors = true + config.Group.Return.Notifications = true + } + + switch *offset { + case "oldest": + config.Consumer.Offsets.Initial = sarama.OffsetOldest + case "newest": + config.Consumer.Offsets.Initial = sarama.OffsetNewest + default: + printUsageErrorAndExit("-offset should be `oldest` or `newest`") + } + + // Init consumer, consume errors & messages + consumer, err := cluster.NewConsumer(strings.Split(*brokerList, ","), *groupID, strings.Split(*topicList, ","), config) + if err != nil { + printErrorAndExit(69, "Failed to start consumer: %s", err) + } + defer consumer.Close() + + // Create signal channel + sigchan := make(chan os.Signal, 1) + signal.Notify(sigchan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) + + // Consume all channels, wait for signal to exit + for { + select { + case msg, more := <-consumer.Messages(): + if more { + fmt.Fprintf(os.Stdout, "%s/%d/%d\t%s\n", msg.Topic, msg.Partition, msg.Offset, msg.Value) + consumer.MarkOffset(msg, "") + } + case ntf, more := <-consumer.Notifications(): + if more { + logger.Printf("Rebalanced: %+v\n", ntf) + } + case err, more := <-consumer.Errors(): + if more { + logger.Printf("Error: %s\n", err.Error()) + } + case <-sigchan: + return + } + } +} + +func printErrorAndExit(code int, format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n\n", values...) + os.Exit(code) +} + +func printUsageErrorAndExit(format string, values ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n\n", values...) + flag.Usage() + os.Exit(64) +} diff --git a/vendor/github.com/bsm/sarama-cluster/config.go b/vendor/github.com/bsm/sarama-cluster/config.go new file mode 100644 index 00000000..0208a880 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/config.go @@ -0,0 +1,112 @@ +package cluster + +import ( + "regexp" + "time" + + "github.com/Shopify/sarama" +) + +var minVersion = sarama.V0_9_0_0 + +// Config extends sarama.Config with Group specific namespace +type Config struct { + sarama.Config + + // Group is the namespace for group management properties + Group struct { + // The strategy to use for the allocation of partitions to consumers (defaults to StrategyRange) + PartitionStrategy Strategy + Offsets struct { + Retry struct { + // The numer retries when committing offsets (defaults to 3). + Max int + } + } + Session struct { + // The allowed session timeout for registered consumers (defaults to 30s). + // Must be within the allowed server range. + Timeout time.Duration + } + Heartbeat struct { + // Interval between each heartbeat (defaults to 3s). It should be no more + // than 1/3rd of the Group.Session.Timout setting + Interval time.Duration + } + // Return specifies which group channels will be populated. If they are set to true, + // you must read from the respective channels to prevent deadlock. + Return struct { + // If enabled, rebalance notification will be returned on the + // Notifications channel (default disabled). + Notifications bool + } + + Topics struct { + // An additional whitelist of topics to subscribe to. + Whitelist *regexp.Regexp + // An additional blacklist of topics to avoid. If set, this will precede over + // the Whitelist setting. + Blacklist *regexp.Regexp + } + + Member struct { + // Custom metadata to include when joining the group. The user data for all joined members + // can be retrieved by sending a DescribeGroupRequest to the broker that is the + // coordinator for the group. + UserData []byte + } + } +} + +// NewConfig returns a new configuration instance with sane defaults. +func NewConfig() *Config { + c := &Config{ + Config: *sarama.NewConfig(), + } + c.Group.PartitionStrategy = StrategyRange + c.Group.Offsets.Retry.Max = 3 + c.Group.Session.Timeout = 30 * time.Second + c.Group.Heartbeat.Interval = 3 * time.Second + c.Config.Version = minVersion + return c +} + +// Validate checks a Config instance. It will return a +// sarama.ConfigurationError if the specified values don't make sense. +func (c *Config) Validate() error { + if c.Group.Heartbeat.Interval%time.Millisecond != 0 { + sarama.Logger.Println("Group.Heartbeat.Interval only supports millisecond precision; nanoseconds will be truncated.") + } + if c.Group.Session.Timeout%time.Millisecond != 0 { + sarama.Logger.Println("Group.Session.Timeout only supports millisecond precision; nanoseconds will be truncated.") + } + if c.Group.PartitionStrategy != StrategyRange && c.Group.PartitionStrategy != StrategyRoundRobin { + sarama.Logger.Println("Group.PartitionStrategy is not supported; range will be assumed.") + } + if !c.Version.IsAtLeast(minVersion) { + sarama.Logger.Println("Version is not supported; 0.9. will be assumed.") + c.Version = minVersion + } + if err := c.Config.Validate(); err != nil { + return err + } + + // validate the Group values + switch { + case c.Group.Offsets.Retry.Max < 0: + return sarama.ConfigurationError("Group.Offsets.Retry.Max must be >= 0") + case c.Group.Heartbeat.Interval <= 0: + return sarama.ConfigurationError("Group.Heartbeat.Interval must be > 0") + case c.Group.Session.Timeout <= 0: + return sarama.ConfigurationError("Group.Session.Timeout must be > 0") + } + + // ensure offset is correct + switch c.Consumer.Offsets.Initial { + case sarama.OffsetOldest, sarama.OffsetNewest: + default: + return sarama.ConfigurationError("Consumer.Offsets.Initial must be either OffsetOldest or OffsetNewest") + } + + return nil +} diff --git a/vendor/github.com/bsm/sarama-cluster/config_test.go b/vendor/github.com/bsm/sarama-cluster/config_test.go new file mode 100644 index 00000000..558cd90f --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/config_test.go @@ -0,0 +1,25 @@ +package cluster + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Config", func() { + var subject *Config + + BeforeEach(func() { + subject = NewConfig() + }) + + It("should init", func() { + Expect(subject.Group.Session.Timeout).To(Equal(30 * time.Second)) + Expect(subject.Group.Heartbeat.Interval).To(Equal(3 * time.Second)) + Expect(subject.Group.Return.Notifications).To(BeFalse()) + Expect(subject.Metadata.Retry.Max).To(Equal(3)) + // Expect(subject.Config.Version).To(Equal(sarama.V0_9_0_0)) + }) + +}) diff --git a/vendor/github.com/bsm/sarama-cluster/consumer.go b/vendor/github.com/bsm/sarama-cluster/consumer.go new file mode 100644 index 00000000..97dd47fc --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/consumer.go @@ -0,0 +1,787 @@ +package cluster + +import ( + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/Shopify/sarama" +) + +// Consumer is a cluster group consumer +type Consumer struct { + client *Client + + csmr sarama.Consumer + subs *partitionMap + + consumerID string + generationID int32 + groupID string + memberID string + + coreTopics []string + extraTopics []string + + dying, dead chan none + + consuming int32 + errors chan error + messages chan *sarama.ConsumerMessage + notifications chan *Notification + + commitMu sync.Mutex +} + +// NewConsumerFromClient initializes a new consumer from an existing client +func NewConsumerFromClient(client *Client, groupID string, topics []string) (*Consumer, error) { + csmr, err := sarama.NewConsumerFromClient(client.Client) + if err != nil { + return nil, err + } + + sort.Strings(topics) + c := &Consumer{ + client: client, + csmr: csmr, + subs: newPartitionMap(), + groupID: groupID, + + coreTopics: topics, + + dying: make(chan none), + dead: make(chan none), + + errors: make(chan error, client.config.ChannelBufferSize), + messages: make(chan *sarama.ConsumerMessage), + notifications: make(chan *Notification, 1), + } + if err := c.client.RefreshCoordinator(groupID); err != nil { + return nil, err + } + + go c.mainLoop() + return c, nil +} + +// NewConsumer initializes a new consumer +func NewConsumer(addrs []string, groupID string, topics []string, config *Config) (*Consumer, error) { + client, err := NewClient(addrs, config) + if err != nil { + return nil, err + } + + consumer, err := NewConsumerFromClient(client, groupID, topics) + if err != nil { + _ = client.Close() + return nil, err + } + consumer.client.own = true + return consumer, nil +} + +// Messages returns the read channel for the messages that are returned by +// the broker. +func (c *Consumer) Messages() <-chan *sarama.ConsumerMessage { return c.messages } + +// Errors returns a read channel of errors that occur during offset management, if +// enabled. By default, errors are logged and not returned over this channel. If +// you want to implement any custom error handling, set your config's +// Consumer.Return.Errors setting to true, and read from this channel. +func (c *Consumer) Errors() <-chan error { return c.errors } + +// Notifications returns a channel of Notifications that occur during consumer +// rebalancing. Notifications will only be emitted over this channel, if your config's +// Group.Return.Notifications setting to true. +func (c *Consumer) Notifications() <-chan *Notification { return c.notifications } + +// HighWaterMarks returns the current high water marks for each topic and partition +// Consistency between partitions is not guaranteed since high water marks are updated separately. +func (c *Consumer) HighWaterMarks() map[string]map[int32]int64 { return c.csmr.HighWaterMarks() } + +// MarkOffset marks the provided message as processed, alongside a metadata string +// that represents the state of the partition consumer at that point in time. The +// metadata string can be used by another consumer to restore that state, so it +// can resume consumption. +// +// Note: calling MarkOffset does not necessarily commit the offset to the backend +// store immediately for efficiency reasons, and it may never be committed if +// your application crashes. This means that you may end up processing the same +// message twice, and your processing should ideally be idempotent. +func (c *Consumer) MarkOffset(msg *sarama.ConsumerMessage, metadata string) { + c.subs.Fetch(msg.Topic, msg.Partition).MarkOffset(msg.Offset+1, metadata) +} + +// MarkPartitionOffset marks an offset of the provided topic/partition as processed. +// See MarkOffset for additional explanation. +func (c *Consumer) MarkPartitionOffset(topic string, partition int32, offset int64, metadata string) { + c.subs.Fetch(topic, partition).MarkOffset(offset+1, metadata) +} + +// MarkOffsets marks stashed offsets as processed. +// See MarkOffset for additional explanation. +func (c *Consumer) MarkOffsets(s *OffsetStash) { + s.mu.Lock() + defer s.mu.Unlock() + + for tp, info := range s.offsets { + c.subs.Fetch(tp.Topic, tp.Partition).MarkOffset(info.Offset+1, info.Metadata) + delete(s.offsets, tp) + } +} + +// Subscriptions returns the consumed topics and partitions +func (c *Consumer) Subscriptions() map[string][]int32 { + return c.subs.Info() +} + +// CommitOffsets manually commits marked offsets. +func (c *Consumer) CommitOffsets() error { + c.commitMu.Lock() + defer c.commitMu.Unlock() + + req := &sarama.OffsetCommitRequest{ + Version: 2, + ConsumerGroup: c.groupID, + ConsumerGroupGeneration: c.generationID, + ConsumerID: c.memberID, + RetentionTime: -1, + } + + if ns := c.client.config.Consumer.Offsets.Retention; ns != 0 { + req.RetentionTime = int64(ns / time.Millisecond) + } + + snap := c.subs.Snapshot() + dirty := false + for tp, state := range snap { + if state.Dirty { + dirty = true + req.AddBlock(tp.Topic, tp.Partition, state.Info.Offset, 0, state.Info.Metadata) + } + } + if !dirty { + return nil + } + + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return err + } + + resp, err := broker.CommitOffset(req) + if err != nil { + c.closeCoordinator(broker, err) + return err + } + + for topic, errs := range resp.Errors { + for partition, kerr := range errs { + if kerr != sarama.ErrNoError { + err = kerr + } else if state, ok := snap[topicPartition{topic, partition}]; ok { + c.subs.Fetch(topic, partition).MarkCommitted(state.Info.Offset) + } + } + } + return err +} + +// Close safely closes the consumer and releases all resources +func (c *Consumer) Close() (err error) { + select { + case <-c.dying: + return + default: + close(c.dying) + } + <-c.dead + + if e := c.release(); e != nil { + err = e + } + if e := c.csmr.Close(); e != nil { + err = e + } + close(c.messages) + close(c.errors) + + if e := c.leaveGroup(); e != nil { + err = e + } + close(c.notifications) + + if c.client.own { + if e := c.client.Close(); e != nil { + err = e + } + } + + return +} + +func (c *Consumer) mainLoop() { + defer close(c.dead) + defer atomic.StoreInt32(&c.consuming, 0) + + for { + atomic.StoreInt32(&c.consuming, 0) + + // Check if close was requested + select { + case <-c.dying: + return + default: + } + + // Remember previous subscriptions + var notification *Notification + if c.client.config.Group.Return.Notifications { + notification = newNotification(c.subs.Info()) + } + + // Rebalance, fetch new subscriptions + subs, err := c.rebalance() + if err != nil { + c.rebalanceError(err, notification) + continue + } + + // Start the heartbeat + hbStop, hbDone := make(chan none), make(chan none) + go c.hbLoop(hbStop, hbDone) + + // Subscribe to topic/partitions + if err := c.subscribe(subs); err != nil { + close(hbStop) + <-hbDone + c.rebalanceError(err, notification) + continue + } + + // Start topic watcher loop + twStop, twDone := make(chan none), make(chan none) + go c.twLoop(twStop, twDone) + + // Start consuming and committing offsets + cmStop, cmDone := make(chan none), make(chan none) + go c.cmLoop(cmStop, cmDone) + atomic.StoreInt32(&c.consuming, 1) + + // Update notification with new claims + if c.client.config.Group.Return.Notifications { + notification.claim(subs) + c.notifications <- notification + } + + // Wait for signals + select { + case <-hbDone: + close(cmStop) + close(twStop) + <-cmDone + <-twDone + case <-twDone: + close(cmStop) + close(hbStop) + <-cmDone + <-hbDone + case <-cmDone: + close(twStop) + close(hbStop) + <-twDone + <-hbDone + case <-c.dying: + close(cmStop) + close(twStop) + close(hbStop) + <-cmDone + <-twDone + <-hbDone + return + } + } +} + +// heartbeat loop, triggered by the mainLoop +func (c *Consumer) hbLoop(stop <-chan none, done chan<- none) { + defer close(done) + + ticker := time.NewTicker(c.client.config.Group.Heartbeat.Interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + switch err := c.heartbeat(); err { + case nil, sarama.ErrNoError: + case sarama.ErrNotCoordinatorForConsumer, sarama.ErrRebalanceInProgress: + return + default: + c.handleError(&Error{Ctx: "heartbeat", error: err}) + return + } + case <-stop: + return + } + } +} + +// topic watcher loop, triggered by the mainLoop +func (c *Consumer) twLoop(stop <-chan none, done chan<- none) { + defer close(done) + + ticker := time.NewTicker(c.client.config.Metadata.RefreshFrequency / 2) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + topics, err := c.client.Topics() + if err != nil { + c.handleError(&Error{Ctx: "topics", error: err}) + return + } + + for _, topic := range topics { + if !c.isKnownCoreTopic(topic) && + !c.isKnownExtraTopic(topic) && + c.isPotentialExtraTopic(topic) { + return + } + } + case <-stop: + return + } + } +} + +// commit loop, triggered by the mainLoop +func (c *Consumer) cmLoop(stop <-chan none, done chan<- none) { + defer close(done) + + ticker := time.NewTicker(c.client.config.Consumer.Offsets.CommitInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := c.commitOffsetsWithRetry(c.client.config.Group.Offsets.Retry.Max); err != nil { + c.handleError(&Error{Ctx: "commit", error: err}) + return + } + case <-stop: + return + } + } +} + +func (c *Consumer) rebalanceError(err error, notification *Notification) { + if c.client.config.Group.Return.Notifications { + c.notifications <- notification + } + + switch err { + case sarama.ErrRebalanceInProgress: + default: + c.handleError(&Error{Ctx: "rebalance", error: err}) + } + + select { + case <-c.dying: + case <-time.After(c.client.config.Metadata.Retry.Backoff): + } +} + +func (c *Consumer) handleError(e *Error) { + if c.client.config.Consumer.Return.Errors { + select { + case c.errors <- e: + case <-c.dying: + return + } + } else { + sarama.Logger.Printf("%s error: %s\n", e.Ctx, e.Error()) + } +} + +// Releases the consumer and commits offsets, called from rebalance() and Close() +func (c *Consumer) release() (err error) { + // Stop all consumers + c.subs.Stop() + + // Clear subscriptions on exit + defer c.subs.Clear() + + // Wait for messages to be processed + time.Sleep(c.client.config.Consumer.MaxProcessingTime) + + // Commit offsets, continue on errors + if e := c.commitOffsetsWithRetry(c.client.config.Group.Offsets.Retry.Max); e != nil { + err = e + } + + return +} + +// -------------------------------------------------------------------- + +// Performs a heartbeat, part of the mainLoop() +func (c *Consumer) heartbeat() error { + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return err + } + + resp, err := broker.Heartbeat(&sarama.HeartbeatRequest{ + GroupId: c.groupID, + MemberId: c.memberID, + GenerationId: c.generationID, + }) + if err != nil { + c.closeCoordinator(broker, err) + return err + } + return resp.Err +} + +// Performs a rebalance, part of the mainLoop() +func (c *Consumer) rebalance() (map[string][]int32, error) { + sarama.Logger.Printf("cluster/consumer %s rebalance\n", c.memberID) + + if err := c.refreshMetadata(); err != nil { + return nil, err + } + + if err := c.client.RefreshCoordinator(c.groupID); err != nil { + return nil, err + } + + allTopics, err := c.client.Topics() + if err != nil { + return nil, err + } + c.extraTopics = c.selectExtraTopics(allTopics) + sort.Strings(c.extraTopics) + + // Release subscriptions + if err := c.release(); err != nil { + return nil, err + } + + // Re-join consumer group + strategy, err := c.joinGroup() + switch { + case err == sarama.ErrUnknownMemberId: + c.memberID = "" + return nil, err + case err != nil: + return nil, err + } + // sarama.Logger.Printf("cluster/consumer %s/%d joined group %s\n", c.memberID, c.generationID, c.groupID) + + // Sync consumer group state, fetch subscriptions + subs, err := c.syncGroup(strategy) + switch { + case err == sarama.ErrRebalanceInProgress: + return nil, err + case err != nil: + _ = c.leaveGroup() + return nil, err + } + return subs, nil +} + +// Performs the subscription, part of the mainLoop() +func (c *Consumer) subscribe(subs map[string][]int32) error { + // fetch offsets + offsets, err := c.fetchOffsets(subs) + if err != nil { + _ = c.leaveGroup() + return err + } + + // create consumers in parallel + var mu sync.Mutex + var wg sync.WaitGroup + + for topic, partitions := range subs { + for _, partition := range partitions { + wg.Add(1) + + info := offsets[topic][partition] + go func(t string, p int32) { + if e := c.createConsumer(t, p, info); e != nil { + mu.Lock() + err = e + mu.Unlock() + } + wg.Done() + }(topic, partition) + } + } + wg.Wait() + + if err != nil { + _ = c.release() + _ = c.leaveGroup() + } + return err +} + +// -------------------------------------------------------------------- + +// Send a request to the broker to join group on rebalance() +func (c *Consumer) joinGroup() (*balancer, error) { + req := &sarama.JoinGroupRequest{ + GroupId: c.groupID, + MemberId: c.memberID, + SessionTimeout: int32(c.client.config.Group.Session.Timeout / time.Millisecond), + ProtocolType: "consumer", + } + + meta := &sarama.ConsumerGroupMemberMetadata{ + Version: 1, + Topics: append(c.coreTopics, c.extraTopics...), + UserData: c.client.config.Group.Member.UserData, + } + err := req.AddGroupProtocolMetadata(string(StrategyRange), meta) + if err != nil { + return nil, err + } + err = req.AddGroupProtocolMetadata(string(StrategyRoundRobin), meta) + if err != nil { + return nil, err + } + + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } + + resp, err := broker.JoinGroup(req) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } else if resp.Err != sarama.ErrNoError { + c.closeCoordinator(broker, resp.Err) + return nil, resp.Err + } + + var strategy *balancer + if resp.LeaderId == resp.MemberId { + members, err := resp.GetMembers() + if err != nil { + return nil, err + } + + strategy, err = newBalancerFromMeta(c.client, members) + if err != nil { + return nil, err + } + } + + c.memberID = resp.MemberId + c.generationID = resp.GenerationId + + return strategy, nil +} + +// Send a request to the broker to sync the group on rebalance(). +// Returns a list of topics and partitions to consume. +func (c *Consumer) syncGroup(strategy *balancer) (map[string][]int32, error) { + req := &sarama.SyncGroupRequest{ + GroupId: c.groupID, + MemberId: c.memberID, + GenerationId: c.generationID, + } + + for memberID, topics := range strategy.Perform(c.client.config.Group.PartitionStrategy) { + if err := req.AddGroupAssignmentMember(memberID, &sarama.ConsumerGroupMemberAssignment{ + Version: 1, + Topics: topics, + }); err != nil { + return nil, err + } + } + + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } + + resp, err := broker.SyncGroup(req) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } else if resp.Err != sarama.ErrNoError { + c.closeCoordinator(broker, resp.Err) + return nil, resp.Err + } + + // Return if there is nothing to subscribe to + if len(resp.MemberAssignment) == 0 { + return nil, nil + } + + // Get assigned subscriptions + members, err := resp.GetMemberAssignment() + if err != nil { + return nil, err + } + + // Sort partitions, for each topic + for topic := range members.Topics { + sort.Sort(int32Slice(members.Topics[topic])) + } + return members.Topics, nil +} + +// Fetches latest committed offsets for all subscriptions +func (c *Consumer) fetchOffsets(subs map[string][]int32) (map[string]map[int32]offsetInfo, error) { + offsets := make(map[string]map[int32]offsetInfo, len(subs)) + req := &sarama.OffsetFetchRequest{ + Version: 1, + ConsumerGroup: c.groupID, + } + + for topic, partitions := range subs { + offsets[topic] = make(map[int32]offsetInfo, len(partitions)) + for _, partition := range partitions { + offsets[topic][partition] = offsetInfo{Offset: -1} + req.AddPartition(topic, partition) + } + } + + // Wait for other cluster consumers to process, release and commit + time.Sleep(c.client.config.Consumer.MaxProcessingTime * 2) + + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } + + resp, err := broker.FetchOffset(req) + if err != nil { + c.closeCoordinator(broker, err) + return nil, err + } + + for topic, partitions := range subs { + for _, partition := range partitions { + block := resp.GetBlock(topic, partition) + if block == nil { + return nil, sarama.ErrIncompleteResponse + } + + if block.Err == sarama.ErrNoError { + offsets[topic][partition] = offsetInfo{Offset: block.Offset, Metadata: block.Metadata} + } else { + return nil, block.Err + } + } + } + return offsets, nil +} + +// Send a request to the broker to leave the group on failes rebalance() and on Close() +func (c *Consumer) leaveGroup() error { + broker, err := c.client.Coordinator(c.groupID) + if err != nil { + c.closeCoordinator(broker, err) + return err + } + + if _, err = broker.LeaveGroup(&sarama.LeaveGroupRequest{ + GroupId: c.groupID, + MemberId: c.memberID, + }); err != nil { + c.closeCoordinator(broker, err) + } + return err +} + +// -------------------------------------------------------------------- + +func (c *Consumer) createConsumer(topic string, partition int32, info offsetInfo) error { + sarama.Logger.Printf("cluster/consumer %s consume %s/%d from %d\n", c.memberID, topic, partition, info.NextOffset(c.client.config.Consumer.Offsets.Initial)) + + // Create partitionConsumer + pc, err := newPartitionConsumer(c.csmr, topic, partition, info, c.client.config.Consumer.Offsets.Initial) + if err != nil { + return err + } + + // Store in subscriptions + c.subs.Store(topic, partition, pc) + + // Start partition consumer goroutine + go pc.Loop(c.messages, c.errors) + + return nil +} + +func (c *Consumer) commitOffsetsWithRetry(retries int) error { + err := c.CommitOffsets() + if err != nil && retries > 0 { + return c.commitOffsetsWithRetry(retries - 1) + } + return err +} + +func (c *Consumer) closeCoordinator(broker *sarama.Broker, err error) { + if broker != nil { + _ = broker.Close() + } + + switch err { + case sarama.ErrConsumerCoordinatorNotAvailable, sarama.ErrNotCoordinatorForConsumer: + _ = c.client.RefreshCoordinator(c.groupID) + } +} + +func (c *Consumer) selectExtraTopics(allTopics []string) []string { + extra := allTopics[:0] + for _, topic := range allTopics { + if !c.isKnownCoreTopic(topic) && c.isPotentialExtraTopic(topic) { + extra = append(extra, topic) + } + } + return extra +} + +func (c *Consumer) isKnownCoreTopic(topic string) bool { + pos := sort.SearchStrings(c.coreTopics, topic) + return pos < len(c.coreTopics) && c.coreTopics[pos] == topic +} + +func (c *Consumer) isKnownExtraTopic(topic string) bool { + pos := sort.SearchStrings(c.extraTopics, topic) + return pos < len(c.extraTopics) && c.extraTopics[pos] == topic +} + +func (c *Consumer) isPotentialExtraTopic(topic string) bool { + rx := c.client.config.Group.Topics + if rx.Blacklist != nil && rx.Blacklist.MatchString(topic) { + return false + } + if rx.Whitelist != nil && rx.Whitelist.MatchString(topic) { + return true + } + return false +} + +func (c *Consumer) refreshMetadata() error { + err := c.client.RefreshMetadata() + if err == sarama.ErrTopicAuthorizationFailed { + // maybe we didn't have authorization to describe all topics + err = c.client.RefreshMetadata(c.coreTopics...) + } + return err +} diff --git a/vendor/github.com/bsm/sarama-cluster/consumer_test.go b/vendor/github.com/bsm/sarama-cluster/consumer_test.go new file mode 100644 index 00000000..cbadcc84 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/consumer_test.go @@ -0,0 +1,214 @@ +package cluster + +import ( + "fmt" + "regexp" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Consumer", func() { + + var newConsumer = func(group string) (*Consumer, error) { + config := NewConfig() + config.Consumer.Return.Errors = true + return NewConsumer(testKafkaAddrs, group, testTopics, config) + } + + var newConsumerOf = func(group string, topics ...string) (*Consumer, error) { + config := NewConfig() + config.Consumer.Return.Errors = true + return NewConsumer(testKafkaAddrs, group, topics, config) + } + + var subscriptionsOf = func(c *Consumer) GomegaAsyncAssertion { + return Eventually(func() map[string][]int32 { + return c.Subscriptions() + }, "10s", "100ms") + } + + var consume = func(consumerID, group string, max int, out chan *testConsumerMessage) { + go func() { + defer GinkgoRecover() + + cs, err := newConsumer(group) + Expect(err).NotTo(HaveOccurred()) + defer cs.Close() + cs.consumerID = consumerID + + for msg := range cs.Messages() { + out <- &testConsumerMessage{*msg, consumerID} + cs.MarkOffset(msg, "") + + if max--; max == 0 { + return + } + } + }() + } + + It("should init and share", func() { + // start CS1 + cs1, err := newConsumer(testGroup) + Expect(err).NotTo(HaveOccurred()) + + // CS1 should consume all 8 partitions + subscriptionsOf(cs1).Should(Equal(map[string][]int32{ + "topic-a": {0, 1, 2, 3}, + "topic-b": {0, 1, 2, 3}, + })) + + // start CS2 + cs2, err := newConsumer(testGroup) + Expect(err).NotTo(HaveOccurred()) + defer cs2.Close() + + // CS1 and CS2 should consume 4 partitions each + subscriptionsOf(cs1).Should(HaveLen(2)) + subscriptionsOf(cs1).Should(HaveKeyWithValue("topic-a", HaveLen(2))) + subscriptionsOf(cs1).Should(HaveKeyWithValue("topic-b", HaveLen(2))) + + subscriptionsOf(cs2).Should(HaveLen(2)) + subscriptionsOf(cs2).Should(HaveKeyWithValue("topic-a", HaveLen(2))) + subscriptionsOf(cs2).Should(HaveKeyWithValue("topic-b", HaveLen(2))) + + // shutdown CS1, now CS2 should consume all 8 partitions + Expect(cs1.Close()).NotTo(HaveOccurred()) + subscriptionsOf(cs2).Should(Equal(map[string][]int32{ + "topic-a": {0, 1, 2, 3}, + "topic-b": {0, 1, 2, 3}, + })) + }) + + It("should allow more consumers than partitions", func() { + cs1, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + defer cs1.Close() + cs2, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + defer cs2.Close() + cs3, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + defer cs3.Close() + cs4, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + + // start 4 consumers, one for each partition + subscriptionsOf(cs1).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs2).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs3).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs4).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + + // add a 5th consumer + cs5, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + defer cs5.Close() + + // make sure no errors occurred + Expect(cs1.Errors()).ShouldNot(Receive()) + Expect(cs2.Errors()).ShouldNot(Receive()) + Expect(cs3.Errors()).ShouldNot(Receive()) + Expect(cs4.Errors()).ShouldNot(Receive()) + Expect(cs5.Errors()).ShouldNot(Receive()) + + // close 4th, make sure the 5th takes over + Expect(cs4.Close()).To(Succeed()) + subscriptionsOf(cs1).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs2).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs3).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + subscriptionsOf(cs4).Should(BeEmpty()) + subscriptionsOf(cs5).Should(HaveKeyWithValue("topic-a", HaveLen(1))) + + // there should still be no errors + Expect(cs1.Errors()).ShouldNot(Receive()) + Expect(cs2.Errors()).ShouldNot(Receive()) + Expect(cs3.Errors()).ShouldNot(Receive()) + Expect(cs4.Errors()).ShouldNot(Receive()) + Expect(cs5.Errors()).ShouldNot(Receive()) + }) + + It("should be allowed to subscribe to partitions via white/black-lists", func() { + config := NewConfig() + config.Consumer.Return.Errors = true + config.Group.Topics.Whitelist = regexp.MustCompile(`topic-\w+`) + config.Group.Topics.Blacklist = regexp.MustCompile(`[bcd]$`) + + cs, err := NewConsumer(testKafkaAddrs, testGroup, nil, config) + Expect(err).NotTo(HaveOccurred()) + defer cs.Close() + + subscriptionsOf(cs).Should(Equal(map[string][]int32{ + "topic-a": {0, 1, 2, 3}, + })) + }) + + It("should support manual mark/commit", func() { + cs, err := newConsumerOf(testGroup, "topic-a") + Expect(err).NotTo(HaveOccurred()) + defer cs.Close() + + subscriptionsOf(cs).Should(Equal(map[string][]int32{ + "topic-a": {0, 1, 2, 3}}, + )) + + cs.MarkPartitionOffset("topic-a", 1, 3, "") + cs.MarkPartitionOffset("topic-a", 2, 4, "") + Expect(cs.CommitOffsets()).NotTo(HaveOccurred()) + + offsets, err := cs.fetchOffsets(cs.Subscriptions()) + Expect(err).NotTo(HaveOccurred()) + Expect(offsets).To(Equal(map[string]map[int32]offsetInfo{ + "topic-a": {0: {Offset: -1}, 1: {Offset: 4}, 2: {Offset: 5}, 3: {Offset: -1}}, + })) + }) + + It("should consume/commit/resume", func() { + acc := make(chan *testConsumerMessage, 150000) + consume("A", "fuzzing", 1500, acc) + consume("B", "fuzzing", 2000, acc) + consume("C", "fuzzing", 1500, acc) + consume("D", "fuzzing", 200, acc) + consume("E", "fuzzing", 100, acc) + + Expect(testSeed(5000)).NotTo(HaveOccurred()) + Eventually(func() int { return len(acc) }, "30s", "100ms").Should(BeNumerically(">=", 5000)) + + consume("F", "fuzzing", 300, acc) + consume("G", "fuzzing", 400, acc) + consume("H", "fuzzing", 1000, acc) + consume("I", "fuzzing", 2000, acc) + Expect(testSeed(5000)).NotTo(HaveOccurred()) + Eventually(func() int { return len(acc) }, "30s", "100ms").Should(BeNumerically(">=", 8000)) + + consume("J", "fuzzing", 1000, acc) + Expect(testSeed(5000)).NotTo(HaveOccurred()) + Eventually(func() int { return len(acc) }, "30s", "100ms").Should(BeNumerically(">=", 9000)) + + consume("K", "fuzzing", 1000, acc) + consume("L", "fuzzing", 3000, acc) + Expect(testSeed(5000)).NotTo(HaveOccurred()) + Eventually(func() int { return len(acc) }, "30s", "100ms").Should(BeNumerically(">=", 12000)) + + consume("M", "fuzzing", 1000, acc) + Expect(testSeed(5000)).NotTo(HaveOccurred()) + Eventually(func() int { return len(acc) }, "30s", "100ms").Should(BeNumerically(">=", 15000)) + + close(acc) + + uniques := make(map[string][]string) + for msg := range acc { + key := fmt.Sprintf("%s/%d/%d", msg.Topic, msg.Partition, msg.Offset) + uniques[key] = append(uniques[key], msg.ConsumerID) + } + Expect(uniques).To(HaveLen(15000)) + }) + + It("should allow close to be called multiple times", func() { + cs, err := newConsumer(testGroup) + Expect(err).NotTo(HaveOccurred()) + Expect(cs.Close()).NotTo(HaveOccurred()) + Expect(cs.Close()).NotTo(HaveOccurred()) + }) + +}) diff --git a/vendor/github.com/bsm/sarama-cluster/doc.go b/vendor/github.com/bsm/sarama-cluster/doc.go new file mode 100644 index 00000000..9c8ff16a --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/doc.go @@ -0,0 +1,8 @@ +/* +Package cluster provides cluster extensions for Sarama, enabing users +to consume topics across from multiple, balanced nodes. + +It requires Kafka v0.9+ and follows the steps guide, described in: +https://cwiki.apache.org/confluence/display/KAFKA/Kafka+0.9+Consumer+Rewrite+Design +*/ +package cluster diff --git a/vendor/github.com/bsm/sarama-cluster/examples_test.go b/vendor/github.com/bsm/sarama-cluster/examples_test.go new file mode 100644 index 00000000..13c30f91 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/examples_test.go @@ -0,0 +1,54 @@ +package cluster_test + +import ( + "fmt" + "log" + "os" + "os/signal" + + cluster "github.com/bsm/sarama-cluster" +) + +// This example shows how to use the consumer can read messages +// from a multiple topics. +func ExampleConsumer() { + + // init (custom) config, enable errors and notifications + config := cluster.NewConfig() + config.Consumer.Return.Errors = true + config.Group.Return.Notifications = true + + // init consumer + brokers := []string{"127.0.0.1:9092"} + topics := []string{"my_topic", "other_topic"} + consumer, err := cluster.NewConsumer(brokers, "my-consumer-group", topics, config) + if err != nil { + panic(err) + } + defer consumer.Close() + + // trap SIGINT to trigger a shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + // consume messages, watch errors and notifications + for { + select { + case msg, more := <-consumer.Messages(): + if more { + fmt.Fprintf(os.Stdout, "%s/%d/%d\t%s\t%s\n", msg.Topic, msg.Partition, msg.Offset, msg.Key, msg.Value) + consumer.MarkOffset(msg, "") // mark message as processed + } + case err, more := <-consumer.Errors(): + if more { + log.Printf("Error: %s\n", err.Error()) + } + case ntf, more := <-consumer.Notifications(): + if more { + log.Printf("Rebalanced: %+v\n", ntf) + } + case <-signals: + return + } + } +} diff --git a/vendor/github.com/bsm/sarama-cluster/glide.yaml b/vendor/github.com/bsm/sarama-cluster/glide.yaml new file mode 100644 index 00000000..2bca63d2 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/glide.yaml @@ -0,0 +1,4 @@ +package: github.com/bsm/sarama-cluster +import: +- package: github.com/Shopify/sarama + version: ^1.9.0 diff --git a/vendor/github.com/bsm/sarama-cluster/offsets.go b/vendor/github.com/bsm/sarama-cluster/offsets.go new file mode 100644 index 00000000..b2abe355 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/offsets.go @@ -0,0 +1,49 @@ +package cluster + +import ( + "sync" + + "github.com/Shopify/sarama" +) + +// OffsetStash allows to accumulate offsets and +// mark them as processed in a bulk +type OffsetStash struct { + offsets map[topicPartition]offsetInfo + mu sync.Mutex +} + +// NewOffsetStash inits a blank stash +func NewOffsetStash() *OffsetStash { + return &OffsetStash{offsets: make(map[topicPartition]offsetInfo)} +} + +// MarkOffset stashes the provided message offset +func (s *OffsetStash) MarkOffset(msg *sarama.ConsumerMessage, metadata string) { + s.MarkPartitionOffset(msg.Topic, msg.Partition, msg.Offset, metadata) +} + +// MarkPartitionOffset stashes the offset for the provided topic/partition combination +func (s *OffsetStash) MarkPartitionOffset(topic string, partition int32, offset int64, metadata string) { + s.mu.Lock() + defer s.mu.Unlock() + + key := topicPartition{Topic: topic, Partition: partition} + if info := s.offsets[key]; offset >= info.Offset { + info.Offset = offset + info.Metadata = metadata + s.offsets[key] = info + } +} + +// Offsets returns the latest stashed offsets by topic-partition +func (s *OffsetStash) Offsets() map[string]int64 { + s.mu.Lock() + defer s.mu.Unlock() + + res := make(map[string]int64, len(s.offsets)) + for tp, info := range s.offsets { + res[tp.String()] = info.Offset + } + return res +} diff --git a/vendor/github.com/bsm/sarama-cluster/offsets_test.go b/vendor/github.com/bsm/sarama-cluster/offsets_test.go new file mode 100644 index 00000000..1a15144e --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/offsets_test.go @@ -0,0 +1,47 @@ +package cluster + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("OffsetStash", func() { + var subject *OffsetStash + + BeforeEach(func() { + subject = NewOffsetStash() + }) + + It("should update", func() { + Expect(subject.offsets).To(HaveLen(0)) + + subject.MarkPartitionOffset("topic", 0, 0, "m3ta") + Expect(subject.offsets).To(HaveLen(1)) + Expect(subject.offsets).To(HaveKeyWithValue( + topicPartition{Topic: "topic", Partition: 0}, + offsetInfo{Offset: 0, Metadata: "m3ta"}, + )) + + subject.MarkPartitionOffset("topic", 0, 200, "m3ta") + Expect(subject.offsets).To(HaveLen(1)) + Expect(subject.offsets).To(HaveKeyWithValue( + topicPartition{Topic: "topic", Partition: 0}, + offsetInfo{Offset: 200, Metadata: "m3ta"}, + )) + + subject.MarkPartitionOffset("topic", 0, 199, "m3t@") + Expect(subject.offsets).To(HaveLen(1)) + Expect(subject.offsets).To(HaveKeyWithValue( + topicPartition{Topic: "topic", Partition: 0}, + offsetInfo{Offset: 200, Metadata: "m3ta"}, + )) + + subject.MarkPartitionOffset("topic", 1, 300, "") + Expect(subject.offsets).To(HaveLen(2)) + Expect(subject.offsets).To(HaveKeyWithValue( + topicPartition{Topic: "topic", Partition: 1}, + offsetInfo{Offset: 300, Metadata: ""}, + )) + }) + +}) diff --git a/vendor/github.com/bsm/sarama-cluster/partitions.go b/vendor/github.com/bsm/sarama-cluster/partitions.go new file mode 100644 index 00000000..8266a879 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/partitions.go @@ -0,0 +1,214 @@ +package cluster + +import ( + "sort" + "sync" + "time" + + "github.com/Shopify/sarama" +) + +type partitionConsumer struct { + pcm sarama.PartitionConsumer + + state partitionState + mu sync.Mutex + + closed bool + dying, dead chan none +} + +func newPartitionConsumer(manager sarama.Consumer, topic string, partition int32, info offsetInfo, defaultOffset int64) (*partitionConsumer, error) { + pcm, err := manager.ConsumePartition(topic, partition, info.NextOffset(defaultOffset)) + + // Resume from default offset, if requested offset is out-of-range + if err == sarama.ErrOffsetOutOfRange { + info.Offset = -1 + pcm, err = manager.ConsumePartition(topic, partition, defaultOffset) + } + if err != nil { + return nil, err + } + + return &partitionConsumer{ + pcm: pcm, + state: partitionState{Info: info}, + + dying: make(chan none), + dead: make(chan none), + }, nil +} + +func (c *partitionConsumer) Loop(messages chan<- *sarama.ConsumerMessage, errors chan<- error) { + defer close(c.dead) + + for { + select { + case msg, ok := <-c.pcm.Messages(): + if !ok { + return + } + select { + case messages <- msg: + case <-c.dying: + return + } + case err, ok := <-c.pcm.Errors(): + if !ok { + return + } + select { + case errors <- err: + case <-c.dying: + return + } + case <-c.dying: + return + } + } +} + +func (c *partitionConsumer) Close() error { + if c.closed { + return nil + } + + err := c.pcm.Close() + c.closed = true + close(c.dying) + <-c.dead + + return err +} + +func (c *partitionConsumer) State() partitionState { + if c == nil { + return partitionState{} + } + + c.mu.Lock() + state := c.state + c.mu.Unlock() + + return state +} + +func (c *partitionConsumer) MarkCommitted(offset int64) { + if c == nil { + return + } + + c.mu.Lock() + if offset == c.state.Info.Offset { + c.state.Dirty = false + } + c.mu.Unlock() +} + +func (c *partitionConsumer) MarkOffset(offset int64, metadata string) { + if c == nil { + return + } + + c.mu.Lock() + if offset > c.state.Info.Offset { + c.state.Info.Offset = offset + c.state.Info.Metadata = metadata + c.state.Dirty = true + } + c.mu.Unlock() +} + +// -------------------------------------------------------------------- + +type partitionState struct { + Info offsetInfo + Dirty bool + LastCommit time.Time +} + +// -------------------------------------------------------------------- + +type partitionMap struct { + data map[topicPartition]*partitionConsumer + mu sync.RWMutex +} + +func newPartitionMap() *partitionMap { + return &partitionMap{ + data: make(map[topicPartition]*partitionConsumer), + } +} + +func (m *partitionMap) IsSubscribedTo(topic string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + for tp := range m.data { + if tp.Topic == topic { + return true + } + } + return false +} + +func (m *partitionMap) Fetch(topic string, partition int32) *partitionConsumer { + m.mu.RLock() + pc, _ := m.data[topicPartition{topic, partition}] + m.mu.RUnlock() + return pc +} + +func (m *partitionMap) Store(topic string, partition int32, pc *partitionConsumer) { + m.mu.Lock() + m.data[topicPartition{topic, partition}] = pc + m.mu.Unlock() +} + +func (m *partitionMap) Snapshot() map[topicPartition]partitionState { + m.mu.RLock() + defer m.mu.RUnlock() + + snap := make(map[topicPartition]partitionState, len(m.data)) + for tp, pc := range m.data { + snap[tp] = pc.State() + } + return snap +} + +func (m *partitionMap) Stop() { + m.mu.RLock() + defer m.mu.RUnlock() + + var wg sync.WaitGroup + for tp := range m.data { + wg.Add(1) + go func(p *partitionConsumer) { + _ = p.Close() + wg.Done() + }(m.data[tp]) + } + wg.Wait() +} + +func (m *partitionMap) Clear() { + m.mu.Lock() + for tp := range m.data { + delete(m.data, tp) + } + m.mu.Unlock() +} + +func (m *partitionMap) Info() map[string][]int32 { + info := make(map[string][]int32) + m.mu.RLock() + for tp := range m.data { + info[tp.Topic] = append(info[tp.Topic], tp.Partition) + } + m.mu.RUnlock() + + for topic := range info { + sort.Sort(int32Slice(info[topic])) + } + return info +} diff --git a/vendor/github.com/bsm/sarama-cluster/partitions_test.go b/vendor/github.com/bsm/sarama-cluster/partitions_test.go new file mode 100644 index 00000000..945913f4 --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/partitions_test.go @@ -0,0 +1,129 @@ +package cluster + +import ( + "github.com/Shopify/sarama" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("partitionConsumer", func() { + var subject *partitionConsumer + + BeforeEach(func() { + var err error + subject, err = newPartitionConsumer(&mockConsumer{}, "topic", 0, offsetInfo{2000, "m3ta"}, sarama.OffsetOldest) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + close(subject.dead) + Expect(subject.Close()).NotTo(HaveOccurred()) + }) + + It("should set state", func() { + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2000, "m3ta"}, + })) + }) + + It("should recover from default offset if requested offset is out of bounds", func() { + pc, err := newPartitionConsumer(&mockConsumer{}, "topic", 0, offsetInfo{200, "m3ta"}, sarama.OffsetOldest) + Expect(err).NotTo(HaveOccurred()) + defer pc.Close() + close(pc.dead) + + state := pc.State() + Expect(state.Info.Offset).To(Equal(int64(-1))) + Expect(state.Info.Metadata).To(Equal("m3ta")) + }) + + It("should update state", func() { + subject.MarkOffset(2001, "met@") // should set state + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2001, "met@"}, + Dirty: true, + })) + + subject.MarkCommitted(2001) // should reset dirty status + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2001, "met@"}, + })) + + subject.MarkOffset(2001, "me7a") // should not update state + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2001, "met@"}, + })) + + subject.MarkOffset(2002, "me7a") // should bump state + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2002, "me7a"}, + Dirty: true, + })) + + subject.MarkCommitted(2001) // should not unset state + Expect(subject.State()).To(Equal(partitionState{ + Info: offsetInfo{2002, "me7a"}, + Dirty: true, + })) + }) + + It("should not fail when nil", func() { + blank := (*partitionConsumer)(nil) + Expect(func() { + _ = blank.State() + blank.MarkOffset(2001, "met@") + blank.MarkCommitted(2001) + }).NotTo(Panic()) + }) + +}) + +var _ = Describe("partitionMap", func() { + var subject *partitionMap + + BeforeEach(func() { + subject = newPartitionMap() + }) + + It("should fetch/store", func() { + Expect(subject.Fetch("topic", 0)).To(BeNil()) + + pc, err := newPartitionConsumer(&mockConsumer{}, "topic", 0, offsetInfo{2000, "m3ta"}, sarama.OffsetNewest) + Expect(err).NotTo(HaveOccurred()) + + subject.Store("topic", 0, pc) + Expect(subject.Fetch("topic", 0)).To(Equal(pc)) + Expect(subject.Fetch("topic", 1)).To(BeNil()) + Expect(subject.Fetch("other", 0)).To(BeNil()) + }) + + It("should return info", func() { + pc0, err := newPartitionConsumer(&mockConsumer{}, "topic", 0, offsetInfo{2000, "m3ta"}, sarama.OffsetNewest) + Expect(err).NotTo(HaveOccurred()) + pc1, err := newPartitionConsumer(&mockConsumer{}, "topic", 1, offsetInfo{2000, "m3ta"}, sarama.OffsetNewest) + Expect(err).NotTo(HaveOccurred()) + subject.Store("topic", 0, pc0) + subject.Store("topic", 1, pc1) + + info := subject.Info() + Expect(info).To(HaveLen(1)) + Expect(info).To(HaveKeyWithValue("topic", []int32{0, 1})) + }) + + It("should create snapshots", func() { + pc0, err := newPartitionConsumer(&mockConsumer{}, "topic", 0, offsetInfo{2000, "m3ta"}, sarama.OffsetNewest) + Expect(err).NotTo(HaveOccurred()) + pc1, err := newPartitionConsumer(&mockConsumer{}, "topic", 1, offsetInfo{2000, "m3ta"}, sarama.OffsetNewest) + Expect(err).NotTo(HaveOccurred()) + + subject.Store("topic", 0, pc0) + subject.Store("topic", 1, pc1) + subject.Fetch("topic", 1).MarkOffset(2001, "met@") + + Expect(subject.Snapshot()).To(Equal(map[topicPartition]partitionState{ + {"topic", 0}: {Info: offsetInfo{2000, "m3ta"}, Dirty: false}, + {"topic", 1}: {Info: offsetInfo{2001, "met@"}, Dirty: true}, + })) + }) + +}) diff --git a/vendor/github.com/bsm/sarama-cluster/testdata/server.properties b/vendor/github.com/bsm/sarama-cluster/testdata/server.properties new file mode 100644 index 00000000..972b2aaf --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/testdata/server.properties @@ -0,0 +1,16 @@ +broker.id=0 +port=29092 +num.network.threads=2 +num.io.threads=8 +socket.send.buffer.bytes=1048576 +socket.receive.buffer.bytes=1048576 +socket.request.max.bytes=104857600 +log.dirs=/tmp/sarama-cluster-test/kafka +num.partitions=4 +log.flush.interval.ms=10 +log.retention.hours=1 +log.segment.bytes=536870912 +log.retention.check.interval.ms=60000 +log.cleaner.enable=true +zookeeper.connect=localhost:22181 +zookeeper.connection.timeout.ms=1000000 diff --git a/vendor/github.com/bsm/sarama-cluster/testdata/zookeeper.properties b/vendor/github.com/bsm/sarama-cluster/testdata/zookeeper.properties new file mode 100644 index 00000000..3646903d --- /dev/null +++ b/vendor/github.com/bsm/sarama-cluster/testdata/zookeeper.properties @@ -0,0 +1,3 @@ +dataDir=/tmp/sarama-cluster-test/zookeeper +clientPort=22181 +maxClientCnxns=0 diff --git a/vendor/github.com/cockroachdb/cmux/.gitignore b/vendor/github.com/cockroachdb/cmux/.gitignore new file mode 100644 index 00000000..daf913b1 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/cockroachdb/cmux/.travis.yml b/vendor/github.com/cockroachdb/cmux/.travis.yml new file mode 100644 index 00000000..e73780f2 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/.travis.yml @@ -0,0 +1,22 @@ +language: go + +go: + - 1.3 + - 1.4 + - 1.5 + - 1.6 + +gobuild_args: -race + +before_install: + - go get -u github.com/golang/lint/golint + - if [[ $TRAVIS_GO_VERSION == 1.5* ]]; then go get -u github.com/kisielk/errcheck; fi + - go get -u golang.org/x/tools/cmd/vet + +before_script: + - '! gofmt -s -l . | read' + - golint ./... + - echo $TRAVIS_GO_VERSION + - if [[ $TRAVIS_GO_VERSION == 1.5* ]]; then errcheck ./...; fi + - go vet . + - go tool vet --shadow . diff --git a/vendor/github.com/cockroachdb/cmux/LICENSE b/vendor/github.com/cockroachdb/cmux/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/cockroachdb/cmux/README.md b/vendor/github.com/cockroachdb/cmux/README.md new file mode 100644 index 00000000..b3713da5 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/README.md @@ -0,0 +1,65 @@ +# cmux: Connection Mux [![Build Status](https://travis-ci.org/cockroachdb/cmux.svg?branch=master)](https://travis-ci.org/cockroachdb/cmux) [![GoDoc](https://godoc.org/github.com/cockroachdb/cmux?status.svg)](https://godoc.org/github.com/cockroachdb/cmux) + +cmux is a generic Go library to multiplex connections based on their payload. +Using cmux, you can serve gRPC, SSH, HTTPS, HTTP, Go RPC, and pretty much any +other protocol on the same TCP listener. + +## How-To +Simply create your main listener, create a cmux for that listener, +and then match connections: +```go +// Create the main listener. +l, err := net.Listen("tcp", ":23456") +if err != nil { + log.Fatal(err) +} + +// Create a cmux. +m := cmux.New(l) + +// Match connections in order: +// First grpc, then HTTP, and otherwise Go RPC/TCP. +grpcL := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) +httpL := m.Match(cmux.HTTP1Fast()) +trpcL := m.Match(cmux.Any()) // Any means anything that is not yet matched. + +// Create your protocol servers. +grpcS := grpc.NewServer() +grpchello.RegisterGreeterServer(grpcs, &server{}) + +httpS := &http.Server{ + Handler: &helloHTTP1Handler{}, +} + +trpcS := rpc.NewServer() +s.Register(&ExampleRPCRcvr{}) + +// Use the muxed listeners for your servers. +go grpcS.Serve(grpcL) +go httpS.Serve(httpL) +go trpcS.Accept(trpcL) + +// Start serving! +m.Serve() +``` + +There are [more examples on GoDoc](https://godoc.org/github.com/cockroachdb/cmux#pkg-examples). + +## Performance +Since we are only matching the very first bytes of a connection, the +performance overhead on long-lived connections (i.e., RPCs and pipelined HTTP +streams) is negligible. + +## Limitations +* *TLS*: `net/http` uses a [type assertion](https://github.com/golang/go/issues/14221) +to identify TLS connections; since cmux's lookahead-implementing connection +wraps the underlying TLS connection, this type assertion fails. This means you +can serve HTTPS using cmux but `http.Request.TLS` will not be set in your +handlers. If you are able to wrap TLS around cmux, you can work around this +limitation. See https://github.com/cockroachdb/cockroach/commit/83caba2 for an +example of this approach. + +* *Different Protocols on The Same Connection*: `cmux` matches the connection +when it's accepted. For example, one connection can be either gRPC or REST, but +not both. That is, we assume that a client connection is either used for gRPC +or REST. diff --git a/vendor/github.com/cockroachdb/cmux/bench_test.go b/vendor/github.com/cockroachdb/cmux/bench_test.go new file mode 100644 index 00000000..2351cd0c --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/bench_test.go @@ -0,0 +1,46 @@ +package cmux + +import ( + "bytes" + "io" + "net" + "sync" + "testing" +) + +type mockConn struct { + net.Conn + r io.Reader +} + +func (c *mockConn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func BenchmarkCMuxConn(b *testing.B) { + benchHTTPPayload := make([]byte, 4096) + copy(benchHTTPPayload, []byte("GET http://www.w3.org/ HTTP/1.1")) + + m := New(nil).(*cMux) + l := m.Match(HTTP1Fast()) + + go func() { + for { + if _, err := l.Accept(); err != nil { + return + } + } + }() + + donec := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(b.N) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + c := &mockConn{ + r: bytes.NewReader(benchHTTPPayload), + } + m.serve(c, donec, &wg) + } +} diff --git a/vendor/github.com/cockroachdb/cmux/buffer.go b/vendor/github.com/cockroachdb/cmux/buffer.go new file mode 100644 index 00000000..5c178585 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/buffer.go @@ -0,0 +1,35 @@ +package cmux + +import ( + "bytes" + "io" +) + +// bufferedReader is an optimized implementation of io.Reader that behaves like +// ``` +// io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(source, buffer)) +// ``` +// without allocating. +type bufferedReader struct { + source io.Reader + buffer *bytes.Buffer + bufferRead int + bufferSize int +} + +func (s *bufferedReader) Read(p []byte) (int, error) { + // Functionality of bytes.Reader. + bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize]) + s.bufferRead += bn + + p = p[bn:] + + // Funtionality of io.TeeReader. + sn, sErr := s.source.Read(p) + if sn > 0 { + if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil { + return bn + wn, wErr + } + } + return bn + sn, sErr +} diff --git a/vendor/github.com/cockroachdb/cmux/cmux.go b/vendor/github.com/cockroachdb/cmux/cmux.go new file mode 100644 index 00000000..89cc910b --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/cmux.go @@ -0,0 +1,210 @@ +package cmux + +import ( + "bytes" + "fmt" + "io" + "net" + "sync" +) + +// Matcher matches a connection based on its content. +type Matcher func(io.Reader) bool + +// ErrorHandler handles an error and returns whether +// the mux should continue serving the listener. +type ErrorHandler func(error) bool + +var _ net.Error = ErrNotMatched{} + +// ErrNotMatched is returned whenever a connection is not matched by any of +// the matchers registered in the multiplexer. +type ErrNotMatched struct { + c net.Conn +} + +func (e ErrNotMatched) Error() string { + return fmt.Sprintf("mux: connection %v not matched by an matcher", + e.c.RemoteAddr()) +} + +// Temporary implements the net.Error interface. +func (e ErrNotMatched) Temporary() bool { return true } + +// Timeout implements the net.Error interface. +func (e ErrNotMatched) Timeout() bool { return false } + +type errListenerClosed string + +func (e errListenerClosed) Error() string { return string(e) } +func (e errListenerClosed) Temporary() bool { return false } +func (e errListenerClosed) Timeout() bool { return false } + +// ErrListenerClosed is returned from muxListener.Accept when the underlying +// listener is closed. +var ErrListenerClosed = errListenerClosed("mux: listener closed") + +// New instantiates a new connection multiplexer. +func New(l net.Listener) CMux { + return &cMux{ + root: l, + bufLen: 1024, + errh: func(_ error) bool { return true }, + donec: make(chan struct{}), + } +} + +// CMux is a multiplexer for network connections. +type CMux interface { + // Match returns a net.Listener that sees (i.e., accepts) only + // the connections matched by at least one of the matcher. + // + // The order used to call Match determines the priority of matchers. + Match(...Matcher) net.Listener + // Serve starts multiplexing the listener. Serve blocks and perhaps + // should be invoked concurrently within a go routine. + Serve() error + // HandleError registers an error handler that handles listener errors. + HandleError(ErrorHandler) +} + +type matchersListener struct { + ss []Matcher + l muxListener +} + +type cMux struct { + root net.Listener + bufLen int + errh ErrorHandler + donec chan struct{} + sls []matchersListener +} + +func (m *cMux) Match(matchers ...Matcher) net.Listener { + ml := muxListener{ + Listener: m.root, + connc: make(chan net.Conn, m.bufLen), + } + m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) + return ml +} + +func (m *cMux) Serve() error { + var wg sync.WaitGroup + + defer func() { + close(m.donec) + wg.Wait() + + for _, sl := range m.sls { + close(sl.l.connc) + // Drain the connections enqueued for the listener. + for c := range sl.l.connc { + _ = c.Close() + } + } + }() + + for { + c, err := m.root.Accept() + if err != nil { + if !m.handleErr(err) { + return err + } + continue + } + + wg.Add(1) + go m.serve(c, m.donec, &wg) + } +} + +func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + + muc := newMuxConn(c) + for _, sl := range m.sls { + for _, s := range sl.ss { + matched := s(muc.getSniffer()) + if matched { + select { + case sl.l.connc <- muc: + case <-donec: + _ = c.Close() + } + return + } + } + } + + _ = c.Close() + err := ErrNotMatched{c: c} + if !m.handleErr(err) { + _ = m.root.Close() + } +} + +func (m *cMux) HandleError(h ErrorHandler) { + m.errh = h +} + +func (m *cMux) handleErr(err error) bool { + if !m.errh(err) { + return false + } + + if ne, ok := err.(net.Error); ok { + return ne.Temporary() + } + + return false +} + +type muxListener struct { + net.Listener + connc chan net.Conn +} + +func (l muxListener) Accept() (net.Conn, error) { + c, ok := <-l.connc + if !ok { + return nil, ErrListenerClosed + } + return c, nil +} + +// MuxConn wraps a net.Conn and provides transparent sniffing of connection data. +type MuxConn struct { + net.Conn + buf bytes.Buffer + sniffer bufferedReader +} + +func newMuxConn(c net.Conn) *MuxConn { + return &MuxConn{ + Conn: c, + } +} + +// From the io.Reader documentation: +// +// When Read encounters an error or end-of-file condition after +// successfully reading n > 0 bytes, it returns the number of +// bytes read. It may return the (non-nil) error from the same call +// or return the error (and n == 0) from a subsequent call. +// An instance of this general case is that a Reader returning +// a non-zero number of bytes at the end of the input stream may +// return either err == EOF or err == nil. The next Read should +// return 0, EOF. +func (m *MuxConn) Read(p []byte) (int, error) { + if n, err := m.buf.Read(p); err != io.EOF { + return n, err + } + return m.Conn.Read(p) +} + +func (m *MuxConn) getSniffer() io.Reader { + m.sniffer = bufferedReader{source: m.Conn, buffer: &m.buf, bufferSize: m.buf.Len()} + return &m.sniffer +} diff --git a/vendor/github.com/cockroachdb/cmux/cmux_test.go b/vendor/github.com/cockroachdb/cmux/cmux_test.go new file mode 100644 index 00000000..90aaab9c --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/cmux_test.go @@ -0,0 +1,470 @@ +package cmux + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/rpc" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/http2" +) + +const ( + testHTTP1Resp = "http1" + rpcVal = 1234 +) + +func safeServe(errCh chan<- error, muxl CMux) { + if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { + errCh <- err + } +} + +func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { + c, err := rpc.Dial(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + return c, func() { + if err := c.Close(); err != nil { + t.Fatal(err) + } + } +} + +type chanListener struct { + net.Listener + connCh chan net.Conn +} + +func newChanListener() *chanListener { + return &chanListener{connCh: make(chan net.Conn, 1)} +} + +func (l *chanListener) Accept() (net.Conn, error) { + if c, ok := <-l.connCh; ok { + return c, nil + } + return nil, errors.New("use of closed network connection") +} + +func testListener(t *testing.T) (net.Listener, func()) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + return l, func() { + if err := l.Close(); err != nil { + t.Fatal(err) + } + } +} + +type testHTTP1Handler struct{} + +func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, testHTTP1Resp) +} + +func runTestHTTPServer(errCh chan<- error, l net.Listener) { + var mu sync.Mutex + conns := make(map[net.Conn]struct{}) + + defer func() { + mu.Lock() + for c := range conns { + if err := c.Close(); err != nil { + errCh <- err + } + } + mu.Unlock() + }() + + s := &http.Server{ + Handler: &testHTTP1Handler{}, + ConnState: func(c net.Conn, state http.ConnState) { + mu.Lock() + switch state { + case http.StateNew: + conns[c] = struct{}{} + case http.StateClosed: + delete(conns, c) + } + mu.Unlock() + }, + } + if err := s.Serve(l); err != ErrListenerClosed { + errCh <- err + } +} + +func runTestHTTP1Client(t *testing.T, addr net.Addr) { + if r, err := http.Get("http://" + addr.String()); err != nil { + t.Fatal(err) + } else { + defer func() { + if err := r.Body.Close(); err != nil { + t.Fatal(err) + } + }() + if b, err := ioutil.ReadAll(r.Body); err != nil { + t.Fatal(err) + } else { + if string(b) != testHTTP1Resp { + t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b) + } + } + } +} + +type TestRPCRcvr struct{} + +func (r TestRPCRcvr) Test(i int, j *int) error { + *j = i + return nil +} + +func runTestRPCServer(errCh chan<- error, l net.Listener) { + s := rpc.NewServer() + if err := s.Register(TestRPCRcvr{}); err != nil { + errCh <- err + } + for { + c, err := l.Accept() + if err != nil { + if err != ErrListenerClosed { + errCh <- err + } + return + } + go s.ServeConn(c) + } +} + +func runTestRPCClient(t *testing.T, addr net.Addr) { + c, cleanup := safeDial(t, addr) + defer cleanup() + + var num int + if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil { + t.Fatal(err) + } + + if num != rpcVal { + t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num) + } +} + +func TestRead(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + const payload = "hello world\r\n" + const mult = 2 + + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + defer close(l.connCh) + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher to force buffering exactly the right amount. + // Before this fix, this would trigger a bug where `Read` would incorrectly + // report `io.EOF` when only the buffer had been consumed. + muxl.Match(func(r io.Reader) bool { + var b [len(payload)]byte + _, _ = r.Read(b[:]) + return false + }) + anyl := muxl.Match(Any()) + go safeServe(errCh, muxl) + muxedConn, err := anyl.Accept() + if err != nil { + t.Fatal(err) + } + for i := 0; i < mult; i++ { + var b [len(payload)]byte + if n, err := muxedConn.Read(b[:]); err != nil { + t.Error(err) + } else if e := len(b); n != e { + t.Errorf("expected to read %d bytes, but read %d bytes", e, n) + } + } + var b [1]byte + if _, err := muxedConn.Read(b[:]); err != io.EOF { + t.Errorf("unexpected error %v, expected %v", err, io.EOF) + } +} + +func TestAny(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l, cleanup := testListener(t) + defer cleanup() + + muxl := New(l) + httpl := muxl.Match(Any()) + + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) + + runTestHTTP1Client(t, l.Addr()) +} + +func TestHTTP2(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher that only reads one byte. + muxl.Match(func(r io.Reader) bool { + var b [1]byte + _, _ = r.Read(b[:]) + return false + }) + h2l := muxl.Match(HTTP2()) + go safeServe(errCh, muxl) + muxedConn, err := h2l.Accept() + close(l.connCh) + if err != nil { + t.Fatal(err) + } + { + var b [len(http2.ClientPreface)]byte + if _, err := muxedConn.Read(b[:]); err != nil { + t.Fatal(err) + } + if string(b[:]) != http2.ClientPreface { + t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) + } + } + { + var b [1]byte + if _, err := muxedConn.Read(b[:]); err != io.EOF { + t.Errorf("unexpected error %v, expected %v", err, io.EOF) + } + } +} + +func TestHTTPGoRPC(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l, cleanup := testListener(t) + defer cleanup() + + muxl := New(l) + httpl := muxl.Match(HTTP2(), HTTP1Fast()) + rpcl := muxl.Match(Any()) + + go runTestHTTPServer(errCh, httpl) + go runTestRPCServer(errCh, rpcl) + go safeServe(errCh, muxl) + + runTestHTTP1Client(t, l.Addr()) + runTestRPCClient(t, l.Addr()) +} + +func TestErrorHandler(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l, cleanup := testListener(t) + defer cleanup() + + muxl := New(l) + httpl := muxl.Match(HTTP2(), HTTP1Fast()) + + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) + + var errCount uint32 + muxl.HandleError(func(err error) bool { + if atomic.AddUint32(&errCount, 1) == 1 { + if _, ok := err.(ErrNotMatched); !ok { + t.Errorf("unexpected error: %v", err) + } + } + return true + }) + + c, cleanup := safeDial(t, l.Addr()) + defer cleanup() + + var num int + for atomic.LoadUint32(&errCount) == 0 { + if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { + // The connection is simply closed. + t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount)) + } + } +} + +func TestClose(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l := newChanListener() + + c1, c2 := net.Pipe() + + muxl := New(l) + anyl := muxl.Match(Any()) + + go safeServe(errCh, muxl) + + l.connCh <- c1 + + // First connection goes through. + if _, err := anyl.Accept(); err != nil { + t.Fatal(err) + } + + // Second connection is sent + l.connCh <- c2 + + // Listener is closed. + close(l.connCh) + + // Second connection either goes through or it is closed. + if _, err := anyl.Accept(); err != nil { + if err != ErrListenerClosed { + t.Fatal(err) + } + if _, err := c2.Read([]byte{}); err != io.ErrClosedPipe { + t.Fatalf("connection is not closed and is leaked: %v", err) + } + } +} + +// Cribbed from google.golang.org/grpc/test/end2end_test.go. + +// interestingGoroutines returns all goroutines we care about for the purpose +// of leak checking. It excludes testing or runtime ones. +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if strings.HasPrefix(stack, "testing.RunTests") { + continue + } + + if stack == "" || + strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, g) + } + sort.Strings(gs) + return +} + +// leakCheck snapshots the currently-running goroutines and returns a +// function to be run at the end of tests to see whether any +// goroutines leaked. +func leakCheck(t testing.TB) func() { + orig := map[string]bool{} + for _, g := range interestingGoroutines() { + orig[g] = true + } + return func() { + // Loop, waiting for goroutines to shut down. + // Wait up to 5 seconds, but finish as quickly as possible. + deadline := time.Now().Add(5 * time.Second) + for { + var leaked []string + for _, g := range interestingGoroutines() { + if !orig[g] { + leaked = append(leaked, g) + } + } + if len(leaked) == 0 { + return + } + if time.Now().Before(deadline) { + time.Sleep(50 * time.Millisecond) + continue + } + for _, g := range leaked { + t.Errorf("Leaked goroutine: %v", g) + } + return + } + } +} diff --git a/vendor/github.com/cockroachdb/cmux/example_recursive_test.go b/vendor/github.com/cockroachdb/cmux/example_recursive_test.go new file mode 100644 index 00000000..dbf58b32 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/example_recursive_test.go @@ -0,0 +1,109 @@ +package cmux_test + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "net/rpc" + "strings" + + "github.com/cockroachdb/cmux" +) + +type recursiveHTTPHandler struct{} + +func (h *recursiveHTTPHandler) ServeHTTP(w http.ResponseWriter, + r *http.Request) { + + fmt.Fprintf(w, "example http response") +} + +func recursiveServeHTTP(l net.Listener) { + s := &http.Server{ + Handler: &recursiveHTTPHandler{}, + } + if err := s.Serve(l); err != cmux.ErrListenerClosed { + panic(err) + } +} + +func tlsListener(l net.Listener) net.Listener { + // Load certificates. + certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + log.Panic(err) + } + + config := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + Rand: rand.Reader, + } + + // Create TLS listener. + tlsl := tls.NewListener(l, config) + return tlsl +} + +type RecursiveRPCRcvr struct{} + +func (r *RecursiveRPCRcvr) Cube(i int, j *int) error { + *j = i * i + return nil +} + +func recursiveServeRPC(l net.Listener) { + s := rpc.NewServer() + if err := s.Register(&RecursiveRPCRcvr{}); err != nil { + panic(err) + } + for { + conn, err := l.Accept() + if err != nil { + if err != cmux.ErrListenerClosed { + panic(err) + } + return + } + go s.ServeConn(conn) + } +} + +// This is an example for serving HTTP, HTTPS, and GoRPC/TLS on the same port. +func Example_recursiveCmux() { + // Create the TCP listener. + l, err := net.Listen("tcp", "127.0.0.1:50051") + if err != nil { + log.Panic(err) + } + + // Create a mux. + tcpm := cmux.New(l) + + // We first match on HTTP 1.1 methods. + httpl := tcpm.Match(cmux.HTTP1Fast()) + + // If not matched, we assume that its TLS. + tlsl := tcpm.Match(cmux.Any()) + tlsl = tlsListener(tlsl) + + // Now, we build another mux recursively to match HTTPS and GoRPC. + // You can use the same trick for SSH. + tlsm := cmux.New(tlsl) + httpsl := tlsm.Match(cmux.HTTP1Fast()) + gorpcl := tlsm.Match(cmux.Any()) + go recursiveServeHTTP(httpl) + go recursiveServeHTTP(httpsl) + go recursiveServeRPC(gorpcl) + + go func() { + if err := tlsm.Serve(); err != cmux.ErrListenerClosed { + panic(err) + } + }() + if err := tcpm.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { + panic(err) + } +} diff --git a/vendor/github.com/cockroachdb/cmux/example_test.go b/vendor/github.com/cockroachdb/cmux/example_test.go new file mode 100644 index 00000000..69e89f4e --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/example_test.go @@ -0,0 +1,119 @@ +package cmux_test + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "net/rpc" + "strings" + + "golang.org/x/net/context" + "golang.org/x/net/websocket" + "google.golang.org/grpc" + "google.golang.org/grpc/examples/helloworld/helloworld" + + "github.com/cockroachdb/cmux" +) + +type exampleHTTPHandler struct{} + +func (h *exampleHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "example http response") +} + +func serveHTTP(l net.Listener) { + s := &http.Server{ + Handler: &exampleHTTPHandler{}, + } + if err := s.Serve(l); err != cmux.ErrListenerClosed { + panic(err) + } +} + +func EchoServer(ws *websocket.Conn) { + if _, err := io.Copy(ws, ws); err != nil { + panic(err) + } +} + +func serveWS(l net.Listener) { + s := &http.Server{ + Handler: websocket.Handler(EchoServer), + } + if err := s.Serve(l); err != cmux.ErrListenerClosed { + panic(err) + } +} + +type ExampleRPCRcvr struct{} + +func (r *ExampleRPCRcvr) Cube(i int, j *int) error { + *j = i * i + return nil +} + +func serveRPC(l net.Listener) { + s := rpc.NewServer() + if err := s.Register(&ExampleRPCRcvr{}); err != nil { + panic(err) + } + for { + conn, err := l.Accept() + if err != nil { + if err != cmux.ErrListenerClosed { + panic(err) + } + return + } + go s.ServeConn(conn) + } +} + +type grpcServer struct{} + +func (s *grpcServer) SayHello(ctx context.Context, in *helloworld.HelloRequest) ( + *helloworld.HelloReply, error) { + + return &helloworld.HelloReply{Message: "Hello " + in.Name + " from cmux"}, nil +} + +func serveGRPC(l net.Listener) { + grpcs := grpc.NewServer() + helloworld.RegisterGreeterServer(grpcs, &grpcServer{}) + if err := grpcs.Serve(l); err != cmux.ErrListenerClosed { + panic(err) + } +} + +func Example() { + l, err := net.Listen("tcp", "127.0.0.1:50051") + if err != nil { + log.Panic(err) + } + + m := cmux.New(l) + + // We first match the connection against HTTP2 fields. If matched, the + // connection will be sent through the "grpcl" listener. + grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + //Otherwise, we match it againts a websocket upgrade request. + wsl := m.Match(cmux.HTTP1HeaderField("Upgrade", "websocket")) + + // Otherwise, we match it againts HTTP1 methods. If matched, + // it is sent through the "httpl" listener. + httpl := m.Match(cmux.HTTP1Fast()) + // If not matched by HTTP, we assume it is an RPC connection. + rpcl := m.Match(cmux.Any()) + + // Then we used the muxed listeners. + go serveGRPC(grpcl) + go serveWS(wsl) + go serveHTTP(httpl) + go serveRPC(rpcl) + + if err := m.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { + panic(err) + } +} diff --git a/vendor/github.com/cockroachdb/cmux/example_tls_test.go b/vendor/github.com/cockroachdb/cmux/example_tls_test.go new file mode 100644 index 00000000..4d34e0b1 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/example_tls_test.go @@ -0,0 +1,75 @@ +package cmux_test + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "strings" + + "github.com/cockroachdb/cmux" +) + +type anotherHTTPHandler struct{} + +func (h *anotherHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "example http response") +} + +func serveHTTP1(l net.Listener) { + s := &http.Server{ + Handler: &anotherHTTPHandler{}, + } + if err := s.Serve(l); err != cmux.ErrListenerClosed { + panic(err) + } +} + +func serveHTTPS(l net.Listener) { + // Load certificates. + certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + log.Panic(err) + } + + config := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + Rand: rand.Reader, + } + + // Create TLS listener. + tlsl := tls.NewListener(l, config) + + // Serve HTTP over TLS. + serveHTTP1(tlsl) +} + +// This is an example for serving HTTP and HTTPS on the same port. +func Example_bothHTTPAndHTTPS() { + // Create the TCP listener. + l, err := net.Listen("tcp", "127.0.0.1:50051") + if err != nil { + log.Panic(err) + } + + // Create a mux. + m := cmux.New(l) + + // We first match on HTTP 1.1 methods. + httpl := m.Match(cmux.HTTP1Fast()) + + // If not matched, we assume that its TLS. + // + // Note that you can take this listener, do TLS handshake and + // create another mux to multiplex the connections over TLS. + tlsl := m.Match(cmux.Any()) + + go serveHTTP1(httpl) + go serveHTTPS(tlsl) + + if err := m.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { + panic(err) + } +} diff --git a/vendor/github.com/cockroachdb/cmux/matchers.go b/vendor/github.com/cockroachdb/cmux/matchers.go new file mode 100644 index 00000000..abc30f6e --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/matchers.go @@ -0,0 +1,150 @@ +package cmux + +import ( + "bufio" + "io" + "io/ioutil" + "net/http" + "strings" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +// Any is a Matcher that matches any connection. +func Any() Matcher { + return func(r io.Reader) bool { return true } +} + +// PrefixMatcher returns a matcher that matches a connection if it +// starts with any of the strings in strs. +func PrefixMatcher(strs ...string) Matcher { + pt := newPatriciaTreeString(strs...) + return pt.matchPrefix +} + +var defaultHTTPMethods = []string{ + "OPTIONS", + "GET", + "HEAD", + "POST", + "PUT", + "DELETE", + "TRACE", + "CONNECT", +} + +// HTTP1Fast only matches the methods in the HTTP request. +// +// This matcher is very optimistic: if it returns true, it does not mean that +// the request is a valid HTTP response. If you want a correct but slower HTTP1 +// matcher, use HTTP1 instead. +func HTTP1Fast(extMethods ...string) Matcher { + return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...) +} + +const maxHTTPRead = 4096 + +// HTTP1 parses the first line or upto 4096 bytes of the request to see if +// the conection contains an HTTP request. +func HTTP1() Matcher { + return func(r io.Reader) bool { + br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead}) + l, part, err := br.ReadLine() + if err != nil || part { + return false + } + + _, _, proto, ok := parseRequestLine(string(l)) + if !ok { + return false + } + + v, _, ok := http.ParseHTTPVersion(proto) + return ok && v == 1 + } +} + +// grabbed from net/http. +func parseRequestLine(line string) (method, uri, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +// HTTP2 parses the frame header of the first frame to detect whether the +// connection is an HTTP2 connection. +func HTTP2() Matcher { + return hasHTTP2Preface +} + +// HTTP1HeaderField returns a matcher matching the header fields of the first +// request of an HTTP 1 connection. +func HTTP1HeaderField(name, value string) Matcher { + return func(r io.Reader) bool { + return matchHTTP1Field(r, name, value) + } +} + +// HTTP2HeaderField resturns a matcher matching the header fields of the first +// headers frame. +func HTTP2HeaderField(name, value string) Matcher { + return func(r io.Reader) bool { + return matchHTTP2Field(r, name, value) + } +} + +func hasHTTP2Preface(r io.Reader) bool { + var b [len(http2.ClientPreface)]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return false + } + + return string(b[:]) == http2.ClientPreface +} + +func matchHTTP1Field(r io.Reader, name, value string) (matched bool) { + req, err := http.ReadRequest(bufio.NewReader(r)) + if err != nil { + return false + } + + return req.Header.Get(name) == value +} + +func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { + if !hasHTTP2Preface(r) { + return false + } + + framer := http2.NewFramer(ioutil.Discard, r) + hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { + if hf.Name == name && hf.Value == value { + matched = true + } + }) + for { + f, err := framer.ReadFrame() + if err != nil { + return false + } + + switch f := f.(type) { + case *http2.HeadersFrame: + if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { + return false + } + if matched { + return true + } + + if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 { + return false + } + } + } +} diff --git a/vendor/github.com/cockroachdb/cmux/patricia.go b/vendor/github.com/cockroachdb/cmux/patricia.go new file mode 100644 index 00000000..56ec4e7b --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/patricia.go @@ -0,0 +1,173 @@ +package cmux + +import ( + "bytes" + "io" +) + +// patriciaTree is a simple patricia tree that handles []byte instead of string +// and cannot be changed after instantiation. +type patriciaTree struct { + root *ptNode +} + +func newPatriciaTree(b ...[]byte) *patriciaTree { + return &patriciaTree{ + root: newNode(b), + } +} + +func newPatriciaTreeString(strs ...string) *patriciaTree { + b := make([][]byte, len(strs)) + for i, s := range strs { + b[i] = []byte(s) + } + return &patriciaTree{ + root: newNode(b), + } +} + +func (t *patriciaTree) matchPrefix(r io.Reader) bool { + return t.root.match(r, true) +} + +func (t *patriciaTree) match(r io.Reader) bool { + return t.root.match(r, false) +} + +type ptNode struct { + prefix []byte + next map[byte]*ptNode + terminal bool +} + +func newNode(strs [][]byte) *ptNode { + if len(strs) == 0 { + return &ptNode{ + prefix: []byte{}, + terminal: true, + } + } + + if len(strs) == 1 { + return &ptNode{ + prefix: strs[0], + terminal: true, + } + } + + p, strs := splitPrefix(strs) + n := &ptNode{ + prefix: p, + } + + nexts := make(map[byte][][]byte) + for _, s := range strs { + if len(s) == 0 { + n.terminal = true + continue + } + nexts[s[0]] = append(nexts[s[0]], s[1:]) + } + + n.next = make(map[byte]*ptNode) + for first, rests := range nexts { + n.next[first] = newNode(rests) + } + + return n +} + +func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) { + if len(bss) == 0 || len(bss[0]) == 0 { + return prefix, bss + } + + if len(bss) == 1 { + return bss[0], [][]byte{{}} + } + + for i := 0; ; i++ { + var cur byte + eq := true + for j, b := range bss { + if len(b) <= i { + eq = false + break + } + + if j == 0 { + cur = b[i] + continue + } + + if cur != b[i] { + eq = false + break + } + } + + if !eq { + break + } + + prefix = append(prefix, cur) + } + + rest = make([][]byte, 0, len(bss)) + for _, b := range bss { + rest = append(rest, b[len(prefix):]) + } + + return prefix, rest +} + +func readBytes(r io.Reader, n int) (b []byte, err error) { + b = make([]byte, n) + o := 0 + for o < n { + nr, err := r.Read(b[o:]) + if err != nil && err != io.EOF { + return b, err + } + + o += nr + + if err == io.EOF { + break + } + } + return b[:o], nil +} + +func (n *ptNode) match(r io.Reader, prefix bool) bool { + if l := len(n.prefix); l > 0 { + b, err := readBytes(r, l) + if err != nil || len(b) != l || !bytes.Equal(b, n.prefix) { + return false + } + } + + if prefix && n.terminal { + return true + } + + b := make([]byte, 1) + for { + nr, err := r.Read(b) + if nr != 0 { + break + } + + if err == io.EOF { + return n.terminal + } + + if err != nil { + return false + } + } + + nextN, ok := n.next[b[0]] + return ok && nextN.match(r, prefix) +} diff --git a/vendor/github.com/cockroachdb/cmux/patricia_test.go b/vendor/github.com/cockroachdb/cmux/patricia_test.go new file mode 100644 index 00000000..16b0f406 --- /dev/null +++ b/vendor/github.com/cockroachdb/cmux/patricia_test.go @@ -0,0 +1,35 @@ +package cmux + +import ( + "strings" + "testing" +) + +func testPTree(t *testing.T, strs ...string) { + pt := newPatriciaTreeString(strs...) + for _, s := range strs { + if !pt.match(strings.NewReader(s)) { + t.Errorf("%s is not matched by %s", s, s) + } + + if !pt.matchPrefix(strings.NewReader(s + s)) { + t.Errorf("%s is not matched as a prefix by %s", s+s, s) + } + + if pt.match(strings.NewReader(s + s)) { + t.Errorf("%s matches %s", s+s, s) + } + } +} + +func TestPatriciaOnePrefix(t *testing.T) { + testPTree(t, "prefix") +} + +func TestPatriciaNonOverlapping(t *testing.T) { + testPTree(t, "foo", "bar", "dummy") +} + +func TestPatriciaOverlapping(t *testing.T) { + testPTree(t, "foo", "far", "farther", "boo", "bar") +} diff --git a/vendor/github.com/coreos/go-semver/.travis.yml b/vendor/github.com/coreos/go-semver/.travis.yml new file mode 100644 index 00000000..05f548c9 --- /dev/null +++ b/vendor/github.com/coreos/go-semver/.travis.yml @@ -0,0 +1,8 @@ +language: go +sudo: false +go: + - 1.4 + - 1.5 + - 1.6 + - tip +script: cd semver && go test diff --git a/vendor/github.com/coreos/go-semver/LICENSE b/vendor/github.com/coreos/go-semver/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/vendor/github.com/coreos/go-semver/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/coreos/go-semver/README.md b/vendor/github.com/coreos/go-semver/README.md new file mode 100644 index 00000000..5bc9263c --- /dev/null +++ b/vendor/github.com/coreos/go-semver/README.md @@ -0,0 +1,28 @@ +# go-semver - Semantic Versioning Library + +[![Build Status](https://travis-ci.org/coreos/go-semver.svg?branch=master)](https://travis-ci.org/coreos/go-semver) +[![GoDoc](https://godoc.org/github.com/coreos/go-semver/semver?status.svg)](https://godoc.org/github.com/coreos/go-semver/semver) + +go-semver is a [semantic versioning][semver] library for Go. It lets you parse +and compare two semantic version strings. + +[semver]: http://semver.org/ + +## Usage + +```go +vA := semver.New("1.2.3") +vB := semver.New("3.2.1") + +fmt.Printf("%s < %s == %t\n", vA, vB, vA.LessThan(*vB)) +``` + +## Example Application + +``` +$ go run example.go 1.2.3 3.2.1 +1.2.3 < 3.2.1 == true + +$ go run example.go 5.2.3 3.2.1 +5.2.3 < 3.2.1 == false +``` diff --git a/vendor/github.com/coreos/go-semver/example.go b/vendor/github.com/coreos/go-semver/example.go new file mode 100644 index 00000000..fd2ee5af --- /dev/null +++ b/vendor/github.com/coreos/go-semver/example.go @@ -0,0 +1,20 @@ +package main + +import ( + "fmt" + "github.com/coreos/go-semver/semver" + "os" +) + +func main() { + vA, err := semver.NewVersion(os.Args[1]) + if err != nil { + fmt.Println(err.Error()) + } + vB, err := semver.NewVersion(os.Args[2]) + if err != nil { + fmt.Println(err.Error()) + } + + fmt.Printf("%s < %s == %t\n", vA, vB, vA.LessThan(*vB)) +} diff --git a/vendor/github.com/coreos/go-semver/semver/semver.go b/vendor/github.com/coreos/go-semver/semver/semver.go new file mode 100644 index 00000000..110fc23e --- /dev/null +++ b/vendor/github.com/coreos/go-semver/semver/semver.go @@ -0,0 +1,268 @@ +// Copyright 2013-2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Semantic Versions http://semver.org +package semver + +import ( + "bytes" + "errors" + "fmt" + "strconv" + "strings" +) + +type Version struct { + Major int64 + Minor int64 + Patch int64 + PreRelease PreRelease + Metadata string +} + +type PreRelease string + +func splitOff(input *string, delim string) (val string) { + parts := strings.SplitN(*input, delim, 2) + + if len(parts) == 2 { + *input = parts[0] + val = parts[1] + } + + return val +} + +func New(version string) *Version { + return Must(NewVersion(version)) +} + +func NewVersion(version string) (*Version, error) { + v := Version{} + + if err := v.Set(version); err != nil { + return nil, err + } + + return &v, nil +} + +// Must is a helper for wrapping NewVersion and will panic if err is not nil. +func Must(v *Version, err error) *Version { + if err != nil { + panic(err) + } + return v +} + +// Set parses and updates v from the given version string. Implements flag.Value +func (v *Version) Set(version string) error { + metadata := splitOff(&version, "+") + preRelease := PreRelease(splitOff(&version, "-")) + dotParts := strings.SplitN(version, ".", 3) + + if len(dotParts) != 3 { + return fmt.Errorf("%s is not in dotted-tri format", version) + } + + parsed := make([]int64, 3, 3) + + for i, v := range dotParts[:3] { + val, err := strconv.ParseInt(v, 10, 64) + parsed[i] = val + if err != nil { + return err + } + } + + v.Metadata = metadata + v.PreRelease = preRelease + v.Major = parsed[0] + v.Minor = parsed[1] + v.Patch = parsed[2] + return nil +} + +func (v Version) String() string { + var buffer bytes.Buffer + + fmt.Fprintf(&buffer, "%d.%d.%d", v.Major, v.Minor, v.Patch) + + if v.PreRelease != "" { + fmt.Fprintf(&buffer, "-%s", v.PreRelease) + } + + if v.Metadata != "" { + fmt.Fprintf(&buffer, "+%s", v.Metadata) + } + + return buffer.String() +} + +func (v *Version) UnmarshalYAML(unmarshal func(interface{}) error) error { + var data string + if err := unmarshal(&data); err != nil { + return err + } + return v.Set(data) +} + +func (v Version) MarshalJSON() ([]byte, error) { + return []byte(`"` + v.String() + `"`), nil +} + +func (v *Version) UnmarshalJSON(data []byte) error { + l := len(data) + if l == 0 || string(data) == `""` { + return nil + } + if l < 2 || data[0] != '"' || data[l-1] != '"' { + return errors.New("invalid semver string") + } + return v.Set(string(data[1 : l-1])) +} + +// Compare tests if v is less than, equal to, or greater than versionB, +// returning -1, 0, or +1 respectively. +func (v Version) Compare(versionB Version) int { + if cmp := recursiveCompare(v.Slice(), versionB.Slice()); cmp != 0 { + return cmp + } + return preReleaseCompare(v, versionB) +} + +// Equal tests if v is equal to versionB. +func (v Version) Equal(versionB Version) bool { + return v.Compare(versionB) == 0 +} + +// LessThan tests if v is less than versionB. +func (v Version) LessThan(versionB Version) bool { + return v.Compare(versionB) < 0 +} + +// Slice converts the comparable parts of the semver into a slice of integers. +func (v Version) Slice() []int64 { + return []int64{v.Major, v.Minor, v.Patch} +} + +func (p PreRelease) Slice() []string { + preRelease := string(p) + return strings.Split(preRelease, ".") +} + +func preReleaseCompare(versionA Version, versionB Version) int { + a := versionA.PreRelease + b := versionB.PreRelease + + /* Handle the case where if two versions are otherwise equal it is the + * one without a PreRelease that is greater */ + if len(a) == 0 && (len(b) > 0) { + return 1 + } else if len(b) == 0 && (len(a) > 0) { + return -1 + } + + // If there is a prerelease, check and compare each part. + return recursivePreReleaseCompare(a.Slice(), b.Slice()) +} + +func recursiveCompare(versionA []int64, versionB []int64) int { + if len(versionA) == 0 { + return 0 + } + + a := versionA[0] + b := versionB[0] + + if a > b { + return 1 + } else if a < b { + return -1 + } + + return recursiveCompare(versionA[1:], versionB[1:]) +} + +func recursivePreReleaseCompare(versionA []string, versionB []string) int { + // A larger set of pre-release fields has a higher precedence than a smaller set, + // if all of the preceding identifiers are equal. + if len(versionA) == 0 { + if len(versionB) > 0 { + return -1 + } + return 0 + } else if len(versionB) == 0 { + // We're longer than versionB so return 1. + return 1 + } + + a := versionA[0] + b := versionB[0] + + aInt := false + bInt := false + + aI, err := strconv.Atoi(versionA[0]) + if err == nil { + aInt = true + } + + bI, err := strconv.Atoi(versionB[0]) + if err == nil { + bInt = true + } + + // Handle Integer Comparison + if aInt && bInt { + if aI > bI { + return 1 + } else if aI < bI { + return -1 + } + } + + // Handle String Comparison + if a > b { + return 1 + } else if a < b { + return -1 + } + + return recursivePreReleaseCompare(versionA[1:], versionB[1:]) +} + +// BumpMajor increments the Major field by 1 and resets all other fields to their default values +func (v *Version) BumpMajor() { + v.Major += 1 + v.Minor = 0 + v.Patch = 0 + v.PreRelease = PreRelease("") + v.Metadata = "" +} + +// BumpMinor increments the Minor field by 1 and resets all other fields to their default values +func (v *Version) BumpMinor() { + v.Minor += 1 + v.Patch = 0 + v.PreRelease = PreRelease("") + v.Metadata = "" +} + +// BumpPatch increments the Patch field by 1 and resets all other fields to their default values +func (v *Version) BumpPatch() { + v.Patch += 1 + v.PreRelease = PreRelease("") + v.Metadata = "" +} diff --git a/vendor/github.com/coreos/go-semver/semver/semver_test.go b/vendor/github.com/coreos/go-semver/semver/semver_test.go new file mode 100644 index 00000000..876c68e0 --- /dev/null +++ b/vendor/github.com/coreos/go-semver/semver/semver_test.go @@ -0,0 +1,370 @@ +// Copyright 2013-2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package semver + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "math/rand" + "reflect" + "testing" + "time" + + "gopkg.in/yaml.v2" +) + +type fixture struct { + GreaterVersion string + LesserVersion string +} + +var fixtures = []fixture{ + fixture{"0.0.0", "0.0.0-foo"}, + fixture{"0.0.1", "0.0.0"}, + fixture{"1.0.0", "0.9.9"}, + fixture{"0.10.0", "0.9.0"}, + fixture{"0.99.0", "0.10.0"}, + fixture{"2.0.0", "1.2.3"}, + fixture{"0.0.0", "0.0.0-foo"}, + fixture{"0.0.1", "0.0.0"}, + fixture{"1.0.0", "0.9.9"}, + fixture{"0.10.0", "0.9.0"}, + fixture{"0.99.0", "0.10.0"}, + fixture{"2.0.0", "1.2.3"}, + fixture{"0.0.0", "0.0.0-foo"}, + fixture{"0.0.1", "0.0.0"}, + fixture{"1.0.0", "0.9.9"}, + fixture{"0.10.0", "0.9.0"}, + fixture{"0.99.0", "0.10.0"}, + fixture{"2.0.0", "1.2.3"}, + fixture{"1.2.3", "1.2.3-asdf"}, + fixture{"1.2.3", "1.2.3-4"}, + fixture{"1.2.3", "1.2.3-4-foo"}, + fixture{"1.2.3-5-foo", "1.2.3-5"}, + fixture{"1.2.3-5", "1.2.3-4"}, + fixture{"1.2.3-5-foo", "1.2.3-5-Foo"}, + fixture{"3.0.0", "2.7.2+asdf"}, + fixture{"3.0.0+foobar", "2.7.2"}, + fixture{"1.2.3-a.10", "1.2.3-a.5"}, + fixture{"1.2.3-a.b", "1.2.3-a.5"}, + fixture{"1.2.3-a.b", "1.2.3-a"}, + fixture{"1.2.3-a.b.c.10.d.5", "1.2.3-a.b.c.5.d.100"}, + fixture{"1.0.0", "1.0.0-rc.1"}, + fixture{"1.0.0-rc.2", "1.0.0-rc.1"}, + fixture{"1.0.0-rc.1", "1.0.0-beta.11"}, + fixture{"1.0.0-beta.11", "1.0.0-beta.2"}, + fixture{"1.0.0-beta.2", "1.0.0-beta"}, + fixture{"1.0.0-beta", "1.0.0-alpha.beta"}, + fixture{"1.0.0-alpha.beta", "1.0.0-alpha.1"}, + fixture{"1.0.0-alpha.1", "1.0.0-alpha"}, +} + +func TestCompare(t *testing.T) { + for _, v := range fixtures { + gt, err := NewVersion(v.GreaterVersion) + if err != nil { + t.Error(err) + } + + lt, err := NewVersion(v.LesserVersion) + if err != nil { + t.Error(err) + } + + if gt.LessThan(*lt) { + t.Errorf("%s should not be less than %s", gt, lt) + } + if gt.Equal(*lt) { + t.Errorf("%s should not be equal to %s", gt, lt) + } + if gt.Compare(*lt) <= 0 { + t.Errorf("%s should be greater than %s", gt, lt) + } + if !lt.LessThan(*gt) { + t.Errorf("%s should be less than %s", lt, gt) + } + if !lt.Equal(*lt) { + t.Errorf("%s should be equal to %s", lt, lt) + } + if lt.Compare(*gt) > 0 { + t.Errorf("%s should not be greater than %s", lt, gt) + } + } +} + +func testString(t *testing.T, orig string, version *Version) { + if orig != version.String() { + t.Errorf("%s != %s", orig, version) + } +} + +func TestString(t *testing.T) { + for _, v := range fixtures { + gt, err := NewVersion(v.GreaterVersion) + if err != nil { + t.Error(err) + } + testString(t, v.GreaterVersion, gt) + + lt, err := NewVersion(v.LesserVersion) + if err != nil { + t.Error(err) + } + testString(t, v.LesserVersion, lt) + } +} + +func shuffleStringSlice(src []string) []string { + dest := make([]string, len(src)) + rand.Seed(time.Now().Unix()) + perm := rand.Perm(len(src)) + for i, v := range perm { + dest[v] = src[i] + } + return dest +} + +func TestSort(t *testing.T) { + sortedVersions := []string{"1.0.0", "1.0.2", "1.2.0", "3.1.1"} + unsortedVersions := shuffleStringSlice(sortedVersions) + + semvers := []*Version{} + for _, v := range unsortedVersions { + sv, err := NewVersion(v) + if err != nil { + t.Fatal(err) + } + semvers = append(semvers, sv) + } + + Sort(semvers) + + for idx, sv := range semvers { + if sv.String() != sortedVersions[idx] { + t.Fatalf("incorrect sort at index %v", idx) + } + } +} + +func TestBumpMajor(t *testing.T) { + version, _ := NewVersion("1.0.0") + version.BumpMajor() + if version.Major != 2 { + t.Fatalf("bumping major on 1.0.0 resulted in %v", version) + } + + version, _ = NewVersion("1.5.2") + version.BumpMajor() + if version.Minor != 0 && version.Patch != 0 { + t.Fatalf("bumping major on 1.5.2 resulted in %v", version) + } + + version, _ = NewVersion("1.0.0+build.1-alpha.1") + version.BumpMajor() + if version.PreRelease != "" && version.PreRelease != "" { + t.Fatalf("bumping major on 1.0.0+build.1-alpha.1 resulted in %v", version) + } +} + +func TestBumpMinor(t *testing.T) { + version, _ := NewVersion("1.0.0") + version.BumpMinor() + + if version.Major != 1 { + t.Fatalf("bumping minor on 1.0.0 resulted in %v", version) + } + + if version.Minor != 1 { + t.Fatalf("bumping major on 1.0.0 resulted in %v", version) + } + + version, _ = NewVersion("1.0.0+build.1-alpha.1") + version.BumpMinor() + if version.PreRelease != "" && version.PreRelease != "" { + t.Fatalf("bumping major on 1.0.0+build.1-alpha.1 resulted in %v", version) + } +} + +func TestBumpPatch(t *testing.T) { + version, _ := NewVersion("1.0.0") + version.BumpPatch() + + if version.Major != 1 { + t.Fatalf("bumping minor on 1.0.0 resulted in %v", version) + } + + if version.Minor != 0 { + t.Fatalf("bumping major on 1.0.0 resulted in %v", version) + } + + if version.Patch != 1 { + t.Fatalf("bumping major on 1.0.0 resulted in %v", version) + } + + version, _ = NewVersion("1.0.0+build.1-alpha.1") + version.BumpPatch() + if version.PreRelease != "" && version.PreRelease != "" { + t.Fatalf("bumping major on 1.0.0+build.1-alpha.1 resulted in %v", version) + } +} + +func TestMust(t *testing.T) { + tests := []struct { + versionStr string + + version *Version + recov interface{} + }{ + { + versionStr: "1.0.0", + version: &Version{Major: 1}, + }, + { + versionStr: "version number", + recov: errors.New("version number is not in dotted-tri format"), + }, + } + + for _, tt := range tests { + func() { + defer func() { + recov := recover() + if !reflect.DeepEqual(tt.recov, recov) { + t.Fatalf("incorrect panic for %q: want %v, got %v", tt.versionStr, tt.recov, recov) + } + }() + + version := Must(NewVersion(tt.versionStr)) + if !reflect.DeepEqual(tt.version, version) { + t.Fatalf("incorrect version for %q: want %+v, got %+v", tt.versionStr, tt.version, version) + } + }() + } +} + +type fixtureJSON struct { + GreaterVersion *Version + LesserVersion *Version +} + +func TestJSON(t *testing.T) { + fj := make([]fixtureJSON, len(fixtures)) + for i, v := range fixtures { + var err error + fj[i].GreaterVersion, err = NewVersion(v.GreaterVersion) + if err != nil { + t.Fatal(err) + } + fj[i].LesserVersion, err = NewVersion(v.LesserVersion) + if err != nil { + t.Fatal(err) + } + } + + fromStrings, err := json.Marshal(fixtures) + if err != nil { + t.Fatal(err) + } + fromVersions, err := json.Marshal(fj) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fromStrings, fromVersions) { + t.Errorf("Expected: %s", fromStrings) + t.Errorf("Unexpected: %s", fromVersions) + } + + fromJson := make([]fixtureJSON, 0, len(fj)) + err = json.Unmarshal(fromStrings, &fromJson) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(fromJson, fj) { + t.Error("Expected: ", fj) + t.Error("Unexpected: ", fromJson) + } +} + +func TestYAML(t *testing.T) { + document, err := yaml.Marshal(fixtures) + if err != nil { + t.Fatal(err) + } + + expected := make([]fixtureJSON, len(fixtures)) + for i, v := range fixtures { + var err error + expected[i].GreaterVersion, err = NewVersion(v.GreaterVersion) + if err != nil { + t.Fatal(err) + } + expected[i].LesserVersion, err = NewVersion(v.LesserVersion) + if err != nil { + t.Fatal(err) + } + } + + fromYAML := make([]fixtureJSON, 0, len(fixtures)) + err = yaml.Unmarshal(document, &fromYAML) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(fromYAML, expected) { + t.Error("Expected: ", expected) + t.Error("Unexpected: ", fromYAML) + } +} + +func TestBadInput(t *testing.T) { + bad := []string{ + "1.2", + "1.2.3x", + "0x1.3.4", + "-1.2.3", + "1.2.3.4", + } + for _, b := range bad { + if _, err := NewVersion(b); err == nil { + t.Error("Improperly accepted value: ", b) + } + } +} + +func TestFlag(t *testing.T) { + v := Version{} + f := flag.NewFlagSet("version", flag.ContinueOnError) + f.Var(&v, "version", "set version") + + if err := f.Set("version", "1.2.3"); err != nil { + t.Fatal(err) + } + + if v.String() != "1.2.3" { + t.Errorf("Set wrong value %q", v) + } +} + +func ExampleVersion_LessThan() { + vA := New("1.2.3") + vB := New("3.2.1") + + fmt.Printf("%s < %s == %t\n", vA, vB, vA.LessThan(*vB)) + // Output: + // 1.2.3 < 3.2.1 == true +} diff --git a/vendor/github.com/coreos/go-semver/semver/sort.go b/vendor/github.com/coreos/go-semver/semver/sort.go new file mode 100644 index 00000000..e256b41a --- /dev/null +++ b/vendor/github.com/coreos/go-semver/semver/sort.go @@ -0,0 +1,38 @@ +// Copyright 2013-2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package semver + +import ( + "sort" +) + +type Versions []*Version + +func (s Versions) Len() int { + return len(s) +} + +func (s Versions) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s Versions) Less(i, j int) bool { + return s[i].LessThan(*s[j]) +} + +// Sort sorts the given slice of Version +func Sort(versions []*Version) { + sort.Sort(Versions(versions)) +} diff --git a/vendor/github.com/coreos/go-systemd/.travis.yml b/vendor/github.com/coreos/go-systemd/.travis.yml new file mode 100644 index 00000000..aa346efd --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/.travis.yml @@ -0,0 +1,27 @@ +sudo: required + +services: + - docker + +env: + global: + - GOPATH=/opt + - BUILD_DIR=/opt/src/github.com/coreos/go-systemd + matrix: + - DOCKER_BASE=ubuntu:16.04 + - DOCKER_BASE=debian:stretch + +before_install: + - docker pull ${DOCKER_BASE} + - docker run --privileged -e GOPATH=${GOPATH} --cidfile=/tmp/cidfile ${DOCKER_BASE} /bin/bash -c "apt-get update && apt-get install -y build-essential git golang dbus libsystemd-dev libpam-systemd && go get github.com/coreos/pkg/dlopen && go get github.com/godbus/dbus" + - docker commit `cat /tmp/cidfile` go-systemd/container-tests + - rm -f /tmp/cidfile + +install: + - docker run -d --cidfile=/tmp/cidfile --privileged -e GOPATH=${GOPATH} -v ${PWD}:${BUILD_DIR} go-systemd/container-tests /bin/systemd --system + +script: + - docker exec `cat /tmp/cidfile` /bin/bash -c "cd ${BUILD_DIR} && ./test" + +after_script: + - docker kill `cat /tmp/cidfile` diff --git a/vendor/github.com/coreos/go-systemd/CONTRIBUTING.md b/vendor/github.com/coreos/go-systemd/CONTRIBUTING.md new file mode 100644 index 00000000..0551ed53 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/CONTRIBUTING.md @@ -0,0 +1,77 @@ +# How to Contribute + +CoreOS projects are [Apache 2.0 licensed](LICENSE) and accept contributions via +GitHub pull requests. This document outlines some of the conventions on +development workflow, commit message formatting, contact points and other +resources to make it easier to get your contribution accepted. + +# Certificate of Origin + +By contributing to this project you agree to the Developer Certificate of +Origin (DCO). This document was created by the Linux Kernel community and is a +simple statement that you, as a contributor, have the legal right to make the +contribution. See the [DCO](DCO) file for details. + +# Email and Chat + +The project currently uses the general CoreOS email list and IRC channel: +- Email: [coreos-dev](https://groups.google.com/forum/#!forum/coreos-dev) +- IRC: #[coreos](irc://irc.freenode.org:6667/#coreos) IRC channel on freenode.org + +Please avoid emailing maintainers found in the MAINTAINERS file directly. They +are very busy and read the mailing lists. + +## Getting Started + +- Fork the repository on GitHub +- Read the [README](README.md) for build and test instructions +- Play with the project, submit bugs, submit patches! + +## Contribution Flow + +This is a rough outline of what a contributor's workflow looks like: + +- Create a topic branch from where you want to base your work (usually master). +- Make commits of logical units. +- Make sure your commit messages are in the proper format (see below). +- Push your changes to a topic branch in your fork of the repository. +- Make sure the tests pass, and add any new tests as appropriate. +- Submit a pull request to the original repository. + +Thanks for your contributions! + +### Coding Style + +CoreOS projects written in Go follow a set of style guidelines that we've documented +[here](https://github.com/coreos/docs/tree/master/golang). Please follow them when +working on your contributions. + +### Format of the Commit Message + +We follow a rough convention for commit messages that is designed to answer two +questions: what changed and why. The subject line should feature the what and +the body of the commit should describe the why. + +``` +scripts: add the test-cluster command + +this uses tmux to setup a test cluster that you can easily kill and +start for debugging. + +Fixes #38 +``` + +The format can be described more formally as follows: + +``` +: + + + +