In [None]:
import importlib

In [None]:
from pyvista import imred, tv, spectra, stars, slitmask, image
import numpy as np
import pdb
import copy
import matplotlib.pyplot as plt
import os
from astropy.table import vstack
import pandas as pd

In [None]:
# Use these lines if you are running the notebook yourself. Matplotlib
# window will open outside the notebook, which is the desired behavior so
# you can have a single display tool, which you should leave open. Other
# plot windows will also appear outside the notebook, which you can close
# as desired
%matplotlib qt
t=tv.TV()
plotinter=True

# following lines only for fully non-interactive demo of notebook
#%matplotlib inline
#plotinter=False
#t=None

In [None]:
# put directory name with images here
indir='UT230909'
red=imred.Reducer('KOSMOS',dir=indir)

In [None]:
red.log().show_in_notebook(display_length=10)

In [None]:
flatims=[22]
flat=red.mkflat(flatims,spec=True,display=None,littrow=False,)

In [None]:
if t is not None:
    t.tv(flat)

In [None]:
biastims=[74,75,76,77,78]
bias=red.mkbias(biastims,display=None)

In [None]:
if t is not None:
    t.tv(bias)

In [None]:
darktims=[94,96,96]
dark=red.mkdark(darktims,display=None)

In [None]:
if t is not None:
    t.tv(dark)

In [None]:
star1=red.reduce(20, crbox='lacosmic', bias=bias, display=t)

In [None]:
flat1 = red.reduce(22)

In [None]:
# this does not take mkflat().
trace=spectra.Trace(transpose=True)
t.tvclear()
bottom,top = trace.findslits(flat1,display=t,thresh=0.5,sn=True)

In [None]:
vars(trace)

In [None]:
trace.rows

In [None]:
kmsfile1='kms/kosmos.23.seg3g2Copy_2.kms'
targets1 = slitmask.read_kms(kmsfile1, sort='YMM')# sort='YMM'

In [None]:
df = targets1.to_pandas()

In [None]:
df # A look at your table in panda formate

In [None]:
# Specify the indices of the rows you want to remove
rows_to_remove = [0,1,2,3,4,5,8,10,12,14,15]

# Remove the specified rows
df_cleaned = df.drop(rows_to_remove)

In [None]:
# Give your index value
in_dex = [6, 7, 9, 11, 13]

# Create a new list to store the filtered lines
filtered_rows = []

for index, line in enumerate(trace.rows):
    if index in in_dex:
        filtered_rows.append(line)

# Replace the original trace.rows with the filtered list
trace.rows = filtered_rows

In [None]:
# Create your own trace from scratch.
trace1=spectra.Trace(sc0=2048,lags=range(-100,100),
                    rows= trace.rows ,transpose=red.transpose, rad=5, degree= 3) #[1585,1545],# 1372 #1545-1585
vars(trace1)

In [None]:
#Trace
srow= [1110, 1173, 1372,1564,1697]   #1173,746,954,1041, # list to allow for multiple spectra on an image, manually set
#srow,ids=trace.findpeak(crstar, thresh=50)  # alternatively, find peak(s)
# trace.find(star) will find the highest peak by cross-correlation
# trace.find(star,inter=True,display=t)  will let you mark a trace location

# rad is setting the width of your trace. It will take the center position to be the star position given taking from DS9.
trace1.trace(star1,srow,skip=10,
                    gaussian = True, display=t, rad= 5)
vars(trace1)

In [None]:
#Frame 15 is He, 16 is Ne, and 17 is Ar
arcs=red.sum([23,])
if t is not None:
    t.clear()
    t.tv(arcs)

In [None]:
arcec1=trace1.extract2d(arcs, display=t)

In [None]:
if len(targets1) == len(bottom) : 
    for arc,target in zip(arcec1,targets1) :
        arc.header['XMM'] = target['XMM']
        arc.header['YMM'] = target['YMM']
else :
    print('ERROR, number of identified slits does not match number of targets')

In [None]:
for i,arc in enumerate(arcec1) :
    
    wav=spectra.WaveCal('KOSMOS/KOSMOS_red_waves.fits')
    nrow=arc.shape[0] # this is referring to the slit width of each slit. 
    # get initial guess at shift from reference using XMM (KOSMOS red low!)
    shift=int(arc.header['XMM'])#*-22.5) # 500 #-wav.pix0)
    lags=np.arange(shift-400,shift+400)

    iter = True
    while iter :
        iter = wav.identify(arc[nrow//2],plot=True,plotinter=True,
                            lags=lags,thresh=10,file='new_wave_lamps/new_neon_red_center.dat')
        lags=np.arange(-150,150)
        plt.close()
        
    # Do the 2D wavelength solution, sampling 5 locations across slitlet
    wav.identify(arc,plot=True,nskip=nrow//5,thresh=10)
    plt.close()

In [None]:
# Sky window starts right on top of spectra window.
starec2=trace1.extract2d(star1, display= t,) # rad = 3 back=[[9,3],[-9,-3]],

#plt.figure()
#plt.plot(starec2.data[0])
#plt.plot(starec2.data[0])
#vars(starec2)

In [None]:
t.clear()
plt.figure()
for i,(o,a) in enumerate(zip(starec2,arcec1)) :
    print(o.shape)
    print(a.wave)
    o.add_wave(a.wave)
    name = o.header["FILE"].split(".")[0]
    print(name)
    #t.tv(o)
    #t.tv(o.wave)
    plt.plot(o.wave[10],o.data[10])
    #o.write(name + "_{:d}.fits".format(i))

In [None]:
#trace2 = spectra.Trace(transpose=False)
#trace2.rows = [0,starec2[i].data.shape[0]]
#trace2.index = [0]

In [None]:
#vars(trace2)

In [None]:
importlib.reload(spectra)
def model(x) :
    return x*0.

fig=plt.figure()
for i in range(len(starec2)) :
    #fig = plt.figure() 
    trace1 = spectra.Trace(transpose=False)
    trace1.rows = [0,starec2[i].data.shape[0]]
    trace1.index = [0]
    peak,ind = trace1.findpeak(starec2[i],thresh=10,sort=True)
    if len(peak) > 0:
        def model(x) :
            return x*0. + peak[0]
        trace1.model = [model]
        spec=trace1.extract(starec2[i],rad=5,back=[[-10,-5],[10,5]], display=t) # This line is important because you are extracting sky here.
        plt.figure(fig)
        spec.wave = starec2[i].wave[peak]
        print(spec.wave[0].shape,spec.data[0].shape)
        plt.plot(spec.wave[0],spec.data[0])
    else :
        print('no peak found for slit: ',i)
    #plt.draw()

In [None]:
import copy
importlib.reload(spectra)
def model(x) :
    return x*0.

fig=plt.figure()
for i in range(len(starec2)) :
    trace3 = spectra.Trace(transpose=False)
    trace3.rows = [0,starec2[i].data.shape[0]]
    trace3.index = [0]
    peak,ind = trace3.findpeak(starec2[i],thresh=10,sort=True)
    if len(peak) > 0:
        def model(x) :
            return x*0. + peak[0]
        trace3.model = [model]
        spec=trace3.extract(starec2[i],rad=4, display=None) #,display=t) # back=[[-10,-5],[5,10]
        plt.figure(fig)
        spec.wave = starec2[i].wave[peak]
        swav=copy.deepcopy(wav)
        swav.skyline( spec, thresh=0.5 , linear= False, inter=plotinter, file='pyvista/data/sky/skyline.dat')
        print(wav.model)
        print(swav.model)
        
        #print(spec.model)
        #print(spec.wave[0].shape,spec.data[0].shape)
        #plt.plot(spec.wave[0],spec.data[0])
    else :
        print('no peak found for slit: ',i)
    #plt.draw()