In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sksurv.metrics import concordance_index_censored
import matplotlib.pyplot as plt
import pandas as pd
# ---------------------------------------------
# 1. Load your own survival dataset
# Replace this section with your real dataset
# ---------------------------------------------
mydata = pd.read_csv("data_ready_45.csv")
# Example of simulated data (you should replace this with your actual data)
np.random.seed(42)  # Remove this line if you don't need random seeding

# --- Replace the following three variables with your actual data ---
X = mydata.drop(columns=["GRF_STAT_PA", "time_frame"], axis=1)
T = mydata.pop("time_frame")
E = mydata.pop("GRF_STAT_PA")

X = X.to_numpy()
T = T.to_numpy()
E = E.to_numpy()
# ------------------------------------------------------------------

# Convert your data into PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)  # Covariates (features)
T_tensor = torch.tensor(T, dtype=torch.float32)  # Time-to-event (survival times)
E_tensor = torch.tensor(E, dtype=torch.float32)  # Event/censoring indicator

# --------------------------------------------------
# 2. No need to change the models (Generator, Discriminator)
# You can leave the architecture as it is unless you want to tweak it.
# --------------------------------------------------

# Define the Generator model
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)  # Output: time-to-event and event/censoring indicator
        )

    def forward(self, z):
        output = self.fc(z)
        time = torch.relu(output[:, 0])  # Ensure non-negative times
        event = torch.sigmoid(output[:, 1])  # Probability of event (between 0 and 1)
        return time, event

# Define the Discriminator model
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # Output: probability of being real
        )

    def forward(self, x):
        return self.fc(x)

# Initialize the generator and discriminator
generator = Generator(input_dim=100, output_dim=2)  # Input: noise vector, Output: time-to-event and event/censoring
discriminator = Discriminator(input_dim=2)  # Input: time-to-event and event/censoring pairs

# Define the loss function and optimizers
criterion = nn.BCELoss()  # Binary cross-entropy for real/fake classification
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# ---------------------------------------------------------
# 3. Training loop - No need to change unless you'd like to tweak epochs or batch size
# ---------------------------------------------------------

# Training loop for GAN
epochs = 10000  # You can reduce this if you don't need that many epochs
batch_size = 64
real_label = 1
fake_label = 0

for epoch in range(epochs):
    # Train Discriminator with real data
    discriminator.zero_grad()
    
    # --- Replace this sampling section if needed ---
    # Sample a batch of real data (time-to-event and event/censoring)
    idx = np.random.randint(0, len(T), batch_size)  # Adjust this to your dataset's size
    real_time = T_tensor[idx].view(-1, 1)
    real_event = E_tensor[idx].view(-1, 1)
    real_data = torch.cat([real_time, real_event], dim=1)
    real_labels = torch.full((batch_size, 1), real_label, dtype=torch.float32)     # -----------------------------------------------
    
    output_real = discriminator(real_data)
    loss_real = criterion(output_real, real_labels)
    
    # Generate fake data using the Generator
    z = torch.randn(batch_size, 100)  # Random noise input
    fake_time, fake_event = generator(z)
    fake_data = torch.cat([fake_time.view(-1, 1), fake_event.view(-1, 1)], dim=1)
    fake_labels = torch.full((batch_size, 1), fake_label, dtype=torch.float32)    
    # Train Discriminator on fake data
    output_fake = discriminator(fake_data.detach())
    loss_fake = criterion(output_fake, fake_labels)
    
    # Combine discriminator loss
    loss_D = loss_real + loss_fake
    loss_D.backward()
    optimizer_D.step()
    
    # Train Generator to fool the discriminator
    generator.zero_grad()
    output_fake = discriminator(fake_data)
    loss_G = criterion(output_fake, real_labels)  # Goal: make discriminator classify fake data as real
    loss_G.backward()
    optimizer_G.step()
    
    # Print loss every 1000 epochs
    if epoch % 1000 == 0:
        print(f'Epoch {epoch} | Loss D: {loss_D.item()} | Loss G: {loss_G.item()}')

# -----------------------------------------------------------
# 4. Evaluation: Adjust this section based on your dataset size and evaluation requirements
# -----------------------------------------------------------

# Generate synthetic data using the trained Generator
z = torch.randn(len(T), 100)  # Adjust this to the number of synthetic samples you want to generate
synthetic_time, synthetic_event = generator(z)
synthetic_time = synthetic_time.detach().numpy()
synthetic_event = (synthetic_event.detach().numpy() > 0.5).astype(int)  # Convert event probabilities to binary

# Calculate the Concordance Index (C-index) on the synthetic data compared to the real data
real_cindex = concordance_index_censored(E == 1, T, T)[0]
synthetic_cindex = concordance_index_censored(synthetic_event == 1, synthetic_time, synthetic_time)[0]

print(f"Real Data C-index: {real_cindex}")
print(f"Synthetic Data C-index: {synthetic_cindex}")

# Visualization: Compare the distribution of real vs. synthetic survival times
plt.hist(T, bins=50, alpha=0.5, label='Real Data')  # Real time-to-event data
plt.hist(synthetic_time, bins=50, alpha=0.5, label='Synthetic Data')  # Synthetic time-to-event data
plt.legend()
plt.title('Comparison of Real vs. Synthetic Survival Times')
plt.xlabel('Time-to-Event')
plt.ylabel('Frequency')
plt.show()


Epoch 0 | Loss D: 90.41207122802734 | Loss G: 0.9191253185272217
Epoch 1000 | Loss D: 100.03914642333984 | Loss G: 0.0
Epoch 2000 | Loss D: 100.00851440429688 | Loss G: 0.0
Epoch 3000 | Loss D: 100.0 | Loss G: 0.0
Epoch 4000 | Loss D: 100.0 | Loss G: 0.0
Epoch 5000 | Loss D: 100.0 | Loss G: 0.0
Epoch 6000 | Loss D: 100.0 | Loss G: 0.0
Epoch 7000 | Loss D: 100.0 | Loss G: 0.0
Epoch 8000 | Loss D: 100.0 | Loss G: 0.0
Epoch 9000 | Loss D: 100.0 | Loss G: 0.0


ValueError: All samples are censored