In [1]:
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from Mink import science_template # just a template I have setup on my github can delete and will only change the plot formatting
from train import make_kaplan_plot, log_rank
from sksurv.nonparametric import kaplan_meier_estimator
from glob import glob
import scipy.stats as stats
pio.templates['science'] = science_template
pio.templates.default = 'science'
pio.renderers.default = "vscode"

In [2]:
features = np.genfromtxt('../data/miRNA_feature_importance.csv', delimiter=',', skip_header=1, usecols=1, dtype=str)
importance = np.genfromtxt('../data/miRNA_feature_importance.csv', delimiter=',', skip_header=1, usecols=2)

def rank_importance(features, importance):
    # features from different 

    dic = {}
    for i, key in enumerate(features):
        if key not in dic.keys():
            dic[key] = [importance[i]]

        else:
            dic[key].append(importance[i])

    for key in dic.keys():
        dic[key] = np.mean(dic[key])

    key = np.array(list(dic.keys()))
    values = np.array(list(dic.values()))

    i = np.argsort(values)[::-1]

    key = key[i]
    values = values[i]

    return key, values

key, values= rank_importance(features, importance)
print(len(key))
for i in range(len(key)):
    print(f"{key[i]} : {values[i]}")




215
hsa-mir-320c-2 : 0.01212604706820899
hsa-mir-146a : 0.011323388765538074
hsa-mir-548k : 0.0110603095346479
hsa-mir-30e : 0.010804788336317955
hsa-mir-27b : 0.008202443280977323
hsa-mir-1270 : 0.008057439170323066
hsa-mir-3913-2 : 0.007622606208425126
hsa-mir-142 : 0.007535273302917844
hsa-mir-23c : 0.007153581299666246
hsa-mir-23b : 0.007087056855284898
hsa-mir-150 : 0.00664939104697643
hsa-mir-338 : 0.006561954624781885
hsa-mir-6891 : 0.0063821300358994585
hsa-mir-4442 : 0.00586387434554978
hsa-mir-3609 : 0.0056536323843898
hsa-mir-3653 : 0.005584363781412005
hsa-mir-632 : 0.0051541850220264696
hsa-mir-4745 : 0.005099245661685625
hsa-mir-4657 : 0.00490235744564936
hsa-mir-5010 : 0.004347826086956485
hsa-mir-4676 : 0.004319685440132245
hsa-mir-33a : 0.004188481675392708
hsa-mir-1294 : 0.003948942959712776
hsa-mir-206 : 0.0037785880013258555
hsa-mir-4522 : 0.0036697247706421908
hsa-mir-125b-2 : 0.0036360005132428873
hsa-mir-216a : 0.003621483418590593
hsa-mir-326 : 0.003211169284467

In [3]:
dir = '../data/kaplan/'
files = glob(dir+'*')
print(files)

['../data/kaplan/RSF_clinical_test_set_0_group2.csv', '../data/kaplan/RSF_clinical_test_set_3_group2.csv', '../data/kaplan/RSF_clinical_test_set_1_group2.csv', '../data/kaplan/RSF_clinical_test_set_4_group2.csv', '../data/kaplan/RSF_clinical_test_set_0_group1.csv', '../data/kaplan/RSF_clinical_test_set_2_group2.csv', '../data/kaplan/RSF_clinical_test_set_2_group1.csv', '../data/kaplan/RSF_clinical_test_set_4_group1.csv', '../data/kaplan/RSF_clinical_test_set_1_group1.csv', '../data/kaplan/RSF_clinical_test_set_3_group1.csv']


In [None]:
fig = make_subplots(rows=3, cols=2, vertical_spacing=0.085, horizontal_spacing=0.055, shared_yaxes=True, shared_xaxes=True)

for i in range(5):
    targets = []
    for file in files:
        if f'test_set_{i}' in file:
            targets.append(file)

    group1 = pd.read_csv(targets[0])
    group2 = pd.read_csv(targets[1])

    
    print(log_rank(group1, group2))

    xgroup1, sgroup1, cgroup1 = kaplan_meier_estimator(group1['censored']==1, group1['days_to_event'], conf_type='log-log')
    xgroup2, sgroup2, cgroup2 = kaplan_meier_estimator(group2['censored']==1, group2['days_to_event'], conf_type='log-log')


    # since no events happen in the first two years i have to add a point at x = 0 to make the plot look more continuous
    xgroup2 = np.concatenate(([0], xgroup2))
    sgroup2 = np.concatenate(([1], sgroup2))
    cgroup2_0 = np.concatenate(([1], cgroup2[0]))
    cgroup2_1 = np.concatenate(([1], cgroup2[1]))

    # since no events happen in the first two years i have to add a point at x = 0 to make the plot look more continuous
    xgroup1 = np.concatenate(([0], xgroup1))
    sgroup1 = np.concatenate(([1], sgroup1))
    cgroup1_0 = np.concatenate(([1], cgroup1[0]))
    cgroup1_1 = np.concatenate(([1], cgroup1[1]))

    # plotting
    row = i//2+1
    col = i%2+1
    print(row,col)

    fig.add_trace(go.Scatter(x=xgroup1/365, y=sgroup1, mode="lines", line_shape='hv', line_color='black', name='high risk', showlegend=False), row=row, col=col)
    fig.add_trace(go.Scatter(x=np.concatenate([xgroup1, xgroup1[::-1]])/365, y=np.concatenate((cgroup1_0, cgroup1_1[::-1])),
                            fill='toself', showlegend=False, line_shape='vh', line_color='skyblue'), row=row, col=col)

    fig.add_trace(go.Scatter(x=np.concatenate([xgroup2, xgroup2[::-1]])/365, y=np.concatenate((cgroup2_0, cgroup2_1[::-1])),
                            fill='toself', mode='lines', showlegend=False, line_shape='vh', line_color='#90ee90'), row=row, col=col)
    fig.add_trace(go.Scatter(x=xgroup2/365, y=sgroup2, mode="lines", line_shape='hv', line_color='green', name='low risk', showlegend=False), row=row, col=col)


fig.update_layout(legend_xanchor='right', 
                yaxis_showgrid=True, xaxis_showgrid=True, font_size=26, width=1000, height=650)
fig.for_each_xaxis(lambda x: x.update(showgrid=True))
fig.for_each_yaxis(lambda x: x.update(showgrid=True, range=(0,1), dtick=0.2))
fig.update_xaxes(tickvals=[0,5,10,15], showticklabels=True, row=2, col=2)
fig.show()


fig.write_image('figures/RSF_kaplans.pdf')

LogRankResult(statistic=-4.8670751216606245, pvalue=1.1326198534845787e-06)
1 1
LogRankResult(statistic=-2.392903647175802, pvalue=0.016715630152123347)
1 2
LogRankResult(statistic=-2.262429598987164, pvalue=0.02367087391275785)
2 1
LogRankResult(statistic=-1.0958205209903125, pvalue=0.2731573261941713)
2 2
LogRankResult(statistic=-3.036268145699497, pvalue=0.002395262445922015)
3 1


In [7]:
from scipy.optimize import curve_fit
from Mink import full_return
my_curve_fit = full_return(curve_fit)

In [17]:
def linear(x,a,b):
    return a*x + b

def ranking_error(true_risks, predicted_risks):
    arg_i = np.argsort(true_risks)
    # print(true_risks[arg_i]/np.max(true_risks))
    arg_j = np.argsort(predicted_risks)
    # print(predicted_risks[arg_i]/np.max(predicted_risks))
    mapping = [np.where(arg_j == arg)[0][0] for arg in arg_i]

    return mapping








fig = make_subplots(rows=3, cols=2, vertical_spacing=0.05, horizontal_spacing=0.055, shared_yaxes=True, shared_xaxes=True)

for i in range(5):
    targets = []
    for file in files:
        if f'test_set_{i}' in file:
            targets.append(file)

    group1 = pd.read_csv(targets[0])
    group2 = pd.read_csv(targets[1])

    
    df = pd.concat([group1, group2])
    # df = df.loc[df['censored']==0]

    map = ranking_error( 1/df['days_to_event'], df['risk'])
    # print(map)
    ideal = np.arange(0,len(df))

    p0 = [1,0]
    res = my_curve_fit(linear, ideal, map, p0=p0, sigma=22*np.ones(len(df)), absolute_sigma=True)
    # print(res)

    # print(np.mean(np.sqrt((map-ideal)**2)))

    # plotting
    row = i//2+1
    col = i%2+1
    print(row,col)

    fig.add_trace(go.Scatter(x=ideal, y=map, mode='markers', name='prediction', showlegend=False, marker_color='black'), row=row, col=col)
    fig.add_trace(go.Scatter(x=ideal, y=linear(ideal, *res.pOpt), mode='lines', line_color='red', name='fit', showlegend=False), row=row, col=col)
    fig.add_trace(go.Scatter(x=ideal, y=ideal, mode='lines', line_dash='dash', name='ideal', line_color='black', showlegend=False), row=row, col=col)

fig.update_layout(legend_xanchor='right', 
                yaxis_showgrid=True, xaxis_showgrid=True, font_size=26, width=1000, height=850)
fig.for_each_xaxis(lambda x: x.update(showgrid=True))
fig.for_each_yaxis(lambda x: x.update(showgrid=True))
fig.show()
fig.write_image('figures/c-indices_visualized.svg')

1 1
1 2
2 1
2 2
3 1
