In [None]:
import pandas as pd
import numpy as np
from ChromaVDB.chroma import ChromaFramework
from DeepGraphDB import DeepGraphDB
from tqdm.notebook import tqdm
import torch

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()

names = [record['name']for record in records]
embs = [record['embeddings'] for record in records]
ids = [record['id'] for record in records]

data = pd.read_csv("/home/cc/PHD/dglframework/cptac/patient_gene_matrix_brca.csv", low_memory=False)
data = data[(data['site_of_resection_or_biopsy'] == 'Breast, NOS') & (data['primary_diagnosis'].isin(['Infiltrating duct carcinoma, NOS','Lobular carcinoma, NOS']))]

In [None]:
def find_intersection_numpy(genes, names):
    genes_arr = np.array(genes)
    names_arr = np.array(names)
    
    # Use isin for vectorized operation
    mask = np.isin(names_arr, genes_arr)
    
    # Get intersection and indices
    intersection = names_arr[mask]
    indices = np.where(mask)[0]

    print(intersection.shape)
    
    return np.char.add(intersection, "_mutated").tolist(), indices.tolist()

In [None]:
genes = [ mut.split("_")[0] for mut in list(data.columns) if "mutated" in mut ]

final_columns, target_indices = find_intersection_numpy(genes, names)

In [None]:
labels_df = data[['vital_status', 'primary_diagnosis']]
data = data[final_columns]
data = data[~(data == 0).all(axis=1)]
labels_df = labels_df.loc[data.index]

len(data)

In [None]:
# labels_df = labels_df[labels_df['primary_diagnosis'] == "Infiltrating duct carcinoma, NOS"]
labels_df = labels_df[labels_df['vital_status'] == "Alive"]

In [None]:
 labels_df

In [None]:
def encode_patient(vdb, gdb, record_ids, gene_data, inv, ctypes, ctypes2):
    label = 0

    g_ids = [ inv[id] for id in record_ids ]

    sb = gdb.get_k_hop_neighbors(np.array(g_ids)[gene_data.values > 0], 1, edge_types=ctypes, flat=True)
    flat_ids = np.array(list(sb))

    start_ids = np.setdiff1d(flat_ids, np.array(g_ids)[gene_data.values <= 0])
    print(f"Genes included in graph: {start_ids.shape}")

    sb_2 = gdb.get_k_hop_neighbors(start_ids, 1, edge_types=ctypes2, flat=True)
    flat_ids_final = np.array(list(sb_2))
    print(f"Total nodes to embed: {flat_ids_final.shape}")

    ids_to_search = [ vdb.global_to_vids_mapping[id] for id in flat_ids_final ]
    retrived_records = vdb.read_record(ids_to_search, include_embeddings=True)

    return torch.tensor(np.array([ record['embeddings']['graph'] for record in retrived_records ])).mean(dim=0), label

patient_embs = []
# labels = []

inv = dict(zip(vdb.global_to_vids_mapping.values(), vdb.global_to_vids_mapping.keys()))

ctypes = [ ctype for ctype in  gdb.graph.canonical_etypes if ctype[0] == "geneprotein" and ctype[2] == "geneprotein" ]

ctypes2 = [ ctype for ctype in  gdb.graph.canonical_etypes if ctype[0] == "geneprotein" or ctype[2] == "geneprotein" ]
ctypes2.remove(('geneprotein', 'protein_protein', 'geneprotein'))

sample_0 = data.loc[labels_df[labels_df['primary_diagnosis'] == 'Infiltrating duct carcinoma, NOS'].sample(n=50, random_state=42).index]
sample_1 = data.loc[labels_df[labels_df['primary_diagnosis'] == 'Lobular carcinoma, NOS'].sample(n=50, random_state=42).index]

# Combine both samples
samples = pd.concat([sample_0, sample_1], ignore_index=True)
labels = [0] * 50 + [1] * 50


for i in tqdm(samples.index):
    print(f"--- Patient {i} ---")
    p_emb, _ = encode_patient(vdb, gdb, np.array(ids)[target_indices], samples.loc[i], inv, ctypes, ctypes2)
    
    if not torch.any(p_emb.isnan()):
        patient_embs.append(p_emb)
        # labels.append(label)

In [None]:
samples.index

In [None]:
# import torch.nn as nn

# patient_mutations = torch.tensor(data.values, dtype=torch.float64)
# gene_embeddings = torch.tensor(np.array(embs)[target_indices])

# # Weight gene embeddings by mutation status
# mutations_expanded = patient_mutations.unsqueeze(-1)
# gene_weight = nn.Linear(1, 1, bias=False, dtype=torch.float64)
# weighted_mutations = gene_weight(mutations_expanded)

# gene_emb_expanded = gene_embeddings.unsqueeze(0)
# weighted_gene_embs = weighted_mutations * gene_emb_expanded

# # Sum across genes
# patient_repr = weighted_gene_embs.sum(dim=1).detach()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import torch

def visualize_embeddings_tsne(embeddings, labels, strings, perplexity=30, n_iter=1000, random_state=42):
    """
    Visualize tensor embeddings using t-SNE with binary labels and interactive string display
    
    Args:
        embeddings: List of tensors or numpy arrays, or a single tensor/array
        labels: List of binary labels (0s and 1s)
        strings: List of strings corresponding to each embedding
        perplexity: t-SNE perplexity parameter (default: 30)
        n_iter: Number of iterations for t-SNE (default: 1000)
        random_state: Random state for reproducibility (default: 42)
    """
    # Enable interactive backend for notebooks
    try:
        # Try to use widget backend for better notebook support
        import matplotlib
        if 'ipykernel' in str(type(get_ipython())):
            print("Notebook detected. Setting up interactive backend...")
            matplotlib.use('widget')  # or 'nbagg'
            plt.ioff()  # Turn off interactive mode temporarily
    except:
        pass
    
    # Convert tensors to numpy if needed
    if isinstance(embeddings, list):
        if torch.is_tensor(embeddings[0]):
            embs_np = torch.stack(embeddings).detach().cpu().numpy()
        else:
            embs_np = np.array(embeddings)
    elif torch.is_tensor(embeddings):
        embs_np = embeddings.detach().cpu().numpy()
    else:
        embs_np = embeddings
    
    # Reshape if needed (flatten each embedding)
    if len(embs_np.shape) > 2:
        embs_np = embs_np.reshape(embs_np.shape[0], -1)
    
    # Convert labels and strings to numpy arrays
    labels_np = np.array(labels)
    strings_np = np.array(strings)
    
    # Validate input lengths
    if len(embs_np) != len(labels_np) or len(embs_np) != len(strings_np):
        raise ValueError("Embeddings, labels, and strings must have the same length")
    
    print(f"Embedding shape: {embs_np.shape}")
    print(f"Labels shape: {labels_np.shape}")
    print(f"Strings count: {len(strings_np)}")
    print(f"Unique labels: {np.unique(labels_np)}")
    
    # Apply t-SNE
    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, init='pca', perplexity=perplexity, n_iter=n_iter, random_state=random_state)
    embeddings_2d = tsne.fit_transform(embs_np)
    
    # Create the interactive plot
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot points with different colors for different labels
    colors = ['red', 'blue']
    labels_text = ['Label 0', 'Label 1']
    
    for i, label in enumerate([0, 1]):
        mask = labels_np == label
        if np.any(mask):
            ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                      c=colors[i], label=labels_text[i], alpha=0.7, s=50)
    
    ax.set_title('Interactive t-SNE Visualization\n(Click on points to see text below)', fontsize=16)
    ax.set_xlabel('t-SNE Component 1', fontsize=12)
    ax.set_ylabel('t-SNE Component 2', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Create a text area below the plot for displaying strings
    fig.text(0.1, 0.02, 'Click on a point to see its text here...', 
             fontsize=10, ha='left', va='bottom',
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8),
             wrap=True)
    
    # Store the text object for updates
    text_display = fig.texts[-1]
    
    def on_click(event):
        """Handle mouse click events"""
        if event.inaxes != ax:
            return
        
        click_x, click_y = event.xdata, event.ydata
        if click_x is None or click_y is None:
            return
        
        # Find the closest point
        distances = np.sqrt((embeddings_2d[:, 0] - click_x)**2 + 
                           (embeddings_2d[:, 1] - click_y)**2)
        closest_idx = np.argmin(distances)
        
        # Update text display
        selected_string = strings_np[closest_idx]
        selected_label = labels_np[closest_idx]
        
        # Wrap text for better display
        wrapped_text = '\n'.join([selected_string[i:i+80] for i in range(0, len(selected_string), 80)])
        display_text = f"Selected Point {closest_idx} (Label {selected_label}):\n{wrapped_text}"
        
        text_display.set_text(display_text)
        
        # Highlight selected point
        ax.scatter(embeddings_2d[closest_idx, 0], embeddings_2d[closest_idx, 1],
                  c='yellow', s=150, alpha=0.8, marker='o', 
                  edgecolors='black', linewidth=2, zorder=5)
        
        fig.canvas.draw()
        
        # Also print to console as backup
        print(f"\nSelected Point {closest_idx} (Label {selected_label}):")
        print(f"Text: {selected_string}")
        print("-" * 50)
    
    # Connect the click event
    cid = fig.canvas.mpl_connect('button_press_event', on_click)
    
    # Enable interactive mode
    plt.ion()
    plt.tight_layout()
    
    # Show with proper backend
    plt.show()
    
    # Store connection ID for cleanup if needed
    fig._click_connection = cid
    
    return embeddings_2d, tsne, fig

# Alternative function for notebooks that definitely works
def visualize_embeddings_tsne_notebook(embeddings, labels, strings, perplexity=30, n_iter=1000, random_state=42):
    """
    Notebook-friendly version that prints strings when you call show_point_text()
    """
    # [Same preprocessing code as above]
    if isinstance(embeddings, list):
        if torch.is_tensor(embeddings[0]):
            embs_np = torch.stack(embeddings).detach().cpu().numpy()
        else:
            embs_np = np.array(embeddings)
    elif torch.is_tensor(embeddings):
        embs_np = embeddings.detach().cpu().numpy()
    else:
        embs_np = embeddings
    
    if len(embs_np.shape) > 2:
        embs_np = embs_np.reshape(embs_np.shape[0], -1)
    
    labels_np = np.array(labels)
    strings_np = np.array(strings)
    
    if len(embs_np) != len(labels_np) or len(embs_np) != len(strings_np):
        raise ValueError("Embeddings, labels, and strings must have the same length")
    
    print(f"Embedding shape: {embs_np.shape}")
    print(f"Labels shape: {labels_np.shape}")
    print(f"Strings count: {len(strings_np)}")
    
    # Apply t-SNE
    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, init='pca', perplexity=perplexity, n_iter=n_iter, random_state=random_state)
    embeddings_2d = tsne.fit_transform(embs_np)
    
    # Create static plot with numbered points
    plt.figure(figsize=(12, 8))
    colors = ['red', 'blue']
    labels_text = ['Label 0', 'Label 1']
    
    for i, label in enumerate([0, 1]):
        mask = labels_np == label
        if np.any(mask):
            plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                       c=colors[i], label=labels_text[i], alpha=0.7, s=50)
    
    # Add point numbers as text annotations
    for i, (x, y) in enumerate(embeddings_2d):
        plt.annotate(str(i), (x, y), xytext=(2, 2), textcoords='offset points',
                    fontsize=6, alpha=0.7)
    
    plt.title('t-SNE Visualization (Points are numbered)\nUse show_point_text(point_number) to see text', fontsize=14)
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Return a function to show point text
    def show_point_text(point_idx):
        if 0 <= point_idx < len(strings_np):
            print(f"\nPoint {point_idx} (Label {labels_np[point_idx]}):")
            print(f"Text: {strings_np[point_idx]}")
            print("-" * 50)
        else:
            print(f"Point {point_idx} not found. Valid range: 0-{len(strings_np)-1}")
    
    # Create a results object to return
    class TSNEResults:
        def __init__(self, embeddings_2d, tsne_model, strings, labels):
            self.embeddings_2d = embeddings_2d
            self.tsne_model = tsne_model
            self.strings = strings
            self.labels = labels
            self.show_text = show_point_text
    
    return TSNEResults(embeddings_2d, tsne, strings_np, labels_np)

In [None]:
results = visualize_embeddings_tsne_notebook(patient_embs, labels, [ "pat_"+str(text) for text in list(samples.index) ], perplexity=15)

In [None]:
# torch.save(torch.stack(patient_embs), '/home/cc/PHD/dglframework/DeepKG/brca-embs.pt')