In [1]:
import os
import tensorflow as tf
import plotly.graph_objects as go

from heartkit.tasks import TaskFactory
from typing import Type, TypeVar
from argdantic import ArgField, ArgParser
from pydantic import BaseModel
from heartkit.utils import env_flag, set_random_seed, setup_logger
from plotly.subplots import make_subplots


from heartkit.tasks.AFIB_Ident.utils import (
    create_model,
    load_datasets,
    load_test_datasets,
    load_train_datasets,
    prepare,
)

from heartkit.defines import (
    HKDemoParams
)
from heartkit.tasks.AFIB_Ident.defines import (
    get_class_mapping,
    get_class_names,
    get_class_shape,
    get_classes,
    get_feat_shape,
)

cli = ArgParser()
B = TypeVar("B", bound=BaseModel)


def parse_content(cls: Type[B], content: str) -> B:
    """Parse file or raw content into Pydantic model.

    Args:
        cls (B): Pydantic model subclasss
        content (str): File path or raw content

    Returns:
        B: Pydantic model subclass instance
    """
    if os.path.isfile(content):
        with open(content, "r", encoding="utf-8") as f:
            content = f.read()

    return cls.model_validate_json(json_data=content)


config = '../configs/arrhythmia-100class-2.json'
params = parse_content(HKDemoParams, config)


params.seed = set_random_seed(params.seed)
params.data_parallelism = 8

class_names = get_class_names(params.num_classes)
class_map = get_class_mapping(params.num_classes)
input_spec = (
    tf.TensorSpec(shape=get_feat_shape(params.frame_size), dtype=tf.float32),
    tf.TensorSpec(shape=get_class_shape(params.frame_size, params.num_classes), dtype=tf.int32),
)
# since now we are getting one minute for every frame so it should be 400 / 100 * 15 = 60 seconds
datasets = load_datasets(
    ds_path=params.ds_path,
    frame_size=params.frame_size * 15,
    sampling_rate=params.sampling_rate,
    class_map=class_map,
    spec=input_spec,
    datasets=params.datasets,
)

2024-03-20 11:14:09.392334: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-20 11:14:09.395969: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-20 11:14:09.434496: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-20 11:14:09.434540: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-20 11:14:09.435853: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

In [2]:
from typing import List, Dict

import datetime
import random
import numpy as np
from tqdm import tqdm
from heartkit.rpc.backends import EvbBackend, PcBackend
from IPython.display import clear_output
from heartkit.tasks.AFIB_Ident.utils import (
    create_model,
    load_datasets,
    load_test_datasets,
    load_train_datasets,
    prepare,
)

from enum import IntEnum
from heartkit.defines import (
    HKDemoParams, HeartBeat, HeartRate, HeartRhythm, HeartSegment
)

from heartkit.tasks.AFIB_Ident.defines import (
    get_class_mapping,
    get_class_names,
    get_class_shape,
    get_classes,
    get_feat_shape,
)

class IcentiaRhythm(IntEnum):
    """Icentia rhythm labels"""
    noise = 0
    normal = 1
    afib = 2
    aflut = 3
    end = 4

HeartRhythmMap = {
    IcentiaRhythm.noise: HeartRhythm.noise,
    IcentiaRhythm.normal: HeartRhythm.normal,
    IcentiaRhythm.afib: HeartRhythm.afib,
    IcentiaRhythm.aflut: HeartRhythm.aflut,
    IcentiaRhythm.end: HeartRhythm.noise,
}




def visualize_prediction(
        # patient_id,
        fig: go.Figure,
        ts,
        sub_x: np.ndarray,
        y_pred: np.ndarray, 
        y_orig: np.ndarray,
        class_names: List[str],
        color_dict: Dict[int, str],
        row_idx: int = 1,
):

    primary_color = "#11acd5"
    
    
    for i in tqdm(range(0, sub_x.shape[0], params.frame_size), desc="Inference"):
        if i % (5*params.frame_size) == 0:
            # start a new row for the make_plots
            row_idx += 1

        if i + params.frame_size > sub_x.shape[0]:
            start, stop = sub_x.shape[0] - params.frame_size, sub_x.shape[0]
        else:
            start, stop = i, i + params.frame_size


        fig.add_annotation(
            x=ts[start] + (ts[stop-1] - ts[start]) / 2,
            y=np.min(x)*0.8,
            text=class_names[y_pred[start]],
            showarrow=False,
            row=row_idx,
            col=1,
            font=dict(color=color_dict[y_pred[start]]),
        )

        # predicted results
        fig.add_vrect(
            x0=ts[start],
            x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
            y0=np.min(x)/3+0.2,
            # y1=np.max(x[start:stop]) / 2,
            y1=np.min(x)/3,  
            fillcolor=color_dict[y_pred[start]],
            opacity=0.25,
            line_width=0,
            row=row_idx,
            col=1,
            secondary_y=False,
        )

        # original results
        fig.add_annotation(
            x=ts[start] + (ts[stop-1] - ts[start]) / 2,
            y=np.max(x),
            text=class_names[y_orig[i // 400]],
            showarrow=False,
            row=row_idx,
            col=1,
            font=dict(color=color_dict[y_orig[i // 400]]),
        )

        fig.add_vrect(
            x0=ts[start],
            x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
            # y0=np.max(x)/2,
            y0=0.9,
            # y1=np.max(x)*0.8,  
            y1=1.1,
            fillcolor=color_dict[y_orig[i // 400]],
            opacity=0.25,
            line_width=0,
            row=row_idx,
            col=1,
            secondary_y=False,
        )
       
           # predction != original
        if y_pred[start] != y_orig[i // 400]:
            fig.add_vrect(
                x0=ts[start],
                x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
                y0=0.9,
                y1=1.1,
                fillcolor="red",
                opacity=0.25,
                line_width=2,
                line_color="red",
                row=row_idx,
                col=1,
                secondary_y=False,
            )

        # finally add the ECG wave
        fig.add_trace(
            go.Scatter(
                x=ts[start:stop],
                y=x[start:stop],
                name="ECG",
                mode="lines",
                line=dict(color=primary_color, width=2),
                showlegend=False,
            ),
            row=row_idx,
            col=1,
            secondary_y=False,
        )
    return fig

In [3]:
BackendRunner = EvbBackend if params.backend == "evb" else PcBackend
runner = BackendRunner(params=params)

# input should be a SignalMetaGenerator
patient_ids = datasets[0].get_test_patient_ids()
single_pat_gen = datasets[0].uniform_patient_generator(patient_ids=[patient_ids[0]], repeat=False, shuffle=False)
single_pat_gen

<generator object IcentiaDataset.uniform_patient_generator at 0x1555511ed480>

In [4]:
id = patient_ids[0]
segment_id = 0
x_start = 0
x_end = 5000


continuous_gen = datasets[0].signal_label_TimeFrame_generator(single_pat_gen, segment_id=segment_id, frame_start=x_start, frame_end=x_end)

In [6]:
tmp = next(continuous_gen)
tmp

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'datasets/icentia11k/p10000.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
color_dict = {
    -1: "#505050",  # Grey color for -1
    0: "#11acd5",  # Blue color for 0
    1: "#ce6cff",  # Purple color for 1
    2: "#a1d34f"   # Green color for 2
}



In [None]:
def segment_start_end_plot(id, seg_sig_gen, frame_start, frame_end):
    color_dict = {
        -1: "#505050",  # Grey color for -1
        0: "#11acd5",  # Blue color for 0
        1: "#ce6cff",  # Purple color for 1
        2: "#a1d34f"   # Green color for 2
    }


    # sub_x = 
    bg_color = "rgba(38,42,50,1.0)"
    plotly_template = "plotly_dark"

    n_sample = sub_x.shape[0] / params.frame_size
    nrow = int(n_sample/5)
    tod = datetime.datetime(2024, 5, 24, random.randint(12, 23), 00)
    ts = np.array([tod + datetime.timedelta(seconds=i / params.sampling_rate) for i in range(sub_x.shape[0])])
    for i in tqdm(range(0, sub_x.shape[0], params.frame_size), desc="Inference"):
        ratios = []
        if i % (5*params.frame_size) == 0:

            # start a new row for the make_plots
            row_idx += 1
            # print(row_idx)
        # this is [x.shape[0] - 400, x.shape[0]], get the earlier peak, this is the end
        if i + params.frame_size > sub_x.shape[0]:
            start, stop = sub_x.shape[0] - params.frame_size, sub_x.shape[0]
        else:
            start, stop = i, i + params.frame_size

        # print("Before inference this is the ts:", i, start, stop)
        xx = prepare(sub_x[start:stop], sample_rate=params.sampling_rate, preprocesses=params.preprocesses)
        runner.set_inputs(xx)
        runner.perform_inference()
        yy = runner.get_outputs()
        # y_orig[start:stop] = 
        # this is the predicted label for current frame
        y_pred[start:stop] = np.argmax(yy, axis=-1).flatten()
        # Assuming y_pred and y_orig are numpy arrays
        if y_pred[i] == y_orig[i]:
            ratios.append(1)
        else:
            ratios.append(0)

        print(np.sum(ratios) / (len(y_pred) / 400))
        whole_seg_pred.append(ratios)
    

    fig = make_subplots(
            rows=nrow,
            cols=1,
            specs=[[{"colspan": 1, "type": "xy", "secondary_y": True}]] * nrow,
            subplot_titles=(None, None),
            horizontal_spacing=0.05,
            vertical_spacing=0.1,
        )
    
    fig = visualize_prediction(
        fig,
        ts,
        sub_x,
        y_pred, 
        y_orig,
        class_names,
        color_dict,
        row_idx,
    )
    fig.update_layout(
        template=plotly_template,
        height=400*nrow,
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=10, r=10, t=80, b=80),
        legend=dict(groupclick="toggleitem"),
        title=f"Patient ID: {id}, Segment ID: {segment_id}",
        title_x=0.5,
    )
    fig.write_html(params.job_dir / "longer_demo.html", include_plotlyjs="cdn", full_html=True)
    fig.show()

In [None]:
row_idx = 0


bg_color = "rgba(38,42,50,1.0)"
primary_color = "#11acd5"
plotly_template = "plotly_dark"


ratios = [] # store whether predctions == ground truth for 60 minutes


# maximal length should be 1 minute







for i in tqdm(range(0, sub_x.shape[0], params.frame_size), desc="Inference"):
    if i % (5*params.frame_size) == 0:
        row_idx += 1
        print(row_idx)
    # this is [x.shape[0] - 400, x.shape[0]], get the earlier peak, this is the end
    if i + params.frame_size > sub_x.shape[0]:
        start, stop = sub_x.shape[0] - params.frame_size, sub_x.shape[0]
    else:
        start, stop = i, i + params.frame_size
    # print("Before inference this is the ts:", i, start, stop)
    xx = prepare(x[start:stop], sample_rate=params.sampling_rate, preprocesses=params.preprocesses)
    runner.set_inputs(xx)
    runner.perform_inference()
    yy = runner.get_outputs()
    y_pred[start:stop] = np.argmax(yy, axis=-1).flatten()

    fig.add_annotation(
        x=ts[start] + (ts[stop-1] - ts[start]) / 2,
        y=np.min(x)*0.8,
        text=class_names[y_pred[start]],
        showarrow=False,
        row=row_idx,
        col=1,
        font=dict(color=color_dict[y_pred[start]]),
    )

    # predicted results
    fig.add_vrect(
        x0=ts[start],
        x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
        y0=np.min(x)/3+0.2,
        # y1=np.max(x[start:stop]) / 2,
        y1=np.min(x)/3,  
        fillcolor=color_dict[y_pred[start]],
        opacity=0.25,
        line_width=0,
        row=row_idx,
        col=1,
        secondary_y=False,
    )

    # original results
    fig.add_annotation(
        x=ts[start] + (ts[stop-1] - ts[start]) / 2,
        y=np.max(x),
        text=class_names[y_orig[i // 400]],
        showarrow=False,
        row=row_idx,
        col=1,
        font=dict(color=color_dict[y_orig[i // 400]]),
    )

    fig.add_vrect(
        x0=ts[start],
        x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
        # y0=np.max(x)/2,
        y0=0.9,
        # y1=np.max(x)*0.8,  
        y1=1.1,
        fillcolor=color_dict[y_orig[i // 400]],
        opacity=0.25,
        line_width=0,
        row=row_idx,
        col=1,
        secondary_y=False,
    )

    # predction != original
    if y_pred[start] != y_orig[i // 400]:
        fig.add_vrect(
            x0=ts[start],
            x1=ts[stop-1] - datetime.timedelta(seconds=0.1),
            # y0=np.max(x)/2,
            y0=0.9,
            # y1=np.max(x)*0.8,  
            y1=1.1,
            # annotation_text=class_names[y_pred[start]],
            fillcolor="red",
            opacity=0.25,
            line_width=2,
            line_color="red",
            row=row_idx,
            col=1,
            secondary_y=False,
        )

    fig.add_trace(
        go.Scatter(
            x=ts[start:stop],
            y=x[start:stop],
            name="ECG",
            mode="lines",
            line=dict(color=primary_color, width=2),
            showlegend=False,
        ),
        row=row_idx,
        col=1,
        secondary_y=False,
    )            
# END FOR
runner.close()

    

fig.update_layout(
    template=plotly_template,
    height=400*nrow,
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=10, t=80, b=80),
    legend=dict(groupclick="toggleitem"),
    title=f"Patient ID: {patient_id}, Segment ID: {segment_id}",
    title_x=0.5,
)

fig.write_html(params.job_dir / "longer_demo.html", include_plotlyjs="cdn", full_html=True)
fig.show()
