# Measure Conditional Average Treatment Effect for 27 Typologies

In [1]:
import os, glob
import numpy as np
import astropy.table as aTable
from tqdm.notebook import tqdm, trange

In [2]:
import torch
from sbi import utils as Ut
from sbi import inference as Inference
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [3]:
import corner as DFM
# --- plotting ---
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.xmargin'] = 1
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['xtick.major.size'] = 5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['ytick.major.size'] = 5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['legend.frameon'] = False

In [4]:
if torch.cuda.is_available(): device = 'cuda'
else: device = 'cpu'

## load compiled data set
See data section in paper

In [5]:
dat_dir = '/scratch/gpfs/chhahn/noah/' # directory with data

data = aTable.Table.read(os.path.join(dat_dir, 'zipcode.fema.fsf.acs.rainfall.v2.csv'), format='csv')
print('%i entries; %i unique zipcodes' % (len(data), len(np.unique(data['reportedZipcode']))))

74288 entries; 14729 unique zipcodes


In [6]:
control = (data['communityRatingSystemDiscount'] == 11.)
print(f"{np.sum(control)} entries in control; {len(np.unique(data['reportedZipcode'][control]))} unique zipcodes")

treat = (data['communityRatingSystemDiscount'] < 11.)
print(f"{np.sum(treat)} entries in control; {len(np.unique(data['reportedZipcode'][treat]))} unique zipcodes")

43644 entries in control; 11258 unique zipcodes
30644 entries in control; 5040 unique zipcodes


In [7]:
columns = ['amountPaidOnTotalClaim_per_policy', 'mean_rainfall', 'avg_risk_score_all', 
           'median_household_income', 'population', 'renter_fraction', 'educated_fraction', 'white_fraction'] 

In [8]:
control_data = np.vstack([np.ma.getdata(data[col][control].data) for col in columns]).T
control_data[:,0] = np.log10(control_data[:,0])
control_data[:,3] = np.log10(control_data[:,3])
control_data[:,4] = np.log10(control_data[:,4])

treat_data = np.vstack([np.ma.getdata(data[col][treat].data) for col in columns]).T
treat_data[:,0] = np.log10(treat_data[:,0])
treat_data[:,3] = np.log10(treat_data[:,3])
treat_data[:,4] = np.log10(treat_data[:,4])

In [9]:
_control_data = control_data[:,1:].copy()
_control_data[:,2] = 10**_control_data[:,2]
_control_data[:,3] = 10**_control_data[:,3]

_treat_data = treat_data[:,1:].copy()
_treat_data[:,2] = 10**_treat_data[:,2]
_treat_data[:,3] = 10**_treat_data[:,3]

## load $\mathcal{Q}^C_\phi(X)$ and $\mathcal{Q}^T_\phi(X)$ for testing covariate support
See Appendix B for details on $\mathcal{Q}^C_\phi(X)$ and $\mathcal{Q}^T_\phi(X)$

In [10]:
support_control = torch.load('/scratch/gpfs/chhahn/noah/qphi_support.v2.log.control.pt', 
                             map_location=torch.device(device))
support_treat = torch.load('/scratch/gpfs/chhahn/noah/qphi_support.v2.log.treat.pt', 
                           map_location=torch.device(device))

In [11]:
def within_support(covar, thresh=-10): 
    logcond = covar.copy()
    logcond[2] = np.log10(logcond[2])
    logcond[3] = np.log10(logcond[3])
    logprob_control = support_control.log_prob(
        torch.tensor(logcond.astype(np.float32)).to(device)).detach().cpu()[0]
    logprob_treat = support_treat.log_prob(torch.tensor(logcond.astype(np.float32)).to(device)).detach().cpu()[0]
    #print(logprob_control, logprob_treat)
    return [False, True][(logprob_control > thresh) & (logprob_treat > thresh)]

## read $q_\phi$

In [12]:
def read_best_models(study_name): 
    fevents = glob.glob(os.path.join(dat_dir, 'nde/%s/*/events*' % study_name))

    events, best_valid = [], []
    for fevent in fevents: 
        ea = EventAccumulator(fevent)
        ea.Reload()

        try: 
            best_valid.append(ea.Scalars('best_validation_log_prob')[0].value)
            events.append(fevent)
        except: 
            pass #print(fevent)
    best_valid = np.array(best_valid)

    print('%i models trained' % np.max([int(os.path.dirname(event).split('.')[-1]) for event in events]))
    
    i_models = [int(os.path.dirname(events[i]).split('.')[-1]) for i in np.argsort(best_valid)[-5:][::-1]]
    print(i_models) 
    
    qphis = []
    for i_model in i_models: 
        fqphi = os.path.join(dat_dir, 'nde/%s/%s.%i.pt' % (study_name, study_name, i_model))
        qphi = torch.load(fqphi, map_location=device)
        qphis.append(qphi)
    return qphis

In [13]:
qphis_control = read_best_models('control.v2.made')
qphis_treat = read_best_models('treat.v2.made')

2313 models trained
[680, 1365, 1087, 78, 7]
2609 models trained
[462, 1287, 1547, 2243, 269]


# CATE for 27 typologies
and save to file 

In [14]:
income_low, income_mid, income_high = 4e4, 6e4, 9e4      # roughly 16, 50 and 84 percentile
population_low, population_mid, population_high = 2.5e3, 1.2e4, 3e4     # roughly 16, 50, 84 percentile
diversity_low, diversity_mid, diversity_high = 0.6, 0.85, 0.95      # roughly 16, 50, 84 percentile

In [17]:
_ranges = [(0., 500.), (0.5, 5.), (2.5e4, 1e5), (0, 4.5e4), (0., 0.8), (0., 0.6), (0.0, 1.)]
n_sample = 40000 

cv_cates, cates, treats, controls, sig_cates = [], [], [], [], []
for ii_covar, i_covar in enumerate([0, 1, 4, 5]):    
    for i_inc, inc in enumerate([income_low, income_mid, income_high]): 
        for i_pop, pop in enumerate([population_low, population_mid, population_high]): 
            for i_div, div in enumerate([diversity_low, diversity_mid, diversity_high]):             
                near = ((np.abs(_control_data[:,2] - inc) < 5e3) &
                        (np.abs(_control_data[:,3] - pop) < 5e3) & 
                        (np.abs(_control_data[:,-1] - div) < 0.05))
            
                fid = np.median(_control_data[near], axis=0)
                fid[2] = inc
                fid[3] = pop
                fid[-1] = div
                
                xs_covar = np.linspace(_ranges[i_covar][0], _ranges[i_covar][1], 10)
                                
                cv_cate, cate, treat, control = [], [], [], []
                sig_cate = []
                for _covar in np.concatenate([xs_covar]): 
                    covars = fid.copy()
                    covars[i_covar] = _covar

                    # make sure it's within support
                    if not within_support(covars): continue

                    covars[2] = np.log10(covars[2])
                    covars[3] = np.log10(covars[3])        

                    treat_samp, control_samp = [], []
                    for qphi_treat in qphis_treat: 
                        _samp = qphi_treat.sample((int(n_sample/len(qphis_treat)),),
                                                   x=torch.tensor(covars, dtype=torch.float32).to(device), 
                                                   show_progress_bars=False)    
                        treat_samp.append(_samp.detach().cpu().numpy())
                    treat_samp = np.array(treat_samp).flatten()

                    for qphi_control in qphis_control: 
                        _samp = qphi_control.sample((int(n_sample/len(qphis_treat)),), 
                                                    x=torch.tensor(covars, dtype=torch.float32).to(device), 
                                                    show_progress_bars=False)
                        control_samp.append(_samp.detach().cpu().numpy())
                    control_samp = np.array(control_samp).flatten()

                    cv_cate.append(_covar)
                    cate.append(np.mean(10**treat_samp) - np.mean(10**control_samp))
                    treat.append(np.mean(10**treat_samp))
                    control.append(np.mean(10**control_samp))          
                    
                    # uncertainty of CATE
                    sig_treat = np.std(10**treat_samp)/np.sqrt(float(n_sample))
                    sig_control = np.std(10**control_samp)/np.sqrt(float(n_sample))
                    #print(np.sqrt(sig_treat**2 + sig_control**2))
                    sig_cate.append(np.sqrt(sig_treat**2 + sig_control**2))
                    
                cv_cates.append(cv_cate)
                cates.append(cate)
                treats.append(treat)
                controls.append(control)  
                sig_cates.append(sig_cate)
                
np.save('cv_cates.npy', np.array(cv_cates))
np.save('cates.npy', np.array(cates))
np.save('treats.npy', np.array(treats))
np.save('controls.npy', np.array(controls))
np.save('sig_cates.npy', np.array(sig_cates))

