In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
import torch
import torch.utils.data as data
import pandas as pd
import math
from scipy.stats import pearsonr  # Import pearsonr function

In [9]:
TNBC_C= pd.read_csv(r"/Users/xinwang/Dropbox (Choate)/Isabella Dropbox/Topology_ST/TNBC_Slides/TNBC_ST_C/TNBC_C_ITF800_HLAB.csv")
TNBC_C= TNBC_C.iloc[:, 1:]
X = TNBC_C.iloc[:, :-1]
y = TNBC_C.iloc[:,-1]
# Step 1: Feature Scaling/Normalization
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)
# Step 2: Target Scaling
target_scaler = MinMaxScaler()
y_scaled = target_scaler.fit_transform(y.values.reshape(-1,1))
X_train, X_val, y_train, y_val = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)
batch_size = 32

train_data = data.TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))
train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

val_data = data.TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.float32))
val_loader = data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [11]:
#from ignite.handlers import EarlyStopping
from torch.optim.lr_scheduler import ExponentialLR

In [26]:
# Define early stopping parameters
patience = 10
min_delta = 0.00001 # Minimum change in validation loss to be considered an improvement
cumulative_delta = False  # Set to True if min_delta defines increase since last patience reset

best_val_loss = float('inf')
epochs_no_improve = 0

# Define learning rate scheduler
scheduler = ExponentialLR(optimizer, gamma=0.9)

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [29]:
class TransformerRegression(nn.Module):
    def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(TransformerRegression, self).__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, 1)  # Regression output
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        embedded = self.dropout(self.positional_encoding(self.embedding(x)))
        enc_output = embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, None)

        output = self.fc(enc_output)
        return output.squeeze(2)  # Remove the last dimension (sequence length)

# Assuming you have your data ready
input_dim = 800
d_model = 256
num_heads = 8
num_layers = 4
d_ff = 1024
max_seq_length = 1  # Since this is not sequential data
dropout = 0.1

transformer = TransformerRegression(input_dim, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

criterion = nn.MSELoss()  # Use MSE loss for regression
optimizer = optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
epochs_no_improve = 0
num_epochs = 100
for epoch in range(num_epochs):
    transformer.train()
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        output = transformer(batch_X)
        output = output.view(-1,1)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
    # Apply learning rate scheduler
    scheduler.step()
    # Print learning rate
    current_lr = optimizer.param_groups[0]['lr']  # Get the current learning rate from the optimizer
    print(f"Epoch: {epoch+1}, Learning Rate: {current_lr:.6f}")

    transformer.eval()  # Switch to evaluation mode for validation
    with torch.no_grad():
        val_mse_sum = 0.0
        val_mae_sum = 0.0
        val_pearson_sum = 0.0  # Initialize Pearson sum
        num_val_batches = 0
        #all_val_attention_scores = []
        
        for batch_X, batch_y in val_loader:
            val_output = transformer(batch_X)
            val_output = val_output.view(-1, 1)
            #print(val_output.shape)
            #val_output = val_output.squeeze()  # Reshape to match batch_y shape
            #batch_y = batch_y.view(val_output.shape)
            val_mse_sum += mean_squared_error(batch_y, val_output)
            val_mae_sum += mean_absolute_error(batch_y, val_output)
            val_pearson, _ = pearsonr(batch_y.view(-1), val_output.view(-1))  # Calculate Pearson correlation
            val_pearson_sum += val_pearson
            num_val_batches += 1
            #all_val_attention_scores.extend(attn_scores_list)
            

        # Calculate average metrics over all validation batches
        avg_val_mse = val_mse_sum / num_val_batches
        avg_val_mae = val_mae_sum / num_val_batches
        avg_val_pearson = val_pearson_sum / num_val_batches
    # Check for early stopping
    if avg_val_mse < best_val_loss:
        best_val_loss = avg_val_mse
        epochs_no_improve = 0
    else:
       # best_val_loss = avg_val_mse
        epochs_no_improve += 1

    if epochs_no_improve == patience:
        print("Early stopping triggered.")
        break
    print(f"Epoch: {epoch+1}, Avg. Val MSE: {avg_val_mse:.4f}, Avg. Val MAE: {avg_val_mae:.4f}, Avg. Val Pearson: {avg_val_pearson:.4f}")
   


Epoch: 1, Learning Rate: 0.000900
Epoch: 1, Avg. Val MSE: 0.0122, Avg. Val MAE: 0.0743, Avg. Val Pearson: 0.3371
Epoch: 2, Learning Rate: 0.000810
Epoch: 2, Avg. Val MSE: 0.0133, Avg. Val MAE: 0.0770, Avg. Val Pearson: 0.3414
Epoch: 3, Learning Rate: 0.000729
Epoch: 3, Avg. Val MSE: 0.0099, Avg. Val MAE: 0.0715, Avg. Val Pearson: 0.3399
Epoch: 4, Learning Rate: 0.000656
Epoch: 4, Avg. Val MSE: 0.0101, Avg. Val MAE: 0.0711, Avg. Val Pearson: 0.3390
Epoch: 5, Learning Rate: 0.000590
Epoch: 5, Avg. Val MSE: 0.0102, Avg. Val MAE: 0.0852, Avg. Val Pearson: 0.3342
Epoch: 6, Learning Rate: 0.000531
Epoch: 6, Avg. Val MSE: 0.0159, Avg. Val MAE: 0.1127, Avg. Val Pearson: 0.3258
Epoch: 7, Learning Rate: 0.000478
Epoch: 7, Avg. Val MSE: 0.0187, Avg. Val MAE: 0.1230, Avg. Val Pearson: 0.3371
Epoch: 8, Learning Rate: 0.000430
Epoch: 8, Avg. Val MSE: 0.0088, Avg. Val MAE: 0.0683, Avg. Val Pearson: 0.3387
Epoch: 9, Learning Rate: 0.000387
Epoch: 9, Avg. Val MSE: 0.0085, Avg. Val MAE: 0.0723, Avg. Val