### **NOTE**:
Due to limited Ed resources, the kernel may stop while training GRU/LSTMs.

After completing this exercise, you can use the colab notebook below to train bigger networks


<a href="https://colab.research.google.com/drive/1Wvwn44WaoPdrB8V2K2xQlqylnEP-U4WP?usp=sharing" target="_blank" >
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>


## LSTM vs GRU

We will use both GRU and LSTM to perform sentiment analysis in tensorflow.keras and compare their performance using the custom IMDB dataset.

! pip install astroquery

In [2]:
from astroquery.gaia import Gaia
import matplotlib.pyplot as plt
import math

In [3]:
def extract_dl_ind(datalink_dict, key, figsize = [15,5], fontsize = 12, linewidth = 2, show_legend = True, show_grid = True):
    ""
    "Extract individual DataLink products and export them to an Astropy Table"
    ""
    dl_out  = datalink_dict[key][0].to_table()
    if 'time' in dl_out.keys():
        plot_e_phot(dl_out, colours  = ['green', 'red', 'blue'], title = 'Epoch photometry', fontsize = fontsize, show_legend = show_legend, show_grid = show_grid, figsize = figsize)
    if 'wavelength' in dl_out.keys():
        if len(dl_out) == 343:  title = 'XP Sampled'
        if len(dl_out) == 2401: title = 'RVS'
        plot_sampled_spec(dl_out, color = 'blue', title = title, fontsize = fontsize, show_legend = False, show_grid = show_grid, linewidth = linewidth, legend = '', figsize = figsize)
    return dl_out


def plot_e_phot(inp_table, colours  = ['green', 'red', 'blue'], title = 'Epoch photometry', fontsize = 12, show_legend = True, show_grid = True, figsize = [15,5]):
    ""
    "Epoch photometry plotter. 'inp_table' MUST be an Astropy-table object."
    ""
    fig      = plt.figure(figsize=figsize)
    xlabel   = f'JD date [{inp_table["time"].unit}]'
    ylabel   = f'magnitude [{inp_table["mag"].unit}]'
    gbands   = ['G', 'RP', 'BP']
    colours  = iter(colours)

    plt.gca().invert_yaxis()
    for band in gbands:
        phot_set = inp_table[inp_table['band'] == band]
        plt.plot(phot_set['time'], phot_set['mag'], 'o', label = band, color = next(colours))
    make_canvas(title = title, xlabel = xlabel, ylabel = ylabel, fontsize= fontsize, show_legend=show_legend, show_grid = show_grid)
    plt.show()


def plot_sampled_spec(inp_table, color = 'blue', title = '', fontsize = 14, show_legend = True, show_grid = True, linewidth = 2, legend = '', figsize = [12,4], show_plot = True):
    ""
    "RVS & XP sampled spectrum plotter. 'inp_table' MUST be an Astropy-table object."
    ""
    if show_plot:
        fig      = plt.figure(figsize=figsize)
    xlabel   = f'Wavelength [{inp_table["wavelength"].unit}]'
    ylabel   = f'Flux [{inp_table["flux"].unit}]'
    plt.plot(inp_table['wavelength'], inp_table['flux'], '-', linewidth = linewidth, label = legend)
    make_canvas(title = title, xlabel = xlabel, ylabel = ylabel, fontsize= fontsize, show_legend=show_legend, show_grid = show_grid)
    if show_plot:
        plt.show()


def make_canvas(title = '', xlabel = '', ylabel = '', show_grid = False, show_legend = False, fontsize = 12):
    ""
    "Create generic canvas for plots"
    ""
    plt.title(title,    fontsize = fontsize)
    plt.xlabel(xlabel,  fontsize = fontsize)
    plt.ylabel(ylabel , fontsize = fontsize)
    plt.xticks(fontsize = fontsize)
    plt.yticks(fontsize = fontsize)
    if show_grid:
        plt.grid()
    if show_legend:
        plt.legend(fontsize = fontsize*0.75)

In [4]:
#query = f"SELECT * FROM gaiadr3.gaia_source WHERE has_epoch_photometry = 'TRUE' and source_id = 1035533795140608"
query = f"SELECT source_id FROM gaiadr3.vari_eclipsing_binary"

job     = Gaia.launch_job_async(query)
results = job.get_results()
print(f'Table size (rows): {len(results)}')

INFO: Query finished. [astroquery.utils.tap.core]
Table size (rows): 2184477


In [5]:
total_ids = len(results['source_id'].value.data)
batch = 5000
ids = results['source_id'].value.data

In [6]:
retrieval_type = 'EPOCH_PHOTOMETRY'          # Options are: 'EPOCH_PHOTOMETRY', 'MCMC_GSPPHOT', 'MCMC_MSC', 'XP_SAMPLED', 'XP_CONTINUOUS', 'RVS', 'ALL'
data_structure = 'INDIVIDUAL'   # Options are: 'INDIVIDUAL', 'COMBINED', 'RAW'
data_release   = 'Gaia DR3'     # Options are: 'Gaia DR3' (default), 'Gaia DR2'

k=0
for i in range(264,math.ceil(total_ids/batch)):
    if i * batch > total_ids:
        batch_ids = ids[i*batch:total_ids]
    else:
        batch_ids = ids[i*batch:(i+1)*batch]
    datalink = Gaia.load_data(ids=batch_ids, data_release = data_release, retrieval_type=retrieval_type, data_structure = data_structure, verbose = False, output_file = None)
    dl_keys  = [inp for inp in datalink.keys()]
    dl_keys.sort()
    for dl_key in dl_keys:
        datalink[dl_key][0].to_table().to_pandas().to_csv("./vari_eclipsing_binary/"+dl_key+".csv")
    k+=batch
    print(k," downloaded")

print()
print(f'The following Datalink products have been downloaded:')
for dl_key in keys:
    print(f' * {dl_key}')

5000  downloaded
10000  downloaded
15000  downloaded
20000  downloaded
25000  downloaded
30000  downloaded
35000  downloaded
40000  downloaded
45000  downloaded
50000  downloaded
55000  downloaded
60000  downloaded


KeyboardInterrupt: 

In [7]:
i

264