In [None]:
import pandas as pd
import numpy as np
from ChromaVDB.chroma import ChromaFramework
from DeepGraphDB import DeepGraphDB
from tqdm.notebook import tqdm

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()

names = [record['name'] for record in records]
embs = [record['embeddings'] for record in records]
ids = [record['id'] for record in records]

data = pd.read_excel('data/2025_03_29.xlsx') # Provare ad usare anche stadio-avanzato, IPI e Log10hGE
# type: DLBCL (Diffuse Large B-cell Lymphoma)

In [None]:
# gene_set = "LNF"
gene_set = "plasma"
# gene_measure = "MUT"
gene_measure = "VAF"

gene_data = data[[col for col in data.columns if gene_set in col and gene_measure in col]]

In [None]:
genes = list(set([ gene.split('_')[0] for gene in gene_data.columns  ]))

final_columns = []
embeddings = []
record_ids = []

for gene in genes:
    if gene in names:
        final_columns.append(gene+"_"+gene_set+"_"+gene_measure)
        embeddings.append(embs[names.index(gene)])
        record_ids.append(ids[names.index(gene)])
    else:
        print(gene)

print(len(final_columns))

In [None]:
gene_data = gene_data[final_columns]
gene_data['pfs'] = data['PFS_Cens_updated']
# gene_data = gene_data.dropna()
gene_data = gene_data.fillna(0)

In [None]:
pfs_binary = gene_data['pfs'].values
mutants = gene_data.drop(columns=['pfs'])
patient_mutations = mutants.values
gene_embeddings = np.array(embeddings)

# # L2 normalization
# norms = np.linalg.norm(gene_embeddings, axis=1, keepdims=True)
# # Avoid division by zero
# norms = np.where(norms == 0, 1, norms)
# gene_embeddings = gene_embeddings / norms

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, balanced_accuracy_score

class PFSClassificationMLP(nn.Module):
    def __init__(self, n_genes, embedding_dim, hidden_dims=[512, 256]):
        super().__init__()
        
        self.n_genes = n_genes
        self.embedding_dim = embedding_dim
        
        # Gene-level weighted sum
        self.gene_weight = nn.Linear(1, 1, bias=False)
        
        # Patient-level MLP
        layers = []
        prev_dim = embedding_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.patient_mlp = nn.Sequential(*layers)
        
    def forward(self, patient_mutations, gene_embeddings):
        batch_size = patient_mutations.shape[0]
        
        # Weight gene embeddings by mutation status
        mutations_expanded = patient_mutations.unsqueeze(-1)
        weighted_mutations = self.gene_weight(mutations_expanded)
        
        gene_emb_expanded = gene_embeddings.unsqueeze(0)
        weighted_gene_embs = weighted_mutations * gene_emb_expanded
        
        # Sum across genes
        patient_repr = weighted_gene_embs.sum(dim=1)

        # Normalize patient representations
        # patient_repr = torch.nn.functional.normalize(patient_repr, p=2, dim=1)
        
        # Final classification
        logits = self.patient_mlp(patient_repr)
        return logits
    
class PFSAttentionMLP(nn.Module):
    def __init__(self, n_genes, embedding_dim, hidden_dims=[512, 256], num_heads=2):
        super().__init__()
        
        self.n_genes = n_genes
        self.embedding_dim = embedding_dim
        self.num_heads = min(num_heads, embedding_dim // 16)  # Ensure reasonable head size
        
        # Project mutation features to embedding dimension
        self.mutation_proj = nn.Sequential(
            nn.Linear(1, embedding_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(embedding_dim // 2, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
        
        # Multi-head attention between embeddings and mutations
        self.attention = nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=self.num_heads,
            batch_first=True,
            dropout=0.3
        )
        
        # Output projection after attention
        self.attention_proj = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Final MLP for classification
        layers = []
        prev_dim = embedding_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.classifier = nn.Sequential(*layers)
        
    def forward(self, patient_mutations, gene_embeddings):
        batch_size = patient_mutations.shape[0]
        
        # Project mutations to embedding space
        mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
        mutation_features = self.mutation_proj(mutations_expanded)  # (batch_size, n_genes, embedding_dim)
        
        # Prepare gene embeddings for attention
        gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
        
        # Use mutations as queries, gene embeddings as keys and values
        attended_features, attention_weights = self.attention(
            query=mutation_features,
            key=gene_emb_batch,
            value=gene_emb_batch
        )  # (batch_size, n_genes, embedding_dim)
        
        # Apply projection after attention
        attended_features = self.attention_proj(attended_features)
        
        # Weighted pooling based on mutation status
        mutation_weights = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
        weighted_features = attended_features * (mutation_weights + 0.1)  # Small bias for non-mutated genes
        
        # Global average pooling with normalization
        mutation_counts = patient_mutations.sum(dim=1, keepdim=True).unsqueeze(-1) + 1e-6  # (batch_size, 1, 1)
        patient_repr = weighted_features.sum(dim=1) / mutation_counts.squeeze(-1)  # (batch_size, embedding_dim)
        
        # Final classification
        logits = self.classifier(patient_repr)
        return logits

class PFSMutationOnlyMLP(nn.Module):
    def __init__(self, n_genes, hidden_dims=[512, 256]):
        super().__init__()
        
        self.n_genes = n_genes
        
        # Direct MLP on mutation features
        layers = []
        prev_dim = n_genes
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, patient_mutations):
        # Direct classification from mutation features
        logits = self.mlp(patient_mutations)
        return logits

class PFSClassifier:
    def __init__(self, use_kg = True,  device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = None
        self.use_kg = use_kg
        
    def create_model(self, n_genes, embedding_dim, **kwargs):
        if self.use_kg:
            self.model = PFSClassificationMLP(
            #self.model = PFSAttentionMLP(
                n_genes=n_genes,
                embedding_dim=embedding_dim,
                **kwargs
            ).to(self.device)
        else:
            self.model = PFSMutationOnlyMLP(
                n_genes=n_genes,
                **kwargs
            ).to(self.device)
        
        # Initialize weights
        for module in self.model.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def train_fold(self, train_mutations, train_labels, val_mutations, val_labels, 
                   gene_embeddings, epochs=150, lr=0.001, max_patience=50):
        
        # Convert to tensors and move to device
        train_mut = torch.FloatTensor(train_mutations).to(self.device)
        train_pfs = torch.FloatTensor(train_labels).to(self.device)
        val_mut = torch.FloatTensor(val_mutations).to(self.device)
        val_pfs = torch.FloatTensor(val_labels).to(self.device)
        gene_emb = torch.FloatTensor(gene_embeddings).to(self.device)
        
        # Calculate class weights for imbalanced data
        class_counts = np.bincount(train_labels.astype(int))
        if len(class_counts) == 2 and class_counts[1] > 0:
            pos_weight = torch.FloatTensor([class_counts[0] / class_counts[1]]).to(self.device)
        else:
            pos_weight = torch.FloatTensor([1.0]).to(self.device)
        
        # Setup training with class balancing
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        # optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        
        best_val_f1 = 0.0
        patience = 0
        best_state = None
        best_auc = 0.0
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            optimizer.zero_grad()
            if self.use_kg:
                logits = self.model(train_mut, gene_emb).squeeze()
            else:
                logits = self.model(train_mut).squeeze()
            loss = criterion(logits, train_pfs)
            loss.backward()
            optimizer.step()
            
            # Validation
            self.model.eval()
            with torch.no_grad():
                if self.use_kg:
                    val_logits = self.model(val_mut, gene_emb).squeeze()
                else:
                    val_logits = self.model(val_mut).squeeze()
                val_probs = torch.sigmoid(val_logits).cpu().numpy()
            
            val_preds = (val_probs > 0.5).astype(int)
            if len(np.unique(val_labels)) > 1:
                val_auc = roc_auc_score(val_labels, val_probs)
            else:
                val_auc = 0.5
            val_f1 = f1_score(val_labels, val_preds, zero_division=0)
            val_balanced_acc = balanced_accuracy_score(val_labels, val_preds)
            
            # Early stopping based on F1 score
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_auc = val_auc
                best_balanced_acc = val_balanced_acc
                patience = 0
                best_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            else:
                patience += 1
                if patience >= max_patience:  # Increased patience
                    break
        
        # Load best model
        if best_state is not None:
            self.model.load_state_dict(best_state)
        
        return best_auc, best_val_f1, best_balanced_acc, patience
    
    def cross_validate(self, patient_mutations, gene_embeddings, pfs_binary, 
                      n_splits=5, n_repeats=3, epochs=100, lr=0.001, **model_kwargs):
        
        n_genes, embedding_dim = gene_embeddings.shape
        
        # Setup repeated stratified k-fold cross-validation
        rskf = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)
        
        fold_aucs = []
        fold_f1s = []
        fold_balanced_accs = []
        
        fold_count = 0
        for train_idx, val_idx in rskf.split(patient_mutations, pfs_binary):
            fold_count += 1
            
            # Create fresh model for each fold
            self.create_model(n_genes, embedding_dim, **model_kwargs)
            
            # Get fold data
            train_mut = patient_mutations[train_idx]
            train_pfs = pfs_binary[train_idx]
            val_mut = patient_mutations[val_idx]
            val_pfs = pfs_binary[val_idx]
            
            # Train fold
            val_auc, val_f1, val_balanced_acc, tot_epochs = self.train_fold(train_mut, train_pfs, val_mut, val_pfs, 
                                                              gene_embeddings, epochs, lr)
            
            fold_aucs.append(val_auc)
            fold_f1s.append(val_f1)
            fold_balanced_accs.append(val_balanced_acc)
            
            print(f"Fold {fold_count}: AUC = {val_auc:.3f}, F1 = {val_f1:.3f}, Balanced Acc = {val_balanced_acc:.3f}, Max epoch = {tot_epochs}")
        
        mean_auc = np.mean(fold_aucs)
        std_auc = np.std(fold_aucs)
        mean_f1 = np.mean(fold_f1s)
        std_f1 = np.std(fold_f1s)
        mean_balanced_acc = np.mean(fold_balanced_accs)
        std_balanced_acc = np.std(fold_balanced_accs)
        
        print(f"Cross-validation results ({n_splits}-fold, {n_repeats} repeats):")
        print(f"AUC:          {mean_auc:.3f} ± {std_auc:.3f}")
        print(f"F1:           {mean_f1:.3f} ± {std_f1:.3f}")
        print(f"Balanced Acc: {mean_balanced_acc:.3f} ± {std_balanced_acc:.3f}")
        
        return {
            'fold_aucs': fold_aucs,
            'fold_f1s': fold_f1s,
            'fold_balanced_accs': fold_balanced_accs,
            'mean_auc': mean_auc,
            'std_auc': std_auc,
            'mean_f1': mean_f1,
            'std_f1': std_f1,
            'mean_balanced_acc': mean_balanced_acc,
            'std_balanced_acc': std_balanced_acc,
            'n_splits': n_splits,
            'n_repeats': n_repeats,
            'total_folds': fold_count
        }
    
    def fit(self, patient_mutations, gene_embeddings, pfs_binary, 
            epochs=100, lr=0.001, **model_kwargs):
        
        n_genes, embedding_dim = gene_embeddings.shape
        self.create_model(n_genes, embedding_dim, **model_kwargs)
        
        # Calculate class weights
        class_counts = np.bincount(pfs_binary.astype(int))
        if len(class_counts) == 2 and class_counts[1] > 0:
            pos_weight = torch.FloatTensor([class_counts[0] / class_counts[1]]).to(self.device)
        else:
            pos_weight = torch.FloatTensor([1.0]).to(self.device)
        
        # Convert to tensors
        mut_tensor = torch.FloatTensor(patient_mutations).to(self.device)
        pfs_tensor = torch.FloatTensor(pfs_binary).to(self.device)
        gene_emb = torch.FloatTensor(gene_embeddings).to(self.device)
        
        # Setup training with class balancing
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        
        # Training loop
        for epoch in range(epochs):
            self.model.train()
            optimizer.zero_grad()
            logits = self.model(mut_tensor, gene_emb).squeeze()
            loss = criterion(logits, pfs_tensor)
            loss.backward()
            optimizer.step()
    
    def predict(self, patient_mutations, gene_embeddings):
        self.model.eval()
        
        mut_tensor = torch.FloatTensor(patient_mutations).to(self.device)
        gene_emb = torch.FloatTensor(gene_embeddings).to(self.device)
        
        with torch.no_grad():
            logits = self.model(mut_tensor, gene_emb).squeeze()
            probabilities = torch.sigmoid(logits).cpu().numpy()
        
        predictions = (probabilities > 0.5).astype(int)
        
        return predictions, probabilities
    
    def evaluate(self, patient_mutations, gene_embeddings, true_labels):
        predictions, probabilities = self.predict(patient_mutations, gene_embeddings)
        
        metrics = {
            'accuracy': accuracy_score(true_labels, predictions),
            'f1': f1_score(true_labels, predictions, zero_division=0),
            'balanced_accuracy': balanced_accuracy_score(true_labels, predictions),
            'auc': roc_auc_score(true_labels, probabilities) if len(np.unique(true_labels)) > 1 else 0.5
        }
        
        return metrics

In [None]:

classifier = PFSClassifier()
cv_results = classifier.cross_validate(patient_mutations, gene_embeddings, pfs_binary, 10, 5, 250)
# This will run 5-fold CV repeated 3 times = 15 total folds for more robust estimates

# classifier.fit(patient_mutations, gene_embeddings, pfs_binary)
# metrics = classifier.evaluate(test_mutations, gene_embeddings, test_labels)
# print(f"Test AUC: {metrics['auc']:.3f}, Test F1: {metrics['f1']:.3f}")

In [None]:
classifier = PFSClassifier(use_kg=False)
cv_results = classifier.cross_validate(patient_mutations, gene_embeddings, pfs_binary, 10, 5, 250)

In [None]:
# Cross-validation results (20-fold, 5 repeats, usekg, mlp):
# AUC:          0.730 ± 0.205
# F1:           0.740 ± 0.139
# Balanced Acc: 0.775 ± 0.156

# Cross-validation results (20-fold, 5 repeats, nokg, mlp):
# AUC:          0.639 ± 0.207
# F1:           0.683 ± 0.118
# Balanced Acc: 0.719 ± 0.132