In [1]:
import torch
import torchtuples as tt # Some useful functions
import numpy as np
import pandas as pd
from sksurv.metrics import concordance_index_censored
from sklearn.model_selection import KFold
from Mink import science_template
import plotly.graph_objects as go
import plotly.io as pio
from pycox.models import LogisticHazard, loss, CoxPH
from pycox.evaluation import EvalSurv
from multi_pipe import load_all
from util import n_equal_slices, t_test_feature_selection, pearson_feature_selection, TruSight170
from clinical_pipe import impute_scale
from ray import tune, train

pio.templates['science'] = science_template
pio.templates.default = 'science'



In [2]:
pip list

Package                   Version
------------------------- --------------
absl-py                   2.1.0
aiosignal                 1.3.1
altgraph                  0.17
anyio                     3.6.1
appdirs                   1.4.4
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
astroid                   2.6.6
asttokens                 2.4.1
async-generator           1.10
async-lru                 2.0.4
attrs                     23.2.0
Babel                     2.10.3
backcall                  0.2.0
beautifulsoup4            4.11.1
bleach                    3.3.1
blinker                   1.7.0
Brotli                    1.0.9
brotlicffi                1.1.0.0
cached-property           1.5.2
celluloid                 0.2.0
certifi                   2020.12.5
cffi                      1.17.1
chardet                   4.0.0
charset-normalizer        3.3.2
chebpy                    0.2
click                     7.1.2
colorama             

In [3]:
res_dict = load_all()

2it [00:00, 13.11it/s]

520it [00:37, 13.81it/s]
523it [00:01, 471.32it/s]


clinical_M not in keys
clinical_N not in keys
clinical_T not in keys


In [4]:

def make_net(in_features, out_features, dims, dropout=0.2, seed=54):
    non_zero_dims = []
    # zero layers should be dropped
    for dim in dims:
        if dim != 0:
            non_zero_dims.append(dim)

    dims = non_zero_dims

    torch.manual_seed(seed)
    if len(dims) == 0:
        raise ValueError('Neural network must have at least two layers')

    # connect the in_features to the first dimension in the hidden layers
    modules = [torch.nn.Linear(in_features, dims[0]),
               torch.nn.ReLU(),
               torch.nn.BatchNorm1d(dims[0]),
               torch.nn.Dropout(dropout)]
    
    # connect the hidden layers
    for i,dim in enumerate(dims[1:]):

        modules.append(torch.nn.Linear(dims[i], dim))
        modules.append(torch.nn.ReLU())
        modules.append(torch.nn.BatchNorm1d(dim))
        modules.append(torch.nn.Dropout(dropout))

    # didn't add an activation function here
    # connect the final hidden layer to the out_features
    

    modules = modules + [torch.nn.Linear(dims[-1], out_features)]

    net = torch.nn.Sequential(*modules)
    
    net.apply(init_weights)
    return net

def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)

net = make_net(50, 10, [20])

print(net)

Sequential(
  (0): Linear(in_features=50, out_features=20, bias=True)
  (1): ReLU()
  (2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.2, inplace=False)
  (4): Linear(in_features=20, out_features=10, bias=True)
)


In [5]:
def c_index_score(model, data, labels, times):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X_test = torch.tensor(data, dtype=torch.float32).to(device)

    # Get the survival predictions for the test set
    surv_preds = model.predict_surv_df(X_test)
    
    # Estimate median survival times for each individual
    median_surv_times = surv_preds.apply(lambda col: col[col <= 0.5].index[0] if any(col <= 0.5) else col.index[-1], axis=0)

    # Compute C-index using true times, event indicators, and predicted times
    c_index = concordance_index_censored(labels.astype(bool), times, -median_surv_times)[0]
    return c_index

def combine_transform_data(DataFrames, train_bool, ) -> list:
    '''Take in all the preprocessed data '''

    combined_data = DataFrames[0]

    for df in DataFrames[1:]:
        combined_data = pd.merge(combined_data, df.drop(['time', 'label'], axis=1), on='case_id', how='left')


    data = combined_data.loc[:, ~combined_data.columns.isin(['case_id', 'label', 'time','primary_diagnosis_Squamous cell carcinoma, spindle cell'])].values
    data = np.array(data, dtype=np.float32)
    

    labels = np.array(combined_data['label'].values, dtype=int)
    times = np.array(combined_data['time'].values, dtype=np.float32)

    train_set = data[train_bool]
    train_labels = labels[train_bool]
    train_times = times[train_bool]
    test_set = data[~train_bool]
    test_times = times[~train_bool]
    test_labels = labels[~train_bool]
  
    
    

    return train_set, train_labels, train_times, test_set, test_labels, test_times

def cross_validate(train_set, y_train, dropout=0.3, num_nodes=[50,50], batch_norm=True, batch_size=50, epochs=10, optimizer=None):
    in_features = train_set.shape[1]
    out_features = 1

    
    
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    scores = []
    

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_set)):
        # print(f"Fold {fold + 1}")
        
        
        

        callbacks = [tt.cb.EarlyStopping(patience=5)]
        # Define the model and move to GPU
        net = make_net(in_features, out_features, num_nodes, dropout=dropout)
        
        model = CoxPH(net, optimizer)
    

        # Move data to GPU
        X_train, X_val = torch.tensor(train_set[train_idx], dtype=torch.float32).to(device), torch.tensor(train_set[val_idx], dtype=torch.float32).to(device)

        # Ensure correct types for durations and events
        y_train_durations, y_train_events = y_train  # Unpack y_train into durations and events

        # Convert durations to float32 and events to int64
        y_train_fold = (
            torch.tensor(y_train_durations[train_idx], dtype=torch.float32).to(device),  # Durations as float32
            torch.tensor(y_train_events[train_idx], dtype=torch.float32).to(device)       # Events as int64
        )

        y_val_fold = (
            torch.tensor(y_train_durations[val_idx], dtype=torch.float32).to(device),    # Durations as float32
            torch.tensor(y_train_events[val_idx], dtype=torch.float32).to(device)         # Events as int64
        )

        # Training
        log = model.fit(X_train, y_train_fold, batch_size, epochs, callbacks, val_data=(X_val, y_val_fold), verbose=False)
        
        scores.append(model.score_in_batches((X_val, y_val_fold))['loss'])
    
    mean_score = sum(scores) / len(scores)

    return mean_score, model




def multi_training(data_dict, seed: int, fusion_type: str, datatypes: list) -> list:
    
    
    miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times = data_dict['mirna']
    mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times = data_dict['mrna']
    clinical_df = data_dict['clinical']
    
    rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(seed)))
    random_order = np.arange(0,len(clinical_df))
    rs.shuffle(random_order)
    slices = n_equal_slices(len(clinical_df), 5)

    c_vals = []

    for bound in slices:

        
        test_set = random_order[bound[0]:bound[1]]
        

        test_bool = np.array([True if i in test_set else False for i in range(len(random_order))], dtype=bool)
        train_bool = ~test_bool
        


        mRNA_df = t_test_feature_selection(train_bool, mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times, 20)
        miRNA_df = t_test_feature_selection(train_bool, miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times, 100)
        scaled = impute_scale(clinical_df.loc[:, ~clinical_df.columns.isin(['case_id', 'label', 'time'])], train_bool)
        
        # print(clinical_df[['case_id', 'label', 'time']].values.shape)

        labeled = np.concatenate((scaled, clinical_df[['case_id', 'label', 'time']].values), axis=1)
        clinical_df = pd.DataFrame(data=labeled, columns=clinical_df.columns)

        processed_data = {'mrna': mRNA_df, 'mirna': miRNA_df, 'clinical': clinical_df}

        arg = []
        for datatype in datatypes:
            arg.append(processed_data[datatype])

        train_set, train_labels, train_times, test_set, test_labels, test_times = combine_transform_data(arg, train_bool)

       
        
        y_train = (train_times, train_labels)
        y_test = (test_times, test_labels)
        
        optimizer = tt.optim.RMSprop(lr=0.007)
        res, model = cross_validate(train_set, y_train,  dropout=0.5, num_nodes=[50,50], batch_size=32, optimizer=optimizer)
        # I don't understand this but it has to be done first
        _ = model.compute_baseline_hazards()
        c = c_index_score(model, test_set, test_labels, test_times)
        print(c)
        c_vals.append(c)
        

    return c_vals

In [6]:


c_vals = multi_training(res_dict, 42, 'early', ['clinical'])
print(f"{np.mean(c_vals)} +/- {2*np.std(c_vals, ddof=1)}")

  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  self.net.load_state_dict(torch.load(path, **kwargs))


0.538961038961039


  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  self.net.load_state_dict(torch.load(path, **kwargs))


0.6089546502690238


  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  self.net.load_state_dict(torch.load(path, **kwargs))


0.5254208754208755


  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  self.net.load_state_dict(torch.load(path, **kwargs))


0.5426425099425541


  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE
  self.net.load_state_dict(torch.load(path, **kwargs))


0.5815824765310684
0.5595123102249122 +/- 0.06928786379262097


In [7]:
def retrain(train_set, y_train, test_set, y_test, optimal_dict):
    X_train= torch.tensor(train_set, dtype=torch.float32)
    # Ensure correct types for durations and events
    y_train_durations, y_train_events = y_train  # Unpack y_train into durations and events

    # Convert durations to float32 and events to int64
    y_train_fold = (
        torch.tensor(y_train_durations, dtype=torch.float32),  
        torch.tensor(y_train_events, dtype=torch.float32)      
    )

    callbacks = [tt.cb.EarlyStopping(patience=5, dataset='train')]

    optimizer = tt.optim.RMSprop(lr=optimal_dict["lr"], momentum=optimal_dict["momentum"])
    # print(X_train)
    net = make_net(X_train.shape[1], 1, [int(optimal_dict['dim1']), int(optimal_dict['dim2'])], dropout=optimal_dict['dropout'])

    model = CoxPH(net, optimizer)
    
    log = model.fit(X_train, y_train_fold, int(optimal_dict['batch']), 100, callbacks, verbose=False)

    risk = model.predict(test_set).squeeze()

    return concordance_index_censored(y_test[1]==1, y_test[0], risk)


def outer_validation(data_dict, seed: int, fusion_type: str, datatypes: list) -> list:
    
    def train_omics(config):

        if config['optimizer'] == 0:
            optimizer = tt.optim.AdamW(lr=config["lr"], betas=(config["B1"], config["B2"]))
        elif config['optimizer'] == 1:
            optimizer = tt.optim.RMSprop(lr=config["lr"], momentum=config["momentum"])
            
        train_set = torch.load("/home/elliotw/Elliot_NN_optimizing/data_tmp/train.pt")
        y_train = torch.load("/home/elliotw/Elliot_NN_optimizing/data_tmp/val.pt")
        
        
        
        dims = [config['dim1'], config['dim2']]
        

        res, _ = cross_validate(train_set, y_train, 
                            dropout=config['dropout'], 
                            num_nodes=dims, 
                            batch_size=config['batch'], 
                            optimizer=optimizer, 
                            )

        train.report({"score": res})
    
    miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times = data_dict['mirna']
    mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times = data_dict['mrna']
    clinical_df = data_dict['clinical']
    
    rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(seed)))
    random_order = np.arange(0,len(clinical_df))
    rs.shuffle(random_order)
    slices = n_equal_slices(len(clinical_df), 5)

    c_vals = []

    for bound in slices:

        # this is inside the for loop because it is overwritten at the end to avoid missing arguments not tuned in raytuner
        grid_space = {
            "lr": tune.grid_search([0.001, 0.005, 0.0075, 0.01, 0.03, 0.05]),
            "B1": 0.9,
            "B2": 0.99,
            "dim1": tune.grid_search([50,100]),
            "dim2": tune.grid_search([50,100]),
            "dropout": tune.grid_search([0.2, 0.5]),
            "batch": 32,
            "momentum": 0,
            "optimizer": 1
        }

        
        test_set = random_order[bound[0]:bound[1]]
        test_bool = np.array([True if i in test_set else False for i in range(len(random_order))], dtype=bool)
        train_bool = ~test_bool
        
        mRNA_df = TruSight170(train_bool, mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times)
        miRNA_df = pearson_feature_selection(train_bool, miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times, 100)
        scaled = impute_scale(clinical_df.loc[:, ~clinical_df.columns.isin(['case_id', 'label', 'time'])], train_bool)

        labeled = np.concatenate((scaled, clinical_df[['case_id', 'label', 'time']].values), axis=1)
        clinical_df = pd.DataFrame(data=labeled, columns=clinical_df.columns)
        processed_data = {'mrna': mRNA_df, 'mirna': miRNA_df, 'clinical': clinical_df}

        arg = []
        for datatype in datatypes:
            arg.append(processed_data[datatype])

        train_set, train_labels, train_times, test_set, test_labels, test_times = combine_transform_data(arg, train_bool)
        y_train = (train_times, train_labels)
        y_test = (test_times, test_labels)

        torch.save(train_set, "./data_tmp/train.pt")
        torch.save(y_train,"./data_tmp/val.pt")
        train_omics = tune.with_resources(train_omics, {"gpu": 0.25})

        tuner = tune.Tuner(
            train_omics,
            param_space=grid_space,
            tune_config=tune.TuneConfig(
                # num_samples=100,
                metric="score",
                mode="min"
            )
        )

        results = tuner.fit()
        best_result = min(results, key= lambda x: x.metrics['score'])
        print(best_result)
        parameters = best_result.path.split('_')[-3].split(',')

        for param in parameters:
            key, value = param.split('=')
            value = float(value)
            grid_space[key] = value

        c = retrain(train_set, y_train, test_set, y_test, grid_space)[0]
        
        c_vals.append(c)

    return c_vals



In [8]:
miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times = res_dict['mirna']
mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times = res_dict['mrna']
clinical_df = res_dict['clinical']

rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(4)))
random_order = np.arange(0,len(clinical_df))
rs.shuffle(random_order)
slices = n_equal_slices(len(clinical_df), 5)

test_set = random_order[slices[0][0]:slices[0][1]]
test_bool = np.array([True if i in test_set else False for i in range(len(random_order))], dtype=bool)
train_bool = ~test_bool

mRNA_df = t_test_feature_selection(train_bool, mRNA_id, mRNA_data, mRNA_gene_names, mRNA_matched_labels, mRNA_matched_times, 20)
miRNA_df = t_test_feature_selection(train_bool, miRNA_id, miRNA_data, miRNA_gene_names, miRNA_matched_labels, miRNA_matched_times, 100)
scaled = impute_scale(clinical_df.loc[:, ~clinical_df.columns.isin(['case_id', 'label', 'time'])], train_bool)

labeled = np.concatenate((scaled, clinical_df[['case_id', 'label', 'time']].values), axis=1)
clinical_df = pd.DataFrame(data=labeled, columns=clinical_df.columns)

processed_data = {'mrna': mRNA_df, 'mirna': miRNA_df, 'clinical': clinical_df}

arg = []
for datatype in ['mrna']:
    arg.append(processed_data[datatype])


train_set, train_labels, train_times, test_set, test_labels, test_times = combine_transform_data(arg, train_bool)
print(train_set.shape)
num_durations = 100

y_train = (train_times, train_labels)
y_test = (test_times, test_labels)
    
torch.save(train_set, "./data_tmp/train.pt")
torch.save(y_train,"./data_tmp/val.pt")


def train_omics(config):

    if config['optimizer'] == 0:
        optimizer = tt.optim.AdamW(lr=config["lr"], betas=(config["B1"], config["B2"]))
    elif config['optimizer'] == 1:
        optimizer = tt.optim.RMSprop(lr=config["lr"], momentum=config["momentum"])
        
    train_set = torch.load("/home/elliotw/Elliot_NN_optimizing/data_tmp/train.pt")
    y_train = torch.load("/home/elliotw/Elliot_NN_optimizing/data_tmp/val.pt")
    
    
    
    dims = [config['dim1'], config['dim2']]
    

    res, _ = cross_validate(train_set, y_train, 
                        dropout=config['dropout'], 
                        num_nodes=dims, 
                        batch_size=config['batch'], 
                        optimizer=optimizer, 
                        )

    train.report({"score": res})



  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE


(408, 20)


  z = ((controls_feature_mean - cases_feature_mean) - mudiff)/pooledSE


In [15]:
res = outer_validation(res_dict, 6, '', ['mirna','mrna','clinical'])

0,1
Current time:,2024-12-16 10:13:42
Running for:,00:00:47.41
Memory:,13.4/62.5 GiB

Trial name,status,loc,dim1,dim2,dropout,lr,iter,total time (s),score
train_omics_3a124_00000,TERMINATED,172.21.21.61:3389074,50,50,0.2,0.001,1,1.55427,3.93063
train_omics_3a124_00001,TERMINATED,172.21.21.61:3389075,100,50,0.2,0.001,1,1.49816,4.06255
train_omics_3a124_00002,TERMINATED,172.21.21.61:3389076,50,100,0.2,0.001,1,1.49251,3.98954
train_omics_3a124_00003,TERMINATED,172.21.21.61:3389077,100,100,0.2,0.001,1,1.52986,3.98892
train_omics_3a124_00004,TERMINATED,172.21.21.61:3389378,50,50,0.5,0.001,1,1.54301,3.77677
train_omics_3a124_00005,TERMINATED,172.21.21.61:3389379,100,50,0.5,0.001,1,1.60378,3.96691
train_omics_3a124_00006,TERMINATED,172.21.21.61:3389380,50,100,0.5,0.001,1,1.51834,3.93053
train_omics_3a124_00007,TERMINATED,172.21.21.61:3389381,100,100,0.5,0.001,1,1.52216,3.87021
train_omics_3a124_00008,TERMINATED,172.21.21.61:3389678,50,50,0.2,0.005,1,1.5631,3.9148
train_omics_3a124_00009,TERMINATED,172.21.21.61:3389679,100,50,0.2,0.005,1,1.58349,3.94003


2024-12-16 10:13:42,831	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/elliotw/ray_results/train_omics_2024-12-16_10-12-55' in 0.0059s.
2024-12-16 10:13:42,835	INFO tune.py:1041 -- Total run time: 47.42 seconds (47.41 seconds for the tuning loop).


Result(
  metrics={'score': 3.75340576171875},
  path='/home/elliotw/ray_results/train_omics_2024-12-16_10-12-55/train_omics_3a124_00047_47_dim1=100,dim2=100,dropout=0.5000,lr=0.0500_2024-12-16_10-12-55',
  filesystem='local',
  checkpoint=None
)


  self.net.load_state_dict(torch.load(path, **kwargs))


In [16]:
print(res)

print(f"{np.mean(res)} +/- {2*np.std(res)/np.sqrt(5)}")

[0.5988601490574309, 0.5208845208845209, 0.536328125, 0.6875525651808242, 0.659264399722415]
0.6005779519690382 +/- 0.058638849325227095


In [49]:
grid_space = {
    "lr": tune.grid_search([0.001, 0.005, 0.01, 0.05]),
    "B1": 0.9,
    "B2": 0.99,
    "dim1": tune.grid_search([50,100]),
    "dim2": tune.grid_search([50,100]),
    "dropout": tune.grid_search([0.2, 0.5]),
    "batch": 64,
    "momentum": tune.grid_search([0,0.2]),
    "optimizer": 1
}
# grid_search

train_omics = tune.with_resources(train_omics, {"gpu": 0.25})

tuner = tune.Tuner(
    train_omics,
    param_space=grid_space,
    tune_config=tune.TuneConfig(
        # num_samples=100,
        metric="score",
        mode="min"
    )
)

results = tuner.fit()

0,1
Current time:,2024-11-18 12:57:57
Running for:,00:00:48.40
Memory:,10.4/62.5 GiB

Trial name,status,loc,dim1,dim2,dropout,lr,momentum,iter,total time (s),score
train_omics_87f7a_00000,TERMINATED,172.21.21.61:2162159,50,50,0.2,0.001,0.0,1,1.14727,3.75474
train_omics_87f7a_00001,TERMINATED,172.21.21.61:2162160,100,50,0.2,0.001,0.0,1,1.05859,3.82707
train_omics_87f7a_00002,TERMINATED,172.21.21.61:2162161,50,100,0.2,0.001,0.0,1,1.14967,3.77997
train_omics_87f7a_00003,TERMINATED,172.21.21.61:2162162,100,100,0.2,0.001,0.0,1,1.10195,3.77514
train_omics_87f7a_00004,TERMINATED,172.21.21.61:2162460,50,50,0.5,0.001,0.0,1,1.06738,3.81092
train_omics_87f7a_00005,TERMINATED,172.21.21.61:2162461,100,50,0.5,0.001,0.0,1,1.03584,3.84528
train_omics_87f7a_00006,TERMINATED,172.21.21.61:2162462,50,100,0.5,0.001,0.0,1,1.16596,3.76416
train_omics_87f7a_00007,TERMINATED,172.21.21.61:2162463,100,100,0.5,0.001,0.0,1,1.11425,3.7805
train_omics_87f7a_00008,TERMINATED,172.21.21.61:2162760,50,50,0.2,0.005,0.0,1,1.12594,3.71318
train_omics_87f7a_00009,TERMINATED,172.21.21.61:2162761,100,50,0.2,0.005,0.0,1,1.05461,3.84049


2024-11-18 12:57:57,858	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/elliotw/ray_results/train_omics_2024-11-18_12-57-09' in 0.0080s.
2024-11-18 12:57:57,865	INFO tune.py:1041 -- Total run time: 48.42 seconds (48.40 seconds for the tuning loop).


In [50]:
best_result = min(results, key= lambda x: x.metrics['score'])
print(best_result)
parameters = best_result.path.split('_')[-3].split(',')

for param in parameters:
    key, value = param.split('=')
    value = float(value)
    grid_space[key] = value

print(grid_space)

Result(
  metrics={'score': 3.713183784484863},
  path='/home/elliotw/ray_results/train_omics_2024-11-18_12-57-09/train_omics_87f7a_00008_8_dim1=50,dim2=50,dropout=0.2000,lr=0.0050,momentum=0_2024-11-18_12-57-09',
  filesystem='local',
  checkpoint=None
)
{'lr': 0.005, 'B1': 0.9, 'B2': 0.99, 'dim1': 50.0, 'dim2': 50.0, 'dropout': 0.2, 'batch': 64, 'momentum': 0.0, 'optimizer': 1}


In [51]:
X_train= torch.tensor(train_set, dtype=torch.float32)

# Ensure correct types for durations and events
y_train_durations, y_train_events = y_train  # Unpack y_train into durations and events

# Convert durations to float32 and events to int64
y_train_fold = (
    torch.tensor(y_train_durations, dtype=torch.float32),  # Durations as float32
    torch.tensor(y_train_events, dtype=torch.float32)      # Events as int64
)

callbacks = [tt.cb.EarlyStopping(patience=5, dataset='train')]

torch.manual_seed(3)
# Define the model and move to GPU
optimizer = tt.optim.RMSprop(lr=grid_space["lr"], momentum=grid_space["momentum"])
# print(X_train)
net = make_net(X_train.shape[1], 1, [int(grid_space['dim1']), int(grid_space['dim2'])], dropout=grid_space['dropout'])

model = CoxPH(net, optimizer)





# Training
log = model.fit(X_train, y_train_fold, int(grid_space['batch']), 100, callbacks, verbose=False)
print(log.to_pandas())

risk = model.predict(test_set).squeeze()

print(concordance_index_censored(y_test[1]==1, y_test[0], risk))
# ev = EvalSurv(surv, y_test[0], y_test[1], censor_surv='km')

# print(ev.concordance_td())


    train_loss  val_loss
0     4.759322       NaN
1     3.571152       NaN
2     3.451598       NaN
3     3.349008       NaN
4     3.236879       NaN
5     3.190073       NaN
6     3.257966       NaN
7     3.070701       NaN
8     3.095080       NaN
9     3.092274       NaN
10    3.090242       NaN
11    3.034384       NaN
12    3.027818       NaN
13    3.008570       NaN
14    2.971441       NaN
15    3.033185       NaN
16    2.940410       NaN
17    2.823087       NaN
18    2.914625       NaN
19    2.859583       NaN
20    2.851130       NaN
21    2.788332       NaN
22    2.893174       NaN
23    2.840199       NaN
24    2.740641       NaN
25    2.748704       NaN
26    2.801127       NaN
27    2.740373       NaN
28    2.738158       NaN
29    2.723378       NaN
30    2.697249       NaN
31    2.678042       NaN
32    2.617781       NaN
33    2.581289       NaN
34    2.603884       NaN
35    2.597042       NaN
36    2.682450       NaN
37    2.569641       NaN
38    2.543880       NaN


  self.net.load_state_dict(torch.load(path, **kwargs))
