# DNSSEC-Aware Resolver Downgrade Attacks

In [None]:
import logging
import random
import string
from datetime import datetime
import itertools
import concurrent
import math

import numpy as np
import dns.message, dns.query, dns.rdataclass, dns.rdatatype, dns.flags, dns.exception, dns.name, dns.dnssec
from tqdm import tqdm
import pandas as pd

IN = dns.rdataclass.from_text("IN")
NS = dns.rdatatype.from_text("NS")
SOA = dns.rdatatype.from_text("SOA")
DS = dns.rdatatype.from_text("DS")
A = dns.rdatatype.from_text("A")
TXT = dns.rdatatype.from_text("TXT")
AAAA = dns.rdatatype.from_text("AAAA")
RRSIG = dns.rdatatype.from_text("RRSIG")

ALGORITHMS = [
    dns.dnssec.RSASHA1,
    dns.dnssec.RSASHA256,
    dns.dnssec.RSASHA512,
    dns.dnssec.ECDSAP256SHA256,
    dns.dnssec.ECDSAP384SHA384,
    dns.dnssec.ED25519,
    dns.dnssec.ED448,
]

ZONE = dns.name.from_text('downgrade.dedyn.io')

executor = concurrent.futures.ThreadPoolExecutor(50)

def query(qname, resolver, cd, rdtype=A):
    q = dns.message.make_query(qname, rdtype)
    if cd:
        q.flags = q.flags | dns.flags.CD
    
    if resolver.startswith('https'):
        method = dns.query.https
        where = resolver
    elif resolver.startswith('tls'):
        method = dns.query.tls
        where = resolver[len('tls://'):]
    else:
        method = dns.query.udp
        where = resolver
        
    try:
        return method(q, where=where, timeout=2)
    except dns.exception.Timeout:
        return method(q, where=where, timeout=5)
    
def run(f, args_list):
    results = []    
    futures = [executor.submit(f, *args) for args in args_list]
    with tqdm(total=len(futures)) as pbar:
        for future in concurrent.futures.as_completed(futures):
            pbar.update(1)
            if future.exception():
                logging.warning(f"{future.exception()}")
                results.append({'status': future.exception()})
            else:
                results.append(future.result())
    return results    

## Define Test Zones with Different Combinations of DS and DNSKEY Records

In [None]:
zones = [
    {
        'ds': algos, 
        'dnskey': tuple(sorted(set(algos) - set(remove_dnskeys))),
        'name': dns.name.from_text(
            "-".join(
                [f"ds{a}" for a in sorted(algos)] +
                [f"dnskey{int(a)}" for a in sorted(set(algos) - set(remove_dnskeys))]
            ),
            origin=ZONE
        ),
    }
    for algos in itertools.chain(itertools.combinations(ALGORITHMS, 1), itertools.combinations(ALGORITHMS, 2))
    for remove_dnskeys in [[a for i, a in enumerate(algos) if v[i]] for v in itertools.product([True, False], repeat=len(algos))]
]
zones = pd.DataFrame(zones)
zones = zones.set_index('name')
zones

## Define Resolvers to be Studied: In the Lab (UDP/TCP) and In the Wild (UDP/TCP, TLS, HTTPS)

In [None]:
open_resolvers = [{'resolver_addr': row['IPv4'], 'resolver_name': row['Handle'], 'resolver_group': 'open-named'} for _, row in pd.read_csv("open-resolvers.csv").iterrows()]
lab_resolvers = [{'resolver_addr': row['IPv4'], 'resolver_name': row['Handle'], 'resolver_group': 'lab'} for _, row in pd.read_csv("lab-resolvers.csv").iterrows()]

In [None]:
doh_resolvers = {
    'cloudflare-doh': 'https://cloudflare-dns.com/dns-query',
    'cloudflare-mozilla-doh': 'https://mozilla.cloudflare-dns.com/dns-query',
    'google-doh': 'https://dns.google/dns-query',
    'quad9-doh': 'https://dns.quad9.net/dns-query',
    #'clean-browsing-doh': 'https://security-filter-dns.cleanbrowsing.org/dns-query',
    'adguard-doh': 'https://dns.adguard.com/dns-query',
    'comcast-doh': 'https://doh.xfinity.com/dns-query',
}
doh_resolvers = [{'resolver_addr': addr, 'resolver_name': handle, 'resolver_group': 'open-named'} for handle, addr in doh_resolvers.items()]
dot_resolvers = {
    'cloudflare-dot': 'tls://1.1.1.1',
    'google-dot': 'tls://8.8.8.8',
    'quad9-dot': 'tls://9.9.9.9',
    'clean-browsing-dot': 'tls://185.228.168.9',
    'adguard-dot': 'tls://94.140.14.14',
}
dot_resolvers = [{'resolver_addr': addr, 'resolver_name': handle, 'resolver_group': 'open-named'} for handle, addr in dot_resolvers.items()]

In [None]:
def resolver_transport(row):
    if row['resolver_addr'].startswith('tls'):
        return 'DoT'
    if row['resolver_addr'].startswith('https'):
        return 'DoH'
    return 'UDP/TCP'

resolvers = pd.DataFrame(open_resolvers + lab_resolvers + doh_resolvers + dot_resolvers)
resolvers['resolver_transport'] = resolvers.apply(resolver_transport, axis=1)
resolvers.head(5)

### Determine Resolver Cipher Support

In [None]:
def check_resolver(resolver, algorithm):
    try:
        # order of queries to avoid caching problems?
        r2 = query(dns.name.from_text(f'mitm-at.ds{algorithm}-dnskey{algorithm}', origin=ZONE), resolver['resolver_addr'], cd=False, rdtype=TXT)  # signature invalid
        r1 = query(dns.name.from_text(f'ds{algorithm}-dnskey{algorithm}', origin=ZONE), resolver['resolver_addr'], cd=False, rdtype=TXT)  # signature valid
        return {
            **resolver,
            'algorithm': algorithm,
            'status': 'ok',
            'rcode1': r1.rcode(),
            'rcode2': r2.rcode(),
        }
    except dns.exception.Timeout:
        return {
            **resolver,
            'algorithm': algorithm,
            'status': 'timeout',
        }
    except Exception as e:
        return {
            **resolver,
            'algorithm': algorithm,
            'status': e,
        }

In [None]:
resolver_support_results = run(check_resolver, [(resolver, a) for _, resolver in resolvers.iterrows() for a in ALGORITHMS])
resolver_support_results = pd.DataFrame(resolver_support_results)

In [None]:
def support(row):
    if row['status'] != 'ok':
        return None
    if row['rcode1'] == dns.rcode.Rcode.NOERROR and row['rcode2'] == dns.rcode.Rcode.NOERROR:
        return False
    elif row['rcode1'] == dns.rcode.Rcode.NOERROR and row['rcode2'] == dns.rcode.Rcode.SERVFAIL:
        return True
    elif row['rcode1'] == dns.rcode.Rcode.SERVFAIL and row['rcode2'] == dns.rcode.Rcode.NOERROR:
        logging.warning(f'Weird resolver behavior for {row["resolver_name"]}')
        return None
    elif row['rcode1'] == dns.rcode.Rcode.SERVFAIL and row['rcode2'] == dns.rcode.Rcode.SERVFAIL:
        logging.warning(f'Weird resolver behavior for {row["resolver_name"]}')
        return None
    else:
        logging.warning(f'Resolver {row["resolver_name"]} returned {row["rcode1"]} and {row["rcode2"]}')
    
resolver_support_results['supported'] = resolver_support_results.apply(support, axis=1)

In [None]:
def uncertain_any(s):
    if None in list(s):  # None in s is always false, likely due to pandas' messing with the 'in' operator
        return None
    else:
        return any(s)
    
resolver_support = resolver_support_results.groupby(['resolver_addr', 'algorithm'], dropna=False)[['supported']].agg({
    'supported': [uncertain_any]
}).reset_index().pivot(index='resolver_addr', columns='algorithm', values=('supported', 'uncertain_any')).reset_index()
resolver_support.columns = ['resolver_addr'] + [f'supports_{a}' for a in ALGORITHMS]
resolver_support['support'] = resolver_support.apply(lambda row: tuple(a for a in ALGORITHMS if row[f'supports_{a}'] is True), axis=1)
resolvers = resolvers.set_index('resolver_addr').join(resolver_support.set_index('resolver_addr'))

In [None]:
def row_style(row):
    styles = {
        True: 'color: green;',
        False: 'color: red;',
    }
    return [styles.get(v) for v in row]
    
resolvers.style.apply(row_style, axis=1)

## Define Attack Strategies

In [None]:
attacks = [
    {'name': 'replace signature number with ed448 and fake content', 'instructions': ('rs16', 'at')},
    {'name': 'replace signature number with ed25519 and fake content', 'instructions': ('rs15', 'at')},
    {'name': 'remove all signatures except ed448 and fake content', 'instructions': ('at',) + tuple(f'ds{a}' for a in ALGORITHMS if a < dns.dnssec.ED448)},
    {'name': 'remove all signatures except ed25519 and ed448 and fake content', 'instructions': ('at',) + tuple(f'ds{a}' for a in ALGORITHMS if a < dns.dnssec.ED25519)},
]
attacks = pd.DataFrame(attacks)
attacks['prefix'] = attacks.apply(lambda row: f"mitm-{'-'.join(row['instructions'])}", axis=1)
attacks = attacks.set_index('prefix')
attacks

## Run Attack Evaluation

In [None]:
def check_attack(addr, prefix, zone):
    try:
        qname = dns.name.from_text(prefix, origin=zone)
        r1 = query(qname, addr, cd=False, rdtype=TXT)
        return {
            'resolver_addr': addr,
            'zone': zone,
            'attack': prefix,
            'status': 'ok',
            'rcode': r1.rcode(),
            'evil_content': 'evil' in str(r1),
        }
    except dns.exception.Timeout:
        return {
            'resolver_addr': addr,
            'zone': zone,
            'attack': prefix,
            'status': 'timeout',
        }
    except Exception as e:
        logging.warning(f"Exception: {type(e).__name__}: {e}")
        return {
            'resolver_addr': addr,
            'zone': zone,
            'attack': prefix,
            'status': e,
        }

In [None]:
attack_results = run(check_attack, [(addr, prefix, zone) for addr, _ in resolvers.iterrows() for prefix, _ in attacks.iterrows() for zone, _ in zones.iterrows()][:1000])
attack_results = pd.DataFrame(attack_results)

In [None]:
attack_results[attack_results['evil_content'] == True]

In [None]:
results = attack_results.join(resolvers, on='resolver_addr').join(zones, on='zone').join(attacks, on='attack')

In [None]:
assert len(attack_results) == len(results)

In [None]:
results['supported_ds'] = results.apply(lambda row: tuple(set(row['ds']) & set(row['support'])), axis=1)
results['supported_dnskey'] = results.apply(lambda row: tuple(set(row['dnskey']) & set(row['support'])), axis=1)

In [None]:
def behavior_correct(row):
    if row['status'] != 'ok':
        return None
    
    if row['rcode'] == dns.rcode.Rcode.NOERROR:
        # signature invalid, but no error reported may only happen if no ds or no dnskey was supported
        return not row['supported_ds'] or not row['supported_dnskey']
    elif row['rcode'] == dns.rcode.Rcode.SERVFAIL:
        # desired behavior for invalid signatures
        return True
            
    return None

results['behavior_correct'] = results.apply(behavior_correct, axis=1)

In [None]:
results['zone_prefix'] = results.apply(lambda row: row['zone'][0].decode(), axis=1)
results['zone_config'] = results.apply(lambda row: f"DS: {','.join(str(int(e)) for e in row['ds'])} DNSKEY: {','.join(str(int(e)) for e in row['dnskey'])}", axis=1)

## Show Success Attacks

In [None]:
pd.options.display.max_rows = len(resolvers) * len(attacks)

def values(s):
    return '; '.join(s)

def zone_proportion(s):
    return len(s) / len(zones)

attack_success_rate = results.groupby(['attack', 'name', 'resolver_group', 'resolver_name', 'resolver_addr'], dropna=False).agg({
    'behavior_correct': [len, 'mean']
}).reset_index() #.style.apply(lambda row: ['', 'color: red' if row['behavior_correct']['mean'] < 1 else ''], axis=1)
attack_success_rate[attack_success_rate[('behavior_correct', 'mean')] < 1]

## Generate Attack Reports

In [None]:
results['zone_name'] = results.apply(lambda row: row['zone'].to_text(), axis=1)

In [None]:
results.keys()

In [None]:
for _, row in attack_success_rate[attack_success_rate[('behavior_correct', 'mean')] < 1].iterrows():
    attack = row["attack"][0]
    attack_name = row["name"][0]
    resolver_addr = row["resolver_addr"][0]
    resolver_name = row["resolver_name"][0]
    details = results[(results['attack'] == attack) & (results['resolver_addr'] == resolver_addr)]
    headline = f'## Attack: "{attack_name}"/{attack} at {resolver_name}/{resolver_addr}'
    print(headline)
    print('-' * (len(headline) + 2))
    for a in ALGORITHMS:
        more_details = details[details['zone_name'].str.contains(f'ds{a}')]
        success_rate = 1-more_details['behavior_correct'].mean()
        affected_names = ', '.join(more_details[more_details['behavior_correct'] == False]['zone_prefix'])
        print(f"    Success if DS     with alg {a:2n} present: {success_rate:6.1%} {'+' * math.ceil(success_rate*10):10s} {affected_names}")
    for a in ALGORITHMS:
        success_rate = 1-details[details['zone_name'].str.contains(f'dnskey{a}')]['behavior_correct'].mean()
        print(f"    Success if DNSKEY with alg {a:2n} present: {success_rate:6.1%} {'+' * math.ceil(success_rate*10)}")
    print()