In [None]:
import sys
import torch
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'
from sklearn.metrics import auc
from sklearn.metrics import recall_score
from sklearn.metrics import roc_curve
sys.path.append('../..')
sys.path.append('..')
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import torch.nn.functional as F
from train_setup import setup_directories, setup_tensorboard, setup_logging
# From datasets import get_attr_max_min
from utils import EMA, seed_all
from vae import HVAE
from pgm.train_pgm import preprocess, sup_epoch, eval_epoch
from pgm.utils_pgm import plot_cf, check_nan, update_stats
from pgm.layers import TraceStorage_ELBO
from pgm.flow_pgm import FlowPGM
# from _flow_pgm_legacy import FlowPGM
from pgm.train_cf import DSCM
from PIL import Image
from matplotlib import colors
from chexploration.mimic_multitask import DenseNet
from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd
import copy

In [None]:
def norm(batch):
    for k, v in batch.items():
        if k == 'x':
            batch['x'] = (batch['x'].float() - 127.5) / 127.5  # [-1,1]
        elif k in ['age']:
            batch[k] = batch[k].float().unsqueeze(-1)
            batch[k] = batch[k] / 100.
            batch[k] = batch[k] *2 -1 #[-1,1]
        elif k in ['race']:
            batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
        elif k in ['finding']:
            batch[k] = batch[k].unsqueeze(-1).float()
        else:
            batch[k] = batch[k].float().unsqueeze(-1)
    return batch

def loginfo(title, logger, stats):
    logger.info(f'{title} | ' +
                ' - '.join(f'{k}: {v:.4f}' for k, v in stats.items()))

def inv_preprocess(pa):
    # Undo [-1,1] parent preprocessing back to original range
    for k, v in pa.items():
        if k =='age':
            pa[k] = (v + 1) / 2 * 100
    return pa


def vae_preprocess(args, pa):
    pa = torch.cat([pa[k] for k in args.parents_x], dim=1)
    pa = pa[..., None, None].repeat(
        1, 1, *(args.input_res,)*2).float()
    return pa


def get_metrics(preds, targets):
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')
    stats = {}
    for k in preds.keys():
        if k=="age":
            preds_k = (preds[k] + 1) / 2 *100  # [-1,1] -> [0,100]
            stats[k+'_mae'] = torch.mean(
                torch.abs(targets[k] - preds_k)).item() 
    return stats

class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

In [None]:
num_classes = 14
attribute_names = ['sex', 'race', 'finding']
sex_categories = ['Male', 'Female']  # 0,1
race_categories = ['White', 'Asian', 'Black']  # 0,1,2
finding_categories = ['No disease', 'Pleural Effusion']

target_label = "Pleural Effusion"    

def test(model, data_loader, device):
    model.eval()
    logits = []
    preds = []
    targets = []
    attributes = {k:[] for k in attribute_names}

    with torch.no_grad():
        for index, batch in enumerate(tqdm(data_loader, desc='Test-loop')):
            img, lab = batch['x'].to(device), batch['finding'].to(device)
            img = (img+ 1) * 127.5
            # print(f"img.max() {img.max()}, img.min(), {img.min()}")
            img = img.repeat(1,3,1,1)
            out = model(img)[0]
            pred = torch.sigmoid(out)
            logits.append(out)
            preds.append(pred)
            targets.append(lab)
            for k in attributes.keys():
                attributes[k].append(batch[k])

        logits = torch.cat(logits, dim=0)
        preds = torch.cat(preds, dim=0)
        targets = torch.cat(targets, dim=0)
        for k in attributes.keys():
            attributes[k] = torch.cat(attributes[k], dim=0).cpu().numpy()

    return preds.cpu().numpy(), targets.cpu().numpy(), logits.cpu().numpy(), attributes




### Set predictor

In [None]:
# Load predictors
args = Hparams()
args.predictor_path = 'PREDICTOR_PATH'
predictor_checkpoint = torch.load(args.predictor_path)
args.update(predictor_checkpoint['hparams'])
args.use_dataset = 'mimic'
args.csv_dir =  "META_DATA_DIR"
args.data_dir = "DATA_DIR"
args.loss_norm = "l2"
predictor = FlowPGM(args)
predictor.load_state_dict(predictor_checkpoint['ema_model_state_dict'])

### Set PGM

In [None]:
# Load PGM
args.pgm_path = 'PGM_PATH'
print(f'\nLoading PGM checkpoint: {args.pgm_path}')
pgm_checkpoint = torch.load(args.pgm_path)
pgm_args = Hparams()
pgm_args.update(pgm_checkpoint['hparams'])
pgm = FlowPGM(pgm_args)
pgm.load_state_dict(pgm_checkpoint['ema_model_state_dict'])

### Set VAE

In [None]:
args.vae_path = 'VAE_PATH'

print(f'\nLoading VAE checkpoint: {args.vae_path}')
vae_checkpoint = torch.load(args.vae_path)
vae_args = Hparams()
vae_args.update(vae_checkpoint['hparams'])
vae = HVAE(vae_args)
vae.load_state_dict(vae_checkpoint['ema_model_state_dict'])

### Load DSCM

In [None]:
args = Hparams()
dscm_dir = "mimic_dscm_new_classifier_lr_1e4_lagrange_lr_1_damping_10"
which_checkpoint="12000_checkpoint"


args.load_path = f'../checkpoints/a_r_s_f/{dscm_dir}/{which_checkpoint}.pt'
_save_fig_dir = f'SAVE_DIR/{dscm_dir}/{which_checkpoint}'
os.makedirs(_save_fig_dir, exist_ok = True)
dscm_checkpoint = torch.load(args.load_path)
args.update(dscm_checkpoint['hparams'])
model = DSCM(args, pgm, predictor, vae)
args.cf_particles =1
# model.load_state_dict(dscm_checkpoint['ema_model_state_dict'])
model.load_state_dict(dscm_checkpoint['model_state_dict'])
model.cuda()
elbo_fn = TraceStorage_ELBO(num_particles=1)

print(args.input_res)
# Set model require_grad to False
for p in model.parameters():
    p.requires_grad = False

### Load pre-trianed DenseNet

In [None]:

densenet = DenseNet(num_classes_disease=14, num_classes_sex=2, num_classes_race=3, class_weights_race=(1.0, 1.0, 1.0))

densenet_path = "../chexploration/multi_task_ckpt/epoch=7-step=5887.ckpt"
densenet_checkpoint = torch.load(densenet_path)
densenet.load_state_dict(densenet_checkpoint['state_dict'])
densenet = densenet.cuda()

### Load data

In [None]:
# Load subgroup
args.bs = 1
from train_pgm import setup_dataloaders
subloaders = {}

# For race
subloaders['White'] = setup_dataloaders(args, select_subgroup=True, race_choice="White")['test']
subloaders['Black'] = setup_dataloaders(args, select_subgroup=True, race_choice="Black")['test']
subloaders['Asian'] = setup_dataloaders(args, select_subgroup=True, race_choice="Asian")['test']

print(f"White: {len(subloaders['White'].dataset)}, Black: {len(subloaders['Black'].dataset)}, Asian: {len(subloaders['Asian'].dataset)}")


# For sex
subloaders['Male'] = setup_dataloaders(args, select_subgroup=True, sex_choice="Male")['test']
subloaders['Female'] = setup_dataloaders(args, select_subgroup=True, sex_choice="Female")['test']
print(f"male: {len(subloaders['Male'].dataset)}, female: {len(subloaders['Female'].dataset)}")

# For finding
subloaders['No disease'] = setup_dataloaders(args, select_subgroup=True, finding_choice='No Finding')['test']
subloaders['Pleural Effusion'] = setup_dataloaders(args, select_subgroup=True, finding_choice='Pleural Effusion')['test']
print(f"No disease: {len(subloaders['No disease'].dataset)}, Pleural Effusion: {len(subloaders['Pleural Effusion'].dataset)}")

# Race/Sex/Finding categories
race_categories = ['White', 'Asian', 'Black']
sex_categories = ['Male', 'Female']
finding_categories = ['No disease', 'Pleural Effusion']



### Check logits

In [None]:
densenet = densenet.eval()

In [None]:
_N_sample = 1000 // args.bs-1

def get_logits_for_attribute_a_on_CFs_intervened_on_attribute_b(attribute_a, attribute_b):
    cf_logits = {
        'logit - Pleural Effusion':[],
        'logit - No Disease':[],
        'logit - Male':[],
        'logit - Female':[],
        'logit - White':[],
        'logit - Black':[],
        'logit - Asian':[], 
        attribute_a:[],
            }

    # Select categories for attribute A 
    match attribute_a:
        case "sex":
            attr_a_category = sex_categories
        case "finding":
            attr_a_category = finding_categories
        case "race":
            attr_a_category = race_categories

    # Intervened on Attribute B for subgroups classified upon Attribute A.
    for attr_a in attr_a_category:
        with torch.no_grad():
            dag_variables = list(model.pgm.variables.keys())
            count=0
            for batch in tqdm(subloaders[attr_a]):
                if count>_N_sample:
                    break
                batch = preprocess(batch)  
                pa = {k: v for k, v in batch.items() if k in dag_variables}  
                img = batch['x']
                img = (img+ 1) * 127.5
                img = img.repeat(1,3,1,1)

                _logits_disease, _logits_sex, _logits_race = densenet.forward(img)
                for _i in range(len(_logits_disease)):
                    cf_logits[attribute_a].append(attr_a)
                    cf_logits['logit - Pleural Effusion'].append(_logits_disease[_i][10].cpu().item())
                    cf_logits['logit - No Disease'].append(_logits_disease[_i][0].cpu().item())
                    cf_logits['logit - Male'].append(_logits_sex[_i][0].cpu().item())
                    cf_logits['logit - Female'].append(_logits_sex[_i][1].cpu().item())
                    cf_logits['logit - White'].append(_logits_race[_i][0].cpu().item())
                    cf_logits['logit - Asian'].append(_logits_race[_i][1].cpu().item())
                    cf_logits['logit - Black'].append(_logits_race[_i][2].cpu().item())
                count+=1

    for attr_a in attr_a_category:
        n_train = len(subloaders[attr_a].dataset)
        train_samples = copy.deepcopy(subloaders[attr_a].dataset.samples)
        for k in train_samples.keys():
            if k in ['sex', 'finding', 'age', 'race']:
                train_samples[k]=torch.from_numpy(np.array(train_samples[k]))
        with torch.no_grad():
            match attribute_b:
                case 'race':
                    cf_title = f"{attr_a} do(race)"
                case 'sex':
                    cf_title = f"{attr_a} do(sex)"
                case 'finding':
                    cf_title = f"{attr_a} do(disease)"
                case 'age':
                    cf_title = f"{attr_a} do(age)"

            dag_variables = list(model.pgm.variables.keys())
            count=0
            for batch in tqdm(subloaders[attr_a]):
                if count>_N_sample:
                    break
                bs = batch['x'].shape[0]
                
                batch = preprocess(batch)  
                pa = {k: v for k, v in batch.items() if k in dag_variables}  

                # Intervene on Attribute B
                do = {}
                do_k = attribute_b
                match attribute_b:
                    case "sex":
                        do[do_k] = 1-pa[do_k] 
                    case "finding":
                        do[do_k] = 1-pa[do_k] 
                    case "race":
                        batch_r = copy.deepcopy(batch)[do_k]
                        batch_r = torch.argmax(batch_r, dim=-1)
                        race_ = torch.bernoulli(1/2*torch.ones_like(batch_r))
                        _do_r= (batch_r+race_+1)%3
                        _do_r = F.one_hot(_do_r.long(), num_classes=3).squeeze().float()
                        do[do_k]=_do_r
                    case "age":
                        do[do_k] = train_samples[do_k].clone()[torch.randperm(n_train)][:bs].unsqueeze(1)
                        do[do_k] = (do[do_k] / 100)*2 -1
                do = preprocess(do)

                # for _k in batch.keys():
                #     print(f"{_k}: {batch[_k].size()}")
                # get counterfactual x
                out_cf = model.forward(batch, do, elbo_fn, cf_particles=args.cf_particles)
                img = out_cf['cfs']['x']
                img = (img+ 1) * 127.5
                img = img.repeat(1,3,1,1)
                _logits_disease, _logits_sex, _logits_race = densenet.forward(img)
                for _i in range(len(_logits_disease)):
                    cf_logits[attribute_a].append(cf_title)
                    cf_logits['logit - Pleural Effusion'].append(_logits_disease[_i][10].cpu().item())
                    cf_logits['logit - No Disease'].append(_logits_disease[_i][0].cpu().item())
                    cf_logits['logit - Male'].append(_logits_sex[_i][0].cpu().item())
                    cf_logits['logit - Female'].append(_logits_sex[_i][1].cpu().item())
                    cf_logits['logit - White'].append(_logits_race[_i][0].cpu().item())
                    cf_logits['logit - Asian'].append(_logits_race[_i][1].cpu().item())
                    cf_logits['logit - Black'].append(_logits_race[_i][2].cpu().item())
                count+=1
    cf_logits = pd.DataFrame.from_dict(cf_logits) 
    return cf_logits 

def plot_logits(cf_logits, select_keys, hue_order, attribute_a, xdat, ydat):
    # Select subgroups
    match attribute_a:
        case "race":
            select_cf_logits = cf_logits[cf_logits.race.isin(select_keys)]
        case "sex":
            select_cf_logits = cf_logits[cf_logits.sex.isin(select_keys)]
        case "finding":
            select_cf_logits = cf_logits[cf_logits.finding.isin(select_keys)]
    # Randomize index
    select_cf_logits = select_cf_logits.sample(frac=1.0).reset_index(drop=True)
    
    alpha = 0.6
    style = 'o'
    markersize = 40
    kind = 'scatter'

    sns.set_theme(style="white", palette=None)
    fig = sns.jointplot(data=select_cf_logits, 
                        x=xdat, 
                        y=ydat, 
                        hue=attribute_a, 
                        hue_order=hue_order, 
                        kind=kind, 
                        marker=style, 
                        s=markersize, 
                        alpha=alpha, 
                        joint_kws=dict(rasterized=True))
    fig.ax_joint.legend(loc='upper right')
    # fig.set_axis_labels('Logit - No Finding', 'Logit - Pleural Effusion')
    # xlim_logits = fig.ax_joint.get_xlim()
    # ylim_logits = fig.ax_joint.get_ylim()
    title = 'logits_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'logits')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title+'.pdf'), bbox_inches='tight', dpi=300)
    plt.close('all')


### Check embeddings

In [None]:
densenet = densenet.eval()

In [None]:
_N_sample = 1000 // args.bs-1

from sklearn import decomposition
import pickle as pk

def pca_for_embeddings(select_cf_info, pca_file=None):
    pd.options.mode.chained_assignment = None # Remove warning information
    
    if not os.path.isfile(pca_file):
        pca = decomposition.PCA(n_components=0.99, whiten=False)
        pca_values = pca.fit_transform(select_cf_info['embeddings'].tolist())
        pk.dump(pca, open(pca_file,"wb"))
        print(f"Save {pca_file}")
    else:
        print(f"We load {pca_file}")
        pca = pk.load(open(pca_file,'rb'))
        pca_values = pca.transform(select_cf_info['embeddings'].tolist())
    
    select_cf_info['pca'] = pca_values.tolist()
    select_cf_info['pca mode 1'] = pca_values[:,0]
    select_cf_info['pca mode 2'] = pca_values[:,1]
    select_cf_info['pca mode 3'] = pca_values[:,2]
    select_cf_info['pca mode 4'] = pca_values[:,3]
    select_cf_info['pca'] = select_cf_info['pca'].tolist()
    return select_cf_info, pca

def plot_pca(cf_info, select_keys, hue_order, attribute_a):
    # Select subgroups
    match attribute_a:
        case "race":
            select_cf_info = cf_info[cf_info.race.isin(select_keys)]
        case "sex":
            select_cf_info = cf_info[cf_info.sex.isin(select_keys)]
        case "finding":
            select_cf_info = cf_info[cf_info.finding.isin(select_keys)]
    # Get PCA
    select_cf_info_with_pca, pca_trained = pca_for_embeddings(select_cf_info, pca_file=f"pca_{attribute_a}.pkl")
    # Randomize index
    select_cf_info_with_pca = select_cf_info_with_pca.sample(frac=1.0).reset_index(drop=True)

    alpha = 0.6
    style = 'o'
    markersize = 40
    kind = 'scatter'
    xdat = 'pca mode 1'
    ydat = 'pca mode 2'

    hue_type = attribute_a

    sns.set_theme(style="white", palette=None)
    fig = sns.jointplot(x=xdat, y=ydat, data=select_cf_info_with_pca,
                        hue=hue_type, 
                        hue_order=hue_order,
                        kind=kind,
                        marker=style, 
                        s=markersize, 
                        alpha=alpha, 
                        joint_kws=dict(rasterized=True),
                        )
    fig.ax_joint.legend(loc='upper right')
    # fig.set_axis_labels('PCA mode1', 'PCA mode2')

    title_12 = 'pca12_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca12')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_12+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_12+'.png'), bbox_inches='tight', dpi=800)

    xdat = 'pca mode 3'
    ydat = 'pca mode 4'
    sns.set_theme(style="white", palette=None)
    fig = sns.jointplot(x=xdat, y=ydat, data=select_cf_info_with_pca, 
                        hue=hue_type,
                        hue_order=hue_order, 
                        kind=kind, marker=style, s=markersize, alpha=alpha,
                        )
    fig.ax_joint.legend(loc='upper right')
    # fig.set_axis_labels('PCA mode3', 'PCA mode4')
    title_34 = 'pca34_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca34')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_34+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_34+'.png'), bbox_inches='tight', dpi=800)
    

    fontscale = 1.6
    color_palette = 'tab10'
    # color_palette=['tab:blue', 'tab:orange', 'tab:red']
    # Plot PCA mode 1
    sns.set_theme(style="white", palette=color_palette, font_scale=fontscale)
    fig, ax = plt.subplots(figsize=(10,3))
    g = sns.kdeplot(
        x='pca mode 1',
        hue=hue_type,
        fill=True, 
        hue_order=hue_order,
        data=select_cf_info_with_pca, 
        ax=ax, 
        common_norm=False,
        )
    g.get_legend().set_title(None)
    g.spines[['right', 'top']].set_visible(False)
    title_1 = 'pca1_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca1')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_1+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_1+'.png'), bbox_inches='tight', dpi=800)
    
    # Plot PCA mode 2
    sns.set_theme(style="white", palette=color_palette, font_scale=fontscale)
    fig, ax = plt.subplots(figsize=(10,3))
    g = sns.kdeplot(
        x='pca mode 2',
        hue=hue_type,
        fill=True, 
        hue_order=hue_order,
        data=select_cf_info_with_pca, 
        ax=ax, 
        common_norm=False,
        )
    g.get_legend().set_title(None)
    g.spines[['right', 'top']].set_visible(False)
    title_2 = 'pca2_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca2')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_2+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_2+'.png'), bbox_inches='tight', dpi=800)
    
    # Plot PCA mode 3
    sns.set_theme(style="white", palette=color_palette, font_scale=fontscale)
    fig, ax = plt.subplots(figsize=(10,3))
    g = sns.kdeplot(
        x='pca mode 3',
        hue=hue_type,
        fill=True, 
        hue_order=hue_order,
        data=select_cf_info_with_pca, 
        ax=ax, 
        common_norm=False,
        )
    g.get_legend().set_title(None)
    g.spines[['right', 'top']].set_visible(False)
    title_3 = 'pca3_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca3')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_3+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_3+'.png'), bbox_inches='tight', dpi=800)

    # Plot PCA mode 4
    sns.set_theme(style="white", palette=color_palette, font_scale=fontscale)
    fig, ax = plt.subplots(figsize=(10,3))
    g = sns.kdeplot(
        x='pca mode 4',
        hue=hue_type,
        fill=True, 
        hue_order=hue_order,
        data=select_cf_info_with_pca, 
        ax=ax, 
        common_norm=False,
        )
    g.get_legend().set_title(None)
    g.spines[['right', 'top']].set_visible(False)
    title_4 = 'pca4_'+'_'.join(select_keys)
    _fig_path = os.path.join(_save_fig_dir, 'pca4')
    os.makedirs(_fig_path, exist_ok = True)
    fig.savefig(os.path.join(_fig_path, title_4+'.pdf'), bbox_inches='tight', dpi=800)
    fig.savefig(os.path.join(_fig_path, title_4+'.png'), bbox_inches='tight', dpi=800)

    plt.close('all')
    return pca_trained

def get_embbedings_for_attribute_a_on_CFs_intervened_on_attribute_b(attribute_a, attribute_b):
    cf_info = {
        'embeddings':[], 
        attribute_a:[],
            }
    # Select categories for attribute A 
    match attribute_a:
        case "sex":
            attr_a_category = sex_categories
        case "finding":
            attr_a_category = finding_categories
        case "race":
            attr_a_category = race_categories

    # Intervened on Attribute B for subgroups classified upon Attribute A.
    for attr_a in attr_a_category:
        with torch.no_grad():
            dag_variables = list(model.pgm.variables.keys())
            count=0
            for batch in tqdm(subloaders[attr_a]):
                if count>_N_sample:
                    break
                batch = preprocess(batch)  
                pa = {k: v for k, v in batch.items() if k in dag_variables}  
                img = batch['x']
                img = (img+ 1) * 127.5
                img = img.repeat(1,3,1,1)
                _orig_embeddings = densenet.backbone.forward(img)
                for _i in range(len(_orig_embeddings)):
                    cf_info[attribute_a].append(attr_a)
                    cf_info['embeddings'].append(_orig_embeddings[_i].cpu().numpy())
                count+=1

    for attr_a in attr_a_category:
        n_train = len(subloaders[attr_a].dataset)
        train_samples = copy.deepcopy(subloaders[attr_a].dataset.samples)
        for k in train_samples.keys():
            if k in ['sex', 'finding', 'age', 'race']:
                train_samples[k]=torch.from_numpy(np.array(train_samples[k]))
        with torch.no_grad():
            match attribute_b:
                case 'race':
                    cf_title = f"{attr_a} do(race)"
                case 'sex':
                    cf_title = f"{attr_a} do(sex)"
                case 'finding':
                    cf_title = f"{attr_a} do(disease)"
                case 'age':
                    cf_title = f"{attr_a} do(age)"

            dag_variables = list(model.pgm.variables.keys())
            count=0
            for batch in tqdm(subloaders[attr_a]):
                if count>_N_sample:
                    break
                bs = batch['x'].shape[0]
                batch = preprocess(batch)  
                pa = {k: v for k, v in batch.items() if k in dag_variables}  

                # Intervene on Attribute B
                do = {}
                do_k = attribute_b
                match attribute_b:
                    case "sex":
                        do[do_k] = 1-pa[do_k] 
                    case "finding":
                        do[do_k] = 1-pa[do_k] 
                    case "race":
                        batch_r = copy.deepcopy(batch)[do_k]
                        batch_r = torch.argmax(batch_r, dim=-1)
                        race_ = torch.bernoulli(1/2*torch.ones_like(batch_r))
                        _do_r= (batch_r+race_+1)%3
                        _do_r = F.one_hot(_do_r.long(), num_classes=3).squeeze().float()
                        do[do_k]=_do_r
                    case "age":
                        do[do_k] = train_samples[do_k].clone()[torch.randperm(n_train)][:bs].unsqueeze(1)
                        do[do_k] = (do[do_k] / 100)*2 -1
                do = preprocess(do)
                # get counterfactual x
                out_cf = model.forward(batch, do, elbo_fn, cf_particles=args.cf_particles)
                img = out_cf['cfs']['x']
                img = (img+ 1) * 127.5
                img = img.repeat(1,3,1,1)
                _cf_embeddings = densenet.backbone.forward(img)
                for _i in range(len(_cf_embeddings)):
                    cf_info[attribute_a].append(cf_title)
                    cf_info['embeddings'].append(_cf_embeddings[_i].cpu().numpy())
                    # cf_info['sex'].append(sex_categories[int(batch['sex'][_i].item())])
                    # cf_info['finding'].append("CF_"+finding_categories[int(batch['finding'][_i].item())])
                count+=1
    cf_info = pd.DataFrame.from_dict(cf_info) 
    return cf_info


### Get embeddings for Race/Sex/Finding subgroups for CFs intervened on Race/Finding/Sex/Age

In [None]:
def PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a,
    attribute_b
    ):
    # Get embeddings
    cf_info = get_embbedings_for_attribute_a_on_CFs_intervened_on_attribute_b(
        attribute_a = attribute_a,
        attribute_b = attribute_b,
        )

    # PCA for original groups
    match attribute_a:
        case "race":
            select_keys = ['White', "Black", "Asian"]
        case "sex":
            select_keys = ['Male', "Female"]
        case "finding":
            select_keys = ['No disease', 'Pleural Effusion']
    
    _ = plot_pca(
        cf_info=cf_info,
        select_keys=select_keys,
        hue_order=select_keys,
        attribute_a=attribute_a,
        )
    # PCA for attribute a on CFs intervened on attribute b
    match attribute_b:
        case "race":
            _ATTRI_B = 'race'
        case "sex":
            _ATTRI_B = "sex"
        case "age":
            _ATTRI_B = "age"
        case "finding":
            _ATTRI_B = "disease"

    match attribute_a:
        case "race":
            select_keys_groups = [
                ['White', 'Black', 'Asian', f'White do({_ATTRI_B})'],
                ['White', 'Black', 'Asian', f'Asian do({_ATTRI_B})'],
                ['White', 'Black', 'Asian', f'Black do({_ATTRI_B})'],
                ['White', 'Black', 'Asian', f'White do({_ATTRI_B})', f'Black do({_ATTRI_B})', f'Asian do({_ATTRI_B})'],
                ]
        case "sex":
            select_keys_groups = [
                ['Male', "Female", f"Male do({_ATTRI_B})"],
                ['Male', "Female", f"Female do({_ATTRI_B})"],
                ['Male', "Female", f"Male do({_ATTRI_B})", f"Female do({_ATTRI_B})"],
                ]
        case "finding":
             select_keys_groups = [
                ['Pleural Effusion', "No disease", f"Pleural Effusion do({_ATTRI_B})"],
                ['Pleural Effusion', "No disease", f"No disease do({_ATTRI_B})"],
                ['Pleural Effusion', "No disease", f"Pleural Effusion do({_ATTRI_B})", f"No disease do({_ATTRI_B})"],
                ]
    
    for _select_keys in select_keys_groups:
        _ = plot_pca(
            cf_info=cf_info,
            select_keys=_select_keys,
            hue_order=_select_keys,
            attribute_a=attribute_a,
        )

### PCAs on different attributes A on CFs of different attribute B

In [None]:

"""Race subrgoups"""
# pca for race subgroups on sex CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='race',
    attribute_b='sex'
)

# pca for race subgroups on finding CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='race',
    attribute_b='finding'
)

"""Sex subrgoups"""
# pca for sex subgroups on race CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='sex',
    attribute_b='race'
)

# pca for sex subgroups on finding CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='sex',
    attribute_b='finding'
)

"""Finding subrgoups"""
# pca for finding subgroups on race CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='finding',
    attribute_b='race'
)

# pca for finding subgroups on sex CFs
PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='finding',
    attribute_b='sex'
)


PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='finding',
    attribute_b='finding'
)

PCA_for_Attribute_A_on_CFs_intervened_on_Attribute_B(
    attribute_a='sex',
    attribute_b='sex'
)

