In [1]:
import graphviz
from equiflow import TableFlows, TableCharacteristics, TableDrifts, FlowDiagram, EquifFlow
import pandas as pd
import numpy as np
from IPython.display import Image, display
from matplotlib import pyplot as plt
from IPython.display import SVG
import os
import matplotlib.pyplot as plt


ImportError: cannot import name 'EquifFlow' from 'equiflow' (/Users/joaomatos/Documents/equiflow/equiflow/__init__.py)

In [None]:
np.random.seed(42)
n = 100000
data = {
    'age': np.random.randint(18, 80, size=n),
    'sofa': np.random.choice([0,1,2,3,4,5,6,7,8,9,10,15, np.nan], size=n),
    'race': np.random.choice(['White', 'Black', 'Asian', 'Hispanic', None],
                             size=n),
    'sex': np.random.choice(['Male', 'Female'], size=n),
    'english': np.random.choice(['Fluent', 'Limited', np.nan, None], size=n),
}

for i in range(1, 11):
    data[f'var{i}'] = np.random.randn(n)

df = pd.DataFrame(data)

# apply exclusions to make a list
data_0 = df.copy()
data_1 = data_0.loc[data_0.english.notnull()]
data_2 = data_1.loc[data_1.sofa.notnull()]

# data_3a = data_2.loc[data_2.sofa <= 5]
# data_3b = data_2.loc[data_2.sofa > 5]
# data4a = data_3a.loc[data_3a.age < 50]
# data4b = data_3a.loc[data_3a.age >= 50]

In [None]:
class TableCharacteristics:
  def __init__(
      self,
      dfs: list,
      categorical: Optional[list] = None,
      normal: Optional[list] = None,
      nonnormal: Optional[list] = None,
      decimals: Optional[int] = 1,
      format_cat: Optional[str] = 'N (%)',
      format_normal: Optional[str] = 'Mean ± SD',
      format_nonnormal: Optional[str] = 'Median [IQR]',
      thousands_sep: Optional[bool] = True,
      missingness: Optional[bool] = True,
      label_suffix: Optional[bool] = True,
      rename: Optional[dict] = None,
  ) -> None:
        
    if not isinstance(dfs, list) or len(dfs) < 2:
        raise ValueError("dfs must be a list with length ≥ 2")
    
    if (categorical is None) & (normal is None) & (nonnormal is None):
        raise ValueError("At least one of categorical, normal, or nonnormal must be provided")
       
    if (categorical is not None) & (not isinstance(categorical, list)):
        raise ValueError("categorical must be a list")

    if (normal is not None) & (not isinstance(normal, list)):
        raise ValueError("normal must be a list")
    
    if (nonnormal is not None) & (not isinstance(nonnormal, list)):
        raise ValueError("nonnormal must be a list")
    
    if not isinstance(decimals, int) or decimals < 0:
        raise ValueError("decimals must be a non-negative integer")
    
    if format_cat not in ['%', 'N', 'N (%)']:
        raise ValueError("format must be '%', 'N', or 'N (%)'")
    
    if format_normal not in ['Mean ± SD', 'Mean', 'SD']:
        raise ValueError("format must be 'Mean ± SD' or 'Mean' or 'SD'")
    
    if format_nonnormal not in ['Median [IQR]', 'Mean', 'SD']:
        raise ValueError("format must be 'Median [IQR]' or 'Mean' or 'SD'")
    
    if not isinstance(thousands_sep, bool):
        raise ValueError("thousands_sep must be a boolean")
    
    if not isinstance(missingness, bool):
        raise ValueError("missingness must be a boolean")
    
    if not isinstance(label_suffix, bool):
        raise ValueError("label_suffix must be a boolean")
    
    if (rename is not None) & (not isinstance(rename, dict)):
      raise ValueError("rename must be a dictionary")
    
    self._dfs = dfs

    if categorical is None:
      self._categorical = []
    else:
       self._categorical = categorical
      
    if normal is None:
      self._normal = []
    else:
      self._normal = normal
    
    if nonnormal is None:
      self._nonnormal = []
    else:
      self._nonnormal = nonnormal

    self._decimals = decimals
    self._missingness = missingness
    self._format_cat = format_cat
    self._format_normal = format_normal
    self._format_nonnormal = format_nonnormal
    self._thousands_sep = thousands_sep
    self._label_suffix = label_suffix
    
    if rename is not None:
      self._rename = rename
    else:
       self._rename = dict()

    if rename is not None:
      if self._label_suffix:
          self._renamed_categorical = [
             self._rename[c] + ', ' + self._format_cat if c in self._rename.keys() \
              else c + ', ' + self._format_cat for c in self._categorical
          ]
          
          self._renamed_normal = [
              self._rename[n] + ', ' + self._format_normal if n in self._rename.keys() \
              else n + ', ' + self._format_normal for n in self._normal
          ]

          self._renamed_nonnormal = [
              self._rename[nn] + ', ' + self._format_nonnormal if nn in self._rename.keys() \
              else nn + ', ' + self._format_nonnormal for nn in self._nonnormal
          ]


      else:
        self._renamed_categorical = [
            self._rename[c] if c in self._rename.keys() else c for c in self._categorical
        ]

        self._renamed_normal = [
            self._rename[n] if n in self._rename.keys() else n for n in self._normal
        ]

        self._renamed_nonnormal = [
            self._rename[nn] if nn in self._rename.keys() else nn for nn in self._nonnormal
        ]

    else:
      if self._label_suffix:
        self._renamed_categorical = [c + ', ' + self._format_cat for c in self._categorical]
        self._renamed_normal = [n + ', ' + self._format_normal for n in self._normal]
        self._renamed_nonnormal = [nn + ', ' + self._format_nonnormal for nn in self._nonnormal]
      else:
        self._renamed_categorical = self._categorical
        self._renamed_normal = self._normal
        self._renamed_nonnormal = self._nonnormal


  # method to get the unique values, before any exclusion (at i=0)
  def _get_original_uniques(self, cols):

    original_uniques = dict()

    # get uniques values ignoring NaNs
    for c in cols:
      original_uniques[c] = self._dfs[0][c].dropna().unique()

    return original_uniques


  # method to get the value counts for a given column
  def _my_value_counts(self,
                       df: pd.DataFrame(),
                       original_uniques: dict,
                       col: str,
                      ) -> pd.DataFrame(): # type: ignore

    o_uniques = original_uniques[col]
    counts = pd.DataFrame(columns=[col], index=o_uniques)

    # get the number of observations, based on whether we want to include missingness
    if self._missingness:
      n = len(df)
    else:
      n = len(df) - df[col].isnull().sum() # denominator will be the number of non-missing observations

    for o in o_uniques:
      if self._format_cat == '%':
        counts.loc[o,col] = ((df[col] == o).sum() / n * 100).round(self._decimals)
  
      elif self._format_cat == 'N':
        if self._thousands_sep:
          counts.loc[o,col] = f"{(df[col] == o).sum():,}"
        else:
          counts.loc[o,col] = (df[col] == o).sum()
   
      elif self._format_cat == 'N (%)':
        n_counts = (df[col] == o).sum()
        perc_counts = (n_counts / n * 100).round(self._decimals)
        if self._thousands_sep:
          counts.loc[o,col] = f"{n_counts:,} ({perc_counts})"
        else:
          counts.loc[o,col] = f"{n_counts} ({perc_counts})"

      else:
        raise ValueError("format must be '%', 'N', or 'N (%)'")

    return counts 
  
  # method to report distribution of normal variables
  def _normal_vars_dist(self,
                        df: pd.DataFrame(),
                        col: str,
                        df_dists: pd.DataFrame(),
                        ) -> pd.DataFrame():
    
    df.loc[:,col] = pd.to_numeric(df[col], errors='raise')
    
    if self._format_normal == 'Mean ± SD':
      col_mean = np.round(df[col].mean(), self._decimals)
      col_std = np.round(df[col].std(), self._decimals)
      df_dists.loc[(col, ' '), 'value'] = f"{col_mean} ± {col_std}"

    elif self._format_normal == 'Mean':
      col_mean = np.round(df[col].mean(), self._decimals)
      df_dists.loc[(col, ' '), 'value'] = col_mean
    
    elif self._format_normal == 'SD':
      col_std = np.round(df[col].std(), self._decimals)
      df_dists.loc[(col, ' '), 'value'] = col_std

    return df_dists
  
  def _nonnormal_vars_dist(self,
                           df: pd.DataFrame(),
                           col: str,
                           df_dists: pd.DataFrame(),
                          ) -> pd.DataFrame():
     
    df.loc[:,col] = pd.to_numeric(df[col], errors='raise')

    if self._format_nonnormal == 'Mean':
      col_mean = np.round(df[col].mean(), self._decimals)
      df_dists.loc[(col, ' '), 'value'] = col_mean

    elif self._format_nonnormal == 'Median [IQR]':
      col_median = np.round(df[col].median(), self._decimals)
      col_q1 = np.round(df[col].quantile(0.25), self._decimals)
      col_q3 = np.round(df[col].quantile(0.75), self._decimals)

      df_dists.loc[(col, ' '), 'value'] = f"{col_median} [{col_q1}, {col_q3}]"

    elif self._format_nonnormal == 'SD':
      col_std = np.round(df[col].std(), self._decimals)
      df_dists.loc[(col, ' '), 'value'] = col_std

    return df_dists
  

  # method to add missing counts to the table
  def _add_missing_counts(self,
                           df: pd.DataFrame(),
                           col: str,
                           df_dists: pd.DataFrame(),
                           ) -> pd.DataFrame(): # type: ignore

    n = len(df)

    if self._format_cat == '%':
      df_dists.loc[(col,'Missing'),'value'] = (df[col].isnull().sum() / n * 100).round(self._decimals)
    
    elif self._format_cat == 'N':
      if self._thousands_sep:
        df_dists.loc[(col,'Missing'),'value'] = f"{df[col].isnull().sum():,}"
      else:
        df_dists.loc[(col,'Missing'),'value'] = df[col].isnull().sum()

    elif self._format_cat == 'N (%)':
      n_missing = df[col].isnull().sum()
      perc_missing = df[col].isnull().sum() / n * 100
      if self._thousands_sep:
        df_dists.loc[(col,'Missing'),'value'] = f"{n_missing:,} ({(perc_missing).round(self._decimals)})"
      else: 
        df_dists.loc[(col,'Missing'),'value'] = f"{n_missing} ({(perc_missing).round(self._decimals)})"

    else:
      raise ValueError("format must be '%', 'N', or 'N (%)'")

    return df_dists
  
  
  # method to add overall counts to the table
  def _add_overall_counts(self,
                           df,
                           df_dists
                           ) -> pd.DataFrame(): # type: ignore

    if self._thousands_sep:
      df_dists.loc[('Overall', ' '), 'value'] = f"{len(df):,}"
    else:
      df_dists.loc[('Overall', ' '), 'value'] = len(df)


    return df_dists
  
  # method to add label_suffix to the table
  def _add_label_suffix(self,
                         col: str,
                         df_dists: pd.DataFrame(),
                         suffix: str,
                         ) -> pd.DataFrame(): # type: ignore

    new_col = col + suffix
    df_dists = df_dists.rename(index={col: new_col}) 

    return df_dists
  
  # method to rename columns
  def _rename_columns(self,
                       df_dists: pd.DataFrame(),
                       col: str,
                      ) -> pd.DataFrame():
    
    return self._rename[col], df_dists.rename(index={col: self._rename[col]})
  
  def view(self):

    table = pd.DataFrame()

    # get the unique values, before any exclusion, for categorical variables
    original_uniques = self._get_original_uniques(self._categorical)

    for i, df in enumerate(self._dfs):

      df_dists = pd.DataFrame()

      # get distribution for categorical variables
      for col in self._categorical:

        counts = self._my_value_counts(df, original_uniques, col)

        melted_counts = pd.melt(counts.reset_index(), id_vars=['index']) \
                          .set_index(['variable','index'])

        df_dists = pd.concat([df_dists, melted_counts], axis=0)

        if self._missingness:
          df_dists = self._add_missing_counts(df, col, df_dists)

        # rename if applicable
        if col in self._rename.keys():
          col, df_dists = self._rename_columns(df_dists, col)

        if self._label_suffix:
            df_dists = self._add_label_suffix(col, df_dists, ', ' + self._format_cat)
          

      # get distribution for normal variables
      for col in self._normal:
  
          df_dists = self._normal_vars_dist(df, col, df_dists)
          
          if self._missingness:
            df_dists = self._add_missing_counts(df, col, df_dists)

          if col in self._rename.keys():
            col, df_dists = self._rename_columns(df_dists, col)

          if self._label_suffix:
            df_dists = self._add_label_suffix(col, df_dists, ', ' + self._format_normal)
        
      # get distribution for nonnormal variables
      for col in self._nonnormal:

        df_dists = self._nonnormal_vars_dist(df, col, df_dists)

        if self._missingness:
          df_dists = self._add_missing_counts(df, col, df_dists)
        
        if col in self._rename.keys():
          col, df_dists = self._rename_columns(df_dists, col)

        if self._label_suffix:
          df_dists = self._add_label_suffix(col, df_dists, ', ' + self._format_nonnormal)


      df_dists = self._add_overall_counts(df, df_dists)
    
      df_dists.rename(columns={'value': i}, inplace=True)
      table = pd.concat([table, df_dists], axis=1)

    # add super header
    table = table.set_axis(
        pd.MultiIndex.from_product([['Cohort'], table.columns]),
        axis=1)

    # renames indexes
    table.index.names = ['Variable', 'Value']

    # reorder values of "Variable" (level 0) such that 'Overall' comes first
    table = table.sort_index(level=0, key=lambda x: x == 'Overall',
                             ascending=False, sort_remaining=False)

    return table

In [None]:
a = data_0.english.notnull()

dtype('bool')

In [None]:
# build table flows
table_flows = TableFlows(
    dfs=[data_0, data_1, data_2],
    label_suffix=True,
    thousands_sep=True,
)

table_flows.view()

Cohort Flow,0 to 1,1 to 2
,,
"Initial, n",100000.0,50022.0
"Removed, n",49978.0,3874.0
"Result, n",50022.0,46148.0


In [None]:
table_characteristics = TableCharacteristics(
    dfs = [data_0, data_1, data_2],
    categorical = ['race','sex', 'english'],
    # normal = [],
    # nonnormal = [],   
    nonnormal = ['sofa'],
    normal = ['age'],
    format_cat = '%',
    # format_cont = 'Mean ± SD',
    format_normal = 'Mean',
    format_nonnormal='Median [IQR]',
    missingness = True,
    decimals = 1,
    label_suffix = True,
    thousands_sep = False,
    # rename = {}
    rename={'race': 'Race and Ethnicity',
            'english': 'English Proficiency',
            'sex':'Sex',
            'sofa': 'SOFA',
            'age': 'Age',  
            }
)

table_characteristics.view()

Unnamed: 0_level_0,Unnamed: 1_level_0,Cohort,Cohort,Cohort
Unnamed: 0_level_1,Unnamed: 1_level_1,0,1,2
Variable,Value,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Overall,,100000,50022,46148
"Race and Ethnicity, %",Hispanic,20.0,20.0,19.9
"Race and Ethnicity, %",Asian,20.0,19.8,19.8
"Race and Ethnicity, %",Black,20.1,20.2,20.2
"Race and Ethnicity, %",White,19.9,19.9,20.0
"Race and Ethnicity, %",Missing,20.0,20.0,20.1
"Sex, %",Male,50.1,50.1,50.0
"Sex, %",Female,49.9,49.9,50.0
"Sex, %",Missing,0.0,0.0,0.0
"English Proficiency, %",Fluent,25.1,50.2,50.3


In [None]:
table_drifts = TableDrifts(
    dfs=[data_0, data_1, data_2],
    categorical = ['race','sex', 'english'],
    # categorical=[],
    nonnormal = ['sofa'],
    normal = ['age'],
    # nonnormal = [],
    # normal = [],
    # missingness = True,
    decimals = 3,
    # label_suffix = True,
    # thousands_sep = False,
    rename={'race': 'Race and Ethnicity',
            'english': 'English Proficiency',
            'sex':'Sex',
            'sofa': 'SOFA',
            'age': 'Age',  
            }
)
table_drifts.view()

Unnamed: 0_level_0,Cohort Flow,0 to 1,1 to 2
Variable,Value,Unnamed: 2_level_1,Unnamed: 3_level_1
Overall,,,
Race and Ethnicity,Hispanic,0.0,0.002
Race and Ethnicity,Asian,0.003,0.0
Race and Ethnicity,Black,0.003,0.001
Race and Ethnicity,White,0.0,0.001
Sex,Male,0.0,0.001
Sex,Female,0.0,0.001
English Proficiency,Fluent,0.537,0.002
English Proficiency,Limited,0.532,0.002
Age,,-0.003,-0.001


In [None]:
table_drifts.view_simple()


Cohort Flow,0 to 1,1 to 2
Race and Ethnicity,0.006,0.005
Sex,0.0,0.0
English Proficiency,1.414,0.0
Age,-0.003,-0.001
SOFA,0.009,0.0


In [None]:
table_drifts = TableDrifts(
    dfs=[data_0, data_1, data_2],
    categorical = ['race','sex', 'english'],
    # categorical=[],
    nonnormal = ['sofa'],
    normal = ['age'],
    # nonnormal = [],
    # normal = [],
    # missingness = True,
    decimals = 3,
    # label_suffix = True,
    # thousands_sep = False,
    rename={'race': 'Race and Ethnicity',
            'english': 'English Proficiency',
            'sex':'Sex',
            'sofa': 'SOFA',
            'age': 'Age',  
            }
)
table_drifts.view_simple()

Cohort Flow,0 to 1,1 to 2
Race and Ethnicity,0.006,0.005
Sex,0.0,0.0
English Proficiency,1.414,0.0
Age,-0.003,-0.001
SOFA,0.009,0.0


In [None]:


# build flow diagram
flow_diagram = FlowDiagram(
    table_flows,
    table_characteristics,
    table_drifts,
    cohort_labels=[
        '___ patients \nin MIMIC-IV',
        '___ patients with \nEnglish proficiency data',
        '___ patients with \nSOFA data',
        '___ patients in \nthe final cohort',
    ],
    exclusion_labels=[
        '___ patients excluded for\nmissing English proficiency',
        '___ patients excluded \nfor missing SOFA score',
    ],
    legend=True,
    smds=True,
    plot_dists=True, 
    legend_with_vars=True,
    box_width=2.5,
    box_height=1,
    )

flow_diagram.view()

# from wand.image import Image as WImage
# img = WImage(filename='temp/patient_flow.pdf', resolution=120) # bigger
# img


# # delete temp folder with os
# import shutil
# shutil.rmtree('temp')

In [None]:
# fix bugs with categorical needing to be always there
# create equiflow class
# add text reports

In [None]:
equiflow.add_exclusion(mask=df0.english.notnull(), label)
equiflow.add_exclusion(mask=df0.english.notnull(), label)
equiflow.add_exclusion(mask=df0.english.notnull(), label)

#kwargs or args

SyntaxError: positional argument follows keyword argument (1287175066.py, line 1)