In [None]:
# This script is used to create Figure 7 in the paper
# It requires running TIMEVIEW_interface_only.sh or Table_3.sh first to generate the results

%matplotlib inline
import sys
sys.path.append('../../')
from timeview.visualize import expert_tts_plot, grid_tts_plot
from experiments.datasets import load_dataset
from experiments.benchmark import load_column_transformer, create_benchmark_datasets_if_not_exist
from timeview.lit_module import load_model
from experiments.baselines import YNormalizer
from experiments.analysis.analysis_utils import find_results

dataset_name = "synthetic_tumor_wilkerson_1"
model_name = "TTS"

create_benchmark_datasets_if_not_exist(dataset_description_path="../dataset_descriptions")

results = find_results(dataset_name, model_name)

if len(results) == 0:
    print(f"No results found for {dataset_name} and {model_name}")
    print("Make sure you run your experiments from ../run_scripts")
elif len(results) > 1:
    print("Multiple results found for the given dataset and model")
    print("We take the last one but it may produce unexpected results")

timestamp = results[-1]

litmodel = load_model(timestamp, seed=661058651, benchmarks_folder="../benchmarks")
dataset = load_dataset(dataset_name, dataset_description_path="../dataset_descriptions")
column_transformer = load_column_transformer(timestamp, benchmarks_dir="../benchmarks")
y_normalizer = YNormalizer.load_from_benchmark(timestamp, model_name, benchmark_dir="../benchmarks")

expert_tts_plot(litmodel, dataset, (0.0,2.0), n_points=100, figsize=(3.5,3.5), column_transformer=column_transformer, y_normalizer=y_normalizer, display_feature_names=['age','weight','initial','dose'],default_values={'age': 52.4, 'weight': 89.2, 'initial_tumor_volume': 0.38, 'dosage': 0.08})