Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checks if a client host is allowed in case it's the localhost. #8342

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
126 changes: 74 additions & 52 deletions dbms/src/Access/AllowedClientHosts.cpp
Expand Up @@ -9,6 +9,7 @@
#include <ext/scope_guard.h>
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm/find_first_of.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <ifaddrs.h>


Expand All @@ -23,29 +24,64 @@ namespace ErrorCodes
namespace
{
using IPAddress = Poco::Net::IPAddress;
using IPSubnet = AllowedClientHosts::IPSubnet;
const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};

const AllowedClientHosts::IPSubnet ALL_ADDRESSES = AllowedClientHosts::IPSubnet{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};
const IPAddress & getIPV6Loopback()
{
static const IPAddress ip("::1");
return ip;
}

bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip)
{
static const IPAddress prefix("::ffff:127.0.0.0");
/// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix`
/// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255).
return (ip & IPAddress(104, IPAddress::IPv6)) == prefix;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

104?

Copy link
Member Author

@vitlibar vitlibar Dec 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to reset the lowest 24 bits in ::ffff:127.x.y.z to be able to check the first 128-24=104 bits. I've added a comment to the code.

}

IPAddress toIPv6(const IPAddress & addr)
/// Converts an address to IPv6.
/// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1".
IPAddress toIPv6(const IPAddress & ip)
{
if (addr.family() == IPAddress::IPv6)
return addr;
IPAddress v6;
if (ip.family() == IPAddress::IPv6)
v6 = ip;
else
v6 = IPAddress("::ffff:" + ip.toString());

if (addr.isLoopback())
return IPAddress("::1");
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6))
v6 = getIPV6Loopback();

return IPAddress("::FFFF:" + addr.toString());
return v6;
}

IPAddress maskToIPv6(const IPAddress & mask)
/// Converts a subnet to IPv6.
IPSubnet toIPv6(const IPSubnet & subnet)
{
if (mask.family() == IPAddress::IPv6)
return mask;
IPSubnet v6;
if (subnet.prefix.family() == IPAddress::IPv6)
v6.prefix = subnet.prefix;
else
v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString());

return IPAddress(96, IPAddress::IPv6) | toIPv6(mask);
}
if (subnet.mask.family() == IPAddress::IPv6)
v6.mask = subnet.mask;
else
v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString());

v6.prefix = v6.prefix & v6.mask;

// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6.prefix))
v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)};

return v6;
}

/// Helper function for isAddressOfHost().
bool isAddressOfHostImpl(const IPAddress & address, const String & host)
{
IPAddress addr_v6 = toIPv6(address);
Expand Down Expand Up @@ -93,15 +129,15 @@ namespace
return false;
}


/// Cached version of isAddressOfHostImpl(). We need to cache DNS requests.
/// Whether a specified address is one of the addresses of a specified host.
bool isAddressOfHost(const IPAddress & address, const String & host)
{
/// We need to cache DNS requests.
static SimpleCache<decltype(isAddressOfHostImpl), isAddressOfHostImpl> cache;
return cache(address, host);
}


/// Helper function for isAddressOfLocalhost().
std::vector<IPAddress> getAddressesOfLocalhostImpl()
{
std::vector<IPAddress> addresses;
Expand All @@ -114,7 +150,7 @@ namespace

int err = getifaddrs(&ifa_begin);
if (err)
return {IPAddress{"127.0.0.1"}, IPAddress{"::1"}};
return {getIPV6Loopback()};

for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next)
{
Expand All @@ -134,15 +170,15 @@ namespace
return addresses;
}


/// Checks if a specified address pointers to the localhost.
bool isLocalAddress(const IPAddress & address)
/// Whether a specified address is one of the addresses of the localhost.
bool isAddressOfLocalhost(const IPAddress & address)
{
/// We need to cache DNS requests.
static const std::vector<IPAddress> local_addresses = getAddressesOfLocalhostImpl();
return boost::range::find(local_addresses, address) != local_addresses.end();
return boost::range::find(local_addresses, toIPv6(address)) != local_addresses.end();
}


/// Helper function for getHostByAddress().
String getHostByAddressImpl(const IPAddress & address)
{
Poco::Net::SocketAddress sock_addr(address, 0);
Expand All @@ -160,10 +196,10 @@ namespace
return host;
}


/// Cached version of getHostByAddressImpl(). We need to cache DNS requests.
/// Returns the host name by its address.
String getHostByAddress(const IPAddress & address)
{
/// We need to cache DNS requests.
static SimpleCache<decltype(getHostByAddressImpl), &getHostByAddressImpl> cache;
return cache(address);
}
Expand Down Expand Up @@ -203,7 +239,7 @@ AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src)
AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src)
{
addresses = src.addresses;
loopback = src.loopback;
localhost = src.localhost;
subnets = src.subnets;
host_names = src.host_names;
host_regexps = src.host_regexps;
Expand All @@ -212,28 +248,14 @@ AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & s
}


AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src)
{
*this = src;
}


AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src)
{
addresses = std::move(src.addresses);
loopback = src.loopback;
subnets = std::move(src.subnets);
host_names = std::move(src.host_names);
host_regexps = std::move(src.host_regexps);
compiled_host_regexps = std::move(src.compiled_host_regexps);
return *this;
}
AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default;


void AllowedClientHosts::clear()
{
addresses.clear();
loopback = false;
localhost = false;
subnets.clear();
host_names.clear();
host_regexps.clear();
Expand All @@ -250,10 +272,11 @@ bool AllowedClientHosts::empty() const
void AllowedClientHosts::addAddress(const IPAddress & address)
{
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) == addresses.end())
addresses.push_back(addr_v6);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return;
addresses.push_back(addr_v6);
if (addr_v6.isLoopback())
loopback = true;
localhost = true;
}


Expand All @@ -265,18 +288,14 @@ void AllowedClientHosts::addAddress(const String & address)

void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
IPSubnet subnet_v6;
subnet_v6.prefix = toIPv6(subnet.prefix);
subnet_v6.mask = maskToIPv6(subnet.mask);
IPSubnet subnet_v6 = toIPv6(subnet);

if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6))
{
addAddress(subnet_v6.prefix);
return;
}

subnet_v6.prefix = subnet_v6.prefix & subnet_v6.mask;

if (boost::range::find(subnets, subnet_v6) == subnets.end())
subnets.push_back(subnet_v6);
}
Expand Down Expand Up @@ -314,8 +333,11 @@ void AllowedClientHosts::addSubnet(const String & subnet)

void AllowedClientHosts::addHostName(const String & host_name)
{
if (boost::range::find(host_names, host_name) == host_names.end())
host_names.push_back(host_name);
if (boost::range::find(host_names, host_name) != host_names.end())
return;
host_names.push_back(host_name);
if (boost::iequals(host_name, "localhost"))
localhost = true;
}


Expand Down Expand Up @@ -360,7 +382,7 @@ bool AllowedClientHosts::contains(const IPAddress & address) const
if (boost::range::find(addresses, addr_v6) != addresses.end())
return true;

if (loopback && isLocalAddress(addr_v6))
if (localhost && isAddressOfLocalhost(addr_v6))
return true;

/// Check `ip_subnets`.
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Access/AllowedClientHosts.h
Expand Up @@ -94,7 +94,7 @@ class AllowedClientHosts
void compileRegexps() const;

std::vector<IPAddress> addresses;
bool loopback = false;
bool localhost = false;
std::vector<IPSubnet> subnets;
std::vector<String> host_names;
std::vector<String> host_regexps;
Expand Down