In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian, jvp
from jax import random
import numpy as np

In [2]:
def pu_sl(y_observed: np.ndarray,
          y_pred: np.ndarray,
          prior: float):
    N = y_observed.shape[0]
    NP = (y_observed == 1).sum()
    mask = y_observed == 1
    # loss
    loss = 1 / (4 * N) * (y_pred + 1)**2
    tmp = -prior / NP * y_pred * y_observed
    loss += tmp
    # grad
    grad = 1 / (2 * N) * (y_pred + 1)
    grad -= prior / NP * y_observed
    # hess
    hess = 1 / (2 * N) * np.ones_like(grad)
    return loss.sum(), grad, hess

In [3]:
y_obs=np.array([0,1,1,0]*10000)
y_pred=np.random.random(4*10000)*2-1
prior=0.5

In [4]:
pu_sl(y_obs, y_pred, prior)

(0.33569573984036194,
 array([ 6.73749298e-06, -9.80903618e-06, -4.25412076e-06, ...,
        -3.84448674e-06, -1.12459272e-05,  1.02419521e-05]),
 array([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05, 1.25e-05]))

In [5]:
%timeit pu_sl(y_obs, y_pred, prior)

581 µs ± 105 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
def sl(y:int, g:float): return 1/4*((y*g-1)**2)
grad_sl=jit(grad(sl, argnums=1))
hess_sl=jit(grad(grad_sl, argnums=1))

sl_map=jit(vmap(sl))
grad_sl_map=jit(vmap(grad_sl))
hess_sl_map=jit(vmap(hess_sl))

In [7]:
@jit
def pu_sl(y_observed: np.ndarray, #0,1のみ
          y_pred: np.ndarray,
          prior: float):
    y_observed=jnp.asarray(y_observed, dtype=jnp.int8)
    y_pred=jnp.asarray(y_pred,dtype=jnp.float32)

    N = y_observed.shape[0]
    NP = (y_observed == 1).sum()
    
    # loss for y*=1
    lossP = prior/NP * jnp.sum(sl_map(jnp.ones_like(y_pred), y_pred) * y_observed)
    
    # loss for y*=-1
    lossN = jnp.mean(sl_map(-jnp.ones_like(y_pred), y_pred)) - prior/NP*jnp.sum(sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed)
    
    loss=lossP
    grad=prior/NP * grad_sl_map(jnp.ones_like(y_pred), y_pred) * y_observed
    hess=prior/NP * hess_sl_map(jnp.ones_like(y_pred), y_pred) * y_observed
  
    is_non_negative=(lossN > 0) #jax.jit doesnt support `if` statement...
#     if lossN > 0: #if non-negative risk for loss N
    loss += lossN * is_non_negative
    grad += (1/N * grad_sl_map(-jnp.ones_like(y_pred), y_pred) \
             - prior/NP * grad_sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed) * is_non_negative
    hess += (1/N * hess_sl_map(-jnp.ones_like(y_pred), y_pred) \
             - prior/NP * hess_sl_map(-jnp.ones_like(y_pred), y_pred) * y_observed) * is_non_negative
    return loss, grad, hess

In [8]:
pu_sl(y_obs, y_pred, prior)



(DeviceArray(0.33569568, dtype=float32),
 DeviceArray([ 6.73749264e-06, -9.80903587e-06, -4.25412145e-06, ...,
              -3.84448776e-06, -1.12459265e-05,  1.02419526e-05],            dtype=float32),
 DeviceArray([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05,
              1.25e-05], dtype=float32))

In [9]:
%timeit pu_sl(y_obs, y_pred, prior)

505 µs ± 101 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [62]:
tmp=vmap(grad_sl)
a=np.ones(4*10000)
b=np.array([-1,0,1,2]*10000, dtype='float') 

In [72]:
%timeit sl_map(a,b)

323 µs ± 31.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [73]:
%timeit grad_sl_map(a,b)

408 µs ± 72 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [74]:
%timeit hess_sl_map(a,b)

331 µs ± 41.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [41]:
jit(vmap(hess_sl))(np.ones(4), np.array([-1,0,1,2], dtype='float') )

DeviceArray([2., 2., 2., 2.], dtype=float32)

In [6]:
class PU_SL:
    def __init__(self,prior):
        self.prior=prior
        self._loss=jit(self._lossfunc, backend='cpu')
        self._set_grad()
        self._set_hess()

    def __call__(self, y_observed:np.ndarray, y_pred:np.ndarray):
        loss=self._loss(y_observed, y_pred)
        grad=self._grad(y_observed, y_pred)
        hess=self._hess(y_observed, y_pred)
        return loss, grad, hess
    
    def _lossfunc(self,
              y_observed: np.ndarray,
              y_pred: np.ndarray):
        N = y_observed.shape[0]
        NP = (y_observed == 1).sum()
        # loss
        loss = 1 / (4 * N) * jnp.square((y_pred + 1))
        tmp = -self.prior / NP * y_pred * y_observed
        loss += tmp
        return loss.sum()
    
    def _set_grad(self):
        self._grad = jit(grad(self._loss, argnums=1), backend='cpu')
    
    def _set_hess(self):
        def h(y_observed:np.ndarray,
              y_pred:np.ndarray):
            return self._grad(y_obs,y_pred).sum()
        self._hess = jit(grad(h, argnums=1), backend='cpu')

In [7]:
pu_sl = PU_SL(0.5)

In [8]:
%timeit pu_sl(y_obs, y_pred)



1.11 ms ± 82.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
def pu_sl(y_observed: np.ndarray,
          y_pred: np.ndarray,
          prior: float):
    N = y_observed.shape[0]
    NP = y_observed.sum()

    # loss
    loss = 1 / (4 * N) * (y_pred + 1)**2
    tmp = -prior / NP * y_pred * y_observed
    loss += tmp

    return loss.sum()

def pu_sl_nonegative(y_observed: np.ndarray,
                     y_pred: np.ndarray,
                     prior: float):
    N = y_observed.shape[0]
    NP = y_observed.sum()
    
    # risk for positives
    loss=(prior/(4*NP) * (y_pred-1)**2 * y_observed).sum()
    
    # non-negative risk
    tmp=(1/(4*N)*(y_pred+1)**2 - prior/(4*NP)*(y_pred+1)**2*y_observed).sum()
    if tmp>0:
        loss+=tmp
    
    return loss

In [10]:
class GradHess:
    def __init__(self,
                 lossfunc:callable,
                 prior:float):
        self.prior=prior
        self._loss=jit(lossfunc, backend='cpu')
        self._set_grad()
        self._set_hess()
    
    def __call__(self, y_observed:np.ndarray, y_pred:np.ndarray):
#         loss=self._loss(y_observed, y_pred, self.prior)
        grad=self._grad(y_observed, y_pred, self.prior)
        hess=self._hess(y_observed, y_pred, self.prior)
        return grad, hess
    
    def _set_grad(self):
        self._grad = jit(grad(self._loss, argnums=1), backend='cpu')
    
    def _set_hess(self):
        def h(y_observed:np.ndarray,
              y_pred:np.ndarray,
              prior:float):
            return self._grad(y_observed,y_pred,self.prior).sum()
        self._hess = jit(grad(h, argnums=1), backend='cpu')

        
class GradHess:
    def __init__(self,
                 lossfunc:callable,
                 prior:float):
        self.prior=prior
        self._loss=lossfunc
        self._set_grad()
        self._set_hess()
    
    def __call__(self, y_observed:np.ndarray, y_pred:np.ndarray):
#         loss=self._loss(y_observed, y_pred, self.prior)
        grad=self._grad(y_observed, y_pred, self.prior)
        hess=self._hess(y_observed, y_pred, self.prior)
        return grad, hess
    
    def _set_grad(self):
        self._grad = grad(self._loss, argnums=1)
    
    def _set_hess(self):
        def h(y_observed:np.ndarray,
              y_pred:np.ndarray,
              prior:float):
            return self._grad(y_observed,y_pred,self.prior).sum()
        self._hess = grad(h, argnums=1)

In [11]:
pu_sl = GradHess(pu_sl_nonegative, 0.5)

In [12]:
%timeit pu_sl(y_obs, y_pred)

36.2 ms ± 863 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


pytorch 実装一番キレイだけど3,4倍遅い

In [10]:
from torch import tensor, Tensor
import torch

In [11]:
def pu_sl(y_observed: np.ndarray,
          y_pred: np.ndarray,
          prior: float):
    N = y_observed.shape[0]
    NP = y_observed.sum()
    y_observed=tensor(y_observed)
    y_pred=tensor(y_pred, requires_grad=True)
    
    # loss
    loss = 1 / (4 * N) * (y_pred + 1)**2
    tmp = -prior / NP * y_pred * y_observed
    loss += tmp
    loss=loss.sum()
    
    loss.backward(create_graph=True) #一階微分
    grad=y_pred.grad.clone()

    y_pred.grad.sum().backward() #二階微分
    hess=y_pred.grad - grad #蓄積した分を解消
    return loss.detach().numpy(), grad.detach().numpy(), hess.detach().numpy()

In [12]:
%timeit loss, grad, hess=pu_sl(y_obs, y_pred, prior)

2.96 ms ± 737 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
loss, grad, hess=pu_sl(y_obs, y_pred, prior)

In [17]:
loss,grad

(array(0.334929),
 array([ 1.32353755e-05, -1.01408916e-05, -4.18350929e-06, ...,
        -1.45266920e-05, -1.91193260e-05,  1.33722051e-05]))

In [18]:
def pu_sl(y_observed: np.ndarray,
                      y_pred: np.ndarray,
                      prior: float,
                      non_negative_risk=True):
    N = y_observed.shape[0]
    NP = y_observed.sum()
    y_observed=tensor(y_observed)
    y_pred=tensor(y_pred, requires_grad=True)
    
    # risk for positives
    loss=(prior/(4*NP) * ((y_pred-1)**2) * y_observed).sum()
    
    # non-negative risk
    tmp=(1/(4*N)*((y_pred+1)**2) - prior/(4*NP)*((y_pred+1)**2)*y_observed).sum()
#     loss+=torch.max(tmp,tensor(0.0,dtype=torch.double))
    if non_negative_risk:
        loss += max(tmp, 0)
    else:
        loss += tmp
    loss.backward(create_graph=True) #一階微分
    grad=y_pred.grad.clone()
 
    y_pred.grad.sum().backward() #二階微分
    hess=y_pred.grad - grad #蓄積した分を解消
    return loss.detach().numpy(), grad.detach().numpy(), hess.detach().numpy()

In [19]:
%timeit loss, grad, hess=pu_sl(y_obs, y_pred, prior)

3.06 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
loss, grad, hess=pu_sl(y_obs, y_pred, prior)

In [21]:
loss,grad

(array(0.334929),
 array([ 1.32353755e-05, -1.01408916e-05, -4.18350929e-06, ...,
        -1.45266920e-05, -1.91193260e-05,  1.33722051e-05]))

In [22]:
hess

array([1.25e-05, 1.25e-05, 1.25e-05, ..., 1.25e-05, 1.25e-05, 1.25e-05])