In [None]:
import numpy as np
import pandas as pd
import cvxpy as cp
import itertools as it

import pickle
import glob

from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.nonparametric import nelson_aalen_estimator

from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest

import matplotlib.pyplot as plt

from matplotlib.colors import Normalize, to_hex
from matplotlib.cm import get_cmap

# Utils

In [None]:
def time_gen(x, beta, l, k, max_time=1000, death_prob=0.1):
    
    v = np.random.rand()
    time = np.power(-np.log(v) / (l * np.exp(np.dot(x, beta))), 1 / k)
    time = max_time if time > max_time else np.round(time)
    event = bool(np.random.choice([0, 1], 1, p=[death_prob, 1 - death_prob]))
    
    return (event, time)


def data_gen(n_points, n_features, l, k, max_time):
    
    beta_true = np.round(np.random.uniform(-1, 1, n_features), 2)
    X = np.random.uniform(0, 1, size=(n_points, n_features))
    y = np.array([time_gen(x, beta_true, l, k, max_time) for x in X], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])
    
    return beta_true, X, y


def uniform_within_ball(origin, radius, n_points):
    
    dim = origin.size
    
    gauss_points = np.random.normal(loc=0, scale=1, size=(n_points, dim))
    norms = np.linalg.norm(gauss_points, axis=1)
    unf_points = np.power(np.random.uniform(low=0, high=1, size=n_points), 1 / dim)

    points = origin + radius * np.array([gp / n * up for gp, n, up in zip(gauss_points, norms, unf_points)])
    
    return points


def dichotomy(func, x_min=-1000, x_max=1000, tol=1e-9):
    
    while x_max - x_min > tol:
        point = 0.5 * (x_min + x_max)
        f_val = func(point)
        
        if f_val * func(x_max) < 0:
            x_min = point
        else:
            x_max = point
            
    point = 0.5 * (x_min + x_max)

    return point


def expand(args, values, new_args, zero_val, increasing=True):
    
    indexes = np.searchsorted(args, new_args, side='right') - 1
    new_values = np.where(indexes == -1, zero_val, values[indexes])
    
    if increasing:
        new_values = np.where(new_values < zero_val, zero_val, new_values)
        
    return new_values


def time_unification(pred_times, pred_data, pred_type, y_train, zero_val):
        
    if not pred_type in ['sf', 'chf']:
        raise ValueError('Models can produce either a survival function (sf) or a cumulative hazard function (chf)')
    
    times = np.unique([t for _, t in y_train])
    times = np.array([0] + times.tolist())
    
    pred_data = pred_data.reshape(1, -1) if len(pred_data.shape) == 1 else pred_data
    
    if pred_type == 'sf':
        pred_data = np.array([expand(pred_times, pred, times, zero_val, increasing=False) for pred in pred_data])
    else:
        pred_data = np.array([expand(pred_times, pred, times, zero_val) for pred in pred_data])
        
    pred_data = pred_data[0] if pred_data.shape[0] == 1 else pred_data
    
    return times, pred_data


def non_param_predictions(y_train, pred_type):
    
    if not pred_type in ['sf', 'chf']:
        raise ValueError('Models can produce either a survival function (sf) or a cumulative hazard function (chf)')
        
    train_events = np.array([e for e, _ in y_train])
    train_times = np.array([t for _, t in y_train])
    
    if pred_type == 'sf':
        pred_times, pred_data = kaplan_meier_estimator(train_events, train_times)
    else:
        pred_times, pred_data = nelson_aalen_estimator(train_events, train_times)
        
    return pred_times, pred_data


def model_predictions(X_test, pred_type, model, model_type):
    
    if not model_type in ['cox', 'rsf']:
        raise NotImplementedError('Cox and RSF models are only implemented')
    
    if not pred_type in ['sf', 'chf']:
        raise ValueError('Models can produce either a survival function (sf) or a cumulative hazard function (chf)')
        
    if pred_type == 'sf':
        predictions = model.predict_survival_function(X_test)
    else:
        predictions = model.predict_cumulative_hazard_function(X_test)
        
    if model_type == 'cox':
        pred_times = predictions[0].x
        pred_data = np.array([prediction.a * prediction.y for prediction in predictions])
    else:
        pred_times = model.event_times_
        pred_data = predictions.copy()
        
    return pred_times, pred_data


def mtime_computation(X_test, y_train, model, model_type, t_gamma):
    
    pred_type = 'sf'
    zero_val = 1
    
    model_times, model_surv = model_predictions(X_test, pred_type, model, model_type)
    times, model_surv = time_unification(model_times, model_surv, pred_type, y_train, zero_val)
    
    measures = np.array(list(times[1:] - times[:-1]) + [t_gamma])
    mtimes = np.dot(model_surv, measures)
    
    return mtimes


def limiter(particles, instance, r_clo, hcube):
    
    r_part = np.linalg.norm(particles - instance, axis=1) 
    r_lim = np.where(r_part > r_clo, r_clo, r_part)

    particles = instance + np.array([coef * point for coef, point in zip(r_lim / r_part, particles - instance)])
    
    lim_res = []
    for component, (c_min, c_max) in zip(particles.T, hcube.T):
        component = np.where(component < c_min, c_min, component)
        component = np.where(component > c_max, c_max, component)
        lim_res.append(component)

    return np.transpose(lim_res)


def ver_solution(instance, theta, margin, y_train, model, model_type, z_clo, hcube, t_gamma, n_rpoints=1000, n_batch=1000):
    
    inst_mtime = mtime_computation(instance.reshape(1, -1), y_train, model, model_type, t_gamma)
    
    if model_type == 'cox':
        
        print(f'Computing verification solution: {model_type}')
        
        model_beta = model.coef_
        model_times = model.baseline_survival_.x
        model_base_surv = model.baseline_survival_.y * model.baseline_survival_.a

        model_measures = np.array(list(model_times[1:] - model_times[:-1]) + [t_gamma])
        model_mtime = lambda x: model_times[0] + np.dot(model_measures, model_base_surv ** np.exp(x))

        u_zero = dichotomy(lambda u: margin - theta * (model_mtime(u) - inst_mtime))

        z = cp.Variable(instance.size)

        obj = cp.norm(z - instance)
        cons = [
            np.eye(instance.size) @ z >= hcube[0],
            np.eye(instance.size) @ z <= hcube[1],
            theta * (model_beta @ z - u_zero) <= 0
        ] 

        prob = cp.Problem(cp.Minimize(obj), cons)
        prob.solve()
        
        z_ver = z.value

        return z_ver
        
    elif model_type == 'rsf':
        
        print(f'Computing verification solution: {model_type}')
        
        r_clo = np.linalg.norm(z_clo - instance)
        
        rpoints = uniform_within_ball(instance, r_clo, n_rpoints)
        rpoints = limiter(rpoints, instance, r_clo, hcube)
        
        indexes, mask_rp = np.array(np.arange(0, n_rpoints, n_batch).tolist() + [n_rpoints]), []

        for i in np.arange(1, indexes.size):
            
            start, end = indexes[i - 1], indexes[i]
            print(f'\tbatch: {start}:{end}')
            
            rp_mtimes = mtime_computation(rpoints[start:end], y_train, model, model_type, t_gamma)
            mask_rp.extend(margin - theta * (rp_mtimes - inst_mtime) <= 0)
            
        mask_rp = np.array(mask_rp)
        
        if np.any(mask_rp):
            z_ver = rpoints[mask_rp][np.argmin(np.linalg.norm(rpoints[mask_rp] - instance, axis=1))]
        else:
            z_ver = z_clo
        
        return z_ver
    
    else:
        
        raise NotImplementedError('Cox and RSF models are only implemented')

# Particle Swarm Optimization

In [None]:
def objective(z, z_mtime, instance, inst_mtime, theta, margin, mcoef=1e+6):
    
    loss_dist = np.linalg.norm(z - instance)
    loss_marg = margin - theta * (z_mtime - inst_mtime)
    
    mcoef = 0 if loss_marg < 0 else mcoef
    
    return loss_dist + mcoef * loss_marg


def pso_optimization(obj_func, mtime_func,
                     instance, z_clo, r_clo, hcube,
                     max_iter=1000, n_particles=2000,
                     w=0.729, c1=1.4945, c2=1.4945,
                     verbose=20):
    
    obj_vals = lambda particles: np.array([obj_func(p, m) for p, m in zip(particles, mtime_func(particles))])
    
    for iter_idx in np.arange(max_iter + 1):
    
        if iter_idx == 0:
            velocities = np.zeros((n_particles, instance.size)) 
            particles = np.vstack([z_clo, uniform_within_ball(instance, r_clo, n_particles - 1)])
            particles = limiter(particles, instance, r_clo, hcube)
        else:
            terms_cognitive = np.array([np.random.rand() * c1 * p for p in (best_particles - particles)])
            terms_social = np.array([np.random.rand() * c2 * p for p in (best_global - particles)])

            velocities = w * velocities + terms_cognitive + terms_social
            particles = particles + velocities
            particles = limiter(particles, instance, r_clo, hcube)

        values = obj_vals(particles)

        if iter_idx == 0:
            best_particles = particles.copy()
            best_values = values.copy()
        else:
            for idx in np.arange(n_particles):
                if values[idx] < best_values[idx]:
                    best_particles[idx] = particles[idx]
                    best_values[idx] = values[idx]

        best_value = np.min(best_values)
        best_global = best_particles[np.argmin(best_values)]
        best_margin = theta * (mtime_func(best_global.reshape(1, -1)) - inst_mtime)

        if verbose and (iter_idx % verbose == 0):
            print('iter: {:04} | obj: {:.6e} | r: {:2.4f} | z_opt: {}'.format(iter_idx, best_value, best_margin, best_global))

    return best_margin, best_value, best_global

# Visualisation

In [None]:
def plot_result(info_dom, info_model, info_task, save_path):
    
    X_train, y_train, hcube = info_dom
    model, model_type, t_gamma = info_model
    instance, z_clo, z_ver, z_opt, mtime_min, mtime_max = info_task
    
    #####################################################
    
    U_train = X_train
    U_edges = np.array(list(it.product(*hcube.T)))
    u_inst, u_clo, u_ver, u_opt = instance, z_clo, z_ver, z_opt
    
    #####################################################

    U_grid = []
    for a, b in hcube.T:
        U_grid.append(np.linspace(a, b, 200))
    U_grid = np.array(list(it.product(*U_grid)))

    grid_mtimes = mtime_computation(U_grid, y_train, model, model_type, t_gamma)
    mask_grid = margin - theta * (grid_mtimes - inst_mtime) <= 0
    
    #####################################################

    afunc = lambda p: np.arccos(p[0] / np.linalg.norm(p)) if p[1] >= 0 else 2 * np.pi - np.arccos(p[0] / np.linalg.norm(p))    

    U_edges = U_edges[np.argsort([afunc(p) for p in (U_edges - np.mean(U_edges, axis=0))])]
    U_edges = np.array(U_edges.tolist() + [U_edges[0]])

    u_radius = np.linalg.norm(u_inst - u_clo)
    angles = np.linspace(0, 2 * np.pi, 101)
    border = u_inst + u_radius * np.array([[np.sin(angle), np.cos(angle)] for angle in angles])
    
    #####################################################
    
    mtime_min = np.min([mtime_min, np.min(grid_mtimes)])
    mtime_max = np.max([mtime_max, np.max(grid_mtimes)])

    cnorm = Normalize(mtime_min, mtime_max)
    
    #####################################################
    
    plt.figure(figsize=(15, 8))
    plt.rcParams.update({'font.size': 14})

    plt.subplot(1, 2, 1)

    plt.title(r'Values of $m(*)$')
    plt.scatter(U_grid[:, 0], U_grid[:, 1], s=60,
                c=grid_mtimes, cmap='jet', norm=cnorm)
    plt.scatter(U_train[:, 0], U_train[:, 1], s=60, edgecolor='black',
                c=train_mtimes, cmap='jet', norm=cnorm)
    plt.plot(U_edges[:, 0], U_edges[:, 1], c='black', ls='--', label=r'$\partial \scrX$')
    plt.legend(loc='upper left', framealpha=1)

    plt.colorbar()
    plt.xlabel(r'$x_{1}$')
    plt.ylabel(r'$x_{2}$')

    plt.subplot(1, 2, 2)

    plt.title(r'$m(\mathbf{x}) = $' + f'{np.round(inst_mtime, 2)}' + r' | $\theta = $' + f'{theta}' + r' | $r = $' + f'{np.round(margin, 2)}')
    plt.scatter(U_grid[mask_grid, 0], U_grid[mask_grid, 1], s=60,
                c=grid_mtimes[mask_grid], cmap='jet', norm=cnorm)
    plt.scatter(U_train[mask_train, 0], U_train[mask_train, 1], s=60, edgecolor='black',
                c=train_mtimes[mask_train], cmap='jet', norm=cnorm)
    plt.colorbar()
    plt.scatter(u_inst[0], u_inst[1], s=90, c='black', label=r'$\mathbf{x}$')
    plt.plot(U_edges[:, 0], U_edges[:, 1], c='black', ls='--', label=r'$\partial \scrX$')
    plt.plot(border[:, 0], border[:, 1], c='black', ls='-.', label=r'$\partial \scrB$')
    plt.scatter(u_opt[0], u_opt[1], s=90, marker='s', edgecolor='black',
                c='white', label=r'$\mathbf{z}_{opt}$')
    plt.scatter(u_ver[0], u_ver[1], s=90, marker='^', edgecolor='black',
                c='white', label=r'$\mathbf{z}_{ver}$')
    plt.xlabel(r'$x_{1}$')
    plt.ylabel(r'$x_{2}$')
    plt.legend(loc='upper left', framealpha=1)

    plt.tight_layout()
    plt.savefig(save_path, dpi=500)

# Data

## Generate data

In [None]:
flag = False

if flag:
    n_points, l, k, max_time = 1000, 1e-5, 2, 100000

    n_features = 2
    beta_true, X, y = data_gen(n_points, n_features, l, k, max_time)

    with open('data/sdata_d{:02}.pkl'.format(n_features), 'wb') as f:
        pickle.dump([beta_true, X, y], f)

    n_features = 20
    beta_true, X, y = data_gen(n_points, n_features, l, k, max_time)

    with open('data/sdata_d{:02}.pkl'.format(n_features), 'wb') as f:
        pickle.dump([beta_true, X, y], f)

## Load Data

In [None]:
data = []
names = ['sdata_d02', 'sdata_d20', 'stanford2', 'myeloid_(trt_A)', 'myeloid_(trt_B)']

###### Synthetic data: dim = 2

with open('data/sdata_d02.pkl', 'rb') as f:
    bt, X, y = pickle.load(f)
    
data.append([X, y])

###### Synthetic data: dim = 20

with open('data/sdata_d20.pkl', 'rb') as f:
    bt, X, y = pickle.load(f)
    
data.append([X, y])

###### Real data: stanford2

dtmp = pd.read_table('data/stanford2.csv', sep=';')
X = dtmp.values[:, :-2]
y = np.array([(e, t) for t, e in dtmp.values[:, -2:]], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

data.append([X, y])

###### Real data: myeloid (trt_A & trt_B)

dtmp = pd.read_table('data/myeloid.csv', sep=';')

for trt_B in [0, 1]:
    dtmp_cut = dtmp[dtmp['trt_B'] == trt_B]
    dtmp_cut = dtmp_cut.drop(columns=['trt_B'])

    X = dtmp_cut.values[:, :-2]
    y = np.array([(e, t) for t, e in dtmp_cut.values[:, -2:]], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])
    
    data.append([X, y])

# Experiment

## Model

In [None]:
#model_type = 'cox'
model_type = 'rsf'

if model_type == 'cox':
    model = CoxPHSurvivalAnalysis()
else:
    model = RandomSurvivalForest(n_estimators=250, min_samples_leaf=20, n_jobs=-1, random_state=1234)

## Computation

In [None]:
pred_type = 'sf'
zero_val = 1
t_gamma = 1

n_test = 2
n_rpoints = 10 ** 6
n_batch = 10 ** 5

max_iter = 1000
verbose = 50

for (X_train, y_train), name in zip(data, names):
    
    print(f'\nDataset: {name}\n')
    
    hcube = np.array([np.min(X_train, axis=0), np.max(X_train, axis=0)])
    X_test = np.random.RandomState(seed=1234).uniform(low=hcube[0], high=hcube[1], size=(n_test, hcube.shape[1]))
    X_test = np.array([[instance] * 2 for instance in X_test]).reshape(2 * n_test, -1)
    
    model.fit(X_train, y_train)
    
    for task_idx, (theta, instance) in enumerate(list(zip([1, -1, 1, -1], X_test)), start=1):
        
        print(f'task: {task_idx}\n')
        
        margin = None
        
        model_times, model_surv = model_predictions(np.vstack([X_train, instance]), pred_type, model, model_type)
        times, model_surv = time_unification(model_times, model_surv, pred_type, y_train, zero_val)

        train_surv, inst_surv = model_surv[:-1], model_surv[-1]
        measures = np.array(list(times[1:] - times[:-1]) + [t_gamma])

        train_mtimes = np.dot(train_surv, measures)
        inst_mtime = np.dot(inst_surv, measures)
        
        mtime_min = np.min(train_mtimes)
        mtime_max = np.max(train_mtimes)

        if not (mtime_min <= inst_mtime <= mtime_max):
            mtime_min = np.min([t for _, t in y_train])
            mtime_max = np.max([t for _, t in y_train])

        margin_max = 0.5 * ((1 - theta) * (inst_mtime - mtime_min) + (1 + theta) * (mtime_max - inst_mtime)) 

        if not (margin is None):
            if not (0 < margin < margin_max):
                raise ValueError(f'margin is out of interval ({0.}, {margin_max})')
        else:
            margin = np.random.uniform(low=0.25, high=0.75) * margin_max

        mask_train = margin - theta * (train_mtimes - inst_mtime) <= 0
        
        if np.any(mask_train):
            idx_clo = np.argmin(np.linalg.norm(X_train[mask_train] - instance, axis=1))
            z_clo = X_train[mask_train][idx_clo]
        else:
            print('There is no the closest train point\n')
            edges = np.array(list(it.product(*hcube.T)))
            idx_clo = np.argmax(np.linalg.norm(edges - instance, axis=1))
            z_clo = edges[idx_clo]    
        r_clo = np.linalg.norm(instance - z_clo)
            
        z_ver = ver_solution(instance, theta, margin, y_train, model, model_type, z_clo, hcube, t_gamma, n_rpoints, n_batch)
        z_ver_mtime = mtime_computation(z_ver.reshape(1, -1), y_train, model, model_type, t_gamma)
        z_ver_margin = theta * (z_ver_mtime - inst_mtime)
        z_ver_dist = np.linalg.norm(z_ver - instance)
        
        print()
        
        obj_func = lambda particle, part_mtime: objective(particle, part_mtime, instance, inst_mtime, theta, margin)
        mtime_func = lambda particles: mtime_computation(particles, y_train, model, model_type, t_gamma)

        z_opt_margin, z_opt_dist, z_opt = pso_optimization(obj_func, mtime_func, instance, z_clo, r_clo, hcube,
                                                           max_iter=max_iter, verbose=verbose)
        
        print()
        
        with open(f'results/experiment/{name}_{model_type}_task_{task_idx}.pkl', 'wb') as f:
            task = [instance, inst_mtime, theta, margin]
            info_dom = [hcube, train_mtimes, mtime_min, mtime_max, margin_max, mask_train, z_clo]
            info_ver = [z_ver, z_ver_mtime, z_ver_margin, z_ver_dist]
            info_opt = [z_opt, z_opt_margin, z_opt_dist]
            pickle.dump([task, info_dom, info_ver, info_opt], f)

# Results

In [None]:
model_type = 'rsf'

if model_type == 'cox':
    model = CoxPHSurvivalAnalysis()
else:
    model = RandomSurvivalForest(n_estimators=250, min_samples_leaf=20, n_jobs=-1, random_state=1234)

tab = []

for name in names:
    
    if name == 'sdata_d02':
        
        with open('data/sdata_d02.pkl', 'rb') as f:
            _, X_train, y_train = pickle.load(f)
            
    elif name == 'sdata_d20':

        with open('data/sdata_d20.pkl', 'rb') as f:
            _, X_train, y_train = pickle.load(f)
    
    elif name == 'stanford2':

        dtmp = pd.read_table('data/stanford2.csv', sep=';')
        X_train = dtmp.values[:, :-2]
        y_train = np.array([(e, t) for t, e in dtmp.values[:, -2:]], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

    elif name == 'myeloid_(trt_A)':

        dtmp = pd.read_table('data/myeloid.csv', sep=';')

        trt_B = 0
        dtmp_cut = dtmp[dtmp['trt_B'] == trt_B]
        dtmp_cut = dtmp_cut.drop(columns=['trt_B'])

        X_train = dtmp_cut.values[:, :-2]
        y_train = np.array([(e, t) for t, e in dtmp_cut.values[:, -2:]], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])
    
    else:

        dtmp = pd.read_table('data/myeloid.csv', sep=';')

        trt_B = 1
        dtmp_cut = dtmp[dtmp['trt_B'] == trt_B]
        dtmp_cut = dtmp_cut.drop(columns=['trt_B'])

        X_train = dtmp_cut.values[:, :-2]
        y_train = np.array([(e, t) for t, e in dtmp_cut.values[:, -2:]], dtype=[('Status', '?'), ('Survival_in_days', '<f8')])
        
    model.fit(X_train, y_train)   
    
    paths = sorted(glob.glob(f'results/experiment/{name}_{model_type}*'))
    
    for path in paths:
        
        with open(path, 'rb') as f:
            task, info_dom, info_ver, info_opt = pickle.load(f)
            
        instance, inst_mtime, theta, margin = task
        hcube, train_mtimes, mtime_min, mtime_max, margin_max, mask_train, z_clo = info_dom
        z_ver, _, z_ver_margin, z_ver_dist = info_ver
        z_opt, z_opt_margin, z_opt_dist = info_opt

        tab.append([theta, margin, z_ver_margin, z_opt_margin, z_ver_dist, z_opt_dist, np.linalg.norm(z_ver - z_opt)])
        
        if X_train.shape[1] == 2:
            info_dom = [X_train, y_train, hcube]
            info_model = [model, model_type, t_gamma]
            info_task = [instance, z_clo, z_ver, z_opt, mtime_min, mtime_max]
            save_path = 'results/' + path.split('\\')[-1][:-4] + '.png'

            plot_result(info_dom, info_model, info_task, save_path)

In [None]:
tab_columns = np.array(['theta', 'task_margin', 'ver_margin', 'sol_margin', 'ver_dist', 'sol_dist', 'dist(ver, sol)'])
tab_index = np.array([[name] * 4 for name in names]).flatten()

tab_res = pd.DataFrame(tab, columns=tab_columns, index=tab_index)
tab_res.to_csv(f'results/results_{model_type}.tsv', sep='\t')
tab_res