In [1]:
from acousticnn.main_dir import main_dir
import torch, os, json
import seaborn as sns
os.chdir(main_dir)

%reload_ext autoreload
%autoreload 2

import numpy as np
from acousticnn.plate.dataset import get_dataloader
from acousticnn.model import model_factory
from acousticnn.utils.logger import print_log
from acousticnn.utils.argparser import get_args, get_config
from torchinfo import summary
import seaborn as sns

import matplotlib.pyplot as plt
from acousticnn.plate.train_fsm import evaluate

np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
from acousticnn.plate.train_fsm import extract_mean_std, get_mean_from_field_solution

In [2]:
plt.rcParams['axes.labelsize'] = 5
plt.rcParams['axes.titlesize'] = 5
plt.rcParams['axes.titlesize'] = 5
plt.rcParams.update({'font.size': 5})

figsize = (6.75/4, 1.35)
figsize_large = (6.75/3, 1.35)
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.Set2(np.linspace(0,1,2)))
plt.rcParams['text.usetex'] = False
save_dir = "plots/results/"

## isoconture evals

In [3]:
results_150k = json.load(open(os.path.join(main_dir, 'results/results_150k.json')))
results_baseline = json.load(open(os.path.join(main_dir, 'results/results_baseline.json')))
results_best = json.load(open(os.path.join(main_dir, 'results/results_best.json')))

In [None]:
for key in results_150k[0].keys():
    if key == 'n_freq':
        continue
    n_freqs = [result['n_freq'] for i, result in enumerate(results_150k)]
    results = [result[key] for i, result in enumerate(results_150k)]
    n_freqs, results = zip(*sorted(zip(n_freqs, results)))
    lossname = 'MSE'
    # Plotting

    plt.figure(figsize=figsize)
    plt.plot(n_freqs, results, marker='o', lw=0.5, ms=2, label=f'150k FEM evals')
    plt.xscale('log', base=2)
    plt.xticks(n_freqs, labels=[str(freq) for freq in n_freqs])


    plt.scatter(300, y=results_baseline[key], color='r', marker='x', s=17, label=f'1.5 Mio FEM evals ')
    plt.scatter(15, y=results_best[key], color='g', marker='*', s=17, label=f'750k FEM evals')

    plt.xlabel('Frequencies per plate geometry')
    plt.ylabel(lossname)
    plt.ylim(bottom=0)
    #plt.title(f'{lossname} vs. # frequencies / geometry')
    plt.grid(True, lw=0.2)
    plt.legend(fontsize=4.5)
    sns.despine(offset=5)
    plt.tight_layout()
    print(key)
    if key == 'loss (test/val)':
        pass
        plt.savefig(os.path.join(save_dir, 'data_efficiency.svg'), transparent=True)
    plt.show()

# compute results

In [None]:
args = get_args(['--model_cfg', 'fqo_unet.yaml', '--config', 'cfg/V5000_no_sampling.yaml', '--dir', 'debug', '--batch_size', '2'])
logger = None
config = get_config(args.config)
model_cfg = get_config(args.model_cfg)
trainloader, valloader, testloader, _, _, _ = get_dataloader(args, config, logger)

net = model_factory(**model_cfg, conditional=config.conditional, rmfreqs=hasattr(config, "rmfreqs"), len_conditional=len(config.mean_conditional_param) if config.conditional else None).to(args.device)

batch = next(iter(trainloader))
batch = {k: v.to(args.device) for k, v in batch.items()}
summary(net, input_data=(batch["bead_patterns"], batch["phy_para"], batch["frequencies"]))

def evaluate_checkpoint(path):
    data = torch.load(os.path.join(path, "checkpoint_best"))
    new_state_dict = {}
    for key in data["model_state_dict"]:
        new_key = key.replace("_orig_mod.", "")  # Adjust the key as needed
        new_state_dict[new_key] = data["model_state_dict"][key]

    net.load_state_dict(new_state_dict)
    print('evaluate:', path)
    results = evaluate(args, config, net, testloader, logger=logger, report_peak_error=True, epoch=None, report_wasserstein=True)
    a, b, c, save_rmean = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"], results["save_rmean"]
    print_log(f"{a:4.2f} & {b:4.2f} & {save_rmean:3.2f} & {c:3.1f}", logger=logger)
    return results

results = evaluate_checkpoint("path/to/checkpoint")

keys = ['loss (test/val)', 'wasserstein', 'save_rmean']
results = {key: results[key].item() for key in keys}

# Save to a JSON file
with open('results/results.json', 'w') as f:
    json.dump(results, f, indent=4)