# Libraries

In [1]:
import os
import sys
import time
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
df = pd.read_csv(r"C:\Users\Kevin Nathanael\Music\DDI Prediction\data\drug_interactions.csv")
df

Unnamed: 0,drug_id,drug_name,interacting_drug_id,interacting_drug_name,description
0,DB00001,Lepirudin,DB06605,Apixaban,Apixaban may increase the anticoagulant activi...
1,DB00001,Lepirudin,DB06695,Dabigatran etexilate,Dabigatran etexilate may increase the anticoag...
2,DB00001,Lepirudin,DB01254,Dasatinib,The risk or severity of bleeding and hemorrhag...
3,DB00001,Lepirudin,DB01609,Deferasirox,The risk or severity of gastrointestinal bleed...
4,DB00001,Lepirudin,DB01586,Ursodeoxycholic acid,The risk or severity of bleeding and bruising ...
...,...,...,...,...,...
2855843,DB19413,Influenza A Virus A/Thailand/8/2022 IVR-237 (H...,DB13509,Aloxiprin,The risk or severity of Reye's syndrome can be...
2855844,DB19413,Influenza A Virus A/Thailand/8/2022 IVR-237 (H...,DB13538,Guacetisal,The risk or severity of Reye's syndrome can be...
2855845,DB19413,Influenza A Virus A/Thailand/8/2022 IVR-237 (H...,DB13612,Carbaspirin calcium,The risk or severity of Reye's syndrome can be...
2855846,DB19413,Influenza A Virus A/Thailand/8/2022 IVR-237 (H...,DB14006,Choline salicylate,The risk or severity of Reye's syndrome can be...


In [3]:
# Empty graph
G = nx.DiGraph()
G = nx.from_pandas_edgelist(df, source="drug_id", target="interacting_drug_id", edge_attr="description")
# G.nodes()

In [4]:
# Set up the layout
pos = nx.spring_layout(G)

In [5]:
weights = list(nx.get_edge_attributes(G, "description").values())
# weights

In [None]:
# plt.figure(figsize=(22,10))
# nx.draw_networkx_nodes(G, pos, node_size=800, alpha=0.5)
# nx.draw_networkx_edges(G, pos, edge_color="green")
# nx.draw_networkx_labels(G, pos)
# plt.show()

In [None]:
def get_local_embedding(texts, model_name='NeuML/pubmedbert-base-embeddings'):
    # Lazy initialization of model (singleton pattern)
    if not hasattr(get_local_embedding, 'model'):
        # Move model initialization outside of function call to reduce overhead
        get_local_embedding.model = SentenceTransformer(model_name).to('cuda')
        
        # Use torch.cuda.amp for mixed precision
        get_local_embedding.model.half()
    
    # Ensure input is a list
    if isinstance(texts, str):
        texts = [texts]
    
    # Use automatic mixed precision context
    with torch.amp.autocast('cuda'):
        embeddings = get_local_embedding.model.encode(
            texts, 
            convert_to_tensor=True,
            device='cuda',
            normalize_embeddings=True
        )
    
    # Move to CPU and convert to list for storage
    return embeddings.cpu().float().tolist()

In [7]:
def process_graph_embeddings(G, batch_size=32):
    """
    Process graph edges with efficient batching and clean progress tracking.
    
    Args:
        G (nx.Graph): Input graph
        batch_size (int): Number of edges to process in each batch
    """
    # Redirect stdout to suppress nested progress bars
    original_stdout = sys.stdout
    sys.stdout = open(os.devnull, 'w')
    
    try:
        # Start time for performance tracking
        start_time = time.time()
        
        # Convert graph edges to list for efficient batching
        edge_list = list(G.edges(data=True))
        total_edges = len(edge_list)
        
        # Restore stdout for progress printing
        sys.stdout = original_stdout
        
        # Print initial progress message
        print(f"Processing {total_edges} edges in batches of {batch_size}")
        
        # Custom progress tracker
        for i in range(0, total_edges, batch_size):
            batch = edge_list[i:i+batch_size]
            
            # Extract descriptions from batch
            descriptions = [data.get('description', '') for _, _, data in batch]
            
            # Compute embeddings for batch
            try:
                embeddings = get_local_embedding(descriptions)
                
                # Assign embeddings to graph edges
                for j, (u, v, _) in enumerate(batch):
                    G.edges[u, v]['embedding'] = embeddings[j]
                
                # Print progress
                progress = min(100, int((i + len(batch)) / total_edges * 100))
                print(f"\rProgress: {progress}% ({i + len(batch)}/{total_edges} edges)", end='', flush=True)
            
            except Exception as e:
                print(f"\nError processing batch {i}: {e}")
                break
        
        # Calculate and print total processing time
        elapsed_time = (time.time() - start_time) / 60
        print(f"\n\nProcessed {total_edges} edges in {elapsed_time:.2f} minutes")
    
    finally:
        # Ensure stdout is restored
        sys.stdout = original_stdout
    
    return G

In [8]:
processed_graph = process_graph_embeddings(G)

Processing 1428193 edges in batches of 32
Progress: 0% (160/1428193 edges)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Progress: 4% (69248/1428193 edges)

KeyboardInterrupt: 

In [None]:
# PyTorch Geometric data preparation
edges = list(G.edges())
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_attr = torch.stack([torch.tensor(G.edges[u, v]['embedding']) for u, v in edges])
x = torch.ones(len(G), 1)  # Dummy node features

graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [None]:
# Data splitting for training, validation, and testing
edges = list(G.edges(data=True))
train_edges, test_edges = train_test_split(edges, test_size=0.2, random_state=42)
train_edges, val_edges = train_test_split(train_edges, test_size=0.25, random_state=42)

In [None]:
class GCNModel(torch.nn.Module):
    def __init__(self, node_dim=1, edge_dim=None, hidden_dim=64):
        super().__init__()
        self.gcn = GCNConv(node_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim + edge_dim, 1)

    def forward(self, x, edge_index, edge_attr):
        h = self.gcn(x, edge_index).relu()
        combined = torch.cat([h[edge_index[0]], edge_attr], dim=1)
        return torch.sigmoid(self.fc(combined))