Skip to content

Commit

Permalink
Updates, and update testing files. Remove some deprecations
Browse files Browse the repository at this point in the history
  • Loading branch information
cmeyer18 committed Jan 28, 2024
1 parent 6fa80de commit 282d804
Show file tree
Hide file tree
Showing 37 changed files with 110 additions and 2,375 deletions.
49 changes: 24 additions & 25 deletions certificate/certificate_test.go
Original file line number Diff line number Diff line change
@@ -1,108 +1,107 @@
package certificate_test
package certificate

import (
"crypto/tls"
"errors"
"io/ioutil"
"os"
"testing"

"github.com/cmeyer18/apns2/certificate"
"github.com/stretchr/testify/assert"
)

// PKCS#12

func TestValidCertificateFromP12File(t *testing.T) {
cer, err := certificate.FromP12File("_fixtures/certificate-valid.p12", "")
cer, err := FromP12File("_fixtures/certificate-valid.p12", "")
assert.Nil(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestValidCertificateFromP12Bytes(t *testing.T) {
bytes, _ := ioutil.ReadFile("_fixtures/certificate-valid.p12")
cer, err := certificate.FromP12Bytes(bytes, "")
bytes, _ := os.ReadFile("_fixtures/certificate-valid.p12")
cer, err := FromP12Bytes(bytes, "")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestEncryptedValidCertificateFromP12File(t *testing.T) {
cer, err := certificate.FromP12File("_fixtures/certificate-valid-encrypted.p12", "password")
cer, err := FromP12File("_fixtures/certificate-valid-encrypted.p12", "password")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestNoSuchFileP12File(t *testing.T) {
cer, err := certificate.FromP12File("", "")
cer, err := FromP12File("", "")
assert.Equal(t, errors.New("open : no such file or directory").Error(), err.Error())
assert.Equal(t, tls.Certificate{}, cer)
}

func TestBadPasswordP12File(t *testing.T) {
cer, err := certificate.FromP12File("_fixtures/certificate-valid-encrypted.p12", "")
cer, err := FromP12File("_fixtures/certificate-valid-encrypted.p12", "")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, errors.New("pkcs12: decryption password incorrect").Error(), err.Error())
}

// PEM

func TestValidCertificateFromPemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-valid.pem", "")
cer, err := FromPemFile("_fixtures/certificate-valid.pem", "")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestValidCertificateFromPemBytes(t *testing.T) {
bytes, _ := ioutil.ReadFile("_fixtures/certificate-valid.pem")
cer, err := certificate.FromPemBytes(bytes, "")
bytes, _ := os.ReadFile("_fixtures/certificate-valid.pem")
cer, err := FromPemBytes(bytes, "")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestValidCertificateFromPemFileWithPKCS8PrivateKey(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-valid-pkcs8.pem", "")
cer, err := FromPemFile("_fixtures/certificate-valid-pkcs8.pem", "")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestValidCertificateFromPemBytesWithPKCS8PrivateKey(t *testing.T) {
bytes, _ := ioutil.ReadFile("_fixtures/certificate-valid-pkcs8.pem")
cer, err := certificate.FromPemBytes(bytes, "")
bytes, _ := os.ReadFile("_fixtures/certificate-valid-pkcs8.pem")
cer, err := FromPemBytes(bytes, "")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestEncryptedValidCertificateFromPemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-valid-encrypted.pem", "password")
cer, err := FromPemFile("_fixtures/certificate-valid-encrypted.pem", "password")
assert.NoError(t, err)
assert.NotEqual(t, tls.Certificate{}, cer)
}

func TestNoSuchFilePemFile(t *testing.T) {
cer, err := certificate.FromPemFile("", "")
cer, err := FromPemFile("", "")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, errors.New("open : no such file or directory").Error(), err.Error())
}

func TestBadPasswordPemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-valid-encrypted.pem", "badpassword")
cer, err := FromPemFile("_fixtures/certificate-valid-encrypted.pem", "badpassword")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, certificate.ErrFailedToDecryptKey, err)
assert.Equal(t, ErrFailedToDecryptKey, err)
}

func TestBadKeyPemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-bad-key.pem", "")
cer, err := FromPemFile("_fixtures/certificate-bad-key.pem", "")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, certificate.ErrFailedToParsePrivateKey, err)
assert.Equal(t, ErrFailedToParsePrivateKey, err)
}

func TestNoKeyPemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-no-key.pem", "")
cer, err := FromPemFile("_fixtures/certificate-no-key.pem", "")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, certificate.ErrNoPrivateKey, err)
assert.Equal(t, ErrNoPrivateKey, err)
}

func TestNoCertificatePemFile(t *testing.T) {
cer, err := certificate.FromPemFile("_fixtures/certificate-no-certificate.pem", "")
cer, err := FromPemFile("_fixtures/certificate-no-pem", "")
assert.Equal(t, tls.Certificate{}, cer)
assert.Equal(t, certificate.ErrNoCertificate, err)
assert.Equal(t, ErrNoCertificate, err)
}
8 changes: 4 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ var (
TLSDialTimeout = 20 * time.Second
)

// DialTLS is the default dial function for creating TLS connections for
// DialTLSContext is the default dial function for creating TLS connections for
// non-proxied HTTPS requests.
var DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
var DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: TLSDialTimeout,
KeepAlive: TCPKeepAlive,
Expand Down Expand Up @@ -97,7 +97,7 @@ func NewClient(certificate tls.Certificate) *Client {
}
transport := &http2.Transport{
TLSClientConfig: tlsConfig,
DialTLS: DialTLS,
DialTLSContext: DialTLSContext,
ReadIdleTimeout: ReadIdleTimeout,
}
return &Client{
Expand All @@ -120,7 +120,7 @@ func NewClient(certificate tls.Certificate) *Client {
// connection and disconnection as a denial-of-service attack.
func NewTokenClient(token *token.Token) *Client {
transport := &http2.Transport{
DialTLS: DialTLS,
DialTLSContext: DialTLSContext,
ReadIdleTimeout: ReadIdleTimeout,
}
return &Client{
Expand Down
47 changes: 23 additions & 24 deletions client_manager_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package apns2_test
package apns2

import (
"bytes"
Expand All @@ -8,22 +8,21 @@ import (
"testing"
"time"

"github.com/cmeyer18/apns2"
"github.com/cmeyer18/apns2/certificate"
"github.com/stretchr/testify/assert"
)

func TestNewClientManager(t *testing.T) {
manager := apns2.NewClientManager()
manager := NewClientManager()
assert.Equal(t, manager.MaxSize, 64)
assert.Equal(t, manager.MaxAge, 10*time.Minute)
}

func TestClientManagerGetWithoutNew(t *testing.T) {
manager := apns2.ClientManager{
manager := ClientManager{
MaxSize: 32,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
Factory: NewClient,
}

c1 := manager.Get(mockCert())
Expand All @@ -38,16 +37,16 @@ func TestClientManagerGetWithoutNew(t *testing.T) {
func TestClientManagerAddWithoutNew(t *testing.T) {
wg := sync.WaitGroup{}

manager := apns2.ClientManager{
manager := ClientManager{
MaxSize: 1,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
Factory: NewClient,
}

for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
manager.Add(apns2.NewClient(mockCert()))
manager.Add(NewClient(mockCert()))
assert.Equal(t, 1, manager.Len())
wg.Done()
}()
Expand All @@ -56,17 +55,17 @@ func TestClientManagerAddWithoutNew(t *testing.T) {
}

func TestClientManagerLenWithoutNew(t *testing.T) {
manager := apns2.ClientManager{
manager := ClientManager{
MaxSize: 32,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
Factory: NewClient,
}

assert.Equal(t, 0, manager.Len())
}

func TestClientManagerGetDefaultOptions(t *testing.T) {
manager := apns2.NewClientManager()
manager := NewClientManager()
c1 := manager.Get(mockCert())
c2 := manager.Get(mockCert())
v1 := reflect.ValueOf(c1)
Expand All @@ -77,8 +76,8 @@ func TestClientManagerGetDefaultOptions(t *testing.T) {
}

func TestClientManagerGetNilClientFactory(t *testing.T) {
manager := apns2.NewClientManager()
manager.Factory = func(certificate tls.Certificate) *apns2.Client {
manager := NewClientManager()
manager.Factory = func(certificate tls.Certificate) *Client {
return nil
}
c1 := manager.Get(mockCert())
Expand All @@ -89,7 +88,7 @@ func TestClientManagerGetNilClientFactory(t *testing.T) {
}

func TestClientManagerGetMaxAgeExpiration(t *testing.T) {
manager := apns2.NewClientManager()
manager := NewClientManager()
manager.MaxAge = time.Nanosecond
c1 := manager.Get(mockCert())
time.Sleep(time.Microsecond)
Expand All @@ -102,12 +101,12 @@ func TestClientManagerGetMaxAgeExpiration(t *testing.T) {
}

func TestClientManagerGetMaxAgeExpirationWithNilFactory(t *testing.T) {
manager := apns2.NewClientManager()
manager.Factory = func(certificate tls.Certificate) *apns2.Client {
manager := NewClientManager()
manager.Factory = func(certificate tls.Certificate) *Client {
return nil
}
manager.MaxAge = time.Nanosecond
manager.Add(apns2.NewClient(mockCert()))
manager.Add(NewClient(mockCert()))
c1 := manager.Get(mockCert())
time.Sleep(time.Microsecond)
c2 := manager.Get(mockCert())
Expand All @@ -117,7 +116,7 @@ func TestClientManagerGetMaxAgeExpirationWithNilFactory(t *testing.T) {
}

func TestClientManagerGetMaxSizeExceeded(t *testing.T) {
manager := apns2.NewClientManager()
manager := NewClientManager()
manager.MaxSize = 1
cert1 := mockCert()
_ = manager.Get(cert1)
Expand All @@ -130,20 +129,20 @@ func TestClientManagerGetMaxSizeExceeded(t *testing.T) {
}

func TestClientManagerAdd(t *testing.T) {
fn := func(certificate tls.Certificate) *apns2.Client {
fn := func(certificate tls.Certificate) *Client {
t.Fatal("factory should not have been called")
return nil
}

manager := apns2.NewClientManager()
manager := NewClientManager()
manager.Factory = fn
manager.Add(apns2.NewClient(mockCert()))
manager.Add(NewClient(mockCert()))
manager.Get(mockCert())
}

func TestClientManagerAddTwice(t *testing.T) {
manager := apns2.NewClientManager()
manager.Add(apns2.NewClient(mockCert()))
manager.Add(apns2.NewClient(mockCert()))
manager := NewClientManager()
manager.Add(NewClient(mockCert()))
manager.Add(NewClient(mockCert()))
assert.Equal(t, 1, manager.Len())
}
Loading

0 comments on commit 282d804

Please sign in to comment.