In [None]:
import sys
import torchvision
import torch
import os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import torch.nn.functional as F
sys.path.append('..')
from train_setup import setup_directories, setup_tensorboard, setup_logging
# From datasets import get_attr_max_min
from vae import HVAE
from train_pgm import preprocess, sup_epoch, eval_epoch
from utils_pgm import plot_cf, check_nan, update_stats
from layers import TraceStorage_ELBO
from flow_pgm import FlowPGM
from train_cf import DSCM
from PIL import Image
import pandas as pd
from skimage.io import imread, imsave


In [None]:
def norm(batch):
    for k, v in batch.items():
        if k in ['x', 'cf_x']:
            batch[k] = (batch[k].float() - 127.5) / 127.5  # [-1,1]
        elif k in ['age']:
            batch[k] = batch[k].float().unsqueeze(-1)
            batch[k] = batch[k] / 100.
            batch[k] = batch[k] *2 -1 #[-1,1]
        elif k in ['race']:
            batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
        elif k in ['finding']:
            batch[k] = batch[k].unsqueeze(-1).float()
        else:
            try:
                batch[k] = batch[k].float().unsqueeze(-1)
            except:
                batch[k] = batch[k]
    return batch

def loginfo(title, logger, stats):
    logger.info(f'{title} | ' +
                ' - '.join(f'{k}: {v:.4f}' for k, v in stats.items()))

def inv_preprocess(pa):
    # Undo [-1,1] parent preprocessing back to original range
    for k, v in pa.items():
        if k =='age':
            pa[k] = (v + 1) / 2 * 100
    return pa


def vae_preprocess(args, pa):
    pa = torch.cat([pa[k] for k in args.parents_x], dim=1)
    pa = pa[..., None, None].repeat(
        1, 1, *(args.input_res,)*2).float()
    return pa


def get_metrics(preds, targets):
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')
    stats = {}
    for k in preds.keys():
        if k=="age":
            preds_k = (preds[k] + 1) / 2 *100  # [-1,1] -> [0,100]
            stats[k+'_mae'] = torch.mean(
                torch.abs(targets[k] - preds_k)).item() 
    return stats

class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

In [None]:
sex_categories = ['male', 'female']  # 0,1
race_categories = ['White', 'Asian', 'Black']  # 0,1,2
finding_categories = ['No disease', 'Pleural Effusion']

### Set predictor

In [None]:
# Load predictors
args = Hparams()
args.predictor_path = 'PREDICTOR_PATH'
# args.predictor_path = '../../checkpoints/a_r_s_f/mimic_classifier_resnet34_l2_lr4_slurm/checkpoint.pt'
predictor_checkpoint = torch.load(args.predictor_path)
args.update(predictor_checkpoint['hparams'])
args.use_dataset = 'mimic'
args.csv_dir =  "../mimic_meta"
args.data_dir = "DATA_DIR"
args.loss_norm = "l2"
predictor = FlowPGM(args)
predictor.load_state_dict(predictor_checkpoint['ema_model_state_dict'])

### Set PGM

In [None]:
# Load PGM
args.pgm_path = '../../checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt'
print(f'\nLoading PGM checkpoint: {args.pgm_path}')
pgm_checkpoint = torch.load(args.pgm_path)
pgm_args = Hparams()
pgm_args.update(pgm_checkpoint['hparams'])
pgm = FlowPGM(pgm_args)
pgm.load_state_dict(pgm_checkpoint['ema_model_state_dict'])

### Set VAE

In [None]:
args.vae_path = '../../checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt'

print(f'\nLoading VAE checkpoint: {args.vae_path}')
vae_checkpoint = torch.load(args.vae_path)
vae_args = Hparams()
vae_args.update(vae_checkpoint['hparams'])
vae = HVAE(vae_args)
vae.load_state_dict(vae_checkpoint['ema_model_state_dict'])

### Load DSCM

In [None]:
args = Hparams()
dscm_dir = "DSCM_DIR"
which_checkpoint="WHICH_CHECKPOINT_TO_USE"

args.load_path = f'../../checkpoints/a_r_s_f/{dscm_dir}/{which_checkpoint}.pt'
dscm_checkpoint = torch.load(args.load_path)
args.update(dscm_checkpoint['hparams'])
model = DSCM(args, pgm, predictor, vae)
args.cf_particles =1
model.load_state_dict(dscm_checkpoint['ema_model_state_dict'])
model.cuda()
elbo_fn = TraceStorage_ELBO(num_particles=1)

# Set model require_grad to False
for p in model.parameters():
    p.requires_grad = False

### Load subgroup data

In [None]:
# Load subgroup
args.bs = 1
args.input_res =192
from train_pgm import setup_dataloaders

dataloaders = setup_dataloaders(args)

print(len(dataloaders['train'].dataset), len(dataloaders['valid'].dataset), len(dataloaders['test'].dataset))

### Save counterfactuals

In [None]:
def save_cf(save_path_cf, cf_x):
    _x = (cf_x.squeeze(0).squeeze(0).detach().cpu().numpy() + 1) * 127.5
    imsave(save_path_cf, _x.astype(np.uint8))

In [None]:
mode="test"

race_categories = ['White', 'Asian', 'Black']
sex_categories = ['Male', 'Female']
finding_categories = ['No_disease', 'Pleural_Effusion']

transf = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
    ])

save_dir = f"WHERE_TO_SAVE/CF_DATA/{dscm_dir}/{which_checkpoint}/{mode}"
os.makedirs(save_dir, exist_ok=True)

for _race in race_categories:
    os.makedirs(os.path.join(save_dir, f"cf_{_race}"), exist_ok=True)

for _sex in sex_categories:
    os.makedirs(os.path.join(save_dir, f"cf_{_sex}"), exist_ok=True)

for _finding in finding_categories:
    os.makedirs(os.path.join(save_dir, f"cf_{_finding}"), exist_ok=True)
os.makedirs(os.path.join(save_dir, f"cf_Null"), exist_ok=True)


save_dict = {k: [] for k in ['sex', 'race','finding', 'age', 'dicom_id', 
                            'study_id', 'path_preproc',
                            'path_cf_White', 'path_cf_Asian', 'path_cf_Black',
                            'path_cf_Male', 'path_cf_Female', 'path_cf_Null', 
                            'path_cf_No_disease', 'path_cf_Pleural_Effusion',
                             ]}
                            
for batch in tqdm(dataloaders[mode]):
    with torch.no_grad():
        dag_variables = list(model.pgm.variables.keys())
        # Conditions to be saved
        _sex = int(batch['sex'].item())
        _finding = int(batch['finding'].item())
        _age = int((batch['age'].item()+1)*50)
        _race = int(batch['race'].argmax())
        _dicom_id = batch['dicom_id'][0]
        _study_id = batch['study_id'][0]

        save_dict['sex'].append(_sex)
        save_dict['race'].append(_race)
        save_dict['finding'].append(_finding)
        save_dict['age'].append(_age)
        save_dict['path_preproc'].append(batch['path_preproc'][0])
        save_dict['dicom_id'].append(_dicom_id)
        save_dict['study_id'].append(_study_id)

        batch = preprocess(batch)  
        pa = {k: v for k, v in batch.items() if k in dag_variables}  
        # Generate cfs for race
        for cf_race in race_categories:
            if cf_race==race_categories[_race]:
                save_dict[f'path_cf_{race_categories[_race]}'].append('None')
                continue
            do = {}
            do_k = 'race'
            do[do_k] = F.one_hot(torch.tensor(int(race_categories.index(cf_race))), num_classes=3).repeat(len(batch[do_k]),1)
            do = preprocess(do) # move do to gpu
            # get counterfactual x
            out_cf = model.forward(batch, do, elbo_fn, cf_particles=args.cf_particles)
            cf_x = out_cf['cfs']['x']
            cf_x = transf(cf_x) # transform cf_x from (1,1,192,192) to (1,1,224, 224)
            save_path_cf = os.path.join(save_dir, f"cf_{cf_race}", f"s{_study_id}_{_dicom_id}_cf_{cf_race}.jpg")
            save_cf(save_path_cf=save_path_cf, cf_x=cf_x)
            save_dict[f'path_cf_{cf_race}'].append(save_path_cf)

        # Generate null cfs
        cf_race=race_categories[_race]
        do = {}
        do_k = 'race'
        do[do_k] = F.one_hot(torch.tensor(int(race_categories.index(cf_race))), num_classes=3).repeat(len(batch[do_k]),1)
        do = preprocess(do) # move do to gpu
        # get counterfactual x
        out_cf = model.forward(batch, do, elbo_fn, cf_particles=1)
        cf_x = out_cf['cfs']['x']
        cf_x = transf(cf_x) # transform cf_x from (1,1,192,192) to (1,1,224, 224)
        save_path_cf = os.path.join(save_dir, f"cf_Null", f"s{_study_id}_{_dicom_id}_cf_Null.jpg")
        save_cf(save_path_cf=save_path_cf, cf_x=cf_x)
        save_dict[f'path_cf_Null'].append(save_path_cf)
    
        # Generate cfs for sex
        for cf_sex in sex_categories:
            if cf_sex==sex_categories[_sex]:
                save_dict[f'path_cf_{sex_categories[_sex]}'].append('None')
                continue
            do = {}
            do_k = 'sex'
            do[do_k] = torch.tensor(int(sex_categories.index(cf_sex))).repeat(len(batch[do_k]),1)
            do = preprocess(do) # move do to gpu
            # get counterfactual x
            out_cf = model.forward(batch, do, elbo_fn, cf_particles=1)
            cf_x = out_cf['cfs']['x']
            cf_x = transf(cf_x) # transform cf_x from (1,1,192,192) to (1,1,224, 224)
            save_path_cf = os.path.join(save_dir, f"cf_{cf_sex}", f"s{_study_id}_{_dicom_id}_cf_{cf_sex}.jpg")
            save_cf(save_path_cf=save_path_cf, cf_x=cf_x)
            save_dict[f'path_cf_{cf_sex}'].append(save_path_cf)
        
        # Generate cfs for finding
        for cf_finding in finding_categories:
            if cf_finding==finding_categories[_finding]:
                save_dict[f'path_cf_{finding_categories[_finding]}'].append('None')
                continue
            do = {}
            do_k = 'finding'
            do[do_k] = torch.tensor(int(finding_categories.index(cf_finding))).repeat(len(batch[do_k]),1)
            do = preprocess(do) # move do to gpu
            # get counterfactual x
            out_cf = model.forward(batch, do, elbo_fn, cf_particles=1)
            cf_x = out_cf['cfs']['x']
            cf_x = transf(cf_x) # transform cf_x from (1,1,192,192) to (1,1,224, 224)
            save_path_cf = os.path.join(save_dir, f"cf_{cf_finding}", f"s{_study_id}_{_dicom_id}_cf_{cf_finding}.jpg")
            save_cf(save_path_cf=save_path_cf, cf_x=cf_x)
            save_dict[f'path_cf_{cf_finding}'].append(save_path_cf)
csv_file = os.path.join(save_dir, f'{mode}_cfs.csv' )
df = pd.DataFrame.from_dict(save_dict)
df.to_csv(csv_file)        
                