Skip to content

Commit

Permalink
Refactor some common utilities which will be used by the upcoming TSO…
Browse files Browse the repository at this point in the history
… microservice change (tikv#5985)

ref tikv#5836, ref tikv#5949

Refactor some common utilities which will be used by tso mcs

Signed-off-by: Bin Shi <binshi.bing@gmail.com>
  • Loading branch information
binshi-bing authored and nolouch committed Feb 24, 2023
1 parent e853333 commit 1141fbc
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 84 deletions.
60 changes: 60 additions & 0 deletions pkg/utils/etcdutil/etcdutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/tls"
"fmt"
"math/rand"
"net/http"
"net/url"
"testing"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/tempurl"
"github.com/tikv/pd/pkg/utils/typeutil"
"go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/embed"
"go.etcd.io/etcd/etcdserver"
Expand Down Expand Up @@ -230,3 +232,61 @@ func CreateClients(tlsConfig *tls.Config, acUrls []url.URL) (*clientv3.Client, *
log.Info("create etcd v3 client", zap.Strings("endpoints", endpoints))
return client, httpClient, nil
}

// InitClusterID creates a cluster ID for the given key if it hasn't existed.
// This function assumes the cluster ID has already existed and always use a
// cheaper read to retrieve it; if it doesn't exist, invoke the more expensive
// operation InitOrGetClusterID().
func InitClusterID(c *clientv3.Client, key string) (clusterID uint64, err error) {
// Get any cluster key to parse the cluster ID.
resp, err := EtcdKVGet(c, key)
if err != nil {
return 0, err
}
// If no key exist, generate a random cluster ID.
if len(resp.Kvs) == 0 {
return InitOrGetClusterID(c, key)
}
return typeutil.BytesToUint64(resp.Kvs[0].Value)
}

// InitOrGetClusterID creates a cluster ID for the given key with a CAS operation,
// if the cluster ID doesn't exist.
func InitOrGetClusterID(c *clientv3.Client, key string) (uint64, error) {
ctx, cancel := context.WithTimeout(c.Ctx(), DefaultRequestTimeout)
defer cancel()

// Generate a random cluster ID.
ts := uint64(time.Now().Unix())
clusterID := (ts << 32) + uint64(rand.Uint32())
value := typeutil.Uint64ToBytes(clusterID)

// Multiple servers may try to init the cluster ID at the same time.
// Only one server can commit this transaction, then other servers
// can get the committed cluster ID.
resp, err := c.Txn(ctx).
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, string(value))).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return 0, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause()
}

// Txn commits ok, return the generated cluster ID.
if resp.Succeeded {
return clusterID, nil
}

// Otherwise, parse the committed cluster ID.
if len(resp.Responses) == 0 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

response := resp.Responses[0].GetResponseRange()
if response == nil || len(response.Kvs) != 1 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

return typeutil.BytesToUint64(response.Kvs[0].Value)
}
33 changes: 33 additions & 0 deletions pkg/utils/etcdutil/etcdutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,36 @@ func TestEtcdKVPutWithTTL(t *testing.T) {
re.NoError(err)
re.Equal(int64(0), resp.Count)
}

func TestInitClusterID(t *testing.T) {
t.Parallel()
re := require.New(t)
cfg := NewTestSingleConfig(t)
etcd, err := embed.StartEtcd(cfg)
defer func() {
etcd.Close()
}()
re.NoError(err)

ep := cfg.LCUrls[0].String()
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{ep},
})
re.NoError(err)

<-etcd.Server.ReadyNotify()

pdClusterIDPath := "test/TestInitClusterID/pd/cluster_id"
// Get any cluster key to parse the cluster ID.
resp, err := EtcdKVGet(client, pdClusterIDPath)
re.NoError(err)
re.Equal(0, len(resp.Kvs))

clusterID, err := InitClusterID(client, pdClusterIDPath)
re.NoError(err)
re.NotEqual(0, clusterID)

clusterID1, err := InitClusterID(client, pdClusterIDPath)
re.NoError(err)
re.Equal(clusterID, clusterID1)
}
12 changes: 12 additions & 0 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,15 @@ func ResetForwardContext(ctx context.Context) context.Context {
md.Set(ForwardMetadataKey, "")
return metadata.NewOutgoingContext(ctx, md)
}

// GetForwardedHost returns the forwarded host in metadata.
func GetForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[ForwardMetadataKey]; ok {
return t[0]
}
return ""
}
26 changes: 7 additions & 19 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -74,7 +73,7 @@ func (s *GrpcServer) unaryMiddleware(ctx context.Context, header *pdpb.RequestHe
failpoint.Inject("customTimeout", func() {
time.Sleep(5 * time.Second)
})
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -167,7 +166,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
}

streamCtx := stream.Context()
forwardedHost := getForwardedHost(streamCtx)
forwardedHost := grpcutil.GetForwardedHost(streamCtx)
if !s.isLocalRequest(forwardedHost) {
if errCh == nil {
doneCh = make(chan struct{})
Expand Down Expand Up @@ -766,7 +765,7 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error {
if err != nil {
return errors.WithStack(err)
}
forwardedHost := getForwardedHost(stream.Context())
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
failpoint.Inject("grpcClientClosed", func() {
forwardedHost = s.GetMember().Member().GetClientUrls()[0]
})
Expand Down Expand Up @@ -861,7 +860,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error
return errors.WithStack(err)
}

forwardedHost := getForwardedHost(stream.Context())
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
if !s.isLocalRequest(forwardedHost) {
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancel != nil {
Expand Down Expand Up @@ -1786,17 +1785,6 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string
return client.(*grpc.ClientConn), nil
}

func getForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[grpcutil.ForwardMetadataKey]; ok {
return t[0]
}
return ""
}

func (s *GrpcServer) isLocalRequest(forwardedHost string) bool {
failpoint.Inject("useForwardRequest", func() {
failpoint.Return(false)
Expand Down Expand Up @@ -2044,7 +2032,7 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) {

// ReportMinResolvedTS implements gRPC PDServer.
func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -2078,7 +2066,7 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo

// SetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand All @@ -2105,7 +2093,7 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set

// GetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down
18 changes: 1 addition & 17 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func (s *Server) AddStartCallback(callbacks ...func()) {

func (s *Server) startServer(ctx context.Context) error {
var err error
if err = s.initClusterID(); err != nil {
if s.clusterID, err = etcdutil.InitClusterID(s.client, pdClusterIDPath); err != nil {
return err
}
log.Info("init cluster id", zap.Uint64("cluster-id", s.clusterID))
Expand Down Expand Up @@ -408,22 +408,6 @@ func (s *Server) startServer(ctx context.Context) error {
return nil
}

func (s *Server) initClusterID() error {
// Get any cluster key to parse the cluster ID.
resp, err := etcdutil.EtcdKVGet(s.client, pdClusterIDPath)
if err != nil {
return err
}

// If no key exist, generate a random cluster ID.
if len(resp.Kvs) == 0 {
s.clusterID, err = initOrGetClusterID(s.client, pdClusterIDPath)
return err
}
s.clusterID, err = typeutil.BytesToUint64(resp.Kvs[0].Value)
return err
}

// AddCloseCallback adds a callback in the Close phase.
func (s *Server) AddCloseCallback(callbacks ...func()) {
s.closeCallbacks = append(s.closeCallbacks, callbacks...)
Expand Down
48 changes: 0 additions & 48 deletions server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,21 @@ package server
import (
"context"
"fmt"
"math/rand"
"net/http"
"strings"
"time"

"github.com/gorilla/mux"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/etcdutil"
"github.com/tikv/pd/pkg/utils/typeutil"
"github.com/tikv/pd/pkg/versioninfo"
"github.com/tikv/pd/server/config"
"github.com/urfave/negroni"
"go.etcd.io/etcd/clientv3"
"go.uber.org/zap"
)

const (
requestTimeout = etcdutil.DefaultRequestTimeout
)

// LogPDInfo prints the PD version information.
func LogPDInfo() {
log.Info("Welcome to Placement Driver (PD)")
Expand Down Expand Up @@ -88,45 +79,6 @@ func CheckPDVersion(opt *config.PersistOptions) {
}
}

func initOrGetClusterID(c *clientv3.Client, key string) (uint64, error) {
ctx, cancel := context.WithTimeout(c.Ctx(), requestTimeout)
defer cancel()

// Generate a random cluster ID.
ts := uint64(time.Now().Unix())
clusterID := (ts << 32) + uint64(rand.Uint32())
value := typeutil.Uint64ToBytes(clusterID)

// Multiple PDs may try to init the cluster ID at the same time.
// Only one PD can commit this transaction, then other PDs can get
// the committed cluster ID.
resp, err := c.Txn(ctx).
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, string(value))).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return 0, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause()
}

// Txn commits ok, return the generated cluster ID.
if resp.Succeeded {
return clusterID, nil
}

// Otherwise, parse the committed cluster ID.
if len(resp.Responses) == 0 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

response := resp.Responses[0].GetResponseRange()
if response == nil || len(response.Kvs) != 1 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

return typeutil.BytesToUint64(response.Kvs[0].Value)
}

func checkBootstrapRequest(clusterID uint64, req *pdpb.BootstrapRequest) error {
// TODO: do more check for request fields validation.

Expand Down

0 comments on commit 1141fbc

Please sign in to comment.