In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian, jvp
from jax import random
import numpy as np

In [2]:
def pu_sl(y_observed: np.ndarray,
          y_pred: np.ndarray,
          prior: float):
    N = y_observed.shape[0]
    NP = (y_observed == 1).sum()
    mask = y_observed == 1
    # loss
    loss = 1 / (4 * N) * (y_pred + 1)**2
    tmp = -prior / NP * y_pred * y_observed
    loss += tmp
    # grad
    grad = 1 / (2 * N) * (y_pred + 1)
    grad -= prior / NP * y_observed
    # hess
    hess = 1 / (2 * N) * np.ones_like(grad)
    return loss.sum(), grad, hess

In [3]:
y_obs=np.array([0,1,1,0]*10000)
y_pred=np.random.random(4*10000)*2-1
prior=0.5

In [4]:
pu_sl(y_obs, y_pred, prior)

(0.33327379671714524,
 array([ 6.13474264e-06, -2.37834287e-05, -2.88265640e-06, ...,
        -1.03494747e-05, -3.81883959e-06,  6.92244749e-06]),
 array([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05, 1.25e-05]))

In [5]:
%timeit pu_sl(y_obs, y_pred, prior)

549 µs ± 59.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
def sl(y:int, g:float): return 1/4*((y*g-1)**2)
grad_sl=grad(sl, argnums=1)
hess_sl=grad(grad_sl, argnums=1)

sl_map=vmap(sl)
grad_sl_map=vmap(grad_sl)
hess_sl_map=vmap(hess_sl)

In [7]:
@jit
def pu_sl(y_observed: np.ndarray, #0,1のみ
          y_pred: np.ndarray,
          prior: float):
    y_observed=jnp.asarray(y_observed, dtype=jnp.int8)
    y_pred=jnp.asarray(y_pred,dtype=jnp.float32)

    N = y_observed.shape[0]
    NP = (y_observed == 1).sum()
    
    # loss for y*=1
    lossP = prior/NP * jnp.sum(sl_map(jnp.ones_like(y_pred), y_pred)  * y_observed)
    
    # loss for y*=-1
    lossN = jnp.mean(sl_map(-jnp.ones_like(y_pred), y_pred)) - prior/NP*jnp.sum(sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed)
    
    loss=lossP
    grad=prior/NP * grad_sl_map(jnp.ones_like(y_pred), y_pred) * y_observed
    hess=prior/NP * hess_sl_map(jnp.ones_like(y_pred), y_pred) * y_observed
  
    is_non_negative=(lossN > 0) #jax.jit doesnt support `if` statement...
#     if lossN > 0: #if non-negative risk for loss N
    loss += lossN * is_non_negative
    grad += (1/N * grad_sl_map(-jnp.ones_like(y_pred), y_pred) \
             - prior/NP * grad_sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed) * is_non_negative
    hess += (1/N * hess_sl_map(-jnp.ones_like(y_pred), y_pred) \
             - prior/NP * hess_sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed) * is_non_negative
    return loss, grad, hess

In [8]:
pu_sl(y_obs, y_pred, prior)



(DeviceArray(0.3332738, dtype=float32),
 DeviceArray([ 6.1347423e-06, -2.3783428e-05, -2.8826566e-06, ...,
              -1.0349475e-05, -3.8188400e-06,  6.9224475e-06],            dtype=float32),
 DeviceArray([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05,
              1.25e-05], dtype=float32))

In [9]:
%timeit pu_sl(y_obs, y_pred, prior)

518 µs ± 88.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


pytorch 実装一番キレイだけど3,4倍遅い

In [10]:
from torch import tensor, Tensor
import torch

In [11]:
def pu_sl(y_observed: np.ndarray,
          y_pred: np.ndarray,
          prior: float):
    N = y_observed.shape[0]
    NP = y_observed.sum()
    y_observed=tensor(y_observed)
    y_pred=tensor(y_pred, requires_grad=True)
    
    # loss
    loss = 1 / (4 * N) * (y_pred + 1)**2
    tmp = -prior / NP * y_pred * y_observed
    loss += tmp
    loss=loss.sum()
    
    loss.backward(create_graph=True) #一階微分
    grad=y_pred.grad.clone()

    y_pred.grad.sum().backward() #二階微分
    hess=y_pred.grad - grad #蓄積した分を解消
    return loss.detach().numpy(), grad.detach().numpy(), hess.detach().numpy()

In [12]:
%timeit loss, grad, hess=pu_sl(y_obs, y_pred, prior)

4.02 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
loss, grad, hess=pu_sl(y_obs, y_pred, prior)

In [14]:
loss,grad

(array(0.3332738),
 array([ 6.13474264e-06, -2.37834287e-05, -2.88265640e-06, ...,
        -1.03494747e-05, -3.81883959e-06,  6.92244749e-06]))

In [15]:
def pu_sl(y_observed: np.ndarray,
                      y_pred: np.ndarray,
                      prior: float,
                      non_negative_risk=True):
    N = y_observed.shape[0]
    NP = y_observed.sum()
    y_observed=tensor(y_observed)
    y_pred=tensor(y_pred, requires_grad=True)
    
    # risk for positives
    loss=(prior/(4*NP) * ((y_pred-1)**2) * y_observed).sum()
    
    # non-negative risk
    tmp=(1/(4*N)*((y_pred+1)**2) - prior/(4*NP)*((y_pred+1)**2)*y_observed).sum()
#     loss+=torch.max(tmp,tensor(0.0,dtype=torch.double))
    if non_negative_risk:
        loss += max(tmp, 0)
    else:
        loss += tmp
    loss.backward(create_graph=True) #一階微分
    grad=y_pred.grad.clone()
 
    y_pred.grad.sum().backward() #二階微分
    hess=y_pred.grad - grad #蓄積した分を解消
    return loss.detach().numpy(), grad.detach().numpy(), hess.detach().numpy()

In [16]:
%timeit loss, grad, hess=pu_sl(y_obs, y_pred, prior)

6.62 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
loss, grad, hess=pu_sl(y_obs, y_pred, prior)

In [18]:
loss,grad

(array(0.3332738),
 array([ 6.13474264e-06, -2.37834287e-05, -2.88265640e-06, ...,
        -1.03494747e-05, -3.81883959e-06,  6.92244749e-06]))

In [19]:
hess

array([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05, 1.25e-05])