In [183]:
import pathlib

from heliostack.image import Image

import torch
import kornia

from astropy.io import fits
from astropy.wcs import WCS
from astropy.time import Time

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from scipy.ndimage import distance_transform_edt, center_of_mass, label, map_coordinates, convolve
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def stack(images: list[Image]):
    phi = torch.zeros(2500, 5000, device=device, dtype=torch.half)
    psi = torch.zeros(2500, 5000, device=device, dtype=torch.half)
    # counter = torch.zeros(2500, 5000, device=device, dtype=torch.uint8)

    for im in images:
        dx, dy = im.image.shape
        phi[0:dx, 0:dy] += im.image
        psi[0:dx, 0:dy] += im.weight
        # counter[0:dx, 0:dy] += 1
    s = phi / psi
    return s

### Read in all of the images

In [132]:
EDGE_CUT_PIXELS = 250
def ingest_image(image_path, weight_path, device) -> Image:
    im = fits.open(image_path)
    wt = fits.open(weight_path)

    wcs = WCS(im[1].header)
    epoch = Time(im[1].header['MJD-OBS'], format='mjd', scale='utc')

    image = Image(im[1].data[EDGE_CUT_PIXELS:-EDGE_CUT_PIXELS, EDGE_CUT_PIXELS:-EDGE_CUT_PIXELS], 
                  wt[1].data[EDGE_CUT_PIXELS:-EDGE_CUT_PIXELS, EDGE_CUT_PIXELS:-EDGE_CUT_PIXELS], 
                  wcs, 
                  epoch, 
                  device=device)
    return image

In [133]:
root = '/nfs/deep/diffim_distant/B1d/20201019/CCD1'

image_paths = list(pathlib.Path(root).glob('*.diff.rescale.fits.fz'))
weight_paths = list(pathlib.Path(root).glob('*.diff.weight.fits.fz'))
assert len(image_paths) == len(weight_paths)

# read in the images
images = []
for i, w in zip(image_paths, weight_paths):
    images.append(ingest_image(i, w, device))

In [None]:
def find_peaks(image, threshold, min_size=2, max_size=400):
    
    peak_labels, _ = label(peaks)
    labels, counts = np.unique(peak_labels, return_counts=True)
    l = labels[(counts >= min_size) * (counts <= max_size)]

    if len(l) == 0:
        return None, None, None

    ys, xs = np.array(center_of_mass(image, peak_labels, l)).T
    snr = image[ys.astype(int), xs.astype(int)]
    return xs, ys, snr

In [190]:
import sep

### Calculate the stack rates

In [224]:
# use the grid

CPU times: user 150 ms, sys: 8.94 ms, total: 159 ms
Wall time: 158 ms


### Stack at each rate, and extract sources

In [228]:
%%time
for _ in range(10):
    s = stack(images)
    sources = sep.extract(np.float32(s.cpu().numpy()), 2, minarea=2)

CPU times: user 1.51 s, sys: 10 µs, total: 1.51 s
Wall time: 1.51 s
