<a href="https://colab.research.google.com/github/mostly-sunny/digital-health-hackathon/blob/main/2.%20coxph_find_best_network(hyeonbin_ver).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pycox - CoxPH Model
- Network : Test with sets, and find best network and lr by lowest brier score
- Input Variables : G1 ~ G300, Var1 ~ Var10, Treatment
- Output Variables : time, event
- Scaler : MinMaxScaler -> Var1 ~ Var10



In [5]:
pip install pycox



In [6]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn_pandas import DataFrameMapper
import pandas as pd

import torch
import torchtuples as tt

from pycox.models import CoxPH
from pycox.evaluation import EvalSurv

In [7]:
np.random.seed(123456)
_ = torch.manual_seed(123456)

- all-in-one.csv 파일은 유전자 변이 유무, 임상 변수, 생존 기간, 사망 여부, 치료 유무가 열로 존재하는 파일
- test-data-treat-and-untreat.csv 파일은 all-in-one의 열은 같은 602개의 데이터.
  - (0번째 행) : 유전자 변이 모두 0, 치료 0
  - (1번째 행) : 유전자 변이 모두 0, 치료 1
  - (2~301번째 행) : 유전자 변이 n-1에만 1, 치료 0
  - (302~601번재 행) : 유전자 변이 n-301에만 1, 치료 1
- pandas 라이브러리에 있는 csv 파일을 DataFrame으로 바꾸어주는 read_csv 함수를 이용하여 파일을 읽어 들임.
- DataFrame은 표를 나타내는 데이터 타입임.

In [8]:
dataset = pd.read_csv('/content/all-in-one.csv')
dataset_for_hr = pd.read_csv('/content/test-data-treat-and-untreat.csv')

- 위에서 읽어들인 dataset 중에서 20%는 검증(_val -> validation)을 위해 sampling 한다.
- 남은 80%의 데이터에서도 20%는 테스트(_test)를 위해 sampling 한다.

In [9]:
dataset_val = dataset.sample(frac=0.2)
dataset_train = dataset.drop(dataset_val.index)
dataset_test = dataset_train.sample(frac=0.2)
dataset_train = dataset_train.drop(dataset_test.index)

- columns_standardize : 임상변수 - 0~9사이의 값을 표준화
- columns_leave : 유전자 변이 유무 + 치료 유무 - 0과 1로 표현돼 있기 때문에 표준화 필요 없음.
- DataFrameMapper는 pandas DataFrame에서 원하는 열을 뽑아서 리스트로 만들어줌.
- 리스트로 만들때 StandardScaler()가 포함된 열은 표준화를 시킨 뒤, 그리고 None이면 갖고 있는 값을 그대로 넣음.

In [10]:
columns_standardize = ['Var' + str(i) for i in range(1,11)]
columns_leave = ['G' + str(i) for i in range(1,301)]
columns_leave += ['Treatment']

# standardize = [([col], StandardScaler()) for col in columns_standardize]
standardize = [([col], MinMaxScaler()) for col in columns_standardize]

leave = [(col, None) for col in columns_leave]

x_mapper = DataFrameMapper(leave + standardize)

- 위에서 만든 DataFrameMapper로 DataFrame 중 x(입력) 데이터를 모델이 학습할 수 있게끔 리스트 형식으로 바꾸어 준다.



In [11]:
x_train = x_mapper.fit_transform(dataset_train).astype('float32')
x_val = x_mapper.transform(dataset_val).astype('float32')
x_test = x_mapper.transform(dataset_test).astype('float32')
x_for_hr = x_mapper.transform(dataset_for_hr).astype('float32')

- DataFrame (표)에서 Y(출력)데이터인 time(생존기간)과 event(사망여부)를 뽑아 출력 데이터를 추린다.
- 검증(Validation)을 위한 입력-출력 세트 val을 만든다.

In [12]:
get_target = lambda df: (df['time'].values, df['event'].values)
y_train = get_target(dataset_train)
y_val = get_target(dataset_val)

durations_test, events_test = get_target(dataset_test)
val = x_val, y_val

함수 make_net : network을 생성해 리턴하는 함수
- input과 output의 노드 수, 은닉층 수, 은닉층의 노드 수 설정 가능

In [13]:
def make_net(in_features, out_features, hidden, nodes):
  if hidden == 1:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 2:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 3:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
            
      torch.nn.Linear(nodes, out_features)
    )
  elif hidden == 4:
    network =  torch.nn.Sequential(
      torch.nn.Linear(in_features, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),
      
      torch.nn.Linear(nodes, nodes),
      torch.nn.ReLU(),
      torch.nn.BatchNorm1d(nodes),
      torch.nn.Dropout(0.1),

      torch.nn.Linear(nodes, out_features)
    )
  return network

- in_features : 입력데이터의 개수 (x_train.shape : 311 = 300(유전자) + 10(임상변수) + 1(치료유무))
- out_features : 출력노드의 개수

- hidden_layers : 은닉층 수를 가지고 있는 리스트
- number_nodes : 은닉층에 있는 노드 수를 가지고 있는 리스트
- learning_rates : 테스트할 학습률을 가지고 있는 리스트
- brier_scores = brier score을 계산해 append

In [14]:
in_features = x_train.shape[1]
out_features = 1

hidden_layers = [1,2,3]
number_nodes = range(256, 4100, 256)
learning_rates = [0.0001, 0.001, 0.01, 0]
brier_scores = []

for i in hidden_layers:
  for j in number_nodes:
    for k in learning_rates:
      net = make_net(in_features, out_features, i, j)
      model = CoxPH(net, tt.optim.Adam)
      batch_size = 639

      if k == 0:
        lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance = 10)
        model.optimizer.set_lr(lrfinder.get_best_lr())
      else:
        model.optimizer.set_lr(k)
      
      epochs = 512
      callbacks = [tt.callbacks.EarlyStopping()]
      verbose = True

      %%time
      model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val, val_batch_size=batch_size)
      _ = model.compute_baseline_hazards()
      surv = model.predict_surv_df(x_test)
      
      # calculate ratio
      log_partial_hazard = model.predict(x_for_hr)
      partial_hazard = [np.exp(lph) for lph in log_partial_hazard]

      treat_hr = []
      # ratio with treated and untreated
      for g in range(300):
        treat_hr.append([partial_hazard[g+302]/partial_hazard[g+2],'G' + str(g+1)])
      treat_hr.sort()

      # evaluation
      ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
      time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
      score = ev.integrated_brier_score(time_grid)

      if k == 0:
        brier_scores.append([score, i, j, lrfinder.get_best_lr(), treat_hr[:10]])
      else:
        brier_scores.append([score, i, j, k, treat_hr[:10]])

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.2 µs
0:	[0s / 0s],		train_loss: 5.6224,	val_loss: 4.2873
1:	[0s / 0s],		train_loss: 5.5912,	val_loss: 4.2875
2:	[0s / 0s],		train_loss: 5.5557,	val_loss: 4.2876
3:	[0s / 0s],		train_loss: 5.5534,	val_loss: 4.2877
4:	[0s / 0s],		train_loss: 5.5366,	val_loss: 4.2878
5:	[0s / 0s],		train_loss: 5.5085,	val_loss: 4.2879
6:	[0s / 0s],		train_loss: 5.4929,	val_loss: 4.2880
7:	[0s / 0s],		train_loss: 5.4754,	val_loss: 4.2881
8:	[0s / 0s],		train_loss: 5.4642,	val_loss: 4.2882
9:	[0s / 0s],		train_loss: 5.4262,	val_loss: 4.2883
10:	[0s / 0s],		train_loss: 5.4152,	val_loss: 4.2883




[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
2:	[0s / 0s],		train_loss: 5.3697,	val_loss: 4.2653
3:	[0s / 0s],		train_loss: 5.5152,	val_loss: 4.2517
4:	[0s / 0s],		train_loss: 5.2735,	val_loss: 4.2259
5:	[0s / 0s],		train_loss: 4.8135,	val_loss: 4.2041
6:	[0s / 0s],		train_loss: 4.5464,	val_loss: 4.1926
7:	[0s / 0s],		train_loss: 4.3697,	val_loss: 4.1876
8:	[0s / 0s],		train_loss: 4.1821,	val_loss: 4.1825
9:	[0s / 0s],		train_loss: 4.0144,	val_loss: 4.1745
10:	[0s / 1s],		train_loss: 3.9163,	val_loss: 4.1602
11:	[0s / 1s],		train_loss: 3.9166,	val_loss: 4.1372
12:	[0s / 1s],		train_loss: 3.9571,	val_loss: 4.1244
13:	[0s / 1s],		train_loss: 3.9753,	val_loss: 4.1249
14:	[0s / 1s],		train_loss: 3.9521,	val_loss: 4.1382
15:	[0s / 1s],		train_loss: 3.9678,	val_loss: 4.1585
16:	[0s / 1s],		train_loss: 3.9367,	val_loss: 4.1851
17:	[0s / 1s],		train_loss: 3.9256,	val_loss: 4.2261
18:	[0s / 1s],		train_loss: 3.9218,	val_loss: 4.2788
19:	[0s / 1s],		train_loss: 3.9095,	val_loss: 4.3414
20:	

- brier_score가 가장 작은 것부터 정렬


In [15]:
brier_scores.sort()
selected_genes = []
for i in range(10):
  selected_genes.append(brier_scores[0][4][i][1])

In [19]:
gene_count= {}
for i in brier_scores:
  lst = []
  for j in i[4]:
    lst.append(j[1])
  print(lst)
  for k in lst:
    if k in gene_count.keys():
      gene_count[k] += 1
    else:
      gene_count[k] = 1
sorted(gene_count.items(), key=lambda x: x[1], reverse=True)

['G191', 'G137', 'G35', 'G148', 'G204', 'G203', 'G88', 'G27', 'G130', 'G193']
['G59', 'G101', 'G193', 'G174', 'G120', 'G243', 'G180', 'G136', 'G298', 'G242']
['G159', 'G6', 'G101', 'G247', 'G26', 'G31', 'G119', 'G188', 'G218', 'G279']
['G297', 'G168', 'G196', 'G43', 'G86', 'G186', 'G62', 'G98', 'G223', 'G131']
['G149', 'G229', 'G158', 'G26', 'G120', 'G64', 'G156', 'G140', 'G6', 'G54']
['G48', 'G276', 'G195', 'G198', 'G87', 'G186', 'G171', 'G246', 'G120', 'G94']
['G217', 'G69', 'G196', 'G60', 'G213', 'G123', 'G132', 'G8', 'G35', 'G204']
['G134', 'G221', 'G137', 'G54', 'G27', 'G284', 'G87', 'G175', 'G63', 'G139']
['G161', 'G127', 'G188', 'G149', 'G223', 'G133', 'G197', 'G191', 'G35', 'G69']
['G286', 'G292', 'G217', 'G282', 'G145', 'G259', 'G123', 'G80', 'G284', 'G296']
['G47', 'G161', 'G185', 'G157', 'G227', 'G281', 'G282', 'G293', 'G265', 'G137']
['G242', 'G213', 'G262', 'G211', 'G121', 'G238', 'G109', 'G253', 'G195', 'G24']
['G28', 'G29', 'G119', 'G152', 'G73', 'G155', 'G247', 'G74', '

[('G35', 34),
 ('G29', 23),
 ('G193', 21),
 ('G161', 21),
 ('G88', 20),
 ('G137', 18),
 ('G243', 18),
 ('G149', 18),
 ('G69', 18),
 ('G148', 17),
 ('G27', 17),
 ('G298', 17),
 ('G259', 17),
 ('G34', 16),
 ('G204', 15),
 ('G196', 15),
 ('G43', 15),
 ('G242', 14),
 ('G31', 14),
 ('G158', 14),
 ('G47', 14),
 ('G191', 13),
 ('G180', 13),
 ('G223', 13),
 ('G140', 13),
 ('G214', 13),
 ('G101', 12),
 ('G136', 12),
 ('G54', 12),
 ('G48', 12),
 ('G60', 12),
 ('G123', 12),
 ('G292', 12),
 ('G262', 12),
 ('G253', 12),
 ('G264', 12),
 ('G15', 12),
 ('G164', 12),
 ('G1', 12),
 ('G85', 12),
 ('G120', 11),
 ('G159', 11),
 ('G6', 11),
 ('G247', 11),
 ('G186', 11),
 ('G211', 11),
 ('G238', 11),
 ('G28', 11),
 ('G82', 11),
 ('G131', 10),
 ('G195', 10),
 ('G284', 10),
 ('G127', 10),
 ('G227', 10),
 ('G112', 10),
 ('G116', 10),
 ('G72', 10),
 ('G67', 10),
 ('G234', 10),
 ('G119', 9),
 ('G62', 9),
 ('G94', 9),
 ('G221', 9),
 ('G197', 9),
 ('G157', 9),
 ('G73', 9),
 ('G97', 9),
 ('G147', 9),
 ('G142', 9),
 

- brier_score가 가장 좋은 케이스 출력

In [17]:
print("Brier Score :", brier_scores[0][0])
print("Hidden Layers :", brier_scores[0][1])
print("Number of Nodes :", brier_scores[0][2])
print("Learning Rate :", brier_scores[0][3])
print("Selection :", selected_genes)

Brier Score : 0.05462315934099361
Hidden Layers : 1
Number of Nodes : 1792
Learning Rate : 0.01
Selection : ['G191', 'G137', 'G35', 'G148', 'G204', 'G203', 'G88', 'G27', 'G130', 'G193']


NAN 값을 기준으로 정렬이 끊어진 문제 발견

In [18]:
print("brier_score, hidden layer, number of nodes, learning rate")
for i in brier_scores:
  print(i[:-1])

brier_score, hidden layer, number of nodes, learning rate
[0.05462315934099361, 1, 1792, 0.01]
[0.05491309760429704, 2, 2304, 0.0012328467394420717]
[0.05492758802057797, 2, 3328, 0.0012328467394420717]
[0.05509887882972444, 2, 2304, 0.001]
[0.05538589027789313, 3, 2304, 0.001]
[0.0555082716147738, 3, 1536, 0.0012328467394420717]
[0.05556824920870504, 2, 3328, 0.0001]
[0.055675568575760295, 2, 2560, 0.001]
[0.05580345223310109, 2, 2048, 0.001]
[0.05584126765771997, 3, 2048, 0.001]
[0.05617465616106026, 2, 3328, 0.001]
[0.0562795574486702, 2, 768, 0.0031257158496882514]
[0.05639318371493179, 2, 2048, 0.0001]
[0.05642856512679102, 2, 3072, 0.0001]
[0.05643240392734869, 2, 1792, 0.001]
[0.0565276738091141, 2, 1792, 0.001484968262254472]
[0.05677068445880029, 2, 1536, 0.001484968262254472]
[0.05677124528964219, 2, 3072, 0.001]
[0.05681839928331673, 2, 2560, 0.0001]
[0.05689668759890826, 2, 1792, 0.0001]
[0.05705349977971013, 2, 4096, 0.0001]
[0.05710380251710436, 2, 3584, 0.001]
[0.0571841