# Evaluation Notebook

This notebook is designed to reproduce the figures generated in the paper with already trained models. Running this notebook will require adjusting the paths toward the already trained models

In [None]:
from acousticnn.plate.configs.main_dir import main_dir
import wandb, time, os, torch
os.chdir(main_dir)
import numpy as np
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
from acousticnn.plate.dataset import get_dataloader
from acousticnn.plate.train import evaluate, _evaluate, _generate_preds
from acousticnn.plate.train_fsm import evaluate as evaluate_fsm
from acousticnn.plate.train_fsm import _generate_preds as _generate_preds_fsm
from acousticnn.utils.model_utils import get_net
from acousticnn.plate.train_fsm import extract_mean_std, get_mean_from_field_solution
from acousticnn.utils.argparser import get_args, get_config
from acousticnn.utils.plot import plot_loss
import matplotlib.pyplot as plt
from acousticnn.plate.model import model_factory
from matplotlib import rcParams
verbose = False
base_path = os.path.join(main_dir, "experiments/vibrating_plates")
experiment_path = os.path.join(main_dir, "experiments")

In [None]:

%reload_ext autoreload
%autoreload 2
max_frequency = 250
f = np.arange(1, max_frequency +1)
model_cfg = "query_rn18.yaml"

In [None]:
def get_results(model, path=None, fsm=False, verbose=False):
    net = get_net(model, conditional=config.conditional, len_conditional=len(config.mean_conditional_param) if config.conditional else None).cuda()
    if path is None:
        path = f"{experiment_path}/vibrating_plates/{difficulty}/{model}/checkpoint_best"
    net.load_state_dict(torch.load(path)["model_state_dict"])
    if fsm is False:
        prediction, output = _generate_preds(args, config, net, dataloader)
        results = evaluate(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True, verbose=verbose)
    elif fsm is True:
        prediction, output, _ = _generate_preds_fsm(args, config, net, dataloader)
        results = evaluate_fsm(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True, verbose=verbose)
    results.update({"prediction": prediction})
    a, b, c, rmean = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"], results["save_rmean"]
    print(f"{a:4.2f} & {b:4.2f} & {rmean:4.2f} & {c:3.1f}")
    return results

def get_field_prediction(batch, dataloader, net, normalize=True):
    net.eval()
    with torch.no_grad():
        out_mean, out_std, field_mean, field_std = extract_mean_std(dataloader.dataset)
        out_mean, out_std = torch.tensor(out_mean).to(args.device), torch.tensor(out_std).to(args.device)
        field_mean, field_std = torch.tensor(field_mean).to(args.device), torch.tensor(field_std).to(args.device)
        image, field_solution, output, condition = batch["bead_patterns"], batch["z_abs_velocity"], batch["z_vel_mean_sq"], batch["sample_mat"]
        image, field_solution, output, condition = image.to(args.device), field_solution.to(args.device), output.to(args.device), condition.to(args.device)
        prediction_field = net(image, condition)
        pred_field = prediction_field.clone()
        prediction = get_mean_from_field_solution(prediction_field, field_mean, field_std)
        if normalize is True:
            prediction.sub_(out_mean).div_(out_std)
    return prediction.cpu(), pred_field.cpu()

In [None]:
if False:
    from torch.cuda.amp import autocast
    import time
    cfgs = os.listdir(os.path.join(main_dir, "configs/model_cfg/"))
    for model_cfg in cfgs:
        print(model_cfg)
        if model_cfg == "query_unet_1.yaml":
            continue
        args = get_args(["--config", "V5000.yaml", "--model_cfg", model_cfg])

        config = get_config(args.config)
        net = get_net(model_cfg.split(".")[0], conditional=config.conditional).cuda().eval()
        batch = torch.ones((32, 1, 81, 121)).cuda().float()
        with autocast():  # Enable 16-bit casting
            start_time = time.time()
            torch.cuda.synchronize()
            with torch.no_grad():
                for i in range(100):
                    net(batch)
                    torch.cuda.synchronize()
            torch.cuda.synchronize()
            end_time = time.time()
            print(f"Forward pass took {(end_time - start_time)/ 100:.6f} seconds.")

## Evaluate

### Results Transfer

In [None]:
model = "localnet" # query_unet
difficulty = "V5000_larger" # G5000, fsm_V5000
fsm = True
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 6
dataloader = get_dataloader(args, config, logger=None)[2]

_ = get_results(model, path=os.path.join(experiment_path, f"vibrating_plates/transfer/bead_ratio/larger/{model}/checkpoint_best"), fsm=fsm)
_ = get_results(model, path=os.path.join(experiment_path, f"vibrating_plates/transfer/bead_ratio/smaller/{model}/checkpoint_best"), fsm=fsm)

difficulty = "V5000_smaller" # G5000, fsm_V5000
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 6
dataloader = get_dataloader(args, config, logger=None)[2]
_ = get_results(model, path=os.path.join(experiment_path, f"vibrating_plates/transfer/bead_ratio/smaller/{model}/checkpoint_best"), fsm=fsm)
_ = get_results(model, path=os.path.join(experiment_path, f"vibrating_plates/transfer/bead_ratio/larger/{model}/checkpoint_best"), fsm=fsm)

### all results

In [None]:
difficulty = "G5000" # G5000, V5000
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
config.max_frequency = max_frequency
args.batch_size = 6
dataloader = get_dataloader(args, config, logger=None)[2]

In [None]:
results4 = get_results("fno_decoder", verbose=verbose)
results5 = get_results("deeponet", verbose=verbose)
results7 = get_results("fno_fsm", verbose=verbose, fsm=True)

results2 = get_results("grid_rn18", verbose=verbose)
results3 = get_results("query_rn18", verbose=verbose)
results1 = get_results("vit_implicit", verbose=verbose)

results8 = get_results("unet", verbose=verbose, fsm=True)
results6 = get_results("localnet", verbose=verbose, fsm=True)

## Figures

In [None]:
rcParams['axes.labelsize'] = 5
rcParams['axes.titlesize'] = 5
rcParams['axes.titlesize'] = 5
plt.rcParams.update({'font.size': 5})

figsize = (6.75/4, 1.35)
figsize_large = (6.75/3, 1.35)
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.Set2([0, 0.5,1]))
plt.rcParams['text.usetex'] = False

save_dir = "../../plots/results"
from scipy.ndimage import zoom
from matplotlib.ticker import ScalarFormatter
import seaborn as sns


### mse over freq

In [None]:
args = get_args(["--config", "G5000.yaml", "--model_cfg", model_cfg])
args.batch_size = 6
config = get_config(args.config)
config.max_frequency = max_frequency
dataloader = get_dataloader(args, config, logger=None)[2]
G5000_losses = get_results("localnet", verbose=verbose, fsm=True, path=f"{experiment_path}/vibrating_plates/G5000/localnet/checkpoint_best")["losses_per_f"]

args = get_args(["--config", "V5000.yaml", "--model_cfg", model_cfg])
args.batch_size = 6
config = get_config(args.config)
config.max_frequency = max_frequency
dataloader = get_dataloader(args, config, logger=None)[2]
V5000_losses = get_results("localnet", verbose=verbose, fsm=True, path=f"{experiment_path}/vibrating_plates/V5000/localnet/checkpoint_best")["losses_per_f"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=figsize)
plot = plot_loss(V5000_losses, f, ax, quantile=0.5)
plot = plot_loss(G5000_losses, f, ax, quantile=0.5)

legend_labels = ["V-5000", "_", "G-5000", "_"]
ax.legend(legend_labels, loc='upper left')

ax.grid(lw=0.2)
ax.set_ylim(0, 0.3)
ax.set_yticks(np.arange(0, 0.4, 0.1))
sns.despine(ax=ax, offset=5)

plt.tight_layout()
plt.savefig(save_dir + "/mse_over_f.svg", format='svg', dpi = 600, transparent=True)

### Example Frequency Response Predictions

In [None]:
normalize = False
batch = next(iter(dataloader))
actual_frequency_response, field_solution, image =  batch["z_vel_mean_sq"], batch["z_abs_velocity"], batch["bead_patterns"][:, 0]
if not normalize:
        out_mean, out_std, field_mean, field_std = extract_mean_std(dataloader.dataset)
        actual_frequency_response = actual_frequency_response * out_std + out_mean
net = get_net("localnet", conditional=False).cuda()
path = f"{experiment_path}/vibrating_plates/{difficulty}/localnet/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction, prediction_field = get_field_prediction(batch, dataloader, net, normalize=normalize)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(6.75, 1.2))

import seaborn as sns
for i, ax in enumerate(axes.transpose()):
    ax.imshow(image[i], cmap='gray')
    ax.axis('off')
# plt.savefig(os.path.join(save_dir, f"example_plates.pdf"), bbox_inches='tight', transparent=True)


fig, axes = plt.subplots(1, 5, figsize=(6.75, 1.2))
for i, ax in enumerate(axes.transpose()):
    ax.plot(actual_frequency_response[i], lw=0.5, label="Reference", color="black", linestyle='dashed',)
    ax.plot(prediction[i], lw=0.5, label="Prediction", color="#55a78c")
    ax.set_ylim(-20, 80)
    #ax.grid(lw=0.2) 
    ax.set_xlabel('Frequency', fontsize=5)
    ax.set_xticks([0, 100, 200, 300])
    ax.set_ylabel('Amplitude', fontsize=5)
    if i > 0: # Apply changes to all but the first subplot in the lower row
        ax.set_yticklabels([]) # Remove y-axis labels
        ax.set_yticks([]) # Remove y-axis ticks to keep the grid visible
        sns.despine(ax=ax, offset=5, left=True) # Remove left spine for these plots
        ax.set_ylabel("")

    else:
        sns.despine(ax=ax, offset=5) # Apply standard despine for the first subplot
    ax.grid(lw=0.2) 


ax.legend()
#plt.tight_layout()

#plt.savefig(os.path.join(save_dir, f"example_predictions.pdf"), bbox_inches='tight', transparent=True)
plt.show()

In [None]:
normalize = False
batch = next(iter(dataloader))
actual_frequency_response, field_solution, image =  batch["z_vel_mean_sq"], batch["z_abs_velocity"], batch["bead_patterns"][:, 0]
f = np.arange(1, max_frequency+1)
net = get_net("localnet", conditional=False).cuda()
path = f"{experiment_path}/vibrating_plates/{difficulty}/localnet/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction_localnet, prediction_field = get_field_prediction(batch, dataloader, net, normalize=True)
net = get_net("unet", conditional=False).cuda()
path = f"{experiment_path}/vibrating_plates/{difficulty}/unet/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction_unet, _ = get_field_prediction(batch, dataloader, net, normalize=True)

label_grid, label_query = "Grid-UNet", "FQO-UNet"
idx, freq = 0, 132 #  and 11 

if not normalize:
        out_mean, out_std, field_mean, field_std = extract_mean_std(dataloader.dataset)
        prediction_localnet = prediction_localnet.mul(out_std).add_(out_mean)
        prediction_unet= prediction_unet.mul(out_std).add_(out_mean)
        actual_frequency_response = actual_frequency_response.mul(out_std).add_(out_mean)

fig,ax = plt.subplots(1, 1, figsize=figsize)
ax.plot(f, actual_frequency_response[idx][:max_frequency],  label="Reference", color="#909090", lw=0.5, linestyle='dashed', dashes=[1, 1])
ax.plot(f, prediction_unet[idx][:max_frequency], alpha = 0.8,  color="#e19c2c", label=label_grid, lw=0.5)
ax.plot(f, prediction_localnet[idx][:max_frequency], alpha = 0.8, color="#55a78c", label=label_query, lw=0.5)
ax.scatter(freq, actual_frequency_response[idx][freq], color="red", marker="x", s=4, label='Frequency $\it{f}$')
ax.grid(which="major", lw=0.2), ax.set_xticks([0, 100, 200]), ax.set_yticks([-10, 10, 30, 50, 70])
ax.set_xlabel('Frequency')
ax.set_ylabel('Amplitude')
ax.legend(loc="lower left", frameon=False)
sns.despine(offset=5)
plt.tight_layout()
plt.savefig(save_dir + f"/prediction_{difficulty}.pdf", format='pdf', transparent=True, bbox_inches='tight')      
plt.show()


In [None]:
#### FIELD SOLUTIONS #####
field_solution_trans = field_solution * field_std + field_mean
field_solution_trans = np.sqrt(np.exp(field_solution_trans))
field_solution_trans =  zoom(field_solution_trans[idx][freq], 2, order=3)  
prediction_field_trans = prediction_field * field_std + field_mean
prediction_field_trans = np.sqrt(np.exp(prediction_field_trans))
prediction_field_trans =  zoom(prediction_field_trans[idx][freq], 2, order=3)  
vmin = np.min((np.min(prediction_field_trans), np.min(field_solution_trans)))
vmax = np.max((np.max(prediction_field_trans), np.max(field_solution_trans)))

y, x = np.mgrid[0:field_solution_trans.shape[0], 0:field_solution_trans.shape[1]]
fig, ax = plt.subplots(1, 1, figsize=figsize_large)

fig = plt.contourf(x, y, field_solution_trans, levels=20, vmin=vmin, vmax=vmax, antialiased=True, cmap=plt.cm.gray)

for c in fig.collections:
    c.set_edgecolor("face")
plt.tight_layout()
plt.axis("off")
plt.savefig(save_dir + "/solution.svg", transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()
fig, ax = plt.subplots(1, 1, figsize=figsize_large)
fig = plt.contourf(x, y, prediction_field_trans, levels=20, vmin=vmin, vmax=vmax, antialiased=True, cmap=plt.cm.gray)

for c in fig.collections:
    c.set_edgecolor("face")
plt.tight_layout()
plt.axis("off")
plt.savefig(save_dir + "/pred_field.svg", transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

#### Scale #####
fig, ax = plt.subplots(1, 1, figsize=figsize_large)

fig = plt.contourf(x, y, prediction_field_trans*100, levels=20, vmin=vmin, vmax=vmax*100, antialiased=True, cmap=plt.cm.gray)
cbar = plt.colorbar(fig)

formatter = ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-2, 2))

cbar.set_ticks([0, 1, 2, 2.7])

cbar.formatter = formatter
cbar.update_ticks()
cbar.set_label(r'Velocity $\times 10^{-2}$ m/s', fontsize=5)
cbar.ax.tick_params(labelsize=5)

for c in fig.collections:
    c.set_edgecolor("face")
plt.tight_layout()
plt.axis("off")
plt.savefig(save_dir + "/colorbar.svg", transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()


plt.imshow(dataloader.dataset.files["bead_patterns"][idx,0], cmap='gray')
plt.axis('off')

plt.savefig(save_dir + "/beading_pattern.svg", transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()



## MSE over data amount

In [None]:
experiments = ["10_percent", "25_percent", "50_percent", "75_percent"]
model = "query_rn18"
data_vary_path = os.path.join(base_path, "data_variation/", model, difficulty)
paths = [os.path.join(data_vary_path, exp_path,  "checkpoint_best") for exp_path in experiments]
paths = paths + [os.path.join(base_path, f"{difficulty}/{model}/checkpoint_best")]
[print(path) for path in paths]
loss_a = []
for path in paths:
    results = get_results(model, verbose=verbose, path=path, fsm=False)
    loss_a.append(results["loss (test/val)"])


In [None]:
model = "localnet"
data_vary_path = os.path.join(base_path, "data_variation/", model, difficulty)
paths = [os.path.join(data_vary_path, exp_path,  "checkpoint_best") for exp_path in experiments]
paths = paths + [os.path.join(base_path, f"{difficulty}/{model}/checkpoint_best")]
[print(path) for path in paths]

loss_b = []
for path in paths:
    results = get_results(model, verbose=verbose, path=path, fsm=True)
    loss_b.append(results["loss (test/val)"])

In [None]:
# rmse over data amout
fig, ax = plt.subplots(1, 1, figsize=figsize)
if difficulty == "V5000":
    max_samples = 4500 
size = np.array([0.1, 0.25, 0.5, 0.75, 1])
n_samples = max_samples * size
ax.plot(n_samples, loss_a,  'o-', color="#b38784",label="FQO-RN18", lw=0.5, markersize=3)
ax.plot(n_samples, loss_b,  'o-', color="#b5b564", label="FQO-UNet", lw=0.5, markersize=3)
plt.xlabel('Number of samples')
plt.ylabel('MSE')
plt.legend()
ax.grid(lw=0.2)

ax.set_yticks(np.arange(0.1, 0.90, 0.2))
ax.set_ylim(0,0.8)
ax.set_xticks(np.arange(0, max_samples*1.30, max_samples/2))
sns.despine(ax=ax, offset=5)

plt.tight_layout()  # Automatically adjusts margins and spacing

plt.savefig(save_dir + f"/data_variation_{difficulty}.svg", format='svg', dpi = 600, transparent=True)