In [None]:
from acousticnn.plate.configs.main_dir import main_dir
import os
os.chdir(main_dir)

%reload_ext autoreload
%autoreload 2
base_path = os.path.join(main_dir, "experiments/arch")

In [None]:
import numpy as np
from acousticnn.plate.dataset import get_dataloader, HDF5Dataset
from acousticnn.plate.model import model_factory
from acousticnn.plate.train_fsm import extract_mean_std, get_mean_from_field_solution
from acousticnn.utils.builder import build_opti_sche
from acousticnn.utils.logger import init_train_logger, print_log
from acousticnn.utils.argparser import get_args, get_config
from acousticnn.plate.train_fsm import evaluate, _generate_preds
from acousticnn.plate.train import evaluate as evaluate_implicit, _generate_preds as generate_preds_implicit
from acousticnn.plate.train import _evaluate
from torchinfo import summary
import wandb, time, torch
from torch.utils.data import ConcatDataset

np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

difficulty = "G5000"
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", "fno_conditional.yaml"])
config = get_config(args.config)
G5000_dataset  = HDF5Dataset(args, config, config.data_paths, normalization=True)
difficulty = "V5000"
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", "fno_conditional.yaml"])
config = get_config(args.config)
V5000_dataset  = HDF5Dataset(args, config, config.data_paths, normalization=True)
dataset = ConcatDataset([G5000_dataset, V5000_dataset])

In [None]:
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams['axes.labelsize'] = 12
rcParams['axes.titlesize'] = 12
rcParams["figure.figsize"] = (10 / 2.54*0.90, 8 / 2.54*0.75)
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.Set2(np.linspace(0,1,2)))
save_dir = "plots/dataset_figures/"

In [None]:
from scipy.signal import find_peaks, find_peaks_cwt

def detect_peaks(dataset, prominence=.5, wlen=100):
    peaks_list = []
    pixel_sums = []
    prominence_list = []
    for i in range(len(dataset)):
        img, response = dataset[i]["bead_patterns"], dataset[i]["z_vel_mean_sq"]
        pixel_sum  = torch.sum(img > 0) / (img.shape[1] * img.shape[2])
        pixel_sums.append(pixel_sum)
        actual_peaks, properties = find_peaks(response, prominence=prominence, wlen=wlen)
        prominence_list.append(properties["prominences"])
        peaks_list.append(actual_peaks)
    len_peaks = np.array([len(peak) for peak in peaks_list])
    pixel_sums = np.array(pixel_sums)
    peaks = np.hstack(peaks_list)
    return len_peaks, pixel_sums, peaks,  np.hstack(prominence_list), peaks_list

In [None]:
def do_boxplot(len_peaks, pixel_sums, peaks):
    data_dict = {}
    for len_val in set(len_peaks):
        data_dict[len_val] = pixel_sums[len_peaks == len_val]

    data = [data_dict[len_val] for len_val in sorted(data_dict.keys())]
    fig, ax = plt.subplots()
    boxplot = ax.boxplot(data, showfliers=False, patch_artist=True, vert=False)

    box_color = 'lightblue'
    whisker_color = 'gray'
    for patch in boxplot['boxes']:
        patch.set_facecolor(box_color)
        patch.set_edgecolor('black')
    for whisker in boxplot['whiskers']:
        whisker.set(color=whisker_color, linewidth=1.5, linestyle='--')
    for median in boxplot['medians']:
        median.set(color='red', linewidth=2)
    ax.set_yticklabels(sorted(data_dict.keys()), fontsize=10)
    ax.tick_params(axis='both', which='major', labelsize=10)
    plt.xlabel('ratio of pixels with beads')
    plt.ylabel('peak count per image')
    plt.tight_layout()  
    return fig

len_peaks, pixel_sums, peaks, prominences, peaks_list = detect_peaks(dataset)
fig = do_boxplot(len_peaks, pixel_sums, peaks)

plt.tight_layout()  
plt.savefig(save_dir + "peak_count_beadratio.svg", format='svg', dpi = 600)
plt.show()


ret = plt.hist(peaks, bins=np.linspace(-0.5, 300.5, 30), density=False,color='steelblue', edgecolor='black', stacked=True)
plt.xlabel('frequency')
plt.ylabel("peak count")

plt.tight_layout()  
plt.savefig(save_dir + "peak_count_frequency.svg", format='svg', dpi = 600, transparent=True)

In [None]:
fig,ax = plt.subplots(2, 1, figsize=(10 / 2.54*0.9, 8 / 2.54*1.5))
all_patterns = torch.stack([dataset[i]["bead_patterns"] for i in range(len(dataset))])
ax[0].imshow(torch.mean(all_patterns, axis=0)[0], cmap=plt.cm.gray)

ax[0].axis('off')  # Turn off the axis labels
out_mean1, out_mean2 = dataset.datasets[0].out_mean, dataset.datasets[1].out_mean
out_std1, out_std2 = dataset.datasets[0].out_std, dataset.datasets[1].out_std
ax[1].plot((out_mean1 + out_mean2) / 2, lw=3)
ax[1].set_ylim(-25, 75)
ax[1].set_xticks([0, 100, 200, 300])
ax[1].set_yticks([-25, 0, 25, 50, 75])
ax[1].grid(which="major") 
plt.xlabel('frequency')
plt.ylabel('amplitude')

plt.tight_layout()
plt.savefig(save_dir + "dataset_mean.svg", format='svg', dpi = 600, transparent=True)


In [None]:
len_peaks, pixel_sums, peaks, prominences, peaks_list = detect_peaks(V5000_dataset)
print(np.argmax(len_peaks), np.max(len_peaks))
print(np.argmin(len_peaks), np.min(len_peaks))
#np.argwhere(len_peaks == 1)

In [None]:
i = 263  
i2 = 132   

fig, ax = plt.subplots(2, 1, figsize=(10 / 2.54*1.5, 8 / 2.54*1.5)) # , gridspec_kw={'height_ratios': [1, 0.5]}
bead_pattern1, response1 = V5000_dataset[i]["bead_patterns"][0], V5000_dataset[i]["z_vel_mean_sq"]
#response1 = (response1 + out_mean1) * out_std1
bead_pattern2, response2 = V5000_dataset[i2]["bead_patterns"][0], V5000_dataset[i2]["z_vel_mean_sq"]
#response2 = (response2 + out_mean2) * out_std2


plt.subplot(2, 2, 1)  # Create a subplot grid with 2 rows and 2 columns, and select the first subplot
plt.axis('off')  # Turn off the axis labels
plt.imshow(bead_pattern1, cmap=plt.cm.gray)

plt.subplot(2, 2, 2)  # Select the second subplot
plt.imshow(bead_pattern2, cmap=plt.cm.gray)
plt.axis('off')  # Turn off the axis labels
        

plt.subplot(2, 1, 2)  # Select the second subplot
plt.plot(response1, label="left plate", lw=3)
plt.plot(response2, label="right plate", lw=3)
#ax[1].set_ylim(-25, 75)
ax[1].set_xticks([0, 100, 200, 300])
#ax[1].set_yticks([-25, 0, 25, 50, 75])
ax[1].grid(which="major") 
plt.plot(peaks_list[i] - 1, response1[peaks_list[i]], 'x', markersize=6, color="r")
plt.plot(peaks_list[i2], response2[peaks_list[i2]], 'x', markersize=6, color="r")
plt.xlabel('frequency')
plt.ylabel('amplitude')
plt.legend()

plt.subplots_adjust(hspace=0.2)
plt.tight_layout()
plt.savefig(save_dir + "two_single_examples.svg", format='svg', dpi = 600, transparent=True)

## compare different datasets

In [None]:
len_peaks1, pixel_sums1, peaks1, prominence1, peaks_list1 = detect_peaks(V5000_dataset)
len_peaks2, pixel_sums2, peaks2, prominence2, peaks_list2 = detect_peaks(G5000_dataset)

data = [pixel_sums1, pixel_sums2]
data = [len_peaks1, len_peaks2]

# Plotting the boxplots
fig, ax = plt.subplots()
boxplot = ax.boxplot(data, showfliers=False, patch_artist=True, vert=False)
box_color = 'lightblue'
whisker_color = 'gray'
for patch in boxplot['boxes']:
    patch.set_facecolor(box_color)
    patch.set_edgecolor('black')
for whisker in boxplot['whiskers']:
    whisker.set(color=whisker_color, linewidth=1.5, linestyle='--')
for median in boxplot['medians']:
    median.set(color='red', linewidth=2)
    # Customize labels and title
plt.xlabel('number of peaks')
yticklabels = ['V-5000', "G-5000"]
plt.yticks(range(1, len(yticklabels) + 1), yticklabels)
plt.xticks([0,2, 4, 6, 8, 10, 12])
plt.tick_params(axis='y', labelsize=12)
plt.tight_layout()  # Automatically adjusts margins and spacing
plt.savefig(save_dir + "dataset_comparison.svg", format='svg', dpi = 600)
plt.show()
