In [1]:
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
import keras.optimizers
from keras import activations
from sklearn.preprocessing import MinMaxScaler

Using TensorFlow backend.


In [2]:
def pseudomean_ipcw(time, delta, weight, tau):
    '''
    Usage: 
        Compute IPCW pseudo RMST at a single tau using Nelson-Aalen method; tau < the last event time
    Input: 
        1. time: survival time, n*1 array 
        2. delta: censoring indicator, n*1 array 
        3. weight: weight matrix, n*ns array, where ns is the unique event time.
                    if scalar is passed will create np.ones array
        4. tau: landmark time, J*1 array, J is the length of landmark time.
    Ouput: 
        pseudo value matrix, n*J array 
    '''
    n = len(time)
    s = np.unique(sorted([t*d for t,d in zip(t,d) if t*d > 0]))
    ns = len(s) # the number of intervals
    #check if weight is scalar, if true will create an np.ones array 
    w = np.ones((n,), dtype = int) if (np.isscalar(weight)) else weight
    df = pd.concat([pd.DataFrame(data = {'time': t,'delta': d}), 
                    pd.DataFrame(data = w)], axis =1).sort_values(by = ['time', 'delta'], 
                                                                       ascending = (True, False))
    w_idx = df.columns[2:]
    D = np.transpose(np.array([(df['time'].values == si)*df['delta'].values for si in s], dtype = int))
    Y = np.transpose(np.array([si <= df['time'].values for si in s], dtype = int))
    '''
      inx=max(which(s<=tau))
      ttmp=c(0,s)
      tt=c(ttmp[ttmp<=tau],tau) # add one extra column, may repeat, but diff=0
      dt=diff(tt)
    '''
    inx = np.max(np.where(s <= tau))
    tt= np.concatenate([[0], list(s[:inx+1]), [tau]]) # may add one term for tau, may repeat, but won't affect the area since diff=0
    dt = np.diff(tt)
    
    Yw = Y* df[w_idx].values
    Dw = D* df[w_idx].values
    denominator = Yw.sum(axis = 0)
    numerator = Dw.sum(axis = 0)
    IPCW_CH = np.cumsum(numerator/denominator)
    IPCW_surv = np.exp(-IPCW_CH)
    surv = np.concatenate([IPCW_surv[:inx + 1], [IPCW_surv[inx]]])
    IPCW_RM = np.sum(surv*dt)
    

    Denominator = np.array(list(denominator) * n).reshape(n, ns) - Yw
    Numerator = np.array(list(numerator) * n).reshape(n, ns) -Dw
    
    IPCW_CHi = np.cumsum(Numerator/Denominator, axis =1)
    #print(*IPCW_CHi, sep= '\n')
    IPCW_survi = np.exp(-IPCW_CHi)
    M_dt = np.array(list(dt) *n).reshape(n, len(dt))
    ## survi=cbind(IPCW_survi[,1:inx],IPCW_survi[,inx])
    survi = np.column_stack((IPCW_survi[:,:inx +1 ], IPCW_survi[:,inx]))
    IPCW_RMi = np.sum(survi*M_dt , axis = 1)

    
    df['pseudomean'] = n*IPCW_RM-(n-1)*IPCW_RMi
    df = df.sort_index()
    
    return(df['pseudomean'])

In [3]:
n = 200
z = np.random.binomial(1, 0.5, n)
c0 = 0.01
times = np.random.exponential(1/(c0 *np.exp(1*z)), n)
time_censor = np.random.exponential(1/(c0 *np.exp(1*z)), n)
d = np.array([int(x) for x in times < time_censor])
t = np.array([x if x <y else y for x,y in zip(times, time_censor)])

In [4]:
sim = pd.read_csv('../Data/Pseudo_surv_simulation.csv')

In [5]:
t, d, w  = sim['t'], sim['d'], sim[sim.columns[2:]]

In [6]:
taus = [5,12,20,30,40,55,75,108]
ntau = len(taus)
xx = np.transpose(np.array([pseudomean_ipcw(t, d, w, tau) for tau in taus ]))

In [7]:
## xx normalization
scaler = MinMaxScaler()
scaler.fit(xx)
xx_norm = scaler.transform(xx)

In [8]:
Z= np.reshape(z, (-1, 1))

In [9]:
model = Sequential()
model.add(Dense(8, input_dim=Z.shape[1], activation= activations.relu))
# model.add(Dropout(0.2))
# model.add(Dense(4, activation = activations.relu))
# model.add(Dropout(0.4))
model.add(Dense(xx_norm.shape[1], activation = activations.sigmoid))
opt = keras.optimizers.Adam(learning_rate=0.01)
model.compile(loss='mse', optimizer= opt, metrics=['mae'])
history = model.fit(Z, xx_norm, batch_size = 256, epochs = 1000, verbose = 0)

In [10]:
ypred_orig = model.predict(x= np.array([0,1]))

In [11]:
scaler.inverse_transform(ypred_orig)

array([[ 4.7217064, 10.565295 , 16.258284 , 22.634243 , 27.943254 ,
        33.553883 , 39.770817 , 49.229332 ],
       [ 4.814658 , 10.883322 , 16.704382 , 23.341932 , 28.608892 ,
        34.82333  , 40.349617 , 45.296234 ]], dtype=float32)