<a href="https://colab.research.google.com/github/mostly-sunny/digital-health-hackathon/blob/main/temporary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pycox - CoxPH Model
- Network : Test with sets, and find best network and lr by lowest brier score
- Input Variables : G1 ~ G300, Var1 ~ Var10, Treatment
- Output Variables : time, event
- Scaler : MinMaxScaler -> Var1 ~ Var10



In [151]:
pip install pycox



In [152]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn_pandas import DataFrameMapper
import pandas as pd

import torch
import torchvision
import torchtuples as tt

from pycox.models import CoxPH
from pycox.evaluation import EvalSurv

In [153]:
torch.cuda.is_available()

True

In [154]:
np.random.seed(123456)
_ = torch.manual_seed(123456)

- all-in-one.csv 파일은 유전자 변이 유무, 임상 변수, 생존 기간, 사망 여부, 치료 유무가 열로 존재하는 파일
- test-data-treat-and-untreat.csv 파일은 all-in-one의 열은 같은 602개의 데이터.
  - (0번째 행) : 유전자 변이 모두 0, 치료 0
  - (1번째 행) : 유전자 변이 모두 0, 치료 1
  - (2~301번째 행) : 유전자 변이 n-1에만 1, 치료 0
  - (302~601번재 행) : 유전자 변이 n-301에만 1, 치료 1
- pandas 라이브러리에 있는 csv 파일을 DataFrame으로 바꾸어주는 read_csv 함수를 이용하여 파일을 읽어 들임.
- DataFrame은 표를 나타내는 데이터 타입임.

In [155]:
dataset = pd.read_csv('/content/all-in-one.csv')
dataset_for_hr = pd.read_csv('/content/test-data-treat-and-untreat.csv')

- 위에서 읽어들인 dataset 중에서 20%는 검증(_val -> validation)을 위해 sampling 한다.
- 남은 80%의 데이터에서도 20%는 테스트(_test)를 위해 sampling 한다.

In [156]:
dataset_test = dataset.sample(frac=0.3, random_state = np.random.seed(123456))
dataset_train = dataset.drop(dataset_test.index)
dataset_val = dataset_train.sample(frac=0.2, random_state = np.random.seed(123456))
dataset_train = dataset_train.drop(dataset_val.index)

In [157]:
dataset_test

Unnamed: 0,G1,G2,G3,G4,G5,G6,G7,G8,G9,G10,G11,G12,G13,G14,G15,G16,G17,G18,G19,G20,G21,G22,G23,G24,G25,G26,G27,G28,G29,G30,G31,G32,G33,G34,G35,G36,G37,G38,G39,G40,...,G274,G275,G276,G277,G278,G279,G280,G281,G282,G283,G284,G285,G286,G287,G288,G289,G290,G291,G292,G293,G294,G295,G296,G297,G298,G299,G300,Var1,Var2,Var3,Var4,Var5,Var6,Var7,Var8,Var9,Var10,time,event,Treatment
601,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,3,5,4,6,4,3,3,5,7,1,71.955766,1,1
357,0,0,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,...,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,2,2,2,4,2,4,0,2,1,39.587305,1,0
945,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,3,4,1,5,6,2,1,2,2,3,99.459796,1,1
828,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,...,0,0,0,1,0,1,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,2,1,5,3,6,6,3,3,1,2,67.195341,1,1
649,1,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,3,1,2,4,0,1,1,0,6,35.031129,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
289,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,4,0,3,3,1,0,3,0,9,4,66.602950,1,0
913,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,4,4,3,9,5,3,5,0,2,2,57.497019,1,1
483,0,0,0,0,0,1,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,...,0,1,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,5,3,3,1,1,2,3,2,0,3,40.966792,0,0
276,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,4,2,0,1,5,3,5,5,4,1,107.690607,1,1


- columns_standardize : 임상변수 - 0~9사이의 값을 표준화
- columns_leave : 유전자 변이 유무 + 치료 유무 - 0과 1로 표현돼 있기 때문에 표준화 필요 없음.
- DataFrameMapper는 pandas DataFrame에서 원하는 열을 뽑아서 리스트로 만들어줌.
- 리스트로 만들때 StandardScaler()가 포함된 열은 표준화를 시킨 뒤, 그리고 None이면 갖고 있는 값을 그대로 넣음.

In [158]:
columns_standardize = ['Var' + str(i) for i in range(1,11)]
columns_leave = ['G' + str(i) for i in range(1,301)]
columns_leave += ['Treatment']

# standardize = [([col], StandardScaler()) for col in columns_standardize]
standardize = [([col], MinMaxScaler()) for col in columns_standardize]

leave = [(col, None) for col in columns_leave]

x_mapper = DataFrameMapper(leave + standardize)

- 위에서 만든 DataFrameMapper로 DataFrame 중 x(입력) 데이터를 모델이 학습할 수 있게끔 리스트 형식으로 바꾸어 준다.



In [159]:
x_train = x_mapper.fit_transform(dataset_train).astype('float32')
x_val = x_mapper.transform(dataset_val).astype('float32')
x_test = x_mapper.transform(dataset_test).astype('float32')
x_for_hr = x_mapper.transform(dataset_for_hr).astype('float32')

- DataFrame (표)에서 Y(출력)데이터인 time(생존기간)과 event(사망여부)를 뽑아 출력 데이터를 추린다.
- 검증(Validation)을 위한 입력-출력 세트 val을 만든다.

In [160]:
get_target = lambda df: (df['time'].values, df['event'].values)
y_train = get_target(dataset_train)
y_val = get_target(dataset_val)

durations_test, events_test = get_target(dataset_test)
val = x_val, y_val

함수 make_net : network을 생성해 리턴하는 함수
- input과 output의 노드 수, 은닉층 수, 은닉층의 노드 수 설정 가능

In [161]:
def make_net(in_features, out_features, hidden, nodes):
  if hidden == 1:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 2:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, int(nodes/2)),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(int(nodes/2)),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(int(nodes/2), out_features)
    )
  elif hidden == 3:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, int(nodes/2)),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(int(nodes/2)),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(int(nodes/2), int(nodes/4)),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(int(nodes/4)),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(int(nodes/4), out_features)
    )
  elif hidden == 4:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
      
      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, out_features)
    )
  return network

- in_features : 입력데이터의 개수 (x_train.shape : 311 = 300(유전자) + 10(임상변수) + 1(치료유무))
- out_features : 출력노드의 개수

- hidden_layers : 은닉층 수를 가지고 있는 리스트
- number_nodes : 은닉층에 있는 노드 수를 가지고 있는 리스트
- learning_rates : 테스트할 학습률을 가지고 있는 리스트
- brier_scores = brier score을 계산해 append

In [162]:
in_features = x_train.shape[1]
out_features = 1

hidden_layers = [2,3]
number_nodes = [1024, 2048, 3072]
learning_rates = [0.0001, 0.001, 0]
brier_scores = []

total_num = len(hidden_layers) * len(number_nodes) * len(learning_rates)
count = 1

for i in hidden_layers:
  for j in number_nodes:
    for k in learning_rates:
      print(count, '/' , total_num)
      net = make_net(in_features, out_features, i, j)
      model = CoxPH(net, tt.optim.Adam)
      batch_size = 559

      if k == 0:
        lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance = 10)
        model.optimizer.set_lr(lrfinder.get_best_lr())
      else:
        model.optimizer.set_lr(k)
      
      epochs = 512
      callbacks = [tt.callbacks.EarlyStopping()]
      verbose = True

      %%time
      model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val, val_batch_size=batch_size)
      _ = model.compute_baseline_hazards()
      surv = model.predict_surv_df(x_test)

      count += 1

      # calculate ratio
      log_partial_hazard = model.predict(x_for_hr)
      partial_hazard = [np.exp(lph) for lph in log_partial_hazard]

      treat_hr = []
      # ratio with treated and untreated
      for g in range(300):
        treat_hr.append([partial_hazard[g+302]/partial_hazard[g+2],'G' + str(g+1)])
      treat_hr.sort()

      # evaluation
      ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
      time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
      score = ev.integrated_brier_score(time_grid)
      nbll = ev.integrated_nbll(time_grid)

      if k == 0:
        brier_scores.append([score, nbll, i, j, lrfinder.get_best_lr(), treat_hr[:10]])
      else:
        brier_scores.append([score, nbll, i, j, k, treat_hr[:10]])

1 / 18
CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 5.48 µs
0:	[0s / 0s],		train_loss: 5.4531,	val_loss: 3.9831
1:	[0s / 0s],		train_loss: 5.3562,	val_loss: 3.9826
2:	[0s / 0s],		train_loss: 5.2451,	val_loss: 3.9819
3:	[0s / 0s],		train_loss: 5.1681,	val_loss: 3.9811
4:	[0s / 0s],		train_loss: 5.1005,	val_loss: 3.9801
5:	[0s / 0s],		train_loss: 5.0247,	val_loss: 3.9787
6:	[0s / 0s],		train_loss: 4.9494,	val_loss: 3.9770
7:	[0s / 0s],		train_loss: 4.8954,	val_loss: 3.9751
8:	[0s / 0s],		train_loss: 4.8589,	val_loss: 3.9729
9:	[0s / 0s],		train_loss: 4.8005,	val_loss: 3.9704
10:	[0s / 0s],		train_loss: 4.7482,	val_loss: 3.9677
11:	[0s / 0s],		train_loss: 4.7116,	val_loss: 3.9646
12:	[0s / 1s],		train_loss: 4.6622,	val_loss: 3.9612
13:	[0s / 1s],		train_loss: 4.6208,	val_loss: 3.9575
14:	[0s / 1s],		train_loss: 4.5727,	val_loss: 3.9535
15:	[0s / 1s],		train_loss: 4.5313,	val_loss: 3.9492
16:	[0s / 1s],		train_loss: 4.4878,	val_loss: 3.9445
17:	[0s / 1s],		train_loss: 4.4558

- brier_score가 가장 작은 것부터 정렬


In [168]:
partial_hazard

[array([2.01333], dtype=float32),
 array([1.40585], dtype=float32),
 array([2.0170047], dtype=float32),
 array([1.8114334], dtype=float32),
 array([2.4194937], dtype=float32),
 array([1.8435086], dtype=float32),
 array([2.1422], dtype=float32),
 array([1.7707067], dtype=float32),
 array([2.694901], dtype=float32),
 array([2.3923402], dtype=float32),
 array([1.7425401], dtype=float32),
 array([2.2913945], dtype=float32),
 array([1.977036], dtype=float32),
 array([2.8720875], dtype=float32),
 array([1.8129772], dtype=float32),
 array([2.8182337], dtype=float32),
 array([1.8575889], dtype=float32),
 array([1.9845912], dtype=float32),
 array([1.7574089], dtype=float32),
 array([2.4805834], dtype=float32),
 array([1.9884703], dtype=float32),
 array([2.3683498], dtype=float32),
 array([2.0749445], dtype=float32),
 array([2.1360443], dtype=float32),
 array([2.20998], dtype=float32),
 array([2.2923267], dtype=float32),
 array([2.4124298], dtype=float32),
 array([1.8190355], dtype=float32),
 ar

In [163]:
treat_hr

[[array([0.5948483], dtype=float32), 'G108'],
 [array([0.60413426], dtype=float32), 'G298'],
 [array([0.60743606], dtype=float32), 'G230'],
 [array([0.6091651], dtype=float32), 'G253'],
 [array([0.6098423], dtype=float32), 'G48'],
 [array([0.6102403], dtype=float32), 'G21'],
 [array([0.61127937], dtype=float32), 'G81'],
 [array([0.6139316], dtype=float32), 'G119'],
 [array([0.61489636], dtype=float32), 'G209'],
 [array([0.61598635], dtype=float32), 'G127'],
 [array([0.61775553], dtype=float32), 'G6'],
 [array([0.619035], dtype=float32), 'G266'],
 [array([0.6192105], dtype=float32), 'G198'],
 [array([0.6206895], dtype=float32), 'G264'],
 [array([0.622727], dtype=float32), 'G87'],
 [array([0.6239251], dtype=float32), 'G238'],
 [array([0.6268986], dtype=float32), 'G279'],
 [array([0.6279714], dtype=float32), 'G62'],
 [array([0.6284381], dtype=float32), 'G148'],
 [array([0.6300542], dtype=float32), 'G19'],
 [array([0.6301491], dtype=float32), 'G71'],
 [array([0.63032365], dtype=float32), '

In [164]:
brier_scores.sort()
selected_genes = []
for i in range(10):
  selected_genes.append(brier_scores[0][5][i][1])

In [165]:
gene_count= {}
for i in brier_scores:
  lst = []
  for j in i[5]:
    lst.append(j[1])
  print(lst)
  for k in lst:
    if k in gene_count.keys():
      gene_count[k] += 1
    else:
      gene_count[k] = 1
sorted(gene_count.items(), key=lambda x: x[1], reverse=True)

['G108', 'G298', 'G230', 'G253', 'G48', 'G21', 'G81', 'G119', 'G209', 'G127']
['G35', 'G193', 'G136', 'G149', 'G137', 'G94', 'G186', 'G25', 'G120', 'G152']
['G298', 'G136', 'G193', 'G140', 'G35', 'G131', 'G161', 'G214', 'G121', 'G48']
['G273', 'G121', 'G162', 'G35', 'G194', 'G136', 'G299', 'G59', 'G191', 'G34']
['G161', 'G28', 'G35', 'G248', 'G24', 'G202', 'G38', 'G54', 'G214', 'G179']
['G35', 'G137', 'G88', 'G132', 'G98', 'G34', 'G47', 'G265', 'G220', 'G251']
['G268', 'G147', 'G226', 'G197', 'G118', 'G15', 'G119', 'G258', 'G143', 'G204']
['G35', 'G152', 'G123', 'G54', 'G141', 'G145', 'G204', 'G41', 'G191', 'G140']
['G35', 'G193', 'G9', 'G99', 'G164', 'G114', 'G168', 'G124', 'G45', 'G271']
['G156', 'G195', 'G111', 'G1', 'G149', 'G92', 'G152', 'G40', 'G19', 'G209']
['G88', 'G265', 'G82', 'G224', 'G150', 'G11', 'G107', 'G70', 'G297', 'G3']
['G35', 'G48', 'G127', 'G30', 'G296', 'G33', 'G98', 'G220', 'G256', 'G94']
['G194', 'G226', 'G136', 'G186', 'G203', 'G83', 'G130', 'G269', 'G193', 'G9

[('G35', 11),
 ('G136', 5),
 ('G193', 4),
 ('G48', 3),
 ('G149', 3),
 ('G152', 3),
 ('G194', 3),
 ('G34', 3),
 ('G88', 3),
 ('G41', 3),
 ('G298', 2),
 ('G119', 2),
 ('G209', 2),
 ('G127', 2),
 ('G137', 2),
 ('G94', 2),
 ('G186', 2),
 ('G140', 2),
 ('G161', 2),
 ('G214', 2),
 ('G121', 2),
 ('G59', 2),
 ('G191', 2),
 ('G24', 2),
 ('G54', 2),
 ('G98', 2),
 ('G265', 2),
 ('G220', 2),
 ('G226', 2),
 ('G197', 2),
 ('G204', 2),
 ('G9', 2),
 ('G99', 2),
 ('G164', 2),
 ('G45', 2),
 ('G1', 2),
 ('G19', 2),
 ('G70', 2),
 ('G3', 2),
 ('G236', 2),
 ('G69', 2),
 ('G108', 1),
 ('G230', 1),
 ('G253', 1),
 ('G21', 1),
 ('G81', 1),
 ('G25', 1),
 ('G120', 1),
 ('G131', 1),
 ('G273', 1),
 ('G162', 1),
 ('G299', 1),
 ('G28', 1),
 ('G248', 1),
 ('G202', 1),
 ('G38', 1),
 ('G179', 1),
 ('G132', 1),
 ('G47', 1),
 ('G251', 1),
 ('G268', 1),
 ('G147', 1),
 ('G118', 1),
 ('G15', 1),
 ('G258', 1),
 ('G143', 1),
 ('G123', 1),
 ('G141', 1),
 ('G145', 1),
 ('G114', 1),
 ('G168', 1),
 ('G124', 1),
 ('G271', 1),
 ('G1

- brier_score가 가장 좋은 케이스 출력

In [166]:
print("Brier Score :", brier_scores[0][0])
print("Hidden Layers :", brier_scores[0][1])
print("Number of Nodes :", brier_scores[0][2])
print("Learning Rate :", brier_scores[0][3])
print("Selection :", selected_genes)

Brier Score : 0.06997940154243748
Hidden Layers : 0.36066973841494526
Number of Nodes : 3
Learning Rate : 3072
Selection : ['G108', 'G298', 'G230', 'G253', 'G48', 'G21', 'G81', 'G119', 'G209', 'G127']


NAN 값을 기준으로 정렬이 끊어진 문제 발견

In [167]:
print("brier_score, hidden layer, number of nodes, learning rate")
for i in brier_scores:
  print(i[:-1])

brier_score, hidden layer, number of nodes, learning rate
[0.06997940154243748, 0.36066973841494526, 3, 3072, 0.0017886495290574435]
[0.07004986796887101, 0.3884178216528487, 2, 2048, 0.001484968262254472]
[0.07029672273891087, 0.37242316007949877, 2, 3072, 0.0012328467394420717]
[0.07057359261796259, 0.37780080250924014, 2, 3072, 0.001]
[0.07069836954118826, 0.3534956857454184, 3, 3072, 0.001]
[0.07211909013474464, 0.3472352785176396, 2, 2048, 0.001]
[0.07269280002298988, 0.24084021092130467, 2, 1024, 0.007924828983539215]
[0.07274546343193349, 0.27978050077295014, 3, 1024, 0.013848863713938809]
[0.07291973325399, 0.37927466825847966, 3, 2048, 0.001]
[0.07351309593542278, 0.445738634092458, 2, 1024, 0.0001]
[0.07365977158766675, 0.393980213863942, 2, 3072, 0.0001]
[0.07385855964788601, 0.38774490269022155, 3, 2048, 0.0010235310218990308]
[0.07450338079845782, 0.4676717678169885, 2, 2048, 0.0001]
[0.07457627173872768, 0.39792369346995443, 2, 1024, 0.001]
[0.0748148992689731, 0.37420271