In [None]:
import sys
sys.path.append("..")

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from sleeprnn.nn.losses import get_border_weights
from sleeprnn.nn.losses import weighted_cross_entropy_loss_v5
from sleeprnn.common import viz

## Synth Label and Prediction

In [None]:
fs = 25  # in output
n_samples = 250
annots = [[60, 86], [110, 152], [164, 180]]
preds = [[64, 90], [120, 200], [230, 240]]

time_axis = np.arange(n_samples) / fs
# Labels
labels = np.zeros(n_samples)
for annot in annots:
    labels[annot[0]:annot[1]+1] = 1
# Predictions
probabilities = np.zeros(n_samples + 50)
for pred in preds:
    probabilities[pred[0]+25:pred[1]+26] = 1
probabilities = np.clip(probabilities, 0.1, 0.9)
smooth_kernel = np.hanning(8)
smooth_kernel /= smooth_kernel.sum()
probabilities = np.convolve(probabilities, smooth_kernel, mode="same")
probabilities = probabilities[25:-25]

# Logits
logits = np.log(probabilities) - np.log(1 - probabilities)
logits = np.stack([1-logits, logits], axis=-1)

# Plot
line_args = dict(marker='o', markersize=2, linewidth=0.6)
fig, ax = plt.subplots(3, 1, figsize=(8, 4), dpi=100, sharex=True)
ax[0].plot(time_axis, labels, **line_args), ax[0].set_title("Labels")
ax[1].plot(time_axis, probabilities, **line_args), ax[1].set_title("Probabilities")
ax[2].plot(time_axis, logits[..., 1], **line_args), ax[2].set_title("Logits")
plt.tight_layout()
plt.show()

## Weights

In [None]:
class_weights = [1.0, 0.25]  # positive weight, negative always one
focal_gamma = 3
focal_eps = 0.5  # [0, 1]
anti_border_amplitude = 1.0  # [0, 1]
anti_border_halft_width = 6  # n samples

tf.reset_default_graph()
loss, loss_summ, weights_dict = weighted_cross_entropy_loss_v5(
    logits.reshape(1, -1, 2).astype(np.float32), labels.reshape(1, -1).astype(np.int32),
    class_weights,
    focal_gamma, focal_eps,
    anti_border_amplitude, anti_border_halft_width,
    return_weights=True)
sess = tf.Session()
tf.global_variables_initializer()
my_weights = sess.run(weights_dict)

# Plot
title_fontsize = 9
line_args = dict(linewidth=1.2) #dict(marker='o', markersize=2, linewidth=0.6)
fig, ax = plt.subplots(6, 1, figsize=(6, 5), dpi=200, sharex=True)

ax[0].set_title("Labels", loc="left", fontsize=title_fontsize)
ax[0].plot(time_axis, labels, **line_args)

ax[1].set_title("Probabilities", loc="left", fontsize=title_fontsize)
ax[1].plot(time_axis, probabilities, **line_args)

ax[2].set_title(
    "Weight of Class, $w_1:w_0$ = %1.2f" % (class_weights[1] / class_weights[0]), 
    loc="left", fontsize=title_fontsize)
ax[2].plot(time_axis, my_weights["w_class"][0], **line_args)

ax[3].set_title(
    "Weight of Error, $\gamma$ = %1.1f, $\epsilon$ = %1.1f" % (focal_gamma, focal_eps), 
    loc="left", fontsize=title_fontsize)
ax[3].plot(time_axis, my_weights["w_focal"][0], **line_args)

ax[4].set_title(
    "Weight of Border, $a$ = %1.1f, $L$ = %d" % (anti_border_amplitude, anti_border_halft_width), 
    loc="left", fontsize=title_fontsize)
ax[4].plot(time_axis, my_weights["w_border"][0], **line_args)

ax[5].set_title(
    "Total Weight $w_{class}\cdot w_{error}\cdot w_{border}$", 
    loc="left", fontsize=title_fontsize)
ax[5].plot(
    time_axis, my_weights["w_total"][0] / my_weights["w_total"][0].max(), **line_args,
    color=viz.PALETTE['red']
)
for s_ax in ax:
    s_ax.set_ylim([-0.2, 1.2])
    s_ax.tick_params(labelsize=8)
    s_ax.set_xlim([0, n_samples / fs])
    # s_ax.set_xticks([])
ax[-1].set_xlabel("Time [s]", fontsize=8)
plt.tight_layout()
plt.savefig(
    "weights_class%1.2f_g%1.1f_eps%1.1f_a%1.1f_hw%d.png" % (
        class_weights[1] / class_weights[0], 
        focal_gamma, focal_eps,
        anti_border_amplitude, anti_border_halft_width
    )
)
plt.show()