<a href="https://colab.research.google.com/github/mhuertascompany/Saas-Fee/blob/main/hands-on/session3/VAE_SDSS_spectra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE FOR GALAXY SDSS SPECTRA

The goal of this tutorial is to see if VAEs are able to caputre the different galaxy types (e.g. star-forming, passive, AGNs..) from the unsupervised decomposition of spectra. See [Portillo+18](https://ui.adsabs.harvard.edu/abs/2020arXiv200210464P/abstract)
---



#### Before we start, make sure to open this Colab notebook in "PlayGround Mode" (top left) and to change the Runtime type to GPU by navigating to the toolbar and clicking Runtime -> Change runtime type and then changing Hardware accelerator to GPU

---





## Import packages


In [None]:
import os
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from sklearn import preprocessing
import matplotlib.pyplot as plt
import pdb
import pickle
from astropy.io import fits
from astropy.visualization.stretch import SqrtStretch
from astropy.visualization import ImageNormalize, MinMaxInterval

from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard

tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers

%pylab inline

## Data download and preparation

Before mounting the drive click on [this folder](https://drive.google.com/drive/folders/1PcftgBzBySo1Ync-Wdsp9arTCJ_MfEPE?usp=sharing) and add it to your google drive by following these steps:

*   Go to your drive 
*   Find shared folder ("Shared with me" link)
*   Right click it
*   Click Add to My Drive



Mount your drive into Colab:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Then load the data for training. The dataset (X) contains 2 numpy arrays with the galaxy images (128*128 pixels) and the known effective radii for every galaxy. For the training we are using Single Sersic Models convolved with an HST PSF and with real HST noise added. The effective size of the modeled galaxies is thus known and stored in vector Y. The goal is to estimate Y from X.

In [None]:
#Load data
pathinData="/content/drive/My Drive/EDE2019/spectra"

#images
spectra = np.load(pathinData+'/flux.npy')

#wavelength

wl = np.load(pathinData+'/wl.npy')

#labels
label = np.load(pathinData+'/bpt_labels.npy') 


## Visualize Spectra

In [None]:
stretch = SqrtStretch() 

randomized_inds_train = np.random.permutation(len(spectra))


fig = plt.figure()
for i,j in zip(randomized_inds_train[0:4],range(4)):
  ax = fig.add_subplot(2, 2, j+1)
  plt.plot(wl,np.log10(100*spectra[i,:]))
  #im = ax.imshow(x_train[i,:,:,0], origin='lower', cmap='gray',
    #vmin=vmin, norm=norm,vmax=vmax)
  plt.title('BPT='+str(label[i]))
  fig.tight_layout() 
  #fig.colorbar(im)


In [None]:
sp=spectra[0,np.where((wl>4000)&(wl<7000))]
print(len(sp[0]))

sp2=spectra[:,np.where((wl>4000)&(wl<7000))]
print(sp2.shape)
sp2=np.reshape(sp2,(sp2.shape[0],sp2.shape[2],1))
print(sp2.shape)

## Model Setup

Your goal is to setup a VAE that takes as input the spectra we plotted above and learns to generate them. Then you can plot the embeded space and see if things do cluster. Feel free to try differnt approaches! And compare with PCA...

In [None]:
#NEURAL NETWORK - CONVOLUTIONAL AUTOENCODER
sp=spectra[0,:]
original_dim = sp2.shape[1]
print(original_dim)


input_shape = (original_dim,1)
encoded_size = 5  # THIS IS THE SIZE OF THE BOTTLENECK --> FEEL FREE TO CHANGE
# AND EXPLORE
base_depth = 32


prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)

encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    #tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv1D(base_depth, 10, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv1D(base_depth, 10, strides=4,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv1D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv1D(2 * base_depth, 5, strides=4,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv1D(4 * encoded_size, 4, strides=4,
                padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior,weight=0.01)),
])



decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Reshape([1, encoded_size]),
    tfkl.Conv1D(2 * base_depth, 3, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=4),                      
    tfkl.Conv1D(2 * base_depth, 3, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=4),                      
    tfkl.Conv1D(2 * base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=2),                      
    tfkl.Conv1D(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=2),
    tfkl.Conv1D(base_depth, 10, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=2),                      
    tfkl.Conv1D(base_depth, 10, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.UpSampling1D(size=2),                      
    tfkl.Conv1D(filters=1, kernel_size=5, strides=1,
                padding='same', activation=tf.nn.tanh),
    tfkl.Flatten(),
    
    tfkl.Dense(tfpl.IndependentNormal.params_size(input_shape),activation=None),
    #tfpl.IndependentNormal(input_shape, tfd.Normal.mean),
    tfpl.IndependentNormal(input_shape, tfd.Normal.sample),
    #tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])


vae = tfk.Model(inputs=encoder.inputs,outputs=decoder(encoder.outputs[0]))














In [None]:
encoder.summary()
decoder.summary()

In [None]:
negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-4),
            loss=negloglik)







hist = vae.fit(np.arcsinh(sp2),np.arcsinh(sp2),epochs=10)


In [None]:
#predict 10 examples

xhat = vae(np.arcsinh(sp2[10:20]))
assert isinstance(xhat, tfd.Distribution)

print(spectra.shape)
#pdb.set_trace()

## Plotting the results
Plot here the results of your embeddings

In [None]:
fig = plt.figure()
for i,j in zip(range(10,20),range(4)):
  ax = fig.add_subplot(2, 2, j+1)
  plt.plot(wl[np.where((wl>4000)&(wl<7000))],np.arcsinh(sp2[i,:]))
  #plt.title('$Class$='+str(y_test[i]))
  fig.tight_layout() 
  #fig.colorbar(im)


fig = plt.figure()
for i,j in zip(range(10),range(4)):
  ax = fig.add_subplot(2, 2, j+1)
  plt.plot(wl[np.where((wl>4000)&(wl<7000))],xhat.sample()[i,:])
  #plt.title('$Class$='+str(y_test[i]))
  fig.tight_layout() 
  #fig.colorbar(im)

In [None]:
xhat = encoder(np.arcsinh(sp2))
assert isinstance(xhat, tfd.Distribution)




In [None]:
z=np.asarray(xhat.sample())
print(z.shape)

dim1=0
dim2=4

plt.scatter(z[np.where(label==-1),dim1],z[np.where(label==-1),dim2],c='red',s=2)
plt.scatter(z[np.where(label==5),dim1],z[np.where(label==5),dim2],c='blue',s=2)
plt.scatter(z[np.where(label==1),dim1],z[np.where(label==1),dim2],c='green',s=2)
plt.scatter(z[np.where(label==2),dim1],z[np.where(label==2),dim2],c='orange',s=2)
plt.scatter(z[np.where(label==3),dim1],z[np.where(label==3),dim2],c='pink',s=2)
plt.scatter(z[np.where(label==4),dim1],z[np.where(label==4),dim2],c='black',s=2)
plt.scatter(z[np.where(label==6),dim1],z[np.where(label==6),dim2],c='yellow',s=5)

In [None]:
import umap
from sklearn.preprocessing import StandardScaler
reducer = umap.UMAP()
scaled_z = StandardScaler().fit_transform(z)
embedding = reducer.fit_transform(scaled_z)
embedding.shape

In [None]:

import seaborn as sns
plt.scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=label,s=5)
plt.colorbar()