Skip to content

Commit

Permalink
Fix custom PostgreSQL domains not being used in table definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
lw committed Dec 7, 2018
1 parent 54f7a5c commit 9da2d14
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cms/db/__init__.py
Expand Up @@ -80,7 +80,7 @@

# Instantiate or import these objects.

version = 40
version = 41

engine = create_engine(config.database, echo=config.database_debug,
pool_timeout=60, pool_recycle=120)
Expand Down
46 changes: 26 additions & 20 deletions cms/db/types.py
Expand Up @@ -25,6 +25,7 @@
import sqlalchemy
from sqlalchemy import DDL, event, TypeDecorator, Unicode
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.compiler import compiles

from . import metadata

Expand Down Expand Up @@ -58,10 +59,6 @@ class Codename(TypeDecorator):
domain_name = "CODENAME"
impl = Unicode

@classmethod
def compile(cls, dialect=None):
return cls.domain_name

@classmethod
def get_create_command(cls):
return DDL("CREATE DOMAIN %(domain)s VARCHAR "
Expand All @@ -78,6 +75,11 @@ def get_drop_command(cls):
event.listen(metadata, "after_drop", Codename.get_drop_command())


@compiles(Codename)
def compile_codename(element, compiler, **kw):
return Codename.domain_name


class Filename(TypeDecorator):
"""Check that the column is a filename using a simple alphabet.
Expand All @@ -91,10 +93,6 @@ class Filename(TypeDecorator):
domain_name = "FILENAME"
impl = Unicode

@classmethod
def compile(cls, dialect=None):
return cls.domain_name

@classmethod
def get_create_command(cls):
return DDL("CREATE DOMAIN %(domain)s VARCHAR "
Expand All @@ -113,6 +111,11 @@ def get_drop_command(cls):
event.listen(metadata, "after_drop", Filename.get_drop_command())


@compiles(Filename)
def compile_filename(element, compiler, **kw):
return Filename.domain_name


class FilenameSchema(TypeDecorator):
"""Check that the column is a filename schema using a simple alphabet.
Expand All @@ -132,10 +135,6 @@ class FilenameSchema(TypeDecorator):
domain_name = "FILENAME_SCHEMA"
impl = Unicode

@classmethod
def compile(cls, dialect=None):
return cls.domain_name

@classmethod
def get_create_command(cls):
return DDL("CREATE DOMAIN %(domain)s VARCHAR "
Expand All @@ -150,6 +149,11 @@ def get_drop_command(cls):
context={"domain": cls.domain_name})


@compiles(FilenameSchema)
def compile_filename_schema(element, compiler, **kw):
return FilenameSchema.domain_name


event.listen(metadata, "before_create", FilenameSchema.get_create_command())
event.listen(metadata, "after_drop", FilenameSchema.get_drop_command())

Expand All @@ -170,10 +174,6 @@ class FilenameSchemaArray(TypeDecorator):
domain_name = "FILENAME_SCHEMA_ARRAY"
impl = CastingArray(Unicode)

@classmethod
def compile(cls, dialect=None):
return cls.domain_name

@classmethod
def get_create_command(cls):
# Postgres allows the condition "<sth> <op> ALL (<array>)" that
Expand Down Expand Up @@ -203,6 +203,11 @@ def get_drop_command(cls):
event.listen(metadata, "after_drop", FilenameSchemaArray.get_drop_command())


@compiles(FilenameSchemaArray)
def compile_filename_schema_array(element, compiler, **kw):
return FilenameSchemaArray.domain_name


class Digest(TypeDecorator):
"""Check that the column is a valid SHA1 hex digest.
Expand All @@ -221,10 +226,6 @@ class Digest(TypeDecorator):
# The fake digest used to mark a file as deleted in the backend.
TOMBSTONE = "x"

@classmethod
def compile(cls, dialect=None):
return cls.domain_name

@classmethod
def get_create_command(cls):
return DDL("CREATE DOMAIN %(domain)s VARCHAR "
Expand All @@ -240,3 +241,8 @@ def get_drop_command(cls):

event.listen(metadata, "before_create", Digest.get_create_command())
event.listen(metadata, "after_drop", Digest.get_drop_command())


@compiles(Digest)
def compile_digest(element, compiler, **kw):
return Digest.domain_name
163 changes: 163 additions & 0 deletions cmscontrib/updaters/update_41.py
@@ -0,0 +1,163 @@
#!/usr/bin/env python3

# Contest Management System - http://cms-dev.github.io/
# Copyright © 2018 Luca Wehrstedt <luca.wehrstedt@gmail.com>

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""A class to update a dump created by CMS.
Used by DumpImporter and DumpUpdater.
This updater makes sure that the constraints on codenames, filenames,
filename schemas and digests hold.
"""

import logging
import re


logger = logging.getLogger(__name__)

# Fields that contain codenames.
CODENAME_FIELDS = {
"Contest": ["name"],
"Task": ["name"],
"Testcase": ["codename"],
"Admin": ["username"],
"User": ["username"],
"Team": ["code"]
}
# Fields that contain filenames.
FILENAME_FIELDS = {
"Executable": ["filename"],
"UserTestManager": ["filename"],
"UserTestExecutable": ["filename"],
"PrintJob": ["filename"],
"Attachment": ["filename"],
"Manager": ["filename"]
}
# Fields that contain filename schemas.
FILENAME_SCHEMA_FIELDS = {
"File": ["filename"],
"UserTestFile": ["filename"]
}
# Fields that contain arrays of filename schemas.
FILENAME_SCHEMA_ARRAY_FIELDS = {
"Task": ["submission_format"]
}
# Fields that contain digests.
DIGEST_FIELDS = {
"Statement": ["digest"],
"Attachment": ["digest"],
"Manager": ["digest"],
"Testcase": ["input", "output"],
"UserTest": ["input"],
"UserTestFile": ["digest"],
"UserTestManager": ["digest"],
"UserTestResult": ["output"],
"UserTestExecutable": ["digest"],
"File": ["digest"],
"Executable": ["digest"],
"PrintJob": ["digest"]
}


class Updater:

def __init__(self, data):
assert data["_version"] == 40
self.objs = data

self.bad_codenames = []
self.bad_filenames = []
self.bad_filename_schemas = []
self.bad_digests = []

def check_codename(self, class_, attr, codename):
if not re.match("^[A-Za-z0-9_-]+$", codename):
self.bad_codenames.append("%s.%s" % (class_, attr))

def check_filename(self, class_, attr, filename):
if not re.match('^[A-Za-z0-9_.-]+$', filename) \
or filename in {',', '..'}:
self.bad_filenames.append("%s.%s" % (class_, attr))

def check_filename_schema(self, class_, attr, schema):
if not re.match('^[A-Za-z0-9_.-]+(\.%%l)?$', schema) \
or schema in {'.', '..'}:
self.bad_filename_schemas.append("%s.%s" % (class_, attr))

def check_digest(self, class_, attr, digest):
if not re.match('^([0-9a-f]{40}|x)$', digest):
self.bad_digests.append("%s.%s" % (class_, attr))

def run(self):
for k, v in self.objs.items():
if k.startswith("_"):
continue
if v["_class"] in CODENAME_FIELDS:
for attr in CODENAME_FIELDS[v["_class"]]:
self.check_codename(v["_class"], attr, v[attr])
if v["_class"] in FILENAME_FIELDS:
for attr in FILENAME_FIELDS[v["_class"]]:
self.check_filename(v["_class"], attr, v[attr])
if v["_class"] in FILENAME_SCHEMA_FIELDS:
for attr in FILENAME_SCHEMA_FIELDS[v["_class"]]:
self.check_filename_schema(v["_class"], attr, v[attr])
if v["_class"] in FILENAME_SCHEMA_ARRAY_FIELDS:
for attr in FILENAME_SCHEMA_ARRAY_FIELDS[v["_class"]]:
for schema in v[attr]:
self.check_filename_schema(v["_class"], attr, schema)
if v["_class"] in DIGEST_FIELDS:
for attr in DIGEST_FIELDS[v["_class"]]:
self.check_digest(v["_class"], attr, v[attr])

bad = False

if self.bad_codenames:
logger.error(
"The following fields contained invalid codenames: %s. "
"They can only contain letters, digits, underscores and dashes."
% ", ".join(self.bad_codenames))
bad = True

if self.bad_filenames:
logger.error(
"The following fields contained invalid filenames: %s. "
"They can only contain letters, digits, underscores, dashes "
"and periods and cannot be '.' or '..'."
% ", ".join(self.bad_filenames))
bad = True

if self.bad_filename_schemas:
logger.error(
"The following fields contained invalid filename schemas: %s. "
"They can only contain letters, digits, underscores, dashes "
"and periods, end with '.%%l' and cannot be '.' or '..'."
% ", ".join(self.bad_filename_schemas))
bad = True

if self.bad_digests:
logger.error(
"The following fields contained invalid digests: %s. "
"They must be 40-character long lowercase hex values, or 'x'."
% ", ".join(self.bad_digests))
bad = True

if bad:
raise ValueError("Some data was invalid.")

return self.objs

0 comments on commit 9da2d14

Please sign in to comment.