In [None]:
import pandas as pd
import numpy as np
import os
import SimpleITK as sitk
from IPython.display import Markdown
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


from ipywidgets import VBox, HTML
from ipyevents import Event


In [None]:
mount_point = "/mnt/raid/C1_ML_Analysis/"
# df_frames = pd.read_csv(os.path.join(mount_point, "CSV_files/c3_blindsweep_annotation_labels_merged.csv"))
df_frames = pd.read_csv(os.path.join(mount_point, "CSV_files/c3_blindsweep_annotation_labels_merged_train.csv"))

In [None]:
df_frames.columns

In [None]:
id_column = "annotation_id"
img_column = "file_path"
tag_column = "tag"
frame_column = "frame_index"
frame_label="annotation_label"
pid = "pid"

df = df_frames[[id_column, img_column, tag_column, pid]].drop_duplicates().reset_index(drop=True)

frame_label_dict = {
    'reject': 0,
    'low_visible': 1,
    'high_visible': 2,
    'low_measurable': 3,
    'high_measurable': 4
}

frame_label_dict_inv = {v: k for k, v in frame_label_dict.items()}

In [None]:
df_annotated_frames = []
for idx, row in df.iterrows():

    uid = row[id_column]
    frames = df_frames[df_frames[id_column] == uid].sort_values(by=frame_column).reset_index(drop=True)

    img_path = os.path.join(mount_point, row[img_column])

    img = sitk.ReadImage(img_path)
    img_np = sitk.GetArrayFromImage(img)

    if img.GetNumberOfComponentsPerPixel() >  1:
        img_np = img_np[:,:,:,0]

    frame_idx = frames[frame_column].values.tolist()
    frame_idx = np.clip(frame_idx, 0, img_np.shape[0]-1)

    frame_labels = frames[frame_label].values.tolist()
    frame_labels_idx = [frame_label_dict[lbl] for lbl in frame_labels]

    img_labels_np = np.zeros(img_np.shape[0])
    img_labels_np[frame_idx] = np.array(frame_labels_idx)

    file_path_frames = []
    img_name = os.path.splitext(row[img_column])[0]
    for idx, label in enumerate(img_labels_np):
        file_path_frames.append(os.path.join("extract_frames_blind_sweeps", img_name, f"{idx}.nrrd"))
    df_annotated_frames.append(pd.DataFrame({
        "annotation_id": uid,
        "file_path": file_path_frames,
        "annotation_label": img_labels_np.astype(int).tolist()
    }))

df_annotated_frames = pd.concat(df_annotated_frames).reset_index(drop=True)
df_annotated_frames = df_annotated_frames.drop_duplicates().reset_index(drop=True)

In [None]:
df_annotated_frames.hist(column="annotation_label")

In [None]:
df_annotated_frames_filtered = df_annotated_frames[df_annotated_frames['file_path'].apply(lambda x: os.path.exists(os.path.join(mount_point, x)))]

In [None]:
df_annotated_frames_filtered['annotation_label'] = df_annotated_frames_filtered['annotation_label'].apply(lambda x: frame_label_dict_inv[x])

In [None]:
# df_annotated_frames_filtered.to_csv(os.path.join(mount_point, "CSV_files/c3_blindsweep_annotation_labels_merged_frames_train.csv"), index=False)

In [None]:
df_annotated_frames_filtered_pred = pd.read_csv("test_output/classification/c3_blindsweep_annotation_labels_merged_frames/epoch=9-val_loss=0.27/c3_blindsweep_annotation_labels_merged_frames_prediction.csv")
len(df_annotated_frames_filtered_pred)

In [None]:
df_annotated = pd.read_csv("CSV_files/c3_blindsweep_annotation_labels_merged.csv")[['annotation_id', 'tag', 'pid']].drop_duplicates().reset_index(drop=True)
df_annotated_frames_filtered_pred = df_annotated_frames_filtered_pred.merge(df_annotated[['annotation_id', 'tag', 'pid']], on="annotation_id", how="left")
len(df_annotated_frames_filtered_pred)

In [None]:
df_annotated_frames_filtered_pred.columns

In [None]:
def read_seq(df, sample=-1):
    img_seq = []
    for fn in df['file_path']:
        img = sitk.ReadImage(os.path.join(mount_point, fn))
        img_np = sitk.GetArrayFromImage(img)
        if img.GetNumberOfComponentsPerPixel() >  1:
            img_np = img_np[..., 0]
        img_seq.append(img_np)
    
    img_seq_np = np.array(img_seq)

    if sample > 0 and img_seq_np.shape[0] > sample:
        ridx = np.random.choice(img_seq_np.shape[0], sample, replace=False)
        img_seq_np = img_seq_np[ridx]

    return img_seq_np



In [None]:
def plot_3d_array_with_arrows(img_np, axis=0, title="3D volume viewer"):
    """
    vol: 3D numpy array
    axis: which axis to browse (0,1,2)
    Arrow keys:
      Left / Right  -> prev / next slice
      Up / Down     -> jump -10 / +10 slices
      Home / End    -> first / last slice
    """

    # Reorder so browsing axis is first: (S, H, W)
    n_slices = img_np.shape[axis]

    # Normalize to something nice for display (optional)
    vmin, vmax = float(img_np.min()), float(img_np.max())
    # Initial slice
    idx = 0
    slice2d = img_np.take(indices=idx, axis=axis)

    # FigureWidget so we can update in-place
    fig = go.FigureWidget(
        data=[
            go.Heatmap(
                z=np.flip(slice2d, axis=0),
                colorscale="Gray",
                zmin=vmin,
                zmax=vmax,
                showscale=True,
            )
        ],
        layout=go.Layout(
            title=f"{title} — slice {idx+1}/{n_slices} (axis={axis})",
            width=450,
            height=450,
            margin=dict(l=10, r=10, t=50, b=10),
        ),
    )

    status = HTML(value="Click inside the output area once, then use arrow keys.")
    box = VBox([status, fig])

    # Keyboard event capture
    ev = Event(
        source=box,
        watched_events=["keydown"],
        prevent_default_action=True,
        bubbles=True,
    )

    state = {"idx": idx}

    def update(new_idx):
        new_idx = int(np.clip(new_idx, 0, n_slices - 1))
        state["idx"] = new_idx
        with fig.batch_update():
            fig.data[0].z = np.flip(np.take(img_np, indices=new_idx, axis=axis), axis=0)
            fig.layout.title = f"{title} — slice {new_idx+1}/{n_slices} (axis={axis})"
        status.value = (
            "Click inside the output area once, then use arrow keys. "
            f"Current slice: {new_idx+1}/{n_slices}"
        )

    def handle_event(event):
        key = event.get("key", "")
        i = state["idx"]

        if key == "ArrowRight":
            update(i + 1)
        elif key == "ArrowLeft":
            update(i - 1)
        elif key == "ArrowUp":
            update(i + 10)
        elif key == "ArrowDown":
            update(i - 10)
        elif key == "Home":
            update(0)
        elif key == "End":
            update(n_slices - 1)

    ev.on_dom_event(handle_event)
    return box

In [None]:
def plot_3d_arrays_with_arrows_side_by_side(
    img_np_a,
    img_np_b,
    axis=0,
    title_a="Sweep A",
    title_b="Sweep B",
    main_title="Sweep viewer",
    share_contrast=True,
    width=900,
    height=450,
):
    """
    Display two 3D numpy arrays side-by-side with synchronized arrow-key slicing.

    Arrow keys:
      Left / Right  -> prev / next slice
      Up / Down     -> jump -10 / +10 slices
      Home / End    -> first / last slice

    share_contrast:
      True  -> common vmin/vmax for both viewers
      False -> each viewer uses its own vmin/vmax
    """
    assert img_np_a.ndim == 3 and img_np_b.ndim == 3, "Both inputs must be 3D numpy arrays"
    assert axis in (0, 1, 2), "axis must be 0, 1, or 2"

    n_slices_a = img_np_a.shape[axis]
    n_slices_b = img_np_b.shape[axis]
    n_slices = min(n_slices_a, n_slices_b)

    if share_contrast:
        vmin = float(min(img_np_a.min(), img_np_b.min()))
        vmax = float(max(img_np_a.max(), img_np_b.max()))
        vmin_a = vmin_b = vmin
        vmax_a = vmax_b = vmax
    else:
        vmin_a, vmax_a = float(img_np_a.min()), float(img_np_a.max())
        vmin_b, vmax_b = float(img_np_b.min()), float(img_np_b.max())

    idx = 0
    slice_a = np.take(img_np_a, indices=idx, axis=axis)
    slice_b = np.take(img_np_b, indices=idx, axis=axis)

    # Subplots: 1 row, 2 columns
    fig = go.FigureWidget(
        make_subplots(
            rows=1, cols=2,
            subplot_titles=(title_a, title_b),
            horizontal_spacing=0.05
        )
    )

    fig.add_trace(
        go.Heatmap(
            z=np.flip(slice_a, axis=0),
            colorscale="Gray",
            zmin=vmin_a, zmax=vmax_a,
            showscale=True
        ),
        row=1, col=1
    )

    fig.add_trace(
        go.Heatmap(
            z=np.flip(slice_b, axis=0),
            colorscale="Gray",
            zmin=vmin_b, zmax=vmax_b,
            showscale=True
        ),
        row=1, col=2
    )

    fig.update_layout(
        title=f"{main_title} — slice {idx+1}/{n_slices} (axis={axis})",
        width=width,
        height=height,
        margin=dict(l=10, r=10, t=60, b=10),
    )

    status = HTML(value="Click inside the output area once, then use arrow keys.")
    box = VBox([status, fig])

    ev = Event(
        source=box,
        watched_events=["keydown"],
        prevent_default_action=True,
        bubbles=True,
    )

    state = {"idx": idx}

    def update(new_idx):
        new_idx = int(np.clip(new_idx, 0, n_slices - 1))
        state["idx"] = new_idx

        new_slice_a = np.flip(np.take(img_np_a, indices=new_idx, axis=axis), axis=0)
        new_slice_b = np.flip(np.take(img_np_b, indices=new_idx, axis=axis), axis=0)

        with fig.batch_update():
            fig.data[0].z = new_slice_a  # left heatmap
            fig.data[1].z = new_slice_b  # right heatmap
            fig.layout.title = f"{main_title} — slice {new_idx+1}/{n_slices} (axis={axis})"

        status.value = (
            "Click inside the output area once, then use arrow keys. "
            f"Current slice: {new_idx+1}/{n_slices}"
        )

    def handle_event(event):
        key = event.get("key", "")
        i = state["idx"]

        if key == "ArrowRight":
            update(i + 1)
        elif key == "ArrowLeft":
            update(i - 1)
        elif key == "ArrowUp":
            update(i + 10)
        elif key == "ArrowDown":
            update(i - 10)
        elif key == "Home":
            update(0)
        elif key == "End":
            update(n_slices - 1)

    ev.on_dom_event(handle_event)
    return box

In [None]:
Markdown(df_annotated_frames_filtered_pred.groupby("annotation_label")['pred_class'].value_counts().to_markdown())

In [None]:
# img_seq_np = view_seq('annotation_label == "reject" and pred_class == 29')
img_seq_np = read_seq(df_annotated_frames_filtered_pred.query('annotation_label == "high_measurable" and pred_class == 12'), sample=1000)    

# ('high_measurable', 12)	709
# ('high_measurable', 29)	522
# ('high_measurable', 28)	419
# ('high_measurable', 13)	333
# ('high_measurable', 4)	79
# ('high_measurable', 23)	2
# ('high_measurable', 18)
# img_seq_np = read_seq('annotation_label == "high_visible" and pred_class == 10')

plot_3d_array_with_arrows(img_seq_np, axis=0)


In [None]:
df_ac = pd.read_csv(os.path.join(mount_point, "test_output/classification/c3_ac_only/epoch=9-val_loss=0.27/extract_frames_blind_sweeps_Dataset_C3_masked_resampled_256_spc075_merged_balanced_ac_only_file_path_prediction.csv"))

In [None]:
df_ac['pred_class'].value_counts()

In [None]:

img_seq_np = read_seq(df_annotated_frames_filtered_pred.query('annotation_label == "high_measurable" and pred_class == 13').sort_values(by=['annotation_id']), sample=200)
# img_seq_np_b = read_seq(df_annotated_frames_filtered_pred.query('annotation_label == "low_measurable" and pred_class == 13'), sample=200)
img_ac_np = read_seq(df_ac.query('pred_class == 13'), sample=500)

# plot_3d_array_with_arrows(img_ac_np, axis=0)
plot_3d_arrays_with_arrows_side_by_side(
    img_seq_np, img_ac_np, axis=0,
    title_a="high_measurable 13", title_b="ac 13",
    main_title="Compare"
)


In [None]:
df_c3_ac_nonac = pd.read_csv(os.path.join(mount_point, "test_output/classification/c3_ac_non_ac_balanced/epoch=9-val_loss=0.27/extract_frames_blind_sweeps_Dataset_C3_masked_resampled_256_spc075_merged_balanced_ac_nonac_prediction.csv"))

In [None]:
df_c3_ac_nonac['tag'].unique()

In [None]:

# df_c3_ac_nonac_filterd = df_c3_ac_nonac[df_c3_ac_nonac['tag'].isin(['AC', 'BPD', 'TCD', 'FL', 'HL', 'CRL'])]
df_c3_ac_nonac_filterd = df_c3_ac_nonac[df_c3_ac_nonac['tag'].isin(['FL'])]

# FL: 20, 14, 2, 5, 7

Markdown(df_c3_ac_nonac_filterd[['tag', 'pred_class']].value_counts().to_markdown())

In [None]:
# img_seq_np = view_seq('annotation_label == "reject" and pred_class == 29')
img_seq_np = read_seq(df_c3_ac_nonac.query('tag == "FL" and pred_class == 2'))

# ('high_measurable', 12)	709
# ('high_measurable', 29)	522
# ('high_measurable', 28)	419
# ('high_measurable', 13)	333
# ('high_measurable', 4)	79
# ('high_measurable', 23)	2
# ('high_measurable', 18)
# img_seq_np = read_seq('annotation_label == "high_visible" and pred_class == 10')

plot_3d_array_with_arrows(img_seq_np, axis=0)

In [None]:


ac_rank_ac = {
    1.0: [{'tag': 'AC', 'pred_class': 29}, {'tag': 'AC', 'pred_class': 28}],
    0.9: [{'tag': 'AC', 'pred_class': 13}],
    0.8: [{'tag': 'AC', 'pred_class': 12}],
    0.0: [{'tag': 'BPD', 'pred_class': 6}, 
          {'tag': 'BPD', 'pred_class': 11}, 
          {'tag': 'BPD', 'pred_class': 25},
          {'tag': 'BPD', 'pred_class': 26},
          {'tag': 'BPD', 'pred_class': 32}, 
          {'tag': 'BPD', 'pred_class': 16},
          {'tag': 'FL', 'pred_class': 20},
          {'tag': 'FL', 'pred_class': 14},
          {'tag': 'FL', 'pred_class': 2},
          {'tag': 'HL', 'pred_class': 20},
          {'tag': 'HL', 'pred_class': 14},
          {'tag': 'HL', 'pred_class': 2}],
}

ac_rank_annot = {
    0.8: [{'annotation_label': 'high_measurable', 'pred_class': 12, 'samples': 200}, 
          {'annotation_label': 'high_measurable', 'pred_class': 13, 'samples': 100}],
    0.75: [{'annotation_label': 'low_measurable', 'pred_class': 13, 'samples': 500}],
    0.7: [{'annotation_label': 'low_measurable', 'pred_class': 12, 'samples': 500}],
    0.6: [{'annotation_label': 'high_visible', 'pred_class': 13, 'samples': 1000}],
    0.5: [{'annotation_label': 'high_visible', 'pred_class': 12, 'samples': 1000},
          {'annotation_label': 'low_visible', 'pred_class': 13, 'samples': 1000},
          {'annotation_label': 'low_measurable', 'pred_class': 4, 'samples': 50}],
    0.4: [{'annotation_label': 'high_visible', 'pred_class': 4, 'samples': 500},
          {'annotation_label': 'low_measurable', 'pred_class': 10, 'samples': 10}],
    0.3: [{'annotation_label': 'low_visible', 'pred_class': 4, 'samples': 2000},
          {'annotation_label': 'low_visible', 'pred_class': 12, 'samples': 500},
          {'annotation_label': 'high_visible', 'pred_class': 10, 'samples': 100}],
    0.2: [{'annotation_label': 'low_visible', 'pred_class': 10, 'samples': 1000},
          {'annotation_label': 'reject', 'pred_class': 4, 'samples': 1000}],
    0.0: [
        {'annotation_label': "reject", 'pred_class': 14, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 5, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 21, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 24, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 23, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 1, 'samples': 200}, 
        {'annotation_label': "reject", 'pred_class': 33, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 15, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 22, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 10, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 30, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 20, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 18, 'samples': 500}, 
        {'annotation_label': "reject", 'pred_class': 7, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 0, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 31, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 8, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 9, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 27, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 3, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 2, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 32, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 17, 'samples': 100}, 
        {'annotation_label': "reject", 'pred_class': 6, 'samples': 1000}, 
        {'annotation_label': "reject", 'pred_class': 11, 'samples': 954},
        {'annotation_label': "reject", 'pred_class': 34, 'samples': 100},
        {'annotation_label': "reject", 'pred_class': 16, 'samples': 700},
        {'annotation_label': "reject", 'pred_class': 26, 'samples': 556},
        {'annotation_label': "reject", 'pred_class': 19, 'samples': 374},
        {'annotation_label': "reject", 'pred_class': 25, 'samples': 338}]
}

In [None]:


df_ac_simn = []
for score in ac_rank_ac.keys():
    samples = ac_rank_ac[score]
    for sample in samples:
        query = ""
        for key in sample.keys():
            if key != 'samples':
                if isinstance(sample[key], str):
                    query += f"{key} == '{sample[key]}' and "
                else:
                    query += f"{key} == {sample[key]} and "
        query = query[:-5]
        ds = df_c3_ac_nonac.query(query)[['file_path', 'pid', 'tag', 'pred_class']]
        if('samples' in sample):
            ds_s = ds.sample(n=sample['samples'], random_state=25)
        else:
            ds_s = ds

        ds_s['score'] = score
        df_ac_simn.append(ds_s)

for score in ac_rank_annot.keys():
    samples = ac_rank_annot[score]
    for sample in samples:
        query = ""
        for key in sample.keys():
            if key != 'samples':
                if isinstance(sample[key], str):
                    query += f"{key} == '{sample[key]}' and "
                else:
                    query += f"{key} == {sample[key]} and "
        query = query[:-5]
        ds = df_annotated_frames_filtered_pred.query(query)[['file_path', 'annotation_id', 'annotation_label', 'pid', 'tag', 'pred_class']]
        if('samples' in sample):
            ds_s = ds.sample(n=sample['samples'], random_state=25)
        else:
            ds_s = ds
        ds_s['score'] = score
        df_ac_simn.append(ds_s)

df_ac_simn = pd.concat(df_ac_simn).reset_index(drop=True)

In [None]:
df_ac_simn.columns

In [None]:
df_ac_simn.hist(column="score")

In [None]:
img_seq_simn = read_seq(df_ac_simn.query('score == 0.2'), sample=500)

plot_3d_array_with_arrows(img_seq_simn, axis=0)

In [None]:
df_ac_simn.to_csv(os.path.join(mount_point, "CSV_files/c3_blindsweep_annotation_labels_merged_train_v0.1_ac_simn.csv"), index=False)

In [None]:
def build_test():
    df_test = pd.read_csv(os.path.join(mount_point, "test_output/classification/c3_blindsweep_annotation_labels_merged_frames_test/epoch=9-val_loss=0.27/c3_blindsweep_annotation_labels_merged_frames_test_prediction.csv"))
    df_test_score = []

    for score in ac_rank_annot.keys():
        samples = ac_rank_annot[score]
        for sample in samples:
            query = ""
            for key in sample.keys():
                if key != 'samples':
                    if isinstance(sample[key], str):
                        query += f"{key} == '{sample[key]}' and "
                    else:
                        query += f"{key} == {sample[key]} and "
            query = query[:-5]
            ds_t = df_test.query(query)
            # if('samples' in sample):
            #     ds_s = ds.sample(n=sample['samples'], random_state=25)
            # else:
            #     ds_s = ds
            ds_t['score'] = score
            df_test_score.append(ds_t)


    df_test_score = pd.concat(df_test_score).reset_index(drop=True)

    df_missing = df_test[~df_test['file_path'].isin(df_test_score['file_path'])]

    ac_rank_missing = {
        0.8: [{'pred_class': 29},
            {'pred_class': 28}],
        0.0: [{"pred_class":0},
                {"pred_class":1},
                {"pred_class":2},
                {"pred_class":3},
                {"pred_class":5},
                {"pred_class":7},
                {"pred_class":8},
                {"pred_class":11},
                {"pred_class":12},
                {"pred_class":13},
                {"pred_class":14},
                {"pred_class":17},
                {"pred_class":18},
                {"pred_class":19},
                {"pred_class":20},
                {"pred_class":21},
                {"pred_class":22},
                {"pred_class":23},
                {"pred_class":24},
                {"pred_class":27},
                {"pred_class":30},
                {"pred_class":31},
                {"pred_class":32},
                {"pred_class":33},
                {"pred_class":34}]
    }

    df_missing_score = []
    for score in ac_rank_missing.keys():
        samples = ac_rank_missing[score]
        for sample in samples:
            query = ""
            for key in sample.keys():
                if key != 'samples':
                    if isinstance(sample[key], str):
                        query += f"{key} == '{sample[key]}' and "
                    else:
                        query += f"{key} == {sample[key]} and "
            query = query[:-5]
            ds_t = df_missing.query(query)
            # if('samples' in sample):
            #     ds_s = ds.sample(n=sample['samples'], random_state=25)
            # else:
            #     ds_s = ds
            ds_t['score'] = score
            df_missing_score.append(ds_t)

    df_missing_score = pd.concat(df_missing_score).reset_index(drop=True)


    df_test_score = pd.concat([df_test_score, df_missing_score]).reset_index(drop=True)

    df_missing = df_test[~df_test['file_path'].isin(df_test_score['file_path'])]

    df_missing

    df_test_score.to_csv(os.path.join(mount_point, "CSV_files/c3_blindsweep_annotation_labels_merged_ac_simn_test.csv"), index=False)


    