In [None]:
from acousticnn.plate.configs.main_dir import main_dir

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
base_path = os.path.join(main_dir, "experiments/arch")

In [None]:
import numpy as np
from acousticnn.plate.dataset import get_dataloader
from acousticnn.plate.model import model_factory
from acousticnn.plate.train_fsm import extract_mean_std, get_mean_from_field_solution
from acousticnn.utils.builder import build_opti_sche
from acousticnn.utils.logger import init_train_logger, print_log
from acousticnn.utils.argparser import get_args, get_config
from acousticnn.plate.train_fsm import evaluate, _generate_preds
from acousticnn.plate.train import evaluate as evaluate_implicit, _generate_preds as generate_preds_implicit
from acousticnn.plate.train import _evaluate
from torchinfo import summary
import wandb, time, torch


np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})


difficulty = "fsm_V5000" # fsm_V5000, G5000
args = get_args(["--config", f"{difficulty}.yaml", "--model_cfg", "fno_conditional.yaml"])
args.batch_size = 16

config = get_config(args.config)
trainloader, valloader, testloader, trainset, valset, testset = get_dataloader(args, config, logger=None, num_workers=0, shuffle=False)
dataloader = testloader
batch = next(iter(dataloader))
actual_frequency, field_solution, image =  batch["z_vel_mean_sq"], batch["z_abs_velocity"], batch["bead_patterns"][:, 0]

%reload_ext autoreload
%autoreload 2


out_mean, out_std, field_mean, field_std = extract_mean_std(trainloader.dataset)
out_mean, out_std = torch.tensor(out_mean).to(args.device), torch.tensor(out_std).to(args.device)
field_mean, field_std = torch.tensor(field_mean).to(args.device), torch.tensor(field_std).to(args.device)

In [None]:
def get_net(model, dataset):
    print(model, dataset)
    if model == "fno_fsm":
        model_cfg = "fno_fsm"
    elif model == "unet":
        model_cfg = "unet"
    elif model == "query_rn18":
        model_cfg = "query_rn18"
    elif model == "query_unet":
        model_cfg = "query_unet"
    else:
        raise NotImplementedError
    model_cfg += ".yaml"
        
    args = get_args(["--config", dataset + ".yaml", "--model_cfg", model_cfg])
    config = get_config(args.config)
    return model_factory(**get_config(args.model_cfg), conditional=config.conditional)

def get_field_prediction(batch, dataloader, net):
    net.eval()
    with torch.no_grad():
        predictions, outputs, losses = [], [], []
        out_mean, out_std, field_mean, field_std = extract_mean_std(dataloader.dataset)
        out_mean, out_std = torch.tensor(out_mean).to(args.device), torch.tensor(out_std).to(args.device)
        field_mean, field_std = torch.tensor(field_mean).to(args.device), torch.tensor(field_std).to(args.device)
        image, field_solution, output, condition = batch["bead_patterns"], batch["z_abs_velocity"], batch["z_vel_mean_sq"], batch["sample_mat"]
        image, field_solution, output, condition = image.to(args.device), field_solution.to(args.device), output.to(args.device), condition.to(args.device)
        prediction_field = net(image, condition)
        pred_field = prediction_field.clone()
        prediction = get_mean_from_field_solution(prediction_field, field_mean, field_std)
        prediction.sub_(out_mean).div_(out_std)
    return prediction.cpu(), pred_field.cpu()
preds = {}
preds_field = {}

In [None]:
# run only for transfer
if False:
    torch.set_default_tensor_type('torch.FloatTensor')
    args = get_args(["--config", "fsm_V5000.yaml", "--model_cfg", "fno_conditional.yaml"])
    config_transfer = get_config(args.config)
    config_transfer.data_path_ref = config.data_path_ref
    trainloader, valloader, testloader, trainset, valset, testset = get_dataloader(args, config_transfer, logger=None, num_workers=0)
    dataloader = testloader

In [None]:
model = "fno_fsm"
net = get_net(model, difficulty).cuda()
path = f"{base_path}/fno_fsm/{difficulty}/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction, output, losses = _generate_preds(args, config, net, dataloader)
results1 = evaluate(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True)
pred, pred_field = get_field_prediction(batch, dataloader, net)
preds.update({model: pred})
preds_field.update({model: pred_field})
results1.update({"prediction": prediction})
results = results1
a,b,c, r25, r75 = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"], results["r25"], results["r75"]
print(f"{a:4.2f} & {b:4.2f} & [{r25:3.2f}, {r75:3.2f}] & {c:3.1f}")

model = "unet"
net = get_net(model, difficulty).cuda()
path = f"{base_path}/{model}/{difficulty}/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction, output, losses = _generate_preds(args, config, net, dataloader)
results2 = evaluate(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True)
results2.update({"prediction": prediction})
pred, pred_field = get_field_prediction(batch, dataloader, net)
preds.update({model: pred})
preds_field.update({model: pred_field})
results = results2
a,b,c, r25, r75 = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"], results["r25"], results["r75"]
print(f"{a:4.2f} & {b:4.2f} & [{r25:3.2f}, {r75:3.2f}] & {c:3.1f}")

model = "query_unet"
net = get_net(model, difficulty).cuda()
path = f"{base_path}/{model}/{difficulty}/checkpoint_best"
net.load_state_dict(torch.load(path)["model_state_dict"])
prediction, output, losses = _generate_preds(args, config, net, dataloader)
results2 = evaluate(args, config, net, dataloader, report_peak_error=True, epoch=None, report_wasserstein=True)
results2.update({"prediction": prediction})
pred, pred_field = get_field_prediction(batch, dataloader, net)
preds.update({model: pred})
preds_field.update({model: pred_field})
results = results2
a,b,c, r25, r75 = results["loss (test/val)"], results["wasserstein"], results["frequency_distance"], results["r25"], results["r75"]
print(f"{a:4.2f} & {b:4.2f} & [{r25:3.2f}, {r75:3.2f}] & {c:3.1f}")


# Generate visualization

In [None]:
preds["query_unet"] = preds["query_unet"].numpy()[:, :250]
preds["unet"] = preds["unet"].numpy()[:, :250]
preds_field["query_unet"] = preds_field["query_unet"].numpy()[:, :250]
preds_field["unet"] = preds_field["unet"].numpy()[:, :250]
preds["actual"] = batch["z_vel_mean_sq"].numpy()[:, :250]
preds_field["actual"] = batch["z_abs_velocity"].numpy()[:, :250]
bead_patterns = batch["bead_patterns"].numpy()


path = os.path.expanduser("~/network_folder/tmp/visualize_results.pt")
data = {
    "preds": preds,
    "preds_field": preds_field,
    "bead_patterns": bead_patterns
}
field_solution = data["preds_field"]["actual"]
prediction_field = data["preds_field"]["unet"]
prediction_field.shape

#torch.save(data, path)

In [None]:
(8 / 2.54, 7.5 / 2.54 * 0.9)


In [None]:
import matplotlib.pyplot as plt
save_dir = "plots/results"

idx, freq = 10, 147
# field_solution = batch["z_abs_velocity"].numpy()[:, :250]
# prediction_field = data["preds_field"]["query_unet"]
# eps = 1e-9
# field_solution = field_solution * field_std.cpu().numpy() + field_mean.cpu().numpy()[:, :250]
# field_solution = np.sqrt(np.exp(field_solution) )
# prediction_field = prediction_field * field_std.cpu().numpy() + field_mean.cpu().numpy()[:, :250]
# prediction_field = np.sqrt(np.exp(prediction_field))
# vmin = np.min((np.min(prediction_field[idx][freq]), np.min(field_solution[idx][freq])))
# vmax = np.max((np.max(prediction_field[idx][freq]), np.max(field_solution[idx][freq])))
# plt.imshow(bead_patterns[idx][0], cmap=plt.cm.gray)
# plt.tight_layout()
# plt.axis("off")
# plt.savefig(save_dir + "/bead_pattern.png", format='png', transparent=True)
# plt.show()
# plt.imshow(field_solution[idx][freq], cmap=plt.cm.gray, vmin=vmin, vmax=vmax)
# plt.tight_layout()
# plt.axis("off")
# plt.savefig(save_dir + "/solution.png", format='png', transparent=True)
# plt.show()
# plt.imshow(prediction_field[idx][freq], cmap=plt.cm.gray)
# plt.tight_layout()
# plt.axis("off")
# plt.savefig(save_dir + "/pred_field.png", format='png', transparent=True)
# plt.show()

fig,ax = plt.subplots(1, 1, figsize=(2.7 / 2 *2.65748, 2.65748))

# Increase line width and marker size
ax.plot(np.arange(0, 250), data["preds"]["actual"][idx], lw=2.5, color="#909090", linestyle='dashed', label="Reference")
ax.plot(np.arange(0, 250), data["preds"]["query_unet"][idx], color="#55a78c", lw=1.5, alpha = 0.8, label="Query-UNet")
ax.plot(np.arange(0, 250), data["preds"]["unet"][idx], color="#e19c2c", lw=1.5, alpha = 0.8, label="Grid-UNet")
#ax.plot(freq, data["preds"]["query_unet"][idx][freq], 'x', mew=3, markersize=40, color="r")
ax.set_yticks(np.arange(-1, 6, 2))
fontsize= 10
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)
ax.grid(which="major") 
ax.set_ylim(-1, 6.5)
ax.axvline(x=freq, color='red', linestyle='--', label='Frequency $\it{f}$')
ax.legend(fontsize=fontsize, loc="upper right")

# Increase axis label font sizes
plt.xlabel('Frequency', fontsize=fontsize)
#plt.ylabel('Amplitude', fontsize=40)
plt.tight_layout()

plt.savefig(save_dir + "/pred.svg", format='svg', dpi = 600, transparent=True)

plt.show()

In [None]:

import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed

def plot_frequency(frequency=104, model="unet", idx=48, data=data):
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    eps = 1e-9
    field_solution = data["preds_field"]["actual"]
    field_solution = field_solution * field_std.cpu().numpy() + field_mean.cpu().numpy()[:, :250]
    field_solution = np.sqrt(np.exp(field_solution))
    actual_frequency = data["preds"]["actual"]
    prediction = data["preds"][model]
    prediction_field = data["preds_field"][model]
    prediction_field = prediction_field * field_std.cpu().numpy() + field_mean.cpu().numpy()[:, :250]
    prediction_field = np.sqrt(np.exp(prediction_field) )

    image = data["bead_patterns"][:, 0]
    # Get absolute min and max over both prediction_field and field_solution
    vmin = np.min((np.min(prediction_field[idx][frequency]), np.min(field_solution[idx][frequency])))
    vmax = np.max((np.max(prediction_field[idx][frequency]), np.max(field_solution[idx][frequency])))
    #vmin, vmax = None, None
    images = [image[idx], prediction_field[idx][frequency], field_solution[idx][frequency]]
    titles = ["images", "prediction", "actual"]
    for ax, img, title in zip(axs.flat[:-1], images, titles):
        if title == "images":
            cax = ax.imshow(img, cmap=plt.cm.gray)
        else:
            cax = ax.imshow(img, cmap=plt.cm.gray, vmin=vmin, vmax=vmax)
            #cax = ax.imshow(img)
            #cax = ax.contour(img, colors='white', levels=10, alpha=0.2) 
            fig.colorbar(cax, ax=ax, label='Value')
        
        ax.set_title(title)
    
    axs[-1].plot(np.arange(250), actual_frequency[idx], label='Actual')
    axs[-1].plot(np.arange(250), prediction[idx], label='Prediction')
    axs[-1].axvline(x=frequency, color='red', linestyle='--', label='Selected Frequency')

    axs[-1].legend()
    plt.tight_layout()
    plt.show()
    return None

plot = interact(plot_frequency, idx=(0, 15, 1), frequency=(0, 249, 1), model=["unet", "query_unet", "query_rn18"], data=fixed(data))

In [None]:

import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed

def plot_frequency(frequency=104, model="unet", idx=48, data=data):
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    field_solution = data["preds_field"]["actual"]


    actual_frequency = data["preds"]["actual"]
    prediction = data["preds"][model]
    prediction_field = data["preds_field"][model]

    image = data["bead_patterns"][:, 0]
    # Get absolute min and max over both prediction_field and field_solution
    vmin = np.min((np.min(prediction_field), np.min(field_solution)))
    vmax = np.max((np.max(prediction_field), np.max(field_solution)))
    images = [image[idx], prediction_field[idx][frequency], field_solution[idx][frequency]]
    titles = ["images", "prediction", "actual"]
    
    for ax, img, title in zip(axs.flat[:-1], images, titles):
        if title == "images":
            cax = ax.imshow(img, cmap=plt.cm.gray)
        else:
            cax = ax.imshow(img, cmap=plt.cm.gray, vmin=vmin, vmax=vmax)
            fig.colorbar(cax, ax=ax, label='Value')
        
        ax.set_title(title)
    
    axs[-1].plot(np.arange(250), actual_frequency[idx], label='Actual')
    axs[-1].plot(np.arange(250), prediction[idx], label='Prediction')
    axs[-1].axvline(x=frequency, color='red', linestyle='--', label='Selected Frequency')

    axs[-1].legend()
    plt.tight_layout()
    plt.show()
    return None

plot = interact(plot_frequency, idx=(0, 7, 1), frequency=(0, 249, 1), model=["unet", "query_unet", "query_rn18"], data=fixed(data))