In [1]:
from pycox.models.loss import rank_loss_deephit_single
#from pycox.models.data import pair_rank_mat
import numpy as np
import torch
from torch import Tensor
import random
from pycox.models import utils
from pycox.models.loss import _rank_loss_deephit,  _diff_cdf_at_time_i


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def _reduction(loss: Tensor, reduction: str = 'mean') -> Tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")
    
def _pair_rank_mat(mat, idx_durations, events, dtype='float32'):
    n = len(idx_durations)
    for i in range(n):
        dur_i = idx_durations[i]
        ev_i = events[i]
        if ev_i == 0:
            continue
        for j in range(n):
            dur_j = idx_durations[j]
            ev_j = events[j]
            if (dur_i < dur_j) or ((dur_i == dur_j) and (ev_j == 0)):
                mat[i, j] = 1
    return mat

def pair_rank_mat(idx_durations, events, dtype='float32'):
    """Indicator matrix R with R_ij = 1{T_i < T_j and D_i = 1}.
    So it takes value 1 if we observe that i has an event before j and zero otherwise.
    
    Arguments:
        idx_durations {np.array} -- Array with durations.
        events {np.array} -- Array with event indicators.
    
    Keyword Arguments:
        dtype {str} -- dtype of array (default: {'float32'})
    
    Returns:
        np.array -- n x n matrix indicating if i has an observerd event before j.
    """
    idx_durations = idx_durations.reshape(-1)
    events = events.reshape(-1)
    n = len(idx_durations)
    mat = np.zeros((n, n), dtype=dtype)
    mat = _pair_rank_mat(mat, idx_durations, events, dtype)
    return mat

def rank_loss_deephit_single(phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor,
                             sigma: Tensor, reduction: str = 'mean') -> Tensor:
    
    idx_durations = idx_durations.view(-1, 1)
    # events = events.float().view(-1)
    pmf = utils.pad_col(phi) #.softmax(1)
    #print('softmax pmf', pmf)
    y = torch.zeros_like(pmf).scatter(1, idx_durations, 1.) # one-hot
    print('y', y)
    rank_loss = _rank_loss_deephit(pmf, y, rank_mat, sigma, reduction)
    return rank_loss

def _rank_loss_deephit(pmf: Tensor, y: Tensor, rank_mat: Tensor, sigma: float,
                       reduction: str = 'mean') -> Tensor:
    """Ranking loss from DeepHit.
    
    Arguments:
        pmf {torch.tensor} -- Matrix with probability mass function pmf_ij = f_i(t_j)
        y {torch.tensor} -- Matrix with indicator of duration and censoring time. 
        rank_mat {torch.tensor} -- See pair_rank_mat function.
        sigma {float} -- Sigma from DeepHit paper, chosen by you.
    
    Returns:
        torch.tensor -- loss
    """
    r = _diff_cdf_at_time_i(pmf, y)
    print('intermediary result', torch.exp(-r/sigma))
    loss = rank_mat * torch.exp(-r/sigma)
    loss = loss.mean(1, keepdim=True)
    return _reduction(loss, reduction)

def _diff_cdf_at_time_i(pmf: Tensor, y: Tensor) -> Tensor:
    """R is the matrix from the DeepHit code giving the difference in CDF between individual
    i and j, at the event time of j. 
    I.e: R_ij = F_i(T_i) - F_j(T_i)
    
    Arguments:
        pmf {torch.tensor} -- Matrix with probability mass function pmf_ij = f_i(t_j)
        y {torch.tensor} -- Matrix with indicator of duration/censor time.
    
    Returns:
        torch.tensor -- R_ij = F_i(T_i) - F_j(T_i)
    """
    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)
    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    diag_r = r.diag().view(1, -1)
    r = ones.matmul(diag_r) - r
    #print('r',r)
    return r.transpose(0, 1)

In [26]:
# let's do easy to follow example
torch.manual_seed(42)
phi = torch.tensor([
              [1.0, 2.0],
              [3.0, 3.0],
              [2.0, 2.0]], requires_grad=True)
#phi = torch.rand(6, 4,requires_grad=True) * 2 - 1
survival_times = torch.tensor([5, 10, 10])
events = torch.tensor([1, 1, 1])
idx_durations = torch.tensor([0,1,1])
rank_mat = torch.tensor([
              [1.0, 1.0],
              [1.0, 1.0],
              [1.0, 1.0]])
#torch.from_numpy(pair_rank_mat(idx_durations, events))
sigma = torch.tensor([1.0])
var = rank_loss_deephit_single(phi, idx_durations, events, rank_mat, sigma, 'sum') 
print(var)
var.backward()
phi.grad

y tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.]])
ones tensor([[1.],
        [1.],
        [1.]])
y.transpose(0, 1) tensor([[1., 0., 0.],
        [0., 1., 1.],
        [0., 0., 0.]])
pmf.cumsum(1) tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<CumsumBackward0>)
r tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<MmBackward0>)
diag_r tensor([[1., 6., 4.]], grad_fn=<ViewBackward0>)
ones.matmul(diag_r) tensor([[1., 6., 4.],
        [1., 6., 4.],
        [1., 6., 4.]], grad_fn=<MmBackward0>)
r fin tensor([[ 0.,  3.,  1.],
        [-2.,  0., -2.],
        [-1.,  2.,  0.]], grad_fn=<SubBackward0>)
r fin transpose tensor([[ 0., -2., -1.],
        [ 3.,  0.,  2.],
        [ 1., -2.,  0.]], grad_fn=<TransposeBackward0>)
sum tensor(1., grad_fn=<SumBackward0>)
intermediary result tensor([0.3679], grad_fn=<ExpBackward0>)
tensor(1.1036, grad_fn=<SumBackward0>)


tensor([[ 0.0000,  2.2073],
        [ 0.0000, -1.1036],
        [ 0.0000, -1.1036]])

In [32]:
def _diff_cdf_at_time_i(pmf: Tensor, y: Tensor) -> Tensor:
    """R is the matrix from the DeepHit code giving the difference in CDF between individual
    i and j, at the event time of j. 
    I.e: R_ij = F_i(T_i) - F_j(T_i)
    
    Arguments:
        pmf {torch.tensor} -- Matrix with probability mass function pmf_ij = f_i(t_j)
        y {torch.tensor} -- Matrix with indicator of duration/censor time.
    
    Returns:
        torch.tensor -- R_ij = F_i(T_i) - F_j(T_i)
    """
    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)
    print('ones',ones)
    print('y.transpose(0, 1)',y.transpose(0, 1))
    print('pmf.cumsum(1)',pmf.cumsum(1))
    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    print('diagonal r',r)
    diag_r = r.diag().view(1, -1)
    print('diag_r', diag_r)
    print('ones.matmul(diag_r)',ones.matmul(diag_r))
    r = ones.matmul(diag_r) - r
    print('r fin',r)
    print('r fin transpose',r.transpose(0, 1))
    print('sum',r.transpose(0, 1).sum())
    return r.transpose(0, 1).sum()

In [35]:
phi = torch.tensor([
              [1.0, 2.0],
              [3.0, 3.0],
              [2.0, 2.0]], requires_grad=True)
y = torch.tensor([ # has to be fully dimensional for all scenarios
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.]])
pmf = utils.pad_col(phi)
print('pmf', pmf)
var = _diff_cdf_at_time_i(pmf, y)
var.backward()
phi.grad

pmf tensor([[1., 2., 0.],
        [3., 3., 0.],
        [2., 2., 0.]], grad_fn=<CatBackward0>)
ones tensor([[1.],
        [1.],
        [1.]])
y.transpose(0, 1) tensor([[1., 0., 0.],
        [0., 1., 1.],
        [0., 0., 0.]])
pmf.cumsum(1) tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<CumsumBackward0>)
diagonal r tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<MmBackward0>)
diag_r tensor([[1., 6., 4.]], grad_fn=<ViewBackward0>)
ones.matmul(diag_r) tensor([[1., 6., 4.],
        [1., 6., 4.],
        [1., 6., 4.]], grad_fn=<MmBackward0>)
r fin tensor([[ 0.,  3.,  1.],
        [-2.,  0., -2.],
        [-1.,  2.,  0.]], grad_fn=<SubBackward0>)
r fin transpose tensor([[ 0., -2., -1.],
        [ 3.,  0.,  2.],
        [ 1., -2.,  0.]], grad_fn=<TransposeBackward0>)
sum tensor(1., grad_fn=<SumBackward0>)


tensor([[ 0., -2.],
        [ 0.,  1.],
        [ 0.,  1.]])

In [31]:
phi = torch.tensor([
              [1.0, 2.0],
              [3.0, 3.0],
              [2.0, 2.0]], requires_grad=True)
pmf = utils.pad_col(phi)
y = torch.tensor([
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.]])
var = _diff_cdf_at_time_i(pmf, y)
var.backward()
print('gradient',phi.grad)

ones tensor([[1.],
        [1.],
        [1.]])
y.transpose(0, 1) tensor([[1., 0., 0.],
        [0., 1., 1.],
        [0., 0., 0.]])
pmf.cumsum(1) tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<CumsumBackward0>)
r tensor([[1., 3., 3.],
        [3., 6., 6.],
        [2., 4., 4.]], grad_fn=<MmBackward0>)
diag_r tensor([[1., 6., 4.]], grad_fn=<ViewBackward0>)
ones.matmul(diag_r) tensor([[1., 6., 4.],
        [1., 6., 4.],
        [1., 6., 4.]], grad_fn=<MmBackward0>)
r fin tensor([[ 0.,  3.,  1.],
        [-2.,  0., -2.],
        [-1.,  2.,  0.]], grad_fn=<SubBackward0>)
r fin transpose tensor([[ 0., -2., -1.],
        [ 3.,  0.,  2.],
        [ 1., -2.,  0.]], grad_fn=<TransposeBackward0>)
sum tensor(1., grad_fn=<SumBackward0>)
gradient tensor([[ 0., -2.],
        [ 0.,  1.],
        [ 0.,  1.]])


In [17]:
#torch.softmax(utils.pad_col(phi), axis=1) # correct
 

In [14]:
# now calculate above manually

import jax.numpy as jnp
from jax import grad

def test(x):
    
    x = jnp.cumsum(x, axis=1)
    s = 0
    mat = []
    #mat2 = jnp.zeros((3,3))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            if i!=j:
                #print('i,j',i,j)
                #print(x[i,i] - x[i,j])
                diff = x[i,i] - x[i,j]
                #mat2[i,j] = diff
                #print(diff)
                #s = s + diff
                mat.append(diff)

    mat = jnp.array(mat)
    print(mat)
    return jnp.sum(mat)

x = jnp.array([[1.0, 2.0, 3.0],
              [3.0, 3.0, 3.0],
              [2.0, 2.0, 2.0]])
print('sum',test(x))
grad(test)(x)


i,j 0 1
i,j 0 2
i,j 1 0
i,j 1 2
i,j 2 0
i,j 2 1
[-2. -5.  3. -3.  4.  2.]
sum -1.0
i,j 0 1
i,j 0 2
i,j 1 0
i,j 1 2
i,j 2 0
i,j 2 1
Traced<ConcreteArray([-2. -5.  3. -3.  4.  2.], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([-2., -5.,  3., -3.,  4.,  2.], dtype=float32)
  tangent = Traced<ShapedArray(float32[6])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[6]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x20157d2a0>, in_tracers=(Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x2021fc810; to 'JaxprTracer' at 0x2021fc7c0>], out_avals=[ShapedArray(float32[6])], primitive=concatenate, params={'dimension': 0}, effects=set(),

DeviceArray([[ 0., -2., -1.],
             [ 0.,  1., -1.],
             [ 0.,  1.,  2.]], dtype=float32)

In [9]:
torch.manual_seed(42)
phi = torch.tensor([[ 0.7645,  0.8300, 0.2343,  0.9186],
        [0.2191,  0.2018, 0.4869,  0.5873],
        [ 0.8815, 0.7336,  0.8692,  0.1872],
        [ 0.7388,  0.1354,  0.4822, 0.1412],
        [ 0.7709,  0.1478, 0.4668,  0.2549],
        [0.4607, 0.1173, 0.4062,  0.6634]],requires_grad=True)
#phi = torch.rand(6, 4,requires_grad=True) * 2 - 1
survival_times = torch.tensor([5, 10, 10, 20, 20, 30])
events = torch.tensor([1, 1, 1, 1, 1, 1])
idx_durations = torch.tensor([0,1,1,2,2,3])
rank_mat = torch.from_numpy(pair_rank_mat(idx_durations, events))
sigma = torch.tensor([1.0])
var.backward()
phi.grad

phi

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
var = rank_loss_deephit_single(phi, idx_durations, events, rank_mat, sigma, 'sum') 
var

tensor(2.1637, grad_fn=<SumBackward0>)

In [None]:
var.backward()
phi.grad

tensor([[-0.1522,  0.0494,  0.0272,  0.0540],
        [-0.0392, -0.0666,  0.0389,  0.0430],
        [-0.0296, -0.0650,  0.0491,  0.0248],
        [ 0.0791,  0.0142, -0.0599, -0.0179],
        [ 0.0789,  0.0136, -0.0578, -0.0196],
        [ 0.0998,  0.0449, -0.0069, -0.0909]])

In [None]:
import torch

def sum_expression(x):
    # get the diagonal elements of x
    diagonal = torch.diag(x)
    print('diagonal', diagonal)
    # sum the diagonal elements
    sum_diagonal = torch.sum(diagonal)
    print('sum_diagonal', sum_diagonal)
    # subtract x[i,j] from sum_diagonal for all i != j
    print(x - torch.diag(diagonal))
    subtrahend = torch.sum(x - torch.diag(diagonal), dim=1)
    print('subtrahend', subtrahend)
    # sum the subtrahends to get the final result
    result = sum_diagonal - torch.sum(subtrahend)
    
    return result


In [None]:

def sum_expression(x):
    # get the diagonal elements of x
    diagonal = torch.diag(x)
    print('diagonal', diagonal)
    # subtract x[i,j] from diagonal for all i != j
    print(x - torch.diag(diagonal))
    subtrahend = torch.sum(x - torch.diag(diagonal), dim=1)
    print('subtrahend', subtrahend)
    # sum the subtrahends to get the final result
    result = torch.sum(diagonal) - torch.sum(subtrahend)
    
    return result

In [None]:
x = torch.tensor([
              [1.0, 3.0, 3.0],
              [3.0, 3.0, 3.0],
              [2.0, 3.0, 2.0]], requires_grad=True)

var = sum_expression(x)
var

diagonal tensor([1., 3., 2.], grad_fn=<DiagBackward0>)
tensor([[0., 3., 3.],
        [3., 0., 3.],
        [2., 3., 0.]], grad_fn=<SubBackward0>)
subtrahend tensor([6., 6., 5.], grad_fn=<SumBackward1>)


tensor(-11., grad_fn=<SubBackward0>)

In [None]:
x = torch.tensor([
              [1.0, 3.0, 3.0],
              [3.0, 3.0, 3.0],
              [2.0, 3.0, 2.0]],requires_grad=True)
def test(x):
    s = 0
    mat = torch.zeros(3,3)
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            if i!=j:
                #print(x[i,i] - x[i,j])
                diff = x[i,i] - x[i,j]
                s+=diff
                mat[i,j] = diff
    return mat.sum()

In [None]:
t = test(x)
print(t)
t.backward()
x.grad

tensor(-5., grad_fn=<SumBackward0>)


tensor([[ 2., -1., -1.],
        [-1.,  2., -1.],
        [-1., -1.,  2.]])

In [None]:
import jax.numpy as jnp
from jax import grad

def test(x):
    s = 0
    mat = []
    #mat = jnp.zeros((3,3))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            if i!=j:
                #print(x[i,i] - x[i,j])
                diff = x[i,i] - x[i,j]
                #print(diff)
                s=s + diff
                #mat.append(diff)
    #mat = jnp.array(mat)
    return s#jnp.sum(mat)

x = jnp.array([[1.0, 2.0, 3.0],
              [3.0, 3.0, 3.0],
              [2.0, 2.0, 2.0]])
print(test(x))
grad(test)(x)

-3.0


DeviceArray([[ 2., -1., -1.],
             [-1.,  2., -1.],
             [-1., -1.,  2.]], dtype=float32)

In [None]:
torch.diag(x).reshape(3,1)-torch.cumsum(x, axis=1)-torch.diag(x).reshape(3,1)

TypeError: diag(): argument 'input' (position 1) must be Tensor, not DeviceArray

In [None]:
torch.cumsum(x, axis=1)

tensor([[1., 3., 6.],
        [3., 6., 9.],
        [2., 4., 6.]], grad_fn=<CumsumBackward0>)

In [None]:
var.backward()
x.grad

tensor([[ 1., -1., -1., -1.],
        [-1.,  1., -1., -1.],
        [-1., -1.,  1., -1.],
        [-1., -1., -1.,  1.]])