Skip to content

Commit

Permalink
advancedtls: add package for testdata (#3306)
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Jan 10, 2020
1 parent 336cf8d commit 20bce9a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
20 changes: 12 additions & 8 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ import (
"testing"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/advancedtls/testdata"
)

func TestClientServerHandshake(t *testing.T) {
// ------------------Load Client Trust Cert and Peer Cert-------------------
clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem")
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
Expand All @@ -50,21 +51,21 @@ func TestClientServerHandshake(t *testing.T) {
}
return results, fmt.Errorf("custom verification function failed")
}
clientPeerCert, err := tls.LoadX509KeyPair("testdata/client_cert_1.pem",
"testdata/client_key_1.pem")
clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
testdata.Path("client_key_1.pem"))
if err != nil {
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
}
// ------------------Load Server Trust Cert and Peer Cert-------------------
serverTrustPool, err := readTrustCert("testdata/server_trust_cert_1.pem")
serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil
}
serverPeerCert, err := tls.LoadX509KeyPair("testdata/server_cert_1.pem",
"testdata/server_key_1.pem")
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
Expand Down Expand Up @@ -538,7 +539,7 @@ func compare(a1, a2 credentials.AuthInfo) bool {

func TestAdvancedTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem")
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
Expand All @@ -560,7 +561,7 @@ func TestAdvancedTLSOverrideServerName(t *testing.T) {

func TestTLSClone(t *testing.T) {
expectedServerName := "server.name"
clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem")
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
Expand All @@ -571,6 +572,9 @@ func TestTLSClone(t *testing.T) {
ServerNameOverride: expectedServerName,
}
c, err := NewClient(clientOptions)
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
cc := c.Clone()
if cc.Info().ServerName != expectedServerName {
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
Expand Down
42 changes: 42 additions & 0 deletions security/advancedtls/testdata/testdata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package testdata

import (
"path/filepath"
"runtime"
)

// basepath is the root directory of this package.
var basepath string

func init() {
_, currentFile, _, _ := runtime.Caller(0)
basepath = filepath.Dir(currentFile)
}

// Path returns the absolute path the given relative file or directory path,
// relative to the google.golang.org/grpc/testdata directory in the user's GOPATH.
// If rel is already absolute, it is returned unmodified.
func Path(rel string) string {
if filepath.IsAbs(rel) {
return rel
}

return filepath.Join(basepath, rel)
}

0 comments on commit 20bce9a

Please sign in to comment.