In [None]:
import numpy as np
from tensorflow.keras.models import load_model
from ml4h.models.model_factory import get_custom_objects
from ml4h.tensormap.ukb.demographics import age_in_days, af_dummy2, sex_dummy1
from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2, mgb_death_wrt_instance2

In [None]:
output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, mgb_death_wrt_instance2, age_in_days, af_dummy2, sex_dummy1]}
model = load_model('./ecg2af_quintuplet_v2024_01_13.keras')
ecg = np.random.random((1, 5000, 12))
prediction = model(ecg)

In [None]:
for name, pred in zip(model.output_names, prediction):
    otm = output_tensormaps[name]
    if otm.is_survival_curve():
        intervals = otm.shape[-1] // 2
        days_per_bin = 1 + otm.days_window // intervals
        predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
        print(f'AF Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}')
    else:
        print(f'{otm} prediction is {pred}')
        