In [None]:
!pip install biopython

In [None]:
!pip install rapidfuzz

In [None]:
!pip install fuzzywuzzy

In [None]:
!pip install negspacy

In [None]:
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_md-0.5.1.tar.gz --no-deps

In [None]:
import os, re, html, random, itertools
import time
import xml.etree.ElementTree as ET
from time import sleep
from Bio import Entrez
from collections import Counter
import spacy
from spacy.language import Language
from negspacy.negation import Negex
import numpy as np
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import torch
from collections import defaultdict
from torch.utils.data import Dataset
#import matplotlib.pyplot as plt
import plotly.graph_objects as go
import networkx as nx
from transformers import (AutoTokenizer, AutoModel, AutoModelForTokenClassification,
    AutoModelForSequenceClassification, AutoConfig, pipeline,TrainingArguments,
    Trainer, DataCollatorWithPadding, DataCollatorForTokenClassification, AutoModelForCausalLM)
from datasets import Dataset, DatasetDict
from google.colab import files
from rapidfuzz import fuzz

In [None]:
Entrez.email = " " # I removed my email address, If you would like to use the pipeline, please enter your email
nltk.download('punkt_tab')

In [None]:
#from google.colab import files
uploaded = files.upload()


***
# Data Collection and Pre-processing
***

In [None]:
# Load KBs and seed pairs
# Basic KB/name lists
SEED_PAIRS_XLSX = "VHDictionary.xlsx"       # columns: virus, host

if not os.path.exists(SEED_PAIRS_XLSX):
    raise FileNotFoundError(f"Seed pairs file {SEED_PAIRS_XLSX} not found.")

df_kb = pd.read_excel(SEED_PAIRS_XLSX)
seed_list_interactions = list(df_kb[['virus', 'host']].itertuples(index=False, name=None))

# Convert into dict of synonym lists
virus_dict = df_kb["virus"].dropna().astype(str).str.lower().str.strip().str.split(",")
host_dict = df_kb["host"].dropna().astype(str).str.lower().str.strip().str.split(",")

# Flatten lists for direct lookup
virus_names = [v.strip() for sublist in virus_dict for v in sublist]
host_names = [h.strip() for sublist in host_dict for h in sublist]

START_YEAR = 1990
END_YEAR = 2025
MAX_RESULTS = 2000
MAX_PER_PAIR  = 10
REQUEST_DELAY_S = 0.3

In [None]:
# File paths
SENTENCES_CSV = "processed_sentences.csv"
RE_RESULTS_CSV = "virus_host_relations.csv"
FINAL_RELATION_CSV = "final_relation_results.csv"

In [None]:
def _local(tag):
    # strip namespace: "{ns}name" -> "name"
    if "}" in tag:
        return tag.rsplit("}", 1)[1]
    return tag

def prune_jats_unwanted(root):
    """
    Remove unwanted JATS elements (figures, tables, formulas, supplements, graphics)
    and cross-refs to them, before extracting text.
    """
    banned_tags = {
        "fig", "fig-group", "table", "table-wrap", "table-wrap-foot", "graphic",
        "media", "supplementary-material", "disp-formula", "chem-struct-wrap",
        "boxed-text"
    }
    banned_xref_types = {"fig", "table", "supplementary-material"}

    # Remove unwanted element subtrees
    for parent in list(root.iter()):
        # iterate over a copy of children so we can remove safely
        for child in list(parent):
            lt = _local(child.tag)
            if lt in banned_tags:
                parent.remove(child)
                continue
            # remove xref elements that point to figs/tables/supplementary
            if lt == "xref":
                rt = (child.attrib.get("ref-type") or "").strip().lower()
                if rt in banned_xref_types:
                    parent.remove(child)


# Clean the Abstract/Full Text papers

def clean_text(text, lowercase=True):
    """
    Clean PMC text after XML pruning. Also strips stray inline references to
    figures/tables and typical artifacts.
    """
    text = html.unescape(text)
    # remove citation brackets like [1], [1,2], [1â€“3]
    text = re.sub(r"\[(?:\s*\d+\s*(?:[,â€“-]\s*\d+)?\s*)(?:,\s*\d+\s*)*\]", " ", text)
    # remove parenthetical year-only refs (e.g., "(Smith 2019)")
    text = re.sub(r"\(([^()]*\b(19|20)\d{2}[^()]*)\)", " ", text)
    # remove inline "Fig. 1", "Figure S2a", "Tables 3â€“4", etc.
    text = re.sub(r"\b(fig(?:\.|ures?)?|figure(?:s)?)\s*[s]?\d+[A-Za-z\-]*\b", " ", text, flags=re.I)
    text = re.sub(r"\b(table(?:s)?)\s*[s]?\d+[A-Za-z\-]*\b", " ", text, flags=re.I)
    text = re.sub(r"\b(supplementary(?:\s+figure|\s+table)?)\s*[s]?\d+[A-Za-z\-]*\b", " ", text, flags=re.I)
    # scrub URLs/emails/DOIs
    text = re.sub(r"http\S+|www\.\S+|doi:\S+|\S+@\S+", " ", text, flags=re.I)
    # strip any leftover angle-bracket tags just in case
    text = re.sub(r"<[^>]+>", " ", text)
    # non-ascii -> space
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    # collapse whitespace
    text = re.sub(r"\s+", " ", text).strip()

    if lowercase:
        text = text.lower()
    return text

def dedup(seq):
    seen = set()
    out = []
    for x in seq:
        if x not in seen:
            out.append(x); seen.add(x)
    return out

virus_names = dedup(virus_names)
host_names  = dedup(host_names)

In [None]:
def make_pair_query(virus, host, start_year=START_YEAR, end_year=END_YEAR):
    # Phrase-match both terms in Title/Abstract (TIAB) and constrain by year
    v = virus.replace('"', '')
    h = host.replace('"', '')
    return f'"{v}"[TIAB] AND "{h}"[TIAB] AND ("{start_year}"[PDAT] : "{end_year}"[PDAT])'

def pmc_esearch(query, retmax):
    """Search PMC and return a list of PMCID strings."""
    try:
        handle = Entrez.esearch(db="pmc", term=query, retmax=retmax)
        result = Entrez.read(handle)
        handle.close()
        return result.get("IdList", [])
    except Exception as e:
        print(f"[esearch] {e} for query: {query[:200]}...")
        return []

def fetch_article_info(pmc_id):
    """
    Fetch PMC XML and extract:
      - PMCID, Title, Journal, Year
      - FullText: Abstract + Results + Discussion or Results and Discussion
    Skips tables/figures by taking text only from <p> descendants of selected sections.
    """
    try:
        handle = Entrez.efetch(db="pmc", id=str(pmc_id), rettype="full", retmode="xml")
        xml_data = handle.read()
        handle.close()
        root = ET.fromstring(xml_data)

        article = root.find(".//article")
        if article is None:
            return None

        # Basic metadata
        pmcid = str(pmc_id)
        title_elem = article.find(".//article-title")
        title = "".join(title_elem.itertext()).strip() if title_elem is not None else ""

        journal_elem = article.find(".//journal-title")
        journal = "".join(journal_elem.itertext()).strip() if journal_elem is not None else ""

        # Prefer pub date year; fallback to first available year
        year = ""
        for pd in article.findall(".//pub-date"):
            y = pd.findtext("year")
            if y:
                year = y
                break

        # Collect text ONLY from: abstract + results + discussion + results and discussion
        text_parts = []

        # Abstracts
        for abs_elem in article.findall(".//abstract"):
            for p in abs_elem.findall(".//p"):
                t = " ".join(p.itertext()).strip()
                if t:
                    text_parts.append(t)

        # Body sections
        # Keep if title contains any of these keys
        keep_keys = ("results", "discussion", "results and discussion")
        for sec in article.findall(".//body//sec"):
            # section title (case-insensitive; robust to missing)
            sec_title = ""
            sec_title_elem = sec.find("title")
            if sec_title_elem is not None:
                sec_title = "".join(sec_title_elem.itertext()).strip().lower()

            # match "results", "discussion", or merged versions
            if any(k in sec_title for k in keep_keys):
                # Skip figures/tables implicitly by taking only paragraph text
                for p in sec.findall(".//p"):
                    t = " ".join(p.itertext()).strip()
                    if t:
                        text_parts.append(t)

        full_text = " ".join(text_parts).strip()
        return {
            "PMCID": pmcid,
            "Title": title,
            "Journal": journal,
            "Year": year,
            "FullText": full_text
        }
    except Exception as e:
        print(f"[efetch PMC{pmc_id}] {e}")
        return None


# Balanced retrieval per pair

pairs = list(itertools.product(virus_names, host_names))
random.shuffle(pairs)  # fairness: mix across all pairs

seen_pmcids = set()
articles_data = []

for i, (virus, host) in enumerate(pairs, 1):
    if len(articles_data) >= MAX_RESULTS:
        break

    q = make_pair_query(virus, host)
    pmc_ids = pmc_esearch(q, retmax=MAX_PER_PAIR)
    # Small pause to respect NCBI rate limits
    sleep(REQUEST_DELAY_S)

    if not pmc_ids:
        continue

    for pmcid in pmc_ids:
        if pmcid in seen_pmcids:
            continue
        art = fetch_article_info(pmcid)
        sleep(REQUEST_DELAY_S)
        if not art:
            continue

        if not art.get("FullText"):
            continue
        seen_pmcids.add(pmcid)
        articles_data.append(art)

        if len(articles_data) >= MAX_RESULTS:
            break

print(f"Collected {len(articles_data)} unique PMC articles!")

if not articles_data:
    raise RuntimeError("No PMC articles retrieved!.")

df_articles = pd.DataFrame(articles_data)


# Clean + sentence-split (on the reduced text)
processed = []
for _, row in df_articles.iterrows():
    ft = clean_text(row.get("FullText", ""), lowercase=True)
    if not ft:
        continue

    ft = re.sub(r"\s+", " ", ft)
    ft = re.sub(r"\b([A-Z])\.", r"\1<prd>", ft)
    sents = sent_tokenize(ft.replace("<prd>", "."))
    for s in sents:
        s = s.strip()
        if not s:
            continue
        processed.append({
            "PMCID": row.get("PMCID", ""),
            "Title": row.get("Title", ""),
            "Year": row.get("Year", ""),
            "Sentence": s
        })

df_sentences = pd.DataFrame(processed)
df_sentences.to_csv(SENTENCES_CSV, index=False)
print(f"Saved sentences to {SENTENCES_CSV}")
print("Total sentences loaded:", len(df_sentences))


***
# Model Development 
***

In [None]:
# NER pipeline
model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple", device=0 if torch.cuda.is_available() else -1)

In [None]:
# Load seed dataset
df = pd.read_excel("trainingdataset.xlsx")

# Build Marked_Sentence by tagging virus and host mentions
def mark_entities(row):
    sent = str(row["Sentence"])
    virus = row.get("virus", None)
    host = row.get("host", None)

    # Replace only if not null and appears in the sentence
    if pd.notna(virus):
        virus = str(virus).strip()
        if virus and virus in sent:
            sent = sent.replace(virus, f"[VIRUS]{virus}[/VIRUS]", 1)

    if pd.notna(host):
        host = str(host).strip()
        if host and host in sent:
            sent = sent.replace(host, f"[HOST]{host}[/HOST]", 1)

    return sent

df["Marked_Sentence"] = df.apply(mark_entities, axis=1)

In [None]:
# Training hyperparams
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

MODEL_NAME = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
FINETUNED_DIR = "./ner_finetuned_lora"

In [None]:
# Extract tokens + BIO labels

def sentence_to_bio(marked_sentence):
    """
    Convert [VIRUS] and [HOST] markup into BIO token-label pairs
    """
    tokens = []
    labels = []

    # Replace markers with special tokens
    text = (
        marked_sentence
        .replace("[VIRUS]", " [VIRUS] ")
        .replace("[/VIRUS]", " [/VIRUS] ")
        .replace("[HOST]", " [HOST] ")
        .replace("[/HOST]", " [/HOST] ")
    )

    words = text.split()
    current_label = "O"

    for word in words:
        if word == "[VIRUS]":
            current_label = "B-VIRUS"
            continue
        elif word == "[/VIRUS]":
            current_label = "O"
            continue
        elif word == "[HOST]":
            current_label = "B-HOST"
            continue
        elif word == "[/HOST]":
            current_label = "O"
            continue

        tokens.append(word)
        labels.append(current_label)

        # switch from B- to I- if inside entity
        if current_label == "B-VIRUS":
            current_label = "I-VIRUS"
        elif current_label == "B-HOST":
            current_label = "I-HOST"

    return tokens, labels


df["tokens_labels"] = df["Marked_Sentence"].apply(sentence_to_bio)
df["tokens"] = df["tokens_labels"].apply(lambda x: x[0])
df["ner_tags"] = df["tokens_labels"].apply(lambda x: x[1])

In [None]:
dataset = Dataset.from_pandas(df[["tokens", "ner_tags"]])

# Label mappings
label_list = ["O", "B-VIRUS", "I-VIRUS", "B-HOST", "I-HOST"]
label_to_id = {l: i for i, l in enumerate(label_list)}
id_to_label = {i: l for l, i in label_to_id.items()}

dataset = dataset.map(lambda ex: {"ner_tags": [label_to_id[l] for l in ex["ner_tags"]]})

# Tokenizer + alignment
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_and_align_labels(examples):
    tokenized = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=256,
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        aligned = []
        prev_word_id = None
        for word_id in word_ids:
            if word_id is None:
                aligned.append(-100)
            elif word_id != prev_word_id:
                aligned.append(label[word_id])
            else:
                aligned.append(-100)
            prev_word_id = word_id
        labels.append(aligned)
    tokenized["labels"] = labels
    return tokenized

dataset = dataset.map(tokenize_and_align_labels, batched=True)


In [None]:
# Fine-tuning the Model

model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME, num_labels=len(label_list), ignore_mismatched_sizes=True
)
data_collator = DataCollatorForTokenClassification(tokenizer)

# Training arguments
training_args = TrainingArguments(
    output_dir=f"{FINETUNED_DIR}/final_model",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir=f"{FINETUNED_DIR}/logs/final",
    logging_steps=50,
    seed=SEED,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
# Train on seed dataset
trainer.train()

# Save final model
trainer.save_model(f"{FINETUNED_DIR}/final_model")

In [None]:
# Load fine-tuned model
model_path = f"{FINETUNED_DIR}/final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512)
model = AutoModelForTokenClassification.from_pretrained(model_path)

# Create NER pipeline using fine-tuned model
ner_pipeline = pipeline("ner", model=model,tokenizer=tokenizer, grouped_entities=True)

In [None]:
# Helper functions

def mark_entities_safe(sentence, virus, host):
    """
    Safely replace entities in a sentence with special tokens [VIRUS] and [HOST].
    Handles multiple occurrences, case-insensitivity, and overlapping entities.
    Returns None if either entity is missing or identical.
    """
    if not virus or not host or virus.strip().lower() == host.strip().lower():
        return None

    # Escape for regex and word-boundary matching
    virus_pat = r'\b' + re.escape(virus.strip()) + r'\b'
    host_pat = r'\b' + re.escape(host.strip()) + r'\b'

    # Sort by length to avoid partial replacements
    entities = sorted(
        [(virus_pat, "[VIRUS]\\g<0>[/VIRUS]"), (host_pat, "[HOST]\\g<0>[/HOST]")],
        key=lambda x: len(x[0]),
        reverse=True
    )

    s_temp = sentence
    for pat, repl in entities:
        s_temp, n_subs = re.subn(pat, repl, s_temp, flags=re.IGNORECASE)
        if n_subs == 0:
            return None  # Entity not found at all

    return s_temp


def dictionary_match(sentence):
    sentence_lower = sentence.lower()
    virus_matches, host_matches = [], []

    for v in virus_names:
        pattern = rf"\b{re.escape(v)}\b"
        for match in re.finditer(pattern, sentence_lower):
            virus_matches.append({
                "text": v,
                "start": match.start(),
                "end": match.end()
            })

    for h in host_names:
        pattern = rf"\b{re.escape(h)}\b"
        for match in re.finditer(pattern, sentence_lower):
            host_matches.append({
                "text": h,
                "start": match.start(),
                "end": match.end()
            })

    return virus_matches, host_matches


In [None]:
# Negation Detection
# Load spaCy model
nlp = spacy.load("en_core_sci_md")

# Load Excel file
df_neg_rules = pd.read_excel("customnegexrules.xlsx",sheet_name = "rules")
df_relations = pd.read_excel("customnegexrules.xlsx", sheet_name="relations")

# Convert DataFrame into dictionary
custom_rules = {
    col: [str(x) for x in df_neg_rules[col].dropna().tolist()]
    for col in df_neg_rules.columns
}

# Add entity ruler with host/virus dictionary
patterns = [{"label": "HOST", "pattern": name} for name in host_names] + \
           [{"label": "VIRUS", "pattern": name} for name in virus_names]


if "entity_ruler" not in nlp.pipe_names:
    ruler = nlp.add_pipe("entity_ruler", before="ner")
    ruler.add_patterns(patterns)

# Register custom Negex
if "custom_negex" not in nlp.pipe_names:
    nlp.add_pipe(
        "negex",
        config={
            "ent_types": ["VIRUS", "HOST"],
            "neg_termset": custom_rules,
            "extension_name": "is_negated"
        },
        last=True
    )

relation_negations = set(df_relations["relation_negations"].dropna().tolist())
relation_verbs = set(df_relations["relation_verbs"].dropna().tolist())
medium_verbs = set(df_relations["medium_verbs"].dropna().tolist())

# Relation verbs
relation_single = {v for v in relation_verbs if " " not in v}
relation_multi = {v.lower() for v in relation_verbs if " " in v}
relation_single_lemmas = {nlp(v)[0].lemma_.lower() for v in relation_single}

# Medium verbs
medium_single = {v for v in medium_verbs if " " not in v}
medium_multi = {v.lower() for v in medium_verbs if " " in v}
medium_single_lemmas = {nlp(v)[0].lemma_.lower() for v in medium_single}

# Precompile negation patterns for efficiency
_relation_neg_patterns = [re.compile(rf"\b{re.escape(neg)}\b") for neg in relation_negations]
_termination_pattern = re.compile(rf"\b({'|'.join(map(re.escape, custom_rules['termination']))})\b")


def is_relation_negated(doc, virus, host, relation_negations=None, termination=None):
    """
    Detect if a virus-host relation is negated.
    Uses regex, dependency parse, and termination cues.
    """
    relation_negations = relation_negations or []
    termination = termination or custom_rules["termination"]

    virus_lc, host_lc = virus.lower(), host.lower()

    for sent in doc.sents:
        sent_text = sent.text.lower()

        # Skip sentences missing virus/host
        if virus_lc not in sent_text or host_lc not in sent_text:
            continue

        # Split sentence into smaller clauses at termination cues
        clauses = re.split(rf"\b({'|'.join(map(re.escape, termination))})\b", sent_text)

        for clause in clauses:
            clause = clause.strip()
            if not clause:
                continue

            # Check only clauses containing both entities
            if virus_lc in clause and host_lc in clause:
                clause_doc = nlp(clause)

                # Find relation verbs inside clause
                relation_verbs_in_clause = [tok for tok in clause_doc if tok.lemma_.lower() in relation_verbs]

                for verb in relation_verbs_in_clause:
                    # Ensure virus and host appear in verb subtree
                    subtree_text = " ".join(t.text.lower() for t in verb.subtree)
                    if virus_lc in subtree_text and host_lc in subtree_text:

                        # Regex-based negation matches
                        if any(neg in clause for neg in relation_negations):
                            return True

                        # Dependency-based: negation directly on relation verb
                        if any(child.dep_ == "neg" for child in verb.children):
                            return True

    return False


label_map = {"B-VIRUS": "VIRUS","I-VIRUS": "VIRUS","B-HOST": "HOST",
              "I-HOST": "HOST", "O": "O","VIRUS": "VIRUS","HOST": "HOST"}


In [None]:
# Initialize results

relation_results = []
seen_pairs_by_paper = {}

# Main processing loop
for idx, row in tqdm(df_sentences.iterrows(), total=len(df_sentences)):
    sentence = str(row["Sentence"])
    pmcid = row.get("PMCID", "")

    try:

        # Extract entities
        hf_entities = ner_pipeline(sentence)
        virus_entities, host_entities = [], []

        for ent in hf_entities:
            raw_label = ent.get("entity_group", ent.get("entity", "O"))
            label = label_map.get(raw_label, raw_label).upper()
            word = ent.get("word", "").replace("##", "")
            word_norm = re.sub(r'\s+', ' ', word).strip()
            start, end = ent.get("start"), ent.get("end")

            if "VIRUS" in label:
                virus_entities.append({"text": word_norm, "start": start, "end": end, "source": "ner"})
            elif "HOST" in label:
                host_entities.append({"text": word_norm, "start": start, "end": end, "source": "ner"})

        # Add dictionary matches
        dict_viruses, dict_hosts = dictionary_match(sentence)
        for v in dict_viruses: v["source"] = "dict"
        for h in dict_hosts: h["source"] = "dict"
        virus_entities.extend(dict_viruses)
        host_entities.extend(dict_hosts)

        if not virus_entities or not host_entities:
            continue

        doc = nlp(sentence)
        spacy_entities = {(ent.start_char, ent.end_char): ent
                          for ent in doc.ents if ent.label_ in ["VIRUS", "HOST"]}

        if pmcid not in seen_pairs_by_paper:
            seen_pairs_by_paper[pmcid] = set()


        # Pairwise analysis
        for v in virus_entities:
            for h in host_entities:
                pair_key = (v["text"].lower(), h["text"].lower())
                if pair_key in seen_pairs_by_paper[pmcid]:
                    continue
                seen_pairs_by_paper[pmcid].add(pair_key)


                # Negation detection
                negated_flag = False
                for (start_char, end_char), spa_ent in spacy_entities.items():
                    if (start_char <= v["start"] < end_char or start_char < v["end"] <= end_char) or \
                       (start_char <= h["start"] < end_char or start_char < h["end"] <= end_char):
                        if getattr(spa_ent._, "is_negated", False):
                            negated_flag = True
                            break
                if is_relation_negated(doc, v["text"], h["text"]):
                    negated_flag = True


                # Confidence scoring
                components = []

                # Source reliability
                v_conf = 1.0 if v["source"] == "ner" else 0.6
                h_conf = 1.0 if h["source"] == "ner" else 0.6
                components.append((v_conf + h_conf) / 2)

                # Distance factor
                v_span, h_span = doc.char_span(v["start"], v["end"]), doc.char_span(h["start"], h["end"])
                if v_span and h_span:
                    token_distance = abs(v_span.start - h_span.start)
                    sentence_len = len(doc)
                    if token_distance <= max(12, sentence_len * 0.03):
                        components.append(1.0)
                    elif token_distance <= max(20, sentence_len * 0.06):
                        components.append(0.7)
                    else:
                        components.append(0.3)
                else:
                    components.append(0.0)

                # Verb evidence
                tokens_lemmas = {t.lemma_.lower() for t in doc}
                doc_text = doc.text.lower()

                verb_score = 0.0
                if tokens_lemmas & relation_single_lemmas or any(phrase in doc_text for phrase in relation_multi):
                    verb_score = 1.0
                elif tokens_lemmas & medium_single_lemmas or any(phrase in doc_text for phrase in medium_multi):
                    verb_score = 0.6
                components.append(verb_score)

                # Frequency bonus
                pair_count = list(seen_pairs_by_paper[pmcid]).count(pair_key)
                freq_score = min(pair_count / 5, 1.0)
                components.append(freq_score)

                # Combine components
                raw_score = sum(components) / len(components)
                confidence = max(0.0, min(1.0, raw_score ** 0.5))  # sqrt boost

                # Classification
                if confidence < 0.45:
                    relation_type = "No Relation"
                    relation_status = "N/A"

                elif confidence < 0.55:
                    relation_type = "Uncertain"
                    relation_status = "Uncertain"
                else:
                    relation_type = "Related"
                    relation_status = "Fail" if negated_flag else "Successful"


                # Save result
                marked_sentence = mark_entities_safe(sentence, v["text"], h["text"])
                if marked_sentence:
                    relation_results.append({
                        "PMCID": pmcid,
                        "Sentence": sentence,
                        "Marked_Sentence": marked_sentence,
                        "Virus_Entity": v["text"],
                        "Host_Entity": h["text"],
                        "Negated": negated_flag,
                        "Confidence": round(confidence, 2),
                        "Relation_Type": relation_type,
                        "Relation_Status": relation_status
                    })

    except Exception as e:
        print(f"Error processing row {idx}: {e}")
        continue

# Convert to DataFrame
df_relation_results = pd.DataFrame(relation_results)
print("Total RE candidate pairs created:", len(df_relation_results))
print("Unique virus-host pairs:", df_relation_results[["Virus_Entity", "Host_Entity"]].drop_duplicates().shape[0])
df_relation_results.head()


In [None]:
# Results normalization
# Build mapping
virus_map = {}
for virus_list in df_kb["virus"].dropna().astype(str).str.lower().str.split(","):
    canonical = virus_list[0].strip()
    for synonym in virus_list:
        virus_map[synonym.strip()] = canonical

host_map = {}
for host_list in df_kb["host"].dropna().astype(str).str.lower().str.split(","):
    canonical = host_list[0].strip()
    for synonym in host_list:
        host_map[synonym.strip()] = canonical


# Normalize case
df_relation_results["Virus_Entity"] = df_relation_results["Virus_Entity"].str.lower().str.strip()
df_relation_results["Host_Entity"] = df_relation_results["Host_Entity"].str.lower().str.strip()

# Map to canonical names
df_relation_results["Virus_Entity"] = df_relation_results["Virus_Entity"].map(virus_map).fillna(df_relation_results["Virus_Entity"])
df_relation_results["Host_Entity"] = df_relation_results["Host_Entity"].map(host_map).fillna(df_relation_results["Host_Entity"])

# Deduplicate by keeping only the row with the highest Confidence
df_relation_results = (
    df_relation_results.loc[
        df_relation_results.groupby(
            ["PMCID", "Sentence", "Virus_Entity", "Host_Entity"]
        )["Confidence"].idxmax()
    ]
    .reset_index(drop=True)
)

In [None]:
df_relation_results.to_csv(FINAL_RELATION_CSV, index=False)

***
# Visualization of Virus-Host Networks
***

In [None]:

# Filter edges: keep only "Successful" relations
successful_edges = df_relation_results[df_relation_results["Relation_Status"] == "Successful"]

# Build the full graph G
G = nx.Graph()
viruses = set(successful_edges["Virus_Entity"])
hosts = set(successful_edges["Host_Entity"])

G.add_nodes_from(viruses, group='virus')
G.add_nodes_from(hosts, group='host')
for _, row in successful_edges.iterrows():
    G.add_edge(row["Virus_Entity"], row["Host_Entity"], status=row["Relation_Status"])

# Filter G to create a smaller graph H with only top nodes
top_viruses = sorted(G.degree, key=lambda x: x[1], reverse=True)[:60]
top_hosts = sorted(G.degree, key=lambda x: x[1], reverse=True)[:80]

top_nodes = [node for node, degree in top_viruses] + [node for node, degree in top_hosts]
H = G.subgraph(top_nodes)

# Update node sets to reflect the smaller graph H
viruses = {n for n in H.nodes() if H.nodes[n]['group'] == 'virus'}
hosts = {n for n in H.nodes() if H.nodes[n]['group'] == 'host'}

# Use H for plotting
inner_radius = 0.5
outer_radius = 2.0

pos = {}
for i, virus in enumerate(viruses):
    angle = 2 * np.pi * i / len(viruses)
    pos[virus] = (inner_radius * np.cos(angle), inner_radius * np.sin(angle))

for i, host in enumerate(hosts):
    angle = 2 * np.pi * i / len(hosts)
    pos[host] = (outer_radius * np.cos(angle), outer_radius * np.sin(angle))

edge_x, edge_y = [], []
for edge in H.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='green'),
    hoverinfo='none',
    mode='lines',
    name='Successful Relations'
)

def create_node_trace(nodes, color, font_size):
    node_x, node_y, node_text, node_size = [], [], [], []
    for node in nodes:
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)
        node_size.append(H.degree(node) * 5)

    return go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=node_text,
        textposition="top center",
        hoverinfo='text',
        marker=dict(size=node_size, color=color, line_width=2),
        name=f"{color.title()} Nodes",
        textfont=dict(size=font_size)
    )

virus_trace = create_node_trace(viruses, 'gray', font_size=12)
host_trace = create_node_trace(hosts, 'blue', font_size=8)

fig = go.Figure(data=[edge_trace, virus_trace, host_trace],
                layout=go.Layout(
                    title='ðŸ§« Virusâ€“Host Network',
                    titlefont_size=22,
                    width=1400,
                    height=1000,
                    showlegend=True,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    annotations=[dict(
                        text="âšª Viruses (inner), ðŸ”µ Hosts (outer) | Node size = degree",
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002
                    )],
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)

fig.show()