# Imports

In [None]:
import hickle
import pandas as pd
import scipy.signal
import shap
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
from scipy import ndimage
import tqdm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import umap

work_dir = '/home/juanpabloalfonzo/Documents/Manga CNNs/Catalogues/'

In [None]:
filepath = '/home/juanpabloalfonzo/Documents/Manga CNNs/Catalogues/scalars_extra.cat'
d = hickle.load(filepath)
manga_scalars = pd.DataFrame(d)

In [None]:
filepath = work_dir + 'images_extra.cat'
d_im = hickle.load(filepath)

In [None]:
filepath = work_dir + 'shap_mass_extra.cat'
d_shapmass = hickle.load(filepath)

In [None]:
filepath = work_dir + 'shap_sfr_extra.cat'
d_shapsfr = hickle.load(filepath)

In [None]:
filepath = work_dir + 'shap_d4000_extra.cat'
d_shapage = hickle.load(filepath)

In [None]:
def plot_images_maps(xim = 5, yim = 5, imsize = 4, key='images', ids = [], sigma = 2,random_mode = False):
    
    if len(ids) == 0:
        if random_mode == True:
            ids = np.random.randint(len(d_im['images']),size=xim*yim)
        if random_mode == False:
            ids = np.arange(xim*yim)
    plt.subplots(figsize=(imsize*xim,imsize*yim))
    plt.subplots_adjust(hspace=0.01,wspace=0.01)
    
    shapmaps = []
    for i in range(xim*yim):
        plt.subplot(yim,xim,i+1)
        try:
            if key == 'images':

                plt.imshow(d_im['images'][ids[i]], origin='lower')
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'Manga-id: '+manga_scalars['mangaid'][ids[i]], color='white', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, manga_scalars['split'][ids[i]], color='white', fontsize=12)
                plt.text(tempx[1]*0.04, tempy[1]*0.09, ' n: %.1f \n A$_V$: %.1f \n z: %.3f' %(manga_scalars['sersic_n'][ids[i]], manga_scalars['Av'][ids[i]], manga_scalars['redshift'][ids[i]]) , color='white', fontsize=14)
            elif key == 'shap_mass':            

                temp = d_shapmass['shap_map_mass'][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                # clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                clim = np.amax(np.abs(temp))
                plt.imshow(d_im['images'][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'log M$_*$: %.2f' %manga_scalars['log_mstar'][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$M$_{*}^{\mathrm{pred}}$: %.1f' %(manga_scalars['log_mstar'][ids[i]] - manga_scalars['pred_mstar'][ids[i]]), color='k', fontsize=12)
                shapmaps.append(d_shapmass['shap_map_mass'][ids[i]].copy())
            elif key == 'shap_sfr':

                temp = d_shapsfr['shap_map_mass'][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                plt.imshow(d_im['images'][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'log SFR: %.2f' %manga_scalars['log_sfr'][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$SFR$^{\mathrm{pred}}$: %.1f' %(manga_scalars['log_sfr'][ids[i]] - manga_scalars['pred_sfr'][ids[i]]), color='k', fontsize=12)
                shapmaps.append(d_shapmass['shap_map_mass'][ids[i]].copy())
            elif key == 'shap_age':

                temp = d_shapage['shap_map_mass'][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                plt.imshow(d_im['images'][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'd4000: %.2f' %manga_scalars['d4000'][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$d4000$^{\mathrm{pred}}$: %.1f' %(manga_scalars['d4000'][ids[i]] - manga_scalars['pred_d4000'][ids[i]]), color='k', fontsize=12)
                shapmaps.append(d_shapmass['shap_map_mass'][ids[i]].copy())
            plt.axis('off')

        except:

            plt.axis('off')
            print('no shap for id = ',ids[i])
    return shapmaps

In [None]:
def plot_images_maps_mask(mask, xim = 5, yim = 5, imsize = 4, key='images', ids = [], sigma = 2,):
    """
    Function written to be able to take mask (filters) into account when plotting images and shap maps.
    Had to be designed to take into account the difference between numpy indicies and pandas indicies

    """
    
    if len(ids) == 0:
        ids = np.arange(xim*yim)
    plt.subplots(figsize=(imsize*xim,imsize*yim))
    plt.subplots_adjust(hspace=0.01,wspace=0.01)
    
    shapmaps = []
    for i in range(xim*yim):
        plt.subplot(yim,xim,i+1)
        try:
            if key == 'images':

                plt.imshow(np.array(d_im['images'])[mask][ids[i]], origin='lower')
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'Manga-id: '+np.array(manga_scalars['mangaid'])[mask][ids[i]], color='white', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, np.array(manga_scalars['split'])[mask][ids[i]], color='white', fontsize=12)
                plt.text(tempx[1]*0.04, tempy[1]*0.09, ' n: %.1f \n A$_V$: %.1f \n z: %.3f' %(np.array(manga_scalars['sersic_n'])[mask][ids[i]], np.array(manga_scalars['Av'])[mask][ids[i]], np.array(manga_scalars['redshift'])[mask][ids[i]]) , color='white', fontsize=14)
            elif key == 'shap_mass':            

                temp = np.array(d_shapmass['shap_map_mass'])[mask][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                # clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                clim = np.amax(np.abs(temp))
                plt.imshow(np.array(d_im['images'])[mask][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'log M$_*$: %.2f' %np.array(manga_scalars['log_mstar'])[mask][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$M$_{*}^{\mathrm{pred}}$: %.1f' %(np.array(manga_scalars['log_mstar'])[mask] - np.array(manga_scalars['pred_mstar'])[mask])[ids[i]], color='k', fontsize=12)
                shapmaps.append(np.array(d_shapmass['shap_map_mass'])[mask][ids[i]].copy())
            elif key == 'shap_sfr':

                temp = np.array(d_shapsfr['shap_map_mass'])[mask][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                plt.imshow(np.array(d_im['images'])[mask][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'log SFR: %.2f' %np.array(manga_scalars['log_sfr'])[mask][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$SFR$^{\mathrm{pred}}$: %.1f' %(np.array(manga_scalars['log_sfr'])[mask] - np.array(manga_scalars['pred_sfr'])[mask])[ids[i]], color='k', fontsize=12)
                shapmaps.append(np.array(d_shapmass['shap_map_mass'])[mask][ids[i]].copy())
            elif key == 'shap_age':

                temp = np.array(d_shapage['shap_map_mass'])[mask][ids[i]].copy()
                temp = ndimage.gaussian_filter(temp, sigma, mode='nearest')
                temp[np.abs(temp)<np.amax(np.nanpercentile(temp,[2,98]))] = np.nan
                clim = np.amax(np.abs(np.nanpercentile(temp,[2,98])))
                plt.imshow(np.array(d_im['images'])[mask][ids[i]],alpha=0.4, origin='lower')
                plt.pcolor(np.sum(temp,2), cmap='bwr',alpha=0.9)
                tempx, tempy = plt.xlim(), plt.ylim()
                plt.text(tempx[1]*0.042, tempy[1]*0.9, 'd4000: %.2f' %np.array(manga_scalars['d4000'])[mask][ids[i]], color='k', fontsize=14)
                plt.text(tempx[1]*0.042, tempy[1]*0.81, '$\Delta$d4000$^{\mathrm{pred}}$: %.1f' %(np.array(manga_scalars['d4000'])[mask] - np.array(manga_scalars['pred_d4000'])[mask])[ids[i]], color='k', fontsize=12)
                shapmaps.append(np.array(d_shapmass['shap_map_mass'])[mask][ids[i]].copy())
            plt.axis('off')

        except:
            plt.axis('off')
            print('no shap for id = ',ids[i])
    return shapmaps

# Stellar mass - massive and low-mass galaxies

In [None]:
shap.image_plot(d_shapmass['shap_map_mass'][ids[0]], np.array(d_im['images'][ids[0]]), labels=None, show=False)

In [None]:
temp = np.array(manga_scalars['log_mstar'])
# ids = np.argsort(temp)[22:30]
# ids = np.argsort(temp)[10:16]
ids = np.argsort(temp)[6:12]

sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); plt.show()

In [None]:
temp = np.array(manga_scalars['log_mstar'])
# ids = np.argsort(temp)[-12:]
# ids = np.delete(ids, 2)
ids = np.argsort(temp)[-12:-6]

sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2);

# sSFR - star forming and quiescent galaxies

In [None]:
temp = np.array(manga_scalars['log_sfr']) - np.array(manga_scalars['log_mstar'])
ids = np.argsort(temp)[-140:]
# ids = np.argsort(temp)[-116:]
print(temp[ids][0:6])

sigmaval = 4.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval);

In [None]:
temp = np.array(manga_scalars['log_sfr']) - np.array(manga_scalars['log_mstar'])
ids = np.argsort(temp)[-160:]
print(temp[ids][0:6])

sigmaval = 4.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval);

In [None]:
temp = np.array(manga_scalars['log_sfr']) - np.array(manga_scalars['log_mstar'])
# i = 30
i = 0
ids = np.argsort(temp)[i:i+6]

print(temp[ids])

sigmaval = 4.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval);

# old and young galaxies (d4000)

In [None]:
temp = np.array(manga_scalars['d4000'])
# i = 3
i = 0
# ids = np.argsort(temp)[i:i+10]
ids = np.argsort(temp)[i:i+6]
# ids = np.delete(ids, 5)

print(temp[ids])

sigmaval = 4.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval);

In [None]:
temp = np.array(manga_scalars['d4000'])
# i = -50

# ids = np.argsort(temp)[i:i+10]
ids = np.argsort(temp)[-39:]

print(temp[ids][:6])

sigmaval = 4.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval);

In [None]:
sigmaval = 1.0
xim = 5
yim = 5
ids = np.random.randint(len(d_im['images']),size=xim*yim)
plot_images_maps(xim,yim,ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim,ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim,ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim,ids=ids, key='shap_age', sigma=sigmaval * 2);

# Biggest/Smallest Error in Mass

In [None]:
temp = np.array(np.abs(manga_scalars['log_mstar']-manga_scalars['pred_mstar']))
ids = np.argsort(temp)[-6:]


sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); plt.show()

In [None]:
temp = np.array(np.abs(manga_scalars['log_mstar']-manga_scalars['pred_mstar']))
ids = np.argsort(temp)[:6]


sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps(xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps(xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); plt.show()

# Biggest/Smallest Error in SFR

In [None]:
no_nan = []

for i in range(len(manga_scalars)):
    if str(np.array(manga_scalars['log_sfr'])[i]) != 'nan':
        no_nan.append(i)

In [None]:
temp = np.abs(np.array(manga_scalars['log_sfr'])[no_nan]- np.array(manga_scalars['pred_sfr'])[no_nan])
ids = np.argsort(temp)[-6:]



sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); plt.show()

In [None]:
temp = np.abs(np.array(manga_scalars['log_sfr'])[no_nan]- np.array(manga_scalars['pred_sfr'])[no_nan])
ids = np.argsort(temp)[:6]



sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps_mask(no_nan,xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); plt.show()

# High-Mass Starforming Galaxies

Investigating the "turn-over" at the high mass part of the SFMS. Is the turn over due to the growth of the quiescent buldge component while only the disk continues SF activies as found by some literature? 

In [None]:
q = np.where((np.array(manga_scalars['log_mstar'])>11.5) & (np.array(manga_scalars['log_sfr'])>0) & (np.array(manga_scalars['split'],dtype=str)=='Test'))[0]
temp = np.array(manga_scalars['log_mstar'])[q]
ids = np.argsort(temp)[-6:]


sigmaval = 2.0
xim = 6
yim = 1
plot_images_maps_mask(q,xim,yim, ids=ids, key='images', sigma=sigmaval)
plot_images_maps_mask(q,xim,yim, ids=ids, key='shap_mass', sigma=sigmaval)
plot_images_maps_mask(q,xim,yim, ids=ids, key='shap_sfr', sigma=sigmaval)
plot_images_maps_mask(q,xim,yim, ids=ids, key='shap_age', sigma=sigmaval * 2); 