In [None]:
import sys
import argparse
import random
import copy
import pyro
import torch
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import torch.nn.functional as F
sys.path.append('..')
from train_setup import setup_directories, setup_tensorboard, setup_logging
from train_pgm import setup_dataloaders
# From datasets import get_attr_max_min
from utils import EMA, seed_all
from vae import HVAE
from train_pgm import preprocess, sup_epoch, eval_epoch
from utils_pgm import plot_cf, check_nan, update_stats
from layers import TraceStorage_ELBO

# from flow_pgm import FlowPGM_full as FlowPGM
from flow_pgm import FlowPGM

from train_cf import DSCM
# from flow_pgm import FlowPGM_without_finding as FlowPGM

def norm(batch):
    for k, v in batch.items():
        if k == 'x':
            batch['x'] = (batch['x'].float() - 127.5) / 127.5  # [-1,1]
        elif k in ['age']:
            batch[k] = batch[k].float().unsqueeze(-1)
            batch[k] = batch[k] / 100.
            batch[k] = batch[k] *2 -1 #[-1,1]
        elif k in ['race']:
            batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
        elif k in ['finding']:
            batch[k] = batch[k].unsqueeze(-1).float()
        else:
            batch[k] = batch[k].float().unsqueeze(-1)
    return batch

def loginfo(title, logger, stats):
    logger.info(f'{title} | ' +
                ' - '.join(f'{k}: {v:.4f}' for k, v in stats.items()))

def inv_preprocess(pa):
    # Undo [-1,1] parent preprocessing back to original range
    for k, v in pa.items():
        if k =='age':
            pa[k] = (v + 1) / 2 * 100
    return pa


def vae_preprocess(args, pa):
    pa = torch.cat([pa[k] for k in args.parents_x], dim=1)
    pa = pa[..., None, None].repeat(
        1, 1, *(args.input_res,)*2).cuda().float()
    return pa


def get_metrics(preds, targets):
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')
    stats = {}
    for k in preds.keys():
        # if k == 'mri_seq' or k == 'sex':
        #     stats[k+'_rocauc'] = roc_auc_score(
        #         targets[k].numpy(), preds[k].numpy(), average='macro')
        #     stats[k+'_acc'] = (targets[k] == torch.round(preds[k])
        #                        ).sum().item() / targets[k].shape[0]
        if k=="age":
            preds_k = (preds[k] + 1) / 2 *100  # [-1,1] -> [0,100]
            stats[k+'_mae'] = torch.mean(
                torch.abs(targets[k] - preds_k)).item() 
    return stats

class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

### Set predictor

In [None]:
PREDICTOR_PATH= 'PREDICTOR_PATH'
PREDICTOR_FOR_EVALUATION_PATH = 'PREDICTOR_FOR_EVALUATION_PATH'


In [None]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

# Load predictors
args = Hparams()
args.predictor_path = PREDICTOR_PATH

predictor_checkpoint = torch.load(args.predictor_path )
args.update(predictor_checkpoint['hparams'])
predictor_for_evaluation = FlowPGM(args).cuda()
predictor_for_evaluation.load_state_dict(predictor_checkpoint['ema_model_state_dict'])

# Load predictors
args = Hparams()
args.predictor_path = PREDICTOR_FOR_EVALUATION_PATH

predictor_checkpoint = torch.load(args.predictor_path )
args.update(predictor_checkpoint['hparams'])

args.use_dataset = 'mimic'
args.csv_dir =  "../mimic_meta"
args.data_dir = "DATA_DIR"
args.loss_norm = "l2"
predictor = FlowPGM(args).cuda()
predictor.load_state_dict(predictor_checkpoint['model_state_dict'])

args.bs = 20
args.input_res= 224
print(args.use_dataset)

dataloaders = setup_dataloaders(args)
elbo_fn = TraceStorage_ELBO(num_particles=1)

test_stats = eval_epoch(
    predictor_for_evaluation, 
    dataloaders['valid']
)

for k,v in test_stats.items():
    print(f"{k}: {v:.3f} ")


### Set PGM

In [None]:
# Load PGM
args.pgm_path = 'PATH OF SAVED PGM'
print(f'\nLoading PGM checkpoint: {args.pgm_path}')
pgm_checkpoint = torch.load(args.pgm_path)
pgm_args = Hparams()
pgm_args.update(pgm_checkpoint['hparams'])
pgm = FlowPGM(pgm_args).cuda()
pgm.load_state_dict(pgm_checkpoint['ema_model_state_dict'])


### Set VAE

In [None]:
args.vae_path = 'VAE_PATH'

print(f'\nLoading VAE checkpoint: {args.vae_path}')
vae_checkpoint = torch.load(args.vae_path)
vae_args = Hparams()
vae_args.update(vae_checkpoint['hparams'])
vae = HVAE(vae_args).cuda()
vae.load_state_dict(vae_checkpoint['ema_model_state_dict'])

### Load DSCM

In [None]:
DSCM_DIR = "the path where DSCM is saved"
WHICH_CHECKPOIBT = "which checkpoint to choose"

In [None]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

args = Hparams()

dscm_dir = "DSCM_DIR"
which_checkpoint="WHICH_CHECKPOIBT"

args.load_path = f'../../checkpoints/a_r_s_f/{dscm_dir}/{which_checkpoint}.pt'
print(args.load_path)
dscm_checkpoint = torch.load(args.load_path )
args.update(dscm_checkpoint['hparams'])
model = DSCM(args, pgm, predictor, vae)
args.cf_particles =1
model.load_state_dict(dscm_checkpoint['ema_model_state_dict'])
model.cuda()

# Set model require_grad to False
for p in model.parameters():
    p.requires_grad = False


### Evaluate Counterfactual

In [None]:
import torchvision

transf_224 = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
    ])

transf_192 = torchvision.transforms.Compose([
        torchvision.transforms.Resize((192, 192)),
    ])


model.pgm.eval()
model.predictor.eval()
model.vae.eval()
dag_variables = list(model.pgm.variables.keys())
preds = {k: [] for k in dag_variables}
targets = {k: [] for k in dag_variables}
args.save_dir = f"../../results/{dscm_dir}/{which_checkpoint}"
os.makedirs(args.save_dir , exist_ok=True)
# loader = tqdm(enumerate(dataloaders['test']), total=len(
#     dataloaders['test']), mininterval=0.1)

@torch.no_grad()
def eval_counterfactuals(model, dataloader, predictor, do_pa=None):
    ' this can consume lots of memory if dataset is large'
    model.pgm.eval()
    model.predictor.eval()
    predictor.eval()
    model.vae.eval()
    dag_variables = list(model.pgm.variables.keys())
    preds = {k: [] for k in dag_variables}
    targets = {k: [] for k in dag_variables}
    plt_counter = 0
    cf_particles=1

    for batch in tqdm(dataloader):
        # if plt_counter>10:
        #     continue
        plt_counter+=1
        bs = batch['x'].shape[0]
        batch = preprocess(batch)
        batch['x'] = transf_192(batch['x'])
        # randomly intervene on a single parent, where pa_k ~ p(pa_k)
        do = {}
        do_k = copy.deepcopy(do_pa) if do_pa else random.choice(dag_variables) 
        
        do[do_k] = train_samples[do_k].clone()[torch.randperm(n_train)][:bs]
        do = preprocess(norm(do))

        # get counterfactual pa
        pa = {k: v for k, v in batch.items() if k != 'x'}     
        _pa = vae_preprocess(
            args, {k: v.clone() for k, v in pa.items()})   
        # cf_pa = model.pgm.counterfactual(obs=pa, intervention=do, num_particles=1)       
        
        # get counterfactual x
        out = model.forward(batch, do, elbo_fn, cf_particles=cf_particles)
        cf_pa = out['cf_pa']

        nans = 0
        for k, v in out['cfs'].items():
        # for k, v in cfs.items():
            k_nans = torch.isnan(v).sum()
            nans += k_nans
            if k_nans > 0:
                print(f'\nFound {k_nans} nans in cf {k}.')
        if nans > 0:
            continue
        
        out['cfs']['x'] = transf_224(out['cfs']['x'])
        predict_out = predictor.predict(**out['cfs'])
        # predict_out = model.predictor.predict(**cfs)

        for k, v in predict_out.items():
            preds[k].extend(v)
        
        # interventions are the targets for prediction
        for k in targets.keys():
            if k in do.keys():
                targets[k].extend(
                    do[k]
                )
            else:
                targets[k].extend(
                    cf_pa[k]
                )
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')

    stats = {}
    for k in dag_variables:
        if k in ['sex', 'finding']:
            stats[k+'_acc'] = (targets[k].squeeze(-1) == torch.round(preds[k])).sum().item() / targets[k].shape[0]
            stats[k+'_rocauc'] = roc_auc_score(
                targets[k].numpy(), preds[k].numpy(), average='macro')
        elif k == 'age':
            stats[k] = torch.mean(torch.abs(targets[k] - preds[k])).item() * 50
        elif k == 'race':
            num_corrects = (targets[k].argmax(-1) == preds[k].argmax(-1)).sum()
            stats[k + "_acc"] = num_corrects.item() / targets[k].shape[0]
            # preds_k = F.one_hot(torch.argmax(F.softmax(preds[k], dim=-1), dim=-1))
            # stats[k+'_acc'] = accuracy_score(targets[k].to(torch.int32) ,preds_k.to(torch.int32))
            # stats[k+'_rocauc'] = roc_auc_score(
            #     targets[k].to(torch.int32) ,preds_k.to(torch.int32),multi_class="ovr",average="macro",
            # )
            stats[k + "_rocauc"] = roc_auc_score(
                targets[k].numpy(),
                preds[k].numpy(),
                multi_class="ovr",
                average="macro",)
    return stats, preds, targets

@torch.no_grad()
def eval_random(model, dataloader, do_pa=None):
    ' this can consume lots of memory if dataset is large'
    model.pgm.eval()
    model.predictor.eval()
    model.vae.eval()
    dag_variables = list(model.pgm.variables.keys())
    preds = {k: [] for k in dag_variables}
    targets = {k: [] for k in dag_variables}
    plt_counter = 0

    for batch in tqdm(dataloader):
        plt_counter+=1
        bs = batch['x'].shape[0]
        # randomly intervene on a single parent, where pa_k ~ p(pa_k)
        do = {}
        do_k = copy.deepcopy(do_pa) if do_pa else random.choice(dag_variables) 
        do[do_k] = train_samples[do_k].clone()[torch.randperm(n_train)][:bs]

        # get counterfactual pa
        batch = preprocess(batch)
        do = preprocess(norm(do))
        pa = {k: v for k, v in batch.items() if k != 'x'}        
        cf_pa = model.pgm.counterfactual(obs=pa, intervention=do, num_particles=1)       
        
        # get counterfactual x
        _cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()})   
        cf_loc, _ = model.vae.sample(parents=_cf_pa, return_loc=True)

        cf_x = cf_loc

        cfs = {'x': cf_x.clamp(min=-1, max=1)}
        cfs.update(cf_pa)

        nans = 0
        for k, v in cfs.items():
            k_nans = torch.isnan(v).sum()
            nans += k_nans
            if k_nans > 0:
                print(f'\nFound {k_nans} nans in cf {k}.')
        if nans > 0:
            continue

        out = model.predictor.predict(**cfs)

        for k, v in out.items():
            preds[k].extend(v)
        
        # interventions are the targets for prediction
        for k in targets.keys():
            if k in do.keys():
                targets[k].extend(
                    do[k]
                )
                # print(f"do {k}: {do[k].size()}")
            else:
                targets[k].extend(
                    cf_pa[k]
                )
                # print(f"cf_pa {k}: {cf_pa[k].size()}")
        
        if plt_counter<2:
            pass
            # plot_cf_rec(batch['x'], cf_loc, pa, cf_pa, do, rec_loc)
            # plot_cf(batch['x'], cf_x, pa, cf_pa, do)

    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')

    stats = {}
    for k in dag_variables:
        if k in ['sex', 'finding']:
            stats[k+'_rocauc'] = roc_auc_score(
                targets[k].numpy(), preds[k].numpy(), average='macro')
            stats[k+'_acc'] = (targets[k].squeeze(-1) == torch.round(preds[k])).sum().item() / targets[k].shape[0]
        elif k == 'age':
            stats[k] = torch.mean(torch.abs(targets[k] - preds[k])).item() * 50
        elif k == 'race':
            preds_k = F.one_hot(torch.argmax(F.softmax(preds[k], dim=-1), dim=-1))
            # print(f"preds_k: {preds_k.size()}")
            stats[k+'_acc'] = accuracy_score(targets[k].to(torch.int32) ,preds_k.to(torch.int32))
    return stats



train_samples = copy.deepcopy(dataloaders['train'].dataset.samples)
for k in train_samples.keys():
    if k!="x":
        try:
            train_samples[k]=torch.from_numpy(np.array(train_samples[k]))
        except:
            train_samples[k]=train_samples[k]
print(f"train_samples: {train_samples.keys()}")
n_train = len(dataloaders['train'].dataset)

for k,v in train_samples.items():
    print(f"train samples {k}: {len(v)}")


In [None]:
del pgm, vae, predictor
stats_do = { 
    'race':{},
    'sex':{}, 
    'finding':{},
            }
preds_do = {
    'race':{}, 
    'sex':{}, 
    'finding':{},

            }
targets_do = {
    'race':{}, 
    'sex':{}, 
    'finding':{},

            }

for do_v in stats_do.keys():
    stats, preds, targets = eval_counterfactuals(model, dataloaders['valid'], predictor_for_evaluation, do_pa=do_v)
    stats_do[do_v] = stats


for do_v in stats_do.keys():
    print(f'do_{do_v} | '+' - '.join(f'{k}: {(v-test_stats[k])*100:.1f}' for k,v in stats_do[do_v].items()))

In [None]:
for do_v in stats_do.keys():
    print(f'do_{do_v} | '+' - '.join(f'{k}: {v:.3f}' for k,v in stats_do[do_v].items()))