Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mher committed May 13, 2023
2 parents 4f5838a + 70872ef commit 4ca1546
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 42 deletions.
6 changes: 6 additions & 0 deletions docs/auth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ environment variables. ::

.. _gitlab-oauth:

**NOTE:** If you need a custom GitHub Domain, please export it using environment variable:
`export FLOWER_GITHUB_OAUTH_DOMAIN=github.foobar.com`

GitLab OAuth
------------

Expand Down Expand Up @@ -128,3 +131,6 @@ See `Group and project members API`_ for details.

.. _GitLab OAuth2 API: https://docs.gitlab.com/ee/api/oauth2.html
.. _Group and project members API: https://docs.gitlab.com/ee/api/members.html

**NOTE:** If you need a custom GitHub Domain, please export it using environment variable:
`export FLOWER_GITLAB_OAUTH_DOMAIN=gitlab.foobar.com`
4 changes: 2 additions & 2 deletions examples/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
app = Celery("tasks",
broker=os.environ.get('CELERY_BROKER_URL', 'redis://'),
backend=os.environ.get('CELERY_RESULT_BACKEND', 'redis'))
app.conf.CELERY_ACCEPT_CONTENT = ['pickle', 'json', 'msgpack', 'yaml']
app.conf.CELERY_WORKER_SEND_TASK_EVENTS = True
app.conf.accept_content = ['pickle', 'json', 'msgpack', 'yaml']
app.conf.worker_send_task_events = True


@app.task
Expand Down
10 changes: 5 additions & 5 deletions flower/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,24 +396,24 @@ async def get(self):
:statuscode 503: result backend is not configured
"""
app = self.application
broker_options = self.capp.conf.BROKER_TRANSPORT_OPTIONS
broker_options = self.capp.conf.broker_transport_options

http_api = None
if app.transport == 'amqp' and app.options.broker_api:
http_api = app.options.broker_api

broker_use_ssl = None
if self.capp.conf.BROKER_USE_SSL:
broker_use_ssl = self.capp.conf.BROKER_USE_SSL
if self.capp.conf.broker_use_ssl:
broker_use_ssl = self.capp.conf.broker_use_ssl

broker = Broker(app.capp.connection().as_uri(include_password=True),
http_api=http_api, broker_options=broker_options, broker_use_ssl=broker_use_ssl)

queue_names = self.get_active_queue_names()

if not queue_names:
queue_names = set([self.capp.conf.CELERY_DEFAULT_QUEUE]) |\
set([q.name for q in self.capp.conf.CELERY_QUEUES or [] if q.name])
queue_names = set([self.capp.conf.task_default_queue]) |\
set([q.name for q in self.capp.conf.task_queues or [] if q.name])

queues = await broker.queues(sorted(queue_names))
self.write({'active_queues': queues})
Expand Down
10 changes: 7 additions & 3 deletions flower/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .app import Flower
from .urls import settings
from .utils import abs_path, prepend_url
from .utils import abs_path, prepend_url, strtobool
from .options import DEFAULT_CONFIG_FILE, default_options
from .views.auth import validate_auth_option

Expand Down Expand Up @@ -69,7 +69,11 @@ def apply_env_options():
if option.multiple:
value = [option.type(i) for i in value.split(',')]
else:
value = option.type(value)
if option.type is bool:
value = bool(strtobool(value))
else:
value = option.type(value)
print(name, type(value), value)
setattr(options, name, value)


Expand Down Expand Up @@ -161,7 +165,7 @@ def print_banner(app, ssl):

logger.info(
"Visit me at http%s://%s:%s%s", 's' if ssl else '',
options.address or 'localhost', options.port,
options.address or '0.0.0.0', options.port,
prefix_str
)
else:
Expand Down
5 changes: 1 addition & 4 deletions flower/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,10 @@
define("tasks_columns", type=str,
default="name,uuid,state,args,kwargs,result,received,started,runtime,worker",
help="slugs of columns on /tasks/ page, delimited by comma")
define("auth_provider", default='flower.views.auth.GoogleAuth2LoginHandler',
help="auth handler class")
define("auth_provider", default=None, type=str, help="auth handler class")
define("url_prefix", type=str, help="base url prefix")
define("task_runtime_metric_buckets", type=float, default=Histogram.DEFAULT_BUCKETS,
multiple=True, help="histogram latency bucket value")

# deprecated options
define("inspect", default=False, help="inspect workers", type=bool)

default_options = options
16 changes: 16 additions & 0 deletions flower/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,19 @@ def abs_path(path):

def prepend_url(url, prefix):
return '/' + prefix.strip('/') + url


def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError("invalid truth value %r" % (val,))
24 changes: 15 additions & 9 deletions flower/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import traceback
import copy
import logging
import hmac

from base64 import b64decode

import tornado

from ..utils import template, bugreport, prepend_url, strtobool
from ..utils import template, bugreport, strtobool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,10 +47,10 @@ def write_error(self, status_code, **kwargs):
error_trace += line

self.render('error.html',
debug=self.application.options.debug,
status_code=status_code,
error_trace=error_trace,
bugreport=bugreport())
debug=self.application.options.debug,
status_code=status_code,
error_trace=error_trace,
bugreport=bugreport())
elif status_code == 401:
self.set_status(status_code)
self.set_header('WWW-Authenticate', 'Basic realm="flower"')
Expand All @@ -71,7 +72,12 @@ def get_current_user(self):
try:
basic, credentials = auth_header.split()
credentials = b64decode(credentials.encode()).decode()
if basic != 'Basic' or credentials not in basic_auth:
if basic != 'Basic':
raise tornado.web.HTTPError(401)
for stored_credential in basic_auth:
if hmac.compare_digest(stored_credential, credentials):
break
else:
raise tornado.web.HTTPError(401)
except ValueError:
raise tornado.web.HTTPError(401)
Expand Down Expand Up @@ -101,9 +107,9 @@ def get_argument(self, name, default=[], strip=True, type=None):
if arg is None and default is None:
return arg
raise tornado.web.HTTPError(
400,
"Invalid argument '%s' of type '%s'" % (
arg, type.__name__))
400,
"Invalid argument '%s' of type '%s'" % (
arg, type.__name__))
return arg

@property
Expand Down
21 changes: 13 additions & 8 deletions flower/views/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from celery.utils.imports import instantiate

from ..views import BaseHandler
from ..views.error import NotFoundErrorHandler


def authenticate(pattern, email):
Expand Down Expand Up @@ -85,13 +86,15 @@ async def _on_auth(self, user):

class LoginHandler(BaseHandler):
def __new__(cls, *args, **kwargs):
return instantiate(options.auth_provider, *args, **kwargs)
return instantiate(options.auth_provider or NotFoundErrorHandler, *args, **kwargs)


class GithubLoginHandler(BaseHandler, tornado.auth.OAuth2Mixin):

_OAUTH_AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
_OAUTH_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
_OAUTH_DOMAIN = os.getenv(
"FLOWER_GITLAB_OAUTH_DOMAIN", "github.com")
_OAUTH_AUTHORIZE_URL = f'https://{_OAUTH_DOMAIN}/login/oauth/authorize'
_OAUTH_ACCESS_TOKEN_URL = f'https://{_OAUTH_DOMAIN}/login/oauth/access_token'
_OAUTH_NO_CALLBACKS = False
_OAUTH_SETTINGS_KEY = 'oauth'

Expand Down Expand Up @@ -139,7 +142,7 @@ async def _on_auth(self, user):
access_token = user['access_token']

response = await self.get_auth_http_client().fetch(
'https://api.github.com/user/emails',
f'https://api.{self._OAUTH_DOMAIN}/user/emails',
headers={'Authorization': 'token ' + access_token,
'User-agent': 'Tornado auth'})

Expand All @@ -163,8 +166,10 @@ async def _on_auth(self, user):

class GitLabLoginHandler(BaseHandler, tornado.auth.OAuth2Mixin):

_OAUTH_AUTHORIZE_URL = 'https://gitlab.com/oauth/authorize'
_OAUTH_ACCESS_TOKEN_URL = 'https://gitlab.com/oauth/token'
_OAUTH_GITLAB_DOMAIN = os.getenv(
"FLOWER_GITLAB_AUTH_DOMAIN", "gitlab.com")
_OAUTH_AUTHORIZE_URL = f'https://{_OAUTH_GITLAB_DOMAIN}/oauth/authorize'
_OAUTH_ACCESS_TOKEN_URL = f'https://{_OAUTH_GITLAB_DOMAIN}/oauth/token'
_OAUTH_NO_CALLBACKS = False

async def get_authenticated_user(self, redirect_uri, code):
Expand Down Expand Up @@ -213,7 +218,7 @@ async def _on_auth(self, user):
# Check user email address against regexp
try:
response = await self.get_auth_http_client().fetch(
'https://gitlab.com/api/v4/user',
f'https://{self._OAUTH_GITLAB_DOMAIN}/api/v4/user',
headers={'Authorization': 'Bearer ' + access_token,
'User-agent': 'Tornado auth'}
)
Expand All @@ -228,7 +233,7 @@ async def _on_auth(self, user):
if allowed_groups:
min_access_level = os.environ.get('FLOWER_GITLAB_MIN_ACCESS_LEVEL', '20')
response = await self.get_auth_http_client().fetch(
'https://gitlab.com/api/v4/groups?min_access_level=%s' % (min_access_level,),
f'https://{self._OAUTH_GITLAB_DOMAIN}/api/v4/groups?min_access_level={min_access_level}',
headers={
'Authorization': 'Bearer ' + access_token,
'User-agent': 'Tornado auth'
Expand Down
10 changes: 5 additions & 5 deletions flower/views/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class BrokerView(BaseHandler):
@web.authenticated
async def get(self):
app = self.application
broker_options = self.capp.conf.BROKER_TRANSPORT_OPTIONS
broker_options = self.capp.conf.broker_transport_options

http_api = None
if app.transport == 'amqp' and app.options.broker_api:
http_api = app.options.broker_api

broker_use_ssl = None
if self.capp.conf.BROKER_USE_SSL:
broker_use_ssl = self.capp.conf.BROKER_USE_SSL
if self.capp.conf.broker_use_ssl:
broker_use_ssl = self.capp.conf.broker_use_ssl

try:
broker = Broker(app.capp.connection(connect_timeout=1.0).as_uri(include_password=True),
Expand All @@ -36,8 +36,8 @@ async def get(self):
try:
queue_names = self.get_active_queue_names()
if not queue_names:
queue_names = set([self.capp.conf.CELERY_DEFAULT_QUEUE]) |\
set([q.name for q in self.capp.conf.CELERY_QUEUES or [] if q.name])
queue_names = set([self.capp.conf.task_default_queue]) |\
set([q.name for q in self.capp.conf.task_queues or [] if q.name])

queues = await broker.queues(sorted(queue_names))
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions flower/views/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get(self):
capp = self.application.capp

time = 'natural-time' if app.options.natural_time else 'time'
if capp.conf.CELERY_TIMEZONE:
time += '-' + str(capp.conf.CELERY_TIMEZONE)
if capp.conf.timezone:
time += '-' + str(capp.conf.timezone)

self.render(
"tasks.html",
Expand Down
34 changes: 30 additions & 4 deletions tests/unit/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def test_task_runtime_metric_buckets_no_cmd_line_arg(self):
self.assertEqual(Histogram.DEFAULT_BUCKETS, options.task_runtime_metric_buckets)

def test_task_runtime_metric_buckets_read_from_env(self):
os.environ["FLOWER_TASK_RUNTIME_METRIC_BUCKETS"] = "2,5,inf"
apply_env_options()
self.assertEqual([2.0, 5.0, float('inf')], options.task_runtime_metric_buckets)
with patch.dict(os.environ, {"FLOWER_TASK_RUNTIME_METRIC_BUCKETS": "2,5,inf"}):
apply_env_options()
self.assertEqual([2.0, 5.0, float('inf')], options.task_runtime_metric_buckets)

def test_task_runtime_metric_buckets_no_env_value_provided(self):
apply_env_options()
Expand All @@ -40,6 +40,32 @@ def test_address(self):
apply_options('flower', argv=['--address=foo'])
self.assertEqual('foo', options.address)

def test_auto_refresh(self):
with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "false"}):
apply_env_options()
self.assertFalse(options.auto_refresh)

with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "true"}):
apply_env_options()
self.assertTrue(options.auto_refresh)

with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "0"}):
apply_env_options()
self.assertFalse(options.auto_refresh)

with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "1"}):
apply_env_options()
self.assertTrue(options.auto_refresh)

with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "False"}):
apply_env_options()
self.assertFalse(options.auto_refresh)

with patch.dict(os.environ, {"FLOWER_AUTO_REFRESH": "True"}):
apply_env_options()
self.assertTrue(options.auto_refresh)


def test_autodiscovery(self):
"""
Simulate basic Django setup:
Expand Down Expand Up @@ -134,7 +160,7 @@ def grep(patter, filename):
return int(subprocess.check_output(
'grep "%s" %s|wc -l' % (patter, filename), shell=True))

defined = grep('^define(', 'flower/options.py') - 4
defined = grep('^define(', 'flower/options.py') - 3
documented = grep('^~~', 'docs/config.rst')
self.assertEqual(defined, documented,
msg='Missing option documentation. Make sure all options '
Expand Down

0 comments on commit 4ca1546

Please sign in to comment.