In [None]:
# %pip install pykeen

In [None]:
import time
import torch
import requests
import pandas as pd

from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory
from bs4 import BeautifulSoup


In [None]:
edges_path = 'tsv_files/edges.tsv'
nodes_path = 'tsv_files/nodes.tsv'

In [None]:
edges_df = pd.read_csv(
    edges_path,
    sep='\t',
    usecols=["start_id", "type", "end_id"],
    dtype={"start_id": str, "end_id": str}
)
edges_df = edges_df[["start_id", "type", "end_id"]]

In [None]:
edges_df.head()

In [None]:
nodes_df = pd.read_csv(nodes_path, sep='\t', low_memory=False)
nodes_df.head()

In [None]:
for i in nodes_df['identifier']:
    if i.startswith('DB'):
        print(i)
        break

In [None]:
import json

def get_replaced_identifier(row):
    properties = row['properties']
    try:
        properties_dict = json.loads(properties) if isinstance(properties, str) else {}
    except json.JSONDecodeError:
        properties_dict = {}
    
    xrefs = properties_dict.get("xrefs", [])
    
    for xref in xrefs:
        if xref.startswith("OMIM:"):
            return xref.replace("OMIM:", "OMIM")
    
    return row['identifier']

In [None]:
nodes_df['identifier'] = nodes_df.apply(get_replaced_identifier, axis=1)

entity_to_id = dict(zip(nodes_df['identifier'], nodes_df['node_id']))

In [None]:
entity2id_df = pd.DataFrame(list(entity_to_id.items()), columns=['Entity', 'Node_ID'])
entity2id_df.to_csv('entity2node.csv', index=None)

In [None]:
triples_factory = TriplesFactory.from_labeled_triples(edges_df.values, entity_to_id=entity_to_id, create_inverse_triples=True)

In [None]:
training_factory, testing_factory = triples_factory.split(
    ratios=[0.999999, 0.000001],
    random_state=42
)

In [None]:
result = pipeline(
    model='CompGCN',
    training=training_factory,
    testing=testing_factory,
    validation=None,
    model_kwargs=dict(embedding_dim=64),
    training_kwargs=dict(
        num_epochs=50,
        batch_size=4096000
    ),
    random_seed=42,
    device='cuda:2'
)

In [None]:
# result.save_to_directory('pykeen_result')

In [None]:
model = result.model
training_factory = result.training

In [None]:
entity_embeddings = model.entity_representations[0]
entity_tensor = entity_embeddings(indices=None)

In [None]:
entity_tensor.shape

In [None]:
entity_to_id = training_factory.entity_to_id

In [None]:
total = 0
omim_ids = []
drugbank_ids = []
for entity_id, entity_label in entity_to_id.items():
    print(entity_id, entity_label)
    total += 1
    if total > 10:
        break
    # if "omim" in entity_label.lower():
    #     print(1)
    #     omim_ids.append((entity_label, entity_id))
    # elif "drugbank" in entity_label.lower():
    #     print(2)
    #     drugbank_ids.append((entity_label, entity_id))

In [None]:
def get_entity_embedding(entity_label, entity_to_id, entity_tensor):
    """
    Return the embedding vector corresponding to an entity label.
    
    Parameters:
    - entity_label: str, the label of the entity (e.g., 'OMIM602371' or 'DB00014')
    - entity_to_id: dict, a mapping from entity to ID
    - entity_tensor: torch.Tensor, entity embedding matrix of shape [num_entities, embedding_dim]

    Returns:
    - embedding_vector: torch.Tensor or None
    """
    # Check if the entity exists in the mapping
    if entity_label not in entity_to_id:
        print(f"Entity '{entity_label}' not found in the mapping.")
        return None
    
    # Get the ID index corresponding to the entity
    entity_id = entity_to_id[entity_label]
    
    # Extract the corresponding embedding vector from entity_tensor
    embedding_vector = entity_tensor[entity_id].detach().cpu().numpy()  # Move embedding to CPU for easy handling

    return embedding_vector

In [None]:
example_entity = 'OMIM102100'
embedding = get_entity_embedding(example_entity, entity_to_id, entity_tensor)

In [None]:
embedding.shape

In [None]:
embedding

In [None]:
wdname = pd.read_csv('Wdname.csv', header=None)
wdname.head()

In [None]:
wrname = pd.read_csv('Wrname.csv', header=None)
wrname.head()

In [None]:
wdname = pd.read_csv('data/other/Wdname.csv', header=None, names=["original_id"])  # OMIM ID
wrname = pd.read_csv('data/other/Wrname.csv', header=None, names=["original_id"])  # DrugBank ID

id_df = pd.concat([wdname, wrname], ignore_index=True)

In [None]:
# Create the `entity_label` column
def convert_to_entity_label(original_id):
    """
    Convert the original ID to the `entity_label` format used by the model.
    """
    # Handle IDs that start with "D": check if the characters after "D" are digits (OMIM format)
    if original_id.startswith("D") and original_id[1:].isdigit():
        return "OMIM" + original_id[1:]  # Convert to `OMIMxxxx` format
    # If the ID starts with "DB", retain it as `DBxxxx`
    elif original_id.startswith("DB"):
        return original_id
    else:
        return None  # Return None if the ID doesn't meet any criteria

# Apply the conversion function to create the new `entity_label` column
id_df['entity_label'] = id_df['original_id'].apply(convert_to_entity_label)

# Keep only the rows where `entity_label` is not empty
id_df = id_df.dropna(subset=['entity_label'])

In [None]:
id_df.head()

In [None]:
id_df['embedding'] = id_df['entity_label'].apply(lambda x: get_entity_embedding(x, entity_to_id, entity_tensor))

In [None]:
# Step 1: Create a dictionary for ID replacements
id_replacement_mapping = {
    "DB00510": "DB00313",
    "DB01258": "DB09026",
    "DB01402": "DB01294",
    "DB01904": "DB11195",
    "DB05073": "DB02709"
}

def update_entity_label(entity_label, mapping):
    """
    Update the `entity_label` based on the replacement dictionary.
    """
    # If the current ID exists in the mapping, replace it with the new ID; otherwise, keep it unchanged
    return mapping.get(entity_label, entity_label)

In [None]:
id_df['entity_label'] = id_df['entity_label'].apply(lambda x: update_entity_label(x, id_replacement_mapping))

In [None]:
id_df['embedding'] = id_df['entity_label'].apply(lambda x: get_entity_embedding(x, entity_to_id, entity_tensor))

In [None]:
extract_omim_ids = [
    "OMIM103470", "OMIM106400", "OMIM133600", "OMIM144400", "OMIM153480", 
    "OMIM153640", "OMIM159000", "OMIM159001", "OMIM175505", "OMIM182920", 
    "OMIM212110", "OMIM232240", "OMIM234580", "OMIM241500", "OMIM241510", 
    "OMIM259660", "OMIM300301", "OMIM300306", "OMIM300455", "OMIM300494", 
    "OMIM300497", "OMIM300504", "OMIM300584", "OMIM300640", "OMIM300706", 
    "OMIM300910", "OMIM304900", "OMIM305300", "OMIM306700", "OMIM600208", 
    "OMIM600309", "OMIM600634", "OMIM600996", "OMIM601696", "OMIM601884", 
    "OMIM602025", "OMIM603860", "OMIM605839", "OMIM607447", "OMIM607595", 
    "OMIM607636", "OMIM607655", "OMIM607801", "OMIM608355", "OMIM608622", 
    "OMIM608902", "OMIM609265", "OMIM609307", "OMIM609338", "OMIM609535", 
    "OMIM609886", "OMIM609887", "OMIM610762", "OMIM610799", "OMIM611277", 
    "OMIM612052", "OMIM612359", "OMIM612362", "OMIM612460", "OMIM612542", 
    "OMIM612556", "OMIM612560", "OMIM612671", "OMIM612797", "OMIM612874", 
    "OMIM612975", "OMIM613180", "OMIM613508", "OMIM613875", "OMIM614036", 
    "OMIM614157", "OMIM614192", "OMIM614401", "OMIM614408", "OMIM614546", 
    "OMIM614696", "OMIM615106", "OMIM615221", "OMIM615311", "OMIM615325", 
    "OMIM615457"
]

In [None]:
# Assuming `id_df` is the DataFrame generated earlier with the `entity_label` column to be updated

# 1. Extract all OMIM IDs that need to be checked
omim_ids = [label[4:] for label in id_df['entity_label'] if label.startswith("OMIM")]

# 2. Define the OMIM website URL template for entry lookups
OMIM_URL_TEMPLATE = "https://www.omim.org/entry/{}"

# 4. Set headers to mimic a browser visit
headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
    "Referer": "https://www.omim.org/",
    "Accept-Language": "en-US,en;q=0.9",
    "Accept-Encoding": "gzip, deflate, br",
    "Connection": "keep-alive"
}

# Use requests.Session() to maintain session
session = requests.Session()
session.headers.update(headers)

# 3. Create a replacement dictionary for OMIM IDs
omim_replacement_mapping = {}

# 5. Use a loop to batch retrieve redirect information for each OMIM ID
for omim_id in omim_ids:  # Corrected variable name from extract_omim_ids to omim_ids
    url = OMIM_URL_TEMPLATE.format(omim_id)
    print(f"Checking OMIM ID: {omim_id}, URL: {url}")
    try:
        response = session.get(url, timeout=10)  # Set a timeout of 10 seconds per request
        
        if response.status_code == 200:
            # Parse the webpage using BeautifulSoup
            soup = BeautifulSoup(response.text, 'html.parser')
            
            # Find all <span> tags with class 'mim-font' that might contain redirection info
            span_tags = soup.find_all('span', class_='mim-font')
            for span in span_tags:
                # Check if the span text contains "MOVED TO"
                if "MOVED TO" in span.text:
                    # Extract the list of target IDs after "MOVED TO" (e.g., "MOVED TO 193510, 606952")
                    target_ids_text = span.text.strip().replace("MOVED TO", "").strip().split(' ')[0]
                    target_ids = [tid.strip() for tid in target_ids_text.split(",")]
                    
                    # Use the first target ID as the new ID in the mapping
                    new_id = target_ids[0]
                    omim_replacement_mapping[f"OMIM{omim_id}"] = f"OMIM{new_id}"
                    print(f"OMIM{omim_id} has been moved to OMIM{new_id}")
                    break  # Exit loop after finding the redirection
            else:
                print(f"No change record for OMIM{omim_id}")
        else:
            print(f"Failed to query OMIM{omim_id}, status code: {response.status_code}")
    except requests.exceptions.RequestException as e:
        print(f"Request error for OMIM{omim_id}: {e}")

    # Add a delay to prevent triggering anti-scraping mechanisms
    time.sleep(1)  # Wait 1 second between requests

In [None]:
omim_replacement_mapping

In [None]:
id_df['entity_label'] = id_df['entity_label'].apply(lambda x: omim_replacement_mapping.get(x, x))

id_df['embedding'] = id_df['entity_label'].apply(lambda x: get_entity_embedding(x, entity_to_id, entity_tensor))

In [None]:
manual_mappings = {
    "OMIM106400": "MONDO:0007127",
    "OMIM232240": "MONDO:0009288",
    "OMIM241500": "MONDO:0016605",
    "OMIM306700": "MONDO:0010602"
}

In [None]:
omim_replacement_mapping.update(manual_mappings)

In [None]:
id_df['entity_label'] = id_df['entity_label'].apply(lambda x: omim_replacement_mapping.get(x, x))

id_df['embedding'] = id_df['entity_label'].apply(lambda x: get_entity_embedding(x, entity_to_id, entity_tensor))

In [None]:
id_df.head()

In [None]:
id_df['embedding'] = id_df['embedding'].apply(lambda x: x if x is not None else [0.0] * 64)
embedding_df = pd.DataFrame(id_df['embedding'].tolist())
embedding_df.columns = [f"embedding_{i}" for i in range(embedding_df.shape[1])]
id_df_expanded = pd.concat([id_df.drop('embedding', axis=1), embedding_df], axis=1)

In [None]:
id_df_expanded.head()

In [None]:
id_df_expanded.to_csv('kg_embeddings.csv', index=None)