In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#from jointmodel import sim
import pandas as pd
import patsy
import sys
import pystan
import random
random.seed(1234)
import survivalstan
from stancache import stancache
from stancache import config

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


INFO:stancache.seed:Setting seed to 1245502385


## prepare data

In [2]:
data = survivalstan.sim.sim_data_jointmodel(N=100)

In [3]:
data.keys()

dict_keys(['params', 'covars', 'events', 'biomarker'])

In [4]:
data['events'].head()

Unnamed: 0,subject_id,time,event_value,event_name
0,0,4.288769,1,death
1,1,1.006545,1,death
2,2,0.032628,1,death
3,3,5.5,0,death
4,4,0.55538,1,death


In [5]:
data['covars'].head()

Unnamed: 0,subject_id,X1,X2
0,0,1,0
1,1,0,0
2,2,1,1
3,3,1,1
4,4,0,0


In [6]:
df = pd.merge(data['events'].query('event_name == "death"'), data['covars'], on='subject_id')
df.head()

Unnamed: 0,subject_id,time,event_value,event_name,X1,X2
0,0,4.288769,1,death,1,0
1,1,1.006545,1,death,0,0
2,2,0.032628,1,death,1,1
3,3,5.5,0,death,1,1
4,4,0.55538,1,death,0,0


## standard patsy formula

In [7]:
formula = '(time + event_value) ~ 0 + X1'
y, X = patsy.dmatrices(formula_like=formula, data=df)

In [8]:
md = patsy.ModelDesc.from_formula(formula)

In [9]:
md.lhs_termlist

[Term([EvalFactor('time')]), Term([EvalFactor('event_value')])]

In [10]:
len(md.lhs_termlist)

2

## `Id` class

In [66]:
import numpy as np

class Id(object):
    def __init__(self, desc='id'):
        self.values = []
        self.desc = desc
    
    def memorize_chunk(self, x):
        self.values.extend(np.unique(x))
    
    def memorize_finish(self):
        self.ids = np.arange(len(self.values))+1
        self.lookup = dict(zip(self.values, self.ids))
    
    def transform(self, x):
        if patsy.util.have_pandas and isinstance(x, pd.Series):
            d = pd.Series([self.lookup[val] for val in x]).astype(int)
            d.index = x.index
            return(d)
        else:
            return np.array([self.lookup[val] for val in x])
    
    def len(self):
        return len(self.ids)
    
    def decode_df(self):
        return pd.DataFrame({'id': self.ids, 'value': self.values})

as_id = patsy.stateful_transform(Id)

In [67]:
print(as_id(np.array(['a','b','a','c'])))

[1 2 1 3]


In [68]:
test_formula = 'event_value + as_id(time) + as_id(subject_id) ~ X1 + X2'

In [69]:
y, X = patsy.dmatrices(formula_like=test_formula, data=df)

pd.DataFrame(y).head()

Unnamed: 0,0,1,2
0,1.0,64.0,1.0
1,1.0,45.0,2.0
2,1.0,9.0,3.0
3,0.0,67.0,4.0
4,1.0,34.0,5.0


## `SurvData` class
A helper class for Surv

In [70]:
class SurvData(pd.DataFrame):
    ''' patsy.DesignMatrix representing survival data output '''
    survival_type = 'wide'
    
    def __init__(self, *args, stan_data=dict(), meta_data=dict(), **kwargs):
        super().__init__(*args, **kwargs)
        self.stan_data = stan_data
        self.meta_data = meta_data

class LongSurvData(SurvData):
    ''' pd.DataFrame representing survival data with endpoint_time_id, event_status & subject_id '''
    survival_type = 'long'

class NotValidId(ValueError):
    ''' Class of errors pertaining to invalid Id variables '''


## `Surv` class

In [71]:
class Surv(object):
    def __init__(self):
        self.subject_id = Id('subject')
        self.timepoint_id = Id('timepoint')
        self.group_id = Id('group')
        self._type = None
        self._grouped = None
        pass
    
    def _check_kwargs(self, **kwargs):
        kwargs = dict(**kwargs)
        allowed_kwargs = ['subject', 'group']
        bad_keys = [key not in allowed_kwargs for key in kwargs.keys()]
        if any(bad_keys):
            raise ValueError('Invalid parameter: {}'.format(','.join(bad_keys)))
        return kwargs
        
    def memorize_chunk(self, time, event_status, **kwargs):
        kwargs = self._check_kwargs(**kwargs)
        if 'subject' in kwargs.keys():
            self._type = 'long'
            self.subject_id.memorize_chunk(kwargs['subject'])
            self.timepoint_id.memorize_chunk(time)
        else:
            self._type = 'wide'
        if 'group' in kwargs.keys():
            self._grouped = True
            self.group_id.memorize_chunk(kwargs['group'])
        else:
            self._grouped = False
    
    def memorize_finish(self):
        self.subject_id.memorize_finish()
        self.group_id.memorize_finish()
        self.timepoint_id.memorize_finish()
    
    def _prep_timepoint_standata(self, timepoint_df):
        unique_timepoints = survivalstan.survivalstan._prep_timepoint_dataframe(
            timepoint_df,
            timepoint_id_col='id',
            timepoint_end_col='value')
        timepoint_input_data = {
            't_dur': unique_timepoints['t_dur'],
            't_obs': unique_timepoints['value'],
            'T': len(unique_timepoints.index)
        }
        return timepoint_input_data
    
    def _prep_long(self, timepoint_id, event_status, subject_id, group_id=None, **kwargs):
        if patsy.util.have_pandas:
            dm = {'timepoint_id': timepoint_id,
                  'event_status': event_status,
                  'subject_id': subject_id
                  }
            
            if group_id is not None:
                dm.update({'group_id': group_id})
            dm = pd.DataFrame(dm)
            dm.index = event_status.index
        else:
            if group_id is not None:
                dm = np.append(timepoint_id, event_status, subject_id, group_id, 1)
            else:
                dm = np.append(timepoint_id, event_status, subject_id, 1)
        return LongSurvData(dm, **kwargs)
    
    def _prep_wide(self, time, event_status, group_id=None, **kwargs):
        if patsy.util.have_pandas:
            dm = {'time': time,
                 'event_status': event_status,
                 }
            
            if group_id is not None:
                dm.update({'group_id': group_id})
            dm = pd.DataFrame(dm)
            dm.index = time.index
        else:
            if group_id is not None:
                dm = np.append(time, event_status, group_id, 1)
            else:
                dm = np.append(time, event_status, 1)
        return SurvData(dm, **kwargs)
        
    def transform(self, time, event_status, **kwargs):
        kwargs = self._check_kwargs(**kwargs)
        meta_data = dict()
        stan_data = dict()
        if 'subject' in kwargs.keys():
            subject_id = self.subject_id.transform(kwargs['subject'])
            timepoint_id = self.timepoint_id.transform(time)
            meta_data.update({'timepoint_id': self.timepoint_id.decode_df(),
                              'subject_id': self.subject_id.decode_df()})
            stan_data.update(self._prep_timepoint_standata(self.timepoint_id.decode_df()))
            stan_data.update({'S': self.subject_id.len()})
        if 'group' in kwargs.keys():
            group_id = self.group_id.transform(kwargs['group'])
            stan_data.update({'G': len(self.group_id.len())})
            meta_data.update({'group_id': self.group_id.decode_df()})
        else:
            group_id = None

        if self._type == 'long':
            return(self._prep_long(timepoint_id=timepoint_id, event_status=event_status,
                                  subject_id=subject_id, group_id=group_id,
                                  meta_data=meta_data, stan_data=stan_data)
                  )
        elif self._type == 'wide':
            return(self._prep_wide(time=time, event_status=event_status, group_id=group_id,
                                  meta_data=meta_data, stan_data=stan_data))

surv = patsy.stateful_transform(Surv)

In [72]:
df.head()

Unnamed: 0,subject_id,time,event_value,event_name,X1,X2
0,0,4.288769,1,death,1,0
1,1,1.006545,1,death,0,0
2,2,0.032628,1,death,1,1
3,3,5.5,0,death,1,1
4,4,0.55538,1,death,0,0


In [73]:
surv(time=df['time'], event_status=df['event_value']).head()

Unnamed: 0,event_status,time
0,1,4.288769
1,1,1.006545
2,1,0.032628
3,0,5.5
4,1,0.55538


In [74]:
surv(time=df['time'], event_status=df['event_value'], subject=df['subject_id']).head()

Unnamed: 0,event_status,subject_id,timepoint_id
0,1,1,64
1,1,2,45
2,1,3,9
3,0,4,67
4,1,5,34


In [75]:
y, X = patsy.dmatrices('surv(time=time, event_status=event_value) ~ X1', data=df)

In [76]:
pd.DataFrame(y).head()

Unnamed: 0,0,1
0,1.0,4.288769
1,1.0,1.006545
2,1.0,0.032628
3,0.0,5.5
4,1.0,0.55538


In [77]:
y2, X = patsy.dmatrices('surv(time=time, event_status=event_value, subject=subject_id) ~ X1', data=df)

In [78]:
pd.DataFrame(y2).head()

Unnamed: 0,0,1,2
0,1.0,1.0,64.0
1,1.0,2.0,45.0
2,1.0,3.0,9.0
3,0.0,4.0,67.0
4,1.0,5.0,34.0


In [79]:
pd.DataFrame(y2).tail()

Unnamed: 0,0,1,2
95,0.0,96.0,67.0
96,0.0,97.0,67.0
97,1.0,98.0,40.0
98,1.0,99.0,21.0
99,1.0,100.0,58.0


Test whether class/id labels are retained for predicting new data:

In [80]:
(y2.new, X.new) = patsy.build_design_matrices([y2.design_info, X.design_info], df.tail()) 

In [81]:
pd.DataFrame(y2.new).head()

Unnamed: 0,0,1,2
0,0.0,96.0,67.0
1,0.0,97.0,67.0
2,1.0,98.0,40.0
3,1.0,99.0,21.0
4,1.0,100.0,58.0


### Can we get meta-information about type?

In [82]:
y, X = patsy.dmatrices('surv(event_status=event_value, time=time) ~ X1', data=df)

In [83]:
pd.DataFrame(y).head()

Unnamed: 0,0,1
0,1.0,4.288769
1,1.0,1.006545
2,1.0,0.032628
3,0.0,5.5
4,1.0,0.55538


In [84]:
y, X = patsy.dmatrices('surv(event_status=event_value,time=as_id(time), subject=as_id(subject_id)) ~ bs(X1, 3)',
                       data=df)

In [85]:
isinstance(y, LongSurvData)

False

In [86]:
y.shape

(100, 3)

In [87]:
y.design_info.term_names

['surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))']

In [88]:
X.design_info.factor_infos

{EvalFactor('bs(X1, 3)'): FactorInfo(factor=EvalFactor('bs(X1, 3)'),
            type='numerical',
            state=<factor state>,
            num_columns=3)}

In [89]:
y.design_info

DesignInfo(['surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[0]',
            'surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[1]',
            'surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[2]'],
           factor_infos={EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'): FactorInfo(factor=EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'),
                                    type='numerical',
                                    state=<factor state>,
                                    num_columns=3)},
           term_codings=OrderedDict([(Term([EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))')]),
                                      [SubtermInfo(factors=(EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'),),
                                     

In [90]:
try:
    y.survival_type
except AttributeError as e:
    print(str(e))

'DesignMatrix' object has no attribute 'survival_type'


In [91]:
y.design_info

DesignInfo(['surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[0]',
            'surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[1]',
            'surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))[2]'],
           factor_infos={EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'): FactorInfo(factor=EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'),
                                    type='numerical',
                                    state=<factor state>,
                                    num_columns=3)},
           term_codings=OrderedDict([(Term([EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))')]),
                                      [SubtermInfo(factors=(EvalFactor('surv(event_status=event_value, time=as_id(time), subject=as_id(subject_id))'),),
                                     

## SurvivalFactor class

In [92]:
import logging
logger = logging.getLogger(__name__)
class SurvivalFactor(patsy.EvalFactor):
    ''' A factor object to encode LHS variables 
        for Survival Models, including model type
    '''
    _is_survival = True
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._class = None
    
    def eval(self, *args, **kwargs):
        result = super().eval(*args, **kwargs)
        try:
            self._class = result.__class__
        except:
            logger.warning('Outcome class could not be determined')
        if isinstance(result, SurvData):
            self._type = result.survival_type
            self._meta_data = result.meta_data
            self._stan_data = result.stan_data
            
        return result
        
    

In [93]:
a = SurvivalFactor(code='surv(time=time, event_status=event_value)')

In [94]:
print(a._class)

None


In [95]:
md = patsy.ModelDesc([patsy.Term([a])],[])

In [96]:
y, X = patsy.dmatrices(md, data=df)

In [97]:
y.shape

(100, 2)

In [98]:
term = y.design_info.terms[0]

In [99]:
term.factors[0]._is_survival

True

In [100]:
term.factors[0]._class

__main__.SurvData

In [101]:
term.factors[0]._type

'wide'

## SurvivalModelDesc class

Next we need to find a way to make sure that `SurvivalFactor` type is used by default

In [102]:
import re

In [103]:
class SurvivalModelDesc(object):
    
    def __init__(self, formula):
        self.formula = formula
        self.lhs, self.rhs = re.split(string=formula, pattern='~', maxsplit=1)
        self.lhs_termlist = [patsy.Term([SurvivalFactor(self.lhs)])]
        self.rhs_termlist = patsy.ModelDesc.from_formula(self.rhs).rhs_termlist
        
    def __patsy_get_model_desc__(self, eval_env):
        return patsy.ModelDesc(self.lhs_termlist, self.rhs_termlist)

### confirm we can determine type

In [104]:
my_formula = SurvivalModelDesc('surv(time=time, event_status=event_value) ~ X1')
y, X = patsy.dmatrices(my_formula, data=df)

In [105]:
pd.DataFrame(y).head()

Unnamed: 0,0,1
0,1.0,4.288769
1,1.0,1.006545
2,1.0,0.032628
3,0.0,5.5
4,1.0,0.55538


In [106]:
## should only be one LHS term
assert(len(y.design_info.terms) == 1)

## should only be one LHS factor (within single term)
assert(len(y.design_info.terms[0].factors) == 1)

## LHS factor should be of type "survival"
assert(y.design_info.terms[0].factors[0]._is_survival == True)

In [107]:
# get type of LHS term
y.design_info.terms[0].factors[0]._class

__main__.SurvData

In [108]:
survival_type = y.design_info.terms[0].factors[0]._type
survival_type

'wide'

### confirm we can extract stan & meta-data

In [109]:
my_formula2 = SurvivalModelDesc('surv(time=time, event_status=event_value, subject=subject_id) ~ X1')
y2, X = patsy.dmatrices(my_formula2, data=df)

In [110]:
## should only be one LHS term
assert(len(y2.design_info.terms) == 1)

## should only be one LHS factor (within single term)
assert(len(y2.design_info.terms[0].factors) == 1)

## LHS factor should be of type "survival"
assert(y2.design_info.terms[0].factors[0]._is_survival == True)

In [111]:
y2.design_info.terms[0].factors[0]._class

__main__.LongSurvData

In [112]:
survival_type = y2.design_info.terms[0].factors[0]._type
survival_type

'long'

In [113]:
stan_data = y2.design_info.terms[0].factors[0]._stan_data
stan_data.keys()

dict_keys(['t_obs', 'S', 't_dur', 'T'])

In [114]:
meta_data = y2.design_info.terms[0].factors[0]._meta_data
meta_data.keys()

dict_keys(['timepoint_id', 'subject_id'])

In [116]:
meta_data['subject_id'].head()

Unnamed: 0,id,value
0,1,0
1,2,1
2,3,2
3,4,3
4,5,4


### test model-matrix design on newdata

In [117]:
y2.design_info

DesignInfo(['surv(time=time, event_status=event_value, subject=subject_id)[0]',
            'surv(time=time, event_status=event_value, subject=subject_id)[1]',
            'surv(time=time, event_status=event_value, subject=subject_id)[2]'],
           factor_infos={SurvivalFactor('surv(time=time, event_status=event_value, subject=subject_id)'): FactorInfo(factor=SurvivalFactor('surv(time=time, event_status=event_value, subject=subject_id)'),
                                    type='numerical',
                                    state=<factor state>,
                                    num_columns=3)},
           term_codings=OrderedDict([(Term([SurvivalFactor('surv(time=time, event_status=event_value, subject=subject_id)')]),
                                      [SubtermInfo(factors=(SurvivalFactor('surv(time=time, event_status=event_value, subject=subject_id)'),),
                                                   contrast_matrices={},
                                              

In [118]:
y2.new, X.new = patsy.build_design_matrices(design_infos=[y2.design_info, X.design_info], data=df.tail())

In [119]:
pd.DataFrame(y2.new).head()

Unnamed: 0,0,1,2
0,0.0,96.0,67.0
1,0.0,97.0,67.0
2,1.0,98.0,40.0
3,1.0,99.0,21.0
4,1.0,100.0,58.0
