In [None]:
import requests
import json
import numpy as np
from pysrc.papers.data import AnalysisData
from IPython.display import HTML, display

def filter_by_connectivity(df, graph, percentile=75, max_count=None):
    # Step 1: Compute connectivity (without modifying df)
    connectivity = df['id'].apply(lambda pid: len(list(graph.neighbors(pid))))

    # Step 2: Compute the percentile threshold
    threshold = np.percentile(connectivity, percentile)

    # Step 3: Get mask for nodes above threshold
    above_threshold_mask = connectivity >= threshold

    # Step 4: Apply the mask
    filtered_df = df[above_threshold_mask].copy()
    filtered_df['connections'] = connectivity[above_threshold_mask].values

    # Step 5: If max_count is specified, take top N by connections
    if max_count is not None and len(filtered_df) > max_count:
        filtered_df = filtered_df.sort_values('connections', ascending=False).head(max_count)

    return filtered_df

def render_table(entities):
    html = """
    <style>
        .collapse-content { display: none; margin-top: 5px; }
        .toggle-button { cursor: pointer; color: blue; text-decoration: underline; }
        th { text-align: left; }
    </style>
    <script>
        function toggleCollapse(id) {
            var x = document.getElementById(id);
            x.style.display = (x.style.display === "none") ? "block" : "none";
        }
    </script>
    <table border="1" style="border-collapse: collapse; width: 100%;">
        <thead>
            <tr>
                <th>#</th>
                <th>Name</th>
                <th>Context</th>
                <th>Total Connections</th>
                <th>Papers</th>
            </tr>
        </thead>
        <tbody>
    """
    for idx, entity in enumerate(sorted(entities, key=lambda g: g['total_connections'], reverse=True), start=1):
        collapse_id = f"collapse-{idx}"
        paper_links = "<br>".join(
            f'<a href="/paper/{pid}" target="_blank">{pid}</a>' for pid in entity["cited_in"]
        )
        html += f"""
        <tr>
            <td>{idx}</td>
            <td>{entity['name']}</td>
            <td>{entity['context']}</td>
            <td>{entity['total_connections']}</td>
            <td>
                <span class="toggle-button" onclick="toggleCollapse('{collapse_id}')">
                    Show Papers ({len(entity['cited_in'])})
                </span>
                <div id="{collapse_id}" class="collapse-content">{paper_links}</div>
            </td>
        </tr>
        """
    html += "</tbody></table>"
    display(HTML(html))

###### Remove this part if you're getting json from pubtrends API

# Replace 'your_file.json' with the actual path to your JSON file
file_path = 'pubmed-drug-resistance-in-cancer.json'

# Open and load the JSON file
with open(file_path, 'r', encoding='utf-8') as file:
    data = json.load(file)

ex = AnalysisData.from_json(data)

######

highly_connected_df = filter_by_connectivity(
    ex.df,
    ex.papers_graph,
    percentile=90,
    max_count=50  # cap the result if it's too large
)

abstract_entries = highly_connected_df[['id', 'abstract']].to_dict(orient='records')

# Convert to formatted string for LLM
abstracts_json = json.dumps(abstract_entries, ensure_ascii=False, indent=2)

# 1. Set your function URL
FUNCTION_URL = "URL_OF_YOUR_FUNCTION_HERE"

# 2. System prompt enum (must match server-side allowed value), here are represented all types
si_mode = "GENES_EXTRACTION"
# si_mode = "SUBSTANCES_EXTRACTION"
# si_mode = "CONDITIONS_EXTRACTION"
# si_mode = "PROTEINS_EXTRACTION"

# 4. Make the POST request with abstracts and si_mode
response = requests.post(
    f"{FUNCTION_URL}?si_mode={si_mode}",
    json=abstracts_json,
    headers={"Content-Type": "application/json"}
)

# 5. Handle response
if response.status_code == 200:
    data = response.json()
    connections_by_pid = dict(zip(highly_connected_df['id'], highly_connected_df['connections']))
    for entity in data:
        entity["total_connections"] = sum(
            connections_by_pid.get(pid, 0) for pid in entity.get("cited_in", [])
        )

    print("✅ Entities Extracted:\n")
    render_table(data)
else:
    print(f"❌ Error: {response.status_code}")
    print(response.json())

In [None]:
from collections import defaultdict
from itertools import combinations
from collections import Counter
import networkx as nx
import matplotlib.pyplot as plt

# Reverse mapping: paper → set of genes
paper_to_entities = defaultdict(set)
for entity in data:
    for pid in entity["cited_in"]:
        paper_to_entities[pid].add(entity["name"])

entities_pair_weights = Counter()

# Loop through each paper and count co-occurrences
for entities_set in paper_to_entities.values():
    for g1, g2 in combinations(sorted(entities_set), 2):
        entities_pair_weights[(g1, g2)] += 1

weighted_entities_pairs = defaultdict(float)

for pid, entities_set in paper_to_entities.items():
    weight = connections_by_pid.get(pid, 1)  # fallback to 1
    for g1, g2 in combinations(sorted(entities_set), 2):
        weighted_entities_pairs[(g1, g2)] += weight

# Build the graph from your weighted_gene_pairs
G = nx.Graph()
for (g1, g2), weight in weighted_entities_pairs.items():
    G.add_edge(g1, g2, weight=weight)

# Draw graph
plt.figure(figsize=(16, 16), facecolor='white')
pos = nx.spring_layout(G, k=13, iterations=300, seed=42)

edges = list(G.edges(data=True))
weights = [d['weight'] for (_, _, d) in edges]

# Set white background
plt.style.use('default')  # resets to white bg if you were in dark mode
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'

# Draw nodes and edges
nx.draw_networkx_nodes(G, pos, node_size=600, node_color="#6BAED6", edgecolors='black')
nx.draw_networkx_edges(G, pos,
                       edgelist=[(u, v) for u, v, _ in edges],
                       width=[w * 0.3 for w in weights],
                       alpha=0.5, edge_color="gray")
nx.draw_networkx_labels(G, pos, font_size=12, font_color='black')

# Optional edge labels for strong edges
edge_labels = {(u, v): f"{d['weight']}" for u, v, d in edges if d['weight'] > 2}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8, font_color='gray')

plt.title("Gene-Gene Co-occurrence Graph", fontsize=16)
plt.axis('off')
plt.savefig("gene_graph.png", facecolor='white', dpi=300)
plt.show()