In [None]:
import os
from pathlib import Path

import holoviews as hv
import matplotlib
import numpy as np
import pandas as pd
import panel as pn
import seaborn as sns
from dotenv import load_dotenv
from matplotlib import pyplot as plt
from scipy import stats

from vital.data.cardinal.config import ClinicalAttribute
from vital.data.cardinal.config import View as ViewEnum
from vital.data.cardinal.records2yaml import read_records
from vital.data.cardinal.utils.itertools import Patients

hv.extension("bokeh")

load_dotenv()  # Load environment variables from `.env` file if it exists

In [None]:
%load_ext autoreload
%autoreload 2

## Load the patient records from a CSV file or

In [None]:
records_csv = "~/dataset/cardinal/hcl/patient_records.csv"

In [None]:
csv_records = read_records(records_csv)

csv_records

## Load the patient attributes from the dataset

In [None]:
data_roots = [Path(os.environ["CARDINAL_DATA_PATH"])]

### (Optional) Hard-coded lists of patients that we might want to discard

In [None]:
from typing import List


def load_list_from_file(filepath: str | Path) -> List[str]:
    return Path(filepath).read_text().splitlines()

In [None]:
missing_patients = ["0063"]
unusable_masks = ["0119", "0126", "0135", "0147", "0153", "0158", "0165", "0228"]

exclude_patients = missing_patients + unusable_masks

### (Optional) Hard-coded lists of patients that we want to choose from

In [None]:
clustering_results_path = Path("~/data/didactic/results/cardiac_multimodal_representation_clustering").expanduser()
model = ""
patients_by_cluster = [load_list_from_file(cluster_file) for cluster_file in sorted((clustering_results_path / f"{model}-diag").glob("*.txt"))]

include_patients = None

### Load the patients attributes using the custom collections API

In [None]:
patients_filter_kwargs = {'include_patients': include_patients} if include_patients is not None else {'exclude_patients': exclude_patients}
patients = Patients(data_roots, views=[ViewEnum.A4C, ViewEnum.A2C], **patients_filter_kwargs)

dataset_records = patients.to_dataframe()
dataset_records

## Select which records (from the CSV or from the dataset) to use for further analysis

In [None]:
records = dataset_records

records

## Compute statistics on the data

### Differentiate between numerical and categorical attributes


In [None]:
categorical_variables = ClinicalAttribute.categorical_attrs()
numerical_variables = ClinicalAttribute.numerical_attrs()

#### Describe numerical variables

In [None]:
with pd.option_context("display.max_columns", None):
    display(records[numerical_variables].describe().drop(["count"]).round(decimals=1))

#### Describe categorical variables

In [None]:
with pd.option_context("display.max_columns", None):
    display(records[categorical_variables].describe().drop(["count"]))

### Lists variables, in descending order of missing data

In [None]:
missing_data_by_attr = pd.DataFrame(records.isna().sum(axis="index"), columns=["num"])
missing_data_by_attr["%"] = missing_data_by_attr.num * 100 / len(records)
missing_data_by_attr = missing_data_by_attr.sort_values(ascending=False, by="num").round(decimals=1)

with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(missing_data_by_attr[missing_data_by_attr["%"] > 0])

### Count patients with all the requested data

#### Dynamically determine which column to discard based on availability of data

In [None]:
# ignore_cols = []
ignore_cols = missing_data_by_attr[missing_data_by_attr["%"] > 10].index.tolist()
# ignore_cols = ["la_area", "vmax_tr", "s_prime", "tapse", "dbp_tte", "sbp_tte", "dpb_day", "sbp_day", "sbp_night", "dpb_night", "septal_e_prime", "lateral_e_prime"]
ignore_cols

In [None]:
# Discard selected columns
cols = records.columns.difference(ignore_cols)
records_subset = records[cols]

In [None]:
for data_tag, data in [("all data", records), ("subset of variables", records_subset)]:
    nb_remaining_patients = data.notna().all(axis="columns").sum()
    print(f"Nb patients w/ full data when using {data_tag}: {nb_remaining_patients}")

### Analyze distribution of variables

#### Analyze distribution of numerical variables

In [None]:
num_attr_select = pn.widgets.Select(name="Numerical attribute", value=numerical_variables[0], options=numerical_variables)

@pn.depends(num_attr=num_attr_select, watch=True)
def _update_distribution(num_attr: str) -> hv.Distribution:
    # Cast numerical dtypes to float to avoid problems when computing a distribution from integer columns
    return hv.Distribution(records[num_attr].astype(float))

pn.Row(pn.Column(num_attr_select), hv.DynamicMap(_update_distribution).opts(width=600, height=600, framewise=True))

#### Analyze distribution of categorical variables

In [None]:
for cat_attr in categorical_variables:
    print(cat_attr)
    cat_attr_data = records[cat_attr]
    
    labels = cat_attr_data.unique()
    labels_stats = {label: (label_count := (cat_attr_data == label).sum(), label_count * 100 / len(cat_attr_data)) for label in labels}
    
    for label, label_stats in labels_stats.items():
        print(f"{label}: {label_stats[0]} ({label_stats[1]:.1f}%)")
    
    print()

#### Analyze relationships between variables

In [None]:
num_attr_1_select = pn.widgets.Select(name="Numerical attribute (x)", value=numerical_variables[0], options=numerical_variables)
num_attr_2_select = pn.widgets.Select(name="Numerical attribute (y)", value=numerical_variables[1], options=numerical_variables)
cat_attr_select = pn.widgets.Select(name="Categorical attribute (color)", value=categorical_variables[0], options=categorical_variables)

@pn.depends(num_attr=num_attr_1_select, watch=True)
def _update_x_dist(num_attr: str) -> hv.Distribution:
    # Cast numerical dtypes to float to avoid problems when computing a distribution from integer columns
    return hv.Distribution(records[num_attr].astype(float))

@pn.depends(num_attr=num_attr_2_select, watch=True)
def _update_y_dist(num_attr: str) -> hv.Distribution:
    # Cast numerical dtypes to float to avoid problems when computing a distribution from integer columns
    return hv.Distribution(records[num_attr].astype(float))

@pn.depends(num_attr_1=num_attr_1_select, num_attr_2=num_attr_2_select, cat_attr=cat_attr_select, watch=True)
def _update_points(num_attr_1: str, num_attr_2: str, cat_attr: str) -> hv.Points:
    # Cast numerical dtypes to float to avoid problems with serializing pd.NaT (missing values of pandas' integer dtypes) 
    # when processing the data for the the scatter plot
    points_data = records[[num_attr_1, num_attr_2, cat_attr]].astype({num_attr_1: float, num_attr_2: float})
    return hv.Points(points_data, kdims=[num_attr_1, num_attr_2], vdims=[cat_attr]).opts(color=cat_attr, cmap="Set1")

widgets_layout = pn.Column(num_attr_1_select, num_attr_2_select, cat_attr_select)
plots_layout = (
    hv.DynamicMap(_update_points).opts(width=600, height=600, framewise=True, size=4) <<
    hv.DynamicMap(_update_y_dist).opts(width=150, framewise=True) <<
    hv.DynamicMap(_update_x_dist).opts(height=150, framewise=True)
)

pn.Row(widgets_layout, plots_layout)

### Correlation between attributes

#### Encode categorical attributes as numerical attributes, to give option to include them when comparing target to numerical attributes

The categorical attributes are encoded as the mean/median of the target attribute w.r.t. each class.

However, this could possibly overestimate the correlation of categorical attributes w.r.t. target attribute (e.g. the mean/median by class is coincidentaly correlated with the target attribute, even if the variance within each class is as high as the global variance). Therefore, this measure is presented to provide a comparable perspective between categorical and numerical attributes, but it should be interpreted cautiously.

In [None]:
# Define the function that encodes the categorical attributes as stats of the target attribute w.r.t. each class

def encode_cat_attr(records: pd.DataFrame, target_attr: ClinicalAttribute, cat_attr: ClinicalAttribute, stat: str = 'mean') -> pd.Series:
    target_vals = records[[cat_attr, target_attr]].groupby(cat_attr)
    target_stats = getattr(target_vals, stat)()[target_attr]
    
    cat_encodings = pd.Series(index=records.index)
    for label, target_stat in target_stats.items():
        cat_encodings.loc[records[cat_attr] == label] = target_stat
    
    return cat_encodings

In [None]:
# Define the target attribute and the attributes to encode
target = ClinicalAttribute.sbp_24
cat_attrs = ClinicalAttribute.categorical_attrs()
if is_target_categorical := target in cat_attrs:
    cat_attrs.remove(target) # Remove target from categorical attrs in case it is one of the attributes that can be converted to numerical values

encoded_records = records.copy()
if is_target_categorical:
    # If target is categorical, convert it to numerical values (assuming it will if attribute is not suitable)
    encoded_records[target] = encoded_records[target].astype(int)
    # Also encode the target (w.r.t itself) so that it also has a numerical encoding we can use in downstream operations
    encoded_records[target + "_E"] = encoded_records[target]

for cat_attr in cat_attrs:
    col_idx = encoded_records.columns.to_list().index(cat_attr)
    encoding = encode_cat_attr(encoded_records, target, cat_attr)
    encoded_records.insert(col_idx + 1, cat_attr + "_E", encoding)

#### Correlation between categorical attributes (Cramér's V)

In [None]:
# Define the computation of the statistic used to measure correlation between categorical variables

def cramers_corrected_stat(confusion_matrix: pd.DataFrame) -> float:
    """Computes the version of Cramér's V corrected by Bergsma and Wicher for categorial-categorial association."""
    res = stats.chi2_contingency(confusion_matrix)
    n = confusion_matrix.sum().sum() # Sum over both rows and columns, to get the total number of samples
    phi2 = res.statistic/n
    r, k = confusion_matrix.shape
    phi2corr = max(0, phi2 - ((k-1)*(r-1))/(n-1))
    rcorr = r - ((r-1)**2)/(n-1)
    kcorr = k - ((k-1)**2)/(n-1)
    return np.sqrt(phi2corr / min( (kcorr-1), (rcorr-1)))

In [None]:
# Compute the correlation matrix between categorical attributes
cramers_v_matrix = pd.DataFrame(index=ClinicalAttribute.categorical_attrs())

# Iterate over attributes as targets and compare the remaining attributes to it
for target in ClinicalAttribute.categorical_attrs():
    other_cat_attrs = ClinicalAttribute.categorical_attrs()
    other_cat_attrs.remove(target)

    # Compute Cramér's V for other attributes w.r.t. the target
    target_corr = {cat_attr: cramers_corrected_stat(pd.crosstab(records[target], records[cat_attr])) for cat_attr in other_cat_attrs}
    target_corr[target] = 1 # Set correlation between target and itself to 1
    
    cramers_v_matrix[target] = pd.Series(data=target_corr)

In [None]:
# Plot the pairwise similarity matrix between categorical attributes
from vital.utils.plot import plot_heatmap

plot_heatmap(cramers_v_matrix, annot_kws={"fontsize": "small"})

#### Correlation between categorical attribute and numerical attributes (Kruskal-Wallis H)

In [None]:
# Define the computation of the test used to measure correlation between categorical and numerical attributes

def kruskal(records: pd.DataFrame, target_attr: ClinicalAttribute, num_attr: ClinicalAttribute) -> float:
    samples = []
    for label in records[target_attr].dropna().unique():
        label_samples = records[records[target_attr] == label][num_attr].values
        label_samples = label_samples.astype(float) # Convert all numerical types to float so that scipy can properly handle NaNs
        samples.append(label_samples)
    kruskal_res = stats.kruskal(*samples, nan_policy="omit")
    return kruskal_res.pvalue

In [None]:
# Compute the correlation matrix between categorical attributes
kruskal_matrix = pd.DataFrame(index=ClinicalAttribute.numerical_attrs())

# Iterate over attributes as targets and compare the remaining attributes to it
for target in ClinicalAttribute.categorical_attrs():
    
    # Perform the Kruskal-Wallis H for other attributes w.r.t. the target
    target_pvalues = {num_attr: kruskal(encoded_records, target, num_attr) for num_attr in ClinicalAttribute.numerical_attrs()}
    # Uncomment the line below to compare the target to other categorical attributes' using their numerical encodings and the Kruskal-Wallis H test
    # target_pvalues.update({cat_attr + "_E": kruskal(encoded_records, target, cat_attr + "_E") for cat_attr in other_cat_attrs if cat_attr != target})
    target_disparity = pd.Series(data={attr: np.log(1 / pvalue) for attr, pvalue in target_pvalues.items()})
    
    kruskal_matrix[target] = target_disparity

In [None]:
# Plot the similarity matrix between categorical and numerical attributes
plot_heatmap(kruskal_matrix.T, annot_kws={"fontsize": "small"})

#### Correlation between numerical attributes (encoding + Spearman)

In [None]:
# Compute the correlation matrix between numerical attributes
num_attrs = ClinicalAttribute.numerical_attrs()
encoded_cat_attrs = [cat_attr + "_E" for cat_attr in ClinicalAttribute.categorical_attrs()]
corr_matrix = encoded_records[num_attrs + encoded_cat_attrs].corr(method="spearman")

In [None]:
# Plot the pairwise similarity matrix between numerical attributes
include_encoded_cat_attrs = True

variables = num_attrs
if include_encoded_cat_attrs:
    variables += encoded_cat_attrs
    
plot_heatmap(corr_matrix.loc[variables, variables], cmap="icefire", annot_kws={"fontsize": "small"})

#### Detailed correlation between target attributes and other attributes

In [None]:
# Define generic function to plot rows from similarity matrices as barplot
# This allows to more easily inspect correlation between specific attributes of interest to other attributes

def similary_matrix_row_barplot(matrix: pd.DataFrame, target: str, similarity_name: str, ascending: bool = True):
    # Extract the similarity of the target w.r.t. other attributes from the matrix
    plot_data = matrix.reset_index()[["index", target]]
    # Exclude similarity w/ itself
    plot_data = plot_data[plot_data["index"] != target]
    # Sort the values for more easily readable plots
    plot_data = plot_data.sort_values(target, ascending=ascending)
    
    with sns.axes_style("darkgrid"):
        # Orient the barplot horizontally so that it scales better w/ more attributes
        default_figsize = matplotlib.rcParams['figure.figsize']
        fig, ax = plt.subplots(figsize=(default_figsize[0], 0.25 * len(plot_data)))
        
        ax = sns.barplot(data=plot_data, y="index", x=target, orient="h", ax=ax)
        ax.set(ylabel=None, xlabel=similarity_name, title=f"Similarity between {target} and other attributes")

In [None]:
# Set target attributes for the detailed comparisons in the following cells
targets = [
    ClinicalAttribute.nt_probnp_group,
    ClinicalAttribute.ht_severity,
]

Correlation with categorical attributes (Cramér's V)

In [None]:
for target in targets:
    ax = similary_matrix_row_barplot(cramers_v_matrix, target, "Cramér's V", ascending=False)
    plt.show()

Correlation with numerical attributes (Kruskal-Wallis H)

In [None]:
for target in targets:
    ax = similary_matrix_row_barplot(kruskal_matrix, target, "Disparity from Kruskall-Wallis H test")
    plt.show()

Correlation with all other attributes (encoding + Spearman)

In [None]:
for target in targets:
    ax = similary_matrix_row_barplot(corr_matrix, target + "_E", "Spearman Correlation")
    plt.show()