In [13]:
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
import torch.optim as optim
import copy
from scipy.interpolate import interp1d
dtype = torch.FloatTensor

In [14]:
def sort_data(path):
    data = pd.read_csv(path)
    data.sort_values("OS.time",ascending = False, inplace = True)
    x = data.drop(["Patient_ID", "race_black", "race_white", "age", "stageh","gradeh", "OS", "OS.time"], axis = 1).values
    ytime = data.loc[:, ["OS.time"]].values
    yevent = data.loc[:, ["OS"]].values
    age = data.loc[:, ["age"]].values
    cstage = data.loc[:, ["stageh"]].values
    hgrade = data.loc[:, ["gradeh"]].values
    race_black = data.loc[:, ["race_black"]].values
    race_white = data.loc[:, ["race_white"]].values
    return(x, ytime, yevent, age, cstage, hgrade, race_black, race_white)

def load_data(path, dtype):
    x, ytime, yevent, age, cstage, hgrade, race_black, race_white = sort_data(path)
    X = torch.from_numpy(x).type(dtype)
    YTIME = torch.from_numpy(ytime).type(dtype)
    YEVENT = torch.from_numpy(yevent).type(dtype)
    AGE = torch.from_numpy(age).type(dtype)
    CSTAGE = torch.from_numpy(cstage).type(dtype)
    HGRADE = torch.from_numpy(hgrade).type(dtype)
    RACE_BLACK = torch.from_numpy(race_black).type(dtype)
    RACE_WHITE = torch.from_numpy(race_white).type(dtype)
    if torch.cuda.is_available():
        X = X.cuda()
        YTIME = YTIME.cuda()
        YEVENT = YEVENT.cuda()
        AGE = AGE.cuda()
        CSTAGE = CSTAGE.cuda()
        HGRADE = HGRADE.cuda()
        RACE_BLACK = RACE_BLACK.cuda()
        RACE_WHITE = RACE_WHITE.cuda()
    return(X, YTIME, YEVENT, AGE, CSTAGE, HGRADE, RACE_BLACK, RACE_WHITE)

In [15]:
class EarlyStopping:
    def __init__(self, patience, verbose=False, delta=0):
        
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter % 20 == 0:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [16]:
def reconstruction_loss(x, x_recon):
    batch_size = x.size(0)
    assert batch_size != 0
    
    recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

    return recon_loss

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    
    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

In [17]:
def R_set(x):
    n_sample = x.size(0)
    matrix_ones = torch.ones(n_sample, n_sample)
    indicator_matrix = torch.tril(matrix_ones)
    return(indicator_matrix)

def merged_loss_function(pred, ytime, yevent):
    n_observed = yevent.sum(0)
    ytime_indicator = R_set(ytime)
    if torch.cuda.is_available():
        ytime_indicator = ytime_indicator.cuda()
    risk_set_sum = ytime_indicator.mm(torch.exp(pred)) 
    diff = pred - torch.log(risk_set_sum)
    sum_diff_in_observed = torch.transpose(diff, 0, 1).mm(yevent)
    cost = (- (sum_diff_in_observed / n_observed)).reshape((-1,))
    return(cost)

def c_index(pred, ytime, yevent):
    n_sample = len(ytime)
    ytime_indicator = R_set(ytime)
    ytime_matrix = ytime_indicator - torch.diag(torch.diag(ytime_indicator))
    censor_idx = (yevent == 0).nonzero()
    zeros = torch.zeros(n_sample)
    ytime_matrix[censor_idx, :] = zeros
    pred_matrix = torch.zeros_like(ytime_matrix)
    for j in range(n_sample):
        for i in range(n_sample):
            if pred[i] < pred[j]:
                pred_matrix[j, i]  = 1
            elif pred[i] == pred[j]: 
                pred_matrix[j, i] = 0.5
    concord_matrix = pred_matrix.mul(ytime_matrix)
    concord = torch.sum(concord_matrix)
    epsilon = torch.sum(ytime_matrix)
    concordance_index = torch.div(concord, epsilon)
    if torch.cuda.is_available():
        concordance_index = concordance_index.cuda()
    return(concordance_index)

In [18]:
def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std*eps

In [19]:
def kaiming_init(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)

In [34]:
class BetaVAE_H_sup(nn.Module):
    """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017). Modifications made to best accommodate our data"""

    def __init__(self, z_dim, input_n):
        super(BetaVAE_H_sup, self).__init__()
        self.z_dim = z_dim
        self.nc = input_n
        self.encoder = nn.Sequential(
            nn.Linear(input_n, 200),          
            nn.ReLU(True),
            nn.Linear(200, 50),         
            nn.ReLU(True),
            nn.Linear(50, z_dim*2)            
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 50),                             
            nn.ReLU(True),
            nn.Linear(50, 200),      
            nn.ReLU(True),
            nn.Linear(200, input_n)
        )
        self.clilayer = nn.Linear(z_dim + 5, 1, bias = False)
        
        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            if block == "clilayer":
                init.uniform_(self._modules[block].weight, -0.001, 0.001)
            else:
                for m in self._modules[block]:
                    kaiming_init(m)

    def forward(self, x, x_2, x_3, x_4, x_5, x_6):
        distributions = self._encode(x)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]
        z = reparametrize(mu, logvar)
        x_recon = self._decode(z)
        x_cat = torch.cat((z, x_2, x_3, x_4, x_5, x_6), 1)
        lin_pred = self.clilayer(x_cat)

        return x_recon, mu, logvar, lin_pred

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)

In [35]:
def trainBetaVAE_H_sup(train_x, train_age, train_cstage, train_hgrade, train_race_black, train_race_white, train_ytime, train_yevent,
                       eval_x, eval_age, eval_cstage, eval_hgrade, eval_race_black, eval_race_white, eval_ytime, eval_yevent,
                       z_dim, input_n, Learning_Rate, L2, Num_Epochs, patience, beta):
    net = BetaVAE_H_sup(z_dim, input_n)
    
    early_stopping = EarlyStopping(patience = patience, verbose = False)
    
    if torch.cuda.is_available():
        net.cuda()
    opt = optim.Adam(net.parameters(), lr=Learning_Rate, weight_decay = L2)
    for epoch in range(Num_Epochs+1):
        net.train()
        opt.zero_grad()
        
        x_recon, mu, logvar, lin_pred = net(train_x, train_age, train_cstage, train_hgrade, train_race_black, train_race_white)
        recon_loss = reconstruction_loss(train_x, x_recon)
        total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
        cox_cost = merged_loss_function(lin_pred, train_ytime, train_yevent)
        beta_vae_loss = recon_loss + beta*total_kld + cox_cost
        
        beta_vae_loss.backward()
        opt.step()
        
        net.eval()
        val_x_recon, val_mu, val_logvar, val_pred = net(eval_x, eval_age, eval_cstage, eval_hgrade, eval_race_black, eval_race_white)
        val_recon_loss = reconstruction_loss(eval_x, val_x_recon)
        val_total_kld, val_dim_wise_kld, val_mean_kld = kl_divergence(val_mu, val_logvar)
        val_cox_cost = merged_loss_function(val_pred, eval_ytime, eval_yevent)
        val_loss = val_recon_loss + beta*val_total_kld + val_cox_cost
        
        early_stopping(val_loss, net)
        if early_stopping.early_stop:
            net.train()
            tr_x_recon, tr_mu, tr_logvar, tr_pred = net(train_x, train_age, train_cstage, train_hgrade, train_race_black, train_race_white)
            tr_recon_loss = reconstruction_loss(train_x, tr_x_recon)
            tr_total_kld, tr_dim_wise_kld, tr_mean_kld = kl_divergence(tr_mu, tr_logvar)
            tr_cox_cost = merged_loss_function(tr_pred, train_ytime, train_yevent)
            tr_loss = tr_recon_loss + beta*tr_total_kld + tr_cox_cost
            
            tr_cindex = c_index(tr_pred, train_ytime, train_yevent)
            val_cindex = c_index(val_pred, eval_ytime, eval_yevent)
            print("Early stopping, Number of epochs: ", epoch, ", Loss in Validation: ", val_loss, ", Loss in Training: ", tr_loss)
            break
        if epoch % 200 == 0:
            net.train()
            tr_x_recon, tr_mu, tr_logvar, tr_pred = net(train_x, train_age, train_cstage, train_hgrade, train_race_black, train_race_white)
            tr_recon_loss = reconstruction_loss(train_x, tr_x_recon)
            tr_total_kld, tr_dim_wise_kld, tr_mean_kld = kl_divergence(tr_mu, tr_logvar)
            tr_cox_cost = merged_loss_function(tr_pred, train_ytime, train_yevent)
            tr_loss = tr_recon_loss + beta*tr_total_kld + tr_cox_cost
            
            tr_cindex = c_index(tr_pred, train_ytime, train_yevent)
            val_cindex = c_index(val_pred, eval_ytime, eval_yevent)
            print("Loss in Train: ", tr_loss)
    return (tr_loss, val_loss, tr_mu, tr_logvar, val_mu, val_logvar, tr_cindex, val_cindex)

In [46]:
z_dim = 10
input_n = 929
Initial_Learning_Rate = [0.03, 0.01, 0.001, 0.00075]
L2_Lambda = [0.1, 0.01, 0.005, 0.001]
patience = 100
beta = 1000
num_epochs = 600
Num_EPOCHS = 2000
x_train, ytime_train, yevent_train, age_train, cstage_train, hgrade_train, race_black_train, race_white_train = load_data("D:/DL/Variational autoencoder/Tryout_12_22_2020/mir_train_normalized.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid, cstage_valid, hgrade_valid, race_black_valid, race_white_valid = load_data("D:/DL/Variational autoencoder/Tryout_12_22_2020/mir_validation_normalized.csv", dtype)
x_test, ytime_test, yevent_test, age_test, cstage_test, hgrade_test, race_black_test, race_white_test = load_data("D:/DL/Variational autoencoder/Tryout_12_22_2020/mir_test_normalized.csv", dtype)
opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
if torch.cuda.is_available():
    opt_loss = opt_loss.cuda()
for l2 in L2_Lambda:
    for lr in Initial_Learning_Rate:
        loss_train, loss_valid, tr_mu, tr_logvar, val_mu, val_logvar, tr_cindex, val_cindex = trainBetaVAE_H_sup(x_train, age_train, cstage_train, hgrade_train, race_black_train, race_white_train, ytime_train, yevent_train,
                                                                                                             x_valid, age_valid, cstage_valid, hgrade_valid, race_black_valid, race_white_valid, ytime_valid, yevent_valid,
                                                                                                             z_dim, input_n, lr, l2, num_epochs, patience, beta)
        if loss_valid < opt_loss:
            opt_l2_loss = l2
            opt_lr_loss = lr
            opt_loss = loss_valid
        print ("L2: ", l2, ", LR: ", lr, ", Loss in Validation: ", loss_valid)
loss_train, loss_test, tr_mu, tr_logvar, tes_mu, tes_logvar, tr_cindex, tes_cindex = trainBetaVAE_H_sup(x_train, age_train, cstage_train, hgrade_train, race_black_train, race_white_train, ytime_train, yevent_train,
                                                                                                    x_test, age_test, cstage_test, hgrade_test, race_black_test, race_white_test, ytime_test, yevent_test,
                                                                                                    z_dim, input_n, opt_lr_loss, opt_l2_loss, Num_EPOCHS, patience, beta)
print ("Optimal L2: ", opt_l2_loss, ", Optimal LR: ", opt_lr_loss)
print ("Training C-index: ", tr_cindex, ", testing C-index: ", tes_cindex)

Loss in Train:  tensor([18426298.], grad_fn=<AddBackward0>)
Loss in Train:  tensor([nan], grad_fn=<AddBackward0>)
Loss in Train:  tensor([nan], grad_fn=<AddBackward0>)
Loss in Train:  tensor([nan], grad_fn=<AddBackward0>)
L2:  0.1 , LR:  0.03 , Loss in Validation:  tensor([nan], grad_fn=<AddBackward0>)
Loss in Train:  tensor([26487.6426], grad_fn=<AddBackward0>)
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 20 out of 100
Loss in Train:  tensor([931.1347], grad_fn=<AddBackward0>)
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 40 out of 100
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 80 out of 100
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 40 out of 100
Loss in Train:  tensor([930.4819], grad_fn=<AddBackward0>)
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 80 out of 100
EarlyStopping counter: 100 out of 100
Early stopping, Number of epochs:  454 , Loss in Validation:  tensor([88