In [1]:
import os
import copy
import wget
import time
import asyncio
import warnings
import logging

import numpy as np
import pandas as pd
from pathlib import Path

from astropy.io import fits
from astropy.modeling import models, fitting

from scipy import ndimage
from scipy.signal import medfilt
from scipy.ndimage.filters import gaussian_filter

from matplotlib import pyplot as plt

from lsst import cwfs
from lsst.cwfs.instrument import Instrument
from lsst.cwfs.algorithm import Algorithm
from lsst.cwfs.image import Image, readFile, aperture2image, showProjection
import lsst.cwfs.plots as plots

plt.rcParams['figure.figsize'] = [7, 6]

%matplotlib inline

In [2]:
sensitivity_matrix = np.array([[-1./131., 0., 0.],
                                   [0., 1./131., 0.],
                                   [0., 0., -1./4200.]])

In [3]:
rotation_matrix = lambda angle: np.array([
            [np.cos(np.radians(angle)), -np.sin(np.radians(angle)), 0.],
            [np.sin(np.radians(angle)), np.cos(np.radians(angle)), 0.],
            [0., 0., 1.]])

In [None]:
class DonutHandler:
    def __init__(self):
        
        self.data_folder = Path("data/")
        self.output_folder = Path("data/output")

        self.intra_fnames = []
        self.extra_fnames = []
        
        self.intra_exposures = []
        self.extra_exposures = []
        
        self.dz = []
        self.dz_dict = {}

        # I1/I2 get modified down below, so reset here
        self.I1 = []
        self.I2 = []
        self.pMask = []
        self.to_process = []

        self.zern = []

        # Select where your object is, but ours will be on-axis
        self.fieldXY = [0.0,0.0]

        self.ceny = 350 #round(im_shape[0]/2)
        self.cenx = 450 # round(im_shape[1]/2)
        
        self.pre_side = 300
        self.side = 192  # size for dz=1.5
        
    @property
    def ndata(self):
        return len(self.intra_fnames)
    
    def set_intra_extra(self, intra, extra):
        
        self.intra_fnames = intra
        self.extra_fnames = extra
        
        for i in range(self.ndata):
            self.intra_exposures.append((fits.open(self.data_folder / intra[i]))[0].data)
            self.extra_exposures.append((fits.open(self.data_folder / extra[i]))[0].data)
    
    def set_dz(self, dz_list):
        
        if len(dz_list) != self.ndata:
            raise RuntimeError("dzlist and ndata must have size lenght.")
        
        self.dz = np.copy(dz_list)
        cwfs_config_template = """#Auxiliary Telescope parameters:
Obscuration                             0.3525
Focal_length (m)                        21.6
Aperture_diameter (m)                   1.2
Offset (m)                              {}
Pixel_size (m)                          1.44e-5
"""
        unique_dz = np.unique(self.dz)
        
        for i in range(len(unique_dz)):
            config_index = f"auxtel_{i}"
            path = Path(cwfs.__file__).resolve().parents[3].joinpath("data", config_index)
            if not path.exists():
                os.makedirs(path)
            dest = path.joinpath(f"{config_index}.param")
            with open(dest, "w") as fp:
                fp.write(cwfs_config_template.format(unique_dz[i]*0.041))
            self.dz_dict[unique_dz[i]] = config_index

    def center_and_cut_image(self, index, side=400, semi_auto=False, manual=False):
        
        intra_exp = self.intra_exposures[index]
        extra_exp = self.extra_exposures[index]

        if not manual:
            
            if semi_auto:
                ceny, cenx = self.ceny, self.cenx
                print(ceny, cenx)                
            else:
                print("Automatic centering")
                im_shape=intra_exp.shape
                im_filtered = medfilt(intra_exp+extra_exp,[3,3])
                im_filtered -= int(np.median(im_filtered))
                mean = np.mean(im_filtered)
                # iter 1
                ceny, cenx = np.array(ndimage.measurements.center_of_mass(im_filtered), dtype=int)
                print(ceny, cenx)
            # iter 2
            intra_square = intra_exp[ceny-self.pre_side:ceny+self.pre_side, 
                                     cenx-self.pre_side:cenx+self.pre_side] 
            extra_square = extra_exp[ceny-self.pre_side:ceny+self.pre_side, 
                                     cenx-self.pre_side:cenx+self.pre_side]
            print(intra_square.shape, extra_square.shape)
            im = (intra_square+extra_square)
            im_filtered = medfilt(im,[3,3])
            im_filtered -= int(np.median(im_filtered))
            mean = np.mean(im_filtered)
            im_filtered[im_filtered < mean] = 0.
            im_filtered[im_filtered > mean] = 1.
            # iter 1
            cy2, cx2 = np.array(ndimage.measurements.center_of_mass(im_filtered), dtype=int)
            print(cy2-self.pre_side, cx2-self.pre_side)    
            ceny += (cy2-self.pre_side)
            cenx += (cx2-self.pre_side)
            
            # Now center individually 
            ## INTRA
            intra_square_2 = intra_exp[ceny-self.pre_side:ceny+self.pre_side, 
                                       cenx-self.pre_side:cenx+self.pre_side] 

            
            intra_filtered = medfilt(intra_square_2,[3,3])
            intra_filtered -= int(np.median(intra_filtered))

            mean = np.mean(intra_filtered)
            intra_filtered[intra_filtered < mean] = 0.
            intra_filtered[intra_filtered > mean] = 1.
            # iter 1
            cy2_intra, cx2_intra = np.array(ndimage.measurements.center_of_mass(intra_filtered), dtype=int)
            print(cy2_intra-self.pre_side, cx2_intra-self.pre_side)    
            ceny_intra = ceny+(cy2_intra-self.pre_side)
            cenx_intra = cenx+(cx2_intra-self.pre_side)
            
            ## EXTRA
            extra_square_2 = extra_exp[ceny-self.pre_side:ceny+self.pre_side, 
                                       cenx-self.pre_side:cenx+self.pre_side] 

            
            extra_filtered = medfilt(extra_square_2,[3,3])
            extra_filtered -= int(np.median(extra_filtered))

            mean = np.mean(extra_filtered)
            intra_filtered[extra_filtered < mean] = 0.
            intra_filtered[extra_filtered > mean] = 1.
            # iter 1
            cy2_extra, cx2_extra = np.array(ndimage.measurements.center_of_mass(extra_filtered), dtype=int)
            print(cy2_extra-self.pre_side, cx2_extra-self.pre_side)    
            ceny_extra = ceny+(cy2_extra-self.pre_side)
            cenx_extra = cenx+(cx2_extra-self.pre_side)
        else:
            print("Manual centering")
            ceny_intra=self.ceny #round(im_shape[0]/2)
            cenx_intra=self.cenx # round(im_shape[1]/2)
            
            ceny_extra=self.ceny #round(im_shape[0]/2)
            cenx_extra=self.cenx # round(im_shape[1]/2)
            
        side=int(self.side*self.dz[index]/1.5) # side length of image
        print(f"Side is {side}")
        print('Creating stamps of centroid [y,x] = [{},{}] with a side length of {} pixels'.format(ceny,cenx,side))
        im_shape=intra_exp.shape
        intra_square = intra_exp[ceny_intra-side:ceny_intra+side, cenx_intra-side:cenx_intra+side] 
        extra_square = extra_exp[ceny_extra-side:ceny_extra+side, cenx_extra-side:cenx_extra+side]

        return intra_square, extra_square
    
    def create_images(self):
        
        for i in range(self.ndata):
            try:
                i1, i2 = self.center_and_cut_image(i)
            except Exception as e:
                print(f"Could not process pair {i}: {self.intra_fnames[i]} x {self.extra_fnames[i]}. Consider removing from data")
                self.to_process.append(False)
                self.I1.append(None)
                self.I2.append(None)
            else:
                self.to_process.append(True)
                self.I1.append(Image(i1, self.fieldXY, Image.INTRA))
                self.I2.append(Image(i2, self.fieldXY, Image.EXTRA))

    def run_algo(self):
        # Declare instrument
        
        #declare algorithm
        # declare algorithm - exponential solver.
        
        for i in range(self.ndata):
            if not self.to_process[i]:
                print(f"Skiping pair {i}")
                self.zern.append(np.zeros(9))
                self.pMask.append(None)
                continue

            dz = self.dz[i]
            
            hex_to_focus_scale = 41.0
            offset=1.5 *hex_to_focus_scale # [mm] multiply hexapod dz by magnification factor
            pixelsize = 3.6e-6*4 # 4 is the binning
            # FIXME: put an assertion here and calculate binning above based on change in image size, also pull offset from filename!

            print(f"dz:{dz} - {self.dz_dict[dz]}")
            inst = Instrument(self.dz_dict[dz], self.I1[i].sizeinPix)

            print('Offset should be :{} [mm] at the focus, {} [mm] at the hexapod'.format(offset, offset/hex_to_focus_scale))
            print('Offset in file is :{} [mm] at the focus'.format(1e3*inst.offset))
            print('pixelSize should be: {}'.format(pixelsize))

            algo = Algorithm('exp', inst, 1) # example     

#             algo.reset(donut_handler.I1[i], donut_handler.I2[i])

            algo.runIt(inst,self.I1[i],self.I2[i],'onAxis')
            
            self.pMask.append(algo.pMask)
            
            self.zern.append(algo.zer4UpNm[0:9])



In [None]:
dh = DonutHandler()

In [None]:
dh.set_intra_extra( )
dh.set_dz([0.8])
dh.create_images()
dh.run_algo()
print(dh.zern[0])
print("==========================================")
print(f"[{dh.zern[0][3]},{dh.zern[0][4]},{dh.zern[0][0]}]")

In [None]:
zern = [dh.zern[0][3], dh.zern[0][4], dh.zern[0][0]]
ang = 
hex_corr = np.matmul(np.matmul(zern, rotation_matrix(ang)), sensitivity_matrix)
print("==========================================")
print(np.matmul(zern, rotation_matrix(ang)))
print("==========================================")
print(f"{hex_corr}")

In [None]:
x = np.arange(9)+4
for i in range(len(dh.zern)):
    plt.plot(x, dh.zern[i], 'o-', label=f'{dh.dz[i]}')

xlim = plt.xlim()

plt.plot(np.arange(15), np.zeros(15)+50, 'b--')
plt.plot(np.arange(15), np.zeros(15)-50, 'b--')
plt.xlim(xlim)
plt.ylabel("Zernike coeff (nm)")
plt.xlabel("Zernike index")
plt.grid()
plt.legend()

In [None]:
fig1 = plt.figure(1, figsize=(12,8))

ax11 = fig1.add_subplot(121)

ax11.set_title("defocus 0.8 - intra")
ax11.imshow(dh.I1[0].image0)
ax11.contour(dh.pMask[0]) 

ax12 = fig1.add_subplot(122)

ax12.set_title("defocus 0.8 - extra")
ax12.imshow(dh.I2[0].image0)
ax12.contour(dh.pMask[0]) 

In [None]:
dh = DonutHandler()

dh.set_intra_extra( )
dh.set_dz([0.8])
dh.create_images()
dh.run_algo()
print(dh.zern[0])
print("==========================================")
print(f"[{dh.zern[0][3]},{dh.zern[0][4]},{dh.zern[0][0]}]")

zern = [dh.zern[0][3], dh.zern[0][4], dh.zern[0][0]]
ang = 
hex_corr = np.matmul(np.matmul(zern, rotation_matrix(ang)), sensitivity_matrix)
print("==========================================")
print(np.matmul(zern, rotation_matrix(ang)))
print("==========================================")
print(f"{hex_corr}")

In [None]:
x = np.arange(9)+4
for i in range(len(dh.zern)):
    plt.plot(x, dh.zern[i], 'o-', label=f'{dh.dz[i]}')

xlim = plt.xlim()

plt.plot(np.arange(15), np.zeros(15)+50, 'b--')
plt.plot(np.arange(15), np.zeros(15)-50, 'b--')
plt.xlim(xlim)
plt.ylabel("Zernike coeff (nm)")
plt.xlabel("Zernike index")
plt.grid()
plt.legend()

In [None]:
fig1 = plt.figure(1, figsize=(12,8))

ax11 = fig1.add_subplot(121)

ax11.set_title("defocus 0.8 - intra")
ax11.imshow(dh.I1[0].image0)
ax11.contour(dh.pMask[0]) 

ax12 = fig1.add_subplot(122)

ax12.set_title("defocus 0.8 - extra")
ax12.imshow(dh.I2[0].image0)
ax12.contour(dh.pMask[0]) 

In [None]:
dh = DonutHandler()

dh.set_intra_extra( )
dh.set_dz([0.8])
dh.create_images()
dh.run_algo()
print(dh.zern[0])
print("==========================================")
print(f"[{dh.zern[0][3]},{dh.zern[0][4]},{dh.zern[0][0]}]")

zern = [dh.zern[0][3], dh.zern[0][4], dh.zern[0][0]]
ang = 
hex_corr = np.matmul(np.matmul(zern, rotation_matrix(ang)), sensitivity_matrix)
print("==========================================")
print(np.matmul(zern, rotation_matrix(ang)))
print("==========================================")


print(f"{hex_corr}")

In [None]:
x = np.arange(9)+4
for i in range(len(dh.zern)):
    plt.plot(x, dh.zern[i], 'o-', label=f'{dh.dz[i]}')

xlim = plt.xlim()

plt.plot(np.arange(15), np.zeros(15)+50, 'b--')
plt.plot(np.arange(15), np.zeros(15)-50, 'b--')
plt.xlim(xlim)
plt.ylabel("Zernike coeff (nm)")
plt.xlabel("Zernike index")
plt.grid()
plt.legend()

In [None]:
fig1 = plt.figure(1, figsize=(12,8))

ax11 = fig1.add_subplot(121)

ax11.set_title("defocus 0.8 - intra")
ax11.imshow(dh.I1[0].image0)
ax11.contour(dh.pMask[0]) 

ax12 = fig1.add_subplot(122)

ax12.set_title("defocus 0.8 - extra")
ax12.imshow(dh.I2[0].image0)
ax12.contour(dh.pMask[0]) 

In [None]:
dh = DonutHandler()

dh.set_intra_extra( )
dh.set_dz([0.8])
dh.create_images()
dh.run_algo()
print(dh.zern[0])
print("==========================================")
print(f"[{dh.zern[0][3]},{dh.zern[0][4]},{dh.zern[0][0]}]")

zern = [dh.zern[0][3], dh.zern[0][4], dh.zern[0][0]]
ang = 
hex_corr = np.matmul(np.matmul(zern, rotation_matrix(ang)), sensitivity_matrix)
print("==========================================")
print(np.matmul(zern, rotation_matrix(ang)))
print("==========================================")


print(f"{hex_corr}")

In [None]:
x = np.arange(9)+4
for i in range(len(dh.zern)):
    plt.plot(x, dh.zern[i], 'o-', label=f'{dh.dz[i]}')

xlim = plt.xlim()

plt.plot(np.arange(15), np.zeros(15)+50, 'b--')
plt.plot(np.arange(15), np.zeros(15)-50, 'b--')
plt.xlim(xlim)
plt.ylabel("Zernike coeff (nm)")
plt.xlabel("Zernike index")
plt.grid()
plt.legend()

In [None]:
fig1 = plt.figure(1, figsize=(12,8))

ax11 = fig1.add_subplot(121)

ax11.set_title("defocus 0.8 - intra")
ax11.imshow(dh.I1[0].image0)
ax11.contour(dh.pMask[0]) 

ax12 = fig1.add_subplot(122)

ax12.set_title("defocus 0.8 - extra")
ax12.imshow(dh.I2[0].image0)
ax12.contour(dh.pMask[0]) 

# LAST ONE

In [None]:
dh = DonutHandler()

dh.set_intra_extra( )
dh.set_dz([0.8])
dh.create_images()
dh.run_algo()
print(dh.zern[0])
print("==========================================")
print(f"[{dh.zern[0][3]},{dh.zern[0][4]},{dh.zern[0][0]}]")

zern = [dh.zern[0][3], dh.zern[0][4], dh.zern[0][0]]
ang = 
hex_corr = np.matmul(np.matmul(zern, rotation_matrix(ang)), sensitivity_matrix)
print("==========================================")
print(np.matmul(zern, rotation_matrix(ang)))
print("==========================================")


print(f"{hex_corr}")

In [None]:
x = np.arange(9)+4
for i in range(len(dh.zern)):
    plt.plot(x, dh.zern[i], 'o-', label=f'{dh.dz[i]}')

xlim = plt.xlim()

plt.plot(np.arange(15), np.zeros(15)+50, 'b--')
plt.plot(np.arange(15), np.zeros(15)-50, 'b--')
plt.xlim(xlim)
plt.ylabel("Zernike coeff (nm)")
plt.xlabel("Zernike index")
plt.grid()
plt.legend()

In [None]:
fig1 = plt.figure(1, figsize=(12,8))

ax11 = fig1.add_subplot(121)

ax11.set_title("defocus 0.8 - intra")
ax11.imshow(dh.I1[0].image0)
ax11.contour(dh.pMask[0]) 

ax12 = fig1.add_subplot(122)

ax12.set_title("defocus 0.8 - extra")
ax12.imshow(dh.I2[0].image0)
ax12.contour(dh.pMask[0]) 