diff --git a/extensions/skyportal/skyportal/handlers/api/alert.py b/extensions/skyportal/skyportal/handlers/api/alert.py index f3b31ae5..1cf2e872 100644 --- a/extensions/skyportal/skyportal/handlers/api/alert.py +++ b/extensions/skyportal/skyportal/handlers/api/alert.py @@ -1,26 +1,79 @@ from astropy.io import fits from astropy.visualization import ZScaleInterval +import base64 import bson.json_util as bj import gzip import io +from marshmallow.exceptions import ValidationError import matplotlib.colors as mplc import matplotlib.pyplot as plt import numpy as np import os +import pandas as pd import pathlib -import requests +import tornado.escape +import tornado.httpclient import traceback from baselayer.app.access import auth_or_token +from baselayer.log import make_log from ..base import BaseHandler from ...models import ( DBSession, + Group, + GroupStream, + Obj, Stream, StreamUser, + Source, ) +from .photometry import PhotometryHandler +from .thumbnail import ThumbnailHandler -s = requests.Session() +log = make_log("alert") + + +c = tornado.httpclient.AsyncHTTPClient() + + +def make_thumbnail(a, ttype, ztftype): + + cutout_data = a[f'cutout{ztftype}']['stampData'] + with gzip.open(io.BytesIO(cutout_data), 'rb') as f: + with fits.open(io.BytesIO(f.read())) as hdu: + # header = hdu[0].header + data_flipped_y = np.flipud(hdu[0].data) + # fixme: png, switch to fits eventually + buff = io.BytesIO() + plt.close('all') + fig = plt.figure() + fig.set_size_inches(4, 4, forward=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + fig.add_axes(ax) + + # remove nans: + img = np.array(data_flipped_y) + img = np.nan_to_num(img) + + if ztftype != 'Difference': + img[img <= 0] = np.median(img) + plt.imshow(img, cmap="bone", norm=mplc.LogNorm(), origin='lower') + else: + plt.imshow(img, cmap="bone", origin='lower') + plt.savefig(buff, dpi=42) + + buff.seek(0) + plt.close('all') + + thumb = { + "obj_id": a["objectId"], + "data": base64.b64encode(buff.read()).decode("utf-8"), + "ttype": ttype, + } + + return thumb class ZTFAlertHandler(BaseHandler): @@ -30,7 +83,7 @@ def get_user_streams(self): DBSession() .query(Stream) .join(StreamUser) - .filter(StreamUser.user_id == self.current_user.id) + .filter(StreamUser.user_id == self.associated_user_object.id) .all() ) if streams is None: @@ -38,8 +91,23 @@ def get_user_streams(self): return streams + async def query_kowalski(self, query: dict, timeout=7): + base_url = f"{self.cfg['app.kowalski.protocol']}://" \ + f"{self.cfg['app.kowalski.host']}:{self.cfg['app.kowalski.port']}" + headers = {"Authorization": f"Bearer {self.cfg['app.kowalski.token']}"} + + resp = await c.fetch( + os.path.join(base_url, 'api/queries'), + method='POST', + body=tornado.escape.json_encode(query), + headers=headers, + request_timeout=timeout + ) + + return resp + @auth_or_token - def get(self, objectId: str = None): + async def get(self, objectId: str = None): """ --- single: @@ -107,11 +175,15 @@ def get(self, objectId: str = None): "candidate.jd": 1, "candidate.programid": 1, "candidate.fid": 1, + "candidate.ra": 1, + "candidate.dec": 1, "candidate.magpsf": 1, "candidate.sigmapsf": 1, "candidate.rb": 1, "candidate.drb": 1, "candidate.isdiffpos": 1, + "coordinates.l": 1, + "coordinates.b": 1, } }, ] @@ -121,19 +193,14 @@ def get(self, objectId: str = None): # } } - base_url = f"{self.cfg['app.kowalski.protocol']}://" \ - f"{self.cfg['app.kowalski.host']}:{self.cfg['app.kowalski.port']}" - headers = {"Authorization": f"Bearer {self.cfg['app.kowalski.token']}"} + candid = self.get_query_argument('candid', None) + if candid: + query["query"]["pipeline"][0]["$match"]["candid"] = int(candid) - resp = s.post( - os.path.join(base_url, 'api/queries'), - json=query, - headers=headers, - timeout=7 - ) + resp = await self.query_kowalski(query=query) - if resp.status_code == requests.codes.ok: - alert_data = bj.loads(resp.text).get('data') + if resp.code == 200: + alert_data = tornado.escape.json_decode(resp.body).get('data') return self.success(data=alert_data) else: return self.error(f"Failed to fetch data for {objectId} from Kowalski") @@ -142,10 +209,341 @@ def get(self, objectId: str = None): _err = traceback.format_exc() return self.error(f'failure: {_err}') + @auth_or_token + async def post(self, objectId): + """ + --- + description: Save ZTF objectId from Kowalski as source in SkyPortal + requestBody: + content: + application/json: + schema: + allOf: + - type: object + properties: + candid: + type: integer + description: "alert candid to use to pull thumbnails. defaults to latest alert" + minimum: 1 + group_ids: + type: array + items: + type: integer + description: "group ids to save source to. defaults to all user groups" + minItems: 1 + responses: + 200: + content: + application/json: + schema: Success + 400: + content: + application/json: + schema: Error + """ + streams = self.get_user_streams() + + # allow access to public data only by default + selector = {1} + + for stream in streams: + if "ztf" in stream.name.lower(): + selector.update(set(stream.altdata.get("selector", []))) + + selector = list(selector) + + data = self.get_json() + candid = data.get("candid", None) + group_ids = data.pop("group_ids", None) + + try: + query = { + "query_type": "aggregate", + "query": { + "catalog": "ZTF_alerts_aux", + "pipeline": [ + { + "$match": { + "_id": objectId + } + }, + { + "$project": { + "_id": 1, + "cross_matches": 1, + "prv_candidates": { + "$filter": { + "input": "$prv_candidates", + "as": "item", + "cond": { + "$in": [ + "$$item.programid", + selector + ] + } + } + }, + } + }, + { + "$project": { + "_id": 1, + "prv_candidates.magpsf": 1, + "prv_candidates.sigmapsf": 1, + "prv_candidates.diffmaglim": 1, + "prv_candidates.programid": 1, + "prv_candidates.fid": 1, + "prv_candidates.rb": 1, + "prv_candidates.ra": 1, + "prv_candidates.dec": 1, + "prv_candidates.candid": 1, + "prv_candidates.jd": 1, + } + } + ] + } + } + + resp = await self.query_kowalski(query=query) + + if resp.code == 200: + alert_data = tornado.escape.json_decode(resp.body).get('data', list(dict())) + if len(alert_data) > 0: + alert_data = alert_data[0] + else: + return self.error(f"{objectId} not found on Kowalski") + else: + return self.error(f"Failed to fetch data for {objectId} from Kowalski") + + # grab and append most recent candid as it should not be in prv_candidates + query = { + "query_type": "aggregate", + "query": { + "catalog": "ZTF_alerts", + "pipeline": [ + { + "$match": { + "objectId": objectId, + "candidate.programid": {"$in": selector} + } + }, + { + "$project": { + # grab only what's going to be rendered + "_id": 0, + "candidate.candid": {"$toString": "$candidate.candid"}, + "candidate.programid": 1, + "candidate.jd": 1, + "candidate.fid": 1, + "candidate.rb": 1, + "candidate.drb": 1, + "candidate.ra": 1, + "candidate.dec": 1, + "candidate.magpsf": 1, + "candidate.sigmapsf": 1, + "candidate.diffmaglim": 1, + } + }, + { + "$sort": { + "candidate.jd": -1 + } + }, + { + "$limit": 1 + } + ] + } + } + + resp = await self.query_kowalski(query=query) + + if resp.code == 200: + latest_alert_data = tornado.escape.json_decode(resp.body).get('data', list(dict())) + if len(latest_alert_data) > 0: + latest_alert_data = latest_alert_data[0] + else: + return self.error(f"Failed to fetch data for {objectId} from Kowalski") + + if len(latest_alert_data) > 0: + candids = {a.get('candid', None) for a in alert_data['prv_candidates']} + if latest_alert_data['candidate']["candid"] not in candids: + alert_data['prv_candidates'].append(latest_alert_data['candidate']) + + df = pd.DataFrame.from_records(alert_data["prv_candidates"]) + w = df["candid"] == str(candid) + + if candid is None or sum(w) == 0: + candids = {int(can) for can in df["candid"] if not pd.isnull(can)} + candid = max(candids) + alert = df.loc[df["candid"] == str(candid)].to_dict(orient="records")[0] + else: + alert = df.loc[w].to_dict(orient="records")[0] + + # post source + drb = alert.get('drb') + rb = alert.get('rb') + score = drb if drb is not None and not np.isnan(drb) else rb + alert_thin = { + "id": objectId, + "ra": alert.get('ra'), + "dec": alert.get('dec'), + "score": score, + "altdata": { + "passing_alert_id": candid, + }, + } + + schema = Obj.__schema__() + user_group_ids = [g.id for g in self.associated_user_object.groups if not g.single_user_group] + user_accessible_group_ids = [g.id for g in self.associated_user_object.accessible_groups] + if not user_group_ids: + return self.error( + "You must belong to one or more groups before you can add sources." + ) + if (group_ids is not None) and (len(set(group_ids) - set(user_accessible_group_ids)) > 0): + forbidden_groups = list(set(group_ids) - set(user_accessible_group_ids)) + return self.error( + "Insufficient group access permissions. Not a member of " + f"group IDs: {forbidden_groups}." + ) + try: + group_ids = [ + int(_id) + for _id in group_ids + if int(_id) in user_accessible_group_ids + ] + except Exception: + group_ids = user_group_ids + if not group_ids: + return self.error( + "Invalid group_ids field. Please specify at least " + "one valid group ID that you belong to." + ) + try: + obj = schema.load(alert_thin) + except ValidationError as e: + return self.error( + 'Invalid/missing parameters: ' f'{e.normalized_messages()}' + ) + groups = Group.query.filter(Group.id.in_(group_ids)).all() + if not groups: + return self.error( + "Invalid group_ids field. Please specify at least " + "one valid group ID that you belong to." + ) + + # check that all groups have access to same streams as user + for group in groups: + group_streams = ( + DBSession() + .query(Stream) + .join(GroupStream) + .filter(GroupStream.group_id == group.id) + .all() + ) + if group_streams is None: + group_streams = [] + + group_stream_selector = {1} + + for stream in group_streams: + if "ztf" in stream.name.lower(): + group_stream_selector.update(set(stream.altdata.get("selector", []))) + + if not set(selector).issubset(group_stream_selector): + return self.error(f"Cannot save to group {group.name}: " + "insufficient group alert stream permissions") + + DBSession().add(obj) + DBSession().add_all([Source(obj=obj, group=group) for group in groups]) + DBSession().commit() + + # post photometry + ztf_filters = {1: 'ztfg', 2: 'ztfr', 3: 'ztfi'} + df['ztf_filter'] = df['fid'].apply(lambda x: ztf_filters[x]) + df['magsys'] = "ab" + df['mjd'] = df['jd'] - 2400000.5 + + photometry = { + "obj_id": objectId, + "group_ids": group_ids, + "instrument_id": 1, # placeholder + "mjd": df.mjd.tolist(), + "mag": df.magpsf.tolist(), + "magerr": df.sigmapsf.tolist(), + "limiting_mag": df.diffmaglim.tolist(), + "magsys": df.magsys.tolist(), + "filter": df.ztf_filter.tolist(), + "ra": df.ra.tolist(), + "dec": df.dec.tolist(), + } + + photometry_handler = PhotometryHandler(request=self.request, application=self.application) + photometry_handler.request.body = tornado.escape.json_encode(photometry) + try: + photometry_handler.post() + except Exception: + log(f"Failed to post photometry of {objectId}") + # do not return anything yet + self.clear() + + # post cutouts + for ttype, ztftype in [('new', 'Science'), ('ref', 'Template'), ('sub', 'Difference')]: + query = { + "query_type": "find", + "query": { + "catalog": "ZTF_alerts", + "filter": { + "candid": candid, + "candidate.programid": { + "$in": selector + } + }, + "projection": { + "_id": 0, + "objectId": 1, + f"cutout{ztftype}": 1 + } + }, + "kwargs": { + "limit": 1, + } + } + + resp = await self.query_kowalski(query=query) + + if resp.code == 200: + cutout = bj.loads(bj.dumps(tornado.escape.json_decode(resp.body).get('data', list(dict()))[0])) + else: + cutout = dict() + + thumb = make_thumbnail(cutout, ttype, ztftype) + + try: + thumbnail_handler = ThumbnailHandler(request=self.request, application=self.application) + thumbnail_handler.request.body = tornado.escape.json_encode(thumb) + thumbnail_handler.post() + except Exception as e: + log(f"Failed to post thumbnails of {objectId} | {candid}") + log(str(e)) + self.clear() + + # todo: notify Kowalski so that it puts this objectId on tracking list + # (to post new photometry to SP when new alerts arrive) + + self.push_all(action="skyportal/FETCH_SOURCES") + self.push_all(action="skyportal/FETCH_RECENT_SOURCES") + return self.success(data={"id": objectId}) + + except Exception: + _err = traceback.format_exc() + return self.error(f'failure: {_err}') + class ZTFAlertAuxHandler(ZTFAlertHandler): @auth_or_token - def get(self, objectId: str = None): + async def get(self, objectId: str = None): """ --- single: @@ -226,6 +624,8 @@ def get(self, objectId: str = None): "prv_candidates.diffmaglim": 1, "prv_candidates.programid": 1, "prv_candidates.fid": 1, + "prv_candidates.ra": 1, + "prv_candidates.dec": 1, "prv_candidates.candid": 1, "prv_candidates.jd": 1, } @@ -234,19 +634,10 @@ def get(self, objectId: str = None): } } - base_url = f"{self.cfg['app.kowalski.protocol']}://" \ - f"{self.cfg['app.kowalski.host']}:{self.cfg['app.kowalski.port']}" - headers = {"Authorization": f"Bearer {self.cfg['app.kowalski.token']}"} - - resp = s.post( - os.path.join(base_url, 'api/queries'), - json=query, - headers=headers, - timeout=7 - ) + resp = await self.query_kowalski(query=query) - if resp.status_code == requests.codes.ok: - alert_data = bj.loads(resp.text).get('data', list(dict())) + if resp.code == 200: + alert_data = tornado.escape.json_decode(resp.body).get('data', list(dict())) if len(alert_data) > 0: alert_data = alert_data[0] else: @@ -268,13 +659,17 @@ def get(self, objectId: str = None): "$project": { # grab only what's going to be rendered "_id": 0, - "candidate.candid": 1, + "candidate.candid": {"$toString": "$candidate.candid"}, "candidate.programid": 1, "candidate.jd": 1, "candidate.fid": 1, + "candidate.ra": 1, + "candidate.dec": 1, "candidate.magpsf": 1, "candidate.sigmapsf": 1, "candidate.diffmaglim": 1, + "coordinates.l": 1, + "coordinates.b": 1, } }, { @@ -289,15 +684,10 @@ def get(self, objectId: str = None): } } - resp = s.post( - os.path.join(base_url, 'api/queries'), - json=query, - headers=headers, - timeout=7 - ) + resp = await self.query_kowalski(query=query) - if resp.status_code == requests.codes.ok: - latest_alert_data = bj.loads(resp.text).get('data', list(dict())) + if resp.code == 200: + latest_alert_data = tornado.escape.json_decode(resp.body).get('data', list(dict())) if len(latest_alert_data) > 0: latest_alert_data = latest_alert_data[0] else: @@ -317,7 +707,7 @@ def get(self, objectId: str = None): class ZTFAlertCutoutHandler(ZTFAlertHandler): @auth_or_token - def get(self, objectId: str = None): + async def get(self, objectId: str = None): """ --- summary: Serve ZTF alert cutout as fits or png @@ -442,19 +832,10 @@ def get(self, objectId: str = None): } } - base_url = f"{self.cfg['app.kowalski.protocol']}://" \ - f"{self.cfg['app.kowalski.host']}:{self.cfg['app.kowalski.port']}" - headers = {"Authorization": f"Bearer {self.cfg['app.kowalski.token']}"} - - resp = s.post( - os.path.join(base_url, 'api/queries'), - json=query, - headers=headers, - timeout=7 - ) + resp = await self.query_kowalski(query=query) - if resp.status_code == requests.codes.ok: - alert = bj.loads(resp.text).get('data', list(dict()))[0] + if resp.code == 200: + alert = bj.loads(bj.dumps(tornado.escape.json_decode(resp.body).get('data', list(dict()))[0])) else: alert = dict() @@ -507,9 +888,7 @@ def get(self, objectId: str = None): ax.imshow(img, origin='lower', cmap=cmap, vmin=limits[0], vmax=limits[1]) elif scaling == 'arcsinh': ax.imshow(np.arcsinh(img - np.median(img)), cmap=cmap, origin='lower') - plt.savefig(buff, dpi=42) - buff.seek(0) plt.close('all') self.set_header("Content-Type", 'image/png') diff --git a/extensions/skyportal/static/js/components/Filter.jsx b/extensions/skyportal/static/js/components/Filter.jsx index 5d76c3d7..bf736673 100644 --- a/extensions/skyportal/static/js/components/Filter.jsx +++ b/extensions/skyportal/static/js/components/Filter.jsx @@ -117,7 +117,7 @@ const Filter = () => { useEffect(() => { const fetchFilterVersion = async () => { const data = await dispatch(filterVersionActions.fetchFilterVersion(fid)); - if ((data.status === "error") && !(data.message.includes("not found"))) { + if (data.status === "error" && !data.message.includes("not found")) { setFilterVersionLoadError(data.message); if (filterVersionLoadError.length > 1) { dispatch(showNotification(filterVersionLoadError, "error")); @@ -127,7 +127,7 @@ const Filter = () => { if (loadedId !== fid) { fetchFilterVersion(); } - }, [fid, loadedId, dispatch]); + }, [fid, loadedId, dispatch, filterVersionLoadError]); const group_id = useSelector((state) => state.filter.group_id); @@ -142,7 +142,7 @@ const Filter = () => { } }; if (group_id) fetchGroup(); - }, [group_id, dispatch]); + }, [group_id, dispatch, groupLoadError]); const filter = useSelector((state) => state.filter); const filter_v = useSelector((state) => state.filter_v); diff --git a/extensions/skyportal/static/js/components/SaveAlertButton.jsx b/extensions/skyportal/static/js/components/SaveAlertButton.jsx new file mode 100644 index 00000000..02324fb4 --- /dev/null +++ b/extensions/skyportal/static/js/components/SaveAlertButton.jsx @@ -0,0 +1,221 @@ +import React, { useState, useEffect, useRef } from "react"; +import PropTypes from "prop-types"; +import { useDispatch } from "react-redux"; +import Dialog from "@material-ui/core/Dialog"; +import DialogContent from "@material-ui/core/DialogContent"; +import DialogTitle from "@material-ui/core/DialogTitle"; +import Checkbox from "@material-ui/core/Checkbox"; +import Button from "@material-ui/core/Button"; +import ButtonGroup from "@material-ui/core/ButtonGroup"; +import FormControlLabel from "@material-ui/core/FormControlLabel"; +import ArrowDropDownIcon from "@material-ui/icons/ArrowDropDown"; +import ClickAwayListener from "@material-ui/core/ClickAwayListener"; +import Grow from "@material-ui/core/Grow"; +import Paper from "@material-ui/core/Paper"; +import Popper from "@material-ui/core/Popper"; +import MenuItem from "@material-ui/core/MenuItem"; +import MenuList from "@material-ui/core/MenuList"; +import { useForm, Controller } from "react-hook-form"; + +import * as alertActions from "../ducks/alert"; +import * as sourceActions from "../ducks/source"; +import FormValidationError from "./FormValidationError"; + +const SaveAlertButton = ({ alert, userGroups }) => { + const [isSubmitting, setIsSubmitting] = useState(false); + // Dialog logic: + + const dispatch = useDispatch(); + const [dialogOpen, setDialogOpen] = useState(false); + + const { handleSubmit, errors, reset, control, getValues } = useForm(); + + useEffect(() => { + reset({ + group_ids: [] + }); + }, [reset, userGroups, alert]); + + const handleClickOpenDialog = () => { + setDialogOpen(true); + }; + + const handleCloseDialog = () => { + setDialogOpen(false); + }; + + const validateGroups = () => { + const formState = getValues({ nest: true }); + return formState.group_ids.filter((value) => Boolean(value)).length >= 1; + }; + + const onSubmitGroupSelectSave = async (data) => { + setIsSubmitting(true); + data.id = alert.id; + const groupIDs = userGroups.map((g) => g.id); + const selectedGroupIDs = groupIDs.filter((ID, idx) => data.group_ids[idx]); + + data.payload = {candid: alert.candid, group_ids: selectedGroupIDs}; + + const result = await dispatch(alertActions.saveAlertAsSource(data)); + if (result.status === "error") { + setIsSubmitting(false); + } else { + setDialogOpen(false); + reset(); + await dispatch(sourceActions.fetchSource(alert.id)); + } + }; + + // Split button logic (largely copied from + // https://material-ui.com/components/button-group/#split-button): + + const options = ["Select groups & save as a source"]; + // const options = ["Select groups & save as a source", "Select filters & save as a candidate"]; + + const [splitButtonMenuOpen, setSplitButtonMenuOpen] = useState(false); + const anchorRef = useRef(null); + const [selectedIndex, setSelectedIndex] = useState(0); + + const handleClickMainButton = async () => { + if (selectedIndex === 0) { + handleClickOpenDialog(); + } + }; + + const handleMenuItemClick = (event, index) => { + setSelectedIndex(index); + setSplitButtonMenuOpen(false); + }; + + const handleToggleSplitButtonMenu = () => { + setSplitButtonMenuOpen((prevOpen) => !prevOpen); + }; + + const handleCloseSplitButtonMenu = (event) => { + if (anchorRef.current && anchorRef.current.contains(event.target)) { + return; + } + setSplitButtonMenuOpen(false); + }; + + return ( +