diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2057ea8aa9f8..940cc446ee87 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -115,6 +115,13 @@ def column_formats(self): if m.d3format } + def add_missing_metrics(self, metrics): + exisiting_metrics = {m.metric_name for m in self.metrics} + for metric in metrics: + if metric.metric_name not in exisiting_metrics: + metric.table_id = self.id + self.metrics += [metric] + @property def metrics_combo(self): return sorted( diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index d73331bc5e8c..ca407fb05ede 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -28,6 +28,8 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa add_title = _('Add Druid Column') edit_title = _('Edit Druid Column') + list_widget = ListWidgetWithCheckboxes + edit_columns = [ 'column_name', 'description', 'dimension_spec_json', 'datasource', 'groupby', 'filterable', 'count_distinct', 'sum', 'min', 'max'] @@ -197,7 +199,6 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin add_title = _('Add Druid Datasource') edit_title = _('Edit Druid Datasource') - list_widget = ListWidgetWithCheckboxes list_columns = [ 'datasource_link', 'cluster', 'changed_by_', 'modified'] order_columns = ['datasource_link', 'modified'] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8ac6e8289a34..6ccddbe79c4b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -97,6 +97,10 @@ def sqla_col(self): col = literal_column(self.expression).label(name) return col + @property + def datasource(self): + return self.table + def get_time_filter(self, start_dttm, end_dttm): col = self.sqla_col.label('__time') l = [] # noqa: E741 @@ -155,6 +159,42 @@ def dttm_sql_literal(self, dttm): self.type or '', dttm) return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f')) + def get_metrics(self): + metrics = [] + M = SqlMetric # noqa + quoted = self.column_name + if self.sum: + metrics.append(M( + metric_name='sum__' + self.column_name, + metric_type='sum', + expression='SUM({})'.format(quoted), + )) + if self.avg: + metrics.append(M( + metric_name='avg__' + self.column_name, + metric_type='avg', + expression='AVG({})'.format(quoted), + )) + if self.max: + metrics.append(M( + metric_name='max__' + self.column_name, + metric_type='max', + expression='MAX({})'.format(quoted), + )) + if self.min: + metrics.append(M( + metric_name='min__' + self.column_name, + metric_type='min', + expression='MIN({})'.format(quoted), + )) + if self.count_distinct: + metrics.append(M( + metric_name='count_distinct__' + self.column_name, + metric_type='count_distinct', + expression='COUNT(DISTINCT {})'.format(quoted), + )) + return {m.metric_name: m for m in metrics} + class SqlMetric(Model, BaseMetric): @@ -702,47 +742,12 @@ def fetch_metadata(self): dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num dbcol.is_dttm = dbcol.is_time + else: + dbcol.type = datatype self.columns.append(dbcol) if not any_date_col and dbcol.is_time: any_date_col = col.name - - quoted = col.name - if dbcol.sum: - metrics.append(M( - metric_name='sum__' + dbcol.column_name, - verbose_name='sum__' + dbcol.column_name, - metric_type='sum', - expression='SUM({})'.format(quoted), - )) - if dbcol.avg: - metrics.append(M( - metric_name='avg__' + dbcol.column_name, - verbose_name='avg__' + dbcol.column_name, - metric_type='avg', - expression='AVG({})'.format(quoted), - )) - if dbcol.max: - metrics.append(M( - metric_name='max__' + dbcol.column_name, - verbose_name='max__' + dbcol.column_name, - metric_type='max', - expression='MAX({})'.format(quoted), - )) - if dbcol.min: - metrics.append(M( - metric_name='min__' + dbcol.column_name, - verbose_name='min__' + dbcol.column_name, - metric_type='min', - expression='MIN({})'.format(quoted), - )) - if dbcol.count_distinct: - metrics.append(M( - metric_name='count_distinct__' + dbcol.column_name, - verbose_name='count_distinct__' + dbcol.column_name, - metric_type='count_distinct', - expression='COUNT(DISTINCT {})'.format(quoted), - )) - dbcol.type = datatype + metrics += dbcol.get_metrics().values() metrics.append(M( metric_name='count', @@ -750,16 +755,9 @@ def fetch_metadata(self): metric_type='count', expression='COUNT(*)', )) - - dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter( - or_(M.metric_name == metric.metric_name for metric in metrics)) - dbmetrics = {metric.metric_name: metric for metric in dbmetrics} - for metric in metrics: - metric.table_id = self.id - if not dbmetrics.get(metric.metric_name, None): - db.session.add(metric) if not self.main_dttm_col: self.main_dttm_col = any_date_col + self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() diff --git a/superset/views/core.py b/superset/views/core.py index b23596436ba5..3e18cf3006f4 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1351,11 +1351,17 @@ def checkbox(self, model_view, id_, attr, value): modelview_to_model = { 'TableColumnInlineView': ConnectorRegistry.sources['table'].column_class, + 'DruidColumnInlineView': + ConnectorRegistry.sources['druid'].column_class, } model = modelview_to_model[model_view] - obj = db.session.query(model).filter_by(id=id_).first() - if obj: - setattr(obj, attr, value == 'true') + col = db.session.query(model).filter_by(id=id_).first() + checked = value == 'true' + if col: + setattr(col, attr, checked) + if checked: + metrics = col.get_metrics().values() + col.datasource.add_missing_metrics(metrics) db.session.commit() return json_success('OK')