In [None]:
#import py4DSTEM
import math
import os
# import cv2
import matplotlib.pyplot as plt

import numpy as np
import math
from pyxem.utils.io_utils import _parse_hdr
import h5py
import hyperspy.api as hs
from converter_nord import save_signal
import pyxem as pxm
%matplotlib qt
#%matplotlib inline
from pathlib import Path

# Folders settings

In [None]:
# To set Hyperspy preferences with the interface
#hs.preferences.gui()

Choose root_dir depending on the computer (local, WS, etc...)

In [None]:
# config_file = 'configs/config_federico.yml'
config_file = './configs/config_gulnaz_new_dataset.yml'
os.chdir('.')
print(os.getcwd())

## Load Config file

In [None]:
# Load config file
import os
import yaml
from yacs.config import CfgNode as CN


with open(config_file, "r") as stream:
    try:
        cfg = yaml.safe_load(stream)

    except yaml.YAMLError as exc:
        print(exc)
cfg = CN(cfg)
# cfg.SETUP.DATA_PATH = root_dir+ cfg.SETUP.DATA_PATH
# cfg.SETUP.CALIB_DATA_PATH = root_dir+ cfg.SETUP.CALIB_DATA_PATH
print(cfg)

# Load Data in a "Electron Diffraction" 2D Signal

## Diffraction pattern calibration

In [None]:
# To calibrate, set the PATTERN_CALIBRATED parameter to False
if cfg.SETUP.PATTERN_CALIBRATED:
    roi_calib = hs.roi.CircleROI(cx=cfg.SETUP.CAlIB_COOR[0],cy=cfg.SETUP.CAlIB_COOR[1],r=cfg.SETUP.CAlIB_COOR[2])
else:
    dp = pxm.load_mib(cfg.SETUP.CALIB_DATA_PATH)
    dp.set_signal_type('electron_diffraction')
    dp.compute()
    dp_selected = dp.inav[dp.axes_manager.indices]
    dp_selected.plot()
    roi_calib = hs.roi.CircleROI(cx=130,cy=130,r=30)
    # set the line ROI on the interactive plot
    # connect the roi with the plot
    roi_calib.interactive(dp_selected, color="red")


In [None]:
# Fine tune the roi values
roi_calib.gui()

Once the ROI is selected, print the values and copy to the yaml file

In [None]:
if cfg.SETUP.PATTERN_CALIBRATED:
    print(f"These are the calibration values: [cx, cy, r]: [{roi_calib.cx}, {roi_calib.cy}, {roi_calib.r:.2f}]")
else:
    print(f"Values to paste in the Yaml file CALIB_COOR: [cx, cy, r]: [{roi_calib.cx}, {roi_calib.cy}, {roi_calib.r:.2f}]")

### Loading dataset:

In [None]:
dataset_path = Path(cfg.SETUP.DATA_PATH)
if dataset_path.suffix == ".mib":
    dp = pxm.load_mib(cfg.SETUP.DATA_PATH)
    _parse_hdr(cfg.SETUP.DATA_PATH)
    dp.set_signal_type('electron_diffraction')
    dp.metadata
    dp.compute() # This object is a lazy signal, so before plotting we need to compute the object
    # # convert signal to the right shape
    dp = hs.signals.Signal2D(dp.data.reshape(cfg.SETUP.DATA_SHAPE[0],cfg.SETUP.DATA_SHAPE[1],256,256))
    dp.data.shape
else:
    dp = hs.load(cfg.SETUP.DATA_PATH, lazy=True)

dp.set_signal_type('electron_diffraction')
dp.plot()

In [None]:
# TODO: use if to open file according to the file extension
#dp = pxm.load_mib(cfg.SETUP.DATA_PATH)
# dp = hs.load(cfg.SETUP.DATA_PATH, lazy=True)
#dp.plot()

In [None]:
#dp.axes_manager #.gui()
#hs.print_known_signal_types()

### Crop dataset (optional)

choose the crop region:

In [None]:
crop_roi = hs.roi.RectangularROI(left=50, top=60, right=90, bottom=100)
dp.plot()
crop_roi.add_widget(dp, axes=dp.axes_manager.signal_axes)

once happy with region, do the crop and plot:

In [None]:
crop = crop_roi(dp, axes=dp.axes_manager.signal_axes)
crop.plot()

to save the file, it takes a few minutes:

In [None]:
crop.save("croped.zspy")

In [None]:
dp = hs.load("croped.zspy", lazy=True)

In [None]:
dp.plot()

# Conversion to the right shape

In [None]:
# Reshape dp to (257,256 | 256,256)
# im = hs.signals.Signal2D(dp.data.reshape(256,257,256,256))
#im = hs.signals.Signal2D(dp.data.reshape(cfg.SETUP.DATA_SHAPE[0],cfg.SETUP.DATA_SHAPE[1],256,256))
#im.data.shape
#im.set_signal_type('electron_diffraction')
#im.plot()

(Use the + button to make the point selector bigger)

In [None]:
#im.axes_manager.gui_navigation_sliders()
dp.axes_manager.indices = (130, 198)

# Virtual Diffraction Imaging & Selecting Regions
## Interactive VDF Imaging

In [None]:
#displaying the DP from coordinates of the red point
%matplotlib qt
dp_selected = dp.inav[dp.axes_manager.indices]
dp_selected.plot()

In [None]:
# set the line ROI on the interactive plot
roi_line = hs.roi.Line2DROI(x1= 21, y1=132, x2=200, y2=55)
roi_line.interactive(dp_selected, color="yellow") # connect the roi with the plot
# with roi_calib.events.suppress():
roi_calib.interactive(dp_selected, color='red')
roi_calib.events = None
#roi_line.gui()

In [None]:
roi_line.gui()


In [None]:
def calculate_angles(coordinates):
    # calculate distance in pixels from roi center to the center of the line 
    dist_pix = math.sqrt((coordinates[0][0] - coordinates[1][0])**2 + (coordinates[0][1] - coordinates[1][1])**2)
    # convert semiangle from mrad to degrees
    semiangle_in_deg = math.degrees( cfg.SETUP.SEMIANGLE/1000.0) 
    deg2pix_ratio = semiangle_in_deg/(roi_calib.r)
    dist_deg = dist_pix * deg2pix_ratio
    return (-dist_deg, 0, dist_deg) 

def get_list_of_lines(p1,p2,n_lines, line_widths):
    #calculation of coordinates
    angle = math.atan((p2[1] - p1[1]) / (p2[0] - p1[0]))
    lengths = [i * (roi_line.length / (n_lines-1)) for i in range(n_lines + 1)]  #split the line in equal lengths
    #get list of coordinates
    coordinates = [(np.cos(angle) * line_len + roi_line.x1,
                    np.sin(angle) * line_len + roi_line.y1)
                   for line_len in lengths]
    coordinates = np.array(coordinates, dtype=np.int32)  #set coordinate values to int
    # define multiple roi for lines (with different values of line width)
    rois_line = [
        hs.roi.Line2DROI(x1=coordinates[i][0], y1=coordinates[i][1], x2=coordinates[i][0], y2=coordinates[i][1],
                         linewidth=line_width)
        for line_width in line_widths
        for i in range(len(coordinates) - 1)]

    #getting metadata info label
    meta_data = []
    angles = calculate_angles(coordinates)
    for line_width in line_widths:
        for i in range(len(coordinates) - 1):
            meta_data.append(("Line", coordinates[i][0], coordinates[i][1], line_width, angles[i]))
    

    return rois_line, meta_data

def get_list_of_circles(p1, p2, n_circles, circles_radius):
    #calculation of coordinates
    angle = math.atan((p2[1] - p1[1]) / (p2[0] - p1[0]))
    lengths = [i * (roi_line.length / (n_circles-1)) for i in range(n_circles+1)]  #split the line in equal lengths
    #get list of coordinates
    coordinates = [(np.cos(angle) * line_len + roi_line.x1,
                    np.sin(angle) * line_len + roi_line.y1)
                   for line_len in lengths]
    coordinates = np.array(coordinates, dtype=np.int32)  #set coordinate values to int
    # define multiple roi for lines (with different values of line width)
    rois_circle = [
        hs.roi.CircleROI(cx=coordinates[i][0], cy=coordinates[i][1],
                         r=circle_radius)
        for circle_radius in circles_radius
        for i in range(len(coordinates) - 1)]
    #meta data of each circle
    angles = calculate_angles(coordinates)
    meta_data = [
        ("Circle", coordinates[i][0], coordinates[i][1], circle_radius,
        angles[i]
        )
        for circle_radius in circles_radius
        for i in range(len(coordinates) - 1)]
    return rois_circle, meta_data


roi_lines = get_list_of_lines(p1= (roi_line.x1, roi_line.y1), p2= (roi_line.x2, roi_line.y2),
                          n_lines= cfg.ROI.N_LINES,
                          line_widths= cfg.ROI.LINE_WIDTHS)

roi_circles = get_list_of_circles(p1= (roi_line.x1, roi_line.y1), p2= (roi_line.x2, roi_line.y2),
                                  n_circles=cfg.ROI.N_CIRCLES,
                                  circles_radius= cfg.ROI.CIRCLES_RADIUS)


## (Optional) - graph ROIs

In [None]:
# close active plot before, in order to "erase" previous ROIs
%matplotlib qt
dp_selected.plot()
for roi in roi_lines[0]:
    roi.interactive(dp_selected, color="green")
for roi in roi_circles[0]:
    roi.interactive(dp_selected, color="blue")

## Image stack of Integrated intensity on ROIs

In [None]:
def correct_pixels_and_normalize(img):

    #bad pixel correction:
    img = hs.signals.Signal2D(img)
    img.set_signal_type('electron_diffraction')
    s_dead_pixels = img.find_dead_pixels(lazy_result=True, show_progressbar=True)
    s_hot_pixels = img.find_hot_pixels(show_progressbar=True, threshold_multiplier = cfg.SETUP.THRESHOLD_MULTUPLIER)
    img_corrected = img.correct_bad_pixels(s_dead_pixels+s_hot_pixels, show_progressbar=True, inplace=False, lazy_result=True)

    #normalization
    img_normalized = np.array(img_corrected.data)
    img_normalized = (img_normalized - np.min(img_normalized))/np.ptp(img_normalized)
    img_normalized = hs.signals.Signal2D(img_normalized)
    img_normalized.set_signal_type('electron_diffraction')
    return img_normalized

In [None]:
vdfs_circles = []
for roi in roi_circles[0]:
    img = dp.get_integrated_intensity(roi).data.reshape(cfg.SETUP.DATA_SHAPE)
    img = correct_pixels_and_normalize(img)
    vdfs_circles.append(img)

vdfs_lines = []
for roi in roi_lines[0]:
    img = dp.get_integrated_intensity(roi).data.reshape(cfg.SETUP.DATA_SHAPE)
    img = correct_pixels_and_normalize(img)
    vdfs_lines.append(img)

vdf_stack = hs.stack( vdfs_lines + vdfs_circles, new_axis_name="ROI lines", show_progressbar=True)

vdf_stack.set_signal_type("virtual_dark_field")

### With the Qt pop-up interface

In [None]:
# for plotting outside jupyter
%matplotlib qt
# for plotting inside jupyter
# %matplotlib inline
# %matplotlib widget
vdf_stack.plot()

### With the inline interface of Jupyter notebook

In [30]:
# Create the list of labels
meta_data = roi_lines[1] + roi_circles[1]
labels = [f"{data[0]} width={data[3]} angle={data[4]:.2f} deg" if data[0] == "Line"
          else 
          f"{data[0]} radius={data[3]} angle={data[4]:.2f} deg"
          for data in meta_data]

In [31]:
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.figure(figsize = cfg.SETUP.PLOT_SIZE)
_ = hs.plot.plot_images(vdf_stack, per_row=3, axes_decor="off", colorbar=False, label= labels, fig= fig)


In [None]:
vdf_stack.plot()

### Save images stack into a tiff file

In [None]:
# def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
#path = Path(path)  # os-agnostic
#if path.exists() and not exist_ok:
#    path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
#    dirs = glob.glob(f"{path}{sep}*")  # similar paths
#    matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
#    i = [int(m.groups()[0]) for m in matches if m]  # indices
#    n = max(i) + 1 if i else 2  # increment number
#    path = Path(f"{path}{sep}{n}{suffix}")  # increment path
#if mkdir:
#    path.mkdir(parents=True, exist_ok=True)  # make directory
#return path

In [None]:
# float64 is not always compatible with tiff readers
#vdf_stack.change_dtype('float32')
#vdf_stack.save("output_images/new_dataset.tif")

In [None]:
import tifffile
vdf_stack.change_dtype('float32')
folder_name = "22-06-27 curls"
os.makedirs(f"output_images/{folder_name}", exist_ok=True)
vdf_stack.save(f"output_images/{folder_name}/stack.tif")
vdf_numpy = vdf_stack.data
for image, label in zip(vdf_numpy, labels):
    print(f"output_images/{folder_name}/{label}.tif")
    tifffile.imsave(f"output_images/{folder_name}/{label}.tif", image, imagej=True)

### Load image stack from a tiff file

In [None]:
%matplotlib qt
vdf_load = hs.load('virtual diffraction image stack.tif', force_read_resolution=True)

# vdf_load.change_dtype('float32')
vdf_load.plot()

In [None]:
# Load stack from Tiff
vdf_stack = hs.load('output_images/new_dataset.tif', force_read_resolution=True)

# Inputing images to Segmentation NN
## loading network

In [None]:
from delineation.models import build_segmentation_model
seg_model = build_segmentation_model(cfg)


In [None]:
import torch
if cfg.SETUP.LOCAL:
    seg_model.load_state_dict(torch.load(cfg.TEST.MODEL_WEIGHTS, map_location=torch.device('cpu'))['state_dict'])
else:
    seg_model.load_state_dict(torch.load(cfg.TEST.MODEL_WEIGHTS)['state_dict'])

seg_model.eval()

### Converting stack of VDF images to tensors for inference

In [None]:
vdf_stack_tensor = torch.from_numpy(vdf_stack.data/1).float() #65535.0

In [None]:
with torch.no_grad():
    if cfg.SETUP.LOCAL:
        # torch.cuda.empty_cache()

        _, seg_res1 = seg_model(vdf_stack_tensor.unsqueeze(1))
        # _, seg_res2 = seg_model(vdf_stack_tensor[2].unsqueeze(0).unsqueeze(0))
#     else:
#         torch.cuda.empty_cache()
#
#         _, seg_res1 = seg_model(vdf_stack_tensor[0].unsqueeze(0).unsqueeze(0).cuda().float())
#         _, seg_res2 = seg_model(vdf_stack_tensor[0].unsqueeze(0).unsqueeze(0).cuda().float())
seg_map = seg_res1.squeeze().data.cpu().numpy() > 0.05
# seg_map2 = seg_res2[0].squeeze().data.cpu().numpy()>0.05


In [None]:
%matplotlib inline
ncols = 3
fig, axs = plt.subplots(nrows=seg_map.shape[0]//ncols, ncols= ncols, figsize=cfg.SETUP.PLOT_SIZE)

for i, ax in enumerate(axs.flat):
    ax.imshow(seg_map[i],cmap = "gray" )
    ax.set_title(labels[i])

# Topological Analysis

imports

In [None]:
import itertools
import numpy as np
import json
import cv2
import glob

from gtda.images import HeightFiltration, DilationFiltration, ErosionFiltration, RadialFiltration, SignedDistanceFiltration, DensityFiltration
from gtda.images import Binarizer, Inverter
from gtda.homology import CubicalPersistence
from gtda.diagrams import PairwiseDistance
from gtda.diagrams import Amplitude, PersistenceEntropy
from gtda.diagrams import Filtering
from sklearn.pipeline import Pipeline, make_pipeline, FeatureUnion, make_union
from gtda.plotting import plot_diagram

from PIL import Image
from labelme import utils
from numpy import asarray
import matplotlib as mpl


## Functions definition

In [None]:

def plot_images(images):
    if len(images) == 1:
        plt.imshow(images[0])
    else:
        fig, axes = plt.subplots(1, len(images), figsize=(len(images) * 3, 30))
        axes = axes.flatten()
        cmap = plt.cm.binary
        cmap.set_bad('y')
        vmin, vmax = np.min(images[images != np.inf]), np.max(images[images != np.inf])

        for i in range(len(images)):
            axes[i].imshow(images[i], cmap='binary', vmin=vmin, vmax=vmax)
            # axes[i].imshow(images[i], vmin=vmin, vmax=vmax)
            axes[i].axis('off')  # hide the axes ticks
            #axes[i].set_title(names[i], color= 'black', fontsize=12)
        plt.show()


def plot_diagrams(X, names=None):
    fig, axes = plt.subplots(1, len(X[0]), figsize=(len(X[0]) * 5, 5))
    axes = axes.flatten()
    colors = {0: 'b', 1: 'r', 2: 'g'}

    for i in range(len(X[0])):
        diagram = {dimension: X[dimension][i] for dimension in X.keys()}
        for dimension in X.keys():
            axes[i].plot(diagram[dimension][:, 0], diagram[dimension][:, 1], 'o', color=colors[dimension])
            axes[i].plot([0, np.max(X[dimension])], [0, np.max(X[dimension])], color='k')

        # axes[i].set_title(names[i], color= 'black', fontsize=12)
    plt.show()


def plot_matrices(X):
    n_matrices = X.shape[1] // X.shape[0]

    iterator = tuple(itertools.product(range(n_matrices), range(1)))
    vmin, vmax = np.min(X), np.max(X)
    if n_matrices > 1:
        figure, axes = plt.subplots(1, n_matrices, figsize=(18, 8))
        axes = axes.reshape((1, n_matrices))
        for i, j in iterator:
            plot = axes[j, i].imshow(X[:, i * X.shape[0]:(i + 1) * X.shape[0]], vmin=vmin, vmax=vmax)

    else:
        figure, axes = plt.subplots(1, n_matrices, figsize=(6, 9))
        plot = axes.imshow(X[:, : 1 * X.shape[0]], vmin=vmin, vmax=vmax)

    figure.subplots_adjust(bottom=0.2)
    cbar_ax = figure.add_axes([0.3, 0.2, 0.4, 0.03])
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=norm, orientation='horizontal')
    colorbar.set_label('Pairwise distances')
    plt.show()


def pipeline_steps(filtration_list, homology_dimensions_list):
    steps_filtration = [[
        ('filtration', filtration)]
        for filtration in filtration_list]

    steps_persistance = [[
        ('filtration', filtration),
        ('persistence', CubicalPersistence(homology_dimensions))]
        for filtration in filtration_list for homology_dimensions in homology_dimensions_list]

    steps_distance = [[
        ('filtration', filtration),  #to remove if we use a greyscale image
        ('persistence', CubicalPersistence(homology_dimensions)),
        ('distance', PairwiseDistance(metric='wasserstein', order=None, metric_params={'p': 2, 'delta': 0.1}))]
        for filtration in filtration_list for homology_dimensions in homology_dimensions_list]

    return [steps_filtration, steps_persistance, steps_distance]


def pipline_processing(images, step_list, mode='distance', plot=False):
    pipeline_diag = [(str(i), Pipeline(step_list[i]))
                     for i in range(len(step_list))]
    feature_union_diag = FeatureUnion(pipeline_diag)
    diagrams = feature_union_diag.fit_transform(images)

    if plot == True:
        if mode == 'filtration':
            plot_images(diagrams)


        elif mode == 'persistance':
            for k in range(diagrams.shape[0]):
                fig = plot_diagram(diagrams[k, :, :])
                fig.show()

        elif mode == 'distance':
            plot_matrices(diagrams[:, :, 0])
            plot_matrices(diagrams[:, :, 1])

    return diagrams


def combined_distance(X, weights=None, power=1):
    rows, cols = X.shape[0], X.shape[0]
    nb_matrix = int(X.shape[1] / X.shape[0])
    nb_pipeline = X.shape[-1]
    X = np.reshape(X, (rows, nb_matrix, cols, nb_pipeline))

    # init the weight to 1 if they are not given
    if weights is None:
        weights = np.ones(nb_matrix * nb_pipeline)

    distance = np.zeros((cols, nb_matrix * nb_pipeline))
    for i in range(nb_matrix):
        for j in range(nb_pipeline):
            ## only take the first row
            distance[:, nb_pipeline * i + j] = weights[nb_pipeline * i + j] * (X[:, i, 0, j] ** power)

    distance_tot = np.sum(distance, axis=1)
    distance_tot = np.reshape(distance_tot, (1, len(distance_tot)))

    return distance_tot


def plot_combined_distance(X):
    figure, axes = plt.subplots(1, 1, figsize=(10, 6))
    vmin, vmax = np.min(X), np.max(X)
    plot = axes.imshow(X, vmin=vmin, vmax=vmax)
    cbar_ax = figure.add_axes([0.3, 0.2, 0.4, 0.03])
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=norm, orientation='horizontal')
    colorbar.set_label('Pairwise sum of distances')
    axes.set_yticks([])
    plt.show()


def plot_distance(X, X_type="H0", dataset=" ", normalize_per_distance=True, error=None, weight=None):
    #normalize
    # X_norm = (X-X.min(axis = 0))/(X.max(axis = 0)- X.min(axis = 0))
    dilation = ["Dilation"]
    height = ["Direction [0, 1]", "Direction [0, -1]", "Direction [1, 0]", "Direction [-1, 0]"]
    density = ["Density"]
    radial = ["Position [0.5, 0.5]",
              "Position [0.5, 0.25]", "Position [0.5, 0.75]",
              "Position [0.25, 0.5]", "Position [0.75, 0.5]",
              "Position [0.25, 0.25]", "Position [0.25, 0.75]",
              "Position [0.75, 0.25]", "Position [0.75, 0.75]"]

    ylabel = dilation + density + height + radial
    tick_locs = np.arange(len(ylabel))

    if normalize_per_distance:
        X_norm = X / X.max(axis=0)
    else:
        X_norm = X / X.max()
    if dataset == "s-error":
        figure, axes = plt.subplots(1, 1, figsize=(3, 11))
        plot = axes.imshow(X_norm.T, vmin=0, vmax=1)
        axes.set_title(X_type + " distance : dataset s-error")
    elif dataset == "BF_CL":
        figure, axes = plt.subplots(1, 1, figsize=(6, 10))
        plot = axes.imshow(X_norm.T, vmin=0, vmax=1)
        axes.set_title(X_type + " distance : dataset BF_CL")
    else:
        figure, axes = plt.subplots(1, 1, figsize=(int(np.round((X_norm.shape[0]) / 2) - 1), 13))
        plot = axes.imshow(X_norm.T, vmin=0, vmax=1)
        axes.set_title(X_type + " distance")

    # axes.set_yticks(tick_locs, legend)

    axes.set_xticks(np.arange(X_norm.T.shape[1]))
    axes.set_yticks(np.arange(X_norm.T.shape[0]))

    axes.set_yticklabels(ylabel)
    axes.tick_params(top=True, bottom=False,
                     labeltop=True, labelbottom=False)

    if error is not None:
        ax_top = axes.secondary_xaxis('bottom')
        ax_top.set_xticks(np.arange(X_norm.T.shape[1]))
        ax_top.set_xlabel('error *10^7', color='r')
        ax_top.set_xticklabels(np.round(error, 3), rotation=45)

    if weight is not None:
        ax_right = axes.secondary_yaxis('right')
        ax_right.set_yticks(np.arange(X_norm.T.shape[0]))
        ax_right.set_ylabel('weight', color='r')
        ax_right.set_yticklabels(np.round(weight.reshape(-1), 3))

    # PROBLEM WITH THE DIMENSION OF tHE MATRIX TO PRINT ON THE RIGHT AXES
    # if weight is not None :
    #   ax_right = axes.twinx()
    #   ax_right.set_ylabel('error', color='r')
    #   ax_right.set_yticklabels(weight)

    figure.subplots_adjust(bottom=0.1)
    cbar_ax = figure.add_axes([0.3, 0.2, 0.4, 0.03])
    norm = mpl.colors.Normalize(vmin=0, vmax=1)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=norm, orientation='horizontal')
    colorbar.set_label('Pairwise distances')
    plt.show()


def label_name():
    dilation = ["Dilation         "]
    height = ["Direction [0, 1]", "Direction [0, -1]", "Direction [1, 0]", "Direction [-1, 0]"]
    density = ["Density         "]
    radial = ["Position [0.5, 0.5]",
              "Position [0.5, 0.25]", "Position [0.5, 0.75]",
              "Position [0.25, 0.5]", "Position [0.75, 0.5]",
              "Position [0.25, 0.25]", "Position [0.25, 0.75]",
              "Position [0.75, 0.25]", "Position [0.75, 0.75]"]

    ylabel = dilation + density + height + radial
    return ylabel


def get_first_row(X):
    rows, cols = X.shape[0], X.shape[0]
    nb_matrix = int(X.shape[1] / X.shape[0])
    nb_pipeline = X.shape[-1]
    X = np.reshape(X, (cols, nb_matrix, rows, nb_pipeline))
    distance_H0 = X[1:, :, 0, 0]
    distance_H1 = X[1:, :, 0, 1]

    return distance_H0, distance_H1

## Load images

This step is only necessary if images are not in memory

Make image stack

In [None]:
folder = 'Dataset_name'
ground_truth_number = 1

image_name = '35 2BC6 CL=330-30-{}_gt.png'.format(ground_truth_number)
path_save = 'Results/dataset_{}-{}/'.format(folder, ground_truth_number)
path = 'Dataset2_31072020/'

from pathlib import Path

Path(path_save).mkdir(parents=True, exist_ok=True)

In [None]:
# with open(path + '35 2BC6 CL=330-30-1.json') as f:
#     data = json.load(f)
# imageData = data.get("imageData")
# ground_truth = utils.img_b64_to_arr(imageData)

ground_truth = Image.open(path + image_name).convert('LA')
ground_truth = asarray(ground_truth)[:, :, 0]
images = np.expand_dims(ground_truth, axis=0)

all_images_path = glob.glob(path + folder + "/aligned/segmentation_aligned/*.png")

for i, img_path in enumerate(all_images_path):
    unique_image = Image.open(img_path).convert('LA')
    unique_image = asarray(unique_image)[:, :, 0]
    unique_image = np.expand_dims(unique_image, axis=0)

    images = np.vstack((images, unique_image))

In [None]:
images = seg_map[::3,:,:]
print(images.shape)
plot_images(images)
plot_images(images[:4, 300:500, 400:600])

# Align Images

In [None]:
# If we are using the images of the same column, do we still need to align them?

# Define Filters

In [None]:
x_center, y_center = int(images.shape[1]/2), int(images.shape[2]/2)

direction_list = np.array([ [0, 1], [0, -1], [1, 0], [-1, 0] ])
n_neighbors_list = [[2, 4]]
# Change this value when calculating 3D
homology_dimensions_list = [[0,1]]
center_list = np.array([ [x_center, y_center],
                         [x_center, int(1/2 * y_center)],
                         [x_center, int(3/2 * y_center)],
                         [int(1/2 * x_center), y_center],
                         [int(3/2 * x_center), y_center],
                         [int(1/2 * x_center), int(1/2 * y_center)],
                         [int(1/2 * x_center), int(3/2 * y_center)],
                         [int(3/2 * x_center), int(1/2 * y_center)],
                         [int(3/2 * x_center), int(3/2 * y_center)] ])


filtration_list_height = [HeightFiltration(direction=direction)
                          for direction in direction_list]

filtration_list_dilation = [DilationFiltration()]

filtration_list_density = [DensityFiltration()]

filtration_list_radial = [RadialFiltration(center=center)
                          for center in center_list]

# filtration_list_density = [DensityFiltration(n_neighbors=n_neighbors, normalize=False)
#                            for n_neighbors in n_neighbors_list]

In [None]:
def generic_filtration(filtration_list, homology_dimensions_list, plot = False):
    steps = pipeline_steps(filtration_list, homology_dimensions_list)
    # steps_filtration = steps[0]
    # steps_persistance = steps[1]
    steps_distance = steps[2]

    # _ = pipline_processing(images, steps_filtration_radial, mode='filtration', plot=plot)
    # _ = pipline_processing(images, steps_persistance_radial, mode='persistance', plot=plot)
    diagrams_distance = pipline_processing(images, steps_distance, mode='distance', plot=plot)
    distance = combined_distance(diagrams_distance, weights = None, power = 1)
    plot_combined_distance(distance)
    return diagrams_distance

## Filtrations

In [None]:
# Dilation
diagrams_distance = generic_filtration(filtration_list_dilation, homology_dimensions_list, plot = False)
np.savez(path_save + "diagrams_distance_dilation.npz", diagrams_distance_dilation = diagrams_distance)
# Density
diagrams_distance = generic_filtration(filtration_list_density, homology_dimensions_list, plot = False)
np.savez(path_save + "diagrams_distance_density.npz", diagrams_distance_density = diagrams_distance)
# Height
diagrams_distance = generic_filtration(filtration_list_height, homology_dimensions_list, plot = False)
np.savez(path_save + "diagrams_distance_height.npz", diagrams_distance_height = diagrams_distance)
# Radial
diagrams_distance = generic_filtration(filtration_list_radial, homology_dimensions_list, plot = False)
np.savez(path_save + "diagrams_distance_radial.npz", diagrams_distance_radial = diagrams_distance)

## Load distances from files ( to avoid repeating previous calculations )

In [None]:
data = np.load(path_save + "diagrams_distance_dilation.npz")
diagrams_distance_dilation = data['diagrams_distance_dilation']

data = np.load(path_save + "diagrams_distance_radial.npz")
diagrams_distance_radial = data['diagrams_distance_radial']

data = np.load(path_save + "diagrams_distance_height.npz")
diagrams_distance_height = data['diagrams_distance_height']

data = np.load(path_save + "diagrams_distance_density.npz")
diagrams_distance_density = data['diagrams_distance_density']

all_distance = [diagrams_distance_dilation, diagrams_distance_density,
                diagrams_distance_height, diagrams_distance_radial]

distance_H0 = np.array([])
distance_H1 = np.array([])


for i, dist in enumerate(all_distance) :
    tmp_H0, tmp_H1 = get_first_row(dist)
    distance_H0 = np.hstack([distance_H0, tmp_H0]) if distance_H0.size else tmp_H0
    distance_H1 = np.hstack([distance_H1, tmp_H1]) if distance_H1.size else tmp_H1

# distance = combined_distance(diagrams_distance_radial, weights = None, power = 1)
# distance_H0, distance_H1 = get_first_row(diagrams_distance_radial)
plot_distance(distance_H0, "H0", folder)
plot_distance(distance_H1, "H1", folder)

# Weights optimization

In [None]:
from scipy.optimize import minimize, Bounds

def compute_best_algo(distance_H0, distance_H1=None, type="stack"):
    if type == "stack_3d":
        all_distance = np.dstack((distance_H0, distance_H1))
        all_distance = np.swapaxes(all_distance, 0, 1)
    elif type == "stack":
        all_distance = np.hstack((distance_H0, distance_H1))
        all_distance = np.swapaxes(all_distance, 0, 1)
    elif type == "single" or "H0" or "H1":
        all_distance = np.swapaxes(distance_H0, 0, 1)
    else:
        raise NameError('Wrong type')

    print(str(all_distance.shape))
    weights_dataset = optimization(all_distance)

    print("\n")
    distance_label = label_name()
    #in this case there are twice more weights
    if type == "stack":
        index_middle = len(distance_label)

        print("weights for H0\n")
        for i, (weight, label) in enumerate(zip(weights_dataset[:index_middle], distance_label)):
            print("%s\t weight  %.3f" % (label, weight))
        print("\nweights for H1\n")
        for i, (weight, label) in enumerate(zip(weights_dataset[index_middle:], distance_label)):
            print("%s\t weight  %.3f" % (label, weight))

    else:
        if type == "H0" or "H1":
            print("weights for %s\n" % type)
        else:
            print("weights for H0 and H1\n")
        for i, (weight, label) in enumerate(zip(weights_dataset, distance_label)):
            print("%s\t weight  %.3f" % (label, weight))

    print("\n")

    #best algorithm
    if type == "stack_3d":
        weights_dataset = weights_dataset.reshape(-1, 1, 1)
        error = weights_dataset * all_distance
        total_error = np.sum(np.sum(error, axis=0), axis=-1) * 1e-7
        partial_error = np.sum(error, axis=0) * 1e-7
    elif type == "stack":
        weights_dataset = weights_dataset.reshape(-1, 1)
        error = weights_dataset * all_distance
        total_error = np.sum(error, axis=0) * 1e-7
        mid_index = np.int(len(weights_dataset) / 2)
        partial_error = np.vstack((np.sum(error[mid_index:], axis=0) * 1e-7,
                                   np.sum(error[:mid_index], axis=0) * 1e-7))
        partial_error = partial_error.T

    else:
        weights_dataset = weights_dataset.reshape(-1, 1)
        error = weights_dataset * all_distance
        total_error = np.sum(error, axis=0) * 1e-7
        partial_error = 0

    # total_error
    index = np.argsort(total_error)
    print("Sorted from best to worst\n")
    for i in range(len(index)):
        print("Algorithm : %d\t error : %.3f" % (index[i], total_error[index[i]]))

    return all_distance, weights_dataset, index, total_error, partial_error


def optimization(all_distance):
    x0 = np.repeat(1, all_distance.shape[0])
    cons = {'type': 'eq', 'fun': lambda x0: 1.0 - np.prod(x0)}
    bounds = Bounds(0, np.Inf)
    # if the distance of H0 and H1 are stacked in the 3rd dimension --> weights of a distance is the same for H0 and H1
    if len(all_distance.shape) == 3:
        result = minimize(f_stack_3d, x0, args=all_distance, method='trust-constr', bounds=bounds, constraints=cons)
    else:
        result = minimize(f, x0, args=all_distance, method='trust-constr', bounds=bounds, constraints=cons)
    weights = result.x
    error = result.fun

    return weights


def f(params, all_distance):
    weights = params.reshape(-1, 1)
    # print("all_distance " + str(all_distance.shape))
    # print("weights " + str(weights.shape))
    error = weights * all_distance
    total_error = np.sum(error, axis=-1)
    total_error = -1 * np.sum(total_error)
    # print(str(total_error))
    return total_error


def f_stack_3d(params, all_distance):
    weights = params.reshape(-1, 1, 1)
    # print("3d, all_distance " + str(all_distance.shape))
    # print("weights " + str(weights.shape))
    error = weights * all_distance
    total_error = np.sum(np.sum(error, axis=1), axis=-1)
    total_error = -1 * np.sum(total_error)
    # print(str(total_error))
    return total_error


def constraint(params):
    sum_element = 1
    sum = np.sum(params)
    return sum - sum_element

### best algorithm of dataset ***BF_CL***

- *H0* and *H1* are stacked in the 3rd dimension
- weights for each distance are equal for *H0* and *H1*

In [None]:

all_distance, weight, index, total_error, partial_error = compute_best_algo(distance_H0, distance_H1, type="stack_3d")

In [None]:
plot_distance(distance_H0, X_type="H0", error=partial_error[:, 0], weight=weight)
plot_distance(distance_H1, X_type="H1", error=partial_error[:, 1], weight=weight)

In [None]:
plot_combined_distance(total_error, X_type="H0 and H1 combined")

In [None]:
# save sorted algortihm and weights
# np.savez(path_save + "weights.npz", weights = weights_dataset)
# np.savez(path_save + "index_algorithm_sorted.npz", index_algorithm_sorted = index)

In [None]:
### best algorithm of dataset ***BF_CL***
- * H0 * and *H1 * are
stacked in the
2
nd
dimension
- weights
for each distance are different for * H0 * and * H1 * but the sum of all the weights is 1
all_distance, weight, index, total_error, partial_error = compute_best_algo(distance_H0, distance_H1, type="stack")
mid_index = np.int(len(weight) / 2)
plot_distance(distance_H0, X_type="H0", error=partial_error[:, 0], weight=weight[:mid_index])
plot_distance(distance_H1, X_type="H1", error=partial_error[:, 1], weight=weight[mid_index:])
plot_combined_distance(total_error, X_type="H0 and H1 combined")
### best algorithm of dataset ***BF_CL***
- * H0 * and *H1 * are
computed
separately
- sum
of
weights
for *H0 * is equal to 1
- sum
of
weights
for *H1 * is equal to 1

all_distance_H0, weight_H0, index_H0, total_error_H0, _ = compute_best_algo(distance_H0, type="H0")
all_distance_H1, weight_H1, index_H1, total_error_H1, _ = compute_best_algo(distance_H1, type="H1")
plot_distance(distance_H0, X_type="H0", error=total_error_H0, weight=weight_H0)
plot_distance(distance_H1, X_type="H1", error=total_error_H1, weight=weight_H1)
plot_combined_distance(total_error_H0 + total_error_H1, X_type="H0 and H1 combined")