Skip to content

Commit

Permalink
net/tcp: Calculate TCP-AO traffic keys
Browse files Browse the repository at this point in the history
Add traffic key calculation the way it's described in RFC5926.
Wire it up to tcp_finish_connect() and cache the new keys straight away
on already established TCP connections.

Co-developed-by: Francesco Ruggeri <fruggeri@arista.com>
Signed-off-by: Francesco Ruggeri <fruggeri@arista.com>
Co-developed-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Dmitry Safonov <dima@arista.com>
  • Loading branch information
0x7f454c46 authored and intel-lab-lkp committed Oct 27, 2022
1 parent 00718ec commit 980990b
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 1 deletion.
5 changes: 5 additions & 0 deletions include/net/tcp.h
Expand Up @@ -2088,6 +2088,11 @@ struct tcp_sock_af_ops {
struct tcp_ao_key *(*ao_lookup)(const struct sock *sk,
struct sock *addr_sk,
int sndid, int rcvid);
int (*ao_calc_key_sk)(struct tcp_ao_key *mkt,
u8 *key,
const struct sock *sk,
__be32 sisn, __be32 disn,
bool send);
#endif
};

Expand Down
44 changes: 43 additions & 1 deletion include/net/tcp_ao.h
Expand Up @@ -88,23 +88,57 @@ struct tcp_ao_info {
};

#ifdef CONFIG_TCP_AO
/* TCP-AO structures and functions */

struct tcp4_ao_context {
__be32 saddr;
__be32 daddr;
__be16 sport;
__be16 dport;
__be32 sisn;
__be32 disn;
};

struct tcp6_ao_context {
struct in6_addr saddr;
struct in6_addr daddr;
__be16 sport;
__be16 dport;
__be32 sisn;
__be32 disn;
};

int tcp_parse_ao(struct sock *sk, int cmd, unsigned short int family,
sockptr_t optval, int optlen);
int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
unsigned int len);
void tcp_ao_destroy_sock(struct sock *sk);
int tcp_ao_cache_traffic_keys(const struct sock *sk, struct tcp_ao_info *ao,
struct tcp_ao_key *ao_key);
struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
const union tcp_ao_addr *addr,
int family, int sndid, int rcvid, u16 port);
/* ipv4 specific functions */
int tcp_v4_parse_ao(struct sock *sk, int optname, sockptr_t optval, int optlen);
struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
int sndid, int rcvid);
int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
const struct sock *sk,
__be32 sisn, __be32 disn, bool send);
/* ipv6 specific functions */
int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
const struct sock *sk, __be32 sisn,
__be32 disn, bool send);
struct tcp_ao_key *tcp_v6_ao_lookup(const struct sock *sk,
struct sock *addr_sk,
int sndid, int rcvid);
int tcp_v6_parse_ao(struct sock *sk, int cmd,
sockptr_t optval, int optlen);
#else
void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb);
void tcp_ao_connect_init(struct sock *sk);

#else /* CONFIG_TCP_AO */

static inline struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
const union tcp_ao_addr *addr,
int family, int sndid, int rcvid, u16 port)
Expand All @@ -115,6 +149,14 @@ static inline struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
static inline void tcp_ao_destroy_sock(struct sock *sk)
{
}

static inline void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
{
}

static inline void tcp_ao_connect_init(struct sock *sk)
{
}
#endif

#endif /* _TCP_AO_H */
180 changes: 180 additions & 0 deletions net/ipv4/tcp_ao.c
Expand Up @@ -16,6 +16,42 @@
#include <net/tcp.h>
#include <net/ipv6.h>

int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
unsigned int len)
{
struct crypto_pool_ahash hp;
struct scatterlist sg;
int ret;

if (crypto_pool_get(mkt->crypto_pool_id, (struct crypto_pool *)&hp))
goto clear_hash_noput;

if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp.req),
mkt->key, mkt->keylen))
goto clear_hash;

ret = crypto_ahash_init(hp.req);
if (ret)
goto clear_hash;

sg_init_one(&sg, ctx, len);
ahash_request_set_crypt(hp.req, &sg, key, len);
crypto_ahash_update(hp.req);

/* TODO: Revisit on how to get different output length */
ret = crypto_ahash_final(hp.req);
if (ret)
goto clear_hash;

crypto_pool_put();
return 0;
clear_hash:
crypto_pool_put();
clear_hash_noput:
memset(key, 0, tcp_ao_digest_size(mkt));
return 1;
}

static struct tcp_ao_key *tcp_ao_do_lookup_rcvid(struct sock *sk, u8 keyid)
{
struct tcp_sock *tp = tcp_sk(sk);
Expand Down Expand Up @@ -181,6 +217,47 @@ void tcp_ao_destroy_sock(struct sock *sk)
kfree_rcu(ao, rcu);
}

/* 4 tuple and ISNs are expected in NBO */
static int tcp_v4_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
__be32 saddr, __be32 daddr,
__be16 sport, __be16 dport,
__be32 sisn, __be32 disn)
{
/* See RFC5926 3.1.1 */
struct kdf_input_block {
u8 counter;
u8 label[6];
struct tcp4_ao_context ctx;
__be16 outlen;
} __packed tmp;

tmp.counter = 1;
memcpy(tmp.label, "TCP-AO", 6);
tmp.ctx.saddr = saddr;
tmp.ctx.daddr = daddr;
tmp.ctx.sport = sport;
tmp.ctx.dport = dport;
tmp.ctx.sisn = sisn;
tmp.ctx.disn = disn;
tmp.outlen = htons(tcp_ao_digest_size(mkt) * 8); /* in bits */

return tcp_ao_calc_traffic_key(mkt, key, &tmp, sizeof(tmp));
}

int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
const struct sock *sk,
__be32 sisn, __be32 disn, bool send)
{
if (send)
return tcp_v4_ao_calc_key(mkt, key, sk->sk_rcv_saddr,
sk->sk_daddr, htons(sk->sk_num),
sk->sk_dport, sisn, disn);
else
return tcp_v4_ao_calc_key(mkt, key, sk->sk_daddr,
sk->sk_rcv_saddr, sk->sk_dport,
htons(sk->sk_num), disn, sisn);
}

struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
int sndid, int rcvid)
{
Expand All @@ -189,6 +266,103 @@ struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
return tcp_ao_do_lookup(sk, addr, AF_INET, sndid, rcvid, 0);
}

int tcp_ao_cache_traffic_keys(const struct sock *sk, struct tcp_ao_info *ao,
struct tcp_ao_key *ao_key)
{
u8 *traffic_key = snd_other_key(ao_key);
int ret;

ret = tcp_sk(sk)->af_specific->ao_calc_key_sk(ao_key, traffic_key, sk,
ao->lisn, ao->risn, true);
if (ret)
return ret;

traffic_key = rcv_other_key(ao_key);
return tcp_sk(sk)->af_specific->ao_calc_key_sk(ao_key, traffic_key, sk,
ao->lisn, ao->risn,
false);
}

void tcp_ao_connect_init(struct sock *sk)
{
struct tcp_ao_info *ao_info;
struct tcp_ao_key *key;
struct tcp_sock *tp = tcp_sk(sk);
union tcp_ao_addr *addr;
int family;

ao_info = rcu_dereference_protected(tp->ao_info,
lockdep_sock_is_held(sk));
if (!ao_info)
return;

/* Remove all keys that don't match the peer */
family = sk->sk_family;
if (family == AF_INET)
addr = (union tcp_ao_addr *)&sk->sk_daddr;
#if IS_ENABLED(CONFIG_IPV6)
else if (family == AF_INET6)
addr = (union tcp_ao_addr *)&sk->sk_v6_daddr;
#endif
else
return;

hlist_for_each_entry_rcu(key, &ao_info->head, node) {
if (tcp_ao_key_cmp(key, addr, key->prefixlen, family,
-1, -1, sk->sk_dport) == 0)
continue;

if (key == ao_info->current_key)
ao_info->current_key = NULL;
if (key == ao_info->rnext_key)
ao_info->rnext_key = NULL;
hlist_del_rcu(&key->node);
crypto_pool_release(key->crypto_pool_id);
atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
kfree_rcu(key, rcu);
}

key = tp->af_specific->ao_lookup(sk, sk, -1, -1);
if (key) {
/* if current_key or rnext_key were not provided,
* use the first key matching the peer
*/
if (!ao_info->current_key)
ao_info->current_key = key;
if (!ao_info->rnext_key)
ao_info->rnext_key = key;
tp->tcp_header_len += tcp_ao_len(key);

ao_info->lisn = htonl(tp->write_seq);
ao_info->snd_sne = 0;
ao_info->snd_sne_seq = tp->write_seq;
} else {
WARN_ON_ONCE(1);
rcu_assign_pointer(tp->ao_info, NULL);
kfree(ao_info);
}
}

void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
{
struct tcp_ao_info *ao;
struct tcp_ao_key *key;

ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
lockdep_sock_is_held(sk));
if (!ao)
return;

ao->risn = tcp_hdr(skb)->seq;

ao->rcv_sne = 0;
ao->rcv_sne_seq = ntohl(tcp_hdr(skb)->seq);

hlist_for_each_entry_rcu(key, &ao->head, node) {
tcp_ao_cache_traffic_keys(sk, ao, key);
}
}

static int tcp_ao_current_rnext(struct sock *sk, u16 tcpa_flags,
u8 tcpa_sndid, u8 tcpa_rcvid)
{
Expand Down Expand Up @@ -682,6 +856,12 @@ static int tcp_ao_add_cmd(struct sock *sk, unsigned short int family,
if (ret < 0)
goto err_free_sock;

/* Change this condition if we allow adding keys in states
* like close_wait, syn_sent or fin_wait...
*/
if (sk->sk_state == TCP_ESTABLISHED)
tcp_ao_cache_traffic_keys(sk, ao_info, key);

tcp_ao_link_mkt(ao_info, key);
if (first) {
sk_gso_disable(sk);
Expand Down
1 change: 1 addition & 0 deletions net/ipv4/tcp_input.c
Expand Up @@ -6052,6 +6052,7 @@ void tcp_finish_connect(struct sock *sk, struct sk_buff *skb)
struct tcp_sock *tp = tcp_sk(sk);
struct inet_connection_sock *icsk = inet_csk(sk);

tcp_ao_finish_connect(sk, skb);
tcp_set_state(sk, TCP_ESTABLISHED);
icsk->icsk_ack.lrcvtime = tcp_jiffies32;

Expand Down
1 change: 1 addition & 0 deletions net/ipv4/tcp_ipv4.c
Expand Up @@ -2275,6 +2275,7 @@ static const struct tcp_sock_af_ops tcp_sock_ipv4_specific = {
#ifdef CONFIG_TCP_AO
.ao_lookup = tcp_v4_ao_lookup,
.ao_parse = tcp_v4_parse_ao,
.ao_calc_key_sk = tcp_v4_ao_calc_key_sk,
#endif
};
#endif
Expand Down
1 change: 1 addition & 0 deletions net/ipv4/tcp_output.c
Expand Up @@ -3666,6 +3666,7 @@ static void tcp_connect_init(struct sock *sk)
if (tp->af_specific->md5_lookup(sk, sk))
tp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED;
#endif
tcp_ao_connect_init(sk);

/* If user gave his TCP_MAXSEG, record it to clamp */
if (tp->rx_opt.user_mss)
Expand Down
40 changes: 40 additions & 0 deletions net/ipv6/tcp_ao.c
Expand Up @@ -13,6 +13,46 @@
#include <net/tcp.h>
#include <net/ipv6.h>

int tcp_v6_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
const struct in6_addr *saddr,
const struct in6_addr *daddr,
__be16 sport, __be16 dport,
__be32 sisn, __be32 disn)
{
struct kdf_input_block {
u8 counter;
u8 label[6];
struct tcp6_ao_context ctx;
__be16 outlen;
} __packed tmp;

tmp.counter = 1;
memcpy(tmp.label, "TCP-AO", 6);
tmp.ctx.saddr = *saddr;
tmp.ctx.daddr = *daddr;
tmp.ctx.sport = sport;
tmp.ctx.dport = dport;
tmp.ctx.sisn = sisn;
tmp.ctx.disn = disn;
tmp.outlen = htons(tcp_ao_digest_size(mkt) * 8); /* in bits */

return tcp_ao_calc_traffic_key(mkt, key, &tmp, sizeof(tmp));
}

int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
const struct sock *sk, __be32 sisn,
__be32 disn, bool send)
{
if (send)
return tcp_v6_ao_calc_key(mkt, key, &sk->sk_v6_rcv_saddr,
&sk->sk_v6_daddr, htons(sk->sk_num),
sk->sk_dport, sisn, disn);
else
return tcp_v6_ao_calc_key(mkt, key, &sk->sk_v6_daddr,
&sk->sk_v6_rcv_saddr, sk->sk_dport,
htons(sk->sk_num), disn, sisn);
}

struct tcp_ao_key *tcp_v6_ao_do_lookup(const struct sock *sk,
const struct in6_addr *addr,
int sndid, int rcvid)
Expand Down
1 change: 1 addition & 0 deletions net/ipv6/tcp_ipv6.c
Expand Up @@ -1928,6 +1928,7 @@ static const struct tcp_sock_af_ops tcp_sock_ipv6_specific = {
#ifdef CONFIG_TCP_AO
.ao_lookup = tcp_v6_ao_lookup,
.ao_parse = tcp_v6_parse_ao,
.ao_calc_key_sk = tcp_v6_ao_calc_key_sk,
#endif
};
#endif
Expand Down

0 comments on commit 980990b

Please sign in to comment.