In [3]:
import os
import numpy as np
from rinet.models import RINetV1Pipeline, RINetV2Pipeline, GMMPipeline, RefineRPipeline, ReflimRPipeline
import helpers as lib

# ---- prepare indirect estimation pipelines -------
models = {
    'rinet_v1': RINetV1Pipeline(),
    'rinet_v1_log': RINetV1Pipeline(),
    'rinet_v2': {
        '1d': RINetV2Pipeline(),
        '2d': RINetV2Pipeline(ndim=2)
    },
    'gmm': GMMPipeline(),
    'reflimR': ReflimRPipeline(),
    'refineR': RefineRPipeline()
}
trained_model_path_v1 = '../../modeling/v1/model/'
trained_model_path_v2_1d = '../../modeling/v2/model_1d_tuned/'
trained_model_path_v2_2d = '../../modeling/v2/model_2d_tuned/'
models['rinet_v1'].load(trained_model_path_v1)  # load the trained model weights for rinet
models['rinet_v1_log'].load(trained_model_path_v1)  # load the trained model weights for rinet
models['rinet_v2']['1d'].load(trained_model_path_v2_1d, model_file='best_model.keras')
models['rinet_v2']['2d'].load(trained_model_path_v2_2d, model_file='best_model.keras')

data_dir = '../../data/'
datasets = {
    'liver_1d_ornone': f"{data_dir}/liver/outlier_removal_none/1d/samples.pkl",
    'liver_2d_ornone': f"{data_dir}/liver/outlier_removal_none/2d/samples.pkl",
    'liver_1d_orsample': f"{data_dir}/liver/outlier_removal_samplewise/1d/samples.pkl",
    'liver_2d_orsample': f"{data_dir}/liver/outlier_removal_samplewise/2d/samples.pkl",
    'liver_1d_orpanel': f"{data_dir}/liver/outlier_removal_panelwise/1d/samples.pkl",
    'liver_2d_orpanel': f"{data_dir}/liver/outlier_removal_panelwise/2d/samples.pkl",
    'simulated_1d': f"{data_dir}/simulated/simulated_1d/test/original_data.pkl",
    'simulated_2d': f"{data_dir}/simulated/simulated_2d/test/original_data.pkl"
}

dataset_adapters = {
    "liver_1d_ornone": lib.Liver1DAdapter(),
    "liver_1d_orsample": lib.Liver1DAdapter(),
    "liver_1d_orpanel": lib.Liver1DAdapter(),
    "simulated_1d": lib.Simulated1DAdapter(),
    "liver_2d_ornone": lib.Liver2DAdapter(),
    "liver_2d_orsample": lib.Liver2DAdapter(),
    "liver_2d_orpanel": lib.Liver2DAdapter(),
    "simulated_2d": lib.Simulated2DAdapter(),
}


def run_or_load_model(out_dir, model_key, dataset_key, data):
    os.makedirs(out_dir, exist_ok=True)
    out_path = f"{out_dir}/{model_key}_{dataset_key}.pkl"

    adapter = dataset_adapters.get(dataset_key,lib.BaseDatasetAdapter())
    model = models[model_key]

    # handle rinet_v2 special sub-models
    if model_key == "rinet_v2":
        if "1d" in dataset_key:
            model = model["1d"]
        elif "2d" in dataset_key:
            model = model["2d"]
        else:
            raise ValueError(f"Unexpected dataset key for rinet_v2: {dataset_key}")

    # check compatibility
    if not adapter.is_compatible(model_key):
        raise ValueError(f"Model {model_key} incompatible with {dataset_key}")

    # load cache if exists
    if os.path.exists(out_path):
        print("\t\tFound prediction file")
        return lib.load_pickle(out_path)

    # prepare input
    data_in = [adapter.transform_in(i, model_key) for i in data]
    kwargs = getattr(adapter, "extra_kwargs", lambda key: {})(model_key)

    # run model
    p_full = model.predict(data_in, **kwargs)
    if '1d' in dataset_key:
        p_ri = [
            adapter.transform_out(lib.out_to_ri(model_key, i), j, model_key)
            if i is not None else np.nan
            for i, j in zip(p_full, data)
        ]
    else:
        p_ri = None

    # save cache
    lib.save_pickle([p_ri, p_full], out_path)
    print("\t\tSaved prediction file")

    return p_ri, p_full


  saveable.load_own_variables(weights_store.get(inner_path))
  saveable.load_own_variables(weights_store.get(inner_path))


In [4]:
out_dir = 'predictions'

p = {}
for d in datasets:
    data = lib.load_pickle(datasets[d])
    p[d] = {}
    print(f"Processing... {d}")
    for m in models:
        if dataset_adapters[d].is_compatible(m):
            print(f"\tRunning... {m}")
            p[d][m] = run_or_load_model(out_dir, m, d, data)


Processing... liver_1d_ornone
	Running... rinet_v1
		Found prediction file
	Running... rinet_v1_log
		Found prediction file
	Running... rinet_v2
		Saved prediction file
	Running... gmm
		Found prediction file
	Running... reflimR
		Found prediction file
	Running... refineR
		Found prediction file
Processing... liver_2d_ornone
	Running... rinet_v2
		Saved prediction file
	Running... gmm
		Found prediction file
Processing... liver_1d_orsample
	Running... rinet_v1
		Found prediction file
	Running... rinet_v1_log
		Found prediction file
	Running... rinet_v2
		Saved prediction file
	Running... gmm
		Found prediction file
	Running... reflimR
		Found prediction file
	Running... refineR
		Found prediction file
Processing... liver_2d_orsample
	Running... rinet_v2
		Saved prediction file
	Running... gmm
		Found prediction file
Processing... liver_1d_orpanel
	Running... rinet_v1
		Found prediction file
	Running... rinet_v1_log
		Found prediction file
	Running... rinet_v2
		Saved prediction file
	R