## Neural Residuals

This notebook calculates and pickles residuals for spectral fits using neural network. These results are then visualized in notebook 03 and can be used to reproduce a figure in the paper.

By setting ```architecture=fader``` or ```architecture=factor``` it is possible to either run results for the factorDis or for the faderDis method.

In [None]:
import pandas as pd
import sqlite3
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import spatial
import sys
import pickle


#from tagging.src.datasets import ApogeeDataset
#from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward

sys.path.insert(0,'/share/splinter/ddm/taggingProject/taggingClean/')
from src.datasets import ApogeeDataset
from src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



from tagging.src.utils import get_batch, invert_x,get_xdata


In [None]:
n_batch = 9
n_bins = 7751 
n_conditioned = 3
architecture = "fader"

we next load the data and the model

In [None]:
data = pd.read_pickle("/share/splinter/ddm/taggingProject/taggingClean/data/final/train/spectra_noiseless.pd")


dataset = ApogeeDataset(data,n_bins)
evaluation_loader = torch.utils.data.DataLoader(dataset = dataset[0:18],
                                     batch_size = n_batch,
                                     shuffle = False,
                                     drop_last=True)



In [None]:
if architecture == "fader":
    conditioning_autoencoder = torch.load("/share/splinter/ddm/taggingProject/taggingClean/models/faderDis4wZ/runs/100/adN7214I4000",map_location=device)
elif architecture == "factor":
    conditioning_autoencoder = torch.load("/share/splinter/ddm/taggingProject/taggingClean/models/wasDiswZ/runs/0/wganI6000",map_location=device)



we load the spectra and associated parameters for the stars we will visualize

In [None]:
x_test1,u_test1,v_test1,idx_test1 = get_batch(0,n_batch,dataset)
x_test2,u_test2,v_test2,idx_test2 = get_batch(25000,n_batch,dataset)

we next evaluate the models

In [None]:
_,z1 = conditioning_autoencoder(x_test1,u_test1[:,0:n_conditioned],train_decoder=False)
_,z2 = conditioning_autoencoder(x_test2,u_test2[:,0:n_conditioned],train_decoder=False)
x1_pred,_ = conditioning_autoencoder(z1,u_test1[:,0:n_conditioned],train_encoder=False)
x1_pred_swp,_ = conditioning_autoencoder(z1,u_test2[:,0:n_conditioned],train_encoder=False)

In [None]:
x_test1 = invert_x(x_test1)
x_test2 = invert_x(x_test2)
x1_pred = invert_x(x1_pred)
x1_pred_swp = invert_x(x1_pred_swp)

## Plotting

We can now plot the visualized stars

In [None]:
xdata = get_xdata()


In [None]:

colors = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']

lw = 1
ls = (0, (5, 5))

i=0
n_start = 0
n_end = 256

fig, (ax1,ax2,ax3) = plt.subplots(3,1,sharex=True,gridspec_kw={'hspace': 0, 'wspace': 0})



ax1.plot(xdata[n_start:n_end],x_test1[i].detach().cpu().numpy()[n_start:n_end],linewidth=lw,label="$x_{1}$",c="b")
ax1.plot(xdata[n_start:n_end],x_test2[i].detach().cpu().numpy()[n_start:n_end],linewidth=lw,label="$x_{2}$",c="darkorange") 


ax2.plot(xdata[n_start:n_end],x1_pred_swp[i].detach().cpu().numpy()[n_start:n_end],linewidth=lw,label="$D(E(x_{1},u_{1}),u_{2})$",c="b")
ax2.plot(xdata[n_start:n_end],x_test2[i].detach().cpu().numpy()[n_start:n_end],linewidth=lw,label="$x_{2}$",c="darkorange")


fig.text(0.05, 0.62, 'flux', va='center', rotation='vertical',fontsize=20)

res1 = x1_pred_swp[i]-x_test2[i]
res1 = res1.detach().cpu().numpy()
ax3.plot(xdata[n_start:n_end],res1[n_start:n_end],linewidth=lw,label="$D(E(x_{1},u_{1}),u_{2})-x_{2}$",c="b")



fig.text(0.05, 0.25, 'residuals', va='center', rotation='vertical',fontsize=16)

fig.text(0.13,0.915,"a) FaderDis".format(*data["params"][i][0:3]+data["params"][i+25000][0:3]),va='center', rotation='horizontal',fontsize=16)

fig.text(0.13,0.965,"Star $x_1$: Teff= {} , logg = {}, [Fe/H]= {:.2g}         Star $x_2$: Teff= {} , logg = {}, [Fe/H]= {:.2g}".format(*data["params"][i][0:3]+data["params"][i+25000][0:3]),va='center', rotation='horizontal',fontsize=16)


ax1.set_ylim(0.6,1.0)
ax2.set_ylim(0.6,1.0)
ax3.set_ylim(-0.015,0.015)

fig.set_size_inches(14.5, 6.5)
plt.xlim(xdata[n_start],xdata[n_end])
plt.xlabel(r"Wavelength($\AA$)",fontsize=24)

yticks1 = ax1.yaxis.get_major_ticks()
yticks1[0].set_visible(False)

yticks2 = ax2.yaxis.get_major_ticks()
yticks2[0].set_visible(False)


ax1.legend()
ax2.legend()
ax3.legend()