In [268]:
#!pip install --upgrade pip
#import warnings
#warnings.filterwarnings('ignore')

from fastai.vision.all import *
from fastai.vision.widgets import *

#monitoring wandb
!pip install -qqq wandb
import wandb
from fastai.callback.wandb import *

In [295]:
!pip install --quiet --upgrade specutils
!pip install --quiet --upgrade astropy 
!pip install --quiet --upgrade pyts

#Maths and visualisation
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from pyts.image import GramianAngularField
#from pyts.datasets import load_gunpoint

# astronomy import
from specutils import Spectrum1D, SpectralRegion
from specutils.manipulation import extract_region
from astropy.io import fits
from astropy.units import Quantity
from astropy.visualization import quantity_support
quantity_support()  # for getting units on the axes below  
from astropy.time import Time
from astropy import units as u
from astropy.table import Table
from astropy import units as u

import os

In [325]:
import astropy.wcs as fitswcs #wcs
from specutils import Spectrum1D, SpectralRegion #spectrum1D (specutils)
from astropy.wcs import WCS

def generate_spec1Ds_bess(fits_file_path):
    """
    This function take a FITS file from BeSS VOTable request for generate two Spec1D with SpecUtils,
    entire spectrum and zoom spectrum on Halpha (6525 <> 6595 A)

    Parameters
      ----------
    fits_file_path : str
        The path for FITS file spectrum.

    Returns
    -------
    spec1D
        The spec1D object of the entire spectrum.
    spec1D
        the spec1D object of the Halpha spectrum.

    """
    f = fits.open(fits_file_path)

    #create global spectrum
    evt_data = Table(f[1].data)
    spec1D_global = Spectrum1D(spectral_axis=evt_data['WAVE'] * u.AA, flux=evt_data['FLUX'] * u.Jy)

    #create spectral region for Halpha line zoom (6525 <=> 6595)
    sr =  SpectralRegion(6525*u.AA, 6595*u.AA)

    #create a new spectrum of the selected region for plot
    sub_spectrum = extract_region(spec1D_global, sr)
    spec1D_Ha = Spectrum1D(flux=sub_spectrum.flux,spectral_axis=sub_spectrum.spectral_axis)
    f.close()
    return spec1D_global, spec1D_Ha 



def generate_spec1Ds(fits_file_path):
    """
    This function take a FITS file for Generate two Spec1D with SpecUtils,
    entire spectrum and zoom spectrum on Halpha (6525 <> 6595 A)

    Parameters
    ----------
    fits_file_path : str
        The path for FITS file spectrum.

    Returns
    -------
    spec1D
      The spec1D object of the entire spectrum.
    spec1D
      the spec1D object of the Halpha spectrum.

    """
    f = fits.open(fits_file_path)

    #open & load spectrum file
    specdata = f[0].data
    header = f[0].header

    #make WCS object
    wcs_data = fitswcs.WCS(header={'CDELT1': header['CDELT1'], 'CRVAL1': header['CRVAL1'],
                                   'CUNIT1': header['CUNIT1'], 'CTYPE1': header['CTYPE1'],
                                   'CRPIX1': header['CRPIX1']})

    #set flux units
    flux= specdata * u.Jy

    spec1D_global = Spectrum1D(wcs=wcs_data, flux=flux)

    #create spectral region for Halpha line zoom (6525 <=> 6595)
    sr =  SpectralRegion(6525*u.AA, 6595*u.AA)

    #create a new spectrum of the selected region for plot
    sub_spectrum = extract_region(spec1D_global, sr)
    spec1D_Ha = Spectrum1D(flux=sub_spectrum.flux,spectral_axis=sub_spectrum.spectral_axis)
    f.close()
    return spec1D_global, spec1D_Ha 



def spec_plot(spec1D_to_plot):
    """
    Generate and show a quick plot with Matplotlib from a Spec1D

    Parameters
    ----------
    spec1D_to_plot : spec1D
        The spec1D object to plot

    """
    fig, ax1 = plt.subplots(figsize=(9,6))

    #Global
    ax1.plot(spec1D_to_plot.spectral_axis, spec1D_to_plot.flux)
    ax1.set_ylabel('Flux')
    ax1.set_title("Global Spectrum")

    #Plot
    fig.tight_layout()
    plt.show()


def spec_plots(full_spec1D, region_spec1D):
    """
    Generate and show a two plots with Matplotlib from full and zoom Spec1D

    Parameters
    ----------
    full_spec1D : spec1D
        The spec1D object full spectrum to plot
    full_spec1D : spec1D
        The spec1D object Halpha zoom spectrum to plot

    """
    #create each plot 
    fig, axs = plt.subplots(2, 1, figsize=(16,9))

    #Global
    axs[0].plot(full_spec1D.spectral_axis, full_spec1D.flux)
    axs[0].set_ylabel('Flux')
    axs[0].set_title("Global Spectrum")

    #Halpha Zoom
    axs[1].plot(region_spec1D.spectral_axis, region_spec1D.flux)
    axs[1].set_ylabel('Flux')
    axs[1].set_title("Halpha Crop")

    #Plot
    fig.tight_layout()
    plt.show()
    fig.savefig('tmp/spec_plots.png', dpi=100, bbox_inches='tight')
    plt.close()


def generate_GAF(spec_to_gaf, field_type, png_path):
    """
    Generate a GAF graph in .png and record it in the path given in parameters.

    Parameters
    ----------
    spec_to_gaf : spec1D
        The halpha zoom spec1D object
    filed_type : str
        The type of graph : difference or sum
    png_path : Path
        The path for png records

    """
    X_spec = np.array([spec_to_gaf.flux, spec_to_gaf.spectral_axis])

    #record png from specArray in folder png_path
    gaf = GramianAngularField(method=field_type)
    X_gaf = gaf.fit_transform(X_spec)

    # Show the images for the first time series
    fig_gaf = plt.figure(figsize=(5,5))
    ax_gaf = fig_gaf.add_subplot()
    plt.axis('off')
    #ax_gaf.set_title('Plot title')
    #plt.colorbar(orientation='vertical')

    #Generate plot image
    im_gaf = plt.imshow(X_gaf[0], cmap='viridis', origin='lower')
    #Save fig
    plt.savefig(png_path, dpi=100, bbox_inches='tight')
    plt.close()
 

In [326]:
def gen_data_for_pred(file_upload_path):
    
    sip = generate_spec1Ds(file_upload_path)
    spec_plots(sip[0],sip[1]) 
    generate_GAF(sip[1], 'difference', 'tmp/gip.png')


In [327]:
def gen_data_for_pred_bess(file_upload_path):
    
    sip = generate_spec1Ds_bess(file_upload_path)
    spec_plots(sip[0],sip[1]) 
    generate_GAF(sip[1], 'difference', 'tmp/gip.png')

In [328]:
path = Path()
learn_inf = load_learner(path/'export.pkl', cpu=True)
btn_upload = widgets.FileUpload()
btn_upload_bess = widgets.FileUpload()
out_gadf = widgets.Output()
out_spectrums = widgets.Output()
lbl_pred = widgets.Label()

In [329]:
def on_data_change(change):
    lbl_pred.value = 'Loading... Perhaps...'
        
    with open('tmp/specip.fits', 'wb') as output_file: 
        for uploaded_filename in btn_upload.value:
            content = btn_upload.value[uploaded_filename]['content']   
            output_file.write(content)  
    
    #generate spec
    gen_data_for_pred('tmp/specip.fits')

    #record imgs
    img = PILImage.create('tmp/gip.png')
    specs_im =PILImage.create('tmp/spec_plots.png')
    
    out_gadf.clear_output()
    out_spectrums.clear_output()
    
    #set img to output widgets
    with out_gadf: display(img.to_thumb(256,256))
    with out_spectrums: display(specs_im)
      
    #set preds value to label widget
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction : {pred}; Probability : {probs[pred_idx]:.04f}'
    
    #remove all files used for preds
    os.remove("tmp/gip.png")
    os.remove("tmp/specip.fits")
    os.remove("tmp/spec_plots.png")

In [330]:
def on_data_change_bess(change):
    lbl_pred.value = 'Loading... Perhaps...'
        
    with open('tmp/specip.fits', 'wb') as output_file: 
        for uploaded_filename in btn_upload_bess.value:
            content = btn_upload_bess.value[uploaded_filename]['content']   
            output_file.write(content)  
    
    #generate spec
    gen_data_for_pred_bess('tmp/specip.fits')

    #record imgs
    img = PILImage.create('tmp/gip.png')
    specs_im =PILImage.create('tmp/spec_plots.png')
    
    out_gadf.clear_output()
    out_spectrums.clear_output()
    
    #set img to output widgets
    with out_gadf: display(img.to_thumb(256,256))
    with out_spectrums: display(specs_im)
      
    #set preds value to label widget
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction : {pred}; Probability : {probs[pred_idx]:.04f}'
    
    #remove all files used for preds
    os.remove("tmp/gip.png")
    os.remove("tmp/specip.fits")
    os.remove("tmp/spec_plots.png")

In [331]:
btn_upload.observe(on_data_change, names=['data'])
btn_upload_bess.observe(on_data_change_bess, names=['data'])

In [332]:
display(VBox([widgets.Label('Select a personnal spectrum (only reduced by ISIS, VSpec, Demetra)'),
              btn_upload,
              widgets.Label('Or Select a BeSS VOTable fits file'),
              btn_upload_bess,
              out_gadf,
              lbl_pred,
              out_spectrums]))

VBox(children=(Label(value='Select a personnal spectrum (only reduced by ISIS, VSpec, Demetra)'), FileUpload(v…