In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from sleeprnn.nn.losses import get_border_weights
import tensorflow as tf

In [None]:
weight_border = 10
kernel_size_half = 12

n_samples = 500
annots = [[80, 106], [130, 172], [184, 200]]

# Labels
label = np.zeros(n_samples)
for annot in annots:
    label[annot[0]:annot[1]+1] = 1

# Prepare labels
first_label = label[0:1]
last_label = label[-1:]
label_extended = np.concatenate([first_label, label, last_label])
# Prepare filters
kernel_edge = np.array([-0.5, 0, 0.5])
std_kernel = (2 * kernel_size_half + 1) / 6
# kernel_size_half = (6 * std_kernel) // 2
kernel_steps = np.arange(2 * kernel_size_half + 1) - kernel_size_half
kernel_gauss = (weight_border - 1) * np.exp(-(kernel_steps ** 2) / (2 * (std_kernel ** 2)))
# Filter labels to produce weights of borders
output = np.convolve(label_extended, kernel_edge, mode="valid")
output = np.abs(output)
output = np.convolve(output, kernel_gauss, mode="same")
output = 1 + output

fig, ax = plt.subplots(3, 1, figsize=(15, 6), dpi=200)
ax[0].plot(label, '-o', markersize=3)
ax[0].set_xlim([-0.5, n_samples-1+0.5])
ax[0].set_ylim([-0.1, 1.1])
ax[1].plot(output, '-o', markersize=3)
ax[1].set_xlim([-0.5, n_samples-1+0.5])
#ax[1].set_ylim([-0.1, weight_border + 0.1])
ax[2].plot(kernel_steps, kernel_gauss, '-o')
plt.tight_layout()
plt.show()

In [None]:
# Replicate with TF
tf.reset_default_graph()
labels_ph = tf.placeholder(shape=[None, n_samples], dtype=tf.int32)
border_weights_tf = get_border_weights(labels_ph, weight_border, kernel_size_half)
sess = tf.Session()
tf.global_variables_initializer()

labels_prepared = label.reshape(1, -1).astype(np.int32)
border_weights_np = sess.run(border_weights_tf, feed_dict={labels_ph: labels_prepared})
border_weights_np = border_weights_np[0, :]
print(border_weights_np.dtype)

fig, ax = plt.subplots(2, 1, figsize=(15, 4))
ax[0].plot(label, '-o', markersize=3)
ax[0].set_xlim([-0.5, n_samples-1+0.5])
ax[0].set_ylim([-0.1, 1.1])
ax[1].plot(border_weights_np, '-o', markersize=3)
ax[1].set_xlim([-0.5, n_samples-1+0.5])
ax[1].set_ylim([-0.1, weight_border + 0.1])
plt.tight_layout()
plt.show()