Skip to content

Commit

Permalink
Fixed mypy code smells and added mypy to the CI/CD pipeline (#143)
Browse files Browse the repository at this point in the history
* fix basic mypy command errors

* fix errors for base defs

* add new support tests

* added ignore statements

* fix flake8 errors

* add workflow command

* fix mypy imports

* fix mypy errors
  • Loading branch information
kushagra189 committed Sep 29, 2022
1 parent 05be4df commit 5bac488
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 88 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
pip install -r requirements_dev.txt
- name: Lint with flake8
run: flake8
- name: Run mypy
run: mypy foca
- name: Start MongoDB
uses: supercharge/mongodb-github-action@1.7.0
with:
Expand Down
124 changes: 72 additions & 52 deletions foca/access_control/access_control_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from werkzeug.exceptions import (InternalServerError, NotFound)

from foca.utils.logging import log_traffic
from foca.errors.exceptions import BadRequest

logger = logging.getLogger(__name__)

Expand All @@ -20,27 +21,31 @@ def postPermission() -> str:
Returns:
Identifier of the new permission added.
"""
try:
access_control_adapter = current_app.config["casbin_adapter"]
request_json = request.json
rule = request_json.get("rule", {})
permission_data = [
rule.get("v0", None),
rule.get("v1", None),
rule.get("v2", None),
rule.get("v3", None),
rule.get("v4", None),
rule.get("v5", None)
]
permission_id = access_control_adapter.save_policy_line(
ptype=request_json.get("policy_type", None),
rule=permission_data
)
logger.info("New policy added.")
return permission_id
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
raise InternalServerError
request_json = request.json
if isinstance(request_json, dict):
try:
access_control_adapter = current_app.config["casbin_adapter"]
rule = request_json.get("rule", {})
permission_data = [
rule.get("v0", None),
rule.get("v1", None),
rule.get("v2", None),
rule.get("v3", None),
rule.get("v4", None),
rule.get("v5", None)
]
permission_id = access_control_adapter.save_policy_line(
ptype=request_json.get("policy_type", None),
rule=permission_data
)
logger.info("New policy added.")
return permission_id
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
raise InternalServerError
else:
logger.error("Invalid request payload.")
raise BadRequest


@log_traffic
Expand All @@ -55,27 +60,34 @@ def putPermission(
Returns:
Identifier of updated permission.
"""
try:
request_json = request.json
access_control_config = current_app.config.foca.access_control
db_coll_permission: Collection = (
current_app.config.foca.db.dbs[access_control_config.db_name]
.collections[access_control_config.collection_name].client
)

permission_data = request_json.get("rule", {})
permission_data["id"] = id
permission_data["ptype"] = request_json.get("policy_type", None)
db_coll_permission.replace_one(
filter={"id": id},
replacement=permission_data,
upsert=True
)
logger.info("Policy updated.")
return id
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
raise InternalServerError
request_json = request.json
if isinstance(request_json, dict):
app_config = current_app.config
try:
access_control_config = \
app_config.foca.access_control # type: ignore[attr-defined]
db_coll_permission: Collection = (
app_config.foca.db.dbs[ # type: ignore[attr-defined]
access_control_config.db_name]
.collections[access_control_config.collection_name].client
)

permission_data = request_json.get("rule", {})
permission_data["id"] = id
permission_data["ptype"] = request_json.get("policy_type", None)
db_coll_permission.replace_one(
filter={"id": id},
replacement=permission_data,
upsert=True
)
logger.info("Policy updated.")
return id
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
raise InternalServerError
else:
logger.error("Invalid request payload.")
raise BadRequest


@log_traffic
Expand All @@ -88,11 +100,13 @@ def getAllPermissions(limit=None) -> List[Dict]:
Returns:
List of permission dicts.
"""
logger.info(f"test {current_app.config}")
access_control_config = current_app.config.foca.access_control
app_config = current_app.config
access_control_config = \
app_config.foca.access_control # type: ignore[attr-defined]
db_coll_permission: Collection = (
current_app.config.foca.db.dbs[access_control_config.db_name]
.collections[access_control_config.collection_name].client
app_config.foca.db.dbs[ # type: ignore[attr-defined]
access_control_config.db_name
].collections[access_control_config.collection_name].client
)

if not limit:
Expand Down Expand Up @@ -129,10 +143,13 @@ def getPermission(
Returns:
Permission data for the given id.
"""
access_control_config = current_app.config.foca.access_control
app_config = current_app.config
access_control_config = \
app_config.foca.access_control # type: ignore[attr-defined]
db_coll_permission: Collection = (
current_app.config.foca.db.dbs[access_control_config.db_name]
.collections[access_control_config.collection_name].client
app_config.foca.db.dbs[ # type: ignore[attr-defined]
access_control_config.db_name
].collections[access_control_config.collection_name].client
)

permission = db_coll_permission.find_one(filter={"id": id})
Expand Down Expand Up @@ -162,10 +179,13 @@ def deletePermission(
Returns:
Delete permission identifier.
"""
access_control_config = current_app.config.foca.access_control
app_config = current_app.config
access_control_config = \
app_config.foca.access_control # type: ignore[attr-defined]
db_coll_permission: Collection = (
current_app.config.foca.db.dbs[access_control_config.db_name]
.collections[access_control_config.collection_name].client
app_config.foca.db.dbs[ # type: ignore[attr-defined]
access_control_config.db_name
].collections[access_control_config.collection_name].client
)

del_obj_permission = db_coll_permission.delete_one({'id': id})
Expand Down
18 changes: 9 additions & 9 deletions foca/access_control/foca_casbin_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def save_policy_line(self, ptype: str, rule: List[str]):
line = CasbinRule(ptype=ptype)
for index, value in enumerate(rule):
setattr(line, f"v{index}", value)
line = line.dict()
line["id"] = generate_id()
self._collection.insert_one(line)
return line["id"]
rule_dict: dict = line.dict()
rule_dict["id"] = generate_id()
self._collection.insert_one(rule_dict)
return rule_dict["id"]

def _delete_policy_lines(self, ptype: str, rule: List[str]) -> int:
"""Method to find a delete policies given a list of policy attributes.
Expand Down Expand Up @@ -118,27 +118,27 @@ def save_policy(self, model: Model) -> bool:
self.save_policy_line(ptype, rule)
return True

def add_policy(self, sec: str, ptype: str, rule: CasbinRule) -> bool:
def add_policy(self, sec: str, ptype: str, rule: List[str]) -> bool:
"""Add policy rules to mongodb
Args:
sec: Section corresponding which the rule will be added.
ptype: Policy type for which casbin rule will be added.
rule: Casbin rule to be added.
rule: Casbin rule list definition to be added.
Returns:
True if succeed else False.
"""
self.save_policy_line(ptype, rule)
return True

def remove_policy(self, sec: str, ptype: str, rule: CasbinRule):
def remove_policy(self, sec: str, ptype: str, rule: List[str]):
"""Remove policy rules from mongodb(duplicate rules are also removed).
Args:
sec: Section corresponding which the rule will be added.
ptype: Policy type for which casbin rule will be removed.
rule: Casbin rule to be removed.
rule: Casbin rule list definition to be removed.
Returns:
Number of policies removed.
Expand Down Expand Up @@ -172,6 +172,6 @@ def remove_filtered_policy(
for index, value in enumerate(field_values):
query[f"v{index + field_index}"] = value

query["ptype"] = ptype
query["ptype"] = ptype # type: ignore[assignment]
results = self._collection.delete_many(query)
return results.deleted_count > 0
18 changes: 10 additions & 8 deletions foca/access_control/register_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def register_access_control(
cnx_app: App,
mongo_config: MongoConfig,
mongo_config: Optional[MongoConfig],
access_control_config: AccessControlConfig
) -> App:
"""Register access control configuration with flask app.
Expand Down Expand Up @@ -58,10 +58,11 @@ def register_access_control(
if mongo_config is None:
mongo_config = MongoConfig()

access_control_db = str(access_control_config.db_name)
if mongo_config.dbs is None:
mongo_config.dbs = {access_control_config.db_name: access_db_conf}
mongo_config.dbs = {access_control_db: access_db_conf}
else:
mongo_config.dbs[access_control_config.db_name] = access_db_conf
mongo_config.dbs[access_control_db] = access_db_conf

cnx_app.app.config.foca.db = mongo_config

Expand All @@ -70,7 +71,7 @@ def register_access_control(
app=cnx_app.app,
conf=mongo_config,
db_conf=access_db_conf,
db_name=access_control_config.db_name
db_name=access_control_db
)

# Register access control api specs and corresponding controller.
Expand Down Expand Up @@ -129,7 +130,7 @@ def register_permission_specs(
)

app.add_api(
specification=spec.path[0],
specification=spec.path[0], # type: ignore[index]
**spec.dict().get("connexion", {}),
)
return app
Expand All @@ -154,12 +155,13 @@ def register_casbin_enforcer(
Connexion application instance with registered casbin configuration.
"""
# Check if default, get package path variables for model.
access_control_config_model = str(access_control_config.model)
if access_control_config.api_specs is None:
casbin_model = str(resource_filename(
ACCESS_CONTROL_BASE_PATH, access_control_config.model
ACCESS_CONTROL_BASE_PATH, access_control_config_model
))
else:
casbin_model = access_control_config.model
casbin_model = access_control_config_model

logger.info("Setting casbin model.")
app.app.config["CASBIN_MODEL"] = casbin_model
Expand All @@ -177,7 +179,7 @@ def register_casbin_enforcer(
logger.info("Setting up casbin enforcer.")
adapter = Adapter(
uri=f"mongodb://{mongo_config.host}:{mongo_config.port}/",
dbname=access_control_config.db_name,
dbname=str(access_control_config.db_name),
collection=access_control_config.collection_name
)
app.app.config["casbin_adapter"] = adapter
Expand Down
8 changes: 5 additions & 3 deletions foca/api/register_openapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Register and modify OpenAPI specifications."""

import logging
from typing import List
from pathlib import Path
from typing import Dict, List

from connexion import App
import yaml
Expand Down Expand Up @@ -37,8 +38,9 @@ def register_openapi(
for spec in specs:

# Merge specs
spec_parsed = ConfigParser.merge_yaml(*spec.path)
logger.debug(f"Parsed spec: {spec.path}")
list_specs = [spec.path] if isinstance(spec.path, Path) else spec.path
spec_parsed: Dict = ConfigParser.merge_yaml(*list_specs)
logger.debug(f"Parsed spec: {list_specs}")

# Add/replace root objects
if spec.append is not None:
Expand Down
7 changes: 4 additions & 3 deletions foca/config/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def parse_yaml(conf: Path) -> Dict:
) from exc

@staticmethod
def merge_yaml(*args: Path) -> Optional[Dict]:
def merge_yaml(*args: Path) -> Dict:
"""Parse and merge a set of YAML files.
Merging is done iteratively, from the first, second to the n-th
Expand All @@ -126,7 +126,7 @@ def merge_yaml(*args: Path) -> Optional[Dict]:
"""
args_list = list(args)
if not args_list:
return None
return {}
yaml_dict = Addict(ConfigParser.parse_yaml(args_list.pop(0)))

for arg in args_list:
Expand Down Expand Up @@ -169,7 +169,8 @@ def parse_custom_config(self, model: str) -> BaseModel:
f"has no class {model_class} or could not be imported"
)
try:
custom_config = model_class(**self.config.custom)
custom_config = model_class( # type: ignore[operator]
**self.config.custom) # type: ignore[attr-defined]
except Exception as exc:
raise ValueError(
"failed validating custom configuration: provided custom "
Expand Down
2 changes: 1 addition & 1 deletion foca/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _problem_handler_json(exception: Exception) -> Response:
JSON-formatted error response.
"""
# Look up exception & get status code
conf = current_app.config.foca.exceptions
conf = current_app.config.foca.exceptions # type: ignore[attr-defined]
exc = type(exception)
if exc not in conf.mapping:
exc = Exception
Expand Down
4 changes: 2 additions & 2 deletions foca/factories/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def create_celery_app(app: Flask) -> Celery:
Returns:
Celery application instance.
"""
conf = app.config.foca.jobs
conf = app.config.foca.jobs # type: ignore[attr-defined]

# Instantiate Celery app
celery = Celery(
Expand All @@ -32,7 +32,7 @@ def create_celery_app(app: Flask) -> Celery:
logger.debug(f"Celery app created from '{calling_module}'.")

# Update Celery app configuration with Flask app configuration
setattr(celery.conf, 'foca', app.config.foca)
setattr(celery.conf, 'foca', app.config.foca) # type: ignore[attr-defined]
logger.debug('Celery app configured.')

class ContextTask(celery.Task): # type: ignore
Expand Down
1 change: 0 additions & 1 deletion foca/foca.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def create_app(self) -> App:
conf.access_control.api_specs or
conf.access_control.api_controllers
):
conf.access_control = None
logger.error(
"Please enable security config to register "
"access control."
Expand Down

0 comments on commit 5bac488

Please sign in to comment.