# RGB correlation maps
This notebook creates RGB correlation maps from the `.h5` master file for motor positions and metadata, the `_pilatus_integrated.h5` file for xrd data, and the `fitted_elemnts_scan_XXX.h5`or `fitted_elements_scans_xxx_to_xxx.h5`file for XRF. 

The user provides a set of three maps, based on XRD RoIs and XRF element fits.
These maps are then combined into a RGB color map.

It loads and treats the XRD data in the same way as the `XRD_maps` script.
For *XRF* data, it assumes that `mapping_XRF_fitting` has been run to create XRF fit maps.
Thus it is important that the `scans` parameter is set to the same value here as in `mapping_XRF_fitting`



In [None]:
#Import necessary libaries
%matplotlib widget
import os
import sys
import h5py
import numpy as np
import matplotlib.pyplot as plt

#To import DanMAX from the folder above:
sys.path.append('../')
import DanMAX as DM
style = DM.darkMode(style_dic={'figure.figsize':'large'})

In [None]:
#Set parameters


#Define parameters to find the data
#Use DM.getcurrentproposal to get defaults instead of none
proposal, visit = DM.getCurrentProposal() # This gets proposal and visit from the current path. You can set others to use previous data.


#Scans must be a list
groups = DM.getProposalScans(proposal=proposal,visit=visit)
group = 'group_name'
sample = 'sample_name'
scans= groups[group][sample]

#Select ranges to load (in the unit the data were integrated with)
#useful for reducing the size of large datasets
xrd_range = None
azi_range = None

# define ROI dictionary with a label str and a tuple of lower and upper limits, given in the scattering units (2theta or Q)
#     'label': (lower,upper)
xrd_rois = {'002' : (  8., 8.5),
            'multiplet': (9.9,10.4), 
            '310': (11.8, 13.1),
        'peak_2' : (1.95, 2.15)}


#Define which maps to plot, the lists give the lower and upper limit.
#Set limit to None
xrf_maps = {'Ca_K':[20, 700],
               'Sr_K': [20, 220],
               'Zn_K': [10, 100],}

#Define correlation maps
#Define the correlation maps as a list of lists, naming wether it should be XRD or XRF
#And what parameter it should then be.
correlation_maps = [
    #    Red Value    # Green Value  # Blue Value
    [['XRF','Ca_K'],['XRF','Sr_K'],['XRF','Zn_K']],
    [['XRD','002'],['XRD','multiplet'],['XRD','310']],
]

#Define the location where the XRF fits are saved within the XRF file
xrf_h5_fit_path = 'xrf_fits/xrf_fit/results/parameters/'

In [None]:
#Do not change this code
#Loading XRD data. This is the "slow" step. Hence why it is isolated in its own cell
maps = DM.mapping.stitchScans(scans,XRF=False,proposal=proposal,visit=visit,xrd_range=xrd_range,azi_range=azi_range)

#Apply I0 correction
maps['xrd_map'] = (maps['xrd_map'].transpose(2,0,1)/ maps['I0_map']).transpose(1,2,0)
xrd_avg = np.mean(maps['xrd_map'],axis=(0,1))

In [None]:
#Calculating Maps Do not change this code!
#This is seperated from the above cell to allow rerunning if the map settings are changed

#Create a dictionary of maps for plotting
I_xrd = dict()
for i,peak in enumerate(xrd_rois):
    lower, upper = xrd_rois[peak]
    lower, upper = np.digitize(lower, maps['x_xrd']), np.digitize(upper, maps['x_xrd'])

    #calculate background
    bkg = np.mean([np.mean(maps['xrd_map'][:,:,lower-5:lower],axis=2), np.mean(maps['xrd_map'][:,:,upper+1:upper+6],axis=2)],axis=0)
   
    # calculate the average diffraction peak signal from the integrated XRD data
   
    xrd_map_bgr = np.trapz((maps['xrd_map'][:,:,lower:upper].transpose(2,0,1) - bkg).transpose(1,2,0),maps['x_xrd'][lower:upper])
    xrd_map_bgr -=np.nanmin(xrd_map_bgr)
                           
    I_xrd[peak] = xrd_map_bgr
                           
#Loading XRF data
#Get the XRF file, based on the scans.
# ----   REQURES mapping_XRF_fitting TO HAVE BEEN RUN ---
xrf_fit_dir, xrf_fit_file = DM.mapping.getXRFFitFilename(scans,proposal=proposal,visit=visit)
fits_filename = f'{xrf_fit_dir}/elements/{xrf_fit_file}' 

#Create a map dictionary and collect the maps
xrf = dict()
with h5py.File(fits_filename,'r') as fit_file: 
    for i,elem in enumerate(xrf_maps.keys()):
        xrf[elem] = fit_file[f'{xrf_h5_fit_path}{elem}'][:]

## Plot Average XRD_pattern

In [None]:
# plot average XRD pattern
plt.figure()
plt.title('Average XRD pattern')
plt.plot(maps['x_xrd'],xrd_avg,label='average')
print(maps['xrd_map'].dtype)
if maps['Q']:
    plt.xlabel('Q (A-1)')
else:
    plt.xlabel('2theta (deg)')
plt.ylabel('Intensity')
plt.yscale('log')

#Add vertical lines showing the peaks of interest
for key in xrd_rois:
    xrd_roi = (maps['x_xrd'] > xrd_rois[key][0]) & (maps['x_xrd'] < xrd_rois[key][1])
    plt.axvline(xrd_rois[key][0],c='k',ls='--',lw=1)
    plt.axvline(xrd_rois[key][1],c='k',ls='--',lw=1)
    plt.plot(maps['x_xrd'][xrd_roi],xrd_avg[xrd_roi], '.',ms=2,label=key)
plt.legend()

## Plot correlation maps
The following cells creates and plots the correlation maps.

In [None]:
scale_place = [[10,15],[10,49]]
cols = 2
save_figures = True
rows = int(len(correlation_maps)/cols) + (len(correlation_maps)%cols!=0)
# initialize subplots with shared x- and y-axes
fig,axs = plt.subplots(rows,cols,sharex=True,sharey=True)
fig.tight_layout()
axs = axs.flatten() # flatten the axes list to make it easier to index


for i in range(len(correlation_maps)):
    #Define the RGB map matrix
    rgb_map = np.zeros(list(maps['x_map'].shape)+[3])
    
    #Populate the color map matrix
    for cm in range(3):
        #Read the type of data and key for the data
        map_type = correlation_maps[i][cm][0]
        map_key = correlation_maps[i][cm][1]
        #Read the data
        if map_type == 'XRF':
            cmap = xrf[map_key]*1 #The *1 is to ensure a copy is made, otherwise do .copy()
            cmin = xrf_maps[map_key][0]
            if cmin == None:
                cmin = np.min(cmap.flatten)
            cmax = xrf_maps[map_key][1]
        elif map_type == 'XRD':
            cmap = I_xrd[map_key]/np.max(I_xrd[map_key])
            cmin=np.min(cmap)
            cmax=np.max(cmap)*3/4
        
        #Rescale the data to between 0 and 1
        cmap[cmap<cmin] = cmin
        cmap[cmap>cmax] = cmax
        cmap -= cmin
        cmap /=np.max(cmap)
        rgb_map[:,:,cm] = cmap
    rgb_map[scale_place[0][0]:scale_place[0][1],scale_place[1][0]:scale_place[1][1],:] = 1
    
    # plot the map as an image
    ax = axs[i]
    #pcm = ax.pcolormesh(maps['x_map'],
    #                    maps['y_map'],
    #                    rgb_map,
    #                    shading='nearest')
    pcm = ax.imshow(rgb_map,extent=[np.min(maps['y_map']), np.max(maps['y_map']),np.min(maps['x_map']), np.max(maps['x_map'])])
    #Annotate based on the keys for the maps
    ax.annotate(f'{correlation_maps[i][0][1]}',
         (0,1.01),
         xycoords='axes fraction',
         color='red')
    ax.annotate(f'{correlation_maps[i][1][1]}',
         (0.4,1.01),
         xycoords='axes fraction',
         color='green')
    ax.annotate(f'{correlation_maps[i][2][1]}',
         (0.8,1.01),
         xycoords='axes fraction',
         color='blue')
    ax.set_aspect('equal')
    ax.set_xlabel('x mm')
    ax.set_ylabel('y mm')
    if save_figures:
        base_path = DM.findScan(scans[0])
        save_folder = f'{base_path.split("raw")[0]}/process/rgb_correlation/{group}'
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)
            os.chmod(save_folder,0o770)
        red_key = map_key = correlation_maps[i][0][1]
        green_key = map_key = correlation_maps[i][1][1]
        blue_key = map_key = correlation_maps[i][2][1]
        save_file = f'{save_folder}/{group}_{sample}_{red_key}_{green_key}_{blue_key}.png'
        plt.imsave(save_file,rgb_map)
        os.chmod(save_file,0o770)

In [None]:
base_path = DM.findScan(scans[0])
f'{base_path.split('raw')[0]}/process/rgb_correlation/{group}

In [None]:
base_path = DM.findScan(scans[0])
base_path.split('raw')[0]
save_folder = f'{base_path.split("raw")[0]}/process/rgb_correlation/{group}'

In [None]:
maps['cake_map'].shape

In [None]:
DM.mapping.stitchScans?

In [None]:
np.prod(maps['x_map'].shape)

In [None]:
maps['x_map'].shape

In [None]:
37*0.32