Skip to content

Commit

Permalink
Add generate_sitemap script and unit tests
Browse files Browse the repository at this point in the history
This change adds generate_sitemap.py script that accepts
base URL as an argument. The script generates a sitemap
index file including individual sitemaps for each ecosystem.

The sitemaps are then converted to a compressed version.

It also implements unit tests to ensure proper functionality
and coverage of generate_sitemap script.
  • Loading branch information
zahraaalizadeh committed Jun 7, 2024
1 parent 4014062 commit f9603c4
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 0 deletions.
140 changes: 140 additions & 0 deletions gcp/appengine/generate_sitemap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate sitemap."""
import gzip
import shutil
import sys
import os
import osv
import osv.logs
import datetime
import argparse
from google.cloud import ndb

from xml.etree.ElementTree import Element, SubElement, ElementTree

_SITEMAPS_DIRECTORY = './sitemap'
_SITEMAP_INDEX_PATH = f'{_SITEMAPS_DIRECTORY}/index.xml'
_SITEMAP_URL_LIMIT = 49999


def fetch_vulnerability_ids(ecosystem: str) -> list[str]:
"""Fetch vulnerabilities' id for the given ecosystem."""
bugs = osv.Bug.query(
osv.Bug.status == osv.BugStatus.PROCESSED,
osv.Bug.public == True, # pylint: disable=singleton-comparison
osv.Bug.ecosystem == ecosystem).order(-osv.Bug.timestamp)
bug_ids = [bug.db_id for bug in bugs]
return bug_ids


def osv_get_ecosystems():
"""Get list of ecosystems."""
query = osv.Bug.query(projection=[osv.Bug.ecosystem], distinct=True)
return sorted([bug.ecosystem[0] for bug in query if bug.ecosystem],
key=str.lower)


def get_sitemap_filename_for_ecosystem(ecosystem: str) -> str:
ecosystem_name = ecosystem.replace(' ', '_').replace('.', '__').strip()
return f'{_SITEMAPS_DIRECTORY}/{ecosystem_name}.xml'


def get_sitemap_url_for_ecosystem(ecosystem: str, base_url: str) -> str:
ecosystem_name = ecosystem.replace(' ', '_').replace('.', '__').strip()
return f'{base_url}/sitemap/{ecosystem_name}.xml'


def generate_sitemap_for_ecosystem(ecosystem: str, base_url: str) -> None:
"""Generate a sitemap for the give n ecosystem."""
os.makedirs(_SITEMAPS_DIRECTORY, exist_ok=True)

vulnerability_ids = fetch_vulnerability_ids(ecosystem)
filename = get_sitemap_filename_for_ecosystem(ecosystem)
urlset = Element(
'urlset', xmlns="http://www.sitemaps.org/schemas/sitemap/0.9")

# TODO: For large ecosystems with over 50,000 vulnerabilities, generate
# multiple sitemaps.
for vuln in vulnerability_ids[:_SITEMAP_URL_LIMIT]:
url = SubElement(urlset, 'url')
loc = SubElement(url, 'loc')
loc.text = f"{base_url}/vulnerability/{vuln}"
lastmod = SubElement(url, 'lastmod')
lastmod.text = datetime.datetime.now().isoformat()

tree = ElementTree(urlset)
tree.write(filename, encoding='utf-8', xml_declaration=True)


def compress_file(file_path: str) -> str:
"""Compress the file using gzip and return the path to the compressed file."""
base, _ = os.path.splitext(file_path)
compressed_file_path = f"{base}.gz"
with open(file_path, 'rb') as f_in:
with gzip.open(compressed_file_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
# Remove the original uncompressed file
os.remove(file_path)
return compressed_file_path


def generate_sitemap_index(ecosystems: set[str], base_url: str) -> None:
"""Generate a sitemap index."""
os.makedirs(_SITEMAPS_DIRECTORY, exist_ok=True)

sitemapindex = Element(
'sitemapindex', xmlns="http://www.sitemaps.org/schemas/sitemap/0.9")

for ecosystem in ecosystems:
sitemap = SubElement(sitemapindex, "sitemap")
loc = SubElement(sitemap, 'loc')
loc.text = get_sitemap_url_for_ecosystem(ecosystem, base_url)
lastmod = SubElement(sitemap, 'lastmod')
lastmod.text = datetime.datetime.now().isoformat()

tree = ElementTree(sitemapindex)
tree.write(_SITEMAP_INDEX_PATH, encoding='utf-8', xml_declaration=True)


def generate_sitemaps(base_url: str) -> None:
"""Generate sitemaps including all vulnerabilities, split by ecosystem."""

# Go over the base ecosystems index. Otherwise we'll have duplicated
# vulnerabilities in the sitemap.
base_ecosystems = {
ecosystem for ecosystem in osv_get_ecosystems() if ':' not in ecosystem
}
for ecosystem in base_ecosystems:
generate_sitemap_for_ecosystem(ecosystem, base_url)
compress_file(get_sitemap_filename_for_ecosystem(ecosystem))

generate_sitemap_index(base_ecosystems, base_url)
compress_file(_SITEMAP_INDEX_PATH)


def main() -> int:
parser = argparse.ArgumentParser(description='Generate sitemaps.')
parser.add_argument(
'--base_url', required=True, help='The base URL for the sitemap entries.')
args = parser.parse_args()
generate_sitemaps(args.base_url)
return 0


if __name__ == '__main__':
_ndb_client = ndb.Client()
osv.logs.setup_gcp_logging('generate_sitemap')
with _ndb_client.context():
sys.exit(main())
179 changes: 179 additions & 0 deletions gcp/appengine/generate_sitemap_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""Sitemap generator tests."""

# limitations under the License.
import unittest
import tempfile
import os
import gzip
from unittest.mock import patch, MagicMock
import generate_sitemap
import osv


class TestSitemapGeneration(unittest.TestCase):
"""Tests to verify the functionality of the sitemap generator script"""

def temp_file(self):
# Create a temporary file for testing
self.test_file = tempfile.NamedTemporaryFile(delete=False)
self.test_file.write(b'This is a test file.')
self.test_file.close()
return self.test_file.name

def test_compress_file(self):
"""Test it compresses the file and removes the original file."""
input_filename = self.temp_file()

# Call the compress_file function
compressed_file_path = generate_sitemap.compress_file(input_filename)

# Verify that the original file is removed
self.assertFalse(os.path.exists(input_filename))

# Verify that the compressed file is created
self.assertTrue(os.path.exists(compressed_file_path))

# Verify the contents of the compressed file
with gzip.open(compressed_file_path, 'rb') as f:
content = f.read()
self.assertEqual(content, b'This is a test file.')

# Clean up compressed file created during the test
os.remove(compressed_file_path)

@patch.object(osv.Bug, 'query')
def test_fetch_vulnerability_ids(self, mock_query):
"""Test it returns the vulnerability ids for ecosystem"""
# Mock the returned query
mock_query.return_value.order.return_value = [
MagicMock(db_id='vuln1'),
MagicMock(db_id='vuln2')
]

result = generate_sitemap.fetch_vulnerability_ids('Go')
self.assertEqual(result, ['vuln1', 'vuln2'])

@patch.object(osv.Bug, 'query')
def test_osv_get_ecosystems(self, mock_query):
"""Test it returns the ecosystems"""
# Mock the returned query
mock_query.return_value = [
MagicMock(ecosystem=['UVI']),
MagicMock(ecosystem=['Go'])
]

result = generate_sitemap.osv_get_ecosystems()
self.assertEqual(result, ['Go', 'UVI'])

@patch('generate_sitemap.fetch_vulnerability_ids')
@patch('generate_sitemap.ElementTree')
@patch('generate_sitemap.os.makedirs')
def test_generate_sitemap_for_ecosystem(self, mock_makedirs,
mock_element_tree, mock_fetch_vulns):
"""Check it generates the sitemap for ecosystem"""
mock_fetch_vulns.return_value = ['vuln1', 'vuln2']
mock_tree = MagicMock()
mock_element_tree.return_value = mock_tree

generate_sitemap.generate_sitemap_for_ecosystem('Go', 'http://example.com')

mock_makedirs.assert_called_once_with('./sitemap', exist_ok=True)
mock_tree.write.assert_called_once_with(
'./sitemap/Go.xml', encoding='utf-8', xml_declaration=True)

@patch('generate_sitemap.fetch_vulnerability_ids')
@patch('generate_sitemap.ElementTree')
@patch('generate_sitemap.os.makedirs')
def test_generate_sitemap_for_ecosystem_with_space(self, mock_makedirs,
mock_element_tree,
mock_fetch_vulns):
""""
Check it creates the sitemap correctly where there is a space in the
ecosystem name.
"""
mock_fetch_vulns.return_value = ['vuln1', 'vuln2']
mock_tree = MagicMock()
mock_element_tree.return_value = mock_tree

generate_sitemap.generate_sitemap_for_ecosystem('Rocky Linux',
'http://example.com')

mock_makedirs.assert_called_once_with('./sitemap', exist_ok=True)
mock_tree.write.assert_called_once_with(
'./sitemap/Rocky_Linux.xml', encoding='utf-8', xml_declaration=True)

@patch('generate_sitemap.fetch_vulnerability_ids')
@patch('generate_sitemap.ElementTree')
@patch('generate_sitemap.os.makedirs')
def test_generate_sitemap_for_ecosystem_with_period(self, mock_makedirs,
mock_element_tree,
mock_fetch_vulns):
""""
Check it creates the sitemap correctly where there is a period in the
ecosystem name.
"""
mock_fetch_vulns.return_value = ['vuln1', 'vuln2']
mock_tree = MagicMock()
mock_element_tree.return_value = mock_tree

generate_sitemap.generate_sitemap_for_ecosystem('crates.io',
'http://example.com')

mock_makedirs.assert_called_once_with('./sitemap', exist_ok=True)
mock_tree.write.assert_called_once_with(
'./sitemap/crates__io.xml', encoding='utf-8', xml_declaration=True)

@patch('generate_sitemap.ElementTree')
@patch('generate_sitemap.os.makedirs')
def test_generate_sitemap_index(self, mock_makedirs, mock_element_tree):
"""Check it generates the sitemap index as expected"""
mock_tree = MagicMock()
mock_element_tree.return_value = mock_tree

generate_sitemap.generate_sitemap_index({'Go', 'UVI'}, 'http://example.com')

mock_makedirs.assert_called_once_with('./sitemap', exist_ok=True)
mock_tree.write.assert_called_once_with(
'./sitemap/index.xml', encoding='utf-8', xml_declaration=True)

@patch('generate_sitemap.generate_sitemap_for_ecosystem')
@patch('generate_sitemap.generate_sitemap_index')
@patch('generate_sitemap.osv_get_ecosystems')
@patch('generate_sitemap.compress_file')
def test_generate_sitemap(self, mock_compress_file, mock_get_ecosystems,
mock_generate_index, mock_generate_sitemap):
"""
Check the outer wrapper generates the ecosystems' sitemaps as well as
sitemap index.
"""
mock_get_ecosystems.return_value = ['Go', 'UVI:Library', 'Android']

generate_sitemap.generate_sitemaps('http://example.com')

self.assertEqual(mock_generate_sitemap.call_count, 2)
mock_generate_sitemap.assert_any_call('Go', 'http://example.com')
mock_generate_sitemap.assert_any_call('Android', 'http://example.com')

self.assertEqual(mock_compress_file.call_count, 3)
mock_compress_file.assert_any_call('./sitemap/Go.xml')
mock_compress_file.assert_any_call('./sitemap/Android.xml')
mock_compress_file.assert_any_call('./sitemap/index.xml')

mock_generate_index.assert_called_once_with({'Android', 'Go'},
'http://example.com')


if __name__ == '__main__':
unittest.main()

0 comments on commit f9603c4

Please sign in to comment.