# pl_3d_plot
Plots a 3D spectral image. 

# Imports

In [1]:
%matplotlib qt
import numpy as np 
import matplotlib.pyplot as plt
from spec_im import gui_fname, plot_pl_summary, plot_pl_summary, plot_si_bands, plot_bss_results, plot_decomp_results
from spec_im import PLSpectralImage
import math
import hyperspy.api as hs
import seaborn as sns
import os

sns.set()
sns.set_style('ticks')

hs.preferences.gui(toolkit="traitsui")

In [2]:
def plot_3d_bss_results(s, spec_im, title='', cmap='gray', fig_rows=5, **kwargs):
    return plot_3d_hs_results(s.get_bss_loadings(), s.get_bss_factors(),
                           spec_im, title=title, cmap=cmap,
                           num_rows=fig_rows, **kwargs)

def plot_3d_decomp_results(s, spec_im, title='', cmap='gray', fig_rows=5, **kwargs):
    return plot_3d_hs_results(s.get_decomposition_loadings(),
                           s.get_decomposition_factors(),
                           spec_im, title=title, cmap=cmap,
                           num_rows=fig_rows, **kwargs)

def plot_3d_hs_results(loadings, factors, spec_im, num_rows=6, title='',
                    **kwargs):
    units, scaling = spec_im.get_unit_scaling()
    # grab the blind source separation loadings and factors
    loading_list = loadings.split()
    factor_list = factors.split()
    no_of_loadings = len(loading_list)

    # some quick math to calculate the number of rows per figure and number of figures
    nz = len(spec_im.z_array)
    no_of_figs = no_of_loadings//num_rows
    if no_of_loadings > no_of_figs*num_rows:
        no_of_figs = no_of_figs+1

    fig_list = list()

    for jj in range(0, no_of_figs):
        #print('fig ' + str(jj+1) + ' of ' + str(no_of_figs))
        f = plt.figure()

        # start of list for this figure
        l0 = jj*num_rows

        for ll in range(num_rows):
            lx = l0 + ll
            if lx >= no_of_loadings:
                break
            #print('component ' + str(lx+1) + ' of ' + str(no_of_loadings))
            for ii in range(nz):
                ax = plt.subplot(num_rows, 3+nz, 3*ll+ll*nz+1+ii)
                spec_im._plot(loading_list[lx].data[ii,:,:], **kwargs)
                if ll == 0:                  
                    zval = spec_im.z_array[ii]-spec_im.z_array[0]
                    ax.set_title('z = %0.1f %s' % (zval*scaling, units))
  
            plt.subplot(num_rows, 3+nz, (3*ll+(ll+1)*nz+2, 3*ll+(ll+1)*nz+3))
            plt.plot(spec_im.spec_x_array, factor_list[lx].data)
            plt.title('%d' % lx)
            plt.gca().ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
            plt.xlabel(spec_im.spec_units)

        plt.axis('tight')
#             plt.suptitle(title + ' pg ' + str(jj+1) + ' of ' + str(no_of_figs) + ' z = %0.4f mm' % spec_im.z_array[ii])
        plt.suptitle('%s no %d of %d' % (title, jj+1, no_of_figs))
        f.set_size_inches(10,10)
        plt.tight_layout(h_pad=1)
        plt.subplots_adjust(top=0.92)
        fig_list.append(f)

    return fig_list

# Load file

In [3]:
#fname = gui_fname()
fname = 'D:/Chris/BBOD_Share/uv_microscope/190430 confocal GaN pyramid uvpl/190427_082002_oo_asi_hyperspec_3d_scan.h5'

In [4]:
fpath = os.path.dirname(fname)
fbase = os.path.splitext(fname)[0]
sample = os.path.basename(fname)

In [5]:
si = PLSpectralImage(fname=fname)

Load from D:/Chris/BBOD_Share/uv_microscope/190430 confocal GaN pyramid uvpl/190427_082002_oo_asi_hyperspec_3d_scan.h5 complete.
9 x 102 x 101 spatial x 1044 spectral points


## Cropping the spectrum

In [None]:
si.plot_spec()

In [None]:
plt.figure()
si = si[200:950]
si.plot_spec()

In [None]:
plt.figure()
si.set_background(lims=(700, 900))
si.plot_spec()

In [None]:
si = si[354:500]
plt.figure()
si.plot_spec()

## Convert to eV

In [None]:
esi = si.to_energy()

# Visualize file

In [None]:
plot_pl_summary(si, num_rows=4, show_axes=False, show_scalebar=True, scalebar_alpha=0.5);

In [None]:
esi.z_array = esi.z_array[0:7]
esi.spec_im = esi.spec_im[0:7,:,:,:]

In [None]:
plot_pl_summary(esi, num_rows=4, show_axes=False, show_scalebar=True, scalebar_alpha=0.5);

Below is useful for fixing pyplot.suptitle and pyplot.axis('tight') conflicts

In [None]:
plt.subplots_adjust(top=0.9)

## To Signal1D

In [None]:
s = esi.to_signal()

In [None]:
s.plot(navigator_kwds=dict(cmap='viridis'))

In [None]:
roi = hs.roi.RectangularROI(left=0, right=100., top=0, bottom=100.)
scrop = roi.interactive(s)

In [None]:
roi = hs.roi.SpanROI(left=2, right=3.4)
s.plot(navigator_kwds=dict(cmap='viridis'))
scrop = roi.interactive(s)

In [None]:
s.spikes_removal_tool()

## Signal2D
Visualization with Signal2D for spectral contributions from individual bins. Could also be used for alignment. 

In [None]:
s2d = s.as_signal2D((0,1))

In [None]:
s2d.plot(cmap='viridis', navigator_kwds=dict(cmap='viridis'))

### Alignment

In [None]:
shifts = s2d.inav[95,:].estimate_shift2D()
print(shifts)

In [None]:
val = -2
nf = np.size(esi.spec_x_array)
nz = np.size(esi.z_array)
stupid_shift = np.empty((nf*nz,2))
for kk in range(nz):
#     if kk == 1:
#         stupid_shift_val = [-val, 0]
#     elif kk ==2:
#         stupid_shift_val = [0, val]
#     elif kk ==3:
#         stupid_shift_val = [val, 0]
#     elif kk ==4:
#         stupid_shift_val = [0, -val]
#     else:
#         stupid_shift_val = [0, 0]
    for jj in range(nf):
        stupid_shift_val = [kk*val, 0]
        stupid_shift[nf*kk + jj] = stupid_shift_val

In [None]:
s2d.align2D(shifts=stupid_shift)

In [None]:
s2d.plot(cmap='viridis', navigator_kwds=dict(cmap='viridis'))

## Back to Signal1D
Could be used to load the aligned maps for decomposition.

In [None]:
s1d = s2d.as_signal1D(3)

In [None]:
s1d = s2d.transpose(navigation_axes=[2,3,1])

In [None]:
s1d.plot()

In [None]:
s = s1d

# Hyperspy decomposition

In [None]:
# perform principal component analysis, look at the explained variance
s.decomposition(algorithm='svd')
s.plot_explained_variance_ratio()

In [None]:
s.plot_decomposition_results()

In [None]:
COMPS = 8

In [None]:
s.blind_source_separation(number_of_components=COMPS)
s.plot_bss_results()

In [None]:
sc = s.get_decomposition_model(components=(0,1,2,3,4,5,6,7,22,23))
sc.plot(navigator_kwds={'cmap': 'viridis'})

In [None]:
sc_si = esi.copy(signal=sc)
sc_si.spec_im = np.squeeze(sc_si.spec_im)
plot_pl_summary(sc_si, num_rows=6);

In [None]:
(s-sc).plot()

In [None]:
sc.decomposition(algorithm='nmf', output_dimension=6)

In [None]:
sc.plot_decomposition_results()

In [None]:
plot_3d_decomp_results(sc, esi, cmap='viridis', fig_rows=4, cbar_orientation='horizontal',
                       cbar_position='bottom', show_axes=False, title='NMF')

In [None]:
sc2d = sc.as_signal2D((0,1))

In [None]:
sc2d.plot(cmap='viridis',navigator_kwds={'cmap': 'viridis'})

In [None]:
sc2d.decomposition(algorithm='svd')
sc2d.plot_explained_variance_ratio()

In [None]:
sc2d.decomposition(algorithm='nmf', output_dimension=25)
sc2d.plot_decomposition_results()

In [None]:
sc2d_spat = sc.as_signal2D((2,3))

In [None]:
sc2d_spat.plot()

In [None]:
sc2d_spat.decomposition(algorithm='svd')
sc2d_spat.plot_explained_variance_ratio()

In [None]:
sc2d_spat.decomposition(algorithm='nmf', output_dimension=20)

In [None]:
sc2d_spat.plot_decomposition_results()

In [None]:
sc.plot_decomposition_results()

In [None]:
import matplotlib.backends.backend_pdf
pdf = matplotlib.backends.backend_pdf.PdfPages(fbase+"_decomp.pdf")
for fig in range(1, plt.figure().number): ## will open an empty extra figure :(
    pdf.savefig( fig )
pdf.close()
plt.close(fig='all')

## A1. sklearn based native decomposition

In [None]:
si.decomposition(algorithm='svd', output_dimension=0.9999)

In [None]:
si.plot_explained_variance_ratio()

In [None]:
si.blind_source_separation(number_of_components=10, max_iter=2000)

In [None]:
ii = 0
slice_list = si.get_slice_list()
loadings = si.get_bss_loadings()
factors = si.get_bss_factors()
num_factors = factors.shape[0]

num_rows = 5
num_figures = math.ceil(float(num_factors)/num_rows)
print(loadings.shape, factors.shape, si.spec_x_array.shape)
(nz, ny, nx, nf) = np.shape(si.spec_im)
num_cols = 3 + nz
show_scalebar=False

for jj in range(num_figures):
    plt.figure()
    for kk in range(num_rows):
        if jj*num_rows + kk >= num_factors:  
            break
        index = kk + jj*num_rows
        # print(num_rows, num_cols, 3*kk + 1)
        for ii in range(nz):
            plt.subplot(num_rows, num_cols, num_cols*kk + 1 + ii)
            if kk == 0:
                z_pos = 'z = %0.2f $\mu$m' % ((si.z_array[ii]-si.z_array[0])*1e3)
            else:
                z_pos = ''
                
            if kk==0 and ii==0:
                show_scalebar=True
            else:
                show_scalebar=False
            
            si._plot(loadings[index, ii, :, :], cbar_orientation='vertical', cbar_position='right',
                     title='%s' % z_pos, show_scalebar=show_scalebar)
            # print(num_rows, num_cols, (3*kk + 2, 3*kk+3))
        plt.subplot(num_rows, num_cols, (num_cols*kk + nz + 1, num_cols*kk + nz + 2))
        si._plot_spec(factors[index,:])
        plt.title('%d' % index)
        

In [None]:
slice_list = si.get_slice_list()
plt.figure()
si.get_slice(slice_list[1])[330:360].plot()

In [None]:
mask = si.get_slice(slice_list[2])[330:360].spec_im.sum(axis=-1) < 7.95e4
print(mask.shape)
plt.figure()
plt.imshow(mask)
si.apply_mask(mask)

In [None]:
# mask2 = np.empty(np.shape(si.spec_im.sum(axis=-1)))
# mask2[kk in range(np.size(si.z_array)), :, :, :] = mask
# #si.apply_mask(mask2)
# print('mask size', np.size(mask2), 'nonzero values in mask', np.count_nonzero(mask2.flatten()), '', np.size(si.spec_im.sum(axis=-1)))

In [None]:
si.decomposition(algorithm='svd', output_dimension=0.9999)
si.plot_explained_variance_ratio()

In [None]:
si.blind_source_separation(number_of_components=10)

In [None]:
ii = 0
slice_list = si.get_slice_list()
loadings = si.get_bss_loadings()
factors = si.get_bss_factors()
num_factors = factors.shape[0]

num_rows = 5
num_figures = math.ceil(float(num_factors)/num_rows)
print(loadings.shape, factors.shape, si.spec_x_array.shape)
(nz, ny, nx, nf) = np.shape(si.spec_im)
num_cols = 3 + nz
show_scalebar=False

for jj in range(num_figures):
    plt.figure()
    for kk in range(num_rows):
        if jj*num_rows + kk >= num_factors:  
            break
        index = kk + jj*num_rows
        # print(num_rows, num_cols, 3*kk + 1)
        for ii in range(nz):
            plt.subplot(num_rows, num_cols, num_cols*kk + 1 + ii)
            if kk == 0:
                z_pos = 'z = %0.2f $\mu$m' % ((si.z_array[ii]-si.z_array[0])*1e3)
            else:
                z_pos = ''
                
            if kk==0 and ii==0:
                show_scalebar=True
            else:
                show_scalebar=False
            
            si._plot(loadings[index, ii, :, :], cbar_orientation='vertical', cbar_position='right',
                     title='%s' % z_pos, show_scalebar=show_scalebar)
            # print(num_rows, num_cols, (3*kk + 2, 3*kk+3))
        plt.subplot(num_rows, num_cols, (num_cols*kk + nz + 1, num_cols*kk + nz + 2))
        si._plot_spec(factors[index,:])
        plt.title('%d' % index)
        

In [None]:
sc = s.get_decomposition_model(components=10)
sc.plot()

In [None]:
print('asdf')

In [None]:
plot_3d_decomp_results(sc, si, cmap='viridis', fig_rows=5)

In [None]:
sc = s.get_decomposition_model(components=4)
sc.plot()

In [None]:
from hyperspy.signals import Signal2D

In [None]:
from matplotlib.cm import ScalarMappable
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable

def voxelplot(spec_im, arr, cmap='viridis', alpha=0.2, pmin=5, pmax=95, tval=35):
    assert isinstance(spec_im, PLSpectralImage)
    assert len(spec_im.z_array) > 1
    
    units, scaling = spec_im.get_unit_scaling()
    x_array = spec_im.x_array*scaling
    y_array = spec_im.y_array*scaling
    z_array = spec_im.z_array*scaling
    
    dx = x_array[1] - x_array[0]
    nx = len(x_array)
    x_coords = np.linspace(0, x_array[-1]-x_array[0]+dx, num=nx+1)
    
    dy = y_array[1] - y_array[0]
    ny = len(y_array)
    y_coords = np.linspace(0, y_array[-1]-y_array[0]+dy, num=ny+1)
    
    dz = z_array[1] - z_array[0]
    nz = len(z_array)
    z_coords = np.linspace(0, z_array[-1]-z_array[0]+dz, num=nz+1)
    
    y_corners, x_corners, z_corners = np.meshgrid(y_coords, x_coords, z_coords)
    
    mapper = ScalarMappable(cmap=cmap)
    mapper.set_array(arr)
    mapper.set_clim(vmin=np.percentile(arr, pmin), vmax=np.percentile(arr, pmax))
    tmin = np.percentile(arr, tval)
    
    vol = np.zeros(arr.shape + (4,))
    filled = np.zeros(arr.shape, dtype=bool)
    ec = np.zeros(arr.shape + (4,))
    for kk in range(nz):
#         print('calculating rgba vals for %d of %d layers' % (kk, nz))
        vol[kk, :, :, :] = mapper.to_rgba(arr[kk, :, :], alpha=alpha, bytes=False)
    filled[np.nonzero(arr > tmin)] = True
    
    fig = plt.figure()
#     ax = plt.subplot(1,8,(1,7), projection='3d')
    ax = fig.gca(projection='3d')
#     divider = make_axes_locatable(ax)
#     cax = plt.subplot(1,8,8)
#     print('x_corners', x_corners.shape)
#     print('y_corners', y_corners.shape)
#     print('z_corners', z_corners.shape)
#     print('filled', filled.shape)
#     print('vol', vol.shape)
    
    vx = ax.voxels(x_corners, y_corners, z_corners, np.swapaxes(filled, 0, 2),
              facecolors=np.swapaxes(vol, 0, 2), edgecolors=np.swapaxes(ec, 0, 2))
    ax.set_xlabel(units)
    ax.set_ylabel(units)
    ax.set_zlabel(units)
#     fig.colorbar(vx, cax=cax, orientation='vertical')
    

In [None]:
voxelplot(si, si.spec_im.sum(axis=-1), tval=80)

In [None]:
voxelplot(si, list(sc.get_decomposition_loadings())[2].data, alpha=0.1, pmin=75, pmax=95, tval=90)

In [None]:
print('asdf')