In [None]:
def f_extract_pmf_statistics(x_target_grid, y_target_grid, pmf_pred_nn, bin_centers_edges_z, *args):
    # function to extract PMF entropy, mean, median, mode, exceed probability
    # -------------- Input --------------
    # - x_target_grid       [T,1]   x coordidates of the target set
    # - y_target_grid       [T,1]   x coordidates of the target set
    # - pmf_pred_nn         {1,T}   predicted z PMF for targets 
    # - bin_centers_edges_z [1,n]   bin centers of the z PMF
    # - varargin            h       threshold of z
    # -------------- Output --------------
    # - z_target_entropy_pred_plot     [T,T]   entropy of the predicted target z_PMFs
    # - z_target_mean_pred_GRID_plot   [T,T]   mean of the predicted target z_PMFs
    # - z_target_median_pred_GRID_plot [T,T]   median of the predicted target z_PMFs
    # - z_target_mode_pred_GRID_plot   [T,T]   mode of the predicted target z_PMFs
    # - varargout                      [T,T]   probability of predicted target z_PMFs
    #                                          exceeding varargin (threshold of z)                                        of the 
    # -------------- Version --------------
    # - 2020/03/20 Stephanie Thiesen: intial version
    # -------------- Script --------------
    # calculate the entropy of the z PMFs for all predicted points 
    H_z_pmf_by_class_pred = np.empty(len(pmf_pred_nn))
    for target_ in range(len(pmf_pred_nn)):
        H_z_pmf_by_class_pred[target_] = f_entropy(np.array(pmf_pred_nn[target_]))
    z_target_entropy_pred_plot = H_z_pmf_by_class_pred.reshape(y_target_grid.shape)
    
    # z mean, median, mode,
    z_target_mean_pred = np.empty(len(pmf_pred_nn))
    z_target_median_pred = np.empty(len(pmf_pred_nn))
    z_target_mode_pred = np.empty(len(pmf_pred_nn))
    z_target_probability_pred = np.empty(len(pmf_pred_nn))
    for target_ in range(len(pmf_pred_nn)):
        # mean
        z_target_mean_pred[target_] = (np.sum(bin_centers_edges_z * np.array(pmf_pred_nn[target_]))) / np.sum(np.array(pmf_pred_nn[target_]))
        
        # median
        cmf_median_ = np.cumsum(np.array(pmf_pred_nn[target_]))
        idx_median_ = np.where(cmf_median_ >= 0.5)[0][0]
        x_median_ = [bin_centers_edges_z[idx_median_-1], bin_centers_edges_z[idx_median_], bin_centers_edges_z[idx_median_+1]]
        y_median_ = [cmf_median_[idx_median_-1], cmf_median_[idx_median_], cmf_median_[idx_median_+1]]
        z_target_median_pred[target_] = np.interp(0.5, y_median_, x_median_)
        
        # mode
        idx_ = np.argmax(np.array(pmf_pred_nn[target_]))
        z_target_mode_pred[target_] = bin_centers_edges_z[idx_]
    
    if len(args) >= 1 and not np.isnan(args[0]):
        thres = args[0]
        z_target_probability_pred = np.empty(len(pmf_pred_nn))
        for target_ in range(len(pmf_pred_nn)):
            # probability of exceeding thres
            bin_thres_ = np.sum(bin_centers_edges_z <= thres)
            if bin_centers_edges_z[bin_thres_] < thres:
                bin_thres_ += 1
            cmf_probab = np.cumsum(np.array(pmf_pred_nn[target_]))
            z_target_probability_pred[target_] = 1 - cmf_probab[bin_thres_]
        
        z_target_probability_pred_GRID_plot = z_target_probability_pred.reshape(y_target_grid.shape)
    else:
        z_target_probability_pred_GRID_plot = np.nan
    
    z_target_mean_pred_GRID_plot = z_target_mean_pred.reshape(y_target_grid.shape)
    z_target_median_pred_GRID_plot = z_target_median_pred.reshape(y_target_grid.shape)
    z_target_mode_pred_GRID_plot = z_target_mode_pred.reshape(y_target_grid.shape)
    
    if len(args) >= 1 and not np.isnan(args[0]):
        varargout = [z_target_probability_pred_GRID_plot]
    else:
        varargout = []
    
    return z_target_entropy_pred_plot, z_target_mean_pred_GRID_plot, z_target_median_pred_GRID_plot, z_target_mode_pred_GRID_plot, varargout