# Train linear NCEM

In [None]:
%load_ext autoreload
%autoreload 2

import ncem
import numpy as np
import seaborn as sns
import tensorflow as tf



from scipy.stats import ttest_rel, ttest_ind

sns.set_palette("colorblind")

# paths
data_path_base ="../input-data/raw-data/"
out_path = "../output-data/Hartmann-2021/"
fn_out_cv = out_path + "/results/"

## If errors occur with CUDA

In [None]:
# If `InternalError: libdevice not found at ./libdevice.10.bc [Op:__inference_one_e_step_2806]`
# --> try including in shell $PATH:
#        `export XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda`

# If `"Attempting to perform BLAS operation using StreamExecutor without BLAS support`
# --> try setting a dedicated amount of GPU vram:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=500)]) # 1500 [MB]
    except RuntimeError as e:
        print('ERROR')

# Dataset specific inputs

In [None]:
data_set = 'hartmann'
data_path = data_path_base + '/Hartmann-2021/'
log_transform = False # Hartmann DS is already arcsinh transformed
use_domain = True
scale_node_size=False
merge_node_types_predefined = True
covar_selection = []
output_layer='linear'


# Manual inputs

In [None]:
model_class = 'interactions'
optimizer = 'adam'
domain_type = 'patient'

learning_rate = 0.05
l1 = 0.
l2 = 0.

batch_size = 58
radius = 35
n_eval_nodes = 10

gs_id = f"tutorial_{model_class}_{radius}_{data_set}_{domain_type}"



# Model and training


In [None]:
ncv = 3
epochs = 2000 if "tutorial" not in gs_id else 10 
epochs_warmup = 0
max_steps_per_epoch = 20
patience = 100
lr_schedule_min_lr = 1e-10
lr_schedule_factor = 0.5
lr_schedule_patience = 50
val_bs = 16
max_val_steps_per_epoch = 10
shuffle_buffer_size = None

feature_space_id = "standard"
cond_feature_space_id = "type"

use_covar_node_label = False
use_covar_node_position = False
use_covar_graph_covar = False

In [None]:
trainer = ncem.train.TrainModelInteractions()
trainer.init_estim(log_transform=log_transform)

trainer.estimator.get_data(
    data_origin=data_set,
    data_path=data_path,
    radius=radius,
    graph_covar_selection=covar_selection,
    node_label_space_id=cond_feature_space_id,
    node_feature_space_id=feature_space_id,
    # feature_transformation=transformation_dict[transform_key],
    use_covar_node_position=use_covar_node_position,
    use_covar_node_label=use_covar_node_label,
    use_covar_graph_covar=use_covar_graph_covar,
    # hold_out_covariate=hold_out_covariate,
    domain_type=domain_type,
    # merge_node_types_predefined=merge_node_types_predefined,
)

In [None]:
trainer.estimator.split_data_node(
    validation_split=0.1,
    test_split=0.1,
    seed=0
)

In [None]:
trainer.estimator.init_model(
    optimizer=optimizer,
    learning_rate=learning_rate,
    n_eval_nodes_per_graph=n_eval_nodes,

    l2_coef=l2,
    l1_coef=l1,
    use_interactions=True,
    use_domain=use_domain,
    scale_node_size=scale_node_size,
    output_layer=output_layer,
)
trainer.estimator.model.training_model.summary()

In [None]:
trainer.estimator.train(
    epochs=epochs,
    epochs_warmup=epochs_warmup,
    batch_size=batch_size,
    max_steps_per_epoch=max_steps_per_epoch,
    validation_batch_size=val_bs,
    max_validation_steps=max_val_steps_per_epoch,
    patience=patience,
    lr_schedule_min_lr=lr_schedule_min_lr,
    lr_schedule_factor=lr_schedule_factor,
    lr_schedule_patience=lr_schedule_patience,
    monitor_partition="val",
    monitor_metric="loss",
    shuffle_buffer_size=shuffle_buffer_size,
    early_stopping=True,
    reduce_lr_plateau=True,
)

In [None]:
evaluation_test = trainer.estimator.evaluate_any(
    img_keys=trainer.estimator.img_keys_test,
    node_idx=trainer.estimator.nodes_idx_test
)

In [None]:
evaluation_test

In [None]:
split_per_node_type, evaluation_per_node_type = trainer.estimator.evaluate_per_node_type()

In [None]:
evaluation_per_node_type['Fibroblast']