In [None]:
import numpy as np
import torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

In [None]:
ls5_2010_val = np.load('./data/preds/ls5_2010_preds_val_quant.npy')
ls7_2010_val = np.load('./data/preds/ls7_2010_preds_val_quant.npy')
ls7_2017_val = np.load('./data/preds/ls7_2017_preds_val_quant.npy')
ls8_2017_val = np.load('./data/preds/ls8_2017_preds_val_quant.npy')
val_masks = np.load('./data/preds/val_masks.npy')

In [None]:
# If True, chooses cutoffs to maximize IoU
# If False, chooses cutoffs to minimize difference between precision and recall
MAX_IOU_MODE = False

In [None]:
def compute_stats(tp, fp, fn, tn):
    iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    prec = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
    return np.array([iou, f1, prec,recall])

## Step 1: Find balanced precision/recall cutoff for LS8 2017

In [None]:
cutoffs = np.arange(0,1.0,0.001)
stats_arrays = []
preds = ls8_2017_val
masks = torch.Tensor(val_masks).long()
for cutoff in cutoffs:
    preds_binary = torch.Tensor(preds>cutoff).long()
    tp, fp, fn, tn = smp.metrics.get_stats(preds_binary,
                                           masks,
                                            mode="binary")

    stats_arrays.append(compute_stats(tp, fp, fn, tn))
all_stats = np.vstack(stats_arrays)

In [None]:
plt.plot(cutoffs, all_stats[:,3])
plt.plot(cutoffs, all_stats[:,2])

In [None]:
if MAX_IOU_MODE:
    best_index = np.argmax(all_stats[:,0])
else:
    best_index = np.argmin(np.abs(all_stats[:,2]-all_stats[:,3]))
print(all_stats[best_index])
best_cutoff_ls8 = np.median(cutoffs[np.where(all_stats[:,0] == all_stats[best_index, 0])[0]])
print(best_cutoff_ls8)

# Step 2: Balanced precision/recall cutoff for LS7 using LS8 as baseline

In [None]:
cutoffs = np.arange(0,0.005,0.0001)
stats_arrays = []
preds = ls7_2017_val
masks = torch.Tensor(ls8_2017_val>best_cutoff_ls8).long()
for cutoff in cutoffs:
    preds_binary = torch.Tensor(preds>cutoff).long()
    tp, fp, fn, tn = smp.metrics.get_stats(preds_binary,
                                           masks,
                                            mode="binary")

    stats_arrays.append(compute_stats(tp, fp, fn, tn))
all_stats = np.vstack(stats_arrays)

In [None]:
plt.plot(cutoffs, all_stats[:,3])
plt.plot(cutoffs, all_stats[:,2])

In [None]:
if MAX_IOU_MODE:
    best_index = np.argmax(all_stats[:,0])
else:
    best_index = np.argmin(np.abs(all_stats[:,2]-all_stats[:,3]))
print(all_stats[best_index])
best_cutoff_ls7 = np.median(cutoffs[np.where(all_stats[:,0] == all_stats[best_index, 0])[0]])
print(best_cutoff_ls7)

# Step 3: Balanced precision/recall cutoff for LS5 using LS7 2010 as baseline

In [None]:
cutoffs = np.arange(0,0.1,0.001)
stats_arrays = []
preds = ls5_2010_val
masks = torch.Tensor(ls7_2010_val>best_cutoff_ls7).long()
for cutoff in cutoffs:
    preds_binary = torch.Tensor(preds>cutoff).long()
    tp, fp, fn, tn = smp.metrics.get_stats(preds_binary,
                                           masks,
                                            mode="binary")

    stats_arrays.append(compute_stats(tp, fp, fn, tn))
all_stats = np.vstack(stats_arrays)

In [None]:
plt.plot(cutoffs, all_stats[:,3])
plt.plot(cutoffs, all_stats[:,2])

In [None]:
if MAX_IOU_MODE:
    best_index = np.argmax(all_stats[:,0])
else:
    best_index = np.argmin(np.abs(all_stats[:,2]-all_stats[:,3]))
print(all_stats[best_index])
best_cutoff_ls5 = np.median(cutoffs[np.where(all_stats[:,0] == all_stats[best_index, 0])[0]])
print(best_cutoff_ls5)

## 

# Evaluate against Val Masks

In [None]:
print(np.sum(ls8_2017_val> best_cutoff_ls8))
print(np.sum(ls7_2017_val> best_cutoff_ls7))
print(np.sum(ls7_2010_val> best_cutoff_ls7))
print(np.sum(ls5_2010_val> best_cutoff_ls5))
print(np.sum(val_masks))

In [None]:
print('LS8 2017: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls8_2017_val>best_cutoff_ls8).long(),
                                       torch.Tensor(val_masks).long(),
                                        mode="binary")
))
print('LS7 2017: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls7_2017_val>best_cutoff_ls7).long(),
                                       torch.Tensor(val_masks).long(),
                                        mode="binary")
))
print('LS7 2010: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls7_2010_val>best_cutoff_ls7).long(),
                                       torch.Tensor(val_masks).long(),
                                        mode="binary")
))
print('LS5 2010: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls5_2010_val>best_cutoff_ls5).long(),
                                       torch.Tensor(val_masks).long(),
                                        mode="binary")
))

# Eval on test

In [None]:
ls5_2010_test = np.load('./data/preds/ls5_2010_preds_test_quant.npy')
ls7_2010_test = np.load('./data/preds/ls7_2010_preds_test_quant.npy')
ls7_2017_test = np.load('./data/preds/ls7_2017_preds_test_quant.npy')
ls8_2017_test = np.load('./data/preds/ls8_2017_preds_test_quant.npy')
test_masks = np.load('./data/preds/test_masks.npy')

In [None]:

# Against each other
print('LS8 vs LS7 2017: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls7_2017_test>best_cutoff_ls7).long(),
                                       torch.Tensor(ls8_2017_test>best_cutoff_ls8).long(),
                                        mode="binary")
                                        ))
print('LS7 vs LS5 2010: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls5_2010_test>best_cutoff_ls5).long(),
                                       torch.Tensor(ls7_2010_test>best_cutoff_ls7).long(),
                                        mode="binary")
                                        ))

In [None]:
# Against masks
print('LS8 2017: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls8_2017_test>best_cutoff_ls8).long(),
                                       torch.Tensor(test_masks).long(),
                                        mode="binary")
))
print('LS7 2017: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls7_2017_test>best_cutoff_ls7).long(),
                                       torch.Tensor(test_masks).long(),
                                        mode="binary")
))
print('LS7 2010: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls7_2010_test>best_cutoff_ls7).long(),
                                       torch.Tensor(test_masks).long(),
                                        mode="binary")
))
print('LS5 2010: ',compute_stats(*smp.metrics.get_stats(torch.Tensor(ls5_2010_test>best_cutoff_ls5).long(),
                                       torch.Tensor(test_masks).long(),
                                        mode="binary")
))

In [None]:
print(np.sum(ls8_2017_test> best_cutoff_ls8))
print(np.sum(ls7_2017_test> best_cutoff_ls7))
print(np.sum(ls7_2010_test> best_cutoff_ls7))
print(np.sum(ls5_2010_test> best_cutoff_ls5))
print(np.sum(test_masks))