# Plot NMF maps from xrd flyscans

### Load data and plot average XRD for the map

In [None]:
%matplotlib widget
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from sklearn.decomposition import NMF
#To import DanMAX from the folder above:
sys.path.append('../')
import DanMAX as DM
style = DM.darkMode(style_dic={'figure.figsize':'small'})

In [None]:
# Define scan location:
#Note that scans must be a list!
scans = []
proposal,visit=DM.getCurrentProposal()

#Select ranges to load (in the unit the data were integrated with)
#useful for reducing the size of large datasets
xrd_range = ()
azi_range = None

#Load data XRD mapping data
maps = DM.mapping.stitchScans(scans,XRF=False,proposal=proposal,visit=visit,xrd_range=xrd_range,azi_range=azi_range)
#Apply I0 correction
xrd_map = (maps['xrd_map'].transpose(2,0,1)/ maps['I0_map']).transpose(1,2,0)

# ### ALTERNATIVE load XRD CT data ###
# fname = DM.findScan(scans[0])
# maps = DM.mapping.getXRDctMap(fname,xrd_range=xrd_range)
# xrd_map = maps['xrd_map']-maps['xrd_map'].min()
# ###

map_shape = xrd_map.shape
flat_shape = (map_shape[0]*map_shape[1],map_shape[2])
# 2theta data
tth = maps['x_xrd']

### Create a mask
Use the interactive sliders to select a reasonable threshold

In [None]:
im_std = np.std(xrd_map,axis=2)
widget = DM.interactiveImageHist(im_std)
plt.gcf().set_figwidth(8)

In [None]:
lower, upper = np.array(widget.result.split(),dtype=float)
mask = im_std>lower
xrd_avg = np.mean(xrd_map[mask],axis=(0))
xrd_med = np.median(xrd_map[mask],axis=(0))

### Plot mean and median XRD pattern

In [None]:
# plot average XRD pattern
plt.figure()
plt.title(f'scan-{scans[0]}')
plt.plot(tth,xrd_avg,label='mean')
plt.plot(tth,xrd_med,label='median')
plt.xlabel('2theta (deg)')
plt.ylabel('Intensity')
plt.legend()

### Non-negative Matrix Factorization
Find two non-negative matrices, i.e. matrices with all non-negative elements, (W, H) whose product approximates the non-negative matrix X  
$$W\times H \approx X$$
*X* has shape (*n*,*m*), *W* has shape (*n*,*o*), and *H* has shape: (*o*,*m*).  

If *X* is an xrd map with shape $(h\cdot w,2\theta)$, *W* is a map of NMF component contributions with shape $(h\cdot w, num\_of\_components)$, and *H* is the "diffraction patterns" of each component with shape $(num\_of\_components,2\theta)$

In [None]:
###########################
# initialize the NMF model
n_components = 3
###########################

model = NMF(n_components=n_components,  # number of components
            init='nndsvdar', # initialization procedure
            max_iter=1000,   # maximum number of iterations
           )
# # Learn a NMF model for the data X and returns the transformed data.
W = np.zeros((*map_shape[0:2],n_components),dtype=xrd_map.dtype)
W[mask] = model.fit_transform(xrd_map[mask])
H = model.components_

In [None]:
# calculate the average weight of each component
weight = np.mean(W,axis=(0,1))
# initialize figure
plt.figure()
plt.title(f'scan-{scans[0]}')
y_off = 0
# iterate through component
for i in range(n_components):
    y = weight[i]*H[i]+y_off
    # add an offset
    y_off = y.max()*1.05
    plt.plot(tth,y,label=f'#{i+1}')
plt.xlabel('2theta (deg)')
plt.ylabel('Intensity')
plt.legend()

### Plot NMF component map

In [None]:
# Set the number of columns for the figure
cols = 3

rows = int(n_components/cols) + (n_components%cols!=0)
# initialize subplots with shared x- and y-axes
fig,axs = plt.subplots(rows,cols,sharex=True,sharey=True)
fig.tight_layout()
axs = axs.flatten() # flatten the axes list to make it easier to index

for i in range(n_components):
    # plot the map as a pseudo colormesh
    ax = axs[i]
    ax.set_title(f'component #{i+1}')
    pcm = ax.pcolormesh(maps['x_map'],
                        maps['y_map'],
                        W[:,:,i],
                        shading='nearest',
                       )
    fig.colorbar(pcm,ax=ax,aspect=40,pad=0.05)
    ax.set_xlabel('x mm')
    ax.set_ylabel('y mm')
    # set the aspect ratio to equal to give square pixels
    ax.set_aspect('equal')

# delete surplus plots
for i in range(1,cols*rows-n_components+1):
    fig.delaxes(axs[-i])

fig.tight_layout()

### Plot RGB map
Plot up to three components as a RGB overlay map

In [None]:
##################################################################

# Select up to three components, change the order of the indices
# to change the order of the colors (index start at zero)
RGB_components = [0,1,2]

##################################################################

# initialize array
rgb = np.zeros((*map_shape[0:2],3),dtype=float)
# normalize to 0-1
rgb[:,:,0:len(RGB_components)] = W[:,:,RGB_components]/W[:,:,RGB_components].max()
# plot
fig = plt.figure(layout='constrained')
plt.title(f'scan-{scans[0]}')
plt.pcolormesh(maps['x_map'],
               maps['y_map'],
               rgb)
plt.xlabel('x mm')
plt.ylabel('y mm')
# set the aspect ratio to equal to give square pixels
plt.gca().set_aspect('equal')

legend_handles = [Patch(facecolor=color) for color in ['r','g','b']] 
legend_labels = [f'#{i+1}' for i in RGB_components]
fig.legend(legend_handles,legend_labels,loc='outside lower center',ncol=len(RGB_components))