## In this notebook I will show how you can use the Type II surrogate model from Sarin et al. 2025.

first lets do some imports

In [None]:
import redback_surrogates as rs
import pandas as pd
%pylab inline

### This model requires some extra data which can be easily downloaded via the utility functions in redback_surrogates. This will download/extract/and save the necessary surrogate data to where your packages are installed. 

In [None]:
data_dir = rs.data_management.download_surrogate_data()
if data_dir:
    print(f"Surrogate data is available at: {data_dir}")

### Now you are ready to use the model. Note that this is just a general interface to the model. The interface for a model for fitting is as usual available via redback (https://github.com/nikhil-sarin/redback) directly. 

### This model uses tensorflow which is GPU/CPU agnostic but I recommend turning off the GPU for most-use cases unless you are actually going to take advantage of the GPU

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")

In [None]:
# Now you can just call the model 
tts, lbols = rs.supernovamodels.typeII_lbol(10.2, 0.03, -3., 4., 2, 2.5)
plt.loglog(tts, lbols)

In [None]:
tts, temp, rad = rs.supernovamodels.typeII_photosphere(10.2, 0.03, -3., 4., 2, 2.5)
plt.loglog(tts, temp);
plt.loglog(tts, rad)

In [None]:
out = rs.supernovamodels.typeII_spectra(10.2, 0.03, -3., 4., 2, 2.5)
plt.loglog(out.frequency, out.spectrum.T, color='black', alpha=0.1)

### The models can do big batches at once at the same computational cost - Note you will need bilby installed for the example below but its not necessary, I am just using it to make the prior draw e.g.,

In [None]:
# load the prior which defines the ranges where the surrogate model is trained
prior = rs.utils.get_priors('typeII_spectra')
samples = pd.DataFrame(prior.sample(1000))
# now samples is a just a dataframe of values 
samples

In [None]:
# Now evaluate lbol for all samples at once.
tts, lbols = rs.supernovamodels.typeII_lbol(**samples)
print(lbols.shape)
print(tts.shape)