Skip to content

Commit

Permalink
Temp comment on search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed May 28, 2021
1 parent 470c613 commit 75d5035
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/mp_api/routes/search/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
SearchDielectricPiezoQuery,
SearchIsTheoreticalQuery,
)
from mp_api.routes.surface_properties.query_operators import SurfaceMinMaxQuery

# from mp_api.routes.surface_properties.query_operators import SurfaceMinMaxQuery
from mp_api.routes.electronic_structure.query_operators import ESSummaryDataQuery
from mp_api.routes.thermo.query_operators import ThermoEnergyQuery

Expand All @@ -34,15 +35,19 @@ def generate_stats_prep(self):
model_name = self.model.__name__

# we can only generate statistics for fields that return numbers
valid_numeric_fields = tuple(sorted(k for k, v in SearchDoc().__fields__.items() if v.type_ == float))
valid_numeric_fields = tuple(
sorted(k for k, v in SearchDoc().__fields__.items() if v.type_ == float)
)

async def generate_stats(
field: Literal[valid_numeric_fields] = Query(
valid_numeric_fields[0],
title=f"SearchDoc field to query on, must be a numerical field, "
f"choose from: {', '.join(valid_numeric_fields)}",
),
num_samples: Optional[int] = Query(None, title="If specified, will only sample this number of documents.",),
num_samples: Optional[int] = Query(
None, title="If specified, will only sample this number of documents.",
),
min_val: Optional[float] = Query(
None,
title="If specified, will only consider documents with field values "
Expand All @@ -53,7 +58,9 @@ async def generate_stats(
title="If specified, will only consider documents with field values "
"less than or equal to this minimum value.",
),
num_points: int = Query(100, title="The number of values in the returned distribution."),
num_points: int = Query(
100, title="The number of values in the returned distribution."
),
):
"""
Generate statistics for a given numerical field specified in SearchDoc.
Expand All @@ -78,7 +85,10 @@ async def generate_stats(

pipeline.append({"$project": {field: 1}})

values = [d[field] for d in self.store._collection.aggregate(pipeline, allowDiskUse=True)]
values = [
d[field]
for d in self.store._collection.aggregate(pipeline, allowDiskUse=True)
]
if not min_val:
min_val = min(values)
if not max_val:
Expand Down Expand Up @@ -130,7 +140,7 @@ async def generate_stats(
ESSummaryDataQuery(),
SearchElasticityQuery(),
SearchDielectricPiezoQuery(),
SurfaceMinMaxQuery(),
# SurfaceMinMaxQuery(),
SearchMagneticQuery(),
HasPropsQuery(),
DeprecationQuery(),
Expand Down

0 comments on commit 75d5035

Please sign in to comment.