In [None]:
!pwd

In [None]:
# Load paths for using psana
%env SIT_ROOT=/reg/g/psdm/
%env SIT_DATA=/cds/group/psdm/data/
%env SIT_PSDM_DATA=/cds/data/psdm/

In [None]:
import os
import torch
import random
import numpy as np
import h5py
import time

from peaknet.methods.unet       import UNet
from peaknet.model              import ConfigPeakFinderModel, PeakFinderModel
from peaknet.datasets.utils     import PsanaImg
from peaknet.datasets.transform import center_crop, coord_crop_to_img

from cupyx.scipy import ndimage
import cupy as cp

seed = 0

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors       as mcolors
import matplotlib.patches      as mpatches
import matplotlib.transforms   as mtransforms
%matplotlib inline

## Load psana for accessing image data

In [None]:
# Sample Rayonix dataset
exp           = 'mfxp22820'
run           = 13
img_load_mode = 'calib'
access_mode   = 'idx'
detector_name = 'Rayonix'
photon_energy = 9.54e3    # eV
encoder_value = -196

psana_img = PsanaImg(exp, run, access_mode, detector_name)

In [None]:
# Sample Rayonix dataset
exp           = 'mfx13016'
run           = 28
img_load_mode = 'calib'
access_mode   = 'idx'
detector_name = 'Rayonix'
# photon_energy = 9.54e3    # eV
# encoder_value = -196

psana_img = PsanaImg(exp, run, access_mode, detector_name)

In [None]:
# Load the global mask...
path_mask_gloabl = "label/global_mask.Rayonix.2023_0328_1117_28.v2.npy"
mask_global = np.load(path_mask_gloabl)


## Load Model

In [None]:
timestamp = "2023_0329_1716_06"
epoch = 225
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0329_1716_38"
epoch = 192
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0330_1743_37"
epoch = 196
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0331_1700_54"
epoch = 88
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0402_2312_43"
epoch = 118
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1251_20"
epoch = 58
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1219_09"
epoch = 30
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1251_20"
epoch = 64
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1216_17"
epoch = 131
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1219_09"
epoch = 98
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0403_1300_34"    # lam = 10.0
epoch = 120
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
timestamp = "2023_0404_1133_38"    # lam = 10.0
epoch = 154
fl_chkpt = None if timestamp is None else f"{timestamp}.epoch_{epoch}.chkpt"

In [None]:
base_channels = 8
focal_alpha   = 1.2
focal_gamma   = 2.0
lam = 10.0
method = UNet( in_channels = 1, out_channels = 2, base_channels = base_channels )
config_peakfinder = ConfigPeakFinderModel( method = method,  
                                           focal_alpha = focal_alpha,
                                           focal_gamma = focal_gamma,
                                           lam = lam, )
model = PeakFinderModel(config_peakfinder)
model.init_params()    # ..., load random weights 

In [None]:
# model.init_params(from_timestamp = timestamp)   # Run this will load a trained model
model.init_params(fl_chkpt = fl_chkpt)   # Run this will load a trained model

In [None]:
# Load model to gpus if available...
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
model  = torch.nn.DataParallel(model.method).to(device)

## Define hooks (Optional)

This is for for printing the metadata of the underlying neural networks.

In [None]:
# Define hooks...
activation_dict = {}
preactivation_dict = {}
def get_activation(name, tag = ''):
    if tag not in preactivation_dict: preactivation_dict[tag] = {}
    if tag not in activation_dict: activation_dict[tag] = {}
    def hook(model, input, output):
        preactivation_dict[tag][name] = input
        activation_dict[tag][name] = output
    return hook

# Define what layer you want to check...
for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.ReLU):
        layer.register_forward_hook(get_activation(name, 'relu'))

    if "final_conv" in name:
        layer.register_forward_hook(get_activation(name, 'final_conv'))

In [None]:
# # Check out the shape of the output in each layer...
# class ReLUShapePrinter:
#     def __call__(self, module, input, output):
#         if isinstance(module, torch.nn.ReLU):
#             print(f"{module.__class__.__name__} output shape: {output.shape}")
            
# # Register the shape printer on each layer
# for name, module in model.named_modules():
#     module.register_forward_hook(ReLUShapePrinter())
            
# Check out the shape of the output in each layer...
class NonReLUShapePrinter:
    def __call__(self, module, input, output):
        if not isinstance(module, torch.nn.ReLU):
            print(f"{module.__class__.__name__} output shape: {output.shape}")

# Register the shape printer on each layer
for name, module in model.named_modules():
    module.register_forward_hook(NonReLUShapePrinter())

In [None]:
img.shape

### Example of finding peaks in one image (access by event)

In [None]:
# Load images by event...
# event = 5735
# event = 3101
event = 1907
# event = 7619
img   = psana_img.get(event, None, 'calib')

# img *= mask_global[0]

# img = remove_outliers(img)
offset = 10
size_y, size_x = img.shape
xmin = 0 + offset
xmax = size_x - offset
ymin = 0 + offset
ymax = size_y - offset
img = img[ymin:ymax, xmin:xmax]
img = torch.tensor(img).type(dtype=torch.float)[None,None,].to(device)
img = (img - img.mean()) / img.std()

model.eval()
time_start = time.monotonic()
with torch.no_grad():
    fmap = model.forward(img)
mask_predicted = fmap.sigmoid()
time_end = time.monotonic()
print(f"Elapsed: {(time_end - time_start) * 1e3} ms.")

label_predicted, noise_predicted = mask_predicted[0, :]

label_predicted = label_predicted.cpu().detach().numpy()
noise_predicted = noise_predicted.cpu().detach().numpy()

In [None]:
# Load images by event...
# event = 5735
# event = 3101
event = 1907
# event = 7619
img   = psana_img.get(event, None, 'calib')

# img *= mask_global[0]

# img = remove_outliers(img)
offset = 10
size_y, size_x = img.shape
xmin = 0 + offset
xmax = size_x - offset
ymin = 0 + offset
ymax = size_y - offset
img = img[ymin:ymax, xmin:xmax]
img = torch.tensor(img).type(dtype=torch.float)[None,None,].to(device)
img = (img - img.mean()) / img.std()

model.eval()
time_start = time.monotonic()
with torch.no_grad():
    fmap = model.forward(img)
mask_predicted = fmap.sigmoid()
time_end = time.monotonic()
print(f"Elapsed: {(time_end - time_start) * 1e3} ms.")

# label_predicted, noise_predicted, bg_predicted = mask_predicted[0, :]
label_predicted, noise_predicted = mask_predicted[0, :]

# threshold_prob = 0.5
# mask_predicted[  mask_predicted < threshold_prob ] = 0
# mask_predicted[~(mask_predicted < threshold_prob)] = 1



threshold_prob = 0.2
label_predicted[  label_predicted < threshold_prob ] = 0
label_predicted[~(label_predicted < threshold_prob)] = 1

threshold_prob = 0.4
noise_predicted[  noise_predicted < threshold_prob ] = 0
noise_predicted[~(noise_predicted < threshold_prob)] = 1

label_predicted[noise_predicted > threshold_prob] = 0

# threshold_min = 0.5
# threshold_max = 0.6
# cond = (threshold_min > mask_predicted) + (mask_predicted > threshold_max)
# mask_predicted[ cond] = 0
# mask_predicted[~cond] = 1

# Crop the original image...
size_y, size_x = mask_predicted.shape[-2:]
img_crop, offset_tuple = center_crop(img, size_y, size_x, returns_offset_tuple = True)

img_crop       = img_crop[0, 0].cpu().detach().numpy()
# label_predicted, noise_predicted, bg_predicted = mask_predicted[0, :].cpu().detach().numpy()


# Locate peaks with coordinates...
structure = np.ones((3, 3), dtype=bool)
peak_predicted, num_peak_predicted = ndimage.label(cp.asarray(label_predicted), structure)
peak_pos_predicted_list = ndimage.center_of_mass(cp.asarray(label_predicted), peak_predicted, cp.asarray(range(1, num_peak_predicted+1)))

label_predicted = label_predicted.cpu().detach().numpy()
noise_predicted = noise_predicted.cpu().detach().numpy()
# bg_predicted    = bg_predicted.cpu().detach().numpy()

# [[[ Visual ]]]
# Set up the visual
ncols = 1
nrows = 1
fig   = plt.figure(figsize = (16*2,14*2))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1), ]

# Plot image
data = img_crop
vmin = np.mean(data) - 1 * data.std()
vmax = np.mean(data) + 6 * data.std()
im = ax_list[0].imshow(data, vmin = vmin, vmax = vmax)
im.set_clim(vmin, vmax)
# plt.colorbar(im, cax = ax_list[1], orientation="vertical", pad = 0.05)

# Plot mask overlay
data = label_predicted
vmin = 0
vmax = 1
im2 = ax_list[0].imshow(data, vmin = vmin, vmax = vmax, alpha = 1.)
im2.set_clim(vmin, vmax)
cmap1 = mcolors.ListedColormap(['none', 'red'])
im2.set_cmap(cmap1)

# Place a box on a peak
offset = 4
b_offset = 2
for y, x in peak_pos_predicted_list:
    if np.isnan(y) or np.isnan(x): continue

    x_bottom_left = x.get() - offset
    y_bottom_left = y.get() - offset

    rec_obj = mpatches.Rectangle((x_bottom_left, y_bottom_left),
                                 2 * offset, 2 * offset, 
                                 linewidth = 1.0, 
                                 edgecolor = 'yellow', 
                                 facecolor='none')
    ax_list[0].add_patch(rec_obj)

    y_bmin, x_bmin = 0, 0
    y_bmax, x_bmax = size_y, size_x
    ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
    ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])

In [None]:
# [[[ Visual ]]]
# Set up the visual
ncols = 1
nrows = 1
fig   = plt.figure(figsize = (16*1,14*1))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1),  ]

# Plot image
data = label_predicted
vmin = np.mean(data) - 1 * data.std()
vmax = np.mean(data) + 6 * data.std()
im = ax_list[0].imshow(data, vmin = vmin, vmax = vmax)
im.set_clim(vmin, vmax)
ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])

In [None]:
# [[[ Visual ]]]
# Set up the visual
ncols = 1
nrows = 1
fig   = plt.figure(figsize = (16*1,14*1))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1),  ]

# Plot image
data = noise_predicted
vmin = np.mean(data) - 1 * data.std()
vmax = np.mean(data) + 6 * data.std()
im = ax_list[0].imshow(data, vmin = vmin, vmax = vmax)
im.set_clim(vmin, vmax)
ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])

In [None]:
# [[[ Visual ]]]
# Set up the visual
ncols = 2
nrows = 1
fig   = plt.figure(figsize = (16*1,14*1))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1, 1/21],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1), fig.add_subplot(gspec[0, 1], box_aspect = 20) ]

# Plot image
data = bg_predicted
im = ax_list[0].imshow(data, vmin = 0, vmax = 1)
im.set_clim(vmin, vmax)
ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])
fig.colorbar(im, cax = ax_list[1])

In [None]:
# [[[ Visual ]]]
# Set up the visual
ncols = 1
nrows = 1
fig   = plt.figure(figsize = (16*5,14*5))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1), ]

# Plot image
data = img_crop
vmin = np.mean(data) - 1 * data.std()
vmax = np.mean(data) + 6 * data.std()
im = ax_list[0].imshow(data, vmin = vmin, vmax = vmax)
im.set_clim(vmin, vmax)
# plt.colorbar(im, cax = ax_list[1], orientation="vertical", pad = 0.05)

# Plot mask overlay
data = label_predicted
vmin = 0
vmax = 1
im2 = ax_list[0].imshow(data, vmin = vmin, vmax = vmax, alpha = 1.)
im2.set_clim(vmin, vmax)
cmap1 = mcolors.ListedColormap(['none', 'red'])
im2.set_cmap(cmap1)

# Plot mask overlay
data = noise_predicted
vmin = 0
vmax = 1
im2 = ax_list[0].imshow(data, vmin = vmin, vmax = vmax, alpha = 1.)
im2.set_clim(vmin, vmax)
cmap1 = mcolors.ListedColormap(['none', 'green'])
im2.set_cmap(cmap1)
ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])

# Place a box on a peak
offset = 4
b_offset = 2
for y, x in peak_pos_predicted_list:
    if np.isnan(y) or np.isnan(x): continue

    x_bottom_left = x.get() - offset
    y_bottom_left = y.get() - offset

    rec_obj = mpatches.Rectangle((x_bottom_left, y_bottom_left), 
                                 2 * offset, 2 * offset, 
                                 linewidth = 1.0, 
                                 edgecolor = 'yellow', 
                                 facecolor='none')
    ax_list[0].add_patch(rec_obj)

    y_bmin, x_bmin = 0, 0
    y_bmax, x_bmax = size_y, size_x

In [None]:
# [[[ Visual ]]]
# Set up the visual
ncols = 1
nrows = 1
fig   = plt.figure(figsize = (16*1,14*1))
gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1,],
                          height_ratios = [1,], 
                        )
ax_list = [ fig.add_subplot(gspec[0, 0], aspect = 1),  ]

# Plot image
data = noise_predicted < label_predicted
vmin = np.mean(data) - 1 * data.std()
vmax = np.mean(data) + 6 * data.std()
im = ax_list[0].imshow(data, vmin = vmin, vmax = vmax)
# im = ax_list[0].imshow(data, vmin = 0, vmax = 1)
# im.set_clim(vmin, vmax)
ax_list[0].set_xlim([x_bmin - b_offset, x_bmax + b_offset])
ax_list[0].set_ylim([y_bmin - b_offset, y_bmax + b_offset])

`%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% DIVIDER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%`
<br/>
<br/>
<br/>
<br/>