Skip to content
Permalink
Browse files

Bug 1041584 - move balrog whitelist checks out of AUSDatabase (#121).…

… r=bhearsum
  • Loading branch information...
nurav authored and mozbhearsum committed Sep 15, 2016
1 parent 9fdf16f commit 53f48d4519df5fc709cbd6b28097af2d57ebef37
@@ -16,6 +16,8 @@ def isSpecialURL(url, specialForceHosts):


def isForbiddenUrl(url, product, whitelistedDomains):
if whitelistedDomains is None:
whitelistedDomains = []
domain = urlparse(url)[1]
if domain not in whitelistedDomains:
logging.warning("Forbidden domain: %s", domain)
@@ -173,6 +173,39 @@ def shouldServeUpdate(self, updateQuery):

return True

def containsForbiddenDomain(self, product, whitelistedDomains):
"""Returns True if the blob contains any file URLs that contain a
domain that we're not allowed to serve updates to."""
# Check the top level URLs, if the exist.
for c in self.get('fileUrls', {}).values():
# New-style
if isinstance(c, dict):
for from_ in c.values():
for url in from_.values():
if isForbiddenUrl(url, product, whitelistedDomains):
return True
# Old-style
else:
if isForbiddenUrl(c, product, whitelistedDomains):
return True

# And also the locale-level URLs.
for platform in self.get('platforms', {}).values():
for locale in platform.get('locales', {}).values():
for type_ in ('partial', 'complete'):
if type_ in locale and 'fileUrl' in locale[type_]:
if isForbiddenUrl(locale[type_]['fileUrl'], product,
whitelistedDomains):
return True
for type_ in ('partials', 'completes'):
for update in locale.get(type_, {}):
if 'fileUrl' in update:
if isForbiddenUrl(update["fileUrl"], product,
whitelistedDomains):
return True

return False


class SeparatedFileUrlsMixin(object):

@@ -706,3 +739,8 @@ def getInnerXML(self, updateQuery, update_type, whitelistedDomains, specialForce

def getFooterXML(self, updateQuery, update_type, whitelistedDomains, specialForceHosts):
return '</update>'

def containsForbiddenDomain(self, product, whitelistedDomains):
if isForbiddenUrl(self.get('detailsUrl', None), product, whitelistedDomains):
return True
return False
@@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs):
logger_name = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__)
self.__class__.log = logging.getLogger(logger_name)

def validate(self):
def validate(self, product, whitelistedDomains):
"""Raises a BlobValidationError if the blob is invalid."""
self.log.debug('Validating blob %s' % self)
validator = jsonschema.Draft4Validator(self.getSchema())
@@ -80,6 +80,9 @@ def validate(self):
if errors:
raise BlobValidationError("Invalid blob! See 'errors' for details.", errors)

if self.containsForbiddenDomain(product, whitelistedDomains):
raise ValueError("Blob contains forbidden domain(s)")

def getResponseProducts(self):
# Usually returns None. If the Blob is a SuperBlob, it returns the list
# of return products.
@@ -123,3 +126,6 @@ def getFooterXML(self, updateQuery, update_type, whitelistedDomains, specialForc

def getInnerXML(self, updateQuery, update_type, whitelistedDomains, specialForceHosts):
raise NotImplementedError()

def containsForbiddenDomain(self, product, whitelistedDomains):
raise NotImplementedError()
@@ -60,3 +60,13 @@ def getInnerXML(self, updateQuery, update_type, whitelistedDomains, specialForce
platformData["filesize"], vendorInfo["version"]))

return vendorXML

def containsForbiddenDomain(self, product, whitelistedDomains):
"""Returns True if the blob contains any file URLs that contain a
domain that we're not allowed to serve updates to."""
for vendor in self.get('vendors', {}).values():
for platform in vendor.get('platforms', {}).values():
if 'fileUrl' in platform:
if isForbiddenUrl(platform["fileUrl"], product, whitelistedDomains):
return True
return False
@@ -15,3 +15,7 @@ def getResponseProducts(self):
def shouldServeUpdate(self, updateQuery):
# Since a superblob update will always be returned.
return True

def containsForbiddenDomain(self, product, whitelistedDomains):
# Since SuperBlobs don't have any URLs
return False
@@ -87,3 +87,15 @@ def getFooterXML(self, updateQuery, update_type, whitelistedDomains, specialForc
return ' </addons>'
else:
return None

def containsForbiddenDomain(self, product, whitelistedDomains):
"""Returns True if the blob contains any file URLs that contain a
domain that we're not allowed to serve updates to."""

for addon in self.get('addons', {}).values():
for platform in addon.get('platforms', {}).values():
if 'fileUrl' in platform:
if isForbiddenUrl(platform["fileUrl"], product, whitelistedDomains):
return True

return False
@@ -23,3 +23,7 @@ def shouldServeUpdate(self, updateQuery):
if requestIMEI is not None:
return self.isWhitelisted(requestIMEI)
return False

def containsForbiddenDomain(self, product, whitelistedDomains):
# Since WhitelistBlobs have no URLs
return False
@@ -21,7 +21,6 @@
import dictdiffer.merge

from auslib.global_state import cache, dbo
from auslib.AUS import isForbiddenUrl
from auslib.blobs.base import createBlob
from auslib.util.comparison import string_compare, version_compare
from auslib.util.timestamp import getMillisecondTimestamp
@@ -1278,39 +1277,6 @@ def __init__(self, db, metadata, dialect):
def setDomainWhitelist(self, domainWhitelist):
self.domainWhitelist = domainWhitelist

# TODO: This should really be part of the blob class(es) because it depends
# on a lot of blob schema specific stuff.
def containsForbiddenDomain(self, data, product):
"""Returns True if "data" contains any file URLs that contain a
domain that we're not allowed to serve updates to."""
# Check the top level URLs, if the exist.
for c in data.get('fileUrls', {}).values():
# New-style
if isinstance(c, dict):
for from_ in c.values():
for url in from_.values():
if isForbiddenUrl(url, product, self.domainWhitelist):
return True
# Old-style
else:
if isForbiddenUrl(c, product, self.domainWhitelist):
return True

# And also the locale-level URLs.
for platform in data.get('platforms', {}).values():
for locale in platform.get('locales', {}).values():
for type_ in ('partial', 'complete'):
if type_ in locale and 'fileUrl' in locale[type_]:
if isForbiddenUrl(locale[type_]['fileUrl'], product, self.domainWhitelist):
return True
for type_ in ('partials', 'completes'):
for update in locale.get(type_, {}):
if 'fileUrl' in update:
if isForbiddenUrl(update["fileUrl"], product, self.domainWhitelist):
return True

return False

def getReleases(self, name=None, product=None, limit=None, transaction=None):
self.log.debug("Looking for releases with:")
self.log.debug("name: %s", name)
@@ -1421,11 +1387,9 @@ def insert(self, changed_by, transaction=None, dryrun=False, **columns):
if not self.db.hasPermission(changed_by, "release", "create", columns["product"], transaction):
raise PermissionDeniedError("%s is not allowed to create releases for product %s" % (changed_by, columns["product"]))

blob.validate()
blob.validate(columns["product"], self.domainWhitelist)
if columns["name"] != blob["name"]:
raise ValueError("name in database (%s) does not match name in blob (%s)" % (columns["name"], blob["name"]))
if self.containsForbiddenDomain(blob, columns["product"]):
raise ValueError("Release blob contains forbidden domain.")
columns["data"] = blob.getJSON()

if not dryrun:
@@ -1471,12 +1435,11 @@ def update(self, where, what, changed_by, old_data_version, transaction=None, dr
raise PermissionDeniedError("%s is not allowed to mark %s products read only" % (changed_by, what.get("product")))

if blob:
blob.validate()
blob.validate(what.get("product", current_release["product"]),
self.domainWhitelist)
name = what.get("name", name)
if name != blob["name"]:
raise ValueError("name in database (%s) does not match name in blob (%s)" % (name, blob.get("name")))
if self.containsForbiddenDomain(blob, what.get("product", current_release["product"])):
raise ValueError("Release blob contains forbidden domain.")
what['data'] = blob.getJSON()
if not dryrun:
for release in current_releases:
@@ -1559,9 +1522,7 @@ def addLocaleToRelease(self, name, product, platform, locale, data, old_data_ver
if a not in releaseBlob['platforms']:
releaseBlob['platforms'][a] = {'alias': platform}

releaseBlob.validate()
if self.containsForbiddenDomain(releaseBlob, product):
raise ValueError("Release blob contains forbidden domain.")
releaseBlob.validate(product, self.domainWhitelist)
what = dict(data=releaseBlob.getJSON())

super(Releases, self).update(where=where, what=what, changed_by=changed_by, old_data_version=old_data_version,
@@ -899,7 +899,7 @@ def testPutExistingRelease(self):

def testGMPReleasePut(self):

ret = self._put('/releases/gmprel', data=dict(name='gmprel', product='GMP',
ret = self._put('/releases/gmprel', data=dict(name='gmprel', product='a',
blob="""
{
"name": "gmprel",
@@ -927,7 +927,7 @@ def testGMPReleasePut(self):
r = dbo.releases.t.select().where(dbo.releases.name == 'gmprel').execute().fetchall()
self.assertEquals(len(r), 1)
self.assertEquals(r[0]['name'], 'gmprel')
self.assertEquals(r[0]['product'], 'GMP')
self.assertEquals(r[0]['product'], 'a')
self.assertEquals(json.loads(r[0]['data']), json.loads("""
{
"name": "gmprel",

1 comment on commit 53f48d4

@TaskClusterRobot

This comment has been minimized.

Please sign in to comment.
You can’t perform that action at this time.