In [1]:
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import os
import pandas as pd
import time
import warnings

# rich
import random
import time

from rich.live import Live
from rich.table import Column,Table
from rich.panel import Panel
from rich.console import Console
from rich import box

from rich.layout import Layout
# end rich

from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import KBinsDiscretizer, StandardScaler, MinMaxScaler, Normalizer, OrdinalEncoder
from sklearn.naive_bayes import CategoricalNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.dummy import DummyClassifier

from tableshift import get_dataset
from tableshift.core.features import PreprocessorConfig
from tableshift.core.tasks import get_task_config

from ud_bagging import UDBaggingClassifier, balanced_weight_vector
from ud_naive_bayes import InterpretableBernoulliNB, InterpretableMultinomialNB, InterpretableCategoricalNB

In [2]:
data = [
    [ 'ASSISTments',             'assistments'             ],
    [ 'Childhood Lead',          'nhanes_lead'             ],
    [ 'College Scorecard',       'college_scorecard'       ], 
    [ 'Diabetes',                'brfss_diabetes'          ],
    [ 'FICO HELOC',              'heloc'                   ],
    [ 'Food Stamps',             'acsfoodstamps'           ],
    [ 'Hospital Readmission',    'diabetes_readmission'    ],    
    [ 'Hypertension',            'brfss_blood_pressure'    ],    
    #[ 'ICU Length of Stay'       'mimic_extract_los_3'     ],    
    #[ 'ICU Mortality',           'mimic_extract_mort_hosp' ],        
    [ 'Income',                  'acsincome'               ],
    #[ 'Public Health Insurance', 'acspubcov'               ],
    [ 'Sepsis',                  'physionet'               ],
    [ 'Unemployment',            'acsunemployment'         ],
    [ 'Voting',                  'anes'                    ]
    ]

warnings.filterwarnings("ignore")

dataset,identifier = data[2]

In [3]:

dset = get_dataset(
    cache_dir = '../tableshift/tmp',
    name=identifier,
    initialize_data=False,
    use_cached=True
)

X_a, y_a, _, _ = dset.get_pandas('train')
X_id, y_id, _, _ = dset.get_pandas('id_test')
X_b, y_b, _, _ = dset.get_pandas('ood_test')


In [15]:
np.array(data)[:,0]

array(['ASSISTments', 'Childhood Lead', 'College Scorecard', 'Diabetes',
       'FICO HELOC', 'Food Stamps', 'Hospital Readmission',
       'Hypertension', 'Income', 'Sepsis', 'Unemployment', 'Voting'],
      dtype='<U20')

In [11]:
cols = ['Dataset', 'ID', 'OOD', 'S1', 'S2', 'S3', 'S4']

df_result = pd.DataFrame(index=np.array(data)[:,1], columns=cols)

df_result.Dataset = np.array(data)[:,0]

df_result.style.apply( , subset=cols[2:], )

In [21]:
cols[2:]

['OOD', 'S1', 'S2', 'S3', 'S4']

In [5]:
import sys
print(sys.argv[0])

/home/fslab/github/mixed/.venv/lib/python3.9/site-packages/ipykernel_launcher.py


In [20]:
print(df_result.to_latex(
    float_format="%.03f",
    index=False,
    na_rep='',
    column_format='lrrrrrr',
    
))

\begin{table}
\centering
\caption[table]{Accuracy for Model Adaptation with balanced ensemble of Categorical \textit{Na\"ive} Bayes}
\begin{tabular}{lrrrrrr}
\toprule
             Dataset & ID & OOD & S1 & S2 & S3 & S4 \\
\midrule
         ASSISTments &    &     &    &    &    &    \\
      Childhood Lead &    &     &    &    &    &    \\
   College Scorecard &    &     &    &    &    &    \\
            Diabetes &    &     &    &    &    &    \\
          FICO HELOC &    &     &    &    &    &    \\
         Food Stamps &    &     &    &    &    &    \\
Hospital Readmission &    &     &    &    &    &    \\
        Hypertension &    &     &    &    &    &    \\
              Income &    &     &    &    &    &    \\
              Sepsis &    &     &    &    &    &    \\
        Unemployment &    &     &    &    &    &    \\
              Voting &    &     &    &    &    &    \\
\bottomrule
\end{tabular}
\end{table}



In [12]:
X_a[['locale2','ACTWRMID','age_entry']].describe()

Unnamed: 0,locale2,ACTWRMID,age_entry
count,98556.0,98556.0,98556.0
mean,-1.0,-0.92936,11.863509
std,0.0,0.817806,26.395969
min,-1.0,-1.0,-1.0
25%,-1.0,-1.0,-1.0
50%,-1.0,-1.0,-1.0
75%,-1.0,-1.0,1.0
max,-1.0,20.0,99.0
