diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 4d9d0fb871a..c4122490a8b 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -32,11 +32,12 @@ import ( "testing" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/security/advancedtls/testdata" ) func TestClientServerHandshake(t *testing.T) { // ------------------Load Client Trust Cert and Peer Cert------------------- - clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem") + clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) if err != nil { t.Fatalf("Client is unable to load trust certs. Error: %v", err) } @@ -50,21 +51,21 @@ func TestClientServerHandshake(t *testing.T) { } return results, fmt.Errorf("custom verification function failed") } - clientPeerCert, err := tls.LoadX509KeyPair("testdata/client_cert_1.pem", - "testdata/client_key_1.pem") + clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), + testdata.Path("client_key_1.pem")) if err != nil { t.Fatalf("Client is unable to parse peer certificates. Error: %v", err) } // ------------------Load Server Trust Cert and Peer Cert------------------- - serverTrustPool, err := readTrustCert("testdata/server_trust_cert_1.pem") + serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem")) if err != nil { t.Fatalf("Server is unable to load trust certs. Error: %v", err) } getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil } - serverPeerCert, err := tls.LoadX509KeyPair("testdata/server_cert_1.pem", - "testdata/server_key_1.pem") + serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), + testdata.Path("server_key_1.pem")) if err != nil { t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) } @@ -538,7 +539,7 @@ func compare(a1, a2 credentials.AuthInfo) bool { func TestAdvancedTLSOverrideServerName(t *testing.T) { expectedServerName := "server.name" - clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem") + clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) if err != nil { t.Fatalf("Client is unable to load trust certs. Error: %v", err) } @@ -560,7 +561,7 @@ func TestAdvancedTLSOverrideServerName(t *testing.T) { func TestTLSClone(t *testing.T) { expectedServerName := "server.name" - clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem") + clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) if err != nil { t.Fatalf("Client is unable to load trust certs. Error: %v", err) } @@ -571,6 +572,9 @@ func TestTLSClone(t *testing.T) { ServerNameOverride: expectedServerName, } c, err := NewClient(clientOptions) + if err != nil { + t.Fatalf("Failed to create new client: %v", err) + } cc := c.Clone() if cc.Info().ServerName != expectedServerName { t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) diff --git a/security/advancedtls/testdata/testdata.go b/security/advancedtls/testdata/testdata.go new file mode 100644 index 00000000000..c7d2481c4a0 --- /dev/null +++ b/security/advancedtls/testdata/testdata.go @@ -0,0 +1,42 @@ +/* + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testdata + +import ( + "path/filepath" + "runtime" +) + +// basepath is the root directory of this package. +var basepath string + +func init() { + _, currentFile, _, _ := runtime.Caller(0) + basepath = filepath.Dir(currentFile) +} + +// Path returns the absolute path the given relative file or directory path, +// relative to the google.golang.org/grpc/testdata directory in the user's GOPATH. +// If rel is already absolute, it is returned unmodified. +func Path(rel string) string { + if filepath.IsAbs(rel) { + return rel + } + + return filepath.Join(basepath, rel) +}