# Create network graph and analysis

# 1) Set up libraries and work directories

In [None]:
# Set up libraries
import os
import time 
import random
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from tqdm import tqdm
from rich.text import Text
from rich.console import Console
from pyvis.network import Network
from ipysigma import Sigma

# Work directory
input_directory = "INPUT_DIRECTORY"
variantscape_directory = "VARIANTSCAPE_DIRECTORY"
variantscape_llm_coas_directory = "VARIANTSCAPE_LLM_COAS_DIRECTORY"

# Load the metadata and variant dataset
os.chdir(variantscape_directory)
metadata_mapping = pd.read_csv(os.path.join(variantscape_directory, "metadata_mapping_transposed.csv"), low_memory=False)
variant_analysis_df = pd.read_csv(os.path.join(variantscape_directory, "cleaned_df_v4.csv"), low_memory=False)
print(metadata_mapping.head(5))

os.chdir(variantscape_LLM_coas_directory)
df_consensus = pd.read_csv("final_variant_treatment_consensus.csv")
os.chdir(variantscape_directory)
print("\n\n")
print(df_consensus.head(5))

In [None]:
##### Explore variant dictionary
# Filter metadata_mapping for entities where Category == 'Variant'
variant_entities = metadata_mapping[metadata_mapping['Category'] == 'Variant']['Entity'].tolist()
print(f"\nTotal unique entities with Category == 'Variant': {len(variant_entities)}")
print("ScoreLabel List for Entity = Variant:\n")
for idx, entity in enumerate(variant_entities, 1):
    print(f"{idx}. {entity}")

In [None]:
# Cancer dictionary
# Filter metadata_mapping for entities where Category == 'Cancer'
cancer_entities = metadata_mapping[metadata_mapping['Category'] == 'Cancer']['Entity'].tolist()
print(f"\nTotal unique entities with Category == 'Cancer': {len(cancer_entities)}")
print("ScoreLabel List for Entity = Cancer:\n")
for idx, entity in enumerate(cancer_entities, 1):
    print(f"{idx}. {entity.capitalize()}")

In [None]:
# Treatment dictionary
# Filter for Treatment entities
treatment_entities = metadata_mapping[metadata_mapping['Category'] == 'Treatment']['Entity'].tolist()
print(f"\n\033[1mTotal unique entities with Category == 'Treatment': {len(treatment_entities)}\033[0m")
print("\n\033[1mScoreLabel List for Entity = Treatment:\033[0m")
for idx, entity in enumerate(treatment_entities, 1):
    print(f"{idx}. {entity}")

# ========================================================

# 2) Network graph analysis

## Create network graph

In [None]:
# Clean variant_analysis_df by removing non-entity columns
# Ingnore non binary columns
entity_columns = [col for col in variant_analysis_df.columns if col not in ['PaperId', 'Study_design', 'Abstract', 'Study_weight', 'PaperTitle']]

# Create network graph
G = nx.Graph()

# Add nodes for each entity based on metadata_mapping dictionary
for col in entity_columns:
    category = metadata_mapping.loc[metadata_mapping['Entity'] == col, 'Category'].values[0]
    G.add_node(col, category=category)

# Add edges based on co-occurrence
for idx, row in tqdm(variant_analysis_df.iterrows(), total=variant_analysis_df.shape[0], desc="Adding edges"):
    present_entities = row[entity_columns] == 1 
    present_columns = present_entities[present_entities].index
    for col1 in present_columns:
        for col2 in present_columns:
            if col1 != col2:
                if G.has_edge(col1, col2):
                    G[col1][col2]['weight'] += 1 
                else:
                    G.add_edge(col1, col2, weight=1)

# Exmplore network analysis
print("Number of nodes:", len(G.nodes))
print("Number of edges:", len(G.edges))
print("Network successfully created!")

# Save the network
nx.write_gml(G, 'network_graph.gml')

In [None]:
# Verify the network analysis
G = nx.read_gml('network_graph.gml')

# Example query: Find the neighbors of a particular variant of interest
variant_of_interest = "s768i_EGFR"
variant_neighbors = set(G.neighbors(variant_of_interest))

# Find treatments and cancers associated with the variant
treatments = []
cancers = []

for node in variant_neighbors:
    if G.nodes[node]['category'] == 'Treatment':
        treatments.append(node)
    elif G.nodes[node]['category'] == 'Cancer':
        cancers.append(node)
        
print(f"Treatments associated with variant '{variant_of_interest}':")
print(treatments)
print(f"\nCancers associated with variant '{variant_of_interest}':")
print(cancers)

# Find the top 5 most connected nodes (by degree centrality)
centrality = nx.degree_centrality(G)
sorted_centrality = sorted(centrality.items(), key=lambda x: x[1], reverse=True)
print("\nTop 5 most connected nodes (based on degree centrality):")
for node, score in sorted_centrality[:5]:
    print(f"{node}: {score:.4f}")

## Create weighted network graph based on study design

In [None]:
# Create weighted network graph based on study design

# Study‑design weights
study_design_weights = {
    'Systematic review study':       1.0,
    'Clinical study':                1.0,
    'Observational/RWE study':       0.9,
    'Case report study':             0.9,
    'In vivo/Animal study':          0.8,
    'In vitro study':                0.7,
    'In silico study':               0.6,
    'Undefined':                     0.1,
    'Other':                         0.1,
}

# Entity columns
non_entity_cols = ['PaperId', 'Study_design', 'Abstract', 'Study_weight', 'PaperTitle']
entity_columns = [c for c in variant_analysis_df.columns if c not in non_entity_cols]

# Initialize two graphs
G   = nx.Graph()  # unweighted (increments by 1)
G_w = nx.Graph()  # weighted by evidence level

# 1) Add nodes to both
for col in entity_columns:
    cat = metadata_mapping.loc[
        metadata_mapping['Entity'] == col, 'Category'
    ].values[0]
    G.add_node(col,   category=cat)
    G_w.add_node(col, category=cat)

# 2) Add edges
for _, row in tqdm(variant_analysis_df.iterrows(), total=len(variant_analysis_df), desc="Building graphs"):
    ents = row[entity_columns][row[entity_columns] == 1].index.tolist()
    design = str(row['Study_design']).strip()
    w      = study_design_weights.get(design, 0.5)

    # iterate unique pairs
    for i, e1 in enumerate(ents):
        for e2 in ents[i+1:]:
            # --- unweighted: +1 per co‑occurrence ---
            if G.has_edge(e1, e2):
                G[e1][e2]['weight'] += 1
            else:
                G.add_edge(e1, e2, weight=1)

            # --- weighted: +w per co‑occurrence ---
            if G_w.has_edge(e1, e2):
                G_w[e1][e2]['weight'] += w
            else:
                G_w.add_edge(e1, e2, weight=w)

# 3) Compare basic metrics
print("=== Unweighted graph")
print(" Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())
deg_unw = nx.degree_centrality(G)
top_unw = sorted(deg_unw.items(), key=lambda x: x[1], reverse=True)[:5]
print(" Top 5 nodes by degree centrality:", top_unw)
print("\n=== Weighted graph")
print(" Nodes:", G_w.number_of_nodes(), "Edges:", G_w.number_of_edges())
deg_w = nx.degree_centrality(G_w)
top_w = sorted(deg_w.items(), key=lambda x: x[1], reverse=True)[:5]
print(" Top 5 nodes by degree centrality:", top_w)

# Compare edge‐weight distributions
weights_unw = np.array([d['weight'] for _, _, d in G.edges(data=True)])
weights_w   = np.array([d['weight'] for _, _, d in G_w.edges(data=True)])

print("Edge weights (unweighted): mean=%.2f, std=%.2f" % (weights_unw.mean(), weights_unw.std()))
print("Edge weights (weighted):   mean=%.2f, std=%.2f" % (weights_w.mean(),   weights_w.std()))
print("Number of edges re‑weighted:", np.sum(weights_w != weights_unw))

# Correlate the two weight vectors
corr = np.corrcoef(weights_unw, weights_w)[0,1]
print("Pearson corr between unweighted/weighted edge‐weights: %.3f" % corr)

# Compare degree centrality shifts per node
dc_unw = nx.degree_centrality(G)
dc_w   = nx.degree_centrality(G_w)
delta  = {n: dc_w[n] - dc_unw[n] for n in G.nodes()}
nx.write_gml(G_w, 'network_graph_weighted.gml')

# =================================================

# 3) Query the network to find associations

In [None]:
# Load datasets needed for network query
# UPDATE Load (or reload) the weighted graph into G_w, then alias it to G for downstream code!!
try:
    G_w
    print("Weighted graph already loaded. Proceeding with analysis.")
except NameError:
    print("Loading the weighted network graph...")
    path = os.path.join(variantscape_directory, "network_graph_weighted.gml")
    G_w = nx.read_gml(path)
# Use G as the graph variable everywhere else
G = G_w.copy()

try:
    df_consensus
    print("df_consensus already loaded. Proceeding with analysis.")
except NameError:
    print("Loading the consensus file...")
    df_consensus_path = variantscape_LLM_coas_directory + '/final_variant_treatment_consensus.csv'
    df_consensus = pd.read_csv(df_consensus_path)
    
# Define variant and cancer of interest (with aliasing)
user_input_cancer = "NSCLC"
#variant_of_interest = "v600e_BRAF"

############## Variants of interest for publication ##############
#variant_of_interest = 'l858r_EGFR' #as durggable usecase
#variant_of_interest = 't790m_EGFR' #as resistant usecase

######################## Rare variants #############################
#variant_of_interest = 'g469v_BRAF'
#variant_of_interest = 's768i_EGFR'
variant_of_interest = 'l861q_EGFR'
#variant_of_interest = 'l747p_EGFR' # no associations in the network


# Define alias mapping for cancer names
cancer_alias_map = {
    "nsclc": "lung cancer",
    "non-small cell lung cancer": "lung cancer",
    "tnbc": "breast cancer",
    "her2+ breast cancer": "breast cancer"
}

# Normalize input
clean_input = user_input_cancer.strip().lower()
cancer_of_interest = cancer_alias_map.get(clean_input, clean_input)
display_cancer_name = user_input_cancer

print(f"\n\n\033[1mCancer of interest set to:\033[0m {display_cancer_name} (cancer type:'{cancer_of_interest}')")
print(f"\033[1mVariant of interest set to:\033[0m {variant_of_interest}")

In [None]:
#### Updated weighted network graph with automated threshold, based on qualitative analysis

# Adjustable thresholds
TREATMENT_THRESHOLD_PERCENTILE = 80    # highlight top X% of treatment weights
TREATMENT_MIN_HIGHLIGHT        = 300   # and require ≥X total weight
CANCER_THRESHOLD_PERCENTILE    = 80    # highlight top X% of cancer–variant weights
CANCER_MIN_HIGHLIGHT           = 80    # and require ≥X total weight

# Prepare consensus lookup
df_consensus["Variant_Treatment_Pair"] = (
    df_consensus["Variant_Treatment_Pair"]
    .str.strip()
    .str.lower()
)
consensus_dict = dict(
    zip(df_consensus["Variant_Treatment_Pair"], df_consensus["Resolved_Prediction"])
)

excluded_treatments = {
    'chemotherapy', 'tyrosine kinase inhibitor', 'radiotherapy', 'hormone therapy',
    'adjuvant chemotherapy', 'immunotherapy', 'immune checkpoint inhibitor',
    'mrna vaccine', 'mtor inhibitor', 'radiation ionizing radiotherapy'
}

# Cancer‐only treatments
canc_nei = set(G.neighbors(cancer_of_interest))
treatments = [
    n for n in canc_nei
    if G.nodes[n]['category']=='Treatment'
    and n.lower() not in excluded_treatments
]
t_weights = {t: G[cancer_of_interest][t]['weight'] for t in treatments}
top_cancer_treats = sorted(t_weights.items(), key=lambda x: x[1], reverse=True)[:6]
c_w = list(t_weights.values())
treat_pct = np.percentile(c_w, TREATMENT_THRESHOLD_PERCENTILE) if c_w else 0


# Variant + cancer associations
sensitive, resistant = [], []
for t in treatments:
    try:
        w = G[cancer_of_interest][t]['weight'] + G[variant_of_interest][t]['weight']
        pred = consensus_dict.get(f"{variant_of_interest} + {t}".lower())
        if pred == "Sensitive":
            sensitive.append((t, w))
        elif pred == "Resistant":
            resistant.append((t, w))
    except KeyError:
        continue

top_sens = sorted(sensitive, key=lambda x: x[1], reverse=True)[:6]
top_res  = sorted(resistant, key=lambda x: x[1], reverse=True)[:6]
sens_w = [w for _, w in sensitive]
res_w  = [w for _, w in resistant]
sens_pct = np.percentile(sens_w, TREATMENT_THRESHOLD_PERCENTILE) if sens_w else 0
res_pct  = np.percentile(res_w,   TREATMENT_THRESHOLD_PERCENTILE) if res_w else 0

print(f"\n\033[1mSensitive treatments for variant '{variant_of_interest}' "
      f"(≥{TREATMENT_THRESHOLD_PERCENTILE}th pct & ≥{TREATMENT_MIN_HIGHLIGHT}):\033[0m")
for t, w in top_sens:
    if w >= sens_pct and w >= TREATMENT_MIN_HIGHLIGHT:
        print(f"\033[1;32m{t}: {w:.0f}\033[0m")
    else:
        print(f"\033[2;37m{t}: {w:.0f}\033[0m")

print(f"\n\033[1mResistant treatments for variant '{variant_of_interest}' "
      f"(≥{TREATMENT_THRESHOLD_PERCENTILE}th pct & ≥{TREATMENT_MIN_HIGHLIGHT}):\033[0m")
for t, w in top_res:
    if w >= res_pct and w >= TREATMENT_MIN_HIGHLIGHT:
        print(f"\033[1;31m{t}: {w:.0f}\033[0m")
    else:
        print(f"\033[2;37m{t}: {w:.0f}\033[0m")

        
# Other cancers for variant
var_nei = set(G.neighbors(variant_of_interest))
var_cancers = [
    n for n in var_nei
    if G.nodes[n]['category']=='Cancer'
    and n != cancer_of_interest
]
vc_weights = {}
for c in var_cancers:
    w_v = G[variant_of_interest][c]['weight']
    w_c = G[cancer_of_interest][c]['weight'] if G.has_edge(cancer_of_interest, c) else 0
    vc_weights[c] = w_v + w_c

top_var_c = sorted(vc_weights.items(), key=lambda x: x[1], reverse=True)[:6]
vc_w = list(vc_weights.values())
cancer_pct = np.percentile(vc_w, CANCER_THRESHOLD_PERCENTILE) if vc_w else 0

print(f"\n\033[1mOther cancers for variant '{variant_of_interest}' "
      f"(≥{CANCER_THRESHOLD_PERCENTILE}th pct & ≥{CANCER_MIN_HIGHLIGHT}):\033[0m")
for c, w in top_var_c:
    if w >= cancer_pct and w >= CANCER_MIN_HIGHLIGHT:
        print(f"\033[1;34m{c.capitalize()}: {w:.0f}\033[0m")

    else:
        print(f"\033[2;37m{c.capitalize()}: {w:.0f}\033[0m")

# Assemble
results = []
for t, w in top_cancer_treats:
    results.append({
        "Cancer": display_cancer_name, "Variant": None,
        "Treatment": t, "Association_Type": "Cancer-Only",
        "Prediction": "NA", "Combined_Weight": w
    })
for t, w in top_sens:
    results.append({
        "Cancer": display_cancer_name, "Variant": variant_of_interest,
        "Treatment": t, "Association_Type": "Variant-Cancer",
        "Prediction": "Sensitive", "Combined_Weight": w
    })
for t, w in top_res:
    results.append({
        "Cancer": display_cancer_name, "Variant": variant_of_interest,
        "Treatment": t, "Association_Type": "Variant-Cancer",
        "Prediction": "Resistant", "Combined_Weight": w
    })
for c, w in top_var_c:
    results.append({
        "Cancer": c, "Variant": variant_of_interest,
        "Treatment": None, "Association_Type": "Cross-Cancer",
        "Prediction": "NA", "Combined_Weight": w
    })

df_out = pd.DataFrame(results)
safe_c = display_cancer_name.replace(" ", "_").lower()
safe_v = variant_of_interest.replace(" ", "_").lower()

csv_fname   = f"network_results_{safe_c}_{safe_v}.csv"
excel_fname = f"network_results_{safe_c}_{safe_v}.xlsx"

df_out.to_csv(csv_fname, index=False)
print(f"\nCSV saved to: {csv_fname}")

with pd.ExcelWriter(excel_fname, engine='xlsxwriter') as w:
    df_out.to_excel(w, sheet_name='Results', index=False)
print(f"Excel saved to: {excel_fname}")

# =================================================

# 4) Create network figure

In [None]:
# Create figure to display the full network

# Load the graph
try:
    G_w
    print("Weighted graph already loaded. Proceeding with analysis.")
except NameError:
    print("Loading the weighted network graph...")
    path = os.path.join(variantscape_directory, "network_graph_weighted.gml")
    G_w = nx.read_gml(path)
# Use G as the graph variable everywhere else
G = G_w.copy()


# Filter large components only
print("Filtering graph...")
if nx.is_directed(G):
    G = G.to_undirected()
components = list(tqdm(nx.connected_components(G), desc="Finding Components", ncols=100, ascii=True))
components = [c for c in components if len(c) >= 50]
G = G.subgraph(set().union(*components)).copy()
print(f"Filtered to {len(G.nodes())} nodes and {len(G.edges())} edges.")

print("Computing spring layout...")
pos = nx.spring_layout(G, seed=42, k=0.15, iterations=50)

# Prepare node visuals (category color, degree size)
print("Processing node visuals...")
degrees = dict(G.degree())
node_sizes = [degrees[n] for n in G.nodes()]
max_degree = max(node_sizes)
scaled_sizes = [5 + (deg / max_degree) * 20 for deg in node_sizes]


category_colors = { 
    "Variant": "#00b0f0",     
    "Treatment": "#32cd32",  
    "Cancer": "#ff4c4c"      
}


node_x, node_y, node_text, node_colors = [], [], [], []
for node in tqdm(G.nodes(), desc="Placing Nodes", ncols=100, ascii=True):
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)
    category = G.nodes[node].get('category', 'Variant')
    node_colors.append(category_colors.get(category, '#888888'))
    node_text.append(f"{node}<br>Degree: {degrees[node]}")

node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers',
    text=node_text,
    hoverinfo='text',
    marker=dict(
        size=scaled_sizes,
        color=node_colors,
        opacity=0.85,
        line=dict(width=0.3, color='white')
    ),
    name="",
    showlegend=False
)

# Draw edges (but do not show in legend)
print("Processing edge traces...")
edge_x, edge_y = [], []
for u, v in tqdm(G.edges(), desc="Drawing Edges", total=G.number_of_edges(), ncols=100, ascii=True):
    x0, y0 = pos[u]
    x1, y1 = pos[v]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

edge_trace = go.Scatter(
    x=edge_x,
    y=edge_y,
    line=dict(width=0.2, color='rgba(200,200,200,0.15)'),
    hoverinfo='none',
    mode='lines',
    showlegend=False
)

# Category legend with note on size
legend_entries = [
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=12, color=category_colors['Variant']),
               name="Variant"),
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=12, color=category_colors['Treatment']),
               name="Treatment"),
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=12, color=category_colors['Cancer']),
               name="Cancer"),
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=0.01, color='rgba(0,0,0,0)'),
               name="Size = Node degree")
]



# Create figure
print("Creating Plotly figure...")
fig = go.Figure(data=[edge_trace, node_trace] + legend_entries,
                layout=go.Layout(
                    title=dict(
                        text="Variantscape: Full network graph of molecular variants, treatments and cancer types",
                        font=dict(size=20, color='white'),
                        x=0.5
                    ),
                    showlegend=True,
                    legend=dict(
                        font=dict(color='white'),
                        title=dict(text="Legend", font=dict(size=14, color='white')),
                        bgcolor='rgba(0,0,0,0)',
                        x=0.01,
                        y=0.99
                    ),
                    hovermode='closest',
                    margin=dict(b=10, l=10, r=10, t=80),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    plot_bgcolor='black',
                    paper_bgcolor='black'
                ))
print("Saving to HTML...")
for _ in tqdm(range(100), desc="Writing HTML", ncols=100, ascii=True):
    time.sleep(0.002)
fig.write_html("variantscape_network_graph.html")
print("Saved: variantscape_network_graph.html")

In [None]:
# Check if any nodes are disconnected from the network 
# (should not be the case, as only articles with all 3 entitiy mentions have been included in the analysis)

G_full = G.copy()

# Filter large components only
print("Filtering graph...")
if nx.is_directed(G):
    G = G.to_undirected()

components = list(tqdm(nx.connected_components(G), desc="Finding Components", ncols=100, ascii=True))
components = [c for c in components if len(c) >= 50]
G = G.subgraph(set().union(*components)).copy()

dropped_nodes = len(G_full.nodes()) - len(G.nodes())
print(f"Filtered to {len(G.nodes())} nodes and {len(G.edges())} edges.")
print(f"Dropped {dropped_nodes} nodes not in components ≥ 50.")