<a href="https://colab.research.google.com/github/mostly-sunny/digital-health-hackathon/blob/main/2.%20coxph_find_best_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pycox - CoxPH Model
- Network : Test with sets, and find best network and lr by lowest brier score
- Input Variables : G1 ~ G300, Var1 ~ Var10, Treatment
- Output Variables : time, event
- Scaler : MinMaxScaler -> Var1 ~ Var10



In [None]:
pip install pycox

Collecting pycox
  Downloading pycox-0.2.2-py3-none-any.whl (73 kB)
[?25l[K     |████▌                           | 10 kB 20.5 MB/s eta 0:00:01[K     |█████████                       | 20 kB 27.8 MB/s eta 0:00:01[K     |█████████████▍                  | 30 kB 26.4 MB/s eta 0:00:01[K     |█████████████████▉              | 40 kB 19.0 MB/s eta 0:00:01[K     |██████████████████████▎         | 51 kB 9.2 MB/s eta 0:00:01[K     |██████████████████████████▊     | 61 kB 9.7 MB/s eta 0:00:01[K     |███████████████████████████████▏| 71 kB 9.0 MB/s eta 0:00:01[K     |████████████████████████████████| 73 kB 1.8 MB/s 
Collecting torchtuples>=0.2.0
  Downloading torchtuples-0.2.2-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 765 kB/s 
[?25hCollecting py7zr>=0.11.3
  Downloading py7zr-0.16.1-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 4.1 MB/s 
Collecting brotli>=1.0.9
  Downloading Brotli-1.0.9-cp37-cp37m-manylinux1_x86_64.whl

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn_pandas import DataFrameMapper
import pandas as pd

import torch
import torchtuples as tt

from pycox.models import CoxPH
from pycox.evaluation import EvalSurv

In [None]:
np.random.seed(123456)
_ = torch.manual_seed(123456)

- all-in-one.csv 파일은 유전자 변이 유무, 임상 변수, 생존 기간, 사망 여부, 치료 유무가 열로 존재하는 파일
- test-data-treat-and-untreat.csv 파일은 all-in-one의 열은 같은 602개의 데이터.
  - (0번째 행) : 유전자 변이 모두 0, 치료 0
  - (1번째 행) : 유전자 변이 모두 0, 치료 1
  - (2~301번째 행) : 유전자 변이 n-1에만 1, 치료 0
  - (302~601번재 행) : 유전자 변이 n-301에만 1, 치료 1
- pandas 라이브러리에 있는 csv 파일을 DataFrame으로 바꾸어주는 read_csv 함수를 이용하여 파일을 읽어 들임.
- DataFrame은 표를 나타내는 데이터 타입임.

In [None]:
dataset = pd.read_csv('/content/all-in-one.csv')
dataset_for_hr = pd.read_csv('/content/test-data-treat-and-untreat.csv')

- 위에서 읽어들인 dataset 중에서 20%는 검증(_val -> validation)을 위해 sampling 한다.
- 남은 80%의 데이터에서도 20%는 테스트(_test)를 위해 sampling 한다.

In [None]:
dataset_val = dataset.sample(frac=0.2)
dataset_train = dataset.drop(dataset_val.index)
dataset_test = dataset_train.sample(frac=0.2)
dataset_train = dataset_train.drop(dataset_test.index)

- columns_standardize : 임상변수 - 0~9사이의 값을 표준화
- columns_leave : 유전자 변이 유무 + 치료 유무 - 0과 1로 표현돼 있기 때문에 표준화 필요 없음.
- DataFrameMapper는 pandas DataFrame에서 원하는 열을 뽑아서 리스트로 만들어줌.
- 리스트로 만들때 StandardScaler()가 포함된 열은 표준화를 시킨 뒤, 그리고 None이면 갖고 있는 값을 그대로 넣음.

In [None]:
columns_standardize = ['Var' + str(i) for i in range(1,11)]
columns_leave = ['G' + str(i) for i in range(1,301)]
columns_leave += ['Treatment']

# standardize = [([col], StandardScaler()) for col in columns_standardize]
standardize = [([col], MinMaxScaler()) for col in columns_standardize]

leave = [(col, None) for col in columns_leave]

x_mapper = DataFrameMapper(leave + standardize)

- 위에서 만든 DataFrameMapper로 DataFrame 중 x(입력) 데이터를 모델이 학습할 수 있게끔 리스트 형식으로 바꾸어 준다.



In [None]:
x_train = x_mapper.fit_transform(dataset_train).astype('float32')
x_val = x_mapper.transform(dataset_val).astype('float32')
x_test = x_mapper.transform(dataset_test).astype('float32')
x_for_hr = x_mapper.transform(dataset_for_hr).astype('float32')

- DataFrame (표)에서 Y(출력)데이터인 time(생존기간)과 event(사망여부)를 뽑아 출력 데이터를 추린다.
- 검증(Validation)을 위한 입력-출력 세트 val을 만든다.

In [None]:
get_target = lambda df: (df['time'].values, df['event'].values)
y_train = get_target(dataset_train)
y_val = get_target(dataset_val)

durations_test, events_test = get_target(dataset_test)
val = x_val, y_val

함수 make_net : network을 생성해 리턴하는 함수
- input과 output의 노드 수, 은닉층 수, 은닉층의 노드 수 설정 가능

In [None]:
def make_net(in_features, out_features, hidden, nodes):
  if hidden == 1:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 2:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 3:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 4:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
      
      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, out_features)
    )
  return network

- in_features : 입력데이터의 개수 (x_train.shape : 311 = 300(유전자) + 10(임상변수) + 1(치료유무))
- out_features : 출력노드의 개수

- hidden_layers : 은닉층 수를 가지고 있는 리스트
- number_nodes : 은닉층에 있는 노드 수를 가지고 있는 리스트
- learning_rates : 테스트할 학습률을 가지고 있는 리스트
- brier_scores = brier score을 계산해 append

In [None]:
in_features = x_train.shape[1]
out_features = 1

hidden_layers = [1,2,3,4]
number_nodes = [32, 64, 128, 256, 512, 1024, 2048]
learning_rates = [0.0001, 0.001, 0.01, 0.1, 0]
brier_scores = []

for i in hidden_layers:
  for j in number_nodes:
    for k in learning_rates:
      net = make_net(in_features, out_features, i, j)
      model = CoxPH(net, tt.optim.Adam)
      batch_size = 256

      if k == 0:
        lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance = 10)
        model.optimizer.set_lr(lrfinder.get_best_lr())
      else:
        model.optimizer.set_lr(k)
      
      epochs = 512
      callbacks = [tt.callbacks.EarlyStopping()]
      verbose = True

      %%time
      model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val, val_batch_size=batch_size)
      _ = model.compute_baseline_hazards()
      surv = model.predict_surv_df(x_test)
      
      # calculate ratio
      log_partial_hazard = model.predict(x_for_hr)
      partial_hazard = [np.exp(lph) for lph in log_partial_hazard]

      treat_hr = []
      # ratio with treated and untreated
      for g in range(300):
        treat_hr.append([partial_hazard[g+302]/partial_hazard[g+2],'G' + str(g+1)])
      treat_hr.sort()

      # evaluation
      ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
      time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
      score = ev.integrated_brier_score(time_grid)

      if k == 0:
        brier_scores.append([score, i, j, lrfinder.get_best_lr(), treat_hr[:10]])
      else:
        brier_scores.append([score, i, j, k, treat_hr[:10]])

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.87 µs
0:	[0s / 0s],		train_loss: 4.4867,	val_loss: 4.2817
1:	[0s / 0s],		train_loss: 4.4710,	val_loss: 4.2825
2:	[0s / 0s],		train_loss: 4.4537,	val_loss: 4.2836
3:	[0s / 0s],		train_loss: 4.4543,	val_loss: 4.2850
4:	[0s / 0s],		train_loss: 4.4115,	val_loss: 4.2869
5:	[0s / 0s],		train_loss: 4.4026,	val_loss: 4.2893
6:	[0s / 0s],		train_loss: 4.3801,	val_loss: 4.2924
7:	[0s / 0s],		train_loss: 4.3459,	val_loss: 4.2963
8:	[0s / 0s],		train_loss: 4.3715,	val_loss: 4.3013
9:	[0s / 0s],		train_loss: 4.3589,	val_loss: 4.3072
10:	[0s / 0s],		train_loss: 4.3557,	val_loss: 4.3143




CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.39 µs
0:	[0s / 0s],		train_loss: 4.5282,	val_loss: 4.2778
1:	[0s / 0s],		train_loss: 4.3406,	val_loss: 4.2756
2:	[0s / 0s],		train_loss: 4.2064,	val_loss: 4.2732
3:	[0s / 0s],		train_loss: 4.1309,	val_loss: 4.2700
4:	[0s / 0s],		train_loss: 4.0439,	val_loss: 4.2663
5:	[0s / 0s],		train_loss: 3.9997,	val_loss: 4.2620
6:	[0s / 0s],		train_loss: 3.9227,	val_loss: 4.2577
7:	[0s / 0s],		train_loss: 3.8799,	val_loss: 4.2541
8:	[0s / 0s],		train_loss: 3.8316,	val_loss: 4.2520
9:	[0s / 0s],		train_loss: 3.7697,	val_loss: 4.2518
10:	[0s / 0s],		train_loss: 3.7214,	val_loss: 4.2533
11:	[0s / 0s],		train_loss: 3.6847,	val_loss: 4.2577
12:	[0s / 0s],		train_loss: 3.6490,	val_loss: 4.2657
13:	[0s / 0s],		train_loss: 3.5871,	val_loss: 4.2771
14:	[0s / 0s],		train_loss: 3.5661,	val_loss: 4.2919
15:	[0s / 0s],		train_loss: 3.5259,	val_loss: 4.3087
16:	[0s / 0s],		train_loss: 3.4694,	val_loss: 4.3301
17:	[0s / 0s],		train_loss: 3.4550,	val_loss:



CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 7.15 µs
0:	[0s / 0s],		train_loss: 4.4609,	val_loss: 4.2517
1:	[0s / 0s],		train_loss: 3.9654,	val_loss: 4.2333
2:	[0s / 0s],		train_loss: 3.5320,	val_loss: 4.1996
3:	[0s / 0s],		train_loss: 3.3352,	val_loss: 4.1589
4:	[0s / 0s],		train_loss: 3.0968,	val_loss: 4.1230
5:	[0s / 0s],		train_loss: 2.9354,	val_loss: 4.0960
6:	[0s / 0s],		train_loss: 2.7761,	val_loss: 4.0860
7:	[0s / 0s],		train_loss: 2.7115,	val_loss: 4.1042
8:	[0s / 0s],		train_loss: 2.6505,	val_loss: 4.1424
9:	[0s / 0s],		train_loss: 2.6989,	val_loss: 4.1606
10:	[0s / 0s],		train_loss: 2.6799,	val_loss: 4.1870
11:	[0s / 0s],		train_loss: 2.6906,	val_loss: 4.2413
12:	[0s / 1s],		train_loss: 2.6716,	val_loss: 4.3581
13:	[0s / 1s],		train_loss: 2.6479,	val_loss: 4.5171
14:	[0s / 1s],		train_loss: 2.6719,	val_loss: 4.7034
15:	[0s / 1s],		train_loss: 2.6331,	val_loss: 4.8244
16:	[0s / 1s],		train_loss: 2.6177,	val_loss: 4.8749
CPU times: user 3 µs, sys: 0 ns, total: 3 µs




CPU times: user 3 µs, sys: 1e+03 ns, total: 4 µs
Wall time: 7.39 µs
0:	[0s / 0s],		train_loss: 4.4717,	val_loss: 4.2635
1:	[0s / 0s],		train_loss: 3.9726,	val_loss: 4.2476
2:	[0s / 0s],		train_loss: 3.6345,	val_loss: 4.2228
3:	[0s / 0s],		train_loss: 3.4121,	val_loss: 4.1935
4:	[0s / 0s],		train_loss: 3.2014,	val_loss: 4.1670
5:	[0s / 0s],		train_loss: 3.0278,	val_loss: 4.1412
6:	[0s / 1s],		train_loss: 2.8801,	val_loss: 4.1194
7:	[0s / 1s],		train_loss: 2.7644,	val_loss: 4.1030
8:	[0s / 1s],		train_loss: 2.6675,	val_loss: 4.1047
9:	[0s / 1s],		train_loss: 2.6511,	val_loss: 4.1327
10:	[0s / 1s],		train_loss: 2.6830,	val_loss: 4.1853
11:	[0s / 1s],		train_loss: 2.6543,	val_loss: 4.2860
12:	[0s / 1s],		train_loss: 2.6836,	val_loss: 4.4067
13:	[0s / 2s],		train_loss: 2.6436,	val_loss: 4.5229
14:	[0s / 2s],		train_loss: 2.6260,	val_loss: 4.6635
15:	[0s / 2s],		train_loss: 2.6403,	val_loss: 4.7733
16:	[0s / 2s],		train_loss: 2.6682,	val_loss: 4.8758
17:	[0s / 2s],		train_loss: 2.6407,	val_l



CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 8.11 µs
0:	[0s / 0s],		train_loss: 4.4044,	val_loss: 4.2688
1:	[0s / 0s],		train_loss: 3.9073,	val_loss: 4.2581
2:	[0s / 0s],		train_loss: 3.6333,	val_loss: 4.2441
3:	[0s / 0s],		train_loss: 3.4467,	val_loss: 4.2274
4:	[0s / 1s],		train_loss: 3.2665,	val_loss: 4.2044
5:	[0s / 1s],		train_loss: 3.1510,	val_loss: 4.1764
6:	[0s / 1s],		train_loss: 3.0572,	val_loss: 4.1460
7:	[0s / 1s],		train_loss: 2.9218,	val_loss: 4.1170
8:	[0s / 2s],		train_loss: 2.8282,	val_loss: 4.1016
9:	[0s / 2s],		train_loss: 2.8097,	val_loss: 4.1211
10:	[0s / 2s],		train_loss: 2.8177,	val_loss: 4.1564
11:	[0s / 2s],		train_loss: 2.7848,	val_loss: 4.1935
12:	[0s / 3s],		train_loss: 2.7858,	val_loss: 4.2434
13:	[0s / 3s],		train_loss: 2.7302,	val_loss: 4.3523
14:	[0s / 3s],		train_loss: 2.7771,	val_loss: 4.5181
15:	[0s / 3s],		train_loss: 2.7450,	val_loss: 4.6796
16:	[0s / 4s],		train_loss: 2.7432,	val_loss: 4.8092
17:	[0s / 4s],		train_loss: 2.7266,	val_loss:



CPU times: user 0 ns, sys: 3 µs, total: 3 µs
Wall time: 6.68 µs
0:	[1s / 1s],		train_loss: 4.4738,	val_loss: 4.2771
1:	[1s / 3s],		train_loss: 4.0004,	val_loss: 4.2753
2:	[1s / 4s],		train_loss: 3.6667,	val_loss: 4.2711
3:	[1s / 6s],		train_loss: 3.5050,	val_loss: 4.2635
4:	[1s / 8s],		train_loss: 3.2934,	val_loss: 4.2502
5:	[1s / 9s],		train_loss: 3.2170,	val_loss: 4.2314
6:	[1s / 11s],		train_loss: 3.1120,	val_loss: 4.2070
7:	[1s / 13s],		train_loss: 3.0887,	val_loss: 4.1796
8:	[1s / 14s],		train_loss: 2.9658,	val_loss: 4.1619
9:	[1s / 16s],		train_loss: 2.9548,	val_loss: 4.1516
10:	[1s / 17s],		train_loss: 2.8966,	val_loss: 4.1563
11:	[1s / 19s],		train_loss: 2.9207,	val_loss: 4.1809
12:	[1s / 20s],		train_loss: 2.9359,	val_loss: 4.2128
13:	[1s / 22s],		train_loss: 2.9146,	val_loss: 4.2450
14:	[1s / 23s],		train_loss: 2.8726,	val_loss: 4.2861
15:	[1s / 25s],		train_loss: 2.8705,	val_loss: 4.3238
16:	[1s / 26s],		train_loss: 2.8563,	val_loss: 4.3538
17:	[1s / 28s],		train_loss: 2.868

- brier_score가 가장 작은 것부터 정렬


In [None]:
brier_scores.sort()
selected_genes = []
for i in range(10):
  selected_genes.append(brier_scores[0][4][i][1])

- brier_score가 가장 좋은 케이스 출력

In [None]:
print("Brier Score :", brier_scores[0][0])
print("Hidden Layers :", brier_scores[0][1])
print("Number of Nodes :", brier_scores[0][2])
print("Learning Rate :", brier_scores[0][3])
print("Selection :", selected_genes)

Brier Score : 0.05571701742791964
Hidden Layers : 1
Number of Nodes : 2048
Learning Rate : 0.0010235310218990308
Selection : ['G110', 'G57', 'G28', 'G292', 'G221', 'G260', 'G155', 'G269', 'G111', 'G293']


NAN 값을 기준으로 정렬이 끊어진 문제 발견

In [None]:
print("brier_score, hidden layer, number of nodes, learning rate")
for i in brier_scores:
  print(i[:-1])

[0.05571701742791964, 1, 2048, 0.0010235310218990308]
[0.056438513393290965, 1, 1024, 0.002154434690031894]
[0.056515697640799695, 1, 512, 0.1]
[0.05672940461368982, 1, 256, 0.005462277217684369]
[0.05708716930762827, 1, 512, 0.005462277217684369]
[0.05731754383404234, 3, 512, 0.0005857020818056691]
[0.0574636790375501, 3, 512, 0.001]
[0.0574724830747122, 3, 64, 0.01]
[0.05766754393769138, 4, 1024, 0.001]
[0.05779229452457315, 4, 256, 0.005462277217684369]
[0.05809877132432969, 1, 2048, 0.001]
[0.05866294281957923, 1, 1024, 0.001]
[0.05866442591341538, 4, 2048, 0.0001]
[0.05869208365514611, 3, 2048, 0.0001]
[0.058719959287107625, 4, 512, 0.001]
[0.059110107790455266, 1, 256, 0.001]
[0.05928807586006142, 1, 512, 0.001]
[0.059325984097976646, 4, 256, 0.001]
[0.059370933866415135, 3, 1024, 0.0001917910261672496]
[0.05938501120251226, 1, 512, 0.01]
[0.05947690619874226, 4, 128, 0.01]
[0.059533134978945655, 1, 128, 0.01]
[0.059582599878136855, 3, 512, 0.0001]
[0.05981721119310052, 3, 1024, 