# The 21cm background with 21cmFAST

In this tutorial we will learn how to simulate the 21cm background with a popular semi-numerical code. We are going to: 

1) Derive co-eval signal cubes

2) Derive lightcones of 21cm data

3) Analyse: 21cm power spectrum

This tutorial follows part of the tutorials here: https://github.com/21cmfast/21cmFAST/tree/master/docs/tutorials
As well as descriptions here: https://21cmfast.readthedocs.io/en/latest/tutorials

### Code repository
More information on the semi-numerical code 21cmFAST can be found here:

https://github.com/21cmfast/21cmFAST

https://21cmfast.readthedocs.io

### Publications
21cmFAST was introduced here: https://arxiv.org/pdf/1003.3878.pdf

as well as its python-wrapped version: https://doi.org/10.21105/joss.02582 

In [None]:
# some packages that we will need
#uncomment on google colab:
#!pip install -q condacolab
#import condacolab
#condacolab.install()
#!conda --version
#!conda install -c conda-forge 21cmFAST

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
# We change the default level of the logger so that
# we can see what's happening with caching.
import logging, sys, os
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

import py21cmfast as p21c

# For plotting the cubes, we use the plotting submodule:
from py21cmfast import plotting

# For interacting with the cache
from py21cmfast import cache_tools

print(f"Using 21cmFAST version {p21c.__version__}")

In [None]:
# Clear the cache
if not os.path.exists('_cache'):
    os.mkdir('_cache')

p21c.config['direc'] = '_cache'
cache_tools.clear_cache(direc="_cache")

## 1) Derive co-eval signal cubes

Co-eval cubes are boxes taken at at fixed time, or redshift. The run_coeval method performs all necessary simulation steps to derive the 21cm co-eval box. 

In [None]:
# call run_coeval at three different redshifts
coeval8, coeval9, coeval10 = p21c.run_coeval(
    redshift = [8.0, 9.0, 10.0],
    user_params = {"HII_DIM": 100, "BOX_LEN": 100, "USE_INTERPOLATION_TABLES": True},
    cosmo_params = p21c.CosmoParams(SIGMA_8=0.8),
    astro_params = p21c.AstroParams({"HII_EFF_FACTOR":20.0}),
    random_seed=12345
)


We now have a look at the set of parameters we can access.

### Task 1: 
Explore user_params, cosmo_params, astro_params:

Create some more co-eval cubes in different settings.


In [None]:
print(coeval8.user_params)
help(p21c.AstroParams)

Let's now have a look at the cubes!

###  Question: 
What dimensions do the co-eval cubes created have?

### Task 2:
Plot some co-eval cubes.

In [None]:
# base example
print(coeval8.brightness_temp.shape)

fig, ax = plt.subplots(1,3, figsize=(14,4))
for i, (coeval, redshift) in enumerate(zip([coeval8, coeval9, coeval10], [8,9,10])):
    plotting.coeval_sliceplot(coeval, ax=ax[i], fig=fig);
    plt.title("z = %s"%redshift)
plt.tight_layout()

# which data boxes can exist?
print(p21c.wrapper.get_all_fieldnames(coeval8))

# choose another kind to plot
fig, ax = plt.subplots(1,3, figsize=(14,4))
for i, (coeval, redshift) in enumerate(zip([coeval8, coeval9, coeval10], [8,9,10])):
    plotting.coeval_sliceplot(coeval, kind='density', ax=ax[i], fig=fig);
    plt.title("z = %s"%redshift)
plt.tight_layout()

In [None]:
# bonus: more functionalities are implemented, such as calculating the average, global 21cm signal
# _struct attributes contain the numpy arrays of the data cubes
print(coeval8.brightness_temp_struct.global_Tb)
print(coeval9.brightness_temp_struct.global_Tb)
print(coeval10.brightness_temp_struct.global_Tb)

### Question: 
What does run_coeval() calculate? 

Let's try to access each field individually!

In [None]:
# default parameters
p21c.CosmoParams._defaults_

# initial density and velocity fields
initial_conditions = p21c.initial_conditions(
    user_params = {"HII_DIM": 100, "BOX_LEN": 100},
    cosmo_params = p21c.CosmoParams(SIGMA_8=0.8),
    random_seed=54321
)


print(initial_conditions.cosmo_params)
print(initial_conditions.fieldnames)

plotting.coeval_sliceplot(initial_conditions, "hires_density")

In [None]:
# perturbed density and velocity fields (at given redshift)

perturbed_field = p21c.perturb_field(
    redshift = 8.0,
    init_boxes = initial_conditions
)

plotting.coeval_sliceplot(perturbed_field, "density")
plotting.coeval_sliceplot(perturbed_field, "velocity")


In [None]:
# ionisation field (assuming post-heating)

ionized_field = p21c.ionize_box(
    perturbed_field = perturbed_field
)

plotting.coeval_sliceplot(ionized_field, "xH_box")

In [None]:
# 21cm brightness offset temperature

brightness_temp = p21c.brightness_temperature(
    ionized_box=ionized_field, perturbed_field=perturbed_field
)

plotting.coeval_sliceplot(brightness_temp)

In [None]:
# Bonus: calculate the brightness temperature without assuming post-heating

spin_temp = p21c.spin_temperature(
    perturbed_field = perturbed_field,
    zprime_step_factor=1.05,
)

plotting.coeval_sliceplot(spin_temp, "Ts_box")



ionized_box = p21c.ionize_box(
    spin_temp = spin_temp,
    zprime_step_factor=1.05,
)



brightness_temp = p21c.brightness_temperature(
    ionized_box = ionized_box,
    perturbed_field = perturbed_field,
    spin_temp = spin_temp
)

## 2a) Compare saliency maps for different inputs

### Task 2a: 
Plot both saliency for simulations-only and mock lightcones

### Question 2a: 
Does the network's attention shift with the inclusion of noise? If yes, how?

In [None]:
# Task 2, plot and compare here
model2 = keras.models.load_model('models/3D_21cmPIE_Net_optmock_par6.h5')
#google colab: one might have to adjust the path e.g. to 'introspection-tutorial/models/..

obj_saliency2 = Saliency(model2, model_modifier=model_modifier,clone=True) 

# generate saliency maps for parameters i.e. classes
parameters=[0,1,2,3,4,5]
lc = data[0,0]
for para in parameters:
    def loss(output):
        return output[0][para] # shape (samples,classes)
    map_saliency2 = obj_saliency2(loss,lc.reshape(140,140,2350,1))

print(map_saliency2.shape)
nslice = 130
plt.figure()
plot_saliency = plt.imshow(map_saliency2[0,nslice,:,:], cmap=cm.hot)
plt.figure()
plot_lc = plt.imshow(lc[nslice,:,:], cmap="EoR",vmin=-150,vmax=30)

## 2b) Beyond vanilla saliency 
### Question 2b: 
Can the mapping of attention with saliency be improved, in particular can noise in the maps be reduced? Tipp: Use SmoothGrad from tf-keras-vis. What does SmoothGrad exploit?

### Task 2b:
Explore SmoothGrad for different setting (smooth_samples, smooth_noise).

In [None]:
#  Question 2b, try out SmoothGrad here
from tf_keras_vis.saliency import Saliency

# Change the sigmoid activation of the last layer to a linear one to not obstruct attention.
def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m

# create saliency object
obj_saliency = Saliency(model, model_modifier=model_modifier,clone=True) 

nsamples = 5
noise = 0.2

# generate saliency maps for parameters i.e. classes
parameters=[0,1,2,3,4,5]
lc = data[0,0]
for para in parameters:
    def loss(output):
        return output[0][para] # shape (samples,classes)
    map_saliency = obj_saliency(loss,lc.reshape(140,140,2350,1),smooth_samples=nsamples,smooth_noise=noise)

print(map_saliency.shape)
nslice = 130
plt.figure()
plot_saliency = plt.imshow(map_saliency[0,nslice,:,:], cmap=cm.hot)
plt.figure()
plot_lc = plt.imshow(lc[nslice,:,:], cmap="EoR",vmin=-150,vmax=30)

## 3) (optional) Derive class activation maps and compare 

### Task 3: 
Create CAMs, explore their spatial and temporal structure. 

In [None]:
# note: 
# there is currently an incompatibility of CAM generated with tf-keras-vis and the tensorflow version 2.9
# will need to downgrade tensorflow version
from tf_keras_vis.gradcam import Gradcam

# create CAM object
obj_cam = gradcam.Gradcam(model, model_modifier=model_modifier,clone=True) 

# generate heatmaps for attention
parameters=[0,1,2,3,4,5]
lc = data[0,0]
for para in parameters:
    def loss(output):
        return output[0][para] # shape (samples,classes)
    map_cam = obj_cam(loss,lc.reshape(140,140,2350,1),penultimate_layer=-1) # focus on last convolutional layer

print(map_cam.shape)
nslice = 130
plt.figure()
plot_cam = plt.imshow(map_cam[0,nslice,:,:], cmap=cm.hot)
plt.figure()
plot_lc = plt.imshow(lc[nslice,:,:], cmap="EoR",vmin=-150,vmax=30)