In [1]:
import torch, json
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [3]:
df = pd.read_parquet("data/LRC_articles.parquet", engine="pyarrow")

In [14]:
test_text = df['content'].iloc[0]
print(len(test_text))
test_text

6610


'The term “integrative therapies” describes the carefully designed regimen that combines both medicinal treatments and non-medicinal ones to maximize positive outcomes and minimize the negative impacts of cancer and its treatment on quality of life. The field of integrative oncology is targeted toward reducing the ferocious side effects of standard chemotherapy and radiation treatments that can occur when therapies intended to target cancer cells also damage healthy ones or disrupt a patient’s normal body processes.\n“Standard” or conventional treatments are medicinal treatments for which scientific evidence from highly structured clinical trials has demonstrated positive control over cancer’s progression. Such therapies include surgery, radiation, chemotherapy, bone marrow transplants, immunotherapies, and other techniques for which evidence shows positive impacts on the course of the disease.\nAdvertisement\nDespite proven and often curative health benefits from standard treatments, 

In [15]:
# Load saved model and assets
save_dir = "mmf_deberta_v3/best"
tok = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

with open(f"{save_dir}/frames.json") as f:
    frames = json.load(f)
with open(f"{save_dir}/threshold.json") as f:
    best_thresh = json.load(f)["global"]

def Predict_frames(article: str):
    """Return frames for a given article text"""
    inputs = tok(article, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    # Run model
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    # Apply threshold
    preds = (probs >= best_thresh).astype(int)

    # Collect results
    results = [
        {"frame": f, "prob": float(p), "predicted": bool(pred)}
        for f, p, pred in zip(frames, probs, preds)
    ]
    predicted_frames = [r["frame"] for r in results if r["predicted"]]
    return {"predicted_frames": predicted_frames, "details": results}


In [16]:
out = Predict_frames(test_text)
print(out["predicted_frames"])

['crime', 'health', 'policy', 'quality_life']
