In this notebook, we will demonstrate how to implement our method on the SS3 examples from our paper

## Example

In this example, a high-dimensional dataset with 300 covariates and 600 observations is generated using the following functions:

$$
m_L(\boldsymbol{z} * \boldsymbol{w}_L) = 0.8  \sum_{j=1}^{10} \sin(\boldsymbol{z}_{j}) + 0.8 (\boldsymbol{z}_1 \boldsymbol{z}_2 + \boldsymbol{z}_9 \boldsymbol{z}_{10})
$$
and 
$$
m_C(\boldsymbol{z} * \boldsymbol{w}_C) = 0.8 \sum_{j=1}^{10} \sin(\boldsymbol{z}_{j}) + 0.8 (\boldsymbol{z}_1 \boldsymbol{z}_2 + \boldsymbol{z}_9 \boldsymbol{z}_{10}).
$$

That is, among the 300 covariates, only the first 10 variables in the 'L' and 'C' parts actually contribute to the response. Our task is to correctly identify the important variables.


In [1]:

import numpy as np
from numpy.random import normal, rand
import math
import torch
from torch import nn
import itertools
from itertools import product
from modelsn import Net_nonlinear
import torch.optim as optim
from MFS import FS_epoch, total_loss, training_n
from tqdm import tqdm
from numpy.random import gamma
from dt_g import generate_data,generate_Z
torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import csv
from main import metric
from Cindex_AUC import cindex_AUC


## Data Preparation
- $Z$ is the $n\times p$ covariate matrix;
- $T$ represents the observation time;
- $delta$ indicates censoring status;
- $tau$ is the set of event times in sorted order;
- $Rj$ denotes the set of samples at risk;
- $beta$ represents the probability of cure;
- $alpha$ is the total censoring rate minus the cure probability;
- $ll =\lambda_3\times p$, which is used to control the similarity between the variable selection results of the L and C components. 


For detailed data generation procedures, please refer to the paper.

## Fixed hyper-parameters

-  $s1=s_L,\quad s2=s_C $, the numbers of variables to be selected in part L and part C respectively;
-   epochs, the number of iterations to be run;
-   n_hidden1 & n_hidden2, the number of neurons in the fully connect network;
-   learning_rate, the learning rate for optimizer;
-   Ts & step, the parameters to control the optimization on given support


In [2]:
save_file_name='result.csv'
seed=1234
alpha=0.05
beta=0.3
f='S4'
n=600
p=300
s1=11
s2=9
learning_rate=0.0005
n_hidden1 = 50
n_hidden2 = 10
epochs=5#To avoid long time waiting, we set a smaller number of epochs in this case
Ts=25
step=5
ll=20
c=1
Z=generate_Z(seed,n,p)
T, delta, tau, d, Rj, idx,y_cure= generate_data(device,seed,f,Z,n,p,alpha,beta)
# Define Model
model = Net_nonlinear(n_feature=p, n_hidden1=n_hidden1, n_hidden2=n_hidden2,n_output=1).to(device=device)
best_model = Net_nonlinear(n_feature=p, n_hidden1=n_hidden1, n_hidden2=n_hidden2, n_output=1).to(device=device)
optimizer = torch.optim.Adam(list(model.parameters()), lr=learning_rate, weight_decay=0.0025)
# Define optimizers for the optimization with given support
optimizer0_1 = torch.optim.Adam(model.hidden0_1.parameters(), lr=learning_rate, weight_decay=0.0005)
optimizer0_2 = torch.optim.Adam(model.hidden0_2.parameters(), lr=learning_rate, weight_decay=0.0005)
hist = []
SUPP1 = []
SUPP2 = []
supp_x1 = list(range(p)) # initial support of part L
supp_x2 = list(range(p)) # initial support of part C
supp_x=[supp_x1,supp_x2]
SUPP1.append(supp_x1)
SUPP2.append(supp_x2)
data=[Z, T, delta, tau, d, Rj, idx]
n,p=Z.shape
eta= torch.rand(n).to(device=device)
eta[delta==1]=1
# eta.requires_grad = False
k = len(tau)


In [None]:
### Algorithm
for i in range(epochs):
    print('epoch:',i)
    # One DFS epoch
    model, supp_x,LOSS,eta=FS_epoch(model, s1,s2, supp_x,data, optimizer, optimizer0_1, optimizer0_2,eta, Ts, step,ll)
    # supp_x.sort()
    _,loss=total_loss (data,model,eta,ll)
    print('loss',loss)
    hist.append(loss.data.cpu().numpy().tolist())
    SUPP1.append(supp_x[0])
    SUPP2.append(supp_x[1])
    # Prevent divergence of optimization over support, save the current best model
    if hist[-1] == min(hist):
        best_model.load_state_dict(model.state_dict())
        best_supp = supp_x
        #print(best_supp)
    #Early stop criteria
    if ((len(SUPP1[-1])==len(SUPP1[-2])) & (len(SUPP2[-1])==len(SUPP2[-2]))):

        if((set(SUPP1[-1])==set(SUPP1[-2])) & (set(SUPP2[-1])==set(SUPP2[-2]))) :
            break
print(loss)
best_supp[0],best_supp[1]=list(best_supp[0]),list(best_supp[1])
print('best_supp',best_supp)
correct_set1=list(range(10))
correct_set2=list(range(10))
Z_test=generate_Z(seed+1,n//10,p)
T_test, delta_test, _, _, _, _,y_cure_test= generate_data(device,seed,f,Z_test,n//10,p,alpha,beta)
cindex,AUC=cindex_AUC(T_test, Z_test, delta_test, best_model,y_cure_test,'True')
TPRL,FPRL=metric(correct_set1,best_supp[0],seed, f, alpha,beta,n,p,'f1')
TPRC,FPRC=metric(correct_set2,best_supp[1],seed, f, alpha,beta,n,p,'f2')
print('TPRC:',TPRC,'FPRC:',FPRC,'TPRL:',TPRL,'FPRL:',FPRL,'C-index:',cindex,'AUC:',AUC)

epoch: 0
L-part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 108] C-part: [0, 1, 2, 3, 4, 6, 7, 8, 9]
loss tensor(3.8017, grad_fn=<AddBackward0>)
epoch: 1
L-part: [0, 1, 2, 3, 4, 6, 7, 8, 9, 47, 295] C-part: [0, 1, 2, 3, 4, 6, 7, 8, 9]
loss tensor(3.5346, grad_fn=<AddBackward0>)
epoch: 2
L-part: [0, 1, 2, 3, 4, 6, 7, 8, 9, 215, 233] C-part: [0, 1, 2, 3, 4, 6, 7, 8, 9]
loss tensor(3.5657, grad_fn=<AddBackward0>)
epoch: 3
L-part: [0, 1, 2, 3, 4, 6, 7, 8, 9, 107, 219] C-part: [0, 1, 2, 3, 4, 6, 7, 8, 9]
loss tensor(3.5724, grad_fn=<AddBackward0>)
epoch: 4
L-part: [0, 1, 2, 3, 4, 6, 7, 8, 9, 67, 298] C-part: [0, 1, 2, 3, 4, 6, 7, 8, 9]
loss tensor(3.5283, grad_fn=<AddBackward0>)
tensor(3.5283, grad_fn=<AddBackward0>)
best_supp [[9, 4, 6, 3, 0, 1, 2, 7, 8, 67, 298], [9, 0, 8, 1, 4, 2, 6, 7, 3]]
TPRC: 0.9 FPRC: 0.0 TPRL: 0.9 FPRL: 0.006896551724137931 C-index: 0.8193456614509246 AUC: 0.8433889602053916


We regenerated a test set with a sample size of n/10 to calculate the AUC and C-index. In terms of variable selection results, the algorithm correctly identified 9 influential variables for both Part L and Part C. It made  2 errors in selecting variables for Part L and no errors for Part C. 

## Selection of $s_L$, $s_C$

In [4]:

epochs=5
hist = []
supp_x1=supp_x2 = list(range(p)) # initial support
supp_x=[supp_x1,supp_x2]
eta_o = torch.rand(n).to(device=device)
eta_o[delta==1]=1
original_list = [9,10,11,12,13]# We shorten the candidates list in the notebooks
Ss = list(itertools.product(original_list, repeat=2))
BIC = [] # Store the bic for different s
S_num=[]
best_model = Net_nonlinear(n_feature=p, n_hidden1=n_hidden1, n_hidden2=n_hidden2, n_output=1)
for s in Ss:
    # Training dataset  with given s
    s1=s[0]
    s2=s[1]
    eta=eta_o
    loss,model,supp,bic= training_n(data, s1,s2,eta, epochs=epochs, n_hidden1=n_hidden1, n_hidden2=n_hidden2, learning_rate=0.0005, Ts=25, step=5,ll=20)
    # Store bic values
    BIC.append(bic)
    S_num.append(len(supp))
    if bic == min(BIC):
        best_model.load_state_dict(model.state_dict())
        best_supp = supp
    mid_result=[seed,s1,s2,loss,n,len(supp[0]),len(supp[1])]
    # with open('hist'+save_file_name, 'a', newline='') as file:
    #     writer = csv.writer(file)
    #     if file.tell() == 0:
    #         writer.writerow(["seed", "s1","s2","loss","n","Ss1","Ss2"])
    #     writer.writerow(mid_result)
idx = np.argmin(BIC)
best_s1 = Ss[idx][0]
best_s2 = Ss[idx][1]
print('C-part:Sselected:',best_s2,'L-part:Sselected:',best_s1)


epoch: 0
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 13, 58]
epoch: 1
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 13, 58]
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 13, 58]
tensor(3.4050, grad_fn=<AddBackward0>)
epoch: 0
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 6, 7, 8, 13, 58]
epoch: 1
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 9, 13, 58]
epoch: 2
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 9, 13, 58]
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 9, 13, 58]
tensor(3.2874, grad_fn=<AddBackward0>)
epoch: 0
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 6, 7, 8, 13, 58, 167]
epoch: 1
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 9, 13, 58, 295]
epoch: 2
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: [0, 1, 2, 3, 5, 7, 8, 9, 13, 58, 162]
epoch: 3
L-part: [0, 1, 2, 3, 5, 7, 8, 13, 58] C-part: