In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import seaborn as sns
from speclearn.deep_learning.model_utils import (get_colorbar,
                                                 load_beta_VAE_model)
from speclearn.deep_learning.predict import get_data
from speclearn.plot.map import *
from speclearn.tools.cache import check_file
from speclearn.tools.data_tools import *
local_wavelength = select_wavelength(s_0=0, s_1=-12)

from speclearn.tools.constants import *
from speclearn.io.data.aoi import get_full_map_aoi_longitude

sns.set_style('whitegrid')
sns.set_context('notebook')

In [2]:
k = 5
crs = False
area = 'large'
norm = False
full=True

if full:
    full_name='_full'
else:
    full_name=''

clist, cmap = get_colorbar(k)

In [None]:
aoi_list = get_full_map_aoi_longitude(step_size=20)
model, model_name = load_beta_VAE_model(crs=crs, norm=norm)
print('model:', model_name)

In [4]:
data_2d, coord, latent, recon = get_data(aoi_list, model_name=model_name, crs=crs, periods=[], norm=norm)

In [6]:
recon_2d = np.full((data_2d.shape[0], data_2d.shape[1], data_2d.shape[2]), np.nan)
for i, (x, y) in enumerate(coord):
    recon_2d[x,y, :] = recon[i]
data_2d = None
latent = None

In [7]:
if check_file(os.path.join(CACHE_CLUSTER, f'{model_name}_{k}_cluster_2d.npy')):
    cluster_2d = np.load(os.path.join(CACHE_CLUSTER, f'{model_name}_{k}_cluster_2d.npy'))

cluster_2d_s = cluster_2d[:,:400]
cluster_2d_n = cluster_2d[:,-400:]
cluster_2d_c = cluster_2d[:, 400:-400]

recon_2d_s = recon_2d[:,:400,:]
recon_2d_n = recon_2d[:,-400:,:]
recon_2d_c = recon_2d[:,400:-400,:]

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 5), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,1].set_ylabel('Norm. CRS')
axs[0,1].yaxis.set_tick_params(labelbottom=True)
axs[1,0].set_ylabel('Norm. CRS')

axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

# remove subplot
fig.delaxes(axs[0,0])

for c, ax in enumerate(axs.ravel()):
    if c > 0:
        
        spectra = data_2d[cluster_2d==(c-1)]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}')
        ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c-1], alpha=0.3, edgecolor=None) 

        # Append the current handle and label to the legend lists
        legend_handles.append(ax.lines[-1])
        legend_labels.append(f'Cluster {c}')

        spectra = data_2d_s[cluster_2d_s==c-1]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}',ls='dashed')
        #ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c], alpha=0.4, edgecolor=None, hatch='\\') 

        spectra = data_2d_n[cluster_2d_n==c-1]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}',ls='dotted')
        #ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c], alpha=0.4, edgecolor=None, hatch=f'//') 

        ax.set_ylim(-0.2,1.1)
        ax.set_xlim(450,2550)
        
line = Line2D([0], [0], label='Central', color='black')
line_n = Line2D([0], [0], label='North', color='black', ls='dotted')
line_s = Line2D([0], [0], label='South', color='black',ls='dashed')
legend_handles.append(line)
legend_labels.append(f'Central (|latitude| < 70)')

legend_handles.append(line_s)
legend_labels.append(f'South (latitude < -70)')

legend_handles.append(line_n)
legend_labels.append(f'North (latitude > 70)')
legend = axs[0,1].legend(legend_handles, legend_labels, bbox_to_anchor=(-.69, 0.5), loc='center', ncol=1)
        #legend = ax.legend(legend_handles, legend_labels,loc='center left')

plt.savefig(FIGURE_DIR + f'spectra_clusters_crs_{model_name}_{k}{full_name}_per_latitude.png',bbox_inches='tight', dpi=400)
plt.show()