In [47]:
import pandas as pd
import numpy as np
import os
import time
import copy
import pathlib, tempfile

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from graphviz import Digraph
from joblib import Parallel, delayed
from scipy import stats

from survivors import metrics as metr
from survivors import constants as cnt
from survivors import criteria as crit
from numba import njit, jit

%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [928]:
@njit(cache = True)
def count_N_O(dur_1, dur_2, cens_1, cens_2, times_range, weightings):
    bins = times_range[1] - times_range[0] + 1
    n_1_j = np.histogram(dur_1, bins=bins, range=times_range)[0]
    n_2_j = np.histogram(dur_2, bins=bins, range=times_range)[0]
    O_1_j = np.histogram(dur_1*cens_1, bins=bins, range=times_range)[0]
    O_2_j = np.histogram(dur_2*cens_2, bins=bins, range=times_range)[0]
    
    N_1_j = np.cumsum(n_1_j[::-1])[::-1]
    N_2_j = np.cumsum(n_2_j[::-1])[::-1]
    ind = np.where(N_1_j * N_2_j != 0)
    N_1_j = N_1_j[ind]
    N_2_j = N_2_j[ind]
    O_1_j = O_1_j[ind]
    O_2_j = O_2_j[ind]

    N_j = N_1_j + N_2_j
    O_j = O_1_j + O_2_j
    E_1_j = N_1_j*O_j/N_j
    res = np.zeros((N_j.shape[0], 3), dtype=np.float64)
    res[:, 1] = O_1_j - E_1_j
    res[:, 2] = E_1_j*(N_j - O_j) * N_2_j/(N_j*(N_j - 1))
    return N_j, O_j, res

@njit(cache = True)
def get_lr(res):
    return np.power((res[:, 0]*res[:, 1]).sum(), 2) / ((res[:, 0]*res[:, 0]*res[:, 2]).sum())

# @njit
# def lr_statistic(dur_1, dur_2, cens_1, cens_2, times_range, weightings):
#     bins = times_range[1] - times_range[0] + 1
#     n_1_j = np.histogram(dur_1, bins=bins, range=times_range)[0]
#     n_2_j = np.histogram(dur_2, bins=bins, range=times_range)[0]
#     O_1_j = np.histogram(dur_1*cens_1, bins=bins, range=times_range)[0]
#     O_2_j = np.histogram(dur_2*cens_2, bins=bins, range=times_range)[0]
    
#     N_1_j = np.cumsum(n_1_j[::-1])[::-1]
#     N_2_j = np.cumsum(n_2_j[::-1])[::-1]
#     ind = np.where(N_1_j * N_2_j != 0)
#     N_1_j = N_1_j[ind]
#     N_2_j = N_2_j[ind]
#     O_1_j = O_1_j[ind]
#     O_2_j = O_2_j[ind]

#     N_j = N_1_j + N_2_j
#     O_j = O_1_j + O_2_j
#     E_1_j = N_1_j*O_j/N_j
#     res = np.zeros((N_j.shape[0], 3), dtype=np.float64)
#     res[:, 0] = 1.0
#     if weightings == "wilcoxon":
#         res[:, 0] = N_j
#     elif weightings == "tarone-ware":
#         res[:, 0] = np.sqrt(N_j)
#     elif weightings == "peto":
#         res[:, 0] = np.cumprod((1.0 - O_j/(N_j+1)))
#     print(res[:, 0])
#     res[:, 1] = O_1_j - E_1_j
#     res[:, 2] = E_1_j*(N_j - O_j) * N_2_j/(N_j*(N_j - 1))
#     logrank = np.power((res[:, 0]*res[:, 1]).sum(), 2) / ((res[:, 0]*res[:, 0]*res[:, 2]).sum())
#     return logrank

def lr_statistic(dur_1, dur_2, cens_1, cens_2, times_range, weightings):
    N_j, O_j, res = count_N_O(dur_1, dur_2, cens_1, cens_2, times_range, weightings)
    res[:, 0] = 1.0
    if weightings == "wilcoxon":
        res[:, 0] = N_j
    elif weightings == "tarone-ware":
        res[:, 0] = np.sqrt(N_j)
    elif weightings == "peto":
        res[:, 0] = np.cumprod((1.0 - O_j/(N_j+1)))

    logrank = get_lr(res)
    return logrank

def weight_lr_fast(dur_A, dur_B, cens_A = None, cens_B = None, weightings = ""):
#     times = np.unique(np.concatenate([dur_A, dur_B]))
    times = np.unique(np.hstack((dur_A, dur_B)))
#     times = np.union1d(np.unique(dur_A), np.unique(dur_B))
    dur_A = np.searchsorted(times, dur_A) + 1
    dur_B = np.searchsorted(times, dur_B) + 1
    times_range = (1, times.shape[0])
    if cens_A is None:
        cens_A = np.ones(dur_A.shape[0])
    if cens_B is None:
        cens_B = np.ones(dur_B.shape[0])
    logrank = lr_statistic(dur_A, dur_B, cens_A, cens_B, times_range, weightings)
    pvalue = stats.chi2.sf(logrank, df=1)
    return pvalue

In [65]:
from numba.pycc import CC
from numba import cuda

# cc = CC('lr_crit')
# cc._source_module = "lr_crit.code2compile" 
# cc.output_dir='{}\\dist'.format(os.path.abspath('..'))

# @cc.export('lr_statistic', 'f8(i8[:], i8[:], i8[:], i8[:], i8[:], i8)')

@njit(parallel=True)
def lr_statistic(dur_1, dur_2, cens_1, cens_2, times_range, weightings):
    bins = times_range[1] - times_range[0] + 1
    n_1_j = np.histogram(dur_1, bins=bins, range=times_range)[0]
    n_2_j = np.histogram(dur_2, bins=bins, range=times_range)[0]
    O_1_j = np.histogram(dur_1 * cens_1, bins=bins, range=times_range)[0]
    O_2_j = np.histogram(dur_2 * cens_2, bins=bins, range=times_range)[0]

    N_1_j = np.cumsum(n_1_j[::-1])[::-1]
    N_2_j = np.cumsum(n_2_j[::-1])[::-1]
    ind = np.where(N_1_j * N_2_j != 0)
    N_1_j = N_1_j[ind]
    N_2_j = N_2_j[ind]
    O_1_j = O_1_j[ind]
    O_2_j = O_2_j[ind]

    N_j = N_1_j + N_2_j
    O_j = O_1_j + O_2_j
    E_1_j = N_1_j * O_j / N_j
    res = np.zeros((N_j.shape[0], 3), dtype=np.float32)
    res[:, 1] = O_1_j - E_1_j
    res[:, 2] = E_1_j * (N_j - O_j) * N_2_j / (N_j * (N_j - 1))
    res[:, 0] = 1.0
    if weightings == 2:
        res[:, 0] = N_j
    elif weightings == 3:
        res[:, 0] = np.sqrt(N_j)
    elif weightings == 4:
        res[:, 0] = np.cumprod((1.0 - O_j / (N_j + 1)))
    logrank = np.power((res[:, 0] * res[:, 1]).sum(), 2) / ((res[:, 0] * res[:, 0] * res[:, 2]).sum())
    return logrank


def weight_lr_fast(dur_A, dur_B, cens_A=None, cens_B=None, weightings=""):
    times = np.unique(np.hstack((dur_A, dur_B)))
    dur_A = np.searchsorted(times, dur_A) + 1
    dur_B = np.searchsorted(times, dur_B) + 1
    times_range = np.array([1, times.shape[0]])
    if cens_A is None:
        cens_A = np.ones(dur_A.shape[0])
    if cens_B is None:
        cens_B = np.ones(dur_B.shape[0])
    d = {"logrank": 1, "wilcoxon": 2, "tarone-ware": 3, "peto": 4}
    weightings = d.get(weightings, 1)
    logrank = lr_statistic(dur_A, dur_B, cens_A, cens_B, times_range, weightings)
    pvalue = stats.chi2.sf(logrank, df=1)
    return pvalue
    
# cc.compile()

In [32]:
from lr_crit import lr_statistic
def weight_lr_fast(dur_A, dur_B, cens_A=None, cens_B=None, weightings=""):
    try:
        times = np.unique(np.hstack((dur_A, dur_B)))
        dur_A = np.searchsorted(times, dur_A) + 1
        dur_B = np.searchsorted(times, dur_B) + 1
        times_range = np.array([1, times.shape[0]])
        if cens_A is None:
            cens_A = np.ones(dur_A.shape[0])
        if cens_B is None:
            cens_B = np.ones(dur_B.shape[0])
        d = {"logrank": 1, "wilcoxon": 2, "tarone-ware": 3, "peto": 4}
        weightings = d.get(weightings, 1)
        logrank = lr_statistic(dur_A.astype("int64"), 
                               dur_B.astype("int64"), 
                               cens_A.astype("int64"), 
                               cens_B.astype("int64"), 
                               times_range.astype("int64"), 
                               np.int64(weightings))
        pvalue = stats.chi2.sf(logrank, df=1)
        return pvalue
    except:
        return 1.0

In [57]:
# for i in range(10):
#     dur_A_ = np.random.uniform(0, 10000, 10000)
#     cens_A_ = np.random.choice(2, 10000)
#     dur_B_ = np.random.uniform(0, 10000, 10000)
#     cens_B_ = np.random.choice(2, 10000)
#     print(weight_lr_fast(dur_A_, dur_B_, cens_A_, cens_B_))
#     print(crit.weight_lr_fast(dur_A_, dur_B_, cens_A_, cens_B_))

In [865]:
%lprun -f weight_lr_fast weight_lr_fast(dur_A_, dur_B_, cens_A_, cens_B_)

In [69]:
dur_A_ = np.random.choice(10000, 10000)
cens_A_ = np.random.choice(2, 10000)
dur_B_ = np.random.choice(10000, 10000)
cens_B_ = np.random.choice(2, 10000)

In [70]:
%timeit weight_lr_fast(dur_A_, dur_B_, cens_A_, cens_B_, "peto")

2.39 ms ± 22 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [71]:
%timeit crit.weight_lr_fast(dur_A_, dur_B_, cens_A_, cens_B_, "peto")

75.1 ms ± 333 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [910]:
times = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
dur_A = np.array([0, 2, 3, 4, 5, 3, 9, 3])
cens_A = np.array([1, 1, 0, 1, 0, 0, 0, 0])
dur_B = np.array([4, 5, 6, 7, 8, 1, 3])
cens_B = np.array([1, 1, 1, 0, 1, 1, 1])

In [918]:
@jit
def numb_coeffs_t_j(dur_A, dur_B, cens_A, cens_B, t_j, weightings):
    N_1_j = (dur_A >= t_j).sum()
    N_2_j = (dur_B >= t_j).sum()
    if N_1_j == 0 or N_2_j == 0:
        return 0, 0, 0
    O_1_j = ((dur_A == t_j) * cens_A).sum()  # np.where(dur_A == t_j, cens_A,0).sum()
    O_2_j = ((dur_B == t_j) * cens_B).sum()  # np.where(dur_B == t_j, cens_B,0).sum()
    
    N_j = N_1_j + N_2_j
    O_j = O_1_j + O_2_j
    E_1_j = N_1_j*O_j/N_j
    w_j = 1
    if weightings == "wilcoxon":
        w_j = N_j
    elif weightings == "tarone-ware":
        w_j = np.sqrt(N_j)
    elif weightings == "peto":
        w_j = (1.0 - float(O_j)/(N_j+1))
    
    num = O_1_j - E_1_j
    denom = E_1_j*(N_j - O_j) * N_2_j/(N_j*(N_j - 1))
    return w_j, num, denom

@jit
def numb_lr_statistic(dur_A, dur_B, cens_A, cens_B, times, weightings):
    res = np.zeros((times.shape[0], 3), dtype=np.float32)
    for j, t_j in enumerate(times):
        res[j] = numb_coeffs_t_j(dur_A, dur_B, cens_A, cens_B, t_j, weightings)
    
    if weightings == "peto":
        res[:, 0] = np.cumprod(res[:, 0])
    print(res[:, 0])
    # logrank = np.dot(res[:, 0], res[:, 1])**2 / np.dot(res[:, 0]*res[:, 0], res[:, 2])
    logrank = np.power((res[:, 0]*res[:, 1]).sum(), 2) / ((res[:, 0]*res[:, 0]*res[:, 2]).sum())
#     print(np.power((res[:, 0]*res[:, 1]).sum(), 2), ((res[:, 0]*res[:, 0]*res[:, 2]).sum()))
#     print(logrank)
    return logrank

def numb_weight_lr_fast(dur_A, dur_B, cens_A = None, cens_B = None, weightings = ""):
    try:
        if cens_A is None:
            cens_A = np.ones(dur_A.shape[0])
        if cens_B is None:
            cens_B = np.ones(dur_B.shape[0])

        #     a1 = np.unique(dur_A)
        #     a2 = np.unique(dur_B)
        #     times = np.unique(np.clip(np.union1d(a1,a2), 0, np.min([a1.max(), a2.max()])))
        times = np.union1d(np.unique(dur_A), np.unique(dur_B))
        logrank = numb_lr_statistic(dur_A, dur_B, cens_A, cens_B, times, weightings)
        pvalue = stats.chi2.sf(logrank, df=1)
        return pvalue
    except:
        return 1.0

In [919]:
numb_weight_lr_fast(dur_A, dur_B, cens_A, cens_B)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]


0.5281722982828967

In [920]:
weight_lr_fast(dur_A, dur_B, cens_A, cens_B)

[1. 1. 1. 1. 1. 1. 1. 1. 1.]


0.5281723126024873

In [913]:
crit.weight_lr_fast(dur_A, dur_B, cens_A, cens_B)

0.5281722982828967

In [129]:
np.cumsum(np.array([1, 2, 3, 4])[::-1])[::-1]

array([10,  9,  7,  4])

In [215]:
res = weight_lr_fast(dur_A, dur_B, cens_A, cens_B)
((res[:, 0]*res[:, 0]*res[:, 2]).sum())

[[ 1.         -0.5         0.25      ]
 [ 1.          0.46153846  0.24852072]
 [ 1.          0.5         0.6136364 ]
 [ 1.          0.25        0.4017857 ]
 [ 1.         -0.33333334  0.22222222]
 [ 1.         -0.25        0.1875    ]
 [ 1.          0.          0.        ]
 [ 1.         -0.5         0.25      ]]
0.1382314766966033 2.173665
0.06359372376451176


2.173665

In [230]:
sum(res[:, 0]*res[:, 0]*res[:, 2])

2.173665016889572

In [560]:
times_range = (times.min(), times.max())
bins = times_range[1] - times_range[0] + 1
n_1_j = np.histogram(dur_A, bins=bins,
                     range=times_range)[0]
n_2_j = np.histogram(dur_B, bins=bins, 
                     range=times_range)[0]
O_1_j = np.histogram(dur_A*cens_A, bins=bins, #weights=cens_1,
                     range=times_range)[0]
O_2_j = np.histogram(dur_B*cens_B, bins=bins, #weights=cens_2,
                     range=times_range)[0]

In [561]:
print(n_1_j, n_2_j, O_1_j, O_2_j)

[0 0 1 ... 0 0 0] [0 1 0 ... 0 0 0] [5 0 1 ... 0 0 0] [1 1 0 ... 0 0 0]


In [555]:
dur_A

array([  2,   3,   4,   5,   3, 100,   3])

In [695]:
%timeit np.histogram(dur_A_, bins=bins, range=times_range)[0]

231 µs ± 7.52 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [696]:
def get_freq(dur, times):
    dur.sort()
    n_1_j = np.zeros(times.shape[0])
    ind = np.searchsorted(times, dur)
    n_1_j[ind] = np.bincount(dur)[ind]
    return n_1_j

In [697]:
n_1_j

array([1., 0., 1., 3., 1., 1., 0., 0., 0., 1.])

In [698]:
%timeit get_freq(dur_A_, times)

499 µs ± 2.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [2]:
from numba.pycc import CC

cc = CC('dist_my_module_1')
cc._source_module = "dist_my_module_1.code2compile" 
# # Uncomment the following line to print out the compilation steps
#cc.verbose = True

@cc.export('mult_arr', 'f8[:](f8[:], f8[:])')
@cc.export('multf', 'f8(f8, f8)')
@cc.export('multi', 'i4(i4, i4)')
def mult(a, b):
    return (a + b)*2

@cc.export('square', 'f8(f8)')
def square(a):
    return a ** 2

@cc.export('centdiff_1d', 'f8[:](f8[:], f8)')
def centdiff_1d(u, dx):
    D = np.empty_like(u)
    D[0] = 0
    D[-1] = 0
    for i in range(1, len(D) - 1):
        D[i] = (u[i+1] - 2 * u[i] + u[i-1]) / dx**2
    return D

cc.compile()

In [3]:
from dist_my_module_1 import centdiff_1d

In [4]:
centdiff_1d(np.array([1.0, 2.0, 3.0], dtype=np.float64), 4)

array([0., 0., 0.])

In [25]:
dur_A_ * cens_A_ == dur_A_ & cens_A_

array([False, False, False, ...,  True,  True,  True])

In [26]:
from numba import cuda
