In [22]:
import numpy as np

In [23]:
import sys
sys.path.append('./utils')

from data import Data
from models import Models
from tags import Tags
tags = Tags()

In [18]:
N_TAGS = 17

def threshold(pred, true):
    thres = [0] * N_TAGS
    for i in range(N_TAGS):
        print('Calc thresh for tag {}'.format(tags.idx_to_tag(i)))
        thres[i] = tag_threshold(pred[:,i], true[:,i])
    return thres

def amazon_score(tp, fp, fn):
    p = tp / (tp + fp) if tp + fp > 0 else 0
    r = tp / (tp + fn) if tp + fn > 0 else 0
    result = 5 * p * r / (4 * p + r) if 4 * p + r > 0 else 0
    return result

def tag_threshold(pred, true):
    pred = pred.reshape((len(pred)))
    true = true.reshape((len(true)))
    N_POS = sum(true)
    N_NEG = len(pred) - N_POS
    pair = list(zip(pred, true))
    pair.sort()
    tp, fp, tn, fn = N_POS, N_NEG, 0, 0
    tag_thres = 0
    max_score = amazon_score(tp, fp, fn)
    best_counts = ()
    for i in range(len(pred)):
        if pair[i][1] == 0:
            fp -= 1
            tn += 1
        else:
            tp -= 1
            fn += 1
        current_score = amazon_score(tp, fp, fn)
        if current_score > max_score:
            max_score = current_score
            tag_thres = pair[i][0]
            best_counts = (tp, fp, tn, fn)
    print('Best (tp, fp, tn, fn) is {}'.format(best_counts))
    return tag_thres

def calc_threshold(path):
    m = Models.load_resnet50(path)
    d = Data(train=[0])
    X, y = d.get_fold(0)
    y_pred = m.predict(X, verbose=1)
    return threshold(y_pred, y)

In [24]:
def test():
    pred = np.array([[0.2], [0.6], [0.4], [0.8]])
    true = np.array([[0], [1], [0], [1]])
    thres = threshold(pred, true)
    return thres

In [None]:
t = calc_threshold('./weights-v9.hdf5')

In [25]:
[0.23067564,
 0.27402788,
 0.15499838,
 0.18645976,
 0.12418672,
 0.093219191,
 0.14909597,
 0.13256209,
 0.041971382,
 0.17731731,
 0.10376091,
 0.25468382,
 0.090709485,
 0.13336645,
 0.13344041,
 0.10004906,
 0.036582272]

[0.23067564,
 0.27402788,
 0.15499838,
 0.18645976,
 0.12418672,
 0.093219191,
 0.14909597,
 0.13256209,
 0.041971382,
 0.17731731,
 0.10376091,
 0.25468382,
 0.090709485,
 0.13336645,
 0.13344041,
 0.10004906,
 0.036582272]

Calc thresh for tag haze
Best (tp, fp, tn, fn) is (446.0, 313.0, 7261, 76)
Calc thresh for tag primary
Best (tp, fp, tn, fn) is (7464.0, 242.0, 359, 31)
Calc thresh for tag agriculture
Best (tp, fp, tn, fn) is (2347.0, 960.0, 4655, 134)
Calc thresh for tag clear
Best (tp, fp, tn, fn) is (5642.0, 421.0, 1982, 51)
Calc thresh for tag water
Best (tp, fp, tn, fn) is (1321.0, 986.0, 5604, 185)
Calc thresh for tag habitation
Best (tp, fp, tn, fn) is (664.0, 737.0, 6606, 89)
Calc thresh for tag road
Best (tp, fp, tn, fn) is (1463.0, 782.0, 5702, 149)
Calc thresh for tag cultivation
Best (tp, fp, tn, fn) is (716.0, 1222.0, 5997, 161)
Calc thresh for tag slash_burn
Best (tp, fp, tn, fn) is (21.0, 236.0, 7822, 17)
Calc thresh for tag cloudy
Best (tp, fp, tn, fn) is (393.0, 182.0, 7500, 21)
Calc thresh for tag partly_cloudy
Best (tp, fp, tn, fn) is (1427.0, 259.0, 6370, 40)
Calc thresh for tag conventional_mine
Best (tp, fp, tn, fn) is (5.0, 11.0, 8072, 8)
Calc thresh for tag bare_ground
Best (tp, fp, tn, fn) is (119.0, 347.0, 7562, 68)
Calc thresh for tag artisinal_mine
Best (tp, fp, tn, fn) is (60.0, 37.0, 7993, 6)
Calc thresh for tag blooming
Best (tp, fp, tn, fn) is (24.0, 57.0, 7981, 34)
Calc thresh for tag selective_logging
Best (tp, fp, tn, fn) is (32.0, 117.0, 7921, 26)
Calc thresh for tag blow_down
Best (tp, fp, tn, fn) is (12.0, 61.0, 8010, 13)