In [1]:
import rasterio
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import os
import sys
sys.path.append('/home/esther/qr_for_landcover/scripts')
import landcover_definitions as lc
import util
import pickle

In [2]:
torchgeo_pred_dir = '/home/esther/torchgeo_predictions'
torchgeo_data_dir = '/home/esther/torchgeo_data'

In [3]:
classes_keep = [1,2,3,4,6]
ignore_index = len(classes_keep)

def reindex_ea(array_in, classes_keep):
    ignore_index = len(classes_keep)
    reindex_map = dict(zip(classes_keep, np.arange(len(classes_keep))))
    # reindex shrub to tree
    tree_idx = 3 # tree idx is 3 when there are no zeros
    shrub_idx = 5
    reindexed_mask = -1 * np.ones(array_in.shape)
    for old_idx, new_idx in reindex_map.items():
        reindexed_mask[array_in == old_idx] = new_idx

    reindexed_mask[reindexed_mask == -1] = ignore_index

    return reindexed_mask

# run the evaluation for Q and R

Uncomment one of 1-5 below to evaluate the corresponding experiment.


In [4]:
set_this = 'test'
results_by_state_q = {}
results_by_state_r = {}

states_to_eval = [
    'pittsburgh_pa-2010_1m', 
    'durham_nc-2012_1m',
    'austin_tx-2012_1m',
    'phoenix_az-2010_1m'
]

loss_to_eval_options = [
    'qr_forward',
  #  'qr_reverse'
]

prior_version = 'from_cooccurrences_101_31'
prior_name = f'prior_{prior_version}'
p_add_smooth = 1e-4

compute_r = True
for loss in loss_to_eval_options:
    results_by_state_q[loss] = {}
    results_by_state_r[loss] = {}
    
    for state in states_to_eval:

    
        data_dir = f'{torchgeo_data_dir}/enviroatlas'
        data_dir_this_set = os.path.join(data_dir,f'{state}-{set_this}_tiles-debuffered')
        
#         # 1. qr from pa checkpoint
#         run_name = f'pa_checkpoint_{state}_fcn_1e-05_{loss}_{prior_version}_additive_smooth_0.0001_prior_smooth_0.0001'
        
#         # 2. qr from scratch
#         if loss == 'qr_forward':
#             run_name = f'{state}_fcn_0.0001_{loss}_{prior_version}_additive_smooth_0.0001_prior_smooth_0.0001'
#         elif loss == 'qr_reverse':
#             run_name = f'{state}_fcn_0.001_{loss}_{prior_version}_additive_smooth_0.0001_prior_smooth_0.0001'
        
#         # 3. highres with prior as input
#         run_name = 'pittsburgh_pa-2010_1m_fcn_0.001_nll_with_prior'
        
#         # 4. highres without prior as input
#         run_name = 'pittsburgh_pa-2010_1m_fcn_0.001_nll'  
        
        # 5. output from learned priors
        prior_version = 'learned_101_31'
        prior_name = f'prior_{prior_version}'

        run_name = f'pa_checkpoint_{state}_fcn_1e-05_{loss}_{prior_version}_additive_smooth_0.0001_prior_smooth_0.0001'
        compute_r = True
        
        pred_dir = f'{torchgeo_pred_dir}/{run_name}/enviroatlas'
        pred_dir_this_set = os.path.join(pred_dir,f'{state}-{set_this}_tiles-debuffered')

        fns = os.listdir(data_dir_this_set)
        tile_ids = np.unique([x[:10] for x in fns])
        print(len(tile_ids))

        accs_q = []
        ious_q = []
        accs_r = []
        ious_r = []
        num_pix = []
        

        for tile_id in tile_ids:
            fn_this = os.path.join(data_dir_this_set, f'{tile_id}_h_highres_labels.tif')
            pred_fn_this = os.path.join(pred_dir_this_set, f'{tile_id}_{loss}_pred_last.tif')
            t1 = time.time()

            # gather the data
            with rasterio.open(fn_this) as f:
                hr_lc = f.read()[0]
            # reindex
            hr_lc = reindex_ea(lc.map_raw_lc_to_idx['enviroatlas'][hr_lc], classes_keep)

            preds_this_soft = rasterio.open(pred_fn_this).read()
                                        
            preds_this = preds_this_soft.argmax(0)
            acc_this_q = (np.array([hr_lc == preds_this])[np.array([hr_lc!=ignore_index])]).mean()

            # ignore 0
            iou_this_q = util.per_class_iou(hr_lc, preds_this, np.arange(0,len(classes_keep)))
            accs_q.append(acc_this_q)
            ious_q.append(iou_this_q)
            num_pix.append((hr_lc != ignore_index).sum())

            if compute_r:
                # now do r
                prior_this = rasterio.open(fn_this.replace('h_highres_labels.tif',f'{prior_name}.tif')).read()

                prior = (prior_this / 255. + p_add_smooth) / (prior_this / 255. + p_add_smooth).sum(axis=0)

                z = (preds_this_soft.T / preds_this_soft.sum(axis=(1,2)) ).T
                preds_r = (prior*z).argmax(0)

                acc_this_r = (np.array([hr_lc == preds_r])[np.array([hr_lc!=ignore_index])]).mean()
                iou_this_r = util.per_class_iou(hr_lc, preds_r, np.arange(0,len(classes_keep)))

                accs_r.append(acc_this_r)
                ious_r.append(iou_this_r)
            
            t2 = time.time()
            print(f'{t2-t1:.2f} seconds')

        ious_aggregated_q = util.aggregate_ious([x[1] for x in ious_q], [x[2] for x in ious_q])
        acc_aggregated_q = (np.array(accs_q) * np.array(num_pix)).sum() / np.sum(num_pix)
        print(f'For {state} {set_this} set with {loss} loss:')
        print(f'acc q: {acc_aggregated_q}')
        print(f'mean iou q: {np.mean(ious_aggregated_q[0])}')
        
        if compute_r:
            acc_aggregated_r = (np.array(accs_r) * np.array(num_pix)).sum() / np.sum(num_pix)
            ious_aggregated_r = util.aggregate_ious([x[1] for x in ious_r], [x[2] for x in ious_r])
            print(f'acc r: {acc_aggregated_r}')
            print(f'mean iou r: {np.mean(ious_aggregated_r[0])}')
        
        print('IoU per class over the tiles (q) is: ')
        print(ious_aggregated_q[0])
        if compute_r:
            print('IoU per class over the tiles (r) is: ')
            print(ious_aggregated_r[0])
        

        results_by_state_q[loss][state] = {'accs': accs_q,
                                       'ious': ious_q,
                                       'num_pix':num_pix,
                                       'ious_aggregated': ious_aggregated_q,
                                       'acc_aggregated':acc_aggregated_q}
        
        if compute_r:
            results_by_state_r[loss][state] = {'accs': accs_r,
                                               'ious': ious_r,
                                               'num_pix':num_pix,
                                               'ious_aggregated': ious_aggregated_r,
                                               'acc_aggregated':acc_aggregated_r}
        
    out_fn = f'{torchgeo_pred_dir}/{run_name.replace("_"+state,"")}.pkl'
    with open(out_fn, 'wb') as f:
        print(f'writing results to {out_fn}')
        pickle.dump({'results_by_state_q':results_by_state_q[loss],
                     'results_by_state_r':results_by_state_r[loss]}, f)

10
11.10 seconds
10.23 seconds
8.68 seconds
8.71 seconds
9.19 seconds
8.99 seconds
9.39 seconds
8.62 seconds
9.01 seconds
9.73 seconds
For pittsburgh_pa-2010_1m test set with qr_forward loss:
acc q: 0.8721401499450011
mean iou q: 0.6853562146260689
acc r: 0.8696790696664919
mean iou r: 0.6926713239230677
IoU per class over the tiles (q) is: 
[0.9267498183164448, 0.8519798920676908, 0.21507305644842745, 0.8220769565431544, 0.6109013497546272]
IoU per class over the tiles (r) is: 
[0.962180914356022, 0.8522077721595219, 0.2265058110404857, 0.817221996824602, 0.605240125234707]
10
9.79 seconds
10.11 seconds
9.37 seconds
9.68 seconds
10.69 seconds
9.93 seconds
10.44 seconds
9.63 seconds
9.31 seconds
9.53 seconds
For durham_nc-2012_1m test set with qr_forward loss:
acc q: 0.7874003944402881
mean iou q: 0.47348728996387796
acc r: 0.7892546493024487
mean iou r: 0.4956671053221271
IoU per class over the tiles (q) is: 
[0.5020130899766014, 0.634276644402449, 0.11497267411400797, 0.7885755656114

# save in a nice table format

In [5]:
states_in_reporting_order = [
    'pittsburgh_pa-2010_1m', 
    'durham_nc-2012_1m',
    'austin_tx-2012_1m',
    'phoenix_az-2010_1m'
]

In [6]:
#for loss in loss_to_eval_options:

for loss in [
    'qr_forward',
   #'qr_reverse'
]:
    print(loss + " q ")
    
    result_str = ""
    for state in states_in_reporting_order:
    
        results_q = results_by_state_q[loss][state]
        
        result_str += f"& {np.round(results_q['acc_aggregated']*100,2):.02f} "
        result_str += f"& {np.round(np.mean(results_q['ious_aggregated'][0])*100,2):.02f} "
        
    print(result_str)
    
    print(loss + " r ")
    
    result_str = ""
    for state in states_in_reporting_order:
    
        results_r = results_by_state_r[loss][state]
        
        result_str += f"& {np.round(results_r['acc_aggregated']*100,2):.02f}  "
        result_str += f"& {np.round(np.mean(results_r['ious_aggregated'][0])*100,2):.02f} "
        
    print(result_str)

qr_forward q 
& 87.21 & 68.54 & 78.74 & 47.35 & 79.83 & 51.10 & 74.57 & 39.59 
qr_forward r 
& 86.97  & 69.27 & 78.93  & 49.57 & 80.27  & 52.60 & 75.31  & 42.22 
