Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bibx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"read_wos",
]

__version__ = "0.4.0"
__version__ = "0.4.1"


def query_openalex(
Expand Down
18 changes: 15 additions & 3 deletions src/bibx/builders/openalex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections import Counter
from enum import Enum
from typing import Optional
from urllib.parse import urlparse
Expand All @@ -11,11 +12,14 @@

logger = logging.getLogger(__name__)

MAX_REFERENCES = 400


class HandleReferences(Enum):
"""How to handle references when building an openalex collection."""

BASIC = "basic"
COMMON = "common"
FULL = "full"


Expand All @@ -39,14 +43,22 @@ def build(self) -> Collection:
logger.info("building collection for query %s", self.query)
works = self.client.list_recent_articles(self.query, self.limit)
cache = {work.id: work for work in works}
references: list[str] = []
for work in works:
references.extend(work.referenced_works)
if self.references == HandleReferences.COMMON:
counter = Counter(references)
most_common = {key for key, _ in counter.most_common(MAX_REFERENCES)}
missing = most_common - set(cache.keys())
logger.info("fetching %d missing references", len(missing))
missing_works = self.client.list_articles_by_openalex_id(list(missing))
cache.update({work.id: work for work in missing_works})
if self.references == HandleReferences.FULL:
references: list[str] = []
for work in works:
references.extend(work.referenced_works)
missing = set(references) - set(cache.keys())
logger.info("fetching %d missing references", len(missing))
missing_works = self.client.list_articles_by_openalex_id(list(missing))
cache.update({work.id: work for work in missing_works})

article_cache = {
openalexid: self._work_to_article(work)
for openalexid, work in cache.items()
Expand Down
7 changes: 7 additions & 0 deletions src/bibx/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Callable
from enum import Enum
from typing import TextIO
Expand Down Expand Up @@ -83,8 +84,14 @@ def openalex(
help="how to handle references",
default=HandleReferences.BASIC,
),
verbose: bool = typer.Option(
help="be more verbose",
default=False,
),
) -> None:
"""Run the sap algorithm on a seed file of any supported format."""
if verbose:
logging.basicConfig(level=logging.INFO)
c = query_openalex(" ".join(query), references=references)
s = Sap()
graph = s.create_graph(c)
Expand Down
89 changes: 47 additions & 42 deletions src/bibx/clients/openalex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed, wait
from enum import Enum
from typing import Optional, Union

Expand Down Expand Up @@ -122,6 +123,18 @@ def __init__(
}
)

def _fetch_works(self, params: dict[str, Union[str, int]]) -> WorkResponse:
response = self.session.get(
f"{self.base_url}/works",
params=params,
)
try:
response.raise_for_status()
data = response.json()
return WorkResponse.model_validate(data)
except (requests.RequestException, ValidationError) as error:
raise OpenAlexError(str(error)) from error

def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]:
"""List recent articles from the openalex API."""
select = ",".join(Work.model_fields.keys())
Expand All @@ -134,56 +147,48 @@ def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]:
)
pages = (limit // MAX_WORKS_PER_PAGE) + 1
results: list[Work] = []
for page in range(1, pages + 1):
logger.info("fetching page %d with filter %s", page, filter_)
params: dict[str, Union[str, int]] = {
"select": select,
"filter": filter_,
"sort": "publication_year:desc",
"per_page": MAX_WORKS_PER_PAGE,
"page": page,
}
response = self.session.get(
f"{self.base_url}/works",
params=params,
)
try:
response.raise_for_status()
data = response.json()
work_response = WorkResponse.model_validate(data)
logger.info(
"fetched %d works in page %d", len(work_response.results), page
with ThreadPoolExecutor(max_workers=min(pages, 25)) as executor:
futures = [
executor.submit(
self._fetch_works,
{
"select": select,
"filter": filter_,
"sort": "publication_year:desc",
"per_page": MAX_WORKS_PER_PAGE,
"page": page,
},
)
for page in range(1, pages + 1)
]
wait(futures)
for future in futures:
work_response = future.result()
results.extend(work_response.results)
if page * MAX_WORKS_PER_PAGE >= min(work_response.meta.count, limit):
if len(results) >= limit:
break
except (requests.RequestException, ValidationError) as error:
raise OpenAlexError(str(error)) from error
return results[:limit]

def list_articles_by_openalex_id(self, ids: list[str]) -> list[Work]:
"""List articles by openalex id."""
select = ",".join(Work.model_fields.keys())
filter_ = ",".join([f"ids.openalex:{id_}" for id_ in ids])
results: list[Work] = []
for ids_ in chunks(ids, MAX_IDS_PER_REQUEST):
value = "|".join(ids_)
filter_ = f"ids.openalex:{value},type:types/article"
logger.info("fetching %d ids from openalex", len(ids_))
params: dict[str, Union[str, int]] = {
"select": select,
"filter": filter_,
"per_page": MAX_IDS_PER_REQUEST,
}
response = self.session.get(
f"{self.base_url}/works",
params=params,
)
try:
response.raise_for_status()
data = response.json()
work_response = WorkResponse.model_validate(data)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
self._fetch_works,
{
"select": select,
"filter": f"ids.openalex:{'|'.join(ids)},type:types/article",
"per_page": MAX_IDS_PER_REQUEST,
},
)
for ids in chunks(ids, MAX_IDS_PER_REQUEST)
]
for future in as_completed(futures):
work_response = future.result()
logger.info(
"got %s works from the openalex api", len(work_response.results)
)
results.extend(work_response.results)
except (requests.RequestException, ValidationError) as error:
raise OpenAlexError(str(error)) from error
return results