In [None]:
import sys
import argparse
import copy
import torchvision
import torch
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'
import pyro
from sklearn.metrics import auc, recall_score, roc_curve, roc_auc_score
sys.path.append('../..')
from tqdm import tqdm
import torch.nn.functional as F
sys.path.append('..')
from train_setup import setup_directories, setup_logging
from utils import EMA
from train_pgm import preprocess
import pandas as pd
from layers import TraceStorage_ELBO
from torch.utils.data import DataLoader
from mimic import MimicDataset_with_cfs, MimicDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from train_pgm import setup_dataloaders
import scipy
import matplotlib.pyplot as plt
import torch.nn as nn
from flow_pgm import FlowPGM
from train_cls_cf_mimic import eval_epoch, sup_epoch

_DEVICE = "cuda:0"
_TEST_DATA = "real"

In [None]:
def norm(batch):
    for k, v in batch.items():
        if k in ['x', 'cf_x']:
            batch[k] = (batch[k].float() - 127.5) / 127.5  # [-1,1]
        elif k in ['age']:
            batch[k] = batch[k].float().unsqueeze(-1)
            batch[k] = batch[k] / 100.
            batch[k] = batch[k] *2 -1 #[-1,1]
        elif k in ['race']:
            batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
        elif k in ['finding']:
            batch[k] = batch[k].unsqueeze(-1).float()
        else:
            try:
                batch[k] = batch[k].float().unsqueeze(-1)
            except:
                batch[k] = batch[k]
    return batch

def loginfo(title, logger, stats):
    logger.info(f'{title} | ' +
                ' - '.join(f'{k}: {v:.4f}' for k, v in stats.items()))

def inv_preprocess(pa):
    # Undo [-1,1] parent preprocessing back to original range
    for k, v in pa.items():
        if k =='age':
            pa[k] = (v + 1) / 2 * 100
    return pa


def vae_preprocess(args, pa):
    pa = torch.cat([pa[k] for k in args.parents_x], dim=1)
    pa = pa[..., None, None].repeat(
        1, 1, *(args.input_res,)*2).float()
    return pa


def get_metrics(preds, targets):
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')
    stats = {}
    for k in preds.keys():
        if k=="age":
            preds_k = (preds[k] + 1) / 2 *100  # [-1,1] -> [0,100]
            stats[k+'_mae'] = torch.mean(
                torch.abs(targets[k] - preds_k)).item() 
    return stats

class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

### Set parameters

In [None]:

dscm_dir="WHICH_DSCM_WAS_USED_TO_GENERATE_CFS"
which_checkpoint="WHICH_CHECKPOINT"



parser = argparse.ArgumentParser()
args = parser.parse_known_args()[0]

args.use_data = "cf"
args.eval_data = "real"

# Which cf was used to train the model
args.which_cf ='race'
# args.which_cf ='sex'
# args.which_cf = 'finding'

args.setup = 'sup_determ'

elbo_fn = TraceStorage_ELBO(num_particles=2)


args.data_dir = '/vol/biodata/data/chest_xray/mimic-cxr-jpg-224/data/' 
args.lr = 1e-4
args.bs = 64
args.wd = 0.05
args.csv_dir = f"CF_DATA_DIR/{dscm_dir}/{which_checkpoint}"
args.parents_x = ['age','race', 'sex', 'finding']
args.enc_net = "resnet18"
args.epochs=1000
args.input_res = 224
args.eval_freq = 1
args.use_dataset='mimic_cfs' 
args.input_channels = 1
args.loss_norm="l2"


# The trained classifier on CFs
args.exp_name = f"mimic_train_{args.use_data}_{args.which_cf}_val_{args.eval_data}_classifier_resnet18_lr4_slurm_{args.setup}"


### Load predictors

In [None]:
pyro.clear_param_store()
model = FlowPGM(args)
ema = EMA(model, beta=0.999)
model.to(_DEVICE)
ema.to(_DEVICE)

# model_path = f"checkpoints/{dscm_dir}/{which_checkpoint}/a_r_s_f/{args.exp_name}/checkpoint_current.pt"
model_path = f"checkpoints/{dscm_dir}/{which_checkpoint}/a_r_s_f/{args.exp_name}/checkpoint.pt"

model_checkpoint = torch.load(model_path)
model.load_state_dict(model_checkpoint['ema_model_state_dict'])
model=model.to(_DEVICE)
print(model_path)

### Set dataloaders

In [None]:
args.bs = 20
dataloaders = setup_dataloaders(args)

In [None]:
print(args)

### Evaluate on test data

In [None]:
test_stats = eval_epoch(
                model = model, 
                dataloader = dataloaders['valid'],
                use_data=_TEST_DATA,
            )

In [None]:
print(f"{dscm_dir}: ")
print(f"{args.exp_name}: ")
print(f"{args.which_cf} {args.use_data} | "+" - ".join(f'{k}: {v:.3f}' for k,v in test_stats.items()))