In [None]:
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
n_gpus = torch.cuda.device_count()
print("N GPUS: ", n_gpus)

# Set memory limits per GPU
model_vram_limit_mib = 8192  #12000
max_memory = f'{model_vram_limit_mib}MiB'
max_memory_dict = {i: max_memory for i in range(n_gpus)}

model_id = "meta-llama/Llama-3.2-3B-Instruct"

# Load model and tokenizer with memory limits
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    max_memory=max_memory_dict
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# Now create the pipeline using pre-loaded model/tokenizer
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer
)


In [None]:
# Prepare prompt
prompt = "You are a pirate chatbot who always responds in pirate speak!\nUser: Who are you?\nPirateBot:"

# Run generation
outputs = pipe(
    prompt,
    max_new_tokens=256,
)

print(outputs[0]["generated_text"])


In [None]:
import pandas as pd

# Custom system prompt for summarizing + multi-class subtype extraction
system_prompt = """You are an expert immunologist and biomedical research assistant.
TASK: Analyze the provided text on a vaccine adjuvant's immune response. Extract and structure the key mechanistic information according to the specified JSON schema.

## Instructions for the "summary" field:
- **Synthesize the information into a cohesive, mechanistic narrative of approximately 3-5 sentences.**
- This summary should not be a simple list of facts. Instead, it should describe the sequence of immunological events initiated by the adjuvant.
- For example, describe how the adjuvant is initially sensed (e.g., by PRRs like TLRs), how this leads to innate cell activation (e.g., dendritic cells), and how this subsequently shapes the adaptive response (e.g., T cell polarization and antibody production).
- Integrate the corresponding PMIDs directly into the text immediately following the claims they support.

## Instructions for the "mechanism_subtypes" field:
- Identify **every distinct** immunological mechanism.
- For each identified subtype, list all unique PMIDs cited as evidence for it in the source text.
- Do not merge related subtypes; for example, if both "dendritic cell" and "TLR4" are mentioned, create separate entries for each.

## General Rules:
- **Strict JSON Output:** The entire response MUST be a single, valid JSON object with no surrounding text or explanations.
- **Source Adherence:** Use ONLY the information and PMIDs present in the provided text. Do not infer or add external knowledge.

## JSON Schema:
{
  "adjuvant": "<string>",
  "summary": "<A cohesive, mechanistic narrative of 3-5 sentences describing the sequence of immune events, with inline PMIDs.>",
  "mechanism_subtypes": [
    {
      "mechanism subtype": "<mechanism subtype_1>",
      "evidence_refs": ["########", "..."]
    },
    {
      "mechanism subtype": "<mechanism subtype_2>",
      "evidence_refs": ["########", "..."]
    },...
  ]
}

"""

# Load one row from your collapsed CSV
df = pd.read_csv("outputs/Vaxjo_PMIDs_adjuvant_mechanism_collapsed.csv")
# Rename the column
df = df.rename(columns={"adjuvant_canonical": "adjuvant"})

row = df.iloc[13]   # <-- change the index if you want a different row

adjuvant = str(row["adjuvant"])
mechanism = str(row["immune_response_mechanism"])

# Create messages in chat format
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": f"Adjuvant: {adjuvant}\nImmune response mechanism:\n{mechanism}"}
]

# Run generation
outputs = pipe(
    messages,
    max_new_tokens=4028,
)

# Print model output (raw JSON string)
raw = outputs[0]["generated_text"][-1]["content"]
print(raw)


In [None]:
row = df.iloc[13] 
print(str(row["immune_response_mechanism"]))

In [None]:
# Iterate over the whole DataFrame, run generation, and save raw outputs to a .txt file
# Assumes you already have: df, system_prompt, and pipe(...) defined.

import json
import os

OUT_TXT = "outputs/Vaxjo_PMIDs_mechanism_summary_raw_outputs_llama3.2.txt"   # plain text (human-readable)
# (optional) also keep a machine-friendly JSONL:
OUT_JSONL = "outputs/Vaxjo_PMIDs_mechanism_summary_raw_outputs_llama3.2.jsonl"

os.makedirs(os.path.dirname(OUT_TXT) or ".", exist_ok=True)

# If you want to change token budget, tweak here:
GEN_MAX_NEW_TOKENS = 4028

# Open files once and append per row (flush to avoid losing progress mid-run)
with open(OUT_TXT, "w", encoding="utf-8") as f_txt, open(OUT_JSONL, "w", encoding="utf-8") as f_jsonl:
    for idx, row in df.iterrows():
        adjuvant = str(row.get("adjuvant", ""))
        mechanism = str(row.get("immune_response_mechanism", ""))

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"Adjuvant: {adjuvant}\nImmune response mechanism:\n{mechanism}"}
        ]

        status = "ok"
        try:
            outputs = pipe(messages, max_new_tokens=GEN_MAX_NEW_TOKENS)
            raw = outputs[0]["generated_text"][-1]["content"]
        except Exception as e:
            status = "error"
            raw = f"__ERROR__: {e}"

        # ---- Write human-readable TXT ----
        header = f"===== ROW {idx} | {adjuvant} | {status} =====\n"
        f_txt.write(header)
        f_txt.write((raw or "").strip() + "\n\n")
        f_txt.flush()

        # ---- (Optional) also write JSONL per row ----
        f_jsonl.write(json.dumps({
            "row_index": int(idx),
            "adjuvant": adjuvant,
            "status": status,
            "raw": raw
        }, ensure_ascii=False) + "\n")
        f_jsonl.flush()

print(f"Saved outputs to:\n- {OUT_TXT}\n- {OUT_JSONL} (optional JSONL)")


In [None]:
print("Done")

In [None]:
import json
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd

# Define the input file path
INPUT_FILE = "outputs/Vaxjo_PMIDs_mechanism_summary_raw_outputs_llama3.2.jsonl"

def analyze_subtype_frequency(filepath):
    """
    Reads a JSONL file, analyzes the frequency of mechanism subtypes,
    and generates a report and a bar chart.
    """
    all_subtypes = []
    
    print(f"Reading data from {filepath}...")
    
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    # Each line is a JSON object
                    data = json.loads(line)
                    
                    # The 'raw' field contains the actual JSON string output by the LLM
                    # We need to parse this inner JSON string
                    raw_output_str = data.get("raw")
                    if not raw_output_str:
                        continue
                        
                    inner_data = json.loads(raw_output_str)
                    
                    # Extract the list of mechanism subtypes
                    subtypes_list = inner_data.get("mechanism_subtypes", [])
                    
                    # Append each subtype's name to our master list
                    for subtype_info in subtypes_list:
                        subtype_name = subtype_info.get("mechanism subtype")
                        if subtype_name:
                            all_subtypes.append(subtype_name)
                            
                except json.JSONDecodeError:
                    # Handle cases where a line or the 'raw' string is not valid JSON
                    print(f"Warning: Skipping a line due to JSON decoding error.")
                    continue

    except FileNotFoundError:
        print(f"Error: The file '{filepath}' was not found.")
        return

    if not all_subtypes:
        print("No mechanism subtypes were found in the file.")
        return

    print("\n--- Analysis Complete ---")

    # Use collections.Counter to count the frequencies
    frequency_counts = Counter(all_subtypes)

    # --- Print the Top 20 Most Common Subtypes ---
    print("\nTop 20 Most Common Mechanism Subtypes:")
    for subtype, count in frequency_counts.most_common(20):
        print(f"- {subtype}: {count}")

    # --- Generate and Save the Bar Chart ---
    # Convert the Counter object to a pandas DataFrame for easy plotting
    df = pd.DataFrame(frequency_counts.most_common(), columns=['Subtype', 'Frequency'])
    
    # Let's plot the top 25 for better readability
    df_plot = df.head(25)

    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Create the horizontal bar plot
    ax.barh(df_plot['Subtype'], df_plot['Frequency'], color='skyblue')
    
    # Invert the y-axis to have the most frequent on top
    ax.invert_yaxis()
    
    ax.set_xlabel('Frequency Count', fontsize=12)
    ax.set_title('Frequency of Top 25 Identified Immune Mechanism Subtypes', fontsize=16, pad=20)
    ax.tick_params(axis='y', labelsize=10)

    # Add the count labels on the bars
    for i, v in enumerate(df_plot['Frequency']):
        ax.text(v + 0.5, i, str(v), color='gray', va='center', fontweight='medium')

    # Ensure everything fits
    plt.tight_layout()

    # Save the figure
    output_image_file = "outputs/Vaxjo_PMIDs_mechanism_subtype_frequency.png"
    plt.savefig(output_image_file, dpi=300)

    print(f"\nðŸ“Š Bar chart saved as '{output_image_file}'")


if __name__ == "__main__":
    analyze_subtype_frequency(INPUT_FILE)