In [None]:
from acousticnn.plate.configs.main_dir import main_dir
import numpy as np
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
from acousticnn.plate.dataset import get_dataloader
from acousticnn.plate.train import evaluate, _evaluate, _generate_preds
from acousticnn.plate.train_fsm import evaluate as evaluate_fsm
from acousticnn.plate.train_fsm import extract_mean_std, get_mean_from_field_solution
from acousticnn.utils.argparser import get_args, get_config
from acousticnn.utils.plot import plot_loss
import wandb, time, os, torch
import matplotlib.pyplot as plt
from acousticnn.plate.model import model_factory
from matplotlib import rcParams

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="2"
base_path = os.path.join(main_dir, "experiments/arch")
experiment_path = os.path.join(main_dir, "experiments")

In [None]:

rcParams['axes.labelsize'] = 12
rcParams['axes.titlesize'] = 12
rcParams['axes.titlesize'] = 12
rcParams["figure.figsize"] = (10 / 2.54, 8 / 2.54)
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.Set2([0, 0.5,1]))
plt.rcParams['text.usetex'] = False

f = np.arange(0, 250)
save_dir = "plots/results"

%reload_ext autoreload
%autoreload 2
model_cfg = "query_rn18.yaml"

In [None]:
def get_net(model, conditional):
    print(model)
    model_cfg = model + ".yaml"
    args = get_args(["--config", "0toy.yaml", "--model_cfg", model_cfg])
    return model_factory(**get_config(args.model_cfg), conditional=conditional)

def get_results(model, path=None, fsm=False, verbose=False):
    net = get_net(model, conditional=config.conditional).cuda()
    if path is None:
        path = f"{base_path}/{model}/{difficulty}/checkpoint_best"
    net.load_state_dict(torch.load(path)["model_state_dict"])
    prediction, output = _generate_preds(args, config, net, dataloader)
    if fsm is False:
        results = evaluate(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True, verbose=verbose)
    elif fsm is True:
        results = evaluate_fsm(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True, verbose=verbose)
    results.update({"prediction": prediction})
    r25, r75 = np.nanquantile(results["peak_ratio"], 0.25), np.nanquantile(results["peak_ratio"], 0.75)
    a,b,c = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"]
    print(f"{a:4.2f} & {b:4.2f} & [{r25:4.2f}, {r75:4.2f}] & {c:3.1f}")
    return results

def get_field_prediction(batch, dataloader, net):
    net.eval()
    with torch.no_grad():
        predictions, outputs, losses = [], [], []
        out_mean, out_std, field_mean, field_std = extract_mean_std(dataloader.dataset)
        out_mean, out_std = torch.tensor(out_mean).to(args.device), torch.tensor(out_std).to(args.device)
        field_mean, field_std = torch.tensor(field_mean).to(args.device), torch.tensor(field_std).to(args.device)
        image, field_solution, output, condition = batch["bead_patterns"], batch["z_abs_velocity"], batch["z_vel_mean_sq"], batch["sample_mat"]
        image, field_solution, output, condition = image.to(args.device), field_solution.to(args.device), output.to(args.device), condition.to(args.device)
        prediction_field = net(image, condition)
        pred_field = prediction_field.clone()
        prediction = get_mean_from_field_solution(prediction_field, field_mean, field_std)
        prediction.sub_(out_mean).div_(out_std)
    return prediction.cpu(), pred_field.cpu()
preds = {}
preds_field = {}

In [None]:
if True:
    import time
    cfgs = os.listdir(os.path.join(main_dir, "configs/model_cfg/"))
    for model_cfg in cfgs:
        print(model_cfg)
        args = get_args(["--config", "fsm_V5000.yaml", "--model_cfg", model_cfg])

        config = get_config(args.config)
        net = model_factory(**get_config(args.model_cfg))
        net = net.cuda()
        net.eval()
        batch = torch.ones((32, 1, 81, 121)).cuda().float()

        with torch.no_grad():
            net(batch)
        start_time = time.time()
        with torch.no_grad():
            net(batch)
        end_time = time.time()
        print(f"Forward pass took {end_time - start_time:.6f} seconds.")

## Evaluate

In [None]:
torch.set_default_tensor_type('torch.FloatTensor')
difficulty = "G5000" # G5000, fsm_V5000
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 4
dataloader = get_dataloader(args, config, logger=None)[2]
verbose = False

In [None]:
# run only for transfer
if False:
    torch.set_default_tensor_type('torch.FloatTensor')
    args = get_args(["--config", "G5000.yaml", "--model_cfg", "fno_conditional.yaml"])
    config_transfer = get_config(args.config)
    config_transfer.data_path_ref = config.data_path_ref
    trainloader, valloader, testloader, trainset, valset, testset = get_dataloader(args, config_transfer, logger=None, num_workers=0)
    dataloader = testloader

In [None]:
results1 = get_results("vit_implicit", verbose=verbose)
results2 = get_results("grid_rn18", verbose=verbose)
results3 = get_results("query_rn18", verbose=verbose)
results4 = get_results("fno_decoder", verbose=verbose)
results5 = get_results("deeponet", verbose=verbose)
results6 = get_results("query_unet", verbose=verbose, fsm=True)
results7 = get_results("fno_fsm", verbose=verbose, fsm=True)
results8 = get_results("unet", verbose=verbose, fsm=True)

In [None]:
model = "query_unet" # query_unet
difficulty = "fsm_V5000b" # G5000, fsm_V5000
fsm = True
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 4
dataloader = get_dataloader(args, config, logger=None)[2]

#_ = get_results(model, path=os.path.join(experiment_path, f"transfer/bead_ratio/smaller/{model}/checkpoint_best"))
_ = get_results(model, path=os.path.join(experiment_path, f"transfer/bead_ratio/smaller/checkpoint_best"), fsm=fsm)

difficulty = "fsm_V5000a" # G5000, fsm_V5000
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 4
dataloader = get_dataloader(args, config, logger=None)[2]
#_ = get_results(model, path=os.path.join(experiment_path, f"transfer/bead_ratio/larger/{model}/checkpoint_best"))
_ = get_results(model, path=os.path.join(experiment_path, f"transfer/bead_ratio/larger/checkpoint_best"), fsm=fsm)

In [None]:
results_first = results3.copy()
loss_per_sample = np.mean(results3["losses_per_f"], axis=1)
print(np.argmin(loss_per_sample), np.argmax(loss_per_sample))
print(loss_per_sample[np.argmin(loss_per_sample)], loss_per_sample[np.argmax(loss_per_sample)])

In [None]:
label_grid= "RN18 + FNO"
label_query= "Query-based RN18"
num=11 # 8 and 11 
prediction1 = results2["prediction"]
prediction2 = results4["prediction"]
a = _evaluate(prediction1[num:num+1], output[num:num+1], None, config, args, epoch=0, report_peak_error=True, report_wasserstein=True, dataloader=dataloader)
rmse, emd = a["loss (test/val)"], a["wasserstein"]
eval_grid = label_grid + ", MSE: " + f"{rmse:4.2}" + ", EMD: " + f"{emd:4.3}"
a = _evaluate(prediction2[num:num+1], output[num:num+1], None, config, args, epoch=0, report_peak_error=True, report_wasserstein=True, dataloader=dataloader)
rmse, emd = a["loss (test/val)"], a["wasserstein"]
eval_query = label_query + ", MSE: " + f"{rmse:4.2}" + ", EMD: " + f"{emd:4.3}"

fig,ax = plt.subplots(1, 1, figsize=(10 / 2.54*1.5, 8 / 2.54))
ax.plot(f, output[num],  label="Reference", color="#909090", lw=2.5,linestyle='dashed',dashes=[1, 1])
ax.plot(f, prediction1[num], alpha = 0.8,  color="#e19c2c", label=eval_grid, lw=2.5)
ax.plot(f, prediction2[num], alpha = 0.8, color="#55a78c", label=eval_query, lw=2.5)
ax.set_yticks([-4, -2, 0, 2, 4])
ax.grid(which="major") 
ax.set_ylim(-4, 3.5)
ax.set_xlabel('frequency')
ax.set_ylabel('normalized amplitude')
ax.legend(fontsize=11, loc="lower left", frameon=False)
plt.tight_layout()
plt.show()
#plt.savefig(save_dir + f"prediction_{difficulty}.pdf", format='pdf', dpi = 600, transparent=True)

In [None]:
results_v5000 = results6.copy()


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8 / 2.54, 7.5 / 2.54 * 0.9))
plot = plot_loss(results_v5000["losses_per_f"], f, ax, quantile=0.5)
plot = plot_loss(results6["losses_per_f"], f, ax, quantile=0.5)
legend_labels = ["V-5000", "_", "G-5000", "_"]
ax.legend(legend_labels, fontsize=10, loc='upper left')
ax.grid()
ax.set_ylim(0, 0.3)
ax.set_yticks(np.arange(0, 0.4, 0.1))
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(10)
plt.tight_layout()
plt.savefig(save_dir + "/mse_over_f.svg", format='svg', dpi = 600, transparent=True)

## knn

In [None]:
from acousticnn.plate.knn.knn_train import AutoEncoder, generate_encoding, get_checker, generate_plots, get_predictions, get_output, pred_fn, get_pred_img, eval_knn

In [None]:
torch.set_default_tensor_type('torch.FloatTensor')
setting = "fsm_V5000"
args = get_args(["--config", f"{setting}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
trainloader, valloader, testloader, trainset, valset, testset= get_dataloader(args, config, logger=None, shuffle=False)
dataloader, dataset = valloader, valset
k_max=25
net = AutoEncoder().cuda()
path = os.path.join(base_path, "knn", setting, "checkpoint_best")
net.load_state_dict(torch.load(path))

In [None]:
reference, queries = generate_encoding(trainloader, net), generate_encoding(valloader, net)
losses = eval_knn(reference, queries, k_max, config, logger=None, query_set=valset, reference_set=trainset)
n_neighbors = np.argmin(losses) + 1
print(n_neighbors)

In [None]:
loss_hard = losses

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(10 / 2.54*1, 8 / 2.54))
ax.plot(losses, label = "F-2500", lw=2.5)
ax.plot(loss_hard, label = "V-5000", lw=2.5)
ax.set_yticks(np.arange(0.5, 1.1, 0.2))
ax.grid(which="major") 
ax.set_xlabel('k')
ax.set_ylabel('MSE')
ax.legend(fontsize=11)
plt.tight_layout()
plt.savefig(save_dir + "knn_k_sweep.pdf", format='pdf', dpi = 600, transparent=True)


In [None]:
n_neighbors = 3
n_examples = 3
prediction = get_pred_img(n_neighbors, trainset, trainloader, dataloader, net).squeeze(2)
fig,ax = plt.subplots(n_examples, n_neighbors+ 1, figsize=(10 / 2.54*n_examples, 8 / 2.54*n_neighbors*0.6))
[a.axis('off') for a in ax[:,:].flatten()]

for i in range(n_examples):
    ax[i, 0].imshow(valset[i]["bead_patterns"].squeeze(), cmap=plt.cm.gray)
    for j in range(n_neighbors):
        ax[i, j+1].imshow(prediction[i,j].squeeze(), cmap=plt.cm.gray)
plt.tight_layout()
plt.savefig(save_dir + "knn_nearest_neigbor_images.pdf", format='pdf', dpi = 600, transparent=True)


#### test results

In [None]:
#n_neighbors = 20
use_net=True
dataloader, dataset = testloader, testset
output = get_output(dataset, config)
prediction = pred_fn(n_neighbors, trainset, trainloader, dataloader, net, config, use_net=use_net)
results = _evaluate(prediction, output, config=config, args=args, report_peak_error=True, report_wasserstein=True, dataloader=dataloader, epoch=None, logger=None)
r25, r75 = np.quantile(results["peak_ratio"], 0.25), np.quantile(results["peak_ratio"], 0.75)
a,b,c = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"]
print(f"{a:4.2f} & {b:4.2f} & [{r25:4.2f}, {r75:4.2f}], & {c:3.1f}")

## MSE over data amount

In [None]:
difficulty = "fsm_V5000" 
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", "query_rn18.yaml"])
config = get_config(args.config)
dataloader_hard = get_dataloader(args, config, logger=None)[2]
dataloader = dataloader_hard
experiments = ["10_percent", "25_percent", "50_percent", "75_percent"]
model = "query_rn18"
data_vary_path = os.path.join(experiment_path, "data_variation/", model, difficulty)
paths = [os.path.join(data_vary_path, exp_path,  "checkpoint_best") for exp_path in experiments]
paths = paths + [os.path.join(experiment_path, f"arch/{model}/{difficulty}/checkpoint_best")]
[print(path) for path in paths]


In [None]:
loss_a = []
for path in paths:
    results = get_results(model, verbose=verbose, path=path, fsm=False)
    loss_a.append(results["loss (test/val)"])

In [None]:
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", model_cfg])
config = get_config(args.config)
args.batch_size = 8
dataloader_hard = get_dataloader(args, config, logger=None)[2]
dataloader = dataloader_hard

experiment_paths = ["10_percent", "25_percent", "50_percent", "75_percent"]
model = "query_unet"
data_vary_path = os.path.join(experiment_path, "data_variation/", model, difficulty)
paths = [os.path.join(data_vary_path, exp_path,  "checkpoint_best") for exp_path in experiments]
paths = paths + [os.path.join(experiment_path, f"arch/{model}/{difficulty}/checkpoint_best")]
[print(path) for path in paths]

In [None]:
loss_b = []
for path in paths:
    results = get_results(model, verbose=verbose, path=path, fsm=True)
    loss_b.append(results["loss (test/val)"])

In [None]:
# rmse over data amout
fig, ax = plt.subplots(1, 1, figsize=(8 / 2.54, 7.5 / 2.54 * 0.9))
if difficulty == "fsm_V5000":
    max_samples = 4500 
else:
    max_samples = 2000
size = np.array([0.1, 0.25, 0.5, 0.75, 1])
n_samples = max_samples * size
ax.plot(n_samples, loss_a,  'o-', color="#b38784",label="Query-RN18")
ax.plot(n_samples, loss_b,  'o-', color="#b5b564", label="Query-UNet")
plt.xlabel('Number of samples')
plt.ylabel('MSE')
plt.legend(fontsize=10)
ax.grid()

ax.set_yticks(np.arange(0.2, 0.50, 0.2))
ax.set_xticks(np.arange(0, max_samples*1.30, max_samples/2))
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(10)
plt.tight_layout()  # Automatically adjusts margins and spacing

plt.savefig(save_dir + f"/data_variation_{difficulty}.svg", format='svg', dpi = 600, transparent=True)