Skip to content

Commit

Permalink
replace requests get function to allow timeout #95
Browse files Browse the repository at this point in the history
  • Loading branch information
Herfort committed Mar 5, 2019
1 parent cb2a092 commit 877dcfb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mapswipe_workers/basic/BaseFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@
########################################################################################################################
# INIT #
########################################################################################################################
class myRequestsSession(requests.Session):
"""
The class to replace the get function to allow a timeout parameter to be set.
"""

def __init__(self):
super(myRequestsSession, self).__init__()

def get(self, request_ref, headers, timeout=30):
print('Using customized get request with a timeout of 30 seconds.')
return super(myRequestsSession, self).get(request_ref, headers=headers, timeout=timeout)


def get_environment(modus='development'):
"""
The function to get the firebase and postgres configuration
Expand Down Expand Up @@ -145,6 +158,7 @@ def get_projects(firebase, postgres, filter='all'):
projects_list = []

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
all_projects = fb_db.child("projects").get().val()

# return empty list if there are no projects in firebase
Expand Down Expand Up @@ -230,6 +244,7 @@ def project_exists_firebase(project_id, firebase):
"""

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
project_data = fb_db.child("projects").child(project_id).get().val()

if not project_data:
Expand Down Expand Up @@ -320,6 +335,7 @@ def get_new_imports(firebase):
"""

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
all_imports = fb_db.child("imports").get().val()

new_imports = {}
Expand Down Expand Up @@ -554,6 +570,7 @@ def get_results_from_firebase(firebase):
"""

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
results = fb_db.child("results").get().val()
return results

Expand Down Expand Up @@ -587,6 +604,7 @@ def delete_firebase_results(firebase, all_results):
"""

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
# we will use multilocation update to deete the entries
# therefore we crate an dict with the items we want to delete
data = {}
Expand Down Expand Up @@ -781,6 +799,7 @@ def run_transfer_results(modus):
logging.warning('ALL - run_transfer_results - removed "results.json" file')

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get

# this tries to set the max pool connections to 100
adapter = requests.adapters.HTTPAdapter(max_retries=5, pool_connections=100, pool_maxsize=100)
Expand Down Expand Up @@ -831,6 +850,7 @@ def export_all_projects(firebase):
os.mkdir(DATA_PATH)

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
all_projects = fb_db.child("projects").get().val()

if not all_projects:
Expand Down Expand Up @@ -872,6 +892,7 @@ def export_users_and_stats(firebase):
os.mkdir(DATA_PATH)

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
all_users = fb_db.child("users").get().val()

if not all_users:
Expand Down Expand Up @@ -1015,6 +1036,7 @@ def delete_project_firebase(project_id, import_key, firebase):
"""

fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get
# we create this element to do a multi location update
data = {
"projects/{}/".format(project_id): None,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import requests
from mapswipe_workers.basic import BaseFunctions


class myRequestsSession(requests.Session):

def __init__(self):
super(myRequestsSession, self).__init__()

def get(self, request_ref, headers, timeout=30):
print('Using customized get request with a timeout of 30 seconds.')
return super(myRequestsSession, self).get(request_ref, headers=headers, timeout=timeout)


def test_firebase_connection():

firebase, postgres = BaseFunctions.get_environment('development')
fb_db = firebase.database()
fb_db.requests.get = myRequestsSession().get


request_object = fb_db.child("groups").child("1002").get().val()


if __name__ == '__main__':
test_firebase_connection()
print("Everything passed")
4 changes: 4 additions & 0 deletions tests/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def upload_sample_data_to_firebase():
firebase, postgres = BaseFunctions.get_environment('development')
fb_db = firebase.database()

adapter = requests.adapters.HTTPAdapter(max_retries=5, pool_connections=100, pool_maxsize=100)
for scheme in ('http://', 'https://'):
fb_db.requests.mount(scheme, adapter)

with open('sample_data.json') as f:
sample_data = json.load(f)

Expand Down

0 comments on commit 877dcfb

Please sign in to comment.