In [None]:
import numpy as np 
import pandas as pd
import cv2

import os
from os.path import join

from src import config
from src.metrics import compute_eval_metric

from src.transforms import CenterCrop

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
PREDICTION_DIR = '/workdir/data/predictions/mos-fpn-lovasz-se-resnext50-001'
TRAIN_FOLDS_PATH = '/workdir/data/train_folds_148_mos_emb_1.csv'
SEGM_THRESH = 0.5
PROB_THRESH = 0.6

In [None]:
folds_df = pd.read_csv(TRAIN_FOLDS_PATH)
score_lst = []
probs_df_dict = dict()

crop = CenterCrop((101, 101))

for i, row in folds_df.iterrows():
    if row.fold == 5:
        continue
    
    true_mask = cv2.imread(row.mask_path, cv2.IMREAD_GRAYSCALE)
    pred_mask_path = join(PREDICTION_DIR, f'fold_{row.fold}', 'val', row.id+'.png')
    
    if row.fold not in probs_df_dict:
        probs_path = join(PREDICTION_DIR, f'fold_{row.fold}', 'val', 'probs.csv')
        probs_df_dict[row.fold] = pd.read_csv(probs_path, index_col='id')
        
    prob = probs_df_dict[row.fold].loc[row.id].prob
    prob = prob > PROB_THRESH
    
    prob_mask = cv2.imread(pred_mask_path, cv2.IMREAD_GRAYSCALE)
    pred_mask = (prob_mask / 255.0) > SEGM_THRESH
    pred_mask *= prob
    true_mask = true_mask.astype(bool).astype(np.uint8)
    pred_mask = pred_mask.astype(bool).astype(np.uint8)
    
    score = compute_eval_metric(crop(true_mask), pred_mask)
    score_lst.append((row.id, score))
    
#     if  score < 0.2:
#         print(score, row.id, probs_df_dict[row.fold].loc[row.id].prob)
#         image_path = join(config.TRAIN_DIR, 'images', row.id+'.png')
#         image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        
#         f, axarr = plt.subplots(1, 4)
#         f.set_figheight(4)
#         f.set_figwidth(12)
#         axarr[0].imshow(image)
#         axarr[0].set_title('image')
#         axarr[1].imshow(true_mask)
#         axarr[1].set_title('true')
#         axarr[2].imshow(pred_mask)
#         axarr[2].set_title('pred')
#         axarr[3].imshow(prob_mask)
#         axarr[3].set_title('prob mask')
#         plt.show()


In [None]:
print(np.mean([score for id, score in score_lst]))
plt.hist([score for id, score in score_lst], bins=20)