In [41]:
import numpy as np
from tensorflow.keras.models import load_model
from ml4h.models.model_factory import get_custom_objects
from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2
from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3

In [42]:
output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]}
custom_dict = get_custom_objects([mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3])
model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict)
ecg = np.random.random((1, 5000, 12))
prediction = model(ecg)

In [43]:
for name, pred in zip(model.output_names, prediction):
    otm = output_tensormaps[name]
    if otm.is_survival_curve():
        intervals = otm.shape[-1] // 2
        days_per_bin = 1 + otm.days_window // intervals
        predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
        print(f'AF Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}')
    else:
        print(f'{otm} prediction is {pred}')
        

AF Risk TensorMap(survival_curve_af, (50,), survival_curve) prediction is: 0.3244773745536804
TensorMap(sex_from_wide, (2,), categorical) prediction is [[0.909781   0.09021906]]
TensorMap(age_from_wide_csv, (1,), continuous) prediction is [[-0.2924141]]
TensorMap(af_in_read, (2,), categorical) prediction is [[0.9474023  0.05259773]]


In [37]:
lead_I_model = load_model('./strip_I_survival_curve_af_v2021_06_15.h5', custom_objects=custom_dict)
ecg = np.random.random((1, 5000, 1))
prediction = lead_I_model(ecg)
print('Lead I inference:')
for name, pred in zip(model.output_names, prediction):
    otm = output_tensormaps[name]

    if otm.is_survival_curve():
        intervals = otm.shape[-1] // 2
        days_per_bin = 1 + otm.days_window // intervals
        predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
        print(f'Predicted survival {str(1 - predicted_survivals[0, -1])}')
    else:
        print(f'{otm.name} prediction is {pred}')

Lead I inference:
Predicted survival 0.1239469051361084
sex_from_wide prediction is [[0.92011654 0.0798834 ]]
age_from_wide_csv prediction is [[0.2401576]]
af_in_read prediction is [[0.9965096  0.00349043]]
