In [1]:
from astropy.io import fits
from astropy.visualization import ZScaleInterval, ImageNormalize
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage

import cProfile, pstats, io
import glob
import os
import re
import shutil
import time

In [2]:
class DESIImage():
    def __init__(self, pix, mask, ivar):
        self.pix = pix
        self.mask = mask
        self.ivar = ivar
        self.camera = "R3"

In [3]:
ROOT = os.curdir

In [4]:
def get_images(image_names):
    big = []
    big_mask = []
    
    # Loop over the list of image names and extract all the full names
    for n in image_names:
        # Checks to see if we compressed the files or not
        if len(n) > 6:
            image_name = os.path.join(ROOT, "data", f"{n}.fits")
        else:
            image_name = os.path.join(ROOT, "data", f"{n}.fits.fz")
        hdus = fits.open(image_name)
        
        # Loads the image, then loads the corresponding CR mask.
        working = hdus["IMAGE"].data
        working_mask = hdus["MASK"].data
        
        pretty_disp(working)
        
        # Loops over the image and slices it up into 256x256 squares.
        # This is the same slicing code I wrote that reproduced the deepCR hubble results
        for i in range(128, working.shape[0] - 128 - 255, 256):
            for j in range(128, working.shape[1] - 128 - 255, 256):
                big.append(working[i:i+256, j:j+256].reshape(1, 256, 256))
                big_mask.append(working_mask[i:i+256, j:j+256].reshape(1, 256, 256))
                
        for i in range(0, working.shape[0] - 255, 256):
            for j in range(0, working.shape[1] - 255, 256):
                big.append(working[i:i+256, j:j+256].reshape(1, 256, 256))
                big_mask.append(working_mask[i:i+256, j:j+256].reshape(1, 256, 256))

            
    # Concatenates the images into nice numpy arrays
    big = np.concatenate(big)
    big_mask = np.concatenate(big_mask)

    return (big, big_mask)

def pretty_disp(img):
    zscale = ZScaleInterval()
    norm = ImageNormalize(img, zscale)

    fig, ax = plt.subplots(1, 1, figsize=(14, 28), tight_layout=True)
    ax.imshow(img, cmap='gray', origin="lower", interpolation="none", norm=norm)

In [5]:
from deepCR import deepCR
from deepCR import train
from deepCR import roc

In [6]:
mdl = deepCR(mask="2021-01-15_manta_spectro_epoch60.pth", device='GPU', hidden=32)

n = "10-R9.fits.fz"
hdus = fits.open(os.path.join(os.curdir, "data", n))

working = hdus["IMAGE"].data
truth_mask = hdus["MASK"].data

# m = mdl.clean(working, inpaint=False)

# Enable Profiler
pr = cProfile.Profile()
pr.enable()

# m = mdl.clean(working, inpaint=False)#, segment=True)

m = np.zeros_like(working)

delta_x = m.shape[1] // 4
delta_y = m.shape[0] // 4

for x in np.arange(0, m.shape[1], delta_x):
    for y in np.arange(0, m.shape[0], delta_y):
        m_small = mdl.clean(working[y:y + delta_y, x:x + delta_x], inpaint=False)
        m[y:y + delta_y, x:x+delta_x] = m_small

# Complete profiling run
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs()
ps.sort_stats("tottime")
ps.print_stats()
print(s.getvalue())

hdus.close()

         29017 function calls (22655 primitive calls) in 2.127 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       20    1.427    0.071    1.427    0.071 {method 'cpu' of 'torch._C._TensorBase' objects}
      140    0.502    0.004    0.502    0.004 {built-in method conv2d}
       40    0.040    0.001    0.040    0.001 {method 'astype' of 'numpy.ndarray' objects}
    83/21    0.036    0.000    0.046    0.002 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.031    0.031    2.090    2.090 <ipython-input-6-acacc48dde83>:22(<module>)
       40    0.017    0.000    0.017    0.000 {method 'type' of 'torch._C._TensorBase' objects}
       20    0.012    0.001    0.012    0.001 {method 'copy' of 'numpy.ndarray' objects}
       20    0.009    0.000    2.059    0.103 model.py:91(clean)
      120    0.008    0.000    0.008    0.000 {built-in method batch_norm}
       20    0.006    0.000    0.006