In [1]:
import networkx as nx
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [None]:
def plot_network(self):
    G = nx.Graph()

    # Get embeddings (drug and effect embeddings)
    drug_embeddings = self.w.detach().cpu().numpy()  # (n_drugs, embedding_dim)
    effect_embeddings = self.v.detach().cpu().numpy()  # (n_effects, embedding_dim)

    # Add nodes for drugs and side effects
    for i in range(self.n_drugs):
        G.add_node(f"Drug_{i}", bipartite=0)
    for j in range(self.n_effects):
        G.add_node(f"Effect_{j}", bipartite=1)

    # Calculate the probability matrix using the probit function
    probit_matrix = self.probit()

    # Add edges based on the probit matrix (non-zero probability indicates a link)
    for i in range(self.n_drugs):
        for j in range(self.n_effects):
            prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
            if prob > 0.01:  # Threshold for displaying an edge
                G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

    # Create a layout based on the embeddings
    pos = {}
    
    # Position drugs based on their embeddings
    for i in range(self.n_drugs):
        pos[f"Drug_{i}"] = (drug_embeddings[i, 0], drug_embeddings[i, 1])  # 2D position based on first two embedding dims
    
    # Position effects based on their embeddings
    for j in range(self.n_effects):
        pos[f"Effect_{j}"] = (effect_embeddings[j, 0], effect_embeddings[j, 1])

    # Draw the graph
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_size=500, node_color=["blue" if "Drug" in node else "red" for node in G.nodes], font_size=10, font_weight='bold', edge_color='gray')

    # Display edge weights (probabilities) as labels
    # labels = nx.get_edge_attributes(G, 'weight')
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

    plt.title('Drug-Side Effect Network based on Embeddings and Probit Output')
    plt.show()

def plot_links(self):
    G = nx.Graph()

    # Add nodes for drugs and side effects
    for i in range(self.n_drugs):
        G.add_node(f"Drug_{i}", bipartite=0)
    for j in range(self.n_effects):
        G.add_node(f"Effect_{j}", bipartite=1)

    # Calculate the probability matrix using the probit function
    probit_matrix = self.probit()

    # Add edges based on the probit matrix (non-zero probability indicates a link)
    for i in range(self.n_drugs):
        for j in range(self.n_effects):
            prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
            if prob > 0.35:  # Threshold for displaying an edge
                G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(f"Drug_{i}" for i in range(self.n_drugs)))  # Position for drugs
    pos.update((node, (2, index)) for index, node in enumerate(f"Effect_{j}" for j in range(self.n_effects)))  # Position for effects

    # Draw the graph
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_size=500, node_color='skyblue', font_size=10, font_weight='bold', edge_color='gray')

    # Display edge weights (optional)
    labels = nx.get_edge_attributes(G, 'weight')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

    plt.title('Drug-Side Effect Network')
    plt.show()