In [None]:
import pandas as pd
import numpy as np
from ChromaVDB.chroma import ChromaFramework
from DeepGraphDB import DeepGraphDB
from tqdm.notebook import tqdm
import torch
import dgl
from collections import defaultdict
from sklearn.model_selection import train_test_split
from DeepGraphDB.gnns.heteroSAGEattn import AdvancedHeteroLinkPredictor, compute_loss

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()

names = [record['name'] for record in records]
embs = [record['embeddings'] for record in records]
ids = [record['id'] for record in records]

data = pd.read_excel('data/2025_03_29.xlsx') # Provare ad usare anche stadio-avanzato, IPI e Log10hGE
# type: DLBCL (Diffuse Large B-cell Lymphoma) - G_id: 35884

In [None]:
# gdb.query_nodes_by_nodedata(entity_type="disease", feature_name="name", condition="in", value="B-cell lymphoma")
# nodes = gdb.get_k_hop_neighbors([35884], 2)
# print(len(nodes[1])+len(nodes[2])+1)

# node_names = set([gdb.node_data[gdb.global_to_local_mapping[idd][0]]['name'][gdb.global_to_local_mapping[idd][1]] for idd in nodes[1]] + \
#                 [gdb.node_data[gdb.global_to_local_mapping[idd][0]]['name'][gdb.global_to_local_mapping[idd][1]] for idd in nodes[2]] + [35884])

In [None]:
sub_g, neigh, map_gids = gdb.extract_subgraph([35884], 3, max_neighbors_per_hop=[100, 50, 25])
print(sub_g.number_of_nodes())
print(sub_g.number_of_edges())

In [None]:
gdb.graph = sub_g

in_feats = {ntype: sub_g.ndata['x']['disease'][0].shape[0] for ntype in gdb.graph.ntypes}
# target_entities = ['drug', 'disease', 'geneprotein', 'effectphenotype']
target_entities = ['geneprotein', 'disease', 'pathway', 'cellular_component', 'molecular_function']

# Choose multiple edge types for prediction
target_etypes = [ctype for ctype in gdb.graph.canonical_etypes if ctype[0] == "geneprotein" or ctype[2] in "geneprotein"]
# target_etypes = [ctype for ctype in db.graph.canonical_etypes if ctype[0] in target_entities and ctype[2] in target_entities]
target_etypes = [('disease', 'disease_protein', 'geneprotein'), ('geneprotein', 'protein_protein', 'geneprotein')]

print(f"Target edge types for prediction: {target_etypes}")

hidden_feats = 512
out_feats = 512

model = AdvancedHeteroLinkPredictor(
    node_types=gdb.graph.ntypes,  # All node types in the graph
    edge_types=gdb.graph.etypes,  # All edge types for GNN layers
    in_feats=in_feats,
    hidden_feats=hidden_feats,
    out_feats=out_feats,
    num_layers=3,
    use_attention=True,
    predictor_type='mlp',
    target_etypes=target_etypes  # Only target edge types for prediction
)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

embs = gdb.train_model(model, compute_loss, target_etypes, target_entities, 'cuda', bs=1000000, num_epochs=200)

In [None]:
embeddings = torch.load('/home/cc/PHD/dglframework/DeepKG/embeddings/embeddings_geneprotein.pth').cpu()

In [None]:
# gene_set = "LNF"
gene_set = "plasma"
# gene_measure = "MUT"
gene_measure = "VAF"

gene_data = data[[col for col in data.columns if gene_set in col and gene_measure in col]]
genes = list(set([ gene.split('_')[0] for gene in gene_data.columns  ]))

genes_in_subg = [ gdb.node_data['geneprotein']['name'][gdb.global_to_local_mapping[gid][1]].lower() for gid in map_gids['geneprotein'] ]

In [None]:
match_genes = []
index_matched_genes = []
final_columns = []

for index, gene in enumerate(genes):
    if gene.lower() in genes_in_subg:
        match_genes.append(gene)
        index_matched_genes.append(genes.index(gene))
        final_columns.append(gene+"_"+gene_set+"_"+gene_measure)

In [None]:
gene_data = gene_data[final_columns]
gene_data['pfs'] = data['PFS_Cens_updated']
# gene_data = gene_data.dropna()
gene_data = gene_data.fillna(0)

In [None]:
labels = gene_data['pfs'].values
mutants = gene_data.drop(columns=['pfs'])
patient_mutations = torch.tensor(mutants.values)
gene_embeddings = embeddings[index_matched_genes]

In [None]:
import torch.nn as nn

# Weight gene embeddings by mutation status
mutations_expanded = patient_mutations.unsqueeze(-1)
gene_weight = nn.Linear(1, 1, bias=False, dtype=torch.float64)
weighted_mutations = gene_weight(mutations_expanded)

gene_emb_expanded = gene_embeddings.unsqueeze(0)
weighted_gene_embs = weighted_mutations * gene_emb_expanded

# Sum across genes
patient_repr = weighted_gene_embs.sum(dim=1).detach()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import torch

def visualize_embeddings_tsne(embeddings, labels, strings, perplexity=30, n_iter=1000, random_state=42):
    """
    Visualize tensor embeddings using t-SNE with binary labels and interactive string display
    
    Args:
        embeddings: List of tensors or numpy arrays, or a single tensor/array
        labels: List of binary labels (0s and 1s)
        strings: List of strings corresponding to each embedding
        perplexity: t-SNE perplexity parameter (default: 30)
        n_iter: Number of iterations for t-SNE (default: 1000)
        random_state: Random state for reproducibility (default: 42)
    """
    # Enable interactive backend for notebooks
    try:
        # Try to use widget backend for better notebook support
        import matplotlib
        if 'ipykernel' in str(type(get_ipython())):
            print("Notebook detected. Setting up interactive backend...")
            matplotlib.use('widget')  # or 'nbagg'
            plt.ioff()  # Turn off interactive mode temporarily
    except:
        pass
    
    # Convert tensors to numpy if needed
    if isinstance(embeddings, list):
        if torch.is_tensor(embeddings[0]):
            embs_np = torch.stack(embeddings).detach().cpu().numpy()
        else:
            embs_np = np.array(embeddings)
    elif torch.is_tensor(embeddings):
        embs_np = embeddings.detach().cpu().numpy()
    else:
        embs_np = embeddings
    
    # Reshape if needed (flatten each embedding)
    if len(embs_np.shape) > 2:
        embs_np = embs_np.reshape(embs_np.shape[0], -1)
    
    # Convert labels and strings to numpy arrays
    labels_np = np.array(labels)
    strings_np = np.array(strings)
    
    # Validate input lengths
    if len(embs_np) != len(labels_np) or len(embs_np) != len(strings_np):
        raise ValueError("Embeddings, labels, and strings must have the same length")
    
    print(f"Embedding shape: {embs_np.shape}")
    print(f"Labels shape: {labels_np.shape}")
    print(f"Strings count: {len(strings_np)}")
    print(f"Unique labels: {np.unique(labels_np)}")
    
    # Apply t-SNE
    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, init='pca', perplexity=perplexity, n_iter=n_iter, random_state=random_state)
    embeddings_2d = tsne.fit_transform(embs_np)
    
    # Create the interactive plot
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot points with different colors for different labels
    colors = ['red', 'blue']
    labels_text = ['Label 0', 'Label 1']
    
    for i, label in enumerate([0, 1]):
        mask = labels_np == label
        if np.any(mask):
            ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                      c=colors[i], label=labels_text[i], alpha=0.7, s=50)
    
    ax.set_title('Interactive t-SNE Visualization\n(Click on points to see text below)', fontsize=16)
    ax.set_xlabel('t-SNE Component 1', fontsize=12)
    ax.set_ylabel('t-SNE Component 2', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Create a text area below the plot for displaying strings
    fig.text(0.1, 0.02, 'Click on a point to see its text here...', 
             fontsize=10, ha='left', va='bottom',
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8),
             wrap=True)
    
    # Store the text object for updates
    text_display = fig.texts[-1]
    
    def on_click(event):
        """Handle mouse click events"""
        if event.inaxes != ax:
            return
        
        click_x, click_y = event.xdata, event.ydata
        if click_x is None or click_y is None:
            return
        
        # Find the closest point
        distances = np.sqrt((embeddings_2d[:, 0] - click_x)**2 + 
                           (embeddings_2d[:, 1] - click_y)**2)
        closest_idx = np.argmin(distances)
        
        # Update text display
        selected_string = strings_np[closest_idx]
        selected_label = labels_np[closest_idx]
        
        # Wrap text for better display
        wrapped_text = '\n'.join([selected_string[i:i+80] for i in range(0, len(selected_string), 80)])
        display_text = f"Selected Point {closest_idx} (Label {selected_label}):\n{wrapped_text}"
        
        text_display.set_text(display_text)
        
        # Highlight selected point
        ax.scatter(embeddings_2d[closest_idx, 0], embeddings_2d[closest_idx, 1],
                  c='yellow', s=150, alpha=0.8, marker='o', 
                  edgecolors='black', linewidth=2, zorder=5)
        
        fig.canvas.draw()
        
        # Also print to console as backup
        print(f"\nSelected Point {closest_idx} (Label {selected_label}):")
        print(f"Text: {selected_string}")
        print("-" * 50)
    
    # Connect the click event
    cid = fig.canvas.mpl_connect('button_press_event', on_click)
    
    # Enable interactive mode
    plt.ion()
    plt.tight_layout()
    
    # Show with proper backend
    plt.show()
    
    # Store connection ID for cleanup if needed
    fig._click_connection = cid
    
    return embeddings_2d, tsne, fig

# Alternative function for notebooks that definitely works
def visualize_embeddings_tsne_notebook(embeddings, labels, strings, perplexity=30, n_iter=1000, random_state=42):
    """
    Notebook-friendly version that prints strings when you call show_point_text()
    """
    # [Same preprocessing code as above]
    if isinstance(embeddings, list):
        if torch.is_tensor(embeddings[0]):
            embs_np = torch.stack(embeddings).detach().cpu().numpy()
        else:
            embs_np = np.array(embeddings)
    elif torch.is_tensor(embeddings):
        embs_np = embeddings.detach().cpu().numpy()
    else:
        embs_np = embeddings
    
    if len(embs_np.shape) > 2:
        embs_np = embs_np.reshape(embs_np.shape[0], -1)
    
    labels_np = np.array(labels)
    strings_np = np.array(strings)
    
    if len(embs_np) != len(labels_np) or len(embs_np) != len(strings_np):
        raise ValueError("Embeddings, labels, and strings must have the same length")
    
    print(f"Embedding shape: {embs_np.shape}")
    print(f"Labels shape: {labels_np.shape}")
    print(f"Strings count: {len(strings_np)}")
    
    # Apply t-SNE
    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, init='pca', perplexity=perplexity, n_iter=n_iter, random_state=random_state)
    embeddings_2d = tsne.fit_transform(embs_np)
    
    # Create static plot with numbered points
    plt.figure(figsize=(12, 8))
    colors = ['red', 'blue']
    labels_text = ['Label 0', 'Label 1']
    
    for i, label in enumerate([0, 1]):
        mask = labels_np == label
        if np.any(mask):
            plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                       c=colors[i], label=labels_text[i], alpha=0.7, s=50)
    
    # Add point numbers as text annotations
    for i, (x, y) in enumerate(embeddings_2d):
        plt.annotate(str(i), (x, y), xytext=(2, 2), textcoords='offset points',
                    fontsize=6, alpha=0.7)
    
    plt.title('t-SNE Visualization (Points are numbered)\nUse show_point_text(point_number) to see text', fontsize=14)
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Return a function to show point text
    def show_point_text(point_idx):
        if 0 <= point_idx < len(strings_np):
            print(f"\nPoint {point_idx} (Label {labels_np[point_idx]}):")
            print(f"Text: {strings_np[point_idx]}")
            print("-" * 50)
        else:
            print(f"Point {point_idx} not found. Valid range: 0-{len(strings_np)-1}")
    
    # Create a results object to return
    class TSNEResults:
        def __init__(self, embeddings_2d, tsne_model, strings, labels):
            self.embeddings_2d = embeddings_2d
            self.tsne_model = tsne_model
            self.strings = strings
            self.labels = labels
            self.show_text = show_point_text
    
    return TSNEResults(embeddings_2d, tsne, strings_np, labels_np)

In [None]:
results = visualize_embeddings_tsne_notebook(patient_repr, labels, [ "pat_"+str(text) for text in list(data.index) ], perplexity=15)

In [None]:
indices = [28, 125, 144, 9, 99, 80, 142, 68, 114, 118, 160, 14, 151, 76, 70, 86, 152, 84, 115, 117, 165, 112, 72, 79, 85, 22, 131, 66, 129, 43, 120, 164, 78, 38]
cluster_pat = data.loc[indices]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import seaborn as sns

class EmbeddingClusterer:
    def __init__(self, embeddings, labels, strings):
        """
        Initialize the clusterer with embeddings, binary labels, and associated strings.
        
        Args:
            embeddings: numpy array of shape (n_samples, n_features)
            labels: numpy array of binary labels (0 or 1)
            strings: list of strings associated with each embedding
        """
        self.embeddings = np.array(embeddings)
        self.labels = np.array(labels)
        self.strings = strings
        self.clusters = None
        self.reduced_embeddings = None
        
    def perform_clustering(self, n_clusters=5, method='kmeans'):
        """
        Perform clustering on the embeddings.
        
        Args:
            n_clusters: number of clusters
            method: clustering method ('kmeans' supported)
        """
        if method == 'kmeans':
            clusterer = KMeans(n_clusters=n_clusters, random_state=42)
            self.clusters = clusterer.fit_predict(self.embeddings)
        else:
            raise ValueError("Only 'kmeans' clustering is currently supported")
            
        return self.clusters
    
    def reduce_dimensions(self, method='pca', n_components=2):
        """
        Reduce dimensionality for visualization.
        
        Args:
            method: 'pca' or 'tsne'
            n_components: number of dimensions (2 for visualization)
        """
        if method == 'pca':
            reducer = PCA(n_components=n_components, random_state=42)
            self.reduced_embeddings = reducer.fit_transform(self.embeddings)
        elif method == 'tsne':
            reducer = TSNE(n_components=n_components, random_state=42, perplexity=min(25, len(self.embeddings)-1))
            self.reduced_embeddings = reducer.fit_transform(self.embeddings)
        else:
            raise ValueError("Supported methods: 'pca', 'tsne'")
            
        return self.reduced_embeddings
    
    def plot_clusters(self, figsize=(12, 5), show_strings=False):
        """
        Create visualizations colored by both clusters and binary labels.
        
        Args:
            figsize: figure size tuple
            show_strings: whether to annotate points with strings (only for small datasets)
        """
        if self.clusters is None:
            raise ValueError("Please run perform_clustering() first")
        if self.reduced_embeddings is None:
            raise ValueError("Please run reduce_dimensions() first")
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot 1: Colored by clusters
        scatter1 = ax1.scatter(self.reduced_embeddings[:, 0], 
                              self.reduced_embeddings[:, 1], 
                              c=self.clusters, 
                              cmap='tab10', 
                              alpha=0.7,
                              s=50)
        ax1.set_title('Embeddings Colored by Clusters')
        ax1.set_xlabel('Component 1')
        ax1.set_ylabel('Component 2')
        plt.colorbar(scatter1, ax=ax1, label='Cluster')
        
        # Plot 2: Colored by binary labels
        colors = ['red' if label == 0 else 'blue' for label in self.labels]
        scatter2 = ax2.scatter(self.reduced_embeddings[:, 0], 
                              self.reduced_embeddings[:, 1], 
                              c=colors, 
                              alpha=0.7,
                              s=50)
        ax2.set_title('Embeddings Colored by Binary Labels')
        ax2.set_xlabel('Component 1')
        ax2.set_ylabel('Component 2')
        
        # Add legend for binary labels
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='red', label='Label 0'),
                          Patch(facecolor='blue', label='Label 1')]
        ax2.legend(handles=legend_elements)
        
        # Optionally add string annotations (only for small datasets)
        if show_strings and len(self.strings) <= 50:
            for i, txt in enumerate(self.strings):
                ax1.annotate(txt, (self.reduced_embeddings[i, 0], self.reduced_embeddings[i, 1]), 
                           fontsize=8, alpha=0.7)
                ax2.annotate(txt, (self.reduced_embeddings[i, 0], self.reduced_embeddings[i, 1]), 
                           fontsize=8, alpha=0.7)
        
        plt.tight_layout()
        plt.show()
    
    def analyze_clusters(self):
        """
        Analyze the relationship between clusters and binary labels.
        """
        if self.clusters is None:
            raise ValueError("Please run perform_clustering() first")
            
        print("Cluster Analysis:")
        print("-" * 50)
        
        unique_clusters = np.unique(self.clusters)
        for cluster_id in unique_clusters:
            mask = self.clusters == cluster_id
            cluster_labels = self.labels[mask]
            cluster_strings = [self.strings[i] for i in range(len(self.strings)) if mask[i]]
            
            label_0_count = np.sum(cluster_labels == 0)
            label_1_count = np.sum(cluster_labels == 1)
            total_count = len(cluster_labels)
            
            print(f"Cluster {cluster_id}:")
            print(f"  Total points: {total_count}")
            print(f"  Label 0: {label_0_count} ({label_0_count/total_count*100:.1f}%)")
            print(f"  Label 1: {label_1_count} ({label_1_count/total_count*100:.1f}%)")
            print(f"  Sample strings: {cluster_strings[:3]}")  # Show first 3 strings
            print()

In [None]:
clusterer = EmbeddingClusterer(patient_repr.numpy(), labels, [ "pat_"+str(text) for text in list(data.index) ])

# Perform clustering
clusters = clusterer.perform_clustering(n_clusters=3)

# Reduce dimensions for visualization
reduced_emb = clusterer.reduce_dimensions(method='tsne')

# Create visualizations
clusterer.plot_clusters(show_strings=False)  # Set to True for small datasets

# Analyze clusters
clusterer.analyze_clusters()