# Import the packages

In [None]:
import numpy as np
from math import sqrt
import skimage.io as io
import lmfit
from lmfit.lineshapes import gaussian2d, lorentzian
from skimage import transform
from skimage.feature import blob_log, blob_doh, blob_dog
from scipy import ndimage
from skimage.measure import ransac
import pandas as pd
from sklearn.linear_model import RANSACRegressor
import pandas as pd
from scipy import spatial
import scipy.misc as sp
from skimage.filters import gaussian
from PIL import Image
from sklearn.metrics import mean_squared_error
from cmath import inf

import matplotlib.pyplot as plt
from matplotlib import axes
from matplotlib import figure

plt.rcParams["figure.dpi"] = 300

# All the function needed

In [None]:
def residual_cal(src, dst):
    """This function is used to calculate the residual between src and dst.

    Args:
        src (numpy array): The source array having (N, 2) shape with N points.
        dst (numpy array): The destination array having (N, 2) shape with N points.

    Returns:
        mse: The mean squared error between these two arraies.
    """
    mse = mean_squared_error(src, dst)
    return mse


def blob_detection(img_path, min_sigma, max_sigma, threshold, method=0):
    """This function is mostly used for detecting the beads in any image.

    Args:
        img_path (string): The absolute path of the input image.
        min_sigma (int): The minimum sigma, lower it is, smaller the blob will be detected.
        max_sigma (int): The maximum sigma, higher it is, bigger the blob will be detected.
        threshold (float): Higher it is, higher the intensities of blobs.
        method (int, optional): 0 for Difference of Gaussian (DoG) and 1 for Determinant of Hessian (DoH). 
        They should be applied with different combination of parameters. DoG is more suitable for fret movies,
        while DoH is more suitable for sequencing images. Defaults to 0.

    Returns:
        centers: A numpy array containing the coordinates of all the centers.
    """
    img = io.imread(img_path)
    if method == 0:
        blob = blob_dog(
            img, min_sigma=min_sigma, max_sigma=max_sigma, threshold=threshold
        )
    else:
        blob = blob_doh(
            img, min_sigma=min_sigma, max_sigma=max_sigma, threshold=threshold
        )
    i = 0
    # r = 3
    centers = []
    h, w = img.shape
    for blob in blob:
        y, x, r = blob
        if y > r and y < (h - r) and x > r and x < (w - r):
            centers.append(
                ndimage.measurements.center_of_mass(
                    img[int(y - r) : int(y + r + 1), int(x - r) : int(x + r + 1)]
                )
            )
            centers[i] = list(np.add(centers[i], [x - r, y - r]))
            i += 1
    return np.array(centers)


""" 
    Wrapped polynomial transformation for Ransac using.
"""
class QuadPolyTrans(transform.PolynomialTransform):
    def estimate(*data):
        return transform.PolynomialTransform.estimate(*data, order=2)


class CubicPolyTrans(transform.PolynomialTransform):
    def estimate(*data):
        return transform.PolynomialTransform.estimate(*data, order=3)


def count_nearest_pts(src, dst, radius):
    """Counting the number of nearest neighbors for each given point.

    Args:
        src (numpy array): (N, 2) shape array. Build the kd tree based on this.
        dst (numpy array): (N, 2) shape array. For each point in this array, find the nearest neighbors in src array.
        radius (int): The maximum searching radius.

    Returns:
        res, idx: res is the distance for the point and its neighbor, 'inf' means no neighbor in given search radius. 
        idx is the index for the neighbor in src array.
    """
    tree = spatial.KDTree(src)
    res, idx = tree.query(dst, k=1, distance_upper_bound=radius)
    for i in range(0, len(idx)):
        if len(np.argwhere(idx == idx[i])) > 1:
            res[i] = inf
    return res, idx


def save_np_img(npimg, path):
    """Sava a numpy array as an image.

    Args:
        npimg (numpy array): A numpy array need to be saved as image.
        path (string): The save path.
    """
    img = Image.fromarray(npimg)
    img = img.convert("L")
    img.save(path, dpi=(300.0, 300.0))


def ploting_res(canvas_size, res, pts_radius, save_path, gaussian_sigma=3):
    """ This function is used to generate any given numpy array to the image.

    Args:
        canvas_size (tuple): A two-element tuple as (x_range, y_range). Determine the size of generated image.
        res (numpy array): The point set needed to generate the image.
        pts_radius (int): The radius of the points.
        save_path (string): The save path of the generated image.
        gaussian_sigma (int, optional): The radius of gaussian blurring. Defaults to 3.
    """
    x_range, y_range = canvas_size
    pk_img = np.zeros(canvas_size)
    for ele in res:
        y = int(ele[0])
        x = int(ele[1])
        for j in range(0, pts_radius):
            for k in range(0, pts_radius):
                pk_img[
                    min(max(0, x - j), x_range - 1), min(max(0, y - k), y_range - 1)
                ] = (1024 * 2)
                pk_img[
                    min(x_range - 1, max(0, x + j)), min(max(0, y - k), y_range - 1)
                ] = (1024 * 2)
                pk_img[
                    min(x_range - 1, max(0, x + j)), min(y_range - 1, max(0, y + k))
                ] = (1024 * 2)
                pk_img[
                    min(max(0, x - j), x_range - 1), min(y_range - 1, max(0, y + k))
                ] = (1024 * 2)
    blurred_img = gaussian(pk_img, sigma=gaussian_sigma, multichannel=False)
    save_np_img(blurred_img, save_path)


def count_pairs(ref_coord, target_coord, radius):
    """This function is used for counting how many pairs are there exsist in two point sets.

    Args:
        ref_coord (numpy array): The source point set.
        target_coord (numpy array): The target point set.
        radius (int): Searching radius.

    Returns:
       pairs: (int) The number of pairs.
    """
    tree = spatial.KDTree(ref_coord)
    res, idx = tree.query(target_coord, k=1)
    polished_fretPts = ref_coord[idx[np.where(res != inf)]]
    pairs = 0
    pp = 0
    for i in idx:
        if res[pp] < radius:
            pairs += 1
        pp += 1
    return pairs


def read_coord(csv_path):
    """Read csv file and transfer them into numpy array.

    Args:
        csv_path (string): CSV file path.

    Returns:
        pts: The numpy array containing all the coordinats in the given csv file.
    """
    res_tb = pd.read_csv(csv_path)
    pts = np.column_stack(
        (np.array(res_tb["X"][0 : len(res_tb)]), np.array(res_tb["Y"][0 : len(res_tb)]))
    )
    return pts


def show_blob_detection_res(img_path, min_sigma, max_sigma, threshold, method=0):
    """
    Showing the result of 'blob detection' function. Used as the same way of 'blob_detection'
    """
    fig, ax = plt.subplots()
    img = io.imread(img_path)
    ax.imshow(img)
    if method == 0:
        res = blob_dog(
            img,
            min_sigma=min_sigma,
            max_sigma=max_sigma,
            threshold=threshold,
        )
    else:
        res = blob_doh(
            img,
            min_sigma=min_sigma,
            max_sigma=max_sigma,
            threshold=threshold,
        )
    i = 0
    CM = []
    r = 3
    [h, w] = img.shape
    for blob in res:
        y, x, r = blob
        # print(r)
        if y > r and y < (h - r) and x > r and x < (w - r):
            CM.append(
                ndimage.measurements.center_of_mass(
                    img[int(y - r) : int(y + r), int(x - r) : int(x + r)]
                )
            )
            CM[i] = list(np.add(CM[i], [y - r, x - r]))
            x1, y1 = CM[i]
            c = plt.Circle([y1, x1], 3, color="red", linewidth=1, fill=False)
            ax.add_patch(c)
            i += 1
    ax.set_axis_off()
    plt.show()


# Implement the Alignment
#### Using mannually labelled beads as the first rough alignment

> $res\_tb$ is the clicked results from ImageJ

In [None]:
res_tb = pd.read_csv(
    "/Users/qinhanhou/Desktop/DeindlLab/Git/MUSCLE/ExampleData/Test/Results.csv"
)
sample_num = int(len(res_tb) / 2)
src = np.column_stack(
    (np.array(res_tb["XM"][0:sample_num]), np.array(res_tb["YM"][0:sample_num]))
)
dst = np.column_stack(
    (np.array(res_tb["XM"][sample_num:]), np.array(res_tb["YM"][sample_num:]))
)

rough_tf = transform.estimate_transform("affine", src=src, dst=dst)

#### Getting centers of automatically labelled beads (using center of mass)

The parameters could be adjust:

 1. $min\_sigma$ determine how small the blobs could be detected
 2. $max\_sigma$ determine how big the blobs could be detected
 3. $threhold$ determine the intensities of blobs.


> Uncomment $show\_blob\_detection\_res$ to show the result. It could also be used before getting the final result.

In [None]:
movie_centers = blob_detection(
    "/Users/qinhanhou/Desktop/DeindlLab/Git/MUSCLE/ExampleData/Pos246/beads_246_fret_channel.tif",
    min_sigma=1,
    max_sigma=10,
    threshold=0.01,
)

seq_centers = blob_detection(
    "/Users/qinhanhou/Desktop/DeindlLab/Git/MUSCLE/ExampleData/Pos246/min_projection_246.tif",
    min_sigma=20,
    max_sigma=500,
    threshold=0.001,
    method=1,
)

# show_blob_detection_res(
#     "/Users/qinhanhou/Desktop/DeindlLab/Git/MUSCLE/ExampleData/Pos246/beads_246_fret_channel.tif",
#     min_sigma=1,
#     max_sigma=10,
#     threshold=0.01,
# )

# show_blob_detection_res(
#     "/Users/qinhanhou/Desktop/DeindlLab/Git/MUSCLE/ExampleData/Pos246/min_projection_246.tif",
#     min_sigma=20,
#     max_sigma=500,
#     threshold=0.001,
#     method=1,
# )

#### Doing the combination of two channels, to get all the peak locations from both channels.

In [None]:
green_channel_centers = blob_detection(
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/grenn_channel_beads.tif",
    min_sigma=1,
    max_sigma=8,
    threshold=0.0001,
)

red_channel_centers = blob_detection(
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/red_channel_beads.tif",
    min_sigma=1,
    max_sigma=8,
    threshold=0.0001,
)

# show_blob_detection_res('/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/grenn_channel_beads.tif')
# show_blob_detection_res('/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/red_channel_beads.tif')

res, idx = count_nearest_pts(green_channel_centers, red_channel_centers, 2)
# print(idx)
# print(res)

usable_green_beads = green_channel_centers[idx[np.where(res != inf)]]
usable_red_beads = red_channel_centers[np.where(res != inf)]
print(usable_green_beads)
print(usable_red_beads)

tform = transform.AffineTransform()
tform.estimate(usable_green_beads, usable_red_beads)
tsformed_green_beads = tform(usable_green_beads)
print(tsformed_green_beads)
ploting_res(
    (256, 512),
    res=usable_red_beads,
    pts_radius=2,
    save_path="/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/red_beads.png",
)
ploting_res(
    (256, 512),
    res=tsformed_green_beads,
    pts_radius=2,
    save_path="/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/tform_green_beads.png",
)


#### Getting the local maximum from ImageJ, for both channels
And then apply the affine transformation found in the last cell.

In [None]:
green_channel_peaks = read_coord(
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/Green_channel_peaks.csv"
)
red_channel_peaks = read_coord(
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/Red_channel_peaks.csv"
)
print(tform(green_channel_peaks))
combiened_peaks = np.concatenate((tform(green_channel_peaks), red_channel_peaks))
print(combiened_peaks)


## Plotting simple linear transformed beads and beads from seq image
Then we will use them to do the RANSAC test and get the reliable pairs

In [None]:
x_range = 1566
y_range = 3240
print(rough_tf(movie_centers))
print(seq_centers)
ploting_res(
    (x_range, y_range),
    res=rough_tf(movie_centers),
    pts_radius=3,
    save_path="/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/movie_center.png",
)
ploting_res(
    (x_range, y_range),
    res=seq_centers,
    pts_radius=3,
    save_path="/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/seq_center.png",
)


#### Using ransac to get more pricise beads matching.

In [None]:
res, idx = count_nearest_pts(rough_tf(movie_centers), seq_centers, 8)
print(res)
print(idx)
movie_centers = movie_centers[idx[np.where(res != inf)]]
seq_centers = seq_centers[np.where(res != inf)]

print(movie_centers)
print(seq_centers)

model, inliers = ransac(
    (movie_centers, seq_centers),
    QuadPolyTrans,
    4,
    5,
    initial_inliers=np.ones(len(movie_centers), dtype=bool),
    stop_probability=0.8,
)
print(inliers)
print(model(movie_centers))
print(seq_centers)
print(residual_cal(model(movie_centers), seq_centers))

ploting_res(
    (x_range, y_range),
    model(movie_centers),
    3,
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/movie_center_f.png",
)
ploting_res(
    (x_range, y_range),
    seq_centers,
    3,
    "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/seq_center_f.png",
)


#### For the final result.

In [None]:
peak_locations = combiened_peaks

res = model(np.array(peak_locations))

save_path = "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/"
canvas_size = (1566, 3240)
ploting_res(canvas_size, res, 3, save_path + "transformed_res.png")
fastq = read_coord("/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0817/coord_lib.csv")
res_1, idx_1 = count_nearest_pts(res, fastq, 20)
polished_movie = res[idx_1[np.where(res_1 != inf)]]
polished_fastq = fastq[np.where(res_1 != inf)]
print(len(polished_fastq))
ploting_res((x_range, y_range), polished_fastq, 10, save_path=save_path + "fastq.png", gaussian_sigma=10)
ploting_res((x_range, y_range), polished_movie, 3, save_path=save_path + "fret.png")


#### This part is for random checking.
10 random datasets are generated and stored in $rd\_set$.

In [None]:
rd_set = []
for i in range(0, 1):
    rd = np.random.random((1300, 2))
    rd[:, 0] = rd[:, 0] * 512
    rd[:, 1] = rd[:, 1] * 256
    rd_set.append(rd)

rd_arr = []
arti_arr = []
phix_arr = []

peak_locations = combiened_peaks

res = model(np.array(peak_locations))
rd_res = model(np.array(rd_set[0]))
canvas_size = (1566, 3240)
ploting_res(canvas_size, rd_res, 3, save_path + "rd1.png")

save_path = "/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0822_Combine2Channel/"
# ploting_res(canvas_size, res, 3, save_path + "transformed_res.png")
fastq = read_coord("/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0817/coord_lib.csv")
phix = read_coord("/Users/qinhanhou/Desktop/DeindlLab/0729Poly/0817/coord_phix.csv")

for i in range(0, 100):
    res_rd, idx_rd = count_nearest_pts(rd_res, fastq, i)
    polished_fretPts_k = rd_res[idx_rd[np.where(res_rd != inf)]]
    polished_rd_k = fastq[np.where(res_rd != inf)]
    rd_arr.append(len(polished_rd_k))

for i in range(0, 100):
    res_final, idx_final = count_nearest_pts(res, fastq, i)
    polished_fretPts_final = res[idx_final[np.where(res_final != inf)]]
    polished_fastq_final = fastq[np.where(res_final != inf)]
    arti_arr.append(len(polished_fastq_final))

for i in range(0, 100):
    res_phix, idx_phix = count_nearest_pts(res, phix, i)
    polished_fretPts_phix = res[idx_phix[np.where(res_phix != inf)]]
    polished_fastq_phix = phix[np.where(res_phix != inf)]
    phix_arr.append(len(polished_fastq_phix))

plt.figure()
plt.plot(rd_arr, label='random')
plt.plot(arti_arr, label='lib')
plt.plot(phix_arr, label='phix')
plt.plot(np.array(arti_arr) - np.array(rd_arr), label='Difference')
plt.legend()
plt.show()
# ploting_res((x_range, y_range), polished_rd_k, 3, save_path + "polished_rd.png")
# ploting_res(
#     (x_range, y_range), polished_fretPts_k, 3, save_path + "polished_fretPts_rd.png"
# )
