-
Notifications
You must be signed in to change notification settings - Fork 2
/
tls.go
149 lines (120 loc) · 3.96 KB
/
tls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package httpc
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net/http"
"os"
"path/filepath"
)
var defaultTransport = http.DefaultTransport.(*http.Transport).Clone()
// setupClientCertificateFromBytes reads the provided client certificate / key and CA certificate
// from memory and creates / modifies a tls.Config object
func setupClientCertificateFromBytes(clientCert, clientKey, caCert []byte, tlsConfig *tls.Config) (*tls.Config, error) {
// Load the key pair
clientKeyCert, err := tls.X509KeyPair(clientCert, clientKey)
if err != nil {
return nil, fmt.Errorf("failed to load client key / certificate: %w", err)
}
if tlsConfig == nil {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// If required, instantiate CA certificate pool
if tlsConfig.RootCAs == nil {
caCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to obtain system CA pool: %w", err)
}
tlsConfig.RootCAs = caCertPool
}
// Append CA to pool
if !tlsConfig.RootCAs.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to add CA certificate to pool")
}
// Append client certificate to config
tlsConfig.Certificates = append(tlsConfig.Certificates, clientKeyCert)
return tlsConfig, nil
}
// readClientCertificateFiles reads the provided client certificate / key and CA certificate
// files
func readClientCertificateFiles(certFile, keyFile, caFile string) ([]byte, []byte, []byte, error) {
// Read the client certificate / key file
clientCert, clientKey, err := readclientKeyCertificate(certFile, keyFile)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to read / decode client key / certificate file: %w", err)
}
// Read CA certificate from file
caCert, err := os.ReadFile(filepath.Clean(caFile))
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
return clientCert, clientKey, caCert, nil
}
// readclientKeyCertificate reads both client certificate and key from their
// respective files
func readclientKeyCertificate(certFile, keyFile string) ([]byte, []byte, error) {
// Read the client certificate file
clientCert, err := os.ReadFile(filepath.Clean(certFile))
if err != nil {
return nil, nil, err
}
// Read the client key file
clientKey, err := os.ReadFile(filepath.Clean(keyFile))
if err != nil {
return nil, nil, err
}
return clientCert, clientKey, nil
}
// setupClientCertificate uses the provided tls.Certificate and caCert bytes to create/modify tls.Config
func setupClientCertificate(clientCertWithKey tls.Certificate, caChain []*x509.Certificate, tlsConfig *tls.Config) (*tls.Config, error) {
if clientCertWithKey.PrivateKey == nil {
return nil, fmt.Errorf("supplied certificate does not have a private key")
}
if len(caChain) == 0 {
return nil, fmt.Errorf("no ca certificate(s) supplied")
}
if tlsConfig == nil {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// If required, instantiate CA certificate pool
if tlsConfig.RootCAs == nil {
caCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to obtain system CA pool: %w", err)
}
tlsConfig.RootCAs = caCertPool
}
for _, cert := range caChain {
tlsConfig.RootCAs.AddCert(cert)
}
// Append client certificate to config
tlsConfig.Certificates = append(tlsConfig.Certificates, clientCertWithKey)
return tlsConfig, nil
}
const pemTypeCertificate = "CERTIFICATE"
// ParseCAChain takes a file of PEM encoded things and returns the CERTIFICATEs in order
// taken and adapted from crypto/tls
func ParseCAChain(caCert []byte) ([]*x509.Certificate, error) {
var caChain []*x509.Certificate
for len(caCert) > 0 {
var block *pem.Block
block, caCert = pem.Decode(caCert)
if block == nil {
break
}
if block.Type != pemTypeCertificate || len(block.Headers) != 0 {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
caChain = append(caChain, cert)
}
return caChain, nil
}