# Preprocessing of linebeam 3DXRD data from ID11
 This will do everything up to indexing grains
* Create sparse representation of the data
* Label the spots in the sparse data

This code is largely based on notebooks written by Haixing Fang and Indrajeet Tambe

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib widget

### Set parameters here

In [None]:
sample='SiMo1000_6_v2'
scan_names = ['05_0N_3DXRD']
dataroot = '/Users/al8720/Library/CloudStorage/OneDrive-Malmöuniversitet/projects/castIron/ESRF22/analysis_23/tdxrd/linebeam/rawdata'
analysisroot ='/Users/al8720/Library/CloudStorage/OneDrive-Malmöuniversitet/projects/castIron/ESRF22/analysis_23/tdxrd/linebeam'

In [None]:
#Export an evironment variable related to SLURM
%env SLURM_CPUS_PER_TASK=1

#### Functions

In [None]:
import fabio
import ImageD11.sinograms.dataset
import ImageD11.sinograms.properties
import ImageD11.sinograms.lima_segmenter
import ImageD11.sinograms.assemble_label
import h5py, hdf5plugin
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.colors

def setup_ds(scan_name, dataroot, analysisroot, sample,
             detector='frelon3', omegamotor='diffrz', dtymotor='diffty'):
    ds = ImageD11.sinograms.dataset.DataSet(dataroot,
                                            analysisroot,
                                            sample,
                                            scan_name,
                                            detector=detector,
                                            omegamotor=omegamotor,
                                            dtymotor=dtymotor)
    ds.import_scans()
    ds.import_imagefiles()
    ds.import_motors_from_master()
    print(ds)
    outdir = os.path.join(ds.analysisroot,ds.sample)
    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    print(outdir)
    outname = f'ds_{scan_name}.h5'
    ds.save(os.path.join(outdir,outname))
    return ds, outname

def make_bg(ds,detector='frelon3'):
    imgs = [] #list of seleted frames from the scans
    with h5py.File(os.path.join(ds.datapath,ds.masterfile),'r') as hin:
        for scan in ds.scans:
            nimg = hin[f'{scan}/measurement/{detector}'].shape[0] #nbr of images
            for i in range(0,nimg,20): #take every 20th image for background estimation
                imgs.append(hin[f'{scan}/measurement/{detector}'][i])
    imgs=np.asarray(imgs)
    #Go trough the image in steps of 5 5 times, find the min in each pixel. The backgroun is the mean of these
    bg = np.mean([imgs[i::5].min(axis=0) for i in range(5)], axis=0)
    if detector == 'frelon3':
        mask = np.ones(bg.shape,bool) #this is a mask where noting is masked
    else:
        raise f'Mask not implemented for detector {detector}'
    #save images
    bkgim = os.path.join(ds.analysisroot,ds.sample,f'{ds.dset}_bkg.edf')
    maskim = os.path.join(ds.analysisroot,ds.sample,f'{ds.dset}_{detector}_mask.edf')                 
    fabio.edfimage.edfimage(bg.astype(np.float32)).write(bkgim)
    fabio.edfimage.edfimage(mask.astype(np.uint8)).write(maskim)
    frm = imgs[0]-bg #for checking segmentation
    return bkgim, maskim, frm

def segment_spots(ds, outname, bkgim, maskim, frm, 
                  cut_value=25, 
                  check_segmentation=True, 
                  run_all=False,
                  parallel=True):
    #write a slurm script for the segmenter
    outdir = os.path.join(ds.analysisroot,ds.sample)
    shscript = ImageD11.sinograms.lima_segmenter.setup(os.path.join(outdir,outname))

    #set some segmenter options
    with h5py.File(os.path.join(outdir,outname),'r+') as hin:
        hin['lima_segmenter'].attrs['bgfile']=bkgim
        hin['lima_segmenter'].attrs['maskfile']=maskim
        hin['lima_segmenter'].attrs['cut']=cut_value
    options = ImageD11.sinograms.lima_segmenter.SegmenterOptions()
    options.load(os.path.join(outdir,outname),'lima_segmenter')
    options.jobid = 0
    mask = fabio.open(maskim).data
    options.mask = mask
    options.analysispath = os.path.join(outdir, f'{ds.sample}_{ds.dset}')
    if not parallel:
        options.files_per_core = len(ds.scans)
    #save options as a global variable in the segmenter class
    ImageD11.sinograms.lima_segmenter.OPTIONS = options
    print(options)
    if check_segmentation:
        #make a mapping between pixels and sparse (?)
        fun = ImageD11.sinograms.lima_segmenter.frmtosparse( mask, np.uint16 )
        npx, row, col, val = fun(frm, cut_value) #segment one image to check
        ret = ImageD11.sinograms.lima_segmenter.top_pixels( npx, row, col, val, options.howmany,  options.thresholds)
        spf = ImageD11.sinograms.lima_segmenter.clean( npx, row, col, val )

        fig,ax = plt.subplots(1,2,sharex=True,sharey=True)
        ax[0].imshow(frm,norm=matplotlib.colors.LogNorm())
        ax[1].imshow(spf.to_dense('intensity'),norm=matplotlib.colors.LogNorm())
    if run_all:
        ImageD11.sinograms.lima_segmenter.main(options,parallel=parallel)

def merge_sparse(ds, outname=None):
    if outname:
        outname = os.path.join(ds.analysisroot,outname)
    else:
        outname=os.path.join(ds.analysisroot,ds.sample,f'{ds.sample}_{ds.dset}_sparse.h5')
    outname=ImageD11.sinograms.assemble_label.harvest_masterfile(ds,outname)
    return outname
    
def label_pixels(ds_file,sparse_file, pks_file=None):
    if not pks_file:
        pks_file=ds_file.replace('ds_','pks_')
    #ImageD11.sinograms.properties.main(ds_file,sparse_file,pks_file)
    #process each layer separately
    ds = ImageD11.sinograms.dataset.load(ds_file)
    for row,scan in enumerate(ds.scans):
        print(f'----- {scan} -----')
        pkst = ImageD11.sinograms.properties.pks_table_from_scan(sparse_file,ds,row)
        scan_file = pks_file.replace('.h5',f'_{scan}.h5')
        pkst.save(scan_file)

    return pks_file


### Loop to segment all scans

In [None]:
%load_ext autoreload
%autoreload 2
ds_all=[]
sparse_all = []
for scan in scan_names:
    print(f'******** {scan} *********')
    # create dataset and save ds_***.h5 file
    ds,outname=setup_ds(scan,dataroot,analysisroot,sample)
    ds_all.append(os.path.join(ds.analysisroot,ds.sample,outname)) #needed for labeling 
    # make a background image
    bkgim,maskim, frm = make_bg(ds)
    # segment spots
    segment_spots(ds,outname,bkgim,maskim,frm,
                  cut_value=200,
                  check_segmentation=False,
                  run_all=True,
                  parallel=False)
    sparse = merge_sparse(ds)
    sparse_all.append(sparse) #needed for labeling
    print(f'Saved sparse representation in {sparse}')

### Time to label the peaks

In [None]:
for (scan,dsf,sparsef) in zip(scan_names,ds_all,sparse_all):
    print(f'******** {scan} *********')
    label_pixels(dsf,sparsef)
    


In [None]:
np.asarray(ds.omega).shape