In [1]:
# Import necessary libraries
from infohdp.estimators import MulticlassFullInfoHDPEstimator
from sklearn.datasets import load_breast_cancer, load_diabetes
import seaborn as sns
import pandas as pd
import numpy as np

# Function to estimate mutual information
def estimate_mi(samples):
    estimator = MulticlassFullInfoHDPEstimator()
    i_hdp, di_hdp = estimator.estimate_mutual_information(samples)
    print(f"Ihdp full multiclass, mutual information estimation [nats]: {i_hdp:.4f} ± {di_hdp:.4f}")
    
# Function to display dataset statistics
def display_dataset_statistics(samples, dataset_name):
    x_values = [x for x, y in samples]
    y_values = [y for x, y in samples]
    num_samples = len(samples)
    num_distinct_x = len(set(x_values))
    num_distinct_y = len(set(y_values))
    print(f"{dataset_name} Dataset Statistics:")
    print(f"Number of samples: {num_samples}")
    print(f"Number of distinct elements in X: {num_distinct_x}")
    print(f"Number of distinct elements in Y: {num_distinct_y}")
    print()

In [2]:
# Load and preprocess Breast Cancer dataset
data_bc = load_breast_cancer()
df_bc = pd.DataFrame(data_bc.data, columns=data_bc.feature_names)
df_bc['target'] = data_bc.target
samples_bc = list(zip(df_bc['mean radius'].astype(int), df_bc['target']))
display_dataset_statistics(samples_bc, "Breast Cancer")
estimate_mi(samples_bc)

Breast Cancer Dataset Statistics:
Number of samples: 569
Number of distinct elements in X: 22
Number of distinct elements in Y: 2

Ihdp full multiclass, mutual information estimation [nats]: 0.3649 ± 0.0847


In [3]:
# Load and preprocess Diabetes dataset
data_diabetes = load_diabetes()
df_diabetes = pd.DataFrame(data_diabetes.data, columns=data_diabetes.feature_names)
df_diabetes['target'] = data_diabetes.target

# Discretize the feature variable into 40 groups (percentiles)
df_diabetes['bmi_percentile'] = pd.qcut(df_diabetes['bmi'], 40, labels=False)
# Discretize the target variable into quartiles
df_diabetes['target_quartile'] = pd.qcut(df_diabetes['target'], 4, labels=False)

samples_diabetes = list(zip(df_diabetes['bmi_percentile'], df_diabetes['target_quartile']))
display_dataset_statistics(samples_diabetes, "Diabetes Dataset")
estimate_mi(samples_diabetes)

Diabetes Dataset Dataset Statistics:
Number of samples: 442
Number of distinct elements in X: 40
Number of distinct elements in Y: 4

Ihdp full multiclass, mutual information estimation [nats]: 0.1801 ± 0.0993
