In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
def profile_plot(pred_img: np.ndarray, direction: str, index: int, plot_save_path: str) -> None:
    '''
    Prediction only (no GT, usually real data) profile plot
    '''

    if direction == 'col' or direction == 'vertical':
        # column = 250 # real
        # column = 25 # test
        pred_profile = pred_img[:, index].copy()
        pred_img[:, index] = 2 * np.pi
    elif direction == 'row' or direction == 'horizontal':
        # row = 350 # real
        # row = 25 # test
        pred_profile = pred_img[index, :].copy()
        pred_img[index, :] = 2 * np.pi

    # plot
    fig, axs = plt.subplots(1, 2, figsize=(12,6))
    p0 = axs[0].imshow(pred_img, cmap='gray')
    fig.colorbar(p0, ax=axs[0])

    if direction == 'col' or direction == 'vertical':
        axs[0].set_title('Vertical Phase Profile (Column=' + str(index) + ')')
    elif direction == 'row' or direction == 'horizontal':
        axs[0].set_title('Horizontal Phase Profile (Row=' + str(index) + ')')

    axs[1].plot(pred_profile, label='Predicted', alpha=0.7)
    axs[1].set_title('Predicted Profile')
    axs[1].set_xlabel('Pixel')
    axs[1].set_ylabel('Intensity')
    axs[1].legend()

    fig.tight_layout() # to avoid axes overlapping
    fig.savefig(plot_save_path) # save

In [None]:
img_idx = '2'

In [None]:
pred = './dl_data_set_12_20/imgs/00000_0.1/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="vertical", index=250, plot_save_path="./profile_plot_0.1.png")

In [None]:
pred = './dl_data_set_12_20/imgs/00000_0.5/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="vertical", index=250, plot_save_path="./profile_plot_0.5.png")

In [None]:
pred = './dl_data_set_12_20/imgs/00000_1.0/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="vertical", index=250, plot_save_path="./profile_plot_1.0.png")

In [None]:
img_idx = '7'

In [None]:
pred = './dl_data_set_12_20/imgs/00000_0.1/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="horizontal", index=180, plot_save_path="./profile_plot_0.1_h.png")

In [None]:
pred = './dl_data_set_12_20/imgs/00000_0.5/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="horizontal", index=180, plot_save_path="./profile_plot_0.5_h.png")

In [None]:
pred = './dl_data_set_12_20/imgs/00000_1.0/img_' + img_idx + '.png'
pred_img = cv2.imread(pred, cv2.IMREAD_GRAYSCALE)
print(pred_img.shape)

profile_plot(pred_img= pred_img, direction="horizontal", index=180, plot_save_path="./profile_plot_1.0_h.png")