# Prepare Data

## Physionet Challenge Dataset
- https://medicalai.atlassian.net/wiki/spaces/AT/pages/252379174/Physionet+Label

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import sys
sys.path.append('.')

In [None]:
%load_ext autoreload
%autoreload 2

from tcav import *

In [None]:
physionet_df = pd.read_csv("/bfai/nfs_export/workspace/share/labels/physionet2021/physionet2021_total.csv")

In [None]:
physionet_df.source.value_counts()

In [None]:
label_list = [
    "atrial fibrillation",
    "atrial flutter",
    "bundle branch block",
    "bradycardia",
    "complete left bundle branch block, left bundle...", #
    "complete right bundle branch block, right bund...",
    "1st degree av block",
    "incomplete right bundle branch block",
    "left axis deviation", 
    "left anterior fascicular block",
    "prolonged pr interval",
    "low qrs voltages",
    "prolonged qt interval",
    "nonspecific intraventricular conduction disorder",
    "sinus rhythm", #
    "premature atrial contraction, supraventricular...",
    "pacing rhythm",
    "poor R wave Progression",
    "premature ventricular contractions, ventricula...",
    "qwave abnormal", #
    "right axis deviation",
    "sinus arrhythmia",
    "sinus bradycardia",
    "sinus tachycardia",
    "t wave abnormal", #
    "t wave inversion"
]



## add sublabel

In [None]:
physionet_df['26'] = physionet_df['0'] | physionet_df['1']
label_list.append("atrial fibrillation+atrial flutter")

physionet_df['27'] = physionet_df['24'] | physionet_df['25']
label_list.append("t wave abnormal + t wave inversion ")


In [None]:
label_list = np.array(label_list)
for idx,label_name in enumerate(label_list):
    print(idx, label_name)

In [None]:
label_list[[26,4,5,8,20,12,13,19,27]]

In [None]:
hist_label_list = label_list[:]

In [None]:
label_dist = np.array(physionet_df[[str(i) for i in range(len(hist_label_list))]].sum().tolist())

In [None]:
order_idx_list = np.argsort(label_dist)

In [None]:
import matplotlib.pyplot as plt

# Assume you have these two lists
labels = hist_label_list[order_idx_list]
counts = label_dist[order_idx_list]

plt.figure(figsize=(10, 6))  # Optional: You can adjust the size of the figure

plt.bar(range(len(labels)), counts, color='skyblue', edgecolor='black')

plt.xticks(range(len(labels)), labels, rotation=270)

plt.title('Physionet label distribution')  # Title of the plot
plt.xlabel('Labels')  # X-axis label
plt.ylabel('Counts')  # Y-axis label

plt.show()


## get Concept
- multilabel 데이터 중에서 label 명확히 1개 인것을 우선으로 추출하도록
- 각 label 내에서 soruce_id의 분포가 최대한 동일하도록 
- random control에서는 나머지 data에서 source_id 분포만 같도록 random 추출

In [None]:
from random import shuffle
import random

In [None]:
selected_idx_list = sorted([4,26,5,8,20,12,13,19,27,21,22,23])#
#selected_idx_list = sorted([6,9,15])#

random_concept_n:int = 10 #random concept를 몇개나 만들지 
sample_n = 200
random_seed = 777

concept_oid_dict = dict()

total_oid_set= set(physionet_df.objectid.tolist()) #
control_oid_set = set(physionet_df.objectid.tolist()) #random control oid pool: 

for idx in np.argsort(label_dist[selected_idx_list]): #갯수가 적은 label부터
    select_idx = selected_idx_list[idx]
    name = label_list[select_idx]
    print(name)
    
    target_df = physionet_df[physionet_df[str(select_idx)]==1].copy()
    target_df['count'] = target_df[[str(i) for i in range(len(label_list))]].sum(axis=1).tolist()
        
    exist_oid_df = pd.DataFrame(total_oid_set,columns=['objectid'])
    target_df = pd.merge(target_df,exist_oid_df,on='objectid',how='inner')

    random.seed(random_seed)
    random_number = list(range(len(target_df)))
    shuffle(random_number)
    target_df['random_seed'] = random_number
    
    source_list = target_df.source.value_counts(ascending=True).index.tolist()
    source_sample_list = list()
    
    remain_n = sample_n
    each_n = int(sample_n/len(source_list))
    
    for i,source in enumerate(source_list):
        
        source_sample_df = target_df[target_df.source==source]
        if i==len(source_list)-1:
            target_sample_df = source_sample_df.sort_values(['count','random_seed']).head(remain_n)
        else:
            target_sample_df = source_sample_df.sort_values(['count','random_seed']).head(each_n)
            remain_n -=len(target_sample_df)
        print(name,source,target_sample_df.shape)
        source_sample_list.append(target_sample_df)
    
    target_sample_df = pd.concat(source_sample_list)
    oid_list= target_sample_df.objectid.tolist()
    concept_oid_dict[name] = target_sample_df
    
    total_oid_set = total_oid_set-set(oid_list)
    control_oid_set = control_oid_set-set(target_df.objectid.tolist())
    
    
    if len(oid_list)<sample_n:
        print(f"[Caution]{name} label is insufficient, file_n: {len(oid_list)}")
    else:
        print(f"[Success]{name} label is prepared, sample_n: {len(oid_list)}")

remain_n = sample_n
for random_idx in range(random_concept_n):
    random_sample_df = pd.DataFrame(list(control_oid_set),columns=['objectid'])
    random_sample_df = pd.merge(random_sample_df,physionet_df,on='objectid',how='inner')
    
    each_n = int(sample_n/len(random_sample_df.source.unique()))
    
    random_sample_df = random_sample_df.groupby('source').sample(each_n,random_state=random_seed)
    
    concept_oid_dict[f"random_concept_{random_idx}"] = random_sample_df
        
    print(f"[Success] random{random_idx} label is prepared, sample_n: {len(oid_list)}")
    

In [None]:
random_sample_df = pd.DataFrame(list(control_oid_set),columns=['objectid'])
random_sample_df = pd.merge(random_sample_df,physionet_df,on='objectid',how='inner')
print(random_sample_df.shape)
for col in selected_idx_list:
    print(random_sample_df[str(col)].sum())

In [None]:
random_label_list = label_list[random_sample_df[[str(i) for i in range(0,28)]].sum(axis=0)!=0]
random_label_list

## check concept dist

In [None]:
count_df_list = list()
for name, target_df in concept_oid_dict.items():
    count_df = pd.DataFrame(target_df.source.value_counts()).T.rename({'source':name})
    count_df_list.append(count_df)
    target_df.source.hist(label=name)
    plt.legend()
    plt.show()

In [None]:
pd.concat(count_df_list)

# load model and dataset

In [None]:
from aitiautils.model_loader import ModelLoader
import torch
from torch.utils.data import DataLoader


from captum.attr import LayerGradientXActivation, LayerIntegratedGradients

from captum.concept import TCAV
from captum.concept import Concept

from captum.concept._utils.data_iterator import dataset_to_dataloader, CustomIterableDataset
from captum.concept._utils.common import concepts_to_str

In [None]:
def get_ecg(objectid):
    import requests
    res = requests.get("http://192.168.80.28:30081/ecgs", params={"objectId": objectid})
    ecg_json = res.json()
    return ecg_json

def get_ecg_tensor(objectid,dataset):
    import numpy as np
    import torch
    ecg_json = get_ecg(objectid)
    
    prepro_ = dataset.preprocessing(ecg_json)[0]
    x = prepro_
    x = torch.from_numpy(x).float().unsqueeze(0)
    
    return x

In [None]:
class TCAV_dataset(torch.utils.data.Dataset):
    def __init__(self,sample_df,model_dataset,device):
        self.sample_df = sample_df.reset_index(drop=True)
        self.model_dataset = model_dataset
        self.sample_len = len(sample_df)
        self.success_id = None
    
    def __getitem__(self,index):
        objectid = self.sample_df.loc[index].objectid
        try:
            output = get_ecg_tensor(objectid,self.model_dataset).squeeze()
            self.success_id = objectid
        except:
            print(f'error {objectid}->{self.success_id}')
            output = get_ecg_tensor(self.success_id,self.model_dataset).squeeze()
        return output.to(device)
    
    def __len__(self):
        return self.sample_len

In [None]:
device = "cuda:3"
checkpoint_path = "/bfai/nfs_export/workspace/share/result/jjh/models/lvsd/checkpoint.pth"
# checkpoint_path = "/bfai/nfs_export/workspace/share/result/series/lvsd/lvsd2-220916_163905/checkpoint.pth"

In [None]:
label_df = pd.read_csv("/bfai/nfs_export/workspace/share/result/bytaklee/test/lvsd2-220916_163905/train_probability.csv")
# label_df = pd.read_csv("/bfai/nfs_export/workspace/share/result/jjh/test/hyperkalemia-230404_080756/train_probability.csv")

In [None]:
print(checkpoint_path)
loader = ModelLoader(checkpoint_path,device=device)
classifier = loader.get_network()
classifier.eval()
dataset_cls = loader.get_dataset()


In [None]:
# import pickle
# from aitiautils.calibration import Calibrator
# calibrator = Calibrator(model_pickle="/bfai/nfs_export/workspace/share/result/jjh/models/lvsd/calib_model.pkl")
# output_calib = calibrator.transform(label_df,prob_col_name='1')
# label_df['calib'] = output_calib

In [None]:
label_df_target = label_df[label_df['1']>0.5].sample(1000,random_state=random_seed)

target_tensor_list = list()

for oid in label_df_target.objectid:
    try:
        out = get_ecg_tensor(oid,dataset_cls)
        target_tensor_list.append(out)
    except:
        pass
    

In [None]:
target_tensor = torch.stack(target_tensor_list).squeeze()
target_tensor.shape

# TCAV with captum

In [None]:
from aitiautils.dot_dict import DotDict

In [None]:
tcav_concept_dict = dict()

In [None]:
for idx, (name,concept_df) in enumerate(concept_oid_dict.items()):
    tcav_dataset = TCAV_dataset(concept_df,dataset_cls,device)
    concept_iter = dataset_to_dataloader(tcav_dataset)
    tcav_concept = Concept(idx,name,concept_iter)
    tcav_concept_dict[name] = tcav_concept

In [None]:
tcav_concept_dict = DotDict(tcav_concept_dict)

In [None]:
tcav_concept_dict.keys()

In [None]:
# tcav_concept_dict.pop('right axis deviation')
# tcav_concept_dict.pop('nonspecific intraventricular conduction disorder')
# tcav_concept_dict.pop('prolonged qt interval')
# tcav_concept_dict.pop('qwave abnormal')
# tcav_concept_dict.pop('complete right bundle branch block, right bund...')
# tcav_concept_dict.pop('left axis deviation')
# tcav_concept_dict.pop('atrial fibrillation+atrial flutter')
# tcav_concept_dict.pop('t wave abnormal + t wave inversion ')

In [None]:
# for tmp in classifier.named_parameters():
#     print(tmp[0])

In [None]:
layers = ["blk1d.0.2.conv2","blk1d.1.2.conv2","blk1d.2.2.conv2","blk1d.3.2.conv2"]
tcav_concept_dict.keys()

In [None]:
mytcav = TCAV(model=classifier,layers=layers,
              layer_attr_method =LayerIntegratedGradients(classifier, None, multiply_by_inputs=False) ) #

In [None]:
print(tcav_concept_dict.keys())
list(tcav_concept_dict.values())
experimental_set_rand = [list(tcav_concept_dict.values())]

In [None]:
# experimental_set_rand = [[target,list(tcav_concept_dict.values())[-1]] for target in list(tcav_concept_dict.values())]
# experimental_set_rand = experimental_set_rand[:1]

In [None]:
#experimental_set_rand = [[tcav_concept_dict['t wave abnormal'],tcav_concept_dict['random_concept_0']]]

In [None]:
tcav_scores_w_random = mytcav.interpret(inputs=target_tensor, #.to(device)
                                        experimental_sets=experimental_set_rand,
                                        target=1,
                                        n_steps=5)

In [None]:
def format_float(f):
    return float('{:.3f}'.format(f) if abs(f) >= 0.0005 else '{:.3e}'.format(f))

def plot_tcav_scores(experimental_sets, tcav_scores,layer_list):
    fig, ax = plt.subplots(1, len(experimental_sets), figsize = (25, 7))

    barWidth = 1 / (len(experimental_sets[0]) + 1)

    for idx_es, concepts in enumerate(experimental_sets):

        concepts = experimental_sets[idx_es]
        concepts_key = concepts_to_str(concepts)

        pos = [np.arange(len(layer_list))]
        for i in range(1, len(concepts)):
            pos.append([(x + barWidth) for x in pos[i-1]])
        _ax = (ax[idx_es] if len(experimental_sets) > 1 else ax)
        for i in range(len(concepts)):
            val = [format_float(scores['sign_count'][i]) for layer, scores in tcav_scores[concepts_key].items()]
            direction = [format_float(scores['magnitude'][i]) for layer, scores in tcav_scores[concepts_key].items()]
            print(direction)
            _ax.bar(pos[i], val, width=barWidth, edgecolor='white', label=concepts[i].name)

        # Add xticks on the middle of the group bars
        _ax.set_xlabel('Set {}'.format(str(idx_es)), fontweight='bold', fontsize=16)
        _ax.set_xticks([r + 0.3 * barWidth for r in range(len(layers))])
        _ax.set_xticklabels(layers, fontsize=16)

        # Create legend & Show graphic
        _ax.legend(fontsize=16,bbox_to_anchor=(1.3, 1),loc='upper right') #ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    plt.show()

In [None]:
#plot_tcav_scores(experimental_set_rand, tcav_scores_w_random,layers)

In [None]:
# tcav_scores_w_random['0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18']

In [None]:
# for layer, output_dict in tcav_scores_w_random['0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18'].items():
#     print(layer)
#     print(pd.DataFrame(output_dict)) #index=list(label_list[selected_idx_list])+['random'])

## statistical signification test

In [None]:
tcav_concept_dict.keys()

In [None]:
exp_sets_for_each = list()

for concept_name in tcav_concept_dict.keys():
    
    if "random_concept" in concept_name:
        continue
    
    experimental_sets = list()
    target_concept = tcav_concept_dict[concept_name]
    random_concepts = [tcav_concept_dict[f"random_concept_{i}"] for i in range(0, random_concept_n)]


    experimental_sets.extend([[target_concept, random_concept] for random_concept in random_concepts])
    
    exp_sets_for_each.append(experimental_sets)
#     experimental_sets.append([random_0_concept, random_1_concept])
#     experimental_sets.extend([[random_0_concept, random_concept] for random_concept in random_concepts])



In [None]:
block_tcav_result_list = list()
block_tcav_random_score_list = list()
score_type = "sign_count" #'magnitude'
for block_n in [0,1,2,3]:
    target_layer = f'blk1d.{block_n}.2.conv2'
    
    p_val_out_list = list()
    random_score_each_block=list()
    for target_exp_set in exp_sets_for_each:
        out = get_confidnece_plot(mytcav,target_exp_set,target_layer,score_type,target_tensor,device,label_name=target_exp_set[0][0].name)
        p_val_out_list.append(out)
        random_score_each_block.append(out[-1])


    name_list = [target_exp_set[0][0].name for target_exp_set in exp_sets_for_each]
    #name_list.append('Random control')
    mean_list = [out[1][0] for out in p_val_out_list]
    h_list = [out[1][1] for out in p_val_out_list]
    block_tcav_result_list.append([mean_list,h_list])
    block_tcav_random_score_list.append(random_score_each_block)
    
# mean_list.append(out[2][0])
# h_list.append(out[2][1])




In [None]:
from itertools import chain
block_tcav_radom_result_list = [mean_confidence_interval(list(chain(*block_tcav_random_score_list[i]))) for i in  [0,1,2,3]]

In [None]:
block_tcav_radom_result_list#[0][0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib

def draw_heatmap(matrix,ci_matrix=None, row_names=None, col_names=None, cmap='bwr', cell_width=1, cell_height=1, vmin=None, vmax=None):
    """
    Draw a heatmap for a given matrix using the specified colormap.
    
    Parameters:
    - matrix (list of lists or numpy array): The input N x M matrix.
    - row_names (list of str, optional): Names of rows.
    - col_names (list of str, optional): Names of columns.
    - cmap (str, optional): The colormap to use. Default is 'bwr' (blue-white-red).
    - cell_width (float, optional): Width of each cell in the heatmap. Default is 1.
    - cell_height (float, optional): Height of each cell in the heatmap. Default is 1.
    - vmin (float, optional): Minimum value for colormap scaling.
    - vmax (float, optional): Maximum value for colormap scaling.
    
    Returns:
    - None
    """
    fig_width = len(matrix[0]) * cell_width
    fig_height = len(matrix) * cell_height
    
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    cax = ax.matshow(matrix, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax,alpha=0.8)
    
    # Display the intensity values in each cell
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            print_output = round(matrix[i][j],3)
            if ci_matrix is not None:
                ci = ci_matrix[i][j]
                lower = round(print_output-(ci/2),3)
                upper = round(print_output+(ci/2),3)
                print_output = f"{print_output:.3f}\n({lower:.3f}-{upper:.3f})"
                
            ax.text(j, i, str(print_output), ha='center', va='center', color='black',fontsize=12.5)
    
    # Set row and column names
    if row_names:
        ax.set_yticks(np.arange(len(row_names)))
        ax.set_yticklabels(row_names)
    if col_names:
        ax.set_xticks(np.arange(len(col_names)))
        ax.set_xticklabels(col_names, rotation=45, ha='right')
        ax.xaxis.set_ticks_position('bottom')
    
    plt.colorbar(cax)
    plt.show()
    return fig


In [None]:
reindex_list = [6,0,1,5,7,4,8,3,2,-1]

block_tav_score_list =np.array([
    block_tcav_result_list[0][0]+[block_tcav_radom_result_list[0][0]],
    block_tcav_result_list[1][0]+[block_tcav_radom_result_list[1][0]],
    block_tcav_result_list[2][0]+[block_tcav_radom_result_list[2][0]],
    block_tcav_result_list[3][0]+[block_tcav_radom_result_list[3][0]]])

block_tav_score_list=block_tav_score_list[:,reindex_list]


block_tav_ci_list =np.array([
    block_tcav_result_list[0][1]+[block_tcav_radom_result_list[0][1]],
    block_tcav_result_list[1][1]+[block_tcav_radom_result_list[1][1]],
    block_tcav_result_list[2][1]+[block_tcav_radom_result_list[2][1]],
    block_tcav_result_list[3][1]+[block_tcav_radom_result_list[3][1]]])
block_tav_ci_list = block_tav_ci_list[:,reindex_list]

In [None]:
matrix = block_tav_score_list
rows = ["Block1", "Block2", "Block3","Block4"]
cols = list(np.array(list(tcav_concept_dict)[:])[reindex_list])
fig = draw_heatmap(matrix,block_tav_ci_list, row_names=rows, col_names=cols, cell_width=2, cell_height=1, vmin=0.1, vmax=1,cmap='Reds')


In [None]:
fig.savefig('TCAV_block_figure.png',dpi=250)

In [None]:
# print(block_n)
# plt.figure(figsize=(8,8),tight_layout=True)
# plt.bar(
#         name_list,
#         mean_list,
#         yerr=h_list,
#         color=["gray"],
#         capsize=10,
#     )
# plt.ylim(0,1)
# plt.xticks(rotation=90)
# #plt.savefig(f'../../XAI_repo/notebooks/LVSD_TCAV_BLOCK{block_n}.png')