Skip to content

Commit

Permalink
Merge pull request #302 from fablegroup/thread-safety
Browse files Browse the repository at this point in the history
Ensure requests.Session is not shared between threads
  • Loading branch information
olucurious committed Feb 16, 2022
2 parents 1b4e7d0 + bfe2a6f commit f5c4e4d
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions pyfcm/baseapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
import threading

import requests
from requests.adapters import HTTPAdapter
Expand Down Expand Up @@ -51,13 +52,8 @@ def __init__(self, api_key=None, proxy_dict=None, env=None, json_encoder=None, a
raise AuthenticationError("Please provide the api_key in the google-services.json file")

self.FCM_REQ_PROXIES = None
self.requests_session = requests.Session()
retries = Retry(backoff_factor=1, status_forcelist=[502, 503],
allowed_methods=(Retry.DEFAULT_ALLOWED_METHODS | frozenset(['POST'])))
self.requests_session.mount('http://', adapter or HTTPAdapter(max_retries=retries))
self.requests_session.mount('https://', adapter or HTTPAdapter(max_retries=retries))
self.requests_session.headers.update(self.request_headers())
self.requests_session.mount(self.INFO_END_POINT, HTTPAdapter(max_retries=self.INFO_RETRIES))
self.custom_adapter = adapter
self.thread_local = threading.local()

if proxy_dict and isinstance(proxy_dict, dict) and (('http' in proxy_dict) or ('https' in proxy_dict)):
self.FCM_REQ_PROXIES = proxy_dict
Expand All @@ -73,6 +69,24 @@ def __init__(self, api_key=None, proxy_dict=None, env=None, json_encoder=None, a

self.json_encoder = json_encoder

@property
def requests_session(self):
if getattr(self.thread_local, "requests_session", None) is None:
retries = Retry(
backoff_factor=1,
status_forcelist=[502, 503],
allowed_methods=(Retry.DEFAULT_ALLOWED_METHODS | frozenset(["POST"])),
)
adapter = self.custom_adapter or HTTPAdapter(max_retries=retries)
self.thread_local.requests_session = requests.Session()
self.thread_local.requests_session.mount("http://", adapter)
self.thread_local.requests_session.mount("https://", adapter)
self.thread_local.requests_session.headers.update(self.request_headers())
self.thread_local.requests_session.mount(
self.INFO_END_POINT, HTTPAdapter(max_retries=self.INFO_RETRIES)
)
return self.thread_local.requests_session

def request_headers(self):
"""
Generates request headers including Content-Type and Authorization
Expand Down Expand Up @@ -484,5 +498,3 @@ def send_async_request(self,params_list,timeout):
responses = asyncio.new_event_loop().run_until_complete(fetch_tasks(end_point=self.FCM_END_POINT,headers=self.request_headers(),payloads=payloads,timeout=timeout))

return responses


0 comments on commit f5c4e4d

Please sign in to comment.