diff --git a/README.md b/README.md index bc85a00..efdbe16 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Codacy grade](https://img.shields.io/codacy/grade/1456603c25764b14b441ed509e938154?style=for-the-badge)](https://www.codacy.com/gh/danielnsilva/semanticscholar/dashboard?utm_source=github.com&utm_medium=referral&utm_content=danielnsilva/semanticscholar&utm_campaign=Badge_Grade) [![Codacy coverage](https://img.shields.io/codacy/coverage/1456603c25764b14b441ed509e938154?style=for-the-badge)](https://www.codacy.com/gh/danielnsilva/semanticscholar/dashboard?utm_source=github.com&utm_medium=referral&utm_content=danielnsilva/semanticscholar&utm_campaign=Badge_Coverage) -Unofficial [Semantic Scholar Academic Graph API](https://api.semanticscholar.org/) client library for Python. +Unofficial Python client library for [Semantic Scholar APIs](https://api.semanticscholar.org/), currently supporting the Academic Graph API and Recommendations API. ![](search_paper.gif) @@ -185,6 +185,56 @@ On Computing Machinery and Intelligence Building Thinking Machines by Solving Animal Cognition Tasks ``` +### Recommended papers + +To get recommended papers for a given paper: + +```python +from semanticscholar import SemanticScholar +sch = SemanticScholar() +results = sch.get_recommended_papers('10.1145/3544585.3544600') +for item in results: + print(item.title) +``` + +Output: + +```console +INDUCED SUBGRAPHS AND TREE DECOMPOSITIONS +On the Deque and Rique Numbers of Complete and Complete Bipartite Graphs +Exact and Parameterized Algorithms for the Independent Cutset Problem +On (in)tractability of connection and cut problems +A survey on constructive methods for the Oberwolfach problem and its variants +... +Approximation Algorithms for Directed Weighted Spanners +``` + +To get recommended papers based on a list of positive and negative paper examples: + +```python +from semanticscholar import SemanticScholar +sch = SemanticScholar() +positive_paper_ids = ['10.1145/3544585.3544600'] +negative_paper_ids = ['10.1145/301250.301271'] +results = sch.get_recommended_papers_from_lists(positive_paper_ids, negative_paper_ids) +for item in results: + print(item.title) +``` + +Output: + +```console +BUILDING MINIMUM SPANNING TREES BY LIMITED NUMBER OF NODES OVER TRIANGULATED SET OF INITIAL NODES +Recognition of chordal graphs and cographs which are Cover-Incomparability graphs +Minimizing Maximum Unmet Demand by Transportations Between Adjacent Nodes Characterized by Supplies and Demands +Optimal Near-Linear Space Heaviest Induced Ancestors +Diameter-2-critical graphs with at most 13 nodes +... +Advanced Heuristic and Approximation Algorithms (M2) +``` + +You can also omit the list of negative paper IDs; in which case, the API will return recommended papers based on the list of positive paper IDs only. + ### Query parameters for all methods #### ```fields: list``` diff --git a/semanticscholar/SemanticScholar.py b/semanticscholar/SemanticScholar.py index 9ff688d..bebe45f 100644 --- a/semanticscholar/SemanticScholar.py +++ b/semanticscholar/SemanticScholar.py @@ -1,509 +1,615 @@ -import warnings -from typing import List - -from semanticscholar.ApiRequester import ApiRequester -from semanticscholar.Author import Author -from semanticscholar.BaseReference import BaseReference -from semanticscholar.Citation import Citation -from semanticscholar.PaginatedResults import PaginatedResults -from semanticscholar.Paper import Paper -from semanticscholar.Reference import Reference - - -class SemanticScholar: - ''' - Main class to retrieve data from Semantic Scholar Graph API - ''' - - DEFAULT_API_URL = 'https://api.semanticscholar.org/graph/v1' - DEFAULT_PARTNER_API_URL = 'https://partner.semanticscholar.org/graph/v1' - - auth_header = {} - - def __init__( - self, - timeout: int = 10, - api_key: str = None, - api_url: str = None, - graph_api: bool = True - ) -> None: - ''' - :param float timeout: (optional) an exception is raised\ - if the server has not issued a response for timeout seconds. - :param str api_key: (optional) private API key. - :param str api_url: (optional) custom API url. - :param bool graph_api: (optional) whether use new Graph API. - ''' - - if api_url: - self.api_url = api_url - else: - self.api_url = self.DEFAULT_API_URL - - if api_key: - self.auth_header = {'x-api-key': api_key} - if not api_url: - self.api_url = self.DEFAULT_PARTNER_API_URL - - if not graph_api: - warnings.warn( - 'graph_api parameter is deprecated and will be disabled ' + - 'in the future', DeprecationWarning) - self.api_url = self.api_url.replace('/graph', '') - - self._timeout = timeout - self._requester = ApiRequester(self._timeout) - - @property - def timeout(self) -> int: - ''' - :type: :class:`int` - ''' - return self._timeout - - @timeout.setter - def timeout(self, timeout: int) -> None: - ''' - :param int timeout: - ''' - self._timeout = timeout - self._requester.timeout = timeout - - def get_paper( - self, - paper_id: str, - include_unknown_refs: bool = False, - fields: list = None - ) -> Paper: - '''Paper lookup - - :calls: `GET /paper/{paper_id} `_ - - :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL, \ - PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param bool include_unknown_refs: (optional) include non referenced \ - paper. - :param list fields: (optional) list of the fields to be returned. - :returns: paper data - :rtype: :class:`semanticscholar.Paper.Paper` - :raises: ObjectNotFoundException: if Paper ID not found. - ''' - - if not fields: - fields = Paper.FIELDS - - url = f'{self.api_url}/paper/{paper_id}' - - fields = ','.join(fields) - parameters = f'&fields={fields}' - if include_unknown_refs: - warnings.warn( - 'include_unknown_refs parameter is deprecated and will be disabled ' + - 'in the future', DeprecationWarning) - parameters += '&include_unknown_references=true' - - data = self._requester.get_data(url, parameters, self.auth_header) - paper = Paper(data) - - return paper - - def get_papers( - self, - paper_ids: List[str], - fields: list = None - ) -> List[Paper]: - '''Get details for multiple papers at once - - :calls: `POST /paper/batch `_ - - :param str paper_ids: list of IDs (must be <= 1000) - S2PaperId,\ - CorpusId, DOI, ArXivId, MAG, ACL, PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param list fields: (optional) list of the fields to be returned. - :returns: papers data - :rtype: :class:`List` of :class:`semanticscholar.Paper.Paper` - :raises: BadQueryParametersException: if no paper was found. - ''' - - if not fields: - fields = Paper.SEARCH_FIELDS - - url = f'{self.api_url}/paper/batch' - - fields = ','.join(fields) - parameters = f'&fields={fields}' - - payload = { "ids": paper_ids } - - data = self._requester.get_data( - url, parameters, self.auth_header, payload) - papers = [Paper(item) for item in data] - - return papers - - def get_paper_authors( - self, - paper_id: str, - fields: list = None, - limit: int = 1000 - ) -> PaginatedResults: - '''Get details about a paper's authors - - :calls: `POST /paper/{paper_id}/authors \ - `_ - - :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ - PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return\ - (must be <= 1000). - ''' - - if limit < 1 or limit > 1000: - raise ValueError( - 'The limit parameter must be between 1 and 1000 inclusive.') - - if not fields: - fields = [item for item in Author.SEARCH_FIELDS - if not item.startswith('papers')] - - url = f'{self.api_url}/paper/{paper_id}/authors' - - results = PaginatedResults( - requester=self._requester, - data_type=Author, - url=url, - fields=fields, - limit=limit - ) - - return results - - def get_paper_citations( - self, - paper_id: str, - fields: list = None, - limit: int = 1000 - ) -> PaginatedResults: - '''Get details about a paper's citations - - :calls: `POST /paper/{paper_id}/citations \ - `_ - - :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ - PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return\ - (must be <= 1000). - ''' - - if limit < 1 or limit > 1000: - raise ValueError( - 'The limit parameter must be between 1 and 1000 inclusive.') - - if not fields: - fields = BaseReference.FIELDS + Paper.SEARCH_FIELDS - - url = f'{self.api_url}/paper/{paper_id}/citations' - - results = PaginatedResults( - requester=self._requester, - data_type=Citation, - url=url, - fields=fields, - limit=limit - ) - - return results - - def get_paper_references( - self, - paper_id: str, - fields: list = None, - limit: int = 1000 - ) -> PaginatedResults: - '''Get details about a paper's references - - :calls: `POST /paper/{paper_id}/references \ - `_ - - :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ - PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return\ - (must be <= 1000). - ''' - - if limit < 1 or limit > 1000: - raise ValueError( - 'The limit parameter must be between 1 and 1000 inclusive.') - - if not fields: - fields = BaseReference.FIELDS + Paper.SEARCH_FIELDS - - url = f'{self.api_url}/paper/{paper_id}/references' - - results = PaginatedResults( - requester=self._requester, - data_type=Reference, - url=url, - fields=fields, - limit=limit - ) - - return results - - def search_paper( - self, - query: str, - year: str = None, - publication_types: list = None, - open_access_pdf: bool = None, - venue: list = None, - fields_of_study: list = None, - fields: list = None, - limit: int = 100 - ) -> PaginatedResults: - '''Search for papers by keyword - - :calls: `GET /paper/search `_ - - :param str query: plain-text search query string. - :param str year: (optional) restrict results to the given range of \ - publication year. - :param list publication_type: (optional) restrict results to the given \ - publication type list. - :param bool open_access_pdf: (optional) restrict results to papers \ - with public PDFs. - :param list venue: (optional) restrict results to the given venue list. - :param list fields_of_study: (optional) restrict results to given \ - field-of-study list, using the s2FieldsOfStudy paper field. - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return \ - (must be <= 100). - :returns: query results. - :rtype: :class:`semanticscholar.PaginatedResults.PaginatedResults` - ''' - - if limit < 1 or limit > 100: - raise ValueError( - 'The limit parameter must be between 1 and 100 inclusive.') - - if not fields: - fields = Paper.SEARCH_FIELDS - - url = f'{self.api_url}/paper/search' - - query += f'&year={year}' if year else '' - - if publication_types: - publication_types = ','.join(publication_types) - query += f'&publicationTypes={publication_types}' - - query += '&openAccessPdf' if open_access_pdf else '' - - if venue: - venue = ','.join(venue) - query += f'&venue={venue}' - - if fields_of_study: - fields_of_study = ','.join(fields_of_study) - query += f'&fieldsOfStudy={fields_of_study}' - - results = PaginatedResults( - self._requester, - Paper, - url, - query, - fields, - limit, - self.auth_header - ) - - return results - - def get_author( - self, - author_id: str, - fields: list = None - ) -> Author: - '''Author lookup - - :calls: `GET /author/{author_id} `_ - - :param str author_id: S2AuthorId. - :returns: author data - :rtype: :class:`semanticscholar.Author.Author` - :raises: ObjectNotFoundException: if Author ID not found. - ''' - - if not fields: - fields = Author.FIELDS - - url = f'{self.api_url}/author/{author_id}' - - fields = ','.join(fields) - parameters = f'&fields={fields}' - - data = self._requester.get_data(url, parameters, self.auth_header) - author = Author(data) - - return author - - def get_authors( - self, - author_ids: List[str], - fields: list = None - ) -> List[Author]: - '''Get details for multiple authors at once - - :calls: `POST /author/batch `_ - - :param str author_ids: list of S2AuthorId (must be <= 1000). - :returns: author data - :rtype: :class:`List` of :class:`semanticscholar.Author.Author` - :raises: BadQueryParametersException: if no author was found. - ''' - - if not fields: - fields = Author.SEARCH_FIELDS - - url = f'{self.api_url}/author/batch' - - fields = ','.join(fields) - parameters = f'&fields={fields}' - - payload = { "ids": author_ids } - - data = self._requester.get_data( - url, parameters, self.auth_header, payload) - authors = [Author(item) for item in data] - - return authors - - def get_author_papers( - self, - author_id: str, - fields: list = None, - limit: int = 1000 - ) -> PaginatedResults: - '''Get details about a author's papers - - :calls: `POST /paper/{author_id}/papers \ - `_ - - :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ - PMID, PMCID, or URL from: - - - semanticscholar.org - - arxiv.org - - aclweb.org - - acm.org - - biorxiv.org - - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return\ - (must be <= 1000). - ''' - - if limit < 1 or limit > 1000: - raise ValueError( - 'The limit parameter must be between 1 and 1000 inclusive.') - - if not fields: - fields = Paper.SEARCH_FIELDS - - url = f'{self.api_url}/author/{author_id}/papers' - - results = PaginatedResults( - requester=self._requester, - data_type=Paper, - url=url, - fields=fields, - limit=limit - ) - - return results - - def search_author( - self, - query: str, - fields: list = None, - limit: int = 1000 - ) -> PaginatedResults: - '''Search for authors by name - - :calls: `GET /author/search `_ - - :param str query: plain-text search query string. - :param list fields: (optional) list of the fields to be returned. - :param int limit: (optional) maximum number of results to return \ - (must be <= 1000). - :returns: query results. - :rtype: :class:`semanticscholar.PaginatedResults.PaginatedResults` - ''' - - if limit < 1 or limit > 1000: - raise ValueError( - 'The limit parameter must be between 1 and 1000 inclusive.') - - if not fields: - fields = Author.SEARCH_FIELDS - - url = f'{self.api_url}/author/search' - - results = PaginatedResults( - self._requester, - Author, - url, - query, - fields, - limit, - self.auth_header - ) - - return results +import warnings +from typing import List + +from semanticscholar.ApiRequester import ApiRequester +from semanticscholar.Author import Author +from semanticscholar.BaseReference import BaseReference +from semanticscholar.Citation import Citation +from semanticscholar.PaginatedResults import PaginatedResults +from semanticscholar.Paper import Paper +from semanticscholar.Reference import Reference + + +class SemanticScholar: + ''' + Main class to retrieve data from Semantic Scholar Graph API + ''' + + DEFAULT_API_URL = 'https://api.semanticscholar.org' + DEFAULT_PARTNER_API_URL = 'https://partner.semanticscholar.org' + + BASE_PATH_GRAPH = '/graph/v1' + BASE_PATH_RECOMMENDATIONS = '/recommendations/v1' + + auth_header = {} + + def __init__( + self, + timeout: int = 10, + api_key: str = None, + api_url: str = None + ) -> None: + ''' + :param float timeout: (optional) an exception is raised\ + if the server has not issued a response for timeout seconds. + :param str api_key: (optional) private API key. + :param str api_url: (optional) custom API url. + :param bool graph_api: (optional) whether use new Graph API. + ''' + + if api_url: + self.api_url = api_url + else: + self.api_url = self.DEFAULT_API_URL + + if api_key: + self.auth_header = {'x-api-key': api_key} + if not api_url: + self.api_url = self.DEFAULT_PARTNER_API_URL + + if not graph_api: + warnings.warn( + 'graph_api parameter is deprecated and will be disabled ' + + 'in the future', DeprecationWarning) + self.api_url = self.api_url.replace('/graph', '') + + self._timeout = timeout + self._requester = ApiRequester(self._timeout) + + @property + def timeout(self) -> int: + ''' + :type: :class:`int` + ''' + return self._timeout + + @timeout.setter + def timeout(self, timeout: int) -> None: + ''' + :param int timeout: + ''' + self._timeout = timeout + self._requester.timeout = timeout + + def get_paper( + self, + paper_id: str, + include_unknown_refs: bool = False, + fields: list = None + ) -> Paper: + '''Paper lookup + + :calls: `GET /paper/{paper_id} `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL, \ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param bool include_unknown_refs: (optional) include non referenced \ + paper. + :param list fields: (optional) list of the fields to be returned. + :returns: paper data + :rtype: :class:`semanticscholar.Paper.Paper` + :raises: ObjectNotFoundException: if Paper ID not found. + ''' + + if not fields: + fields = Paper.FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/{paper_id}' + + fields = ','.join(fields) + parameters = f'&fields={fields}' + if include_unknown_refs: + warnings.warn( + 'include_unknown_refs parameter is deprecated and will be disabled ' + + 'in the future', DeprecationWarning) + parameters += '&include_unknown_references=true' + + data = self._requester.get_data(url, parameters, self.auth_header) + paper = Paper(data) + + return paper + + def get_papers( + self, + paper_ids: List[str], + fields: list = None + ) -> List[Paper]: + '''Get details for multiple papers at once + + :calls: `POST /paper/batch `_ + + :param str paper_ids: list of IDs (must be <= 1000) - S2PaperId,\ + CorpusId, DOI, ArXivId, MAG, ACL, PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :returns: papers data + :rtype: :class:`List` of :class:`semanticscholar.Paper.Paper` + :raises: BadQueryParametersException: if no paper was found. + ''' + + if not fields: + fields = Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/batch' + + fields = ','.join(fields) + parameters = f'&fields={fields}' + + payload = { "ids": paper_ids } + + data = self._requester.get_data( + url, parameters, self.auth_header, payload) + papers = [Paper(item) for item in data] + + return papers + + def get_paper_authors( + self, + paper_id: str, + fields: list = None, + limit: int = 1000 + ) -> PaginatedResults: + '''Get details about a paper's authors + + :calls: `POST /paper/{paper_id}/authors \ + `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return\ + (must be <= 1000). + ''' + + if limit < 1 or limit > 1000: + raise ValueError( + 'The limit parameter must be between 1 and 1000 inclusive.') + + if not fields: + fields = [item for item in Author.SEARCH_FIELDS + if not item.startswith('papers')] + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/{paper_id}/authors' + + results = PaginatedResults( + requester=self._requester, + data_type=Author, + url=url, + fields=fields, + limit=limit + ) + + return results + + def get_paper_citations( + self, + paper_id: str, + fields: list = None, + limit: int = 1000 + ) -> PaginatedResults: + '''Get details about a paper's citations + + :calls: `POST /paper/{paper_id}/citations \ + `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return\ + (must be <= 1000). + ''' + + if limit < 1 or limit > 1000: + raise ValueError( + 'The limit parameter must be between 1 and 1000 inclusive.') + + if not fields: + fields = BaseReference.FIELDS + Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/{paper_id}/citations' + + results = PaginatedResults( + requester=self._requester, + data_type=Citation, + url=url, + fields=fields, + limit=limit + ) + + return results + + def get_paper_references( + self, + paper_id: str, + fields: list = None, + limit: int = 1000 + ) -> PaginatedResults: + '''Get details about a paper's references + + :calls: `POST /paper/{paper_id}/references \ + `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return\ + (must be <= 1000). + ''' + + if limit < 1 or limit > 1000: + raise ValueError( + 'The limit parameter must be between 1 and 1000 inclusive.') + + if not fields: + fields = BaseReference.FIELDS + Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/{paper_id}/references' + + results = PaginatedResults( + requester=self._requester, + data_type=Reference, + url=url, + fields=fields, + limit=limit + ) + + return results + + def search_paper( + self, + query: str, + year: str = None, + publication_types: list = None, + open_access_pdf: bool = None, + venue: list = None, + fields_of_study: list = None, + fields: list = None, + limit: int = 100 + ) -> PaginatedResults: + '''Search for papers by keyword + + :calls: `GET /paper/search `_ + + :param str query: plain-text search query string. + :param str year: (optional) restrict results to the given range of \ + publication year. + :param list publication_type: (optional) restrict results to the given \ + publication type list. + :param bool open_access_pdf: (optional) restrict results to papers \ + with public PDFs. + :param list venue: (optional) restrict results to the given venue list. + :param list fields_of_study: (optional) restrict results to given \ + field-of-study list, using the s2FieldsOfStudy paper field. + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return \ + (must be <= 100). + :returns: query results. + :rtype: :class:`semanticscholar.PaginatedResults.PaginatedResults` + ''' + + if limit < 1 or limit > 100: + raise ValueError( + 'The limit parameter must be between 1 and 100 inclusive.') + + if not fields: + fields = Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/paper/search' + + query += f'&year={year}' if year else '' + + if publication_types: + publication_types = ','.join(publication_types) + query += f'&publicationTypes={publication_types}' + + query += '&openAccessPdf' if open_access_pdf else '' + + if venue: + venue = ','.join(venue) + query += f'&venue={venue}' + + if fields_of_study: + fields_of_study = ','.join(fields_of_study) + query += f'&fieldsOfStudy={fields_of_study}' + + results = PaginatedResults( + self._requester, + Paper, + url, + query, + fields, + limit, + self.auth_header + ) + + return results + + def get_author( + self, + author_id: str, + fields: list = None + ) -> Author: + '''Author lookup + + :calls: `GET /author/{author_id} `_ + + :param str author_id: S2AuthorId. + :returns: author data + :rtype: :class:`semanticscholar.Author.Author` + :raises: ObjectNotFoundException: if Author ID not found. + ''' + + if not fields: + fields = Author.FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/author/{author_id}' + + fields = ','.join(fields) + parameters = f'&fields={fields}' + + data = self._requester.get_data(url, parameters, self.auth_header) + author = Author(data) + + return author + + def get_authors( + self, + author_ids: List[str], + fields: list = None + ) -> List[Author]: + '''Get details for multiple authors at once + + :calls: `POST /author/batch `_ + + :param str author_ids: list of S2AuthorId (must be <= 1000). + :returns: author data + :rtype: :class:`List` of :class:`semanticscholar.Author.Author` + :raises: BadQueryParametersException: if no author was found. + ''' + + if not fields: + fields = Author.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/author/batch' + + fields = ','.join(fields) + parameters = f'&fields={fields}' + + payload = { "ids": author_ids } + + data = self._requester.get_data( + url, parameters, self.auth_header, payload) + authors = [Author(item) for item in data] + + return authors + + def get_author_papers( + self, + author_id: str, + fields: list = None, + limit: int = 1000 + ) -> PaginatedResults: + '''Get details about a author's papers + + :calls: `POST /paper/{author_id}/papers \ + `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return\ + (must be <= 1000). + ''' + + if limit < 1 or limit > 1000: + raise ValueError( + 'The limit parameter must be between 1 and 1000 inclusive.') + + if not fields: + fields = Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/author/{author_id}/papers' + + results = PaginatedResults( + requester=self._requester, + data_type=Paper, + url=url, + fields=fields, + limit=limit + ) + + return results + + def search_author( + self, + query: str, + fields: list = None, + limit: int = 1000 + ) -> PaginatedResults: + '''Search for authors by name + + :calls: `GET /author/search `_ + + :param str query: plain-text search query string. + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of results to return \ + (must be <= 1000). + :returns: query results. + :rtype: :class:`semanticscholar.PaginatedResults.PaginatedResults` + ''' + + if limit < 1 or limit > 1000: + raise ValueError( + 'The limit parameter must be between 1 and 1000 inclusive.') + + if not fields: + fields = Author.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_GRAPH + url = f'{base_url}/author/search' + + results = PaginatedResults( + self._requester, + Author, + url, + query, + fields, + limit, + self.auth_header + ) + + return results + + def get_recommended_papers( + self, + paper_id: str, + fields: list = None, + limit: int = 100 + ) -> List[Paper]: + '''Get recommended papers for a single positive example. + + :calls: `GET /recommendations/v1/papers/forpaper/{paper_id} \ + `_ + + :param str paper_id: S2PaperId, CorpusId, DOI, ArXivId, MAG, ACL,\ + PMID, PMCID, or URL from: + + - semanticscholar.org + - arxiv.org + - aclweb.org + - acm.org + - biorxiv.org + + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of recommendations to \ + return (must be <= 500). + :returns: list of recommendations. + :rtype: :class:`List` of :class:`semanticscholar.Paper.Paper` + ''' + + if limit < 1 or limit > 500: + raise ValueError( + 'The limit parameter must be between 1 and 500 inclusive.') + + if not fields: + fields = Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_RECOMMENDATIONS + url = f'{base_url}/papers/forpaper/{paper_id}' + + fields = ','.join(fields) + parameters = f'&fields={fields}&limit={limit}' + + data = self._requester.get_data(url, parameters, self.auth_header) + papers = [Paper(item) for item in data['recommendedPapers']] + + return papers + + def get_recommended_papers_from_lists( + self, + positive_paper_ids: List[str], + negative_paper_ids: List[str] = None, + fields: list = None, + limit: int = 100 + ) -> List[Paper]: + '''Get recommended papers for lists of positive and negative examples. + + :calls: `POST /recommendations/v1/papers/ \ + `_ + + :param list positive_paper_ids: list of paper IDs \ + that the returned papers should be related to. + :param list negative_paper_ids: (optional) list of paper IDs \ + that the returned papers should not be related to. + :param list fields: (optional) list of the fields to be returned. + :param int limit: (optional) maximum number of recommendations to \ + return (must be <= 500). + :returns: list of recommendations. + :rtype: :class:`List` of :class:`semanticscholar.Paper.Paper` + ''' + + if limit < 1 or limit > 500: + raise ValueError( + 'The limit parameter must be between 1 and 500 inclusive.') + + if not fields: + fields = Paper.SEARCH_FIELDS + + base_url = self.api_url + self.BASE_PATH_RECOMMENDATIONS + url = f'{base_url}/papers/' + + fields = ','.join(fields) + parameters = f'&fields={fields}&limit={limit}' + + payload = { + "positivePaperIds": positive_paper_ids, + "negativePaperIds": negative_paper_ids + } + + data = self._requester.get_data( + url, parameters, self.auth_header, payload) + papers = [Paper(item) for item in data['recommendedPapers']] + + return papers diff --git a/tests/test_semanticscholar.py b/tests/test_semanticscholar.py index 25228f8..fe9cfc8 100644 --- a/tests/test_semanticscholar.py +++ b/tests/test_semanticscholar.py @@ -1,346 +1,375 @@ -import json -import unittest -from datetime import datetime - -import vcr -from requests.exceptions import Timeout - -from semanticscholar.Author import Author -from semanticscholar.Citation import Citation -from semanticscholar.Journal import Journal -from semanticscholar.Paper import Paper -from semanticscholar.PublicationVenue import PublicationVenue -from semanticscholar.Reference import Reference -from semanticscholar.SemanticScholar import SemanticScholar -from semanticscholar.SemanticScholarException import ( - BadQueryParametersException, ObjectNotFoundException) -from semanticscholar.Tldr import Tldr - -test_vcr = vcr.VCR( - cassette_library_dir='tests/data', - path_transformer=vcr.VCR.ensure_suffix('.yaml') -) - - -class SemanticScholarTest(unittest.TestCase): - - def setUp(self) -> None: - self.sch = SemanticScholar() - - def test_author(self) -> None: - file = open('tests/data/Author.json', encoding='utf-8') - data = json.loads(file.read()) - item = Author(data) - self.assertEqual(item.affiliations, data['affiliations']) - self.assertEqual(item.aliases, data['aliases']) - self.assertEqual(item.authorId, data['authorId']) - self.assertEqual(item.citationCount, data['citationCount']) - self.assertEqual(item.externalIds, data['externalIds']) - self.assertEqual(item.hIndex, data['hIndex']) - self.assertEqual(item.homepage, data['homepage']) - self.assertEqual(item.name, data['name']) - self.assertEqual(item.paperCount, data['paperCount']) - self.assertEqual(str(item.papers), str(data['papers'])) - self.assertEqual(item.url, data['url']) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), str(data)) - self.assertEqual(item['name'], data['name']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_citation(self): - file = open('tests/data/Citation.json', encoding='utf-8') - data = json.loads(file.read()) - item = Citation(data) - self.assertEqual(item.contexts, data['contexts']) - self.assertEqual(item.intents, data['intents']) - self.assertEqual(item.isInfluential, data['isInfluential']) - self.assertEqual(str(item.paper), str(data['citingPaper'])) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), str(data)) - self.assertEqual(item['contexts'], data['contexts']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_journal(self) -> None: - file = open('tests/data/Paper.json', encoding='utf-8') - data = json.loads(file.read())['journal'] - item = Journal(data) - self.assertEqual(item.name, data['name']) - self.assertEqual(item.pages, data['pages']) - self.assertEqual(item.volume, data['volume']) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), data['name']) - self.assertEqual(item['name'], data['name']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_paper(self) -> None: - file = open('tests/data/Paper.json', encoding='utf-8') - data = json.loads(file.read()) - item = Paper(data) - self.assertEqual(item.abstract, data['abstract']) - self.assertEqual(str(item.authors), str(data['authors'])) - self.assertEqual(item.citationCount, data['citationCount']) - self.assertEqual(str(item.citations), str(data['citations'])) - self.assertEqual(item.corpusId, data['corpusId']) - self.assertEqual(item.embedding, data['embedding']) - self.assertEqual(item.externalIds, data['externalIds']) - self.assertEqual(item.fieldsOfStudy, data['fieldsOfStudy']) - self.assertEqual(item.influentialCitationCount, - data['influentialCitationCount']) - self.assertEqual(item.isOpenAccess, data['isOpenAccess']) - self.assertEqual(str(item.journal), str(data['journal']['name'])) - self.assertEqual(item.openAccessPdf, data['openAccessPdf']) - self.assertEqual(item.paperId, data['paperId']) - self.assertEqual(item.publicationDate, datetime.strptime( - data['publicationDate'], '%Y-%m-%d')) - self.assertEqual(item.publicationTypes, data['publicationTypes']) - self.assertEqual(item.publicationVenue, data['publicationVenue']) - self.assertEqual(item.referenceCount, data['referenceCount']) - self.assertEqual(str(item.references), str(data['references'])) - self.assertEqual(item.s2FieldsOfStudy, data['s2FieldsOfStudy']) - self.assertEqual(item.title, data['title']) - self.assertEqual(str(item.tldr), data['tldr']['text']) - self.assertEqual(item.url, data['url']) - self.assertEqual(item.venue, data['venue']) - self.assertEqual(item.year, data['year']) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), str(data)) - self.assertEqual(item['title'], data['title']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_pubication_venue(self): - file = open('tests/data/Paper.json', encoding='utf-8') - data = json.loads(file.read())['citations'][0]['publicationVenue'] - item = PublicationVenue(data) - self.assertEqual(item.alternate_names, data['alternate_names']) - self.assertEqual(item.alternate_urls, data['alternate_urls']) - self.assertEqual(item.id, data['id']) - self.assertEqual(item.issn, data['issn']) - self.assertEqual(item.name, data['name']) - self.assertEqual(item.type, data['type']) - self.assertEqual(item.url, data['url']) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), str(data)) - self.assertEqual(item['name'], data['name']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_reference(self): - file = open('tests/data/Reference.json', encoding='utf-8') - data = json.loads(file.read()) - item = Reference(data) - self.assertEqual(item.contexts, data['contexts']) - self.assertEqual(item.intents, data['intents']) - self.assertEqual(item.isInfluential, data['isInfluential']) - self.assertEqual(str(item.paper), str(data['citedPaper'])) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), str(data)) - self.assertEqual(item['contexts'], data['contexts']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - def test_tldr(self) -> None: - file = open('tests/data/Paper.json', encoding='utf-8') - data = json.loads(file.read())['tldr'] - item = Tldr(data) - self.assertEqual(item.model, data['model']) - self.assertEqual(item.text, data['text']) - self.assertEqual(item.raw_data, data) - self.assertEqual(str(item), data['text']) - self.assertEqual(item['model'], data['model']) - self.assertEqual(item.keys(), data.keys()) - file.close() - - @test_vcr.use_cassette - def test_get_paper(self): - data = self.sch.get_paper('10.1093/mind/lix.236.433') - self.assertEqual(data.title, - 'Computing Machinery and Intelligence') - self.assertEqual(data.raw_data['title'], - 'Computing Machinery and Intelligence') - - @test_vcr.use_cassette - def test_get_papers(self): - list_of_paper_ids = [ - 'CorpusId:470667', - '10.2139/ssrn.2250500', - '0f40b1f08821e22e859c6050916cec3667778613'] - data = self.sch.get_papers(list_of_paper_ids) - for item in data: - with self.subTest(subtest=item.paperId): - self.assertIn( - 'E. Duflo', [author.name for author in item.authors]) - - @test_vcr.use_cassette - def test_get_paper_authors(self): - data = self.sch.get_paper_authors('CorpusID:54599684') - self.assertEqual(data.offset, 0) - self.assertEqual(data.next, 1000) - self.assertEqual(len([item for item in data]), 2870) - self.assertEqual(data[0].name, 'G. Aad') - - @test_vcr.use_cassette - def test_get_paper_citations(self): - data = self.sch.get_paper_citations('CorpusID:49313245') - self.assertEqual(data.offset, 0) - self.assertEqual(data.next, 1000) - self.assertEqual(len([item.paper.title for item in data]), 4563) - self.assertEqual( - data[0].paper.title, 'Learning to Throw With a Handful of Samples ' - 'Using Decision Transformers') - - @test_vcr.use_cassette - def test_get_paper_references(self): - data = self.sch.get_paper_references('CorpusID:1033682') - self.assertEqual(data.offset, 0) - self.assertEqual(data.next, 0) - self.assertEqual(len(data), 35) - self.assertEqual( - data[0].paper.title, 'Neural Variational Inference and Learning ' - 'in Belief Networks') - - @test_vcr.use_cassette - def test_timeout(self): - self.sch.timeout = 0.01 - self.assertEqual(self.sch.timeout, 0.01) - self.assertRaises(Timeout, - self.sch.get_paper, - '10.1093/mind/lix.236.433') - - @test_vcr.use_cassette - def test_get_author(self): - data = self.sch.get_author(2262347) - self.assertEqual(data.name, 'A. Turing') - - @test_vcr.use_cassette - def test_get_authors(self): - list_of_author_ids = ['3234559', '1726629', '1711844'] - data = self.sch.get_authors(list_of_author_ids) - list_of_author_names = ['E. Dijkstra', 'D. Parnas', 'I. Sommerville'] - self.assertCountEqual( - [item.name for item in data], list_of_author_names) - - @test_vcr.use_cassette - def test_get_author_papers(self): - data = self.sch.get_author_papers(1723755, limit=100) - self.assertEqual(data.offset, 0) - self.assertEqual(data.next, 100) - self.assertEqual(len([item for item in data]), 925) - self.assertEqual(data[0].title, 'Genetic heterogeneity and ' - 'tissue-specific patterns of tumors with multiple ' - 'PIK3CA mutations.') - - @test_vcr.use_cassette - def test_not_found(self): - methods = [self.sch.get_paper, self.sch.get_author] - for method in methods: - with self.subTest(subtest=method.__name__): - self.assertRaises(ObjectNotFoundException, method, 0) - - @test_vcr.use_cassette - def test_bad_query_parameters(self): - self.assertRaises(BadQueryParametersException, - self.sch.get_paper, - '10.1093/mind/lix.236.433', - fields=['unknown']) - - @test_vcr.use_cassette - def test_search_paper(self): - data = self.sch.search_paper('turing') - self.assertGreater(data.total, 0) - self.assertEqual(data.offset, 0) - self.assertEqual(data.next, 100) - self.assertEqual(len(data.items), 100) - self.assertEqual( - data.raw_data[0]['title'], - 'Quantum theory, the Church–Turing principle and the universal ' - 'quantum computer') - - @test_vcr.use_cassette - def test_search_paper_next_page(self): - data = self.sch.search_paper('turing') - data.next_page() - self.assertGreater(len(data), 100) - - @test_vcr.use_cassette - def test_search_paper_traversing_results(self): - data = self.sch.search_paper('turing') - all_results = [item.title for item in data] - self.assertRaises(BadQueryParametersException, data.next_page) - self.assertEqual(len(all_results), len(data.items)) - - @test_vcr.use_cassette - def test_search_paper_fields_of_study(self): - data = self.sch.search_paper('turing', fields_of_study=['Mathematics']) - self.assertEqual(data[0].s2FieldsOfStudy[0]['category'], 'Mathematics') - - @test_vcr.use_cassette - def test_search_paper_year(self): - data = self.sch.search_paper('turing', year=1936) - self.assertEqual(data[0].year, 1936) - - @test_vcr.use_cassette - def test_search_paper_year_range(self): - data = self.sch.search_paper('turing', year='1936-1937') - # assert that all results are in the range - self.assertTrue(all([1936 <= item.year <= 1937 for item in data])) - - @test_vcr.use_cassette - def test_search_paper_publication_types(self): - data = self.sch.search_paper( - 'turing', publication_types=['JournalArticle']) - self.assertTrue('JournalArticle' in data[0].publicationTypes) - data = self.sch.search_paper( - 'turing', publication_types=['Book', 'Conference']) - self.assertTrue( - 'Book' in data[0].publicationTypes or - 'Conference' in data[0].publicationTypes) - - @test_vcr.use_cassette - def test_search_paper_venue(self): - data = self.sch.search_paper('turing', venue=['ArXiv']) - self.assertEqual(data[0].venue, 'ArXiv') - - @test_vcr.use_cassette - def test_search_paper_open_access_pdf(self): - data = self.sch.search_paper('turing', open_access_pdf=True) - self.assertTrue(data[0].openAccessPdf) - - @test_vcr.use_cassette - def test_search_author(self): - data = self.sch.search_author('turing') - self.assertGreater(data.total, 0) - self.assertEqual(data.next, 0) - - @test_vcr.use_cassette - def test_limit_value_exceeded(self): - test_cases = [ - (self.sch.get_paper_authors, '10.1093/mind/lix.236.433', 1001, - 'The limit parameter must be between 1 and 1000 inclusive.'), - (self.sch.get_paper_citations, '10.1093/mind/lix.236.433', 1001, - 'The limit parameter must be between 1 and 1000 inclusive.'), - (self.sch.get_paper_references, '10.1093/mind/lix.236.433', 1001, - 'The limit parameter must be between 1 and 1000 inclusive.'), - (self.sch.get_author_papers, 1723755, 1001, - 'The limit parameter must be between 1 and 1000 inclusive.'), - (self.sch.search_author, 'turing', 1001, - 'The limit parameter must be between 1 and 1000 inclusive.'), - (self.sch.search_paper, 'turing', 101, - 'The limit parameter must be between 1 and 100 inclusive.') - ] - for method, query, upper_limit, error_message in test_cases: - with self.subTest(method=method.__name__, limit=upper_limit): - with self.assertRaises(ValueError) as context: - method(query, limit=upper_limit) - self.assertEqual(str(context.exception), error_message) - with self.subTest(method=method.__name__, limit=0): - with self.assertRaises(ValueError) as context: - method(query, limit=0) - self.assertEqual(str(context.exception), error_message) - - -if __name__ == '__main__': - unittest.main() +import json +import unittest +from datetime import datetime + +import vcr +from requests.exceptions import Timeout + +from semanticscholar.Author import Author +from semanticscholar.Citation import Citation +from semanticscholar.Journal import Journal +from semanticscholar.Paper import Paper +from semanticscholar.PublicationVenue import PublicationVenue +from semanticscholar.Reference import Reference +from semanticscholar.SemanticScholar import SemanticScholar +from semanticscholar.SemanticScholarException import ( + BadQueryParametersException, ObjectNotFoundException) +from semanticscholar.Tldr import Tldr + +test_vcr = vcr.VCR( + cassette_library_dir='tests/data', + path_transformer=vcr.VCR.ensure_suffix('.yaml') +) + + +class SemanticScholarTest(unittest.TestCase): + + def setUp(self) -> None: + self.sch = SemanticScholar() + + def test_author(self) -> None: + file = open('tests/data/Author.json', encoding='utf-8') + data = json.loads(file.read()) + item = Author(data) + self.assertEqual(item.affiliations, data['affiliations']) + self.assertEqual(item.aliases, data['aliases']) + self.assertEqual(item.authorId, data['authorId']) + self.assertEqual(item.citationCount, data['citationCount']) + self.assertEqual(item.externalIds, data['externalIds']) + self.assertEqual(item.hIndex, data['hIndex']) + self.assertEqual(item.homepage, data['homepage']) + self.assertEqual(item.name, data['name']) + self.assertEqual(item.paperCount, data['paperCount']) + self.assertEqual(str(item.papers), str(data['papers'])) + self.assertEqual(item.url, data['url']) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), str(data)) + self.assertEqual(item['name'], data['name']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_citation(self): + file = open('tests/data/Citation.json', encoding='utf-8') + data = json.loads(file.read()) + item = Citation(data) + self.assertEqual(item.contexts, data['contexts']) + self.assertEqual(item.intents, data['intents']) + self.assertEqual(item.isInfluential, data['isInfluential']) + self.assertEqual(str(item.paper), str(data['citingPaper'])) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), str(data)) + self.assertEqual(item['contexts'], data['contexts']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_journal(self) -> None: + file = open('tests/data/Paper.json', encoding='utf-8') + data = json.loads(file.read())['journal'] + item = Journal(data) + self.assertEqual(item.name, data['name']) + self.assertEqual(item.pages, data['pages']) + self.assertEqual(item.volume, data['volume']) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), data['name']) + self.assertEqual(item['name'], data['name']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_paper(self) -> None: + file = open('tests/data/Paper.json', encoding='utf-8') + data = json.loads(file.read()) + item = Paper(data) + self.assertEqual(item.abstract, data['abstract']) + self.assertEqual(str(item.authors), str(data['authors'])) + self.assertEqual(item.citationCount, data['citationCount']) + self.assertEqual(str(item.citations), str(data['citations'])) + self.assertEqual(item.corpusId, data['corpusId']) + self.assertEqual(item.embedding, data['embedding']) + self.assertEqual(item.externalIds, data['externalIds']) + self.assertEqual(item.fieldsOfStudy, data['fieldsOfStudy']) + self.assertEqual(item.influentialCitationCount, + data['influentialCitationCount']) + self.assertEqual(item.isOpenAccess, data['isOpenAccess']) + self.assertEqual(str(item.journal), str(data['journal']['name'])) + self.assertEqual(item.openAccessPdf, data['openAccessPdf']) + self.assertEqual(item.paperId, data['paperId']) + self.assertEqual(item.publicationDate, datetime.strptime( + data['publicationDate'], '%Y-%m-%d')) + self.assertEqual(item.publicationTypes, data['publicationTypes']) + self.assertEqual(item.publicationVenue, data['publicationVenue']) + self.assertEqual(item.referenceCount, data['referenceCount']) + self.assertEqual(str(item.references), str(data['references'])) + self.assertEqual(item.s2FieldsOfStudy, data['s2FieldsOfStudy']) + self.assertEqual(item.title, data['title']) + self.assertEqual(str(item.tldr), data['tldr']['text']) + self.assertEqual(item.url, data['url']) + self.assertEqual(item.venue, data['venue']) + self.assertEqual(item.year, data['year']) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), str(data)) + self.assertEqual(item['title'], data['title']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_pubication_venue(self): + file = open('tests/data/Paper.json', encoding='utf-8') + data = json.loads(file.read())['citations'][0]['publicationVenue'] + item = PublicationVenue(data) + self.assertEqual(item.alternate_names, data['alternate_names']) + self.assertEqual(item.alternate_urls, data['alternate_urls']) + self.assertEqual(item.id, data['id']) + self.assertEqual(item.issn, data['issn']) + self.assertEqual(item.name, data['name']) + self.assertEqual(item.type, data['type']) + self.assertEqual(item.url, data['url']) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), str(data)) + self.assertEqual(item['name'], data['name']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_reference(self): + file = open('tests/data/Reference.json', encoding='utf-8') + data = json.loads(file.read()) + item = Reference(data) + self.assertEqual(item.contexts, data['contexts']) + self.assertEqual(item.intents, data['intents']) + self.assertEqual(item.isInfluential, data['isInfluential']) + self.assertEqual(str(item.paper), str(data['citedPaper'])) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), str(data)) + self.assertEqual(item['contexts'], data['contexts']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + def test_tldr(self) -> None: + file = open('tests/data/Paper.json', encoding='utf-8') + data = json.loads(file.read())['tldr'] + item = Tldr(data) + self.assertEqual(item.model, data['model']) + self.assertEqual(item.text, data['text']) + self.assertEqual(item.raw_data, data) + self.assertEqual(str(item), data['text']) + self.assertEqual(item['model'], data['model']) + self.assertEqual(item.keys(), data.keys()) + file.close() + + @test_vcr.use_cassette + def test_get_paper(self): + data = self.sch.get_paper('10.1093/mind/lix.236.433') + self.assertEqual(data.title, + 'Computing Machinery and Intelligence') + self.assertEqual(data.raw_data['title'], + 'Computing Machinery and Intelligence') + + @test_vcr.use_cassette + def test_get_papers(self): + list_of_paper_ids = [ + 'CorpusId:470667', + '10.2139/ssrn.2250500', + '0f40b1f08821e22e859c6050916cec3667778613'] + data = self.sch.get_papers(list_of_paper_ids) + for item in data: + with self.subTest(subtest=item.paperId): + self.assertIn( + 'E. Duflo', [author.name for author in item.authors]) + + @test_vcr.use_cassette + def test_get_paper_authors(self): + data = self.sch.get_paper_authors('CorpusID:54599684') + self.assertEqual(data.offset, 0) + self.assertEqual(data.next, 1000) + self.assertEqual(len([item for item in data]), 2870) + self.assertEqual(data[0].name, 'G. Aad') + + @test_vcr.use_cassette + def test_get_paper_citations(self): + data = self.sch.get_paper_citations('CorpusID:49313245') + self.assertEqual(data.offset, 0) + self.assertEqual(data.next, 1000) + self.assertEqual(len([item.paper.title for item in data]), 4563) + self.assertEqual( + data[0].paper.title, 'Learning to Throw With a Handful of Samples ' + 'Using Decision Transformers') + + @test_vcr.use_cassette + def test_get_paper_references(self): + data = self.sch.get_paper_references('CorpusID:1033682') + self.assertEqual(data.offset, 0) + self.assertEqual(data.next, 0) + self.assertEqual(len(data), 35) + self.assertEqual( + data[0].paper.title, 'Neural Variational Inference and Learning ' + 'in Belief Networks') + + @test_vcr.use_cassette + def test_timeout(self): + self.sch.timeout = 0.01 + self.assertEqual(self.sch.timeout, 0.01) + self.assertRaises(Timeout, + self.sch.get_paper, + '10.1093/mind/lix.236.433') + + @test_vcr.use_cassette + def test_get_author(self): + data = self.sch.get_author(2262347) + self.assertEqual(data.name, 'A. Turing') + + @test_vcr.use_cassette + def test_get_authors(self): + list_of_author_ids = ['3234559', '1726629', '1711844'] + data = self.sch.get_authors(list_of_author_ids) + list_of_author_names = ['E. Dijkstra', 'D. Parnas', 'I. Sommerville'] + self.assertCountEqual( + [item.name for item in data], list_of_author_names) + + @test_vcr.use_cassette + def test_get_author_papers(self): + data = self.sch.get_author_papers(1723755, limit=100) + self.assertEqual(data.offset, 0) + self.assertEqual(data.next, 100) + self.assertEqual(len([item for item in data]), 925) + self.assertEqual(data[0].title, 'Genetic heterogeneity and ' + 'tissue-specific patterns of tumors with multiple ' + 'PIK3CA mutations.') + + @test_vcr.use_cassette + def test_not_found(self): + methods = [self.sch.get_paper, self.sch.get_author] + for method in methods: + with self.subTest(subtest=method.__name__): + self.assertRaises(ObjectNotFoundException, method, 0) + + @test_vcr.use_cassette + def test_bad_query_parameters(self): + self.assertRaises(BadQueryParametersException, + self.sch.get_paper, + '10.1093/mind/lix.236.433', + fields=['unknown']) + + @test_vcr.use_cassette + def test_search_paper(self): + data = self.sch.search_paper('turing') + self.assertGreater(data.total, 0) + self.assertEqual(data.offset, 0) + self.assertEqual(data.next, 100) + self.assertEqual(len(data.items), 100) + self.assertEqual( + data.raw_data[0]['title'], + 'Quantum theory, the Church–Turing principle and the universal ' + 'quantum computer') + + @test_vcr.use_cassette + def test_search_paper_next_page(self): + data = self.sch.search_paper('turing') + data.next_page() + self.assertGreater(len(data), 100) + + @test_vcr.use_cassette + def test_search_paper_traversing_results(self): + data = self.sch.search_paper('turing') + all_results = [item.title for item in data] + self.assertRaises(BadQueryParametersException, data.next_page) + self.assertEqual(len(all_results), len(data.items)) + + @test_vcr.use_cassette + def test_search_paper_fields_of_study(self): + data = self.sch.search_paper('turing', fields_of_study=['Mathematics']) + self.assertEqual(data[0].s2FieldsOfStudy[0]['category'], 'Mathematics') + + @test_vcr.use_cassette + def test_search_paper_year(self): + data = self.sch.search_paper('turing', year=1936) + self.assertEqual(data[0].year, 1936) + + @test_vcr.use_cassette + def test_search_paper_year_range(self): + data = self.sch.search_paper('turing', year='1936-1937') + # assert that all results are in the range + self.assertTrue(all([1936 <= item.year <= 1937 for item in data])) + + @test_vcr.use_cassette + def test_search_paper_publication_types(self): + data = self.sch.search_paper( + 'turing', publication_types=['JournalArticle']) + self.assertTrue('JournalArticle' in data[0].publicationTypes) + data = self.sch.search_paper( + 'turing', publication_types=['Book', 'Conference']) + self.assertTrue( + 'Book' in data[0].publicationTypes or + 'Conference' in data[0].publicationTypes) + + @test_vcr.use_cassette + def test_search_paper_venue(self): + data = self.sch.search_paper('turing', venue=['ArXiv']) + self.assertEqual(data[0].venue, 'ArXiv') + + @test_vcr.use_cassette + def test_search_paper_open_access_pdf(self): + data = self.sch.search_paper('turing', open_access_pdf=True) + self.assertTrue(data[0].openAccessPdf) + + @test_vcr.use_cassette + def test_search_author(self): + data = self.sch.search_author('turing') + self.assertGreater(data.total, 0) + self.assertEqual(data.next, 0) + + @test_vcr.use_cassette + def test_get_recommended_papers(self): + data = self.sch.get_recommended_papers('10.1145/3544585.3544600') + self.assertEqual(len(data), 100) + + @test_vcr.use_cassette + def test_get_recommended_papers_from_lists(self): + data = self.sch.get_recommended_papers_from_lists( + ['10.1145/3544585.3544600'], ['10.1145/301250.301271']) + self.assertEqual(len(data), 100) + + @test_vcr.use_cassette + def test_get_recommended_papers_from_lists_positive_only(self): + data = self.sch.get_recommended_papers_from_lists( + ['10.1145/3544585.3544600', '10.1145/301250.301271']) + self.assertEqual(len(data), 100) + + @test_vcr.use_cassette + def test_get_recommended_papers_from_lists_negative_only(self): + self.assertRaises(BadQueryParametersException, + self.sch.get_recommended_papers_from_lists, + [], + ['10.1145/3544585.3544600']) + + @test_vcr.use_cassette + def test_limit_value_exceeded(self): + test_cases = [ + (self.sch.get_paper_authors, '10.1093/mind/lix.236.433', 1001, + 'The limit parameter must be between 1 and 1000 inclusive.'), + (self.sch.get_paper_citations, '10.1093/mind/lix.236.433', 1001, + 'The limit parameter must be between 1 and 1000 inclusive.'), + (self.sch.get_paper_references, '10.1093/mind/lix.236.433', 1001, + 'The limit parameter must be between 1 and 1000 inclusive.'), + (self.sch.get_author_papers, 1723755, 1001, + 'The limit parameter must be between 1 and 1000 inclusive.'), + (self.sch.search_author, 'turing', 1001, + 'The limit parameter must be between 1 and 1000 inclusive.'), + (self.sch.search_paper, 'turing', 101, + 'The limit parameter must be between 1 and 100 inclusive.'), + (self.sch.get_recommended_papers, '10.1145/3544585.3544600', 501, + 'The limit parameter must be between 1 and 500 inclusive.'), + (self.sch.get_recommended_papers_from_lists, + ['10.1145/3544585.3544600'], 501, + 'The limit parameter must be between 1 and 500 inclusive.'), + ] + for method, query, upper_limit, error_message in test_cases: + with self.subTest(method=method.__name__, limit=upper_limit): + with self.assertRaises(ValueError) as context: + method(query, limit=upper_limit) + self.assertEqual(str(context.exception), error_message) + with self.subTest(method=method.__name__, limit=0): + with self.assertRaises(ValueError) as context: + method(query, limit=0) + self.assertEqual(str(context.exception), error_message) + + +if __name__ == '__main__': + unittest.main()