In [None]:
import pandas as pd
import numpy as np
from ChromaVDB.chroma import ChromaFramework
from DeepGraphDB import DeepGraphDB
from tqdm.notebook import tqdm
import torch

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()

names = [record['name'] for record in records]
embs = [record['embeddings'] for record in records]
ids = [record['id'] for record in records]

data = pd.read_excel('data/2025_03_29.xlsx') # Provare ad usare anche stadio-avanzato, IPI e Log10hGE

In [None]:
# gene_set = "LNF"
gene_set = "plasma"
# gene_measure = "MUT"
gene_measure = "VAF"

gene_data = data[[col for col in data.columns if gene_set in col and gene_measure in col]]

In [None]:
genes = list(set([ gene.split('_')[0] for gene in gene_data.columns  ]))

final_columns = []
embeddings = []
record_ids = []

for gene in genes:
    if gene in names:
        final_columns.append(gene+"_"+gene_set+"_"+gene_measure)
        embeddings.append(embs[names.index(gene)])
        record_ids.append(ids[names.index(gene)])
    else:
        print(gene)

print(len(final_columns))

gene_data = gene_data[final_columns]
gene_data['pfs'] = data['PFS_Cens_updated']
# gene_data = gene_data.dropna()
gene_data = gene_data.fillna(0)

In [None]:
def encode_patient(vdb, gdb, record_ids, gene_data):
    label = int(gene_data['pfs'])
    gene_data = gene_data.drop(labels=['pfs'])

    inv = dict(zip(vdb.global_to_vids_mapping.values(), vdb.global_to_vids_mapping.keys()))
    g_ids = [ inv[id] for id in record_ids ]

    ctypes = [ ctype for ctype in  gdb.graph.canonical_etypes if ctype[0] == "geneprotein" and ctype[2] == "geneprotein" ]

    # sb = gdb.extract_subgraph(np.array(g_ids)[gene_data.loc[0].values > 0], 2, max_neighbors_per_hop=[50, 25])
    sb = gdb.get_k_hop_neighbors(np.array(g_ids)[gene_data.values > 0], 1, edge_types=ctypes)

    flat_ids = np.array([item for sublist in sb.values() for item in sublist])

    start_ids = np.setdiff1d(flat_ids, np.array(g_ids)[gene_data.values <= 0])
    print(f"Genes included in graph: {start_ids.shape}")

    ctypes_2 = [ ctype for ctype in  gdb.graph.canonical_etypes if ctype[0] == "geneprotein" or ctype[2] == "geneprotein" ]
    ctypes_2.remove(('geneprotein', 'protein_protein', 'geneprotein'))

    sb_2 = gdb.get_k_hop_neighbors(start_ids, 1, edge_types=ctypes_2)
    flat_ids_final = np.array([item for sublist in sb_2.values() for item in sublist])
    print(f"Total nodes to embed: {flat_ids_final.shape}")

    return flat_ids_final

    # ids_to_search = [ vdb.global_to_vids_mapping[id] for id in flat_ids_final ]
    # retrived_records = vdb.read_record(ids_to_search, include_embeddings=True)

    # return torch.tensor(np.array([ record['embeddings']['graph'] for record in retrived_records ])).mean(dim=0), label

In [None]:
final_ids = encode_patient(vdb, gdb, record_ids, gene_data.loc[0])

In [None]:
from collections import defaultdict

local_ids = [ gdb.global_to_local_mapping[g_id] for g_id in final_ids ]

# Create dictionary with entity names as keys and lists of IDs as values
entity_dict = defaultdict(list)

for entity_name, entity_id in local_ids:
    entity_dict[entity_name].append(entity_id)

# Convert to regular dict if needed
entity_dict = dict(entity_dict)

print(entity_dict)

In [None]:
import dgl
import torch
import numpy as np
from collections import defaultdict
from typing import Dict, List, Tuple, Set
        
def create_subgraph_from_node_lists(graph, node_lists: Dict[str, List[int]]) -> dgl.DGLGraph:
    """
    Create a subgraph from lists of node IDs for each entity type.
    
    Args:
        node_lists: Dictionary mapping node type to list of node IDs
                    e.g., {'user': [0, 1, 5], 'item': [2, 7, 9], 'category': [1, 3]}
    
    Returns:
        DGL heterogeneous subgraph containing only specified nodes and their connections
    """
    # Convert node lists to sets for faster lookup
    target_nodes = {ntype: set(nodes) for ntype, nodes in node_lists.items()}
    
    # Create mapping from old node IDs to new node IDs
    old_to_new_mapping = {}
    new_node_counts = {}
    
    for ntype in graph.ntypes:
        if ntype in target_nodes:
            old_ids = sorted(list(target_nodes[ntype]))
            new_ids = list(range(len(old_ids)))
            old_to_new_mapping[ntype] = dict(zip(old_ids, new_ids))
            new_node_counts[ntype] = len(old_ids)
        else:
            old_to_new_mapping[ntype] = {}
            new_node_counts[ntype] = 0
    
    # Extract edges for each edge type
    subgraph_edges = {}
    
    for canonical_etype in graph.canonical_etypes:
        src_type, etype, dst_type = canonical_etype
        
        # Skip if either source or destination type has no target nodes
        if (src_type not in target_nodes or dst_type not in target_nodes or
            len(target_nodes[src_type]) == 0 or len(target_nodes[dst_type]) == 0):
            subgraph_edges[canonical_etype] = ([], [])
            continue
        
        # Get all edges of this type from the full graph
        src_nodes, dst_nodes = graph.edges(etype=canonical_etype)
        src_nodes = src_nodes.numpy()
        dst_nodes = dst_nodes.numpy()
        
        # Filter edges to only include those between target nodes
        valid_edges = []
        for i, (src, dst) in enumerate(zip(src_nodes, dst_nodes)):
            if (src in target_nodes[src_type] and 
                dst in target_nodes[dst_type]):
                # Map to new node IDs
                new_src = old_to_new_mapping[src_type][src]
                new_dst = old_to_new_mapping[dst_type][dst]
                valid_edges.append((new_src, new_dst))
        
        if valid_edges:
            src_list, dst_list = zip(*valid_edges)
            subgraph_edges[canonical_etype] = (list(src_list), list(dst_list))
        else:
            subgraph_edges[canonical_etype] = ([], [])
    
    # Create the new heterogeneous graph
    subgraph = dgl.heterograph(subgraph_edges, num_nodes_dict=new_node_counts)
    
    # Copy node features if they exist
    _copy_node_features(graph, subgraph, old_to_new_mapping, target_nodes)
    
    # Copy edge features if they exist
    _copy_edge_features(graph, subgraph, subgraph_edges, old_to_new_mapping, target_nodes)
    
    return subgraph

def _copy_node_features(graph: dgl.DGLGraph, subgraph: dgl.DGLGraph, 
                        old_to_new_mapping: Dict[str, Dict[int, int]],
                        target_nodes: Dict[str, Set[int]]):
    """Copy node features from full graph to subgraph."""
    for ntype in graph.ntypes:
        if ntype not in target_nodes or len(target_nodes[ntype]) == 0:
            continue
            
        # Get node features from full graph
        node_data = graph.nodes[ntype].data
        
        for feat_name, feat_tensor in node_data.items():
            # Get the old node IDs in the order they appear in the new graph
            old_ids = sorted(list(target_nodes[ntype]))
            
            # Extract features for target nodes
            subgraph_features = feat_tensor[old_ids]
            
            # Set features in subgraph
            subgraph.nodes[ntype].data[feat_name] = subgraph_features

def _copy_edge_features(graph: dgl.DGLGraph, subgraph: dgl.DGLGraph,
                        subgraph_edges: Dict[Tuple[str, str, str], Tuple[List[int], List[int]]],
                        old_to_new_mapping: Dict[str, Dict[int, int]],
                        target_nodes: Dict[str, Set[int]]):
    """Copy edge features from full graph to subgraph."""
    for canonical_etype in graph.canonical_etypes:
        src_type, etype, dst_type = canonical_etype
        
        # Skip if no edges of this type in subgraph
        if len(subgraph_edges[canonical_etype][0]) == 0:
            continue
        
        # Get edge data from full graph
        edge_data = graph.edges[canonical_etype].data
        
        if not edge_data:
            continue
        
        # Get original edges
        orig_src, orig_dst = graph.edges(etype=canonical_etype)
        orig_src = orig_src.numpy()
        orig_dst = orig_dst.numpy()
        
        # Find indices of edges that are in the subgraph
        edge_indices = []
        for i, (src, dst) in enumerate(zip(orig_src, orig_dst)):
            if (src in target_nodes[src_type] and 
                dst in target_nodes[dst_type]):
                edge_indices.append(i)
        
        # Copy edge features
        for feat_name, feat_tensor in edge_data.items():
            if edge_indices:
                subgraph_features = feat_tensor[edge_indices]
                subgraph.edges[canonical_etype].data[feat_name] = subgraph_features


In [None]:
subgraph = create_subgraph_from_node_lists(gdb.graph, entity_dict)

In [None]:
from DeepGraphDB.gnns.patientTuning import PatientSpecificFineTuner

ctypes = [ ctype for ctype in subgraph.canonical_etypes if ctype[0] == "geneprotein" or ctype[1] == "geneprotein"]

node_types = subgraph.ntypes
edge_types = ctypes
target_etypes = ctypes

# Initialize fine-tuner with pretrained model
fine_tuner = PatientSpecificFineTuner(
    base_model_path='model.pt',
    node_types=node_types,
    edge_types=edge_types,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Example patient data (you would load your actual patient subgraphs)
patient_graphs = {
    'patient_1': subgraph
}

# Fine-tune models for all patients
results = fine_tuner.batch_fine_tune_patients(
    patient_data_dict=patient_graphs,
    target_etypes=target_etypes,
    save_dir='/home/cc/PHD/dglframework/DeepKG/embeddings/patient_specific_models'
)

# Print results
for patient_id, result in results.items():
    print(f"Patient {patient_id}: AUC = {result['best_auc']:.4f}")

In [None]:
patient_embs = []
labels = []

for i, row in tqdm(gene_data.iterrows()):
    print(f"--- Patient {i} ---")
    p_emb, label = encode_patient(vdb, gdb, record_ids, gene_data.loc[i])
    
    if not torch.any(p_emb.isnan()):
        patient_embs.append(p_emb)
        labels.append(label)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import torch

def visualize_embeddings_tsne(embeddings, labels, perplexity=30, n_iter=1000, random_state=42):
    """
    Visualize tensor embeddings using t-SNE with binary labels
    
    Args:
        embeddings: List of tensors or numpy arrays, or a single tensor/array
        labels: List of binary labels (0s and 1s)
        perplexity: t-SNE perplexity parameter (default: 30)
        n_iter: Number of iterations for t-SNE (default: 1000)
        random_state: Random state for reproducibility (default: 42)
    """
    
    # Convert tensors to numpy if needed
    if isinstance(embeddings, list):
        if torch.is_tensor(embeddings[0]):
            # Convert list of tensors to numpy array
            embs_np = torch.stack(embeddings).detach().cpu().numpy()
        else:
            # Assume list of numpy arrays
            embs_np = np.array(embeddings)
    elif torch.is_tensor(embeddings):
        # Single tensor
        embs_np = embeddings.detach().cpu().numpy()
    else:
        # Assume numpy array
        embs_np = embeddings
    
    # Reshape if needed (flatten each embedding)
    if len(embs_np.shape) > 2:
        embs_np = embs_np.reshape(embs_np.shape[0], -1)
    
    # Convert labels to numpy array
    labels_np = np.array(labels)
    
    print(f"Embedding shape: {embs_np.shape}")
    print(f"Labels shape: {labels_np.shape}")
    print(f"Unique labels: {np.unique(labels_np)}")
    
    # Apply t-SNE
    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, init='pca', perplexity=perplexity, n_iter=n_iter, random_state=random_state)
    # tsne = PCA(n_components=2)
    embeddings_2d = tsne.fit_transform(embs_np)
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    # Plot points with different colors for different labels
    colors = ['red', 'blue']
    labels_text = ['Label 0', 'Label 1']
    
    for i, label in enumerate([0, 1]):
        mask = labels_np == label
        if np.any(mask):  # Only plot if this label exists
            plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
                       c=colors[i], label=labels_text[i], alpha=0.7, s=50)
    
    plt.title('t-SNE Visualization of Embeddings', fontsize=16)
    plt.xlabel('t-SNE Component 1', fontsize=12)
    plt.ylabel('t-SNE Component 2', fontsize=12)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Show the plot
    plt.show()
    
    return embeddings_2d, tsne

In [None]:
embeddings_2d, tsne_model = visualize_embeddings_tsne(patient_embs, labels, perplexity=25)

print(f"Final 2D embeddings shape: {embeddings_2d.shape}")