Skip to content

Commit

Permalink
merged vinodc branch for adding https feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Alberto Paro committed Feb 15, 2012
2 parents ed796d9 + 967d157 commit e6a78c1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
41 changes: 31 additions & 10 deletions pyes/connection_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import threading
import time
import base64
import requests
from pyes.exceptions import NoServerAvailable
from httplib import HTTPConnection
Expand All @@ -19,7 +20,7 @@
Work taken from pycassa
"""

DEFAULT_SERVER = ("127.0.0.1", 9200)
DEFAULT_SERVER = ("http", "127.0.0.1", 9200)
#API_VERSION = VERSION.split('.')

log = logging.getLogger('pyes')
Expand All @@ -28,28 +29,38 @@
class ClientTransport(object):
"""Encapsulation of a client session."""

def __init__(self, server, framed_transport, timeout, recycle):
self.host, self.port = server
def __init__(self, server, framed_transport, timeout, recycle, basic_auth):
self.connection_type, self.host, self.port = server
self.timeout = timeout
self.headers = {}
#self.client = TimeoutHttpConnectionPool(host, port, timeout)
#setattr(self.client, "execute", self.execute)
if recycle:
self.recycle = time.time() + recycle + random.uniform(0, recycle * 0.1)
else:
self.recycle = None

if basic_auth:
username = basic_auth.get('username')
password = basic_auth.get('password')
base64string = base64.encodestring('%s:%s' %
(username, password))[:-1]
self.headers["Authorization"] = ("Basic %s" % base64string)

def execute(self, request):
"""
Execute a request and return a response
"""
s = requests.session()
response = s.request(method=Method._VALUES_TO_NAMES[request.method],
headers = self.headers.copy()
headers.update(request.headers)
response = requests.request(method=Method._VALUES_TO_NAMES[request.method],
url="http://%s:%s%s" % (self.host, self.port, request.uri), params=request.parameters,
data=request.body, headers=request.headers)
return RestResponse(status=response.status_code, body=response.content, headers=response.headers)

def connect(servers=None, framed_transport=False, timeout=None,
retry_time=60, recycle=None, round_robin=None, max_retries=3):
retry_time=60, recycle=None, round_robin=None,
max_retries=3, basic_auth=None):
"""
Constructs a single ElastiSearch connection. Connects to a randomly chosen
server on the list.
Expand All @@ -66,7 +77,7 @@ def connect(servers=None, framed_transport=False, timeout=None,
servers : [server]
List of ES servers with format: "hostname:port"
Default: [("127.0.0.1", 9200)]
Default: [("http", "127.0.0.1", 9200)]
framed_transport: bool
If True, use a TFramedTransport instead of a TBufferedTransport
timeout: float
Expand All @@ -84,6 +95,13 @@ def connect(servers=None, framed_transport=False, timeout=None,
max_retries: int
Max retry time on connection down
basic_auth: dict
Use HTTP Basic Auth. Use ssl while using basic auth to keep the
password from being transmitted in the clear.
Expects keys:
* username
* password
round_robin: bool
*DEPRECATED*
Expand All @@ -95,7 +113,8 @@ def connect(servers=None, framed_transport=False, timeout=None,
if servers is None:
servers = [DEFAULT_SERVER]
return ThreadLocalConnection(servers, framed_transport, timeout,
retry_time, recycle, max_retries=max_retries)
retry_time, recycle, max_retries=max_retries,
basic_auth=basic_auth)

connect_thread_local = connect

Expand Down Expand Up @@ -139,12 +158,13 @@ def mark_dead(self, server):

class ThreadLocalConnection(object):
def __init__(self, servers, framed_transport=False, timeout=None,
retry_time=10, recycle=None, max_retries=3):
retry_time=10, recycle=None, max_retries=3, basic_auth=None):
self._servers = ServerSet(servers, retry_time)
self._framed_transport = framed_transport #not used in http
self._timeout = timeout
self._recycle = recycle
self._max_retries = max_retries
self._basic_auth = basic_auth
self._local = threading.local()

def __getattr__(self, attr):
Expand Down Expand Up @@ -181,7 +201,8 @@ def connect(self):
server = self._servers.get()
log.debug('Connecting to %s', server)
self._local.conn = ClientTransport(server, self._framed_transport,
self._timeout, self._recycle)
self._timeout, self._recycle,
self._basic_auth)
except (socket.timeout, socket.error):
log.warning('Connection to %s failed.', server)
self._servers.mark_dead(server)
Expand Down
21 changes: 14 additions & 7 deletions pyes/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,25 @@ def __init__(self, server="localhost:9200", timeout=5.0, bulk_size=400,
default_types=None,
dump_curl=False,
model=ElasticSearchModel,
basic_auth=None,
raise_on_bulk_item_failure=False):
"""
Init a es object.
Servers can be defined in different forms:
- host:port with protocol guess (i.e. 127.0.0.1:9200 protocol -> http
127.0.0.1:9500 protocol -> thrift )
- type://host:port (i.e. http://127.0.0.1:9200 thrift://127.0.0.1:9500)
- type://host:port (i.e. http://127.0.0.1:9200 https://127.0.0.1:9200 thrift://127.0.0.1:9500)
- (type, host, port) (i.e. tuple ("http", "127.0.0.1", "9200") ("thrift", "127.0.0.1", "9500")). This is the prefered form.
- (type, host, port) (i.e. tuple ("http", "127.0.0.1", "9200") ("https", "127.0.0.1", "9200")
("thrift", "127.0.0.1", "9500")). This is the prefered form.
:param server: the server name, it can be a list of servers.
:param timeout: timeout for a call
:param bulk_size: size of bulk operation
:param encoder: tojson encoder
:param max_retries: number of max retries for server if a server is down
:param basic_auth: Dictionary with 'username' and 'password' keys for HTTP Basic Auth.
:param model: used to objectify the dictinary. If None, the raw dict is returned.
Expand All @@ -262,6 +265,7 @@ def __init__(self, server="localhost:9200", timeout=5.0, bulk_size=400,
self.cluster = None
self.debug_dump = False
self.cluster_name = "undefined"
self.basic_auth = basic_auth
self.connection = None

if model is None:
Expand Down Expand Up @@ -340,7 +344,7 @@ def check_format(host, port, _type=None):
else:
raise RuntimeError("Unable to recognize port-type: \"%s\"" % port)

if _type not in ["thrift", "http"]:
if _type not in ["thrift", "http", "https"]:
raise RuntimeError("Unable to recognize protocol: \"%s\"" % _type)

if _type == "thrift" and not thrift_enable:
Expand All @@ -355,7 +359,7 @@ def check_format(host, port, _type=None):
_type, host, port = server
check_format(host=host, port=port, _type=_type)
elif isinstance(server, basestring):
if server.startswith(("thrift:", "http:")):
if server.startswith(("thrift:", "http:", "https:")):
tokens = [t.strip("/") for t in server.split(":") if t.strip("/")]
if len(tokens) == 3:
check_format(tokens[1], tokens[2], tokens[0])
Expand All @@ -381,11 +385,14 @@ def _init_connection(self):
raise RuntimeError("No server defined")

_type, host, port = random.choice(self.servers)
if _type == "http":
self.connection = http_connect([(host, port) for _type, host, port in self.servers if _type == "http"], timeout=self.timeout, max_retries=self.max_retries)
if _type in ["http", "https"]:
self.connection = http_connect([(_type, host, port) for _type, host, port in self.servers if _type in ["http", "https"]],
timeout=self.timeout, basic_auth=self.basic_auth,
max_retries=self.max_retries)
return
elif _type == "thrift":
self.connection = thrift_connect([(host, port) for _type, host, port in self.servers if _type == "thrift"], timeout=self.timeout, max_retries=self.max_retries)
self.connection = thrift_connect([(host, port) for _type, host, port in self.servers if _type == "thrift"],
timeout=self.timeout, max_retries=self.max_retries)

def _discovery(self):
"""
Expand Down

0 comments on commit e6a78c1

Please sign in to comment.