Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VCS support #40

Merged
merged 9 commits into from Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions catkin_virtualenv/scripts/combine_requirements
Expand Up @@ -24,7 +24,8 @@ import re
import sys

from collections import namedtuple
from packaging.requirements import Requirement
from packaging.requirements import Requirement, InvalidRequirement
from catkin_virtualenv.requirements import VcsRequirement

comment_regex = re.compile(r'\s*#.*')

Expand All @@ -40,7 +41,17 @@ def combine_requirements(requirements_list, output_file):
contents = comment_regex.sub('', contents)
for requirement_string in contents.splitlines():
if requirement_string and not requirement_string.isspace():
requirement = Requirement(requirement_string)
# Support for VCS requirements. First try to match a SemVer requirement then a VCS requirement.
try:
requirement = Requirement(requirement_string)
except InvalidRequirement as semver_err:
try:
requirement = VcsRequirement(requirement_string)
except InvalidRequirement as vcs_err:
raise RuntimeError(
"Could not match requirement {} for VCS ({}) or SemVer ({})".format(
requirement_string, str(vcs_err), str(semver_err)))

if requirement.name not in combined_requirements:
combined_requirements[requirement.name] = CombinedRequirement(
requirement=requirement,
Expand Down
153 changes: 44 additions & 109 deletions catkin_virtualenv/src/catkin_virtualenv/requirements.py
Expand Up @@ -18,115 +18,50 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import re

from copy import copy
from enum import Enum
from functools import total_ordering


@total_ordering
class SemVer(object):
version_regex = re.compile("^[0-9\.]+$")
from packaging.requirements import InvalidRequirement

class VcsRequirement(object):
'''A non-semver requirement from a version control system.
eg. svn+http://myrepo/svn/MyApp#egg=MyApp
'''

# Borrowing https://github.com/pypa/pipenv/tree/dde2e52cb8bc9bfca7af6c6b1a4576faf00e84f1/pipenv/vendor/requirements
VCS_SCHEMES = [
'git',
'git+https',
'git+ssh',
'git+git',
'hg+http',
'hg+https',
'hg+static-http',
'hg+ssh',
'svn',
'svn+svn',
'svn+http',
'svn+https',
'svn+ssh',
'bzr+http',
'bzr+https',
'bzr+ssh',
'bzr+sftp',
'bzr+ftp',
'bzr+lp',
]

name_regex = re.compile(
r'^(?P<scheme>{0})://'.format(r'|'.join(
[scheme.replace('+', r'\+') for scheme in VCS_SCHEMES])) +
r'((?P<login>[^/@]+)@)?'
r'(?P<path>[^#@]+)'
r'(@(?P<revision>[^#]+))?'
r'(#egg=(?P<name>[^&]+))?$'
)

def __init__(self, string):
# type: (str) -> None
if not self.version_regex.match(string):
raise RuntimeError("Invalid requirement version {}, must match {}".format(
string, self.version_regex.pattern))

self._version = [int(v) for v in string.split('.')]

def __eq__(self, other):
# type: (SemVer, SemVer) -> bool
return self._version == other._version

def __lt__(self, other):
# type: (SemVer, SemVer) -> bool
return self._version < other._version

def __str__(self):
# type: (SemVer) -> str
return '.'.join([str(v) for v in self._version])


class ReqType(Enum):
GREATER = ">="
EXACT = "=="
ANY = None


class ReqMergeException(RuntimeError):
def __init__(self, req, other):
# type: (Requirement, Requirement) -> None
self.req = req
self.other = other

def __str__(self):
# type: () -> str
return "Cannot merge requirements {} and {}".format(self.req, self.other)


class Requirement(object):
name_regex = re.compile("^[][A-Za-z0-9._-]+$")

def __init__(self, string):
# type: (str) -> None
for operation in [ReqType.GREATER, ReqType.EXACT, ReqType.ANY]:
fields = string.split(operation.value)
if len(fields) > 1:
break

self.name = fields[0].lower()
if not self.name_regex.match(self.name):
raise RuntimeError("Invalid requirement name {}, must match {}".format(
string, self.name_regex.pattern))

self.operation = operation
try:
self.version = SemVer(fields[1])
except IndexError:
self.version = None

def __str__(self):
# type: () -> str
return "{}{}{}".format(
self.name,
self.operation.value if self.operation.value else "",
self.version if self.version else ""
)

def __add__(self, other):
# type: (Requirement) -> Requirement
if self.name != other.name:
raise ReqMergeException(self, other)

operation_map = {
self.operation: self,
other.operation: other,
}
operation_set = set(operation_map)

if operation_set == {ReqType.EXACT}:
if self.version == other.version:
return copy(self)
else:
raise ReqMergeException(self, other)

elif operation_set == {ReqType.EXACT, ReqType.GREATER}:
if operation_map[ReqType.EXACT].version >= operation_map[ReqType.GREATER].version:
return copy(operation_map[ReqType.EXACT])
else:
raise ReqMergeException(self, other)

elif operation_set == {ReqType.GREATER}:
out = copy(operation_map[ReqType.GREATER])
out.version = max(self.version, other.version)
return out
match = self.name_regex.search(string)
if match is None:
raise InvalidRequirement("No match for {}".format(self.name_regex.pattern))

elif ReqType.ANY in operation_set:
if len(operation_set) == 1:
return copy(self)
else:
out = copy(self)
out.operation = (operation_set - {ReqType.ANY}).pop()
out.version = operation_map[out.operation].version
return out
self.name = match.group('name')
if self.name is None:
raise InvalidRequirement("No project name '#egg=<name>' was provided")
89 changes: 14 additions & 75 deletions catkin_virtualenv/test/test_requirements.py
Expand Up @@ -18,84 +18,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest

from catkin_virtualenv.requirements import Requirement, SemVer
from catkin_virtualenv.requirements import VcsRequirement
from packaging.requirements import InvalidRequirement


class TestRequirements(unittest.TestCase):

def test_string_transform(self):
reqs = [
"module",
"module>=0.0.5",
"module==1.0.5",
]
def test_vcs_requirement_parse(self):
string = "git+git://github.com/pytransitions/transitions@dev-async#egg=transitions"
req = VcsRequirement(string)
self.assertEqual(req.name, "transitions")

for req in reqs:
self.assertEqual(str(Requirement(req)), req)
def test_vcs_requirement_parse_no_name(self):
string = "git+git://github.com/pytransitions/transitions@dev-async"
with self.assertRaises(InvalidRequirement):
_ = VcsRequirement(string)

def test_failed_transform(self):
reqs = [
"$$asdfasdf",
"module$==1.0.5",
"module>=0.a.5",
"module===1.0.5",
"module=1.0.5",
]

for req in reqs:
with self.assertRaises(RuntimeError) as cm:
print(Requirement(req))
print(cm.exception)

def test_addition(self):
reqs = [
("module==1.0.0", "module", "module==1.0.0"),
("module==1.1.0", "module>=0.4", "module==1.1.0"),
("module==1.2.0", "module>=0.8", "module==1.2.0"),
("module", "module", "module"),
("module>=0.5", "module", "module>=0.5"),
("module>=0.3", "module>=10.0.8", "module>=10.0.8"),
]

for req in reqs:
# Check addition both ways for commutation
for direction in ((0, 1), (1, 0)):
left = Requirement(req[direction[0]])
right = Requirement(req[direction[1]])
result = left + right
self.assertEqual(str(result), req[2])

# Make sure we're returning a new object from the addition method
self.assertIsNot(right, result)
self.assertIsNot(left, result)

def test_failed_addition(self):
reqs = [
("module==1.0.0", "module==2.0.0"),
("module==1.0.0", "module>=1.0.4"),
("module==1.0.0", "other_module"),
("module", "other_module"),
]

for req in reqs:
with self.assertRaises(RuntimeError) as cm:
print(Requirement(req[0]) + Requirement(req[1]))
print(cm.exception)


class TestSemVer(unittest.TestCase):

def test_comparison(self):
versions = [
("1.0.0", "0"),
("3.0.0", "0.1"),
("1.0.0", "0.1.1.1.1.1"),
("4.0.0", "0.1234.1"),
("44.0.0", "003.12"),
("0.5", "0.0.4"),
("0.0.5", "0.0.4"),
("1.22.3", "1.002.3"),
("1.10.0", "1.9.0"),
]
for version in versions:
self.assertTrue(SemVer(version[0]) > SemVer(version[1]))
def test_vcs_requirement_parse_invalid(self):
string = "asdf"
with self.assertRaises(InvalidRequirement):
_ = VcsRequirement(string)