Skip to content

Commit

Permalink
lint 🧹
Browse files Browse the repository at this point in the history
  • Loading branch information
geohacker committed Jun 19, 2019
1 parent 96396ad commit 8b5cb87
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 160
exclude = env/*
4 changes: 2 additions & 2 deletions ml_enabler/api/ml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flask_restful import Resource, request, current_app
from ml_enabler.models.dtos.ml_model_dto import MLModelDTO, MLModelVersionDTO, PredictionDTO
from ml_enabler.models.dtos.ml_model_dto import MLModelDTO, MLModelVersionDTO
from schematics.exceptions import DataError
from ml_enabler.services.ml_model_service import MLModelService, MLModelVersionService
from ml_enabler.services.prediction_service import PredictionService, PredictionTileService
Expand Down Expand Up @@ -262,7 +262,7 @@ def post(self, model_id):
ml_model_dto = MLModelService.get_ml_model_by_id(model_id)

# check if the version is registered
model_version = MLModelVersionService.get_version_by_model_version(model_id, version)
model_version = MLModelVersionService.get_version_by_model_version(ml_model_dto.model_id, version)
prediction_id = PredictionService.create(model_id, model_version.version_id, payload)
return {"prediction_id": prediction_id}, 200

Expand Down
2 changes: 1 addition & 1 deletion ml_enabler/models/dtos/ml_model_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from schematics import Model
from schematics.types import StringType, BaseType, IntType, DateTimeType, ListType, FloatType
from schematics.types import StringType, IntType, DateTimeType, ListType, FloatType


class MLModelDTO(Model):
Expand Down
27 changes: 20 additions & 7 deletions ml_enabler/models/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ class PredictionTile(db.Model):

@staticmethod
def get_tiles_by_quadkey(prediction_id: int, quadkeys: tuple, zoom: int):
return db.session.query(func.substr(PredictionTile.quadkey, 1, zoom).label('qaudkey'), func.avg(cast(cast(PredictionTile.predictions['ml_prediction'], sqlalchemy.String), sqlalchemy.Float)).label('ml_prediction')).filter(PredictionTile.prediction_id == prediction_id).filter(func.substr(PredictionTile.quadkey, 1, zoom).in_(quadkeys)).group_by(func.substr(PredictionTile.quadkey, 1, zoom)).all()

return db.session.query(func.substr(PredictionTile.quadkey, 1, zoom).label('qaudkey'),
func.avg(cast(cast(PredictionTile.predictions['ml_prediction'], sqlalchemy.String),
sqlalchemy.Float)).label('ml_prediction')).filter(PredictionTile.prediction_id == prediction_id).filter(func.substr(
PredictionTile.quadkey, 1, zoom).in_(quadkeys)).group_by(func.substr(PredictionTile.quadkey, 1, zoom)).all()

@staticmethod
def get_aggregate_for_polygon(prediction_id: int, polygon: str):
return db.session.query(func.avg(cast(cast(PredictionTile.predictions['ml_prediction'], sqlalchemy.String), sqlalchemy.Float))).filter(PredictionTile.prediction_id == prediction_id).filter(ST_Within(PredictionTile.centroid, ST_GeomFromText(polygon)) == 'True').one()
return db.session.query(func.avg(cast(cast(PredictionTile.predictions['ml_prediction'], sqlalchemy.String), sqlalchemy.Float))).filter(
PredictionTile.prediction_id == prediction_id).filter(ST_Within(PredictionTile.centroid, ST_GeomFromText(polygon)) == 'True').one()


class Prediction(db.Model):
Expand Down Expand Up @@ -77,7 +80,9 @@ def get(prediction_id: int):
:param prediction_id
:return prediction if found otherwise None
"""
query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash, ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'), Prediction.model_id, Prediction.tile_zoom, Prediction.version_id).filter(Prediction.id == prediction_id)
query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash,
ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'), Prediction.model_id, Prediction.tile_zoom,
Prediction.version_id).filter(Prediction.id == prediction_id)
return query.one()

@staticmethod
Expand All @@ -99,7 +104,11 @@ def get_latest_predictions_in_bbox(model_id: int, version_id: int, bbox: list):
:return list of predictions
"""

query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash, ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'), Prediction.model_id, Prediction.tile_zoom, Prediction.version_id).filter(Prediction.model_id == model_id).filter(Prediction.version_id==version_id).filter(ST_Intersects(Prediction.bbox, ST_MakeEnvelope(bbox[0], bbox[1], bbox[2], bbox[3], 4326))).order_by(Prediction.created.desc()).limit(1)
query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash, ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'),
Prediction.model_id, Prediction.tile_zoom, Prediction.version_id).filter(Prediction.model_id == model_id).filter(
Prediction.version_id == version_id).filter(ST_Intersects(
Prediction.bbox, ST_MakeEnvelope(bbox[0], bbox[1], bbox[2], bbox[3], 4326))).order_by(
Prediction.created.desc()).limit(1)

return query.all()

Expand All @@ -110,7 +119,10 @@ def get_all_predictions_in_bbox(model_id: int, bbox: list):
:param model_id, bbox
:return list of predictions
"""
query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash, ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'), Prediction.model_id, Prediction.tile_zoom, Prediction.version_id).filter(Prediction.model_id == model_id).filter(ST_Intersects(Prediction.bbox, ST_MakeEnvelope(bbox[0], bbox[1], bbox[2], bbox[3], 4326)))
query = db.session.query(Prediction.id, Prediction.created, Prediction.dockerhub_hash, ST_AsGeoJSON(ST_Envelope(Prediction.bbox)).label('bbox'),
Prediction.model_id, Prediction.tile_zoom, Prediction.version_id).filter(
Prediction.model_id == model_id).filter(
ST_Intersects(Prediction.bbox, ST_MakeEnvelope(bbox[0], bbox[1], bbox[2], bbox[3], 4326)))

return query.all()

Expand Down Expand Up @@ -262,7 +274,8 @@ def get_latest_version(model_id: int):
:param model_id
:return version or None
"""
return MLModelVersion.query.filter_by(model_id=model_id).order_by(MLModelVersion.version_major.desc(), MLModelVersion.version_minor.desc(), MLModelVersion.version_patch.desc()).first()
return MLModelVersion.query.filter_by(model_id=model_id).order_by(MLModelVersion.version_major.desc(), MLModelVersion.version_minor.desc(),
MLModelVersion.version_patch.desc()).first()

def as_dto(self):
"""
Expand Down
11 changes: 4 additions & 7 deletions ml_enabler/services/prediction_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from flask import current_app
from ml_enabler.models.ml_model import MLModel, MLModelVersion, Prediction, PredictionTile
from ml_enabler.models.dtos.ml_model_dto import MLModelDTO, MLModelVersionDTO, PredictionDTO
from ml_enabler.models.utils import NotFound, VersionNotFound,\
version_to_array, bbox_str_to_list, PredictionsNotFound, geojson_to_bbox,\
point_list_to_wkt, bbox_to_quadkeys, tuple_to_dict, polygon_to_wkt
from sqlalchemy.orm.exc import NoResultFound
from ml_enabler.models.ml_model import MLModelVersion, Prediction, PredictionTile
from ml_enabler.models.dtos.ml_model_dto import PredictionDTO
from ml_enabler.models.utils import bbox_str_to_list, PredictionsNotFound, geojson_to_bbox,\
bbox_to_quadkeys, tuple_to_dict, polygon_to_wkt
from ml_enabler import db


Expand Down
2 changes: 1 addition & 1 deletion ml_enabler/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def tearDown(self):
self.db.session.rollback()
self.app_context.pop()

super(BaseTestCase, self).tearDown()
super(BaseTestCase, self).tearDown()
2 changes: 1 addition & 1 deletion ml_enabler/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from factory.alchemy import SQLAlchemyModelFactory
from ml_enabler import db
from ml_enabler.models.ml_model import Prediction, MLModel, \
MLModelVersion, PredictionTile
MLModelVersion


class MLModelFactory(SQLAlchemyModelFactory):
Expand Down
1 change: 1 addition & 0 deletions ml_enabler/tests/fixtures/geojson.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8:noqa
def get_geojson():
return {
"type": "FeatureCollection",
Expand Down
5 changes: 2 additions & 3 deletions ml_enabler/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from ml_enabler.tests.base import BaseTestCase
from ml_enabler.tests.factories import MLModelFactory, MLModelVersionFactory, \
PredictionFactory
from ml_enabler.models.ml_model import MLModel, Prediction
from ml_enabler.tests.factories import MLModelFactory
from ml_enabler.tests.fixtures import tiles, geojson
from ml_enabler.tests.utils import create_prediction, create_prediction_tiles


class StatusTest(BaseTestCase):
def test_status(self):
response = self.client.get('/')
Expand Down
3 changes: 2 additions & 1 deletion ml_enabler/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from ml_enabler.tests.factories import MLModelFactory, MLModelVersionFactory, \
PredictionFactory
from ml_enabler.models.ml_model import MLModel, Prediction, PredictionTile
from ml_enabler.models.ml_model import PredictionTile
from ml_enabler import db
from ml_enabler.tests.fixtures import tiles


def create_prediction():
ml_model = MLModelFactory()
db.session.add(ml_model)
Expand Down

0 comments on commit 8b5cb87

Please sign in to comment.