Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations to secure_tempfile.py and models.py #5534

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion securedrop/journalist_app/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def update_submission_preferences():
# The UI prompt ("prevent") is the opposite of the setting ("allow"):
flash(gettext("Preferences saved."), "submission-preferences-success")
value = not bool(request.form.get('prevent_document_uploads'))
InstanceConfig.set('allow_document_uploads', value)
InstanceConfig.set_allow_document_uploads(value)
return redirect(url_for('admin.manage_config'))

@view.route('/add', methods=('GET', 'POST'))
Expand Down
50 changes: 16 additions & 34 deletions securedrop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,13 @@ def collection(self) -> 'List[Union[Submission, Reply]]':
return collection

@property
def fingerprint(self):
def fingerprint(self) -> 'Optional[str]':
return current_app.crypto_util.get_fingerprint(self.filesystem_id)

@fingerprint.setter
def fingerprint(self, value):
raise NotImplementedError

@fingerprint.deleter
def fingerprint(self):
raise NotImplementedError
rmol marked this conversation as resolved.
Show resolved Hide resolved

@property
def public_key(self) -> str:
def public_key(self) -> 'Optional[str]':
return current_app.crypto_util.get_pubkey(self.filesystem_id)

@public_key.setter
def public_key(self, value: str) -> None:
raise NotImplementedError

@public_key.deleter
def public_key(self) -> None:
raise NotImplementedError

def to_json(self) -> 'Dict[str, Union[str, bool, int, str]]':
docs_msg_count = self.documents_messages_count()

Expand Down Expand Up @@ -219,7 +203,7 @@ def is_file(self) -> bool:
def is_message(self) -> bool:
return self.filename.endswith("msg.gpg")

def to_json(self) -> "Dict[str, Union[str, int, bool]]":
def to_json(self) -> 'Dict[str, Any]':
rmol marked this conversation as resolved.
Show resolved Hide resolved
seen_by = {
f.journalist.uuid for f in SeenFile.query.filter(SeenFile.file_id == self.id)
if f.journalist
Expand Down Expand Up @@ -302,7 +286,7 @@ def __init__(self,
def __repr__(self) -> str:
return '<Reply %r>' % (self.filename)

def to_json(self) -> "Dict[str, Union[str, int, bool]]":
def to_json(self) -> 'Dict[str, Any]':
rmol marked this conversation as resolved.
Show resolved Hide resolved
journalist_username = "deleted"
journalist_first_name = ""
journalist_last_name = ""
Expand Down Expand Up @@ -360,18 +344,16 @@ class InvalidUsernameException(Exception):
class FirstOrLastNameError(Exception):
"""Generic error for names that are invalid."""

def __init__(self, msg):
msg = 'Invalid first or last name.'
def __init__(self, msg: str) -> None:
super(FirstOrLastNameError, self).__init__(msg)


class InvalidNameLength(FirstOrLastNameError):
"""Raised when attempting to create a Journalist with an invalid name length."""

def __init__(self, name):
self.name_len = len(name)
if self.name_len > Journalist.MAX_NAME_LEN:
msg = "Name too long (len={})".format(self.name_len)
def __init__(self, name: str) -> None:
name_len = len(name)
msg = "Name too long (len={})".format(name_len)
super(InvalidNameLength, self).__init__(msg)


Expand Down Expand Up @@ -527,7 +509,7 @@ def check_username_acceptable(cls, username: str) -> None:
"for internal use by the software.")

@classmethod
def check_name_acceptable(cls, name):
def check_name_acceptable(cls, name: str) -> None:
# Enforce a reasonable maximum length for names
if len(name) > cls.MAX_NAME_LEN:
raise InvalidNameLength(name)
Expand Down Expand Up @@ -721,12 +703,12 @@ def generate_api_token(self, expiration: int) -> str:
return s.dumps({'id': self.id}).decode('ascii') # type:ignore

@staticmethod
def validate_token_is_not_expired_or_invalid(token):
def validate_token_is_not_expired_or_invalid(token: str) -> bool:
s = TimedJSONWebSignatureSerializer(current_app.config['SECRET_KEY'])
try:
s.loads(token)
except BadData:
return None
return False

return True

Expand Down Expand Up @@ -842,10 +824,10 @@ class InstanceConfig(db.Model):
# updating the configuration.
metadata_cols = ['version', 'valid_until']

def __repr__(self):
def __repr__(self) -> str:
return "<InstanceConfig(version=%s, valid_until=%s)>" % (self.version, self.valid_until)

def copy(self):
def copy(self) -> "InstanceConfig":
'''Make a copy of only the configuration columns of the given
InstanceConfig object: i.e., excluding metadata_cols.
'''
Expand All @@ -860,7 +842,7 @@ def copy(self):
return new

@classmethod
def get_current(cls):
def get_current(cls) -> "InstanceConfig":
'''If the database was created via db.create_all(), data migrations
weren't run, and the "instance_config" table is empty. In this case,
save and return a base configuration derived from each setting's
Expand All @@ -876,7 +858,7 @@ def get_current(cls):
return current

@classmethod
def set(cls, name, value):
def set_allow_document_uploads(cls, value: bool) -> None:
'''Invalidate the current configuration and append a new one with the
requested change.
'''
Expand All @@ -886,7 +868,7 @@ def set(cls, name, value):
db.session.add(old)

new = old.copy()
setattr(new, name, value)
new.allow_document_uploads = value
db.session.add(new)

db.session.commit()
13 changes: 10 additions & 3 deletions securedrop/request_that_secures_file_uploads.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from io import BytesIO
from typing import Optional, BinaryIO

from flask import wrappers
from werkzeug.formparser import FormDataParser

from secure_tempfile import SecureTemporaryFile


class RequestThatSecuresFileUploads(wrappers.Request):

def _secure_file_stream(self, total_content_length, content_type,
filename=None, content_length=None):
def _secure_file_stream(
self,
total_content_length: int,
content_type: Optional[str],
filename: Optional[str] = None,
content_length: Optional[int] = None,
) -> BinaryIO:
"""Storage class for data streamed in from requests.

If the data is relatively small (512KB), just store it in
Expand All @@ -27,7 +34,7 @@ def _secure_file_stream(self, total_content_length, content_type,
return SecureTemporaryFile('/tmp') # nosec
return BytesIO()

def make_form_data_parser(self):
def make_form_data_parser(self) -> FormDataParser:
return self.form_data_parser_class(self._secure_file_stream,
self.charset,
self.encoding_errors,
Expand Down
20 changes: 12 additions & 8 deletions securedrop/secure_tempfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import io
from tempfile import _TemporaryFileWrapper # type: ignore
from typing import Optional
from typing import Union

from pretty_bad_protocol._util import _STREAMLIKE_TYPES
from cryptography.exceptions import AlreadyFinalized
Expand Down Expand Up @@ -34,7 +36,7 @@ class SecureTemporaryFile(_TemporaryFileWrapper, object):
AES_key_size = 256
AES_block_size = 128

def __init__(self, store_dir):
def __init__(self, store_dir: str) -> None:
"""Generates an AES key and an initialization vector, and opens
a file in the `store_dir` directory with a
pseudorandomly-generated filename.
Expand All @@ -56,7 +58,7 @@ def __init__(self, store_dir):
self.file = io.open(self.filepath, 'w+b')
super(SecureTemporaryFile, self).__init__(self.file, self.filepath)

def create_key(self):
def create_key(self) -> None:
"""Generates a unique, pseudorandom AES key, stored ephemerally in
memory as an instance attribute. Its destruction is ensured by the
automatic nightly reboots of the SecureDrop application server combined
Expand All @@ -68,15 +70,15 @@ def create_key(self):
self.iv = os.urandom(self.AES_block_size // 8)
self.initialize_cipher()

def initialize_cipher(self):
def initialize_cipher(self) -> None:
"""Creates the cipher-related objects needed for AES-CTR
encryption and decryption.
"""
self.cipher = Cipher(AES(self.key), CTR(self.iv), default_backend())
self.encryptor = self.cipher.encryptor()
self.decryptor = self.cipher.decryptor()

def write(self, data):
def write(self, data: Union[bytes, str]) -> None:
"""Write `data` to the secure temporary file. This method may be
called any number of times following instance initialization,
but after calling :meth:`read`, you cannot write to the file
Expand All @@ -87,11 +89,13 @@ def write(self, data):
self.last_action = 'write'

if isinstance(data, str):
data = data.encode('utf-8')
data_as_bytes = data.encode('utf-8')
else:
data_as_bytes = data

self.file.write(self.encryptor.update(data))
self.file.write(self.encryptor.update(data_as_bytes))

def read(self, count=None):
def read(self, count: Optional[int] = None) -> bytes:
"""Read `data` from the secure temporary file. This method may
be called any number of times following instance initialization
and once :meth:`write has been called at least once, but not
Expand Down Expand Up @@ -120,7 +124,7 @@ def read(self, count=None):
else:
return self.decryptor.update(self.file.read())

def close(self):
def close(self) -> None:
"""The __del__ method in tempfile._TemporaryFileWrapper (which
SecureTemporaryFile class inherits from) calls close() when the
temporary file is deleted.
Expand Down
16 changes: 1 addition & 15 deletions securedrop/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,10 @@
from mock import MagicMock

from .utils import db_helper
from models import (Journalist, Submission, Reply, Source, get_one_or_else,
from models import (Journalist, Submission, Reply, get_one_or_else,
LoginThrottledException)


def test_source_public_key_setter_unimplemented(journalist_app, test_source):
rmol marked this conversation as resolved.
Show resolved Hide resolved
with journalist_app.app_context():
source = Source.query.first()
with pytest.raises(NotImplementedError):
source.public_key = 'a curious developer tries to set a pubkey!'


def test_source_public_key_delete_unimplemented(journalist_app, test_source):
with journalist_app.app_context():
source = Source.query.first()
with pytest.raises(NotImplementedError):
del source.public_key


def test_get_one_or_else_returns_one(journalist_app, test_journo):
with journalist_app.app_context():
# precondition: there must be one journalist
Expand Down
45 changes: 25 additions & 20 deletions securedrop/upload-screenshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import sys

# Used to generate URLs for API endpoints and links; exposed as argument
from typing import List

from typing import Tuple

from typing import Dict

DEFAULT_BASE_URL = "https://weblate.securedrop.org"

# Where we look for screenshots: the page layout test results in English
Expand All @@ -23,7 +29,7 @@
# filename into the canonical title we give that screenshot in Weblate.
#
# Example conversion: "source-session_timeout.png" -> "source: session timeout"
CANONICALIZATION_RULES = ((r"\.png$", ""), (r"-", ": "), (r"_", " "))
CANONICALIZATION_RULES = [(r"\.png$", ""), (r"-", ": "), (r"_", " ")]
rmol marked this conversation as resolved.
Show resolved Hide resolved

# Weblate organizes internationalization work into projects and components,
# which are part of many URLs, and need to be referenced in some API requests.
Expand All @@ -34,7 +40,7 @@
REQUEST_LIMIT = 50


def main():
def main() -> None:
"""
Uses the generic WeblateUploader class below to run a SecureDrop screenshot
upload.
Expand Down Expand Up @@ -69,22 +75,22 @@ def main():
uploader.upload()


class WeblateUploader(object):
class WeblateUploader:
"""
Manages Weblate screenshot batch uploads, matching filenames against
titles of existing screenshots to create/update as appropriate.
"""

def __init__(
self,
token,
base_url,
project,
component,
files,
request_limit,
canonicalization_rules=(),
):
token: str,
base_url: str,
project: str,
component: str,
files: List[str],
request_limit: int,
canonicalization_rules: List[Tuple[str, str]],
) -> None:

if len(token) != 40:
raise BadOrMissingTokenError(
Expand All @@ -111,7 +117,7 @@ def __init__(
}
self.session.headers.update(headers)

def get_existing_screenshots(self):
def get_existing_screenshots(self) -> List[Dict[str, str]]:
"""
Obtains a list of all existing screenshots, and returns it as a list
in the API's format. Paginates up to the request limit.
Expand All @@ -120,7 +126,7 @@ def get_existing_screenshots(self):

# API results are paginated, so we must loop through a set of results and
# concatenate them.
screenshots = []
screenshots = [] # type: List[Dict[str, str]]
request_count = 0
while next_screenshots_url is not None:
response = self.session.get(next_screenshots_url)
Expand All @@ -136,7 +142,7 @@ def get_existing_screenshots(self):
raise RequestLimitError(msg)
return screenshots

def _canonicalize(self, filename):
def _canonicalize(self, filename: str) -> str:
"""
Derives a human-readable title from a filename using the defined
canonicalization rules, if any. This is used to later update the
Expand All @@ -146,7 +152,7 @@ def _canonicalize(self, filename):
filename = re.sub(pattern, repl, filename)
return filename

def upload(self, check_existing_screenshots=True):
def upload(self, check_existing_screenshots: bool = True) -> None:
"""
Uploads all files using the screenshots endpoint. Optionally, checks
files against a list of existing screenshots and replaces them rather
Expand Down Expand Up @@ -192,11 +198,10 @@ def upload(self, check_existing_screenshots=True):


class BadOrMissingTokenError(Exception):
def __init__(self, reason="Bad or missing token.", base_url=None):
if base_url is not None:
reason += " Obtain token via {}".format(
urljoin(base_url, "accounts/profile/#api")
)
def __init__(self, reason: str, base_url: str) -> None:
reason += " Obtain token via {}".format(
urljoin(base_url, "accounts/profile/#api")
)
super().__init__(reason)


Expand Down