In [None]:
import numpy as np
from astropy.io import fits
import astropy.units as u
import warnings
from scipy.interpolate import interp1d
from astropy.modeling import models, fitting
from specutils import Spectrum1D
from specutils.fitting import fit_continuum
from scipy.signal import correlate

def measure_radvel_check(file1, file2, radvel):
    c = 299792.458* u.km / u.s
    shifts={
        "Line 1":{"c":4388*u.AA, "shift":4388*u.AA*radvel/c, "wlmin":4368*u.AA, "wlmax":4408*u.AA},
        "Line 2":{"c":4922*u.AA,"shift":4922*u.AA*radvel/c, "wlmin":4902*u.AA, "wlmax":4942*u.AA},
        "Line 3":{"c":5016*u.AA, "shift":5016*u.AA*radvel/c, "wlmin":4996*u.AA, "wlmax":5036*u.AA}
    }
    vrads=[]
    for name, shift in shifts.items():
        wl1, flux1=read_fits_spectrum(file1)
        wl2, flux2=read_fits_spectrum(file2)

        mask1=(wl1>shift["wlmin"])&(wl1<shift["wlmax"])
        mask2=(wl2>shift["wlmin"])&(wl2<shift["wlmax"])
        wl1_line=wl1[mask1]
        wl2_line=wl2[mask2]
        f1=flux1[mask1]
        f2=flux2[mask2]
        interp_shift = interp1d(wl2_line.value + shift["shift"].value, f2.value, bounds_error=False, fill_value=0)
        f2 = interp_shift(wl2_line.value)* u.Unit('erg cm-2 s-1 AA-1')

        plt.figure(figsize=(8, 5))
        plt.plot(wl1_line, f1, label="Original spectrum line", color="darkgrey")
        plt.plot(wl2_line, f2, label="Spectrum line with an artificial 20km/s radial velocity shift", color="indianred")
        plt.xlabel("Wavelength (Å)")
        plt.ylabel("Flux")
        plt.xlim(4380, 4400)
        plt.ylim(0, 0.03)
        plt.legend(loc="lower left", frameon=False)
        plt.tight_layout()
        plt.show()
        
        corr = correlate(f2, f1, mode='full')
        lags = np.arange(-len(f1) + 1, len(f2))
        maxi = lags[np.argmax(corr)]
    
        delta_lambda = (wl1_line[1].value - wl2_line[0].value)
        shift_lambda = maxi * delta_lambda 
        vrads.append((((shift_lambda*u.AA) / (shift["c"]*u.AA)) * c).value)
        
    vrad=np.mean(vrads)
    error=(radvel.value-vrad)/radvel.value*100
    print(vrad, error)