In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false')

In [None]:
if not os.path.isdir('loss_fn_plots'):
    os.mkdir('loss_fn_plots')
def pathify(fname):
    return os.path.join('loss_fn_plots', fname)

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

from del_qsar import losses
from del_qsar.enrichments import R_from_z

In [None]:
# MSE loss, high counts (150 / 50)
def make_MSE_plot_low_uncertainty(labels=False):
    fig = plt.figure(figsize=(2, 2), dpi=300)
    loss_fn = torch.nn.MSELoss(reduction='none')
    x = np.linspace(0, 10, 300)
    k1 = np.full((300,), 150)
    k2 = np.full((300,), 50)
    target = np.full((300,), R_from_z(torch.Tensor(k2), 1e6, torch.Tensor(k1), 1e6, 0))
    lines = plt.plot(
        x,
        loss_fn(torch.Tensor(x), torch.Tensor(target)),
        zorder=2,
    )
    fig.canvas.draw()
    ax = plt.gca()
    if labels:
        ax.tick_params(labelsize=7)
    else:
        ax.tick_params(labelsize=12)
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    ax.set_yticks([0, 1, 2, 3, 4, 5])
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 5])
    if labels:
        ax.set_xlabel('hypothesized enrichment R', fontsize=7)
        ax.set_ylabel('MSE', fontsize=7)
    ax.grid(zorder=1)
    plt.tight_layout()
    if labels:
        plt.savefig(pathify('MSE_loss_plot_low_uncertainty_with_labels.png'))
    else:
        plt.savefig(pathify('MSE_loss_plot_low_uncertainty.png'))
    plt.show()

In [None]:
make_MSE_plot_low_uncertainty(labels=True)

In [None]:
make_MSE_plot_low_uncertainty()

In [None]:
# MSE loss, low counts (3 / 1)
def make_MSE_plot_high_uncertainty(labels=False):
    fig = plt.figure(figsize=(2, 2), dpi=300)
    loss_fn = torch.nn.MSELoss(reduction='none')
    x = np.linspace(0, 10, 300)
    k1 = np.full((300,), 3)
    k2 = np.full((300,), 1)
    target = np.full((300,), R_from_z(torch.Tensor(k2), 1e6, torch.Tensor(k1), 1e6, 0))
    lines = plt.plot(
        x,
        loss_fn(torch.Tensor(x), torch.Tensor(target)),
        zorder=2,
    )
    fig.canvas.draw()
    ax = plt.gca()
    if labels:
        ax.tick_params(labelsize=7)
    else:
        ax.tick_params(labelsize=12)
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    ax.set_yticks([0, 1, 2, 3, 4, 5])
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 5])
    if labels:
        ax.set_xlabel('hypothesized enrichment R', fontsize=7)
        ax.set_ylabel('MSE', fontsize=7)
    ax.grid(zorder=1)
    plt.tight_layout()
    if labels:
        plt.savefig(pathify('MSE_loss_plot_high_uncertainty_with_labels.png'))
    else:
        plt.savefig(pathify('MSE_loss_plot_high_uncertainty.png'))
    plt.show()

In [None]:
make_MSE_plot_high_uncertainty(labels=True)

In [None]:
make_MSE_plot_high_uncertainty()

In [None]:
# NLL loss, high counts (150 / 50)
def make_Poisson_plot_low_uncertainty(out, labels=False, tick_labelsize=None):
    fig = plt.figure(figsize=(2, 2), dpi=300)
    preds = np.linspace(0, 10, 300)
    k1 = np.full((300,), 150)
    k2 = np.full((300,), 50)
    lines = plt.plot(
        preds,
        losses.loss_fn_nlogprob(torch.Tensor(preds), torch.Tensor(k1), torch.Tensor(k2), 1e-6, 1e-6),
        zorder=2,
    )
    fig.canvas.draw()
    ax = plt.gca()
    if labels:
        ax.tick_params(labelsize=7)
    else:
        ax.tick_params(labelsize=tick_labelsize)
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    ax.set_yticks([0, 1, 2, 3, 4, 5])
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 5])
    if labels:
        ax.set_xlabel('hypothesized enrichment R', fontsize=7)
        ax.set_ylabel('NLL', fontsize=7)
    ax.grid(zorder=1)
    plt.tight_layout()
    plt.savefig(pathify(out))
    plt.show()

In [None]:
make_Poisson_plot_low_uncertainty('Poisson_enrichment_loss_plot_low_uncertainty_with_labels.png', labels=True)

In [None]:
make_Poisson_plot_low_uncertainty('Poisson_enrichment_loss_plot_low_uncertainty_labelsize_12.png', tick_labelsize=12)

In [None]:
make_Poisson_plot_low_uncertainty('Poisson_enrichment_loss_plot_low_uncertainty_labelsize_9.png', tick_labelsize=9)

In [None]:
# NLL loss, low counts (3 / 1)
def make_Poisson_plot_high_uncertainty(out, labels=False, tick_labelsize=None):
    fig = plt.figure(figsize=(2, 2), dpi=300)
    preds = np.linspace(0, 10, 300)
    k1 = np.full((300,), 3)
    k2 = np.full((300,), 1)
    lines = plt.plot(
        preds,
        losses.loss_fn_nlogprob(torch.Tensor(preds), torch.Tensor(k1), torch.Tensor(k2), 1e-6, 1e-6),
        zorder=2,
    )
    fig.canvas.draw()
    ax = plt.gca()
    if labels:
        ax.tick_params(labelsize=7)
    else:
        ax.tick_params(labelsize=tick_labelsize)
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    ax.set_yticks([0, 1, 2, 3, 4, 5])
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 5])
    if labels:
        ax.set_xlabel('hypothesized enrichment R', fontsize=7)
        ax.set_ylabel('NLL', fontsize=7)
    ax.grid(zorder=1)
    plt.tight_layout()
    if labels:
        plt.savefig(pathify('Poisson_enrichment_loss_plot_high_uncertainty_with_labels.png'))
    else:
        plt.savefig(pathify('Poisson_enrichment_loss_plot_high_uncertainty.png'))
    plt.savefig(pathify(out))
    plt.show()

In [None]:
make_Poisson_plot_high_uncertainty('Poisson_enrichment_loss_plot_high_uncertainty_with_labels.png', labels=True)

In [None]:
make_Poisson_plot_high_uncertainty('Poisson_enrichment_loss_plot_high_uncertainty_labelsize_12.png', tick_labelsize=12)

In [None]:
make_Poisson_plot_high_uncertainty('Poisson_enrichment_loss_plot_high_uncertainty_labelsize_9.png', tick_labelsize=9)

In [None]:
# NLL loss, medium counts (15 / 5)
def make_Poisson_plot_med_uncertainty(labels=False):
    fig = plt.figure(figsize=(2, 2), dpi=300)
    preds = np.linspace(0, 10, 300)
    k1 = np.full((300,), 15)
    k2 = np.full((300,), 5)
    lines = plt.plot(
        preds,
        losses.loss_fn_nlogprob(torch.Tensor(preds), torch.Tensor(k1), torch.Tensor(k2), 1e-6, 1e-6),
        zorder=2,
    )
    fig.canvas.draw()
    ax = plt.gca()
    if labels:
        ax.tick_params(labelsize=7)
    else:
        ax.tick_params(labelsize=9)
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    ax.set_yticks([0, 1, 2, 3, 4, 5])
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 5])
    if labels:
        ax.set_xlabel('hypothesized enrichment R', fontsize=7)
        ax.set_ylabel('NLL', fontsize=7)
    ax.grid(zorder=1)
    plt.tight_layout()
    if labels:
        plt.savefig(pathify('Poisson_enrichment_loss_plot_medium_uncertainty_with_labels.png'))
    else:
        plt.savefig(pathify('Poisson_enrichment_loss_plot_medium_uncertainty.png'))
    plt.show()

In [None]:
make_Poisson_plot_med_uncertainty(labels=True)

In [None]:
make_Poisson_plot_med_uncertainty()