Skip to content

Commit

Permalink
fix: format TXT records based on specs
Browse files Browse the repository at this point in the history
- Chunk strings longer than 255
- Quote TXT/SPF strings and escape special characters

Fixes: #21
Fixes: #20
Fixes: caddy-dns/route53#29
Reference: https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat
  • Loading branch information
aymanbagabas committed Apr 14, 2023
1 parent b7898f7 commit 98401c8
Show file tree
Hide file tree
Showing 3 changed files with 521 additions and 59 deletions.
176 changes: 117 additions & 59 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"log"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -55,6 +55,99 @@ func (p *Provider) init(ctx context.Context) {
p.client = r53.NewFromConfig(cfg)
}

func chunkString(s string, chunkSize int) []string {
var chunks []string
for i := 0; i < len(s); i += chunkSize {
end := i + chunkSize
if end > len(s) {
end = len(s)
}
chunks = append(chunks, s[i:end])
}
return chunks
}

func parseRecordSet(set types.ResourceRecordSet) []libdns.Record {
records := make([]libdns.Record, 0)

// Route53 returns TXT & SPF records with quotes around them.
// https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat
var ttl int64
if set.TTL != nil {
ttl = *set.TTL
}

rtype := string(set.Type)
for _, record := range set.ResourceRecords {
value := *record.Value
switch rtype {
case "TXT", "SPF":
rows := strings.Split(value, "\n")
for i, row := range rows {
parts := strings.Split(row, `" "`)
if len(parts) > 0 {
parts[0] = strings.TrimPrefix(parts[0], `"`)
parts[len(parts)-1] = strings.TrimSuffix(parts[len(parts)-1], `"`)
}

// Join parts
row = strings.Join(parts, "")
row = unquote(row)
rows[i] = row

records = append(records, libdns.Record{
Name: *set.Name,
Value: row,
Type: rtype,
TTL: time.Duration(ttl) * time.Second,
})
}
default:
records = append(records, libdns.Record{
Name: *set.Name,
Value: value,
Type: rtype,
TTL: time.Duration(ttl) * time.Second,
})
}

}

return records
}

func marshalRecord(record libdns.Record) []types.ResourceRecord {
resourceRecords := make([]types.ResourceRecord, 0)

// Route53 requires TXT & SPF records to be quoted.
// https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat
switch record.Type {
case "TXT", "SPF":
strs := make([]string, 0)
if len(record.Value) > 255 {
strs = append(strs, chunkString(record.Value, 255)...)
} else {
strs = append(strs, record.Value)
}

// Quote strings
for i, str := range strs {
strs[i] = quote(str)
}

// Finally join chunks with spaces
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(strings.Join(strs, " ")),
})
default:
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(record.Value),
})
}

return resourceRecords
}

func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) ([]libdns.Record, error) {
getRecordsInput := &r53.ListResourceRecordSetsInput{
HostedZoneId: aws.String(zoneID),
Expand All @@ -79,6 +172,10 @@ func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) (
}

recordSets = append(recordSets, getRecordResult.ResourceRecordSets...)
for _, s := range recordSets {
records = append(records, parseRecordSet(s)...)
}

if getRecordResult.IsTruncated {
getRecordsInput.StartRecordName = getRecordResult.NextRecordName
getRecordsInput.StartRecordType = getRecordResult.NextRecordType
Expand All @@ -88,31 +185,6 @@ func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) (
}
}

for _, rrset := range recordSets {
for _, rrsetRecord := range rrset.ResourceRecords {
rtype := rrset.Type
value := *rrsetRecord.Value
// Route53 returns TXT & SPF records with quotes around them.
// https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat
switch rtype {
case types.RRTypeTxt, types.RRTypeSpf:
var err error
value, err = strconv.Unquote(value)
if err != nil {
return records, fmt.Errorf("Error unquoting TXT/SPF record: %s", err)
}
}
record := libdns.Record{
Name: *rrset.Name,
Value: value,
Type: string(rtype),
TTL: time.Duration(*rrset.TTL) * time.Second,
}

records = append(records, record)
}
}

return records, nil
}

Expand Down Expand Up @@ -170,24 +242,19 @@ func (p *Provider) createRecord(ctx context.Context, zoneID string, record libdn
switch record.Type {
case "TXT":
return p.updateRecord(ctx, zoneID, record, zone)
case "SPF":
record.Value = strconv.Quote(record.Value)
}

resourceRecords := marshalRecord(record)
createInput := &r53.ChangeResourceRecordSetsInput{
ChangeBatch: &types.ChangeBatch{
Changes: []types.Change{
{
Action: types.ChangeActionCreate,
ResourceRecordSet: &types.ResourceRecordSet{
Name: aws.String(libdns.AbsoluteName(record.Name, zone)),
ResourceRecords: []types.ResourceRecord{
{
Value: aws.String(record.Value),
},
},
TTL: aws.Int64(int64(record.TTL.Seconds())),
Type: types.RRType(record.Type),
Name: aws.String(libdns.AbsoluteName(record.Name, zone)),
ResourceRecords: resourceRecords,
TTL: aws.Int64(int64(record.TTL.Seconds())),
Type: types.RRType(record.Type),
},
},
},
Expand All @@ -206,26 +273,19 @@ func (p *Provider) createRecord(ctx context.Context, zoneID string, record libdn
func (p *Provider) updateRecord(ctx context.Context, zoneID string, record libdns.Record, zone string) (libdns.Record, error) {
resourceRecords := make([]types.ResourceRecord, 0)
// AWS Route53 TXT record value must be enclosed in quotation marks on update
switch record.Type {
case "SPF", "TXT":
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(strconv.Quote(record.Value)),
})
}
if record.Type == "TXT" {
txtRecords, err := p.getTxtRecordsFor(ctx, zoneID, zone, record.Name)
if err != nil {
return record, err
}
for _, r := range txtRecords {
if record.Value != r.Value {
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(strconv.Quote(r.Value)),
})
resourceRecords = append(resourceRecords, marshalRecord(r)...)
}
}
}

resourceRecords = append(resourceRecords, marshalRecord(record)...)
updateInput := &r53.ChangeResourceRecordSetsInput{
ChangeBatch: &types.ChangeBatch{
Changes: []types.Change{
Expand Down Expand Up @@ -255,28 +315,24 @@ func (p *Provider) deleteRecord(ctx context.Context, zoneID string, record libdn
action := types.ChangeActionDelete
resourceRecords := make([]types.ResourceRecord, 0)
// AWS Route53 TXT record value must be enclosed in quotation marks on update
switch record.Type {
case "SPF", "TXT":
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(strconv.Quote(record.Value)),
})
}
if record.Type == "TXT" {
txtRecords, err := p.getTxtRecordsFor(ctx, zoneID, zone, record.Name)
if err != nil {
return record, err
}

switch {
case len(txtRecords) > 0 && txtRecords[0].Value != record.Value,
len(txtRecords) > 1:
// If there is only one record, we can delete the entire record set.
case len(txtRecords) == 1:
resourceRecords = append(resourceRecords, marshalRecord(record)...)
// If there are multiple records, we need to upsert the remaining records.
case len(txtRecords) > 1:
action = types.ChangeActionUpsert
resourceRecords = make([]types.ResourceRecord, 0)
}
for _, r := range txtRecords {
if record.Value != r.Value {
resourceRecords = append(resourceRecords, types.ResourceRecord{
Value: aws.String(strconv.Quote(r.Value)),
})
for _, r := range txtRecords {
if record.Value != r.Value {
resourceRecords = append(resourceRecords, marshalRecord(r)...)
}
}
}
}
Expand All @@ -298,6 +354,8 @@ func (p *Provider) deleteRecord(ctx context.Context, zoneID string, record libdn
HostedZoneId: aws.String(zoneID),
}

log.Printf("deleteInput: %+v", deleteInput)

err := p.applyChange(ctx, deleteInput)
if err != nil {
var nfe *types.InvalidChangeBatch
Expand Down
Loading

0 comments on commit 98401c8

Please sign in to comment.