In [None]:
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
from huggingface_hub import login

# Replace 'your_token_here' with your actual Hugging Face token
login(token="your_token_here")

print("Successfully logged in!")


In [None]:

n_gpus = torch.cuda.device_count()
print("N GPUS: ", n_gpus)

# Set memory limits per GPU
model_vram_limit_mib = 8192  #12000
max_memory = f'{model_vram_limit_mib}MiB'
max_memory_dict = {i: max_memory for i in range(n_gpus)}

model_id = "meta-llama/Llama-3.2-3B-Instruct"

# Load model and tokenizer with memory limits
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    max_memory=max_memory_dict
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# Now create the pipeline using pre-loaded model/tokenizer
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer
)


In [None]:

# Custom system prompt for extracting adjuvants and immune mechanisms
system_prompt = (
        """You are a biomedical assistant with expertise in immunology. Extract all substances explicitly described as vaccine adjuvants from the input article.

For each adjuvant, extract the following:
- "adjuvant": The name of the adjuvant (e.g., Alum, MPLA, QS-21)
- "immune_response_mechanism": A brief description of how the adjuvant works to stimulate or enhance the immune response, as described in the text.

Guidelines:
- Include only substances that are explicitly described as adjuvants in the input.
- Do not include delivery systems (e.g., liposomes, virosomes, VLPs) unless they are clearly described as adjuvants.
- Do not infer or guess mechanisms that are not mentioned; leave them empty if not described.
- If no adjuvants are found, say 'No adjuvants mentioned.
- Return valid JSON in the following format:


[
  {
    "adjuvant": "Alum",
    "immune_response_mechanism": "Activates NLRP3 inflammasome and forms a depot for slow antigen release."
  },
  {
    "adjuvant": "MPLA",
    "immune_response_mechanism": "Engages TLR4 pathway to promote Th1 responses."
  }
]"""
    )

# Load paper (replace path with your paper path)
#with open("Dataset/Dataset_PMC_CleanedXML/PMC8707864.xml", encoding="utf-8") as f:
with open("Dataset/PMC_Filtered_Reviews_plaintext_gemini/PMC8483762.txt", encoding="utf-8") as f:
    paper_text = f.read()

# Create messages in chat format
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": paper_text},
]

# Run generation
outputs = pipe(
    messages,
    max_new_tokens=256,  # increase if needed
)

# Print the extracted adjuvant information
#print(outputs[0]["generated_text"])
print(outputs[0]["generated_text"][-1]["content"])


In [None]:
import json
import os
import csv
import torch
import gc
from datetime import datetime

def extract_and_save_adjuvant_response_text_from_string(
    pmc_id: str,
    text: str,
    system_prompt: str,
    pipe,
    output_file: str,
    max_new_tokens: int = 512
):
    """
    Processes a single abstract string and saves the response with PMID to output_file.
    """
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": text},
    ]
    outputs = pipe(messages, max_new_tokens=max_new_tokens)
    response_text = outputs[0]["generated_text"][-1]["content"]
    with open(output_file, "a", encoding="utf-8") as out_f:
        out_f.write(f"=== {pmc_id} ===\n")
        out_f.write(response_text.strip() + "\n\n")
    print(f"Saved response for {pmc_id} to {output_file}")

def process_all_papers(
    input_json_file: str,
    system_prompt: str,
    pipe,
    output_txt_file: str,
    log_csv_file: str,
    max_new_tokens: int = 512
):
    """
    Processes all abstracts from a JSON file and logs each result (success/failure) to a CSV.
    Args:
        input_json_file (str): Path to input JSON file with {PMID: {title, abstract}} entries.
        system_prompt (str): System prompt for the LLM.
        pipe: Hugging Face pipeline object.
        output_txt_file (str): Path to save model responses (plain text).
        log_csv_file (str): Path to save logs (CSV).
        max_new_tokens (int): Generation token limit.
    """
    # Read all abstracts into a dict
    with open(input_json_file, encoding="utf-8") as f:
        papers_dict = json.load(f)

    print(f"Found {len(papers_dict)} papers. Starting extraction...\n")

    # Prepare log CSV (write headers if file doesn't exist)
    log_exists = os.path.exists(log_csv_file)
    with open(log_csv_file, "a", newline='', encoding="utf-8") as log_f:
        log_writer = csv.writer(log_f)
        if not log_exists:
            log_writer.writerow(["PMID", "Status", "Message", "Timestamp"])

        for i, (pmcid, paper_data) in enumerate(papers_dict.items(), 1):
            abstract_text = paper_data.get("abstract", "") or ""
            status = "success"
            message = "Processed successfully"
            try:
                extract_and_save_adjuvant_response_text_from_string(
                    pmcid, abstract_text, system_prompt, pipe, output_txt_file, max_new_tokens
                )
            except Exception as e:
                status = "error"
                message = str(e)[:500]
                print(f"❌ Error processing PMID {pmcid}: {message}")
            else:
                print(f"✅ [{i}/{len(papers_dict)}] Processed PMID {pmcid}")
            finally:
                log_writer.writerow([pmcid, status, message, datetime.now().isoformat()])
                log_f.flush()
                torch.cuda.empty_cache()
                gc.collect()


In [None]:
import os
import gc
import csv
import torch
from datetime import datetime

input_json_file = "Dataset/Vaxjo/All PMID abstracts.txt"  # should contain JSON, not plain text
output_txt_file = "outputs/Vaxjo_PMIDs_adjuvant_responses.txt"
log_csv_file = "outputs/Vaxjo_PMIDs_adjuvant_log.csv"

process_all_papers(
    input_json_file,
    system_prompt,
    pipe,
    output_txt_file,
    log_csv_file,
    max_new_tokens=512
)
