Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions dspace_rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
"""
import json
import logging
import os
from uuid import UUID

import requests
from requests import Request
import os
from uuid import UUID

from .models import *

__all__ = ['DSpaceClient']
Expand All @@ -37,9 +38,13 @@ def parse_json(response):
"""
response_json = None
try:
response_json = response.json()
if response is not None:
response_json = response.json()
except ValueError as err:
_logger.error(f'Error parsing response JSON: {err}. Body text: {response.text}')
if response is not None:
_logger.error(f'Error parsing response JSON: {err}. Body text: {response.text}')
else:
_logger.error(f'Error parsing response JSON: {err}. Response is None')
return response_json


Expand Down Expand Up @@ -73,6 +78,8 @@ class DSpaceClient:
if 'USER_AGENT' in os.environ:
USER_AGENT = os.environ['USER_AGENT']
verbose = False
ITER_PAGE_SIZE = 20
PROXY_DICT = dict(http=os.environ["PROXY_URL"],https=os.environ["PROXY_URL"]) if "PROXY_URL" in os.environ else dict()

# Simple enum for patch operation types
class PatchOperation:
Expand All @@ -82,7 +89,7 @@ class PatchOperation:
MOVE = 'move'

def __init__(self, api_endpoint=API_ENDPOINT, username=USERNAME, password=PASSWORD, solr_endpoint=SOLR_ENDPOINT,
solr_auth=SOLR_AUTH, fake_user_agent=False):
solr_auth=SOLR_AUTH, fake_user_agent=False, proxies=PROXY_DICT):
"""
Accept optional API endpoint, username, password arguments using the OS environment variables as defaults
:param api_endpoint: base path to DSpace REST API, eg. http://localhost:8080/server/api
Expand All @@ -95,6 +102,7 @@ def __init__(self, api_endpoint=API_ENDPOINT, username=USERNAME, password=PASSWO
self.USERNAME = username
self.PASSWORD = password
self.SOLR_ENDPOINT = solr_endpoint
self.proxies = proxies
self.solr = None
self._last_err = None
try:
Expand Down Expand Up @@ -128,7 +136,8 @@ def authenticate(self, retry=False):
# Set headers for requests made during authentication
# Get and update CSRF token
r = self.session.post(self.LOGIN_URL, data={'user': self.USERNAME, 'password': self.PASSWORD},
headers=self.auth_request_headers)
headers=self.auth_request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand All @@ -154,7 +163,8 @@ def authenticate(self, retry=False):
self.session.headers.update({'Authorization': r.headers.get('Authorization')})

# Get and check authentication status
r = self.session.get(f'{self.API_ENDPOINT}/authn/status', headers=self.request_headers)
r = self.session.get(f'{self.API_ENDPOINT}/authn/status', headers=self.request_headers,
proxies=self.proxies)
if r.status_code == 200:
r_json = parse_json(r)
if 'authenticated' in r_json and r_json['authenticated'] is True:
Expand Down Expand Up @@ -203,7 +213,8 @@ def api_get(self, url, params=None, data=None, headers=None):
self._last_err = None
if headers is None:
headers = self.request_headers
r = self.session.get(url, params=params, data=data, headers=headers)
r = self.session.get(url, params=params, data=data, headers=headers,
proxies=self.proxies)
self.update_token(r)
return r

Expand All @@ -218,7 +229,8 @@ def api_post(self, url, params, json, retry=False):
@return: Response from API
"""
self._last_err = None
r = self.session.post(url, json=json, params=params, headers=self.request_headers)
r = self.session.post(url, json=json, params=params, headers=self.request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand Down Expand Up @@ -262,7 +274,8 @@ def api_post_uri(self, url, params, uri_list, retry=False):
@return: Response from API
"""
self._last_err = None
r = self.session.post(url, data=uri_list, params=params, headers=self.list_request_headers)
r = self.session.post(url, data=uri_list, params=params, headers=self.list_request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand Down Expand Up @@ -291,7 +304,8 @@ def api_put(self, url, params, json, retry=False):
@return: Response from API
"""
self._last_err = None
r = self.session.put(url, params=params, json=json, headers=self.request_headers)
r = self.session.put(url, params=params, json=json, headers=self.request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand Down Expand Up @@ -321,7 +335,8 @@ def api_delete(self, url, params, retry=False):
@return: Response from API
"""
self._last_err = None
r = self.session.delete(url, params=params, headers=self.request_headers)
r = self.session.delete(url, params=params, headers=self.request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand All @@ -341,12 +356,13 @@ def api_delete(self, url, params, retry=False):

return r

def api_patch(self, url, operation, path, value, retry=False):
def api_patch(self, url, operation, path, value, params=None, retry=False):
"""
@param url: DSpace REST API URL
@param operation: 'add', 'remove', 'replace', or 'move' (see PatchOperation enumeration)
@param path: path to perform operation - eg, metadata, withdrawn, etc.
@param value: new value for add or replace operations, or 'original' path for move operations
@param params: Optional parameters
@param retry: Has this method already been retried? Used if we need to refresh XSRF.
@return:
@see https://github.com/DSpace/RestContract/blob/main/metadata-patch.md
Expand Down Expand Up @@ -377,7 +393,8 @@ def api_patch(self, url, operation, path, value, retry=False):

# set headers
# perform patch request
r = self.session.patch(url, json=[data], headers=self.request_headers)
r = self.session.patch(url, json=[data], params=params, headers=self.request_headers,
proxies=self.proxies)
self.update_token(r)

if r.status_code == 403:
Expand All @@ -392,7 +409,7 @@ def api_patch(self, url, operation, path, value, retry=False):
_logger.warning(f'Too many retries updating token: {r.status_code}: {r.text}')
else:
_logger.debug("Retrying request with updated CSRF token")
return self.api_patch(url, operation, path, value, True)
return self.api_patch(url, operation, path, value, params, True)
elif r.status_code == 200:
# 200 Success
_logger.info(f'successful patch update to {r.json()["type"]} {r.json()["id"]}')
Expand Down Expand Up @@ -727,7 +744,7 @@ def create_bitstream(self, bundle=None, name=None, path=None, mime=None, metadat
h.update({'Content-Encoding': 'gzip', 'User-Agent': self.USER_AGENT})
req = Request('POST', url, data=payload, headers=h, files=files)
prepared_req = self.session.prepare_request(req)
r = self.session.send(prepared_req)
r = self.session.send(prepared_req, proxies=self.proxies)
if 'DSPACE-XSRF-TOKEN' in r.headers:
t = r.headers['DSPACE-XSRF-TOKEN']
_logger.debug('Updating token to ' + t)
Expand Down Expand Up @@ -922,7 +939,7 @@ def get_items(self, page=0, size=20):
r = self.api_get(url, params=params)
r_json = parse_json(response=r)
if '_embedded' in r_json:
if 'collections' in r_json['_embedded']:
if 'items' in r_json['_embedded']:
for item_resource in r_json['_embedded']['items']:
items.append(Item(item_resource))
elif 'uuid' in r_json:
Expand Down
22 changes: 13 additions & 9 deletions dspace_rest_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, api_resource=None):
self.links = api_resource['_links'].copy()
else:
self.links = {'self': {'href': None}}
if '_embedded' in api_resource:
self.embedded = api_resource['_embedded'].copy()
else:
self.embedded = {}

class AddressableHALResource(HALResource):
id = None
Expand Down Expand Up @@ -421,12 +425,12 @@ class User(SimpleDSpaceObject):
Extends DSpaceObject to implement specific attributes and methods for users (aka. EPersons)
"""
type = 'user'
name = None,
netid = None,
lastActive = None,
canLogIn = False,
email = None,
requireCertificate = False,
name = None
netid = None
lastActive = None
canLogIn = False
email = None
requireCertificate = False
selfRegistered = False

def __init__(self, api_resource=None):
Expand Down Expand Up @@ -473,11 +477,11 @@ def __init__(self, api_resource):
if 'lastModified' in api_resource:
self.lastModified = api_resource['lastModified']
if 'step' in api_resource:
self.step = api_resource['lastModified']
self.step = api_resource['step']
if 'sections' in api_resource:
self.sections = api_resource['sections'].copy()
if 'type' in api_resource:
self.lastModified = api_resource['lastModified']
self.type = api_resource['type']

def as_dict(self):
parent_dict = super(InProgressSubmission, self).as_dict()
Expand Down Expand Up @@ -508,7 +512,7 @@ def __init__(self, api_resource):
if 'label' in api_resource:
self.label = api_resource['label']
if 'type' in api_resource:
self.label = api_resource['type']
self.type = api_resource['type']

class RelationshipType(AddressableHALResource):
"""
Expand Down