# Transfer learning - Survival prediction
Here we are initializing the survival prediction task by using the pre-trained weights from the "foundation model" and fine tuning on survival prediction. We use for training 70% of the data and 30% for validation. 

This notebook guides you through the use of the clinical transformere API

In [1]:
import sys
sys.path.append('/root/capsule/environment/clinical_transformer/')

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
from xai.models import Trainer
from xai.models import SurvivalTransformer
from xai.models import OptimizedSurvivalDataGenerator as SurvivalDataGenerator
from xai.losses.survival import cIndex_SigmoidApprox as cindex_loss
from xai.metrics.survival import sigmoid_concordance as cindex_metric

2024-05-28 18:34:37.911531: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-28 18:34:38.138140: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [4]:
import pandas as pd
from samecode.random import set_seed

## Dataset

In [5]:
data = pd.read_csv('../data/dataset-train.data.csv')
features = ["f_{}".format(i) for i in range(0, 10)]
features

['f_0', 'f_1', 'f_2', 'f_3', 'f_4', 'f_5', 'f_6', 'f_7', 'f_8', 'f_9']

## Parameters

In [6]:
max_features_percentile=100
test_size=0.3 # fraction of samples used for validation
repetitions=1 # number replicates (training / validation) random splits to evaluate variability.

mode='survival'
learning_rate=0.0001
epochs=100
verbose=0
seed=0
embedding_size = 128
num_heads = 2
num_layers = 2

## Training

In [9]:
!rm -r ../results/runs/TransferLearningSurvival/

In [10]:
outdir = '../results/runs/TransferLearningSurvival/'
set_seed(0)

trainer = Trainer(
    from_pretrained='../results/runs/FoundationModel/fold-0_id-0/model.E001000.h5',
    out_dir = outdir,
    max_features_percentile=max_features_percentile,
    test_size=test_size,
    mode=mode,
    model=SurvivalTransformer, 
    dataloader=SurvivalDataGenerator,
    loss=cindex_loss,
    metrics=[cindex_metric]
)

trainer.setup_data(
    data, 
    discrete_features = [],
    continuous_features = features,
    target=['time', 'event']
)

trainer.setup_model(
    learning_rate=learning_rate,
    embedding_size=embedding_size,
    num_heads=num_heads,
    num_layers=num_layers,
    batch_size_max=True,
    save_best_only=False
)

trainer.fit(repetitions=repetitions, epochs=epochs, verbose=verbose, seed=seed)

INFO	2024-05-28 18:37:34,988	Setting up working directory: ../results/runs/TransferLearningSurvival/
INFO	2024-05-28 18:37:34,990	Setting up transfer learning directory: ../results/runs/TransferLearningSurvival//model.E001000.h5/
INFO	2024-05-28 18:37:35,185	Number of continuous features: 10
INFO	2024-05-28 18:37:35,186	Number of discrete features: 0
INFO	2024-05-28 18:37:35,186	Number of samples: 700
INFO	2024-05-28 18:37:35,192	Number of classes: 1
INFO	2024-05-28 18:37:35,195	RUN ID: fold-0_id-0
INFO	2024-05-28 18:37:35,197	RUN ID out directory: ../results/runs/TransferLearningSurvival//model.E001000.h5//fold-0_id-0/
INFO	2024-05-28 18:37:35,262	Training samples: 490
INFO	2024-05-28 18:37:35,263	Testing samples: 210
INFO	2024-05-28 18:37:35,267	Number of features at 100th percentile: 10 that are non nans
2024-05-28 18:37:40.562599: E tensorflow/core/framework/node_def_util.cc:675] NodeDef mentions attribute epsilon which is not in the op definition: Op<name=_MklFusedBatchMatMulV2; s