In [None]:
import os

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib
import pickle
import errno
from ray_nn.data.lightning_data_module import DefaultDataModule
from ray_nn.nn.xy_hist_data_models import Model, MetrixXYHistSurrogate, StandardizeXYHist
#from ray_nn.nn.xy_hist_data_models_old import MetrixXYHistSurrogateOld
import glob
import torch
from tqdm.auto import tqdm
from ray_nn.nn.xy_hist_data_models import StandardizeXYHist
from ray_tools.simulation.torch_datasets import (
    BalancedMemoryDataset,
    HistDataset,
)
from scipy.stats import ttest_rel
%load_ext autoreload
%autoreload 2
import re

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Generate random data
x = np.linspace(0, 10, 100)
y = np.random.rand(100)

# Create the plot with specific size
plt.figure(figsize=(4.3, 1.7))
plt.plot(x, y, color='teal', linewidth=2)
plt.title("Random Plot", fontsize=10)
plt.xlabel("X-axis", fontsize=8)
plt.ylabel("Y-axis", fontsize=8)
plt.tight_layout()
plt.show()


In [None]:
def get_checkpoint_path(identifier, prefix="outputs/xy_hist", suffix="checkpoints"):
    """
    Get the checkpoint file path with the highest step for a given identifier.
    
    Args:
        identifier (str): The identifier for the checkpoint folder.
        prefix (str): The base directory where checkpoints are stored.
        suffix (str): The subdirectory containing checkpoint files.
        
    Returns:
        str: Full path to the checkpoint file with the highest step, or None if no files found.
    """
    base_path = os.path.join(prefix, identifier, suffix)
    if not os.path.exists(base_path) or not os.path.isdir(base_path):
        raise FileNotFoundError(
            errno.ENOENT, os.strerror(errno.ENOENT), base_path)
    
    highest_step = -1
    highest_ckpt = None
    step_pattern = re.compiler("step=(\d+)\.ckpt$")
    
    for file_name in os.listdir(base_path):
        match = step_pattern.search(file_name)
        if match:
            step = int(match.group(1))
            if step > highest_step:
                highest_step = step
                highest_ckpt = file_name
    
    if highest_ckpt:
        return os.path.join(base_path, highest_ckpt)
    return None

@staticmethod
def result_dict_to_latex(statistics_dict, reference_key="ref"):
    if len(result_dict) < 4:
        alignment = "l" * len(statistics_dict)
        table_environment = "tabular"
    else:
        alignment = r"""*{"""+str(len(statistics_dict))+r"""}{>{\centering\arraybackslash}X}"""
        table_environment = "tabularx"
    
    if table_environment =="tabularx":
        text_width =  r"""{\textwidth}"""
    else:
        text_width = ""

    output_string = (
        r"""
    \begin{"""+table_environment+r"""}"""+text_width+r"""{p{2.5cm}|"""+
    alignment    
        + r"""}
    \hline"""
        + "\n"
    )
    scenarios = [k+r" $\pm\sigma$ (\acs{CI})" for k in statistics_dict.keys()]
    keys = ["Metric"] + scenarios
    output_string += " & ".join(keys) + r" \\" + "\n" 
    output_string += r"\hline" + "\n"
    
    model_keys = list(list(statistics_dict.values())[0].keys())

    for model_key in model_keys:
        model_row = [model_key]
        for (mean, std_dev, is_best, is_significant, p_value) in [v[model_key] for v in statistics_dict.values()]:
            model_row_element = f"{mean:.2e}".replace("e+0", "e+").replace("e-0", "e-")
            if is_best:
                model_row_element = r"\textbf{" + model_row_element + r"}"
            model_row_element = model_row_element+r" $\pm$ "+f"{std_dev:.2e}".replace("e+0", "e+").replace("e-0", "e-")
            if not model_key == reference_key:
                model_row_element += f" ({p_value[0]:.2e}, {p_value[1]:.2e})".replace("e+0", "e+").replace("e-0", "e-")
                if is_significant:
                    model_row_element += r"$\dagger$"
            model_row += [model_row_element]
        output_string += " & ".join(model_row) + r" \\" + "\n"

    output_string += r"""\hline
    \end{"""+table_environment+r"""}"""
    return output_string
def evaluate_model_dict_to_result_dict(model_dict):
    result_dict = {}
    for scenario_name, scenario_subset in tqdm(metrics_dict.items()):
        result_dict[scenario_name] = {model_key: evaluate(model, scenario_subset) for model_key, model in model_dict.items()}
    return result_dict

@staticmethod
def significant_confidence_levels(group_A, group_B, confidence=0.99):
    ci = ttest_rel(group_A.flatten().cpu(), group_B.flatten().cpu()).confidence_interval(confidence_level=confidence)
    confidence_interval = (ci.low.item(), ci.high.item())
    return not (confidence_interval[0] < 0. and confidence_interval[1] > 0.), confidence_interval


def statistics(result_dict, reference_key="ref"):
    min_mean = float('inf')
    statistics_dict = {}
    for key, value in result_dict.items():
        mean = value.mean()
        statistics_dict[key] = (mean.item(), value.std().item())
        if mean < min_mean:
            min_mean_key = key
            min_mean = mean

    for key, value in result_dict.items():
         statistics_dict[key] =  statistics_dict[key] + (key==min_mean_key,) + significant_confidence_levels(value, result_dict[reference_key])
         diff = (result_dict[key] - result_dict[min_mean_key]).flatten().abs().cpu()
         mean = torch.mean(diff)
         std_dev = torch.std(diff)
    return statistics_dict

def model_paths_to_model_dict(model_paths):
    models_dict = {}
    for key, identifier in model_paths.items():
        path = get_checkpoint_path(identifier)
        models_dict[key] = MetrixXYHistSurrogate.load_from_checkpoint(
        checkpoint_path=path,
        #hparams_file="/path/to/experiment/version/hparams.yaml",
        map_location=None,
        )
    return models_dict
def evaluate(model, subset='good', load_len=2000000):
    model.criterion = torch.nn.MSELoss(reduction='none')
    standardizer = model.standardizer
    output_list = []
    sub_groups = ['parameters', 'histogram/ImagePlane', 'n_rays/ImagePlane']
    transforms=[lambda x: x[:, 1:].float(), lambda x: standardizer(x.flatten(start_dim=1).float()), lambda x: x.int()]
    dataset = HistDataset([13, 14], "datasets/metrix_simulation/ray_emergency_surrogate_50+50+z+-30", "histogram_*.h5", sub_groups, transforms, normalize_sub_groups=['parameters'], load_max=load_len//2)
    memory_dataset = BalancedMemoryDataset(dataset=dataset, min_n_rays=1, subset=subset)
    del dataset
    num_workers = 0
    datamodule = DefaultDataModule(test_dataset=memory_dataset, num_workers=num_workers, batch_size_val=1024)

    for x, y in tqdm(datamodule.test_dataloader(), leave=False):
        with torch.no_grad():
            #y_hat = model(x.to(model.device))
            x = x.to(model.device)
            y = y.to(model.device)
            output_list.append(model.test_step((x,y)).mean(dim=-1))
    if len(output_list) > 0:
        output_tensor = torch.cat(output_list)
    return output_tensor

In [None]:
metrics_dict = {r"Nonempty \acs{MSE}":"good", r"Empty \acs{MSE}":"bad"}
model_paths = {
    "ref": "s021yw7n",
    #"gbp_2": "cuvbjeb9",
    #"gbp_4": "dbv0tc5h",
    #"plat_10": "2qbmgk8z",
    #"5_l": "d42xqlbe",
    #"6_l": "zpd30wi1",
    #"blow_1": "mwko1ldx",
    #"blow_3": "y8gw1ebu",
    #"min_n_1": "x6pch3kp",
    #"min_n_20": "yjqv9nlo",
    "adam_w": "sias70j8",
    #"ReLU": "tqyfx4y5",
              }
model_dict = model_paths_to_model_dict(model_paths)

result_dict = evaluate_model_dict_to_result_dict(model_dict)

In [None]:
statistics_dict = {key: statistics(value) for key, value in result_dict.items()}
print(result_dict_to_latex(statistics_dict))

In [None]:
def test_dataloader(path, dataset_ids = [10, 11], load_len=None, batch_size=128, num_workers=0, standardizer=StandardizeXYHist(), sub_groups=['parameters', 'histogram/ImagePlane', 'n_rays/ImagePlane'], normalize_sub_groups = ['parameters'], file_pattern = 'histogram_*.h5', subset=None):
    transforms = [lambda x: x[:, 1:].float(), lambda x: standardizer(x.flatten(start_dim=1).float()), lambda x: x.int()]
    hist_dataset = HistDataset(dataset_ids, path, file_pattern, sub_groups, transforms, normalize_sub_groups, load_max=load_len)
    if subset is not None:
        hist_dataset = BalancedMemoryDataset(hist_dataset, subset=subset)
    datamodule = DefaultDataModule(train_dataset=None, val_dataset=None, test_dataset=hist_dataset, batch_size_train=batch_size, batch_size_val=batch_size, num_workers=num_workers)
    return datamodule.test_dataloader()

def calculate_mses(model, dataloader):
    mses = []
    with torch.no_grad():
        for i in tqdm(dataloader):
            cpu_model = model.model_orig.cpu()
            mses.append(((cpu_model(i[0])-i[1])**2).mean(dim=-1))
    mses_tensor = torch.hstack(mses)
    return mses_tensor

def mse_statistics(path, mses_tensor):
    plt.clf()
    plt.figure(figsize=(4.7, 3.0))
    print((mses_tensor < 1e-5).sum(), mses_tensor.max())
    reduced = mses_tensor[(mses_tensor > 1e-5)]
    plt.hist(mses_tensor.cpu(), bins=100)
    plt.axvline(x=1e-5, color='red', linestyle='--', linewidth=1, label=r'$x=1 \times 10^{-5}$')
    plt.ylabel("Count [#]")
    plt.yscale('log')
    plt.xlabel("MSE")
    plt.tight_layout()
    plt.savefig(os.path.join(path,"mse_hist.pdf"), bbox_inches="tight")
    return plt.gcf()

In [None]:
root_dir = ''

model_path = os.path.join(root_dir, "outputs/xy_hist/s021yw7n/checkpoints/epoch=235-step=70000000.ckpt")
model = Model(path=model_path)

path = os.path.join(root_dir,'datasets/metrix_simulation/ray_emergency_surrogate_50+50+z+-30/')
dataloader = test_dataloader(path)
mses_tensor = calculate_mses(model, dataloader=dataloader)

In [None]:
mse_statistics(os.path.join(root_dir,"outputs/"), mses_tensor)

In [None]:
dataloader = test_dataloader(path, subset='good')
for x in dataloader:
    prediction = model.model_orig(x[0].to(model.device))
    ground_truth = x[1]
    break


In [None]:
fig = MetrixXYHistSurrogate.plot_data_3(prediction[:3].detach().cpu().numpy(), ground_truth[:3].detach().cpu().numpy())
fig.savefig(os.path.join(root_dir, "outputs/surrogate_vs_ray_ui.pdf"))

with open(os.path.join(root_dir, "outputs/prediction.pkl"), 'wb') as handle:
    pickle.dump(prediction, handle)
with open(os.path.join(root_dir, "outputs/ground_truth.pkl"), 'wb') as handle:
    pickle.dump(ground_truth, handle)

In [None]:
with open(os.path.join(root_dir, "outputs/prediction.pkl"), 'rb') as f:
    prediction = pickle.load(f)
with open(os.path.join(root_dir, "outputs/ground_truth.pkl"), 'rb') as f:
    ground_truth = pickle.load(f)

In [None]:
fig = MetrixXYHistSurrogate.plot_data_3(prediction[:3].detach().cpu().numpy(), ground_truth[:3].detach().cpu().numpy())
fig.savefig(os.path.join(root_dir, "outputs/surrogate_vs_ray_ui.pdf"),  bbox_inches='tight')