Skip to content
Permalink
Browse files

rgetct: cleanup the code significantly

- remove lots of printing at random
- refactor TLS and X509 SCT checking into separate functions
- only check x509 SCTs in rget for now
  • Loading branch information
philips committed Jul 22, 2019
1 parent 38f5147 commit 1a2ca169900630eaca56d0bd6a48f034f3c093d2
Showing with 89 additions and 53 deletions.
  1. +33 −11 rget/cmd/root.go
  2. +56 −42 rgetct/ct.go
@@ -17,11 +17,13 @@ package cmd
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"time"

homedir "github.com/mitchellh/go-homedir"
@@ -104,17 +106,35 @@ func initConfig() {
}
}

func validSCTs(valid, invalid int, cturl string, logs []loglist.Log) string {
var names []string
for _, l := range logs {
names = append(names, l.Description)
}
return fmt.Sprintf("validated %d/%d SCTs in logs %q ", valid, (valid + invalid), strings.Join(names, ", "))
}

func levelSCTs(valid, invalid int) (string, error) {
switch {
case valid != 0 && invalid == 0:
return "OK", nil
case valid == 0:
return "Error", errors.New("no valid SCTs")
default:
return "Warning", nil
}
}

func get(cmd *cobra.Command, args []string) {
var chain []*x509.Certificate
var valid, invalid int
var totalInvalid int

durl := args[0]

// Step 1: Download the SHA256SUMS that is correct for the URL
prefix, err := rgetwellknown.SumPrefix(durl)
sumsURL := prefix + "SHA256SUMS"
fmt.Printf("Downloading sums file: %v\n", sumsURL)
fmt.Printf("downloading sums: %v\n", sumsURL)
response, err := http.Get(sumsURL)
var sha256file []byte
if err != nil {
@@ -140,6 +160,8 @@ func get(cmd *cobra.Command, args []string) {
sums := rgethash.FromSHA256SumFile(string(sha256file))
cturl := "https://" + sums.Domain() + "." + domain + "." + rgetwellknown.PublicServiceHost

fmt.Printf("validating transparency URL: %v\n", cturl)

hc := &http.Client{Timeout: 30 * time.Second}
ctx := context.Background()
lf := ctutil.NewLogInfo
@@ -156,20 +178,20 @@ func get(cmd *cobra.Command, args []string) {
os.Exit(1)
}

// Get chain served online for TLS connection to site, and check any SCTs
// provided alongside on the connection along the way.
chain, valid, invalid, err = rgetct.GetAndCheckSiteChain(ctx, lf, cturl, ll, hc)
// _ to skip TLS extension SCTs, rget doesn't use those yet
chain, _, err = rgetct.GetSiteSCTs(ctx, cturl, hc)
if err != nil {
fmt.Printf("%s: failed to get cert chain: %v\n", cturl, err)
os.Exit(1)
}
fmt.Printf("Found %d external SCTs for %q, of which %d were validated\n", (valid + invalid), cturl, valid)
totalInvalid += invalid

// Check the chain for embedded SCTs.
valid, invalid = rgetct.CheckChain(ctx, lf, chain, ll, hc)
fmt.Printf("Found %d embedded SCTs for %q, of which %d were validated\n", (valid + invalid), domain, valid)
totalInvalid += invalid
// Check x509 chain SCTs
valid, invalid, logs := rgetct.CheckX509(ctx, lf, chain, ll, hc)
lvl, err := levelSCTs(valid, invalid)
fmt.Printf("%s: x509 SCTs: %s\n", lvl, validSCTs(valid, invalid, cturl, logs))
if err != nil {
os.Exit(1)
}

// create download request
req, err := grab.NewRequest("", durl)
@@ -22,12 +22,12 @@ import (

type logInfoFactory func(*loglist.Log, *http.Client) (*ctutil.LogInfo, error)

// CheckChain iterates over any embedded SCTs in the leaf certificate of the chain
// CheckX509 iterates over any X509 extension SCTs in the leaf certificate of the chain
// and checks those SCTs. Returns the counts of valid and invalid embedded SCTs found.
func CheckChain(ctx context.Context, lf logInfoFactory, chain []*x509.Certificate, ll *loglist.LogList, hc *http.Client) (int, int) {
func CheckX509(ctx context.Context, lf logInfoFactory, chain []*x509.Certificate, ll *loglist.LogList, hc *http.Client) (valid int, invalid int, logs []loglist.Log) {
leaf := chain[0]
if len(leaf.SCTList.SCTList) == 0 {
return 0, 0
return
}

var issuer *x509.Certificate
@@ -47,56 +47,59 @@ func CheckChain(ctx context.Context, lf logInfoFactory, chain []*x509.Certificat
merkleLeaf, err := ct.MerkleTreeLeafForEmbeddedSCT([]*x509.Certificate{leaf, issuer}, 0)
if err != nil {
fmt.Printf("Failed to build Merkle leaf: %v\n", err)
return 0, len(leaf.SCTList.SCTList)
invalid = len(leaf.SCTList.SCTList)
return
}

var valid, invalid int
for i, sctData := range leaf.SCTList.SCTList {
subject := fmt.Sprintf("embedded SCT[%d]", i)
if checkSCT(ctx, lf, subject, merkleLeaf, &sctData, ll, hc) {
ok, log := checkSCT(ctx, lf, subject, merkleLeaf, &sctData, ll, hc)
logs = append(logs, *log)
if ok {
valid++
} else {
invalid++
}
}
return valid, invalid
return
}

// GetAndCheckSiteChain retrieves and returns the chain of certificates presented
// for an HTTPS site. Along the way it checks any external SCTs that are served
// up on the connection alongside the chain. Returns the chain and counts of
// valid and invalid external SCTs found.
func GetAndCheckSiteChain(ctx context.Context, lf logInfoFactory, target string, ll *loglist.LogList, hc *http.Client) ([]*x509.Certificate, int, int, error) {
// GetSiteSCTs retrieves and returns the x509 chain and TLS SCTs presented
// for an HTTPS site.
func GetSiteSCTs(ctx context.Context, target string, hc *http.Client) (chain []*x509.Certificate, tlsSCTs [][]byte, err error) {
u, err := url.Parse(target)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to parse URL: %v", err)
err = fmt.Errorf("failed to parse URL: %v", err)
return
}
if u.Scheme != "https" {
return nil, 0, 0, errors.New("non-https URL provided")
err = errors.New("non-https URL provided")
return
}
host := u.Host
if !strings.Contains(host, ":") {
host += ":443"
}

fmt.Printf("Retrieve certificate chain from TLS connection to %q\n", host)
dialer := net.Dialer{Timeout: hc.Timeout}
conn, err := tls.DialWithDialer(&dialer, "tcp", host, &tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to dial %q: %v", host, err)
err = fmt.Errorf("failed to dial %q: %v", host, err)
return
}
defer conn.Close()

goChain := conn.ConnectionState().PeerCertificates
fmt.Printf("Found chain of length %d\n", len(goChain))

// Convert base crypto/x509.Certificates to our forked x509.Certificate type.
chain := make([]*x509.Certificate, len(goChain))
chain = make([]*x509.Certificate, len(goChain))
var verifiedHostname bool
for i, goCert := range goChain {
cert, err := x509.ParseCertificate(goCert.Raw)
var cert *x509.Certificate
cert, err = x509.ParseCertificate(goCert.Raw)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to convert Go Certificate [%d]: %v", i, err)
err = fmt.Errorf("failed to convert Go Certificate [%d]: %v", i, err)
return
}

if err := cert.VerifyHostname(u.Host); err == nil {
@@ -107,21 +110,33 @@ func GetAndCheckSiteChain(ctx context.Context, lf logInfoFactory, target string,
}

if verifiedHostname == false {
return nil, 0, 0, errors.New("cannot verify host for target")
err = errors.New("cannot verify host for target")
return
}

// Check externally-provided SCTs.
var valid, invalid int
scts := conn.ConnectionState().SignedCertificateTimestamps
tlsSCTs = conn.ConnectionState().SignedCertificateTimestamps

return
}

// CheckTLS iterates over any TLS extension SCTs presented from a connection
// and checks those SCTs. Returns the counts of valid and invalid
// SCTs found.
func CheckTLS(ctx context.Context, scts [][]byte, chain []*x509.Certificate, lf logInfoFactory, target string, ll *loglist.LogList, hc *http.Client) (valid int, invalid int, logs []loglist.Log) {
if len(scts) > 0 {
var merkleLeaf *ct.MerkleTreeLeaf
merkleLeaf, err := ct.MerkleTreeLeafFromChain(chain, ct.X509LogEntryType, 0 /* timestamp added later */)
if err != nil {
fmt.Printf("Failed to build Merkle tree leaf: %v\n", err)
return chain, 0, len(scts), nil
invalid = len(scts)
return
}
for i, sctData := range scts {
subject := fmt.Sprintf("external SCT[%d]", i)
if checkSCT(ctx, lf, subject, merkleLeaf, &x509.SerializedSCT{Val: sctData}, ll, hc) {
ok, log := checkSCT(ctx, lf, subject, merkleLeaf, &x509.SerializedSCT{Val: sctData}, ll, hc)
logs = append(logs, *log)
if ok {
valid++
} else {
invalid++
@@ -130,54 +145,53 @@ func GetAndCheckSiteChain(ctx context.Context, lf logInfoFactory, target string,
}
}

return chain, valid, invalid, nil
return
}

// checkSCT performs checks on an SCT and Merkle tree leaf, performing both
// signature validation and online log inclusion checking. Returns whether
// the SCT is valid.
func checkSCT(ctx context.Context, liFactory logInfoFactory, subject string, merkleLeaf *ct.MerkleTreeLeaf, sctData *x509.SerializedSCT, ll *loglist.LogList, hc *http.Client) bool {
func checkSCT(ctx context.Context, liFactory logInfoFactory, subject string, merkleLeaf *ct.MerkleTreeLeaf, sctData *x509.SerializedSCT, ll *loglist.LogList, hc *http.Client) (result bool, log *loglist.Log) {
sct, err := x509util.ExtractSCT(sctData)
if err != nil {
fmt.Printf("Failed to deserialize %s data: %v\n", subject, err)
fmt.Printf("Data: %x\n", sctData.Val)
return false
return
}
fmt.Printf("Examine %s with timestamp: %d (%v) from logID: %x\n", subject, sct.Timestamp, ct.TimestampToTime(sct.Timestamp), sct.LogID.KeyID[:])
log := ll.FindLogByKeyHash(sct.LogID.KeyID)

// TODO(philips): add verbose logging
// fmt.Printf("Examine %s with timestamp: %d (%v) from logID: %x\n", subject, sct.Timestamp, ct.TimestampToTime(sct.Timestamp), sct.LogID.KeyID[:])
log = ll.FindLogByKeyHash(sct.LogID.KeyID)
if log == nil {
fmt.Printf("Unknown logID: %x, cannot validate %s\n", sct.LogID, subject)
return false
return
}
logInfo, err := liFactory(log, hc)
if err != nil {
fmt.Printf("Failed to build log info for %q log: %v\n", log.Description, err)
return false
return
}

result := true
fmt.Printf("Validate %s against log %q...", subject, logInfo.Description)
result = true
if err := logInfo.VerifySCTSignature(*sct, *merkleLeaf); err != nil {
fmt.Printf("Failed to verify %s signature from log %q: %v\n", subject, log.Description, err)
result = false
} else {
fmt.Printf("Validate %s against log %q... validated\n", subject, log.Description)
}

fmt.Printf("Check %s inclusion against log %q...\n", subject, log.Description)
index, err := logInfo.VerifyInclusion(ctx, *merkleLeaf, sct.Timestamp)
_, err = logInfo.VerifyInclusion(ctx, *merkleLeaf, sct.Timestamp)
if err != nil {
age := time.Since(ct.TimestampToTime(sct.Timestamp))
if age < logInfo.MMD {
fmt.Printf("Failed to verify inclusion proof (%v) but %s timestamp is only %v old, less than log's MMD of %d seconds\n", err, subject, age, log.MaximumMergeDelay)
// TODO(philips): fix this case.
return true
result = true
return
} else {
fmt.Printf("Failed to verify inclusion proof for %s: %v\n", subject, err)
}
return false
result = false
return
}
fmt.Printf("Check %s inclusion against log %q... included at %d\n", subject, log.Description, index)

return result
return
}

0 comments on commit 1a2ca16

Please sign in to comment.
You can’t perform that action at this time.