Skip to content

Commit

Permalink
tcp: Add l3index to tcp_md5sig_key and md5 functions
Browse files Browse the repository at this point in the history
Add l3index to tcp_md5sig_key to represent the L3 domain of a key, and
add l3index to tcp_md5_do_add and tcp_md5_do_del to fill in the key.

With the key now based on an l3index, add the new parameter to the
lookup functions and consider the l3index when looking for a match.

The l3index comes from the skb when processing ingress packets leveraging
the helpers created for socket lookups, tcp_v4_sdif and inet_iif (and the
v6 variants). When the sdif index is set it means the packet ingressed a
device that is part of an L3 domain and inet_iif points to the VRF device.
For egress, the L3 domain is determined from the socket binding and
sk_bound_dev_if.

Signed-off-by: David Ahern <dsahern@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
dsahern authored and davem330 committed Jan 2, 2020
1 parent 534322c commit dea53bb
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 51 deletions.
24 changes: 12 additions & 12 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,9 @@ struct tcp_md5sig_key {
struct hlist_node node;
u8 keylen;
u8 family; /* AF_INET or AF_INET6 */
union tcp_md5_addr addr;
u8 prefixlen;
union tcp_md5_addr addr;
int l3index; /* set if key added with L3 scope */
u8 key[TCP_MD5SIG_MAXKEYLEN];
struct rcu_head rcu;
};
Expand Down Expand Up @@ -1577,34 +1578,33 @@ struct tcp_md5sig_pool {
int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sk_buff *skb);
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, const u8 *newkey, u8 newkeylen,
gfp_t gfp);
int family, u8 prefixlen, int l3index,
const u8 *newkey, u8 newkeylen, gfp_t gfp);
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen);
int family, u8 prefixlen, int l3index);
struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
const struct sock *addr_sk);

#ifdef CONFIG_TCP_MD5SIG
#include <linux/jump_label.h>
extern struct static_key_false tcp_md5_needed;
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk,
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr,
int family);
static inline struct tcp_md5sig_key *
tcp_md5_do_lookup(const struct sock *sk,
const union tcp_md5_addr *addr,
int family)
tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr, int family)
{
if (!static_branch_unlikely(&tcp_md5_needed))
return NULL;
return __tcp_md5_do_lookup(sk, addr, family);
return __tcp_md5_do_lookup(sk, l3index, addr, family);
}

#define tcp_twsk_md5_key(twsk) ((twsk)->tw_md5_key)
#else
static inline struct tcp_md5sig_key *tcp_md5_do_lookup(const struct sock *sk,
const union tcp_md5_addr *addr,
int family)
static inline struct tcp_md5sig_key *
tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr, int family)
{
return NULL;
}
Expand Down
68 changes: 48 additions & 20 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,19 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
hash_location = tcp_parse_md5sig_option(th);
if (sk && sk_fullsock(sk)) {
const union tcp_md5_addr *addr;
int l3index;

/* sdif set, means packet ingressed via a device
* in an L3 domain and inet_iif is set to it.
*/
l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0;
addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
key = tcp_md5_do_lookup(sk, addr, AF_INET);
key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
} else if (hash_location) {
const union tcp_md5_addr *addr;
int sdif = tcp_v4_sdif(skb);
int dif = inet_iif(skb);
int l3index;

/*
* active side is lost. Try to find listening socket through
Expand All @@ -725,8 +731,12 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
if (!sk1)
goto out;

/* sdif set, means packet ingressed via a device
* in an L3 domain and dif is set to it.
*/
l3index = sdif ? dif : 0;
addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
key = tcp_md5_do_lookup(sk1, addr, AF_INET);
key = tcp_md5_do_lookup(sk1, l3index, addr, AF_INET);
if (!key)
goto out;

Expand Down Expand Up @@ -911,6 +921,7 @@ static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
struct request_sock *req)
{
const union tcp_md5_addr *addr;
int l3index;

/* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV
* sk->sk_state == TCP_SYN_RECV -> for Fast Open.
Expand All @@ -924,13 +935,14 @@ static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
* Rcv.Wind.Shift bits:
*/
addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0;
tcp_v4_send_ack(sk, skb, seq,
tcp_rsk(req)->rcv_nxt,
req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale,
tcp_time_stamp_raw() + tcp_rsk(req)->ts_off,
req->ts_recent,
0,
tcp_md5_do_lookup(sk, addr, AF_INET),
tcp_md5_do_lookup(sk, l3index, addr, AF_INET),
inet_rsk(req)->no_srccheck ? IP_REPLY_ARG_NOSRCCHECK : 0,
ip_hdr(skb)->tos);
}
Expand Down Expand Up @@ -990,7 +1002,7 @@ DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
EXPORT_SYMBOL(tcp_md5_needed);

/* Find the Key structure for an address. */
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk,
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr,
int family)
{
Expand All @@ -1010,7 +1022,8 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk,
hlist_for_each_entry_rcu(key, &md5sig->head, node) {
if (key->family != family)
continue;

if (key->l3index && key->l3index != l3index)
continue;
if (family == AF_INET) {
mask = inet_make_mask(key->prefixlen);
match = (key->addr.a4.s_addr & mask) ==
Expand All @@ -1034,7 +1047,8 @@ EXPORT_SYMBOL(__tcp_md5_do_lookup);

static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
const union tcp_md5_addr *addr,
int family, u8 prefixlen)
int family, u8 prefixlen,
int l3index)
{
const struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_key *key;
Expand All @@ -1053,6 +1067,8 @@ static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
hlist_for_each_entry_rcu(key, &md5sig->head, node) {
if (key->family != family)
continue;
if (key->l3index && key->l3index != l3index)
continue;
if (!memcmp(&key->addr, addr, size) &&
key->prefixlen == prefixlen)
return key;
Expand All @@ -1064,23 +1080,26 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
const struct sock *addr_sk)
{
const union tcp_md5_addr *addr;
int l3index;

l3index = l3mdev_master_ifindex_by_index(sock_net(sk),
addr_sk->sk_bound_dev_if);
addr = (const union tcp_md5_addr *)&addr_sk->sk_daddr;
return tcp_md5_do_lookup(sk, addr, AF_INET);
return tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
}
EXPORT_SYMBOL(tcp_v4_md5_lookup);

/* This can be called on a newly created socket, from other files */
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, const u8 *newkey, u8 newkeylen,
gfp_t gfp)
int family, u8 prefixlen, int l3index,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
{
/* Add Key to the list */
struct tcp_md5sig_key *key;
struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_info *md5sig;

key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
if (key) {
/* Pre-existing entry - just update that one. */
memcpy(key->key, newkey, newkeylen);
Expand Down Expand Up @@ -1112,6 +1131,7 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
key->keylen = newkeylen;
key->family = family;
key->prefixlen = prefixlen;
key->l3index = l3index;
memcpy(&key->addr, addr,
(family == AF_INET6) ? sizeof(struct in6_addr) :
sizeof(struct in_addr));
Expand All @@ -1121,11 +1141,11 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
EXPORT_SYMBOL(tcp_md5_do_add);

int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
u8 prefixlen)
u8 prefixlen, int l3index)
{
struct tcp_md5sig_key *key;

key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
if (!key)
return -ENOENT;
hlist_del_rcu(&key->node);
Expand Down Expand Up @@ -1158,6 +1178,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.tcpm_addr;
const union tcp_md5_addr *addr;
u8 prefixlen = 32;
int l3index = 0;

if (optlen < sizeof(cmd))
return -EINVAL;
Expand All @@ -1178,12 +1199,12 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
addr = (union tcp_md5_addr *)&sin->sin_addr.s_addr;

if (!cmd.tcpm_keylen)
return tcp_md5_do_del(sk, addr, AF_INET, prefixlen);
return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index);

if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
return -EINVAL;

return tcp_md5_do_add(sk, addr, AF_INET, prefixlen,
return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
}

Expand Down Expand Up @@ -1311,11 +1332,16 @@ static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
const struct iphdr *iph = ip_hdr(skb);
const struct tcphdr *th = tcp_hdr(skb);
const union tcp_md5_addr *addr;
int genhash;
unsigned char newhash[16];
int genhash, l3index;

/* sdif set, means packet ingressed via a device
* in an L3 domain and dif is set to the l3mdev
*/
l3index = sdif ? dif : 0;

addr = (union tcp_md5_addr *)&iph->saddr;
hash_expected = tcp_md5_do_lookup(sk, addr, AF_INET);
hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
hash_location = tcp_parse_md5sig_option(th);

/* We've parsed the options - do we have a hash? */
Expand All @@ -1341,11 +1367,11 @@ static bool tcp_v4_inbound_md5_hash(const struct sock *sk,

if (genhash || memcmp(hash_location, newhash, 16) != 0) {
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s\n",
net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
&iph->saddr, ntohs(th->source),
&iph->daddr, ntohs(th->dest),
genhash ? " tcp_v4_calc_md5_hash failed"
: "");
: "", l3index);
return true;
}
return false;
Expand Down Expand Up @@ -1431,6 +1457,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
#ifdef CONFIG_TCP_MD5SIG
const union tcp_md5_addr *addr;
struct tcp_md5sig_key *key;
int l3index;
#endif
struct ip_options_rcu *inet_opt;

Expand Down Expand Up @@ -1478,17 +1505,18 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
tcp_initialize_rcv_mss(newsk);

#ifdef CONFIG_TCP_MD5SIG
l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif);
/* Copy over the MD5 key from the original socket */
addr = (union tcp_md5_addr *)&newinet->inet_daddr;
key = tcp_md5_do_lookup(sk, addr, AF_INET);
key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
if (key) {
/*
* We're using one, so create a matching key
* on the newsk structure. If we fail to get
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, addr, AF_INET, 32,
tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index,
key->key, key->keylen, GFP_ATOMIC);
sk_nocaps_add(newsk, NETIF_F_GSO_MASK);
}
Expand Down

0 comments on commit dea53bb

Please sign in to comment.