forked from sloria/environs
/
environs.py
277 lines (226 loc) · 9.67 KB
/
environs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# -*- coding: utf-8 -*-
import contextlib
import inspect
import functools
import json as pyjson
import os
import re
try:
import urllib.parse as urlparse
except ImportError:
# Python 2
import urlparse
import marshmallow as ma
from dotenv import load_dotenv
from dotenv.main import _walk_to_root
__version__ = "4.1.0"
__all__ = ["EnvError", "Env"]
MARSHMALLOW_VERSION_INFO = tuple(
[int(part) for part in ma.__version__.split(".") if part.isdigit()]
)
class EnvError(ValueError):
pass
_PROXIED_PATTERN = re.compile(r"\s*{{\s*(\S*)\s*}}\s*")
def _field2method(field_or_factory, method_name, preprocess=None):
def method(self, name, default=ma.missing, subcast=None, **kwargs):
missing = kwargs.pop("missing", None) or default
if isinstance(field_or_factory, type) and issubclass(
field_or_factory, ma.fields.Field
):
field = field_or_factory(missing=missing, **kwargs)
else:
field = field_or_factory(subcast=subcast, missing=missing, **kwargs)
parsed_key, raw_value, proxied_key = self._get_from_environ(name, ma.missing)
self._fields[parsed_key] = field
if raw_value is ma.missing and field.missing is ma.missing:
raise EnvError(
'Environment variable "{}" not set'.format(proxied_key or parsed_key)
)
if raw_value or raw_value == "":
value = raw_value
else:
value = field.missing
if preprocess:
value = preprocess(value, subcast=subcast, **kwargs)
try:
value = field.deserialize(value)
except ma.ValidationError as err:
raise EnvError(
'Environment variable "{}" invalid: {}'.format(name, err.args[0])
)
else:
self._values[parsed_key] = value
return value
method.__name__ = str(method_name) # cast to str for Py2 compat
return method
def _func2method(func, method_name):
def method(self, name, default=ma.missing, subcast=None, **kwargs):
parsed_key, raw_value, proxied_key = self._get_from_environ(name, default)
if raw_value is ma.missing:
raise EnvError(
'Environment variable "{}" not set'.format(proxied_key or parsed_key)
)
value = func(raw_value, **kwargs)
self._fields[parsed_key] = ma.fields.Field(**kwargs)
self._values[parsed_key] = value
return value
method.__name__ = str(method_name) # cast to str for Py2 compat
return method
# From webargs
def _dict2schema(dct):
"""Generate a `marshmallow.Schema` class given a dictionary of
`Fields <marshmallow.fields.Field>`.
"""
attrs = dct.copy()
if MARSHMALLOW_VERSION_INFO[0] < 3:
class Meta(object):
strict = True
attrs["Meta"] = Meta
return type(str(""), (ma.Schema,), attrs)
def _make_list_field(**kwargs):
subcast = kwargs.pop("subcast", None)
inner_field = ma.Schema.TYPE_MAPPING[subcast] if subcast else ma.fields.Field
return ma.fields.List(inner_field, **kwargs)
def _preprocess_list(value, **kwargs):
return value if ma.utils.is_iterable_but_not_string(value) else value.split(",")
def _preprocess_dict(value, **kwargs):
subcast = kwargs.get("subcast")
return {
key.strip(): subcast(val.strip()) if subcast else val.strip()
for key, val in (item.split("=") for item in value.split(",") if value)
}
def _preprocess_json(value, **kwargs):
return pyjson.loads(value)
def _dj_db_url_parser(value, **kwargs):
try:
import dj_database_url
except ImportError:
raise RuntimeError(
"The dj_db_url parser requires the dj-database-url package. "
"You can install it with: pip install dj-database-url"
)
return dj_database_url.parse(value, **kwargs)
def _dj_email_url_parser(value, **kwargs):
try:
import dj_email_url
except ImportError:
raise RuntimeError(
"The dj_email_url parser requires the dj-email-url package. "
"You can install it with: pip install dj-email-url"
)
return dj_email_url.parse(value, **kwargs)
class URLField(ma.fields.URL):
def _serialize(self, value, attr, obj):
return value.geturl()
# Override deserialize rather than _deserialize because we need
# to call urlparse *after* validation has occurred
def deserialize(self, value, attr=None, data=None):
ret = super(URLField, self).deserialize(value, attr, data)
return urlparse.urlparse(ret)
class Env(object):
"""An environment variable reader."""
__call__ = _field2method(ma.fields.Field, "__call__")
def __init__(self):
self._fields = {}
self._values = {}
self._prefix = None
self.__parser_map__ = dict(
bool=_field2method(ma.fields.Bool, "bool"),
str=_field2method(ma.fields.Str, "str"),
int=_field2method(ma.fields.Int, "int"),
float=_field2method(ma.fields.Float, "float"),
decimal=_field2method(ma.fields.Decimal, "decimal"),
list=_field2method(_make_list_field, "list", preprocess=_preprocess_list),
dict=_field2method(ma.fields.Dict, "dict", preprocess=_preprocess_dict),
json=_field2method(ma.fields.Field, "json", preprocess=_preprocess_json),
datetime=_field2method(ma.fields.DateTime, "datetime"),
date=_field2method(ma.fields.Date, "date"),
timedelta=_field2method(ma.fields.TimeDelta, "timedelta"),
uuid=_field2method(ma.fields.UUID, "uuid"),
url=_field2method(URLField, "url"),
dj_db_url=_func2method(_dj_db_url_parser, "dj_db_url"),
dj_email_url=_func2method(_dj_email_url_parser, "dj_email_url"),
)
def __repr__(self):
return "<{} {}>".format(self.__class__.__name__, self._values)
__str__ = __repr__
@staticmethod
def read_env(path=None, recurse=True, stream=None, verbose=False, override=False):
"""Read a .env file into os.environ.
If .env is not found in the directory from which this method is called,
the default behavior is to recurse up the directory tree until a .env
file is found. If you do not wish to recurse up the tree, you may pass
False as a second positional argument.
"""
# By default, start search from the same file this function is called
if path is None:
frame = inspect.currentframe().f_back
caller_dir = os.path.dirname(frame.f_code.co_filename)
start = os.path.join(os.path.abspath(caller_dir))
else:
start = path
if recurse:
for dirname in _walk_to_root(start):
check_path = os.path.join(dirname, ".env")
if os.path.exists(check_path):
return load_dotenv(
check_path, stream=stream, verbose=verbose, override=override
)
else:
if path is None:
start = os.path.join(start, ".env")
return load_dotenv(start, stream=stream, verbose=verbose, override=override)
@contextlib.contextmanager
def prefixed(self, prefix):
"""Context manager for parsing envvars with a common prefix."""
old_prefix = self._prefix
if old_prefix is None:
self._prefix = prefix
else:
self._prefix = "{}{}".format(old_prefix, prefix)
yield self
self._prefix = old_prefix
def __getattr__(self, name, **kwargs):
try:
return functools.partial(self.__parser_map__[name], self)
except KeyError:
raise AttributeError("{} has no attribute {}".format(self, name))
def add_parser(self, name, func):
"""Register a new parser method with the name ``name``. ``func`` must
receive the input value for an environment variable.
"""
self.__parser_map__[name] = _func2method(func, method_name=name)
return None
def parser_for(self, name):
"""Decorator that registers a new parser method with the name ``name``.
The decorated function must receive the input value for an environment variable.
"""
def decorator(func):
self.add_parser(name, func)
return func
return decorator
def add_parser_from_field(self, name, field_cls):
"""Register a new parser method with name ``name``, given a marshmallow ``Field``."""
self.__parser_map__[name] = _field2method(field_cls, method_name=name)
def dump(self):
"""Dump parsed environment variables to a dictionary of simple data types (numbers
and strings).
"""
schema = _dict2schema(self._fields)()
dump_result = schema.dump(self._values)
return dump_result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else dump_result
def _get_from_environ(self, key, default):
"""Access a value from os.environ. Handles proxed variables, e.g. SMTP_LOGIN={{MAILGUN_LOGIN}}.
Returns a tuple (envvar_key, envvar_value, proxied_key). The ``envvar_key`` will be different from
the passed key for proxied variables. proxied_key will be None if the envvar isn't proxied.
"""
env_key = self._get_key(key)
value = os.environ.get(env_key, default)
if hasattr(value, "strip"):
match = _PROXIED_PATTERN.match(value)
if match: # Proxied variable
proxied_key = match.groups()[0]
return key, self._get_from_environ(proxied_key, default)[1], proxied_key
return env_key, value, None
def _get_key(self, key):
return self._prefix + key if self._prefix else key