In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import seaborn as sns
from speclearn.deep_learning.model_utils import (get_colorbar,
                                                 load_beta_VAE_model)
from speclearn.deep_learning.predict import (get_full_data, get_data,
                                             process_full_map, read_area)
from speclearn.io.data.aoi import get_full_map_aoi
from speclearn.plot.map import *
from speclearn.tools.cache import check_file
from speclearn.tools.data_tools import *
from pysptools.spectro import FeaturesConvexHullQuotient
local_wavelength = select_wavelength(s_0=0, s_1=-12)
import datetime

print('Current time: ', datetime.datetime.now())

from speclearn.tools.constants import *
from speclearn.io.data.aoi import get_full_map_aoi_longitude

sns.set_style('whitegrid')
sns.set_context('notebook')

In [2]:
k = 6
crs = False
area = 'large'
norm = False
full=True

if full:
    full_name='_full'
else:
    full_name=''

clist, cmap = get_colorbar(k)

In [None]:
aoi_list = get_full_map_aoi_longitude(step_size=20)
model, model_name = load_beta_VAE_model(crs=crs, norm=norm)
print('model:', model_name)

In [4]:
data_2d, coord, latent, recon = get_data(aoi_list, model_name=model_name, crs=crs, periods=[], norm=norm)

In [6]:
recon_2d = np.full((data_2d.shape[0], data_2d.shape[1], data_2d.shape[2]), np.nan)
for i, (x, y) in enumerate(coord):
    recon_2d[x,y, :] = recon[i]
data_2d = None
latent = None

In [7]:
if check_file(os.path.join(CACHE_CLUSTER, f'{model_name}_{k}_cluster_2d.npy')):
    cluster_2d = np.load(os.path.join(CACHE_CLUSTER, f'{model_name}_{k}_cluster_2d.npy'))

cluster_2d_s = cluster_2d[:,:400]
cluster_2d_n = cluster_2d[:,-400:]
cluster_2d_c = cluster_2d[:, 400:-400]

recon_2d_s = recon_2d[:,:400,:]
recon_2d_n = recon_2d[:,-400:,:]
recon_2d_c = recon_2d[:,400:-400,:]

In [8]:
# crs_mean_file = f'crs_mean_{model_name}_{k}{full_name}.npy'
# crs_std_file = f'crs_std_{model_name}_{k}{full_name}.npy'
# crs_exists = False

# if check_file(crs_mean_file):
#     crs_mean = np.load(crs_mean_file)
#     crs_std = np.load(crs_std_file)
#     crs_exists = True
# else:
#     crs_mean = []
#     crs_std = []

#     for c in range(0,5):
#         if c < 5:
#             if not crs_exists:
#                 spectra = recon_2d[cluster_2d==c]
#                 spectra = process_spectra(spectra, norm=False, crs=True)
#                 crs_mean.append(np.nanmean(spectra, axis=0))
#                 crs_std.append(np.nanstd(spectra, axis=0))

In [8]:
# crs_mean_file = f'crs_mean_{model_name}_{k}{full_name}.npy'
# crs_std_file = f'crs_std_{model_name}_{k}{full_name}.npy'

# crs_exists = False
# if check_file(crs_mean_file):
#     crs_mean = np.load(crs_mean_file)
#     crs_std = np.load(crs_std_file)
#     crs_exists = True
# else:
#     crs_mean = []
#     crs_std = []

#     for c in range(0,5):
#         if c < 5:
#             if not crs_exists:
#                 spectra = data_2d[cluster_2d==c]
#                 spectra = process_spectra(spectra, norm, crs=True)
#                 crs_mean.append(np.nanmean(spectra, axis=0))
#                 crs_std.append(np.nanstd(spectra, axis=0))

# np.save(f'crs_mean_{model_name}_{k}{full_name}.npy', np.array(crs_mean))
# np.save(f'crs_std_{model_name}_{k}{full_name}.npy', np.array(crs_std))

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 5), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,0].set_ylabel('Norm. CRS')
axs[1,0].set_ylabel('Norm. CRS')
axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

for c, ax in enumerate(axs.ravel()):
    if c < 5:
        spectra = recon_2d[cluster_2d==c]
        spectra = spectra[~np.isnan(spectra).any(axis=1)]

        y = (np.median((spectra), axis=0)) # normalize_data
        y_q1 = np.quantile((spectra), q=0.25, axis=0)
        y_q3 = np.quantile((spectra), q=0.75, axis=0)

        ax.errorbar(local_wavelength, y, c=clist[c],lw=2.0, label=f'Cluster {c+1}')
        ax.fill_between(local_wavelength, y_q1, y_q3, color=clist[c], alpha=0.4, edgecolor=None) 
        
        ax.set_ylim(-0.15,1.1)
        ax.set_xlim(450,2550)
        
        # Append the current handle and label to the legend lists
        legend_handles.append(ax.lines[-1])
        legend_labels.append(f'Cluster {c + 1}')
        
        # ax.vlines(1000, 0, 1., color='grey', linestyle='--')
        # ax.vlines(2000, 0, 1., color='grey', linestyle='--')
    else:
        legend = ax.legend(legend_handles, legend_labels,loc='center left')

#plt.savefig(f'spectra_clusters_crs_{model_name}_{k}{full_name}_crs.png', bbox_inches='tight', dpi=200)
plt.show()

In [None]:
for i in range(10,20):
    if np.any(np.isnan(data_2d_full[i,i,:])):
        continue
    plt.errorbar(local_wavelength, data_2d_full[i,i,:],alpha=0.7)
    print(np.std(data_2d_full[i,i,:]))
    plt.show()

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 5), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,0].set_ylabel('Norm. CRS')
axs[1,0].set_ylabel('Norm. CRS')
axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

for c, ax in enumerate(axs.ravel()):
    if c < 5:
        spectra = data_2d_full[cluster_2d_full==c]
        spectra = spectra[~np.isnan(spectra).any(axis=1)]

        y = normalize_data(np.median((spectra), axis=0))
        y_q1 = np.quantile((spectra), q=0.25, axis=0)
        y_q3 = np.quantile((spectra), q=0.75, axis=0)

        ax.errorbar(local_wavelength, y, c=clist[c],lw=2.0, label=f'Cluster {c+1}')
        ax.fill_between(local_wavelength, y_q1, y_q3, color=clist[c], alpha=0.4, edgecolor=None) 
        
        ax.set_ylim(-0.15,1.1)
        ax.set_xlim(450,2550)
        
        # Append the current handle and label to the legend lists
        legend_handles.append(ax.lines[-1])
        legend_labels.append(f'Cluster {c + 1}')
        
        # ax.vlines(1000, 0, 1., color='grey', linestyle='--')
        # ax.vlines(2000, 0, 1., color='grey', linestyle='--')
    else:
        legend = ax.legend(legend_handles, legend_labels,loc='center left')

#plt.savefig(f'spectra_clusters_crs_{model_name}_{k}{full_name}_crs.png', bbox_inches='tight', dpi=200)
plt.show()

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 5), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,0].set_ylabel('Norm. CRS')
axs[1,0].set_ylabel('Norm. CRS')
axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

for c, ax in enumerate(axs.ravel()):
    if c < 5:
        spectra = data_2d_full[cluster_2d_full==c]
        y = normalize_data(np.nanmean((spectra), axis=0))
        y_error = np.nanstd((spectra), axis=0)

        ax.errorbar(local_wavelength, y, c=clist[c],lw=2.0, label=f'Cluster {c+1}')
        ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c], alpha=0.4, edgecolor=None) 
        
        ax.set_ylim(-0.15,1.1)
        ax.set_xlim(450,2550)
        
        # Append the current handle and label to the legend lists
        legend_handles.append(ax.lines[-1])
        legend_labels.append(f'Cluster {c + 1}')
        
        # ax.vlines(1000, 0, 1., color='grey', linestyle='--')
        # ax.vlines(2000, 0, 1., color='grey', linestyle='--')
    else:
        legend = ax.legend(legend_handles, legend_labels,loc='center left')

#plt.savefig(f'spectra_clusters_crs_{model_name}_{k}{full_name}_crs.png', bbox_inches='tight', dpi=200)
plt.show()

In [None]:
import matplotlib
clist = ['#432e6b','#4580ba', '#b3b3b3', '#7cd250', '#fbeb37'] #fced69
cmap = matplotlib.colors.ListedColormap(clist, "")
cmap.set_under('black')
cmap

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 5), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,1].set_ylabel('Norm. CRS')
axs[0,1].yaxis.set_tick_params(labelbottom=True)
axs[1,0].set_ylabel('Norm. CRS')

axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

# remove subplot
fig.delaxes(axs[0,0])

for c, ax in enumerate(axs.ravel()):
    if c > 0:
        
        spectra = data_2d[cluster_2d==(c-1)]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}')
        ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c-1], alpha=0.3, edgecolor=None) 

        # Append the current handle and label to the legend lists
        legend_handles.append(ax.lines[-1])
        legend_labels.append(f'Cluster {c}')

        spectra = data_2d_s[cluster_2d_s==c-1]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}',ls='dashed')
        #ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c], alpha=0.4, edgecolor=None, hatch='\\') 

        spectra = data_2d_n[cluster_2d_n==c-1]
        y = normalize_data(np.nanmean(spectra, axis=0))
        y_error = yerr=np.nanstd(spectra, axis=0)
        ax.errorbar(local_wavelength, y, c=clist[c-1],lw=2.0, label=f'Cluster {c+1}',ls='dotted')
        #ax.fill_between(local_wavelength, y-y_error, y+y_error, color=clist[c], alpha=0.4, edgecolor=None, hatch=f'//') 

        ax.set_ylim(-0.2,1.1)
        ax.set_xlim(450,2550)
        
line = Line2D([0], [0], label='Central', color='black')
line_n = Line2D([0], [0], label='North', color='black', ls='dotted')
line_s = Line2D([0], [0], label='South', color='black',ls='dashed')
legend_handles.append(line)
legend_labels.append(f'Central (|latitude| < 70)')

legend_handles.append(line_s)
legend_labels.append(f'South (latitude < -70)')

legend_handles.append(line_n)
legend_labels.append(f'North (latitude > 70)')
legend = axs[0,1].legend(legend_handles, legend_labels, bbox_to_anchor=(-.69, 0.5), loc='center', ncol=1)
        #legend = ax.legend(legend_handles, legend_labels,loc='center left')

plt.savefig(FIGURE_DIR + f'spectra_clusters_crs_{model_name}_{k}{full_name}_per_latitude.png',bbox_inches='tight', dpi=400)
plt.show()

# Poles map

In [None]:
kmeans_name = f"kmeans_{k}_{model_name}{full_name}.pkl"
kmeans = cluster_with_kmeans(kmeans_name, data_2d, latent, k)

In [None]:
eu_cluster = []
cos_cluster = []

for mineral_name in mineral_groups['Mineral'].unique():
    print(mineral_name)
    eu_dist = []
    cos_dist = []

    for c in range(k):
        b = np.array(crs_mean[c])  # data_2d[cluster_2d == c]
        a = np.array(mineral_groups[mineral_groups['Mineral'] == mineral_name]['CRS'])
        
        eu_dist.append(np.linalg.norm(a - b))
        cos_dist.append(cosine_similarity(np.array([a]), np.array([b])))

    eu_cluster.append(np.argmin(eu_dist))
    cos_cluster.append(np.argmax(cos_dist))

    mineral_groups.loc[mineral_groups['Mineral'] == mineral_name, 'Ec Cluster'] = np.argmin(eu_dist)
    mineral_groups.loc[mineral_groups['Mineral'] == mineral_name, 'Cos Cluster'] = np.argmax(cos_dist)

print(eu_cluster)
print(cos_cluster)

In [None]:
from speclearn.deep_learning.ml_tools import normalize_data
sns.set_context('notebook')
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(15, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,0].set_ylabel('Norm. CRS')
axs[1,0].set_ylabel('Norm. CRS')
axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

for c, ax in enumerate(axs.ravel()):
    if c < 7:
        ax.errorbar(local_wavelength, normalize_data(crs_mean[c]), c=colors[c],lw=1.5)# , yerr=crs_std[c]
        legend_handles.append(ax.lines[-1])
        #sns.lineplot(data=mineral_groups[mineral_groups['Cos Cluster']==c], x='Wavelength', y='CRS', hue='Mineral', ax=ax, palette='flare')
        
        ax.set_ylim(-0.25,1.05)
        ax.set_xlim(450,2550)
        
        # Append the current handle and label to the legend lists
        legend_labels.append(f'Cluster {c + 1}')
        
        ax.vlines(1000, -0.25, 1.05, color='grey', linestyle='--')
        ax.vlines(2000, -0.25, 1.05, color='grey', linestyle='--')

    else:
        legend = ax.legend(legend_handles, legend_labels,loc='center left')
plt.savefig(f'spectra_clusters_{k}_crs_wo_error.png', bbox_inches='tight', dpi=200)
plt.show()

In [None]:
sns.set_context('notebook')
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(15, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.1})

# Define handles and labels for the legend
legend_handles = []
legend_labels = []
axs[0,0].set_ylabel('Norm. Reflectance')
axs[1,0].set_ylabel('Norm. Reflectance')
axs[1,0].set_xlabel('Wavelength [nm]')
axs[1,1].set_xlabel('Wavelength [nm]')
axs[1,2].set_xlabel('Wavelength [nm]')

for c, ax in enumerate(axs.ravel()):
    if c < 7:
        ax.errorbar(local_wavelength, crs_mean[c], yerr=crs_std[c], c=colors[c],lw=1.5, label=f'Cluster {c+1}')
        legend_handles.append(ax.lines[-1])
        sns.lineplot(data=mineral_groups[mineral_groups['Ec Cluster']==c], x='Wavelength', y='CRS', hue='Mineral', ax=ax, palette='flare')
        
        ax.set_ylim(-0.05,1.05)
        ax.set_xlim(450,2550)
        
        # Append the current handle and label to the legend lists
        legend_labels.append(f'Cluster {c + 1}')
        
        ax.vlines(1000, 0, 1., color='grey', linestyle='--')
        ax.vlines(2000, 0, 1., color='grey', linestyle='--')

    else:
        legend = ax.legend(legend_handles, legend_labels,loc='center left')
plt.savefig(f'spectra_clusters_{k}_crs.png', bbox_inches='tight', dpi=200)
plt.show()