# Finding Significance Classification

## Inference using the selected model

In order to predict the significance of a trial's findings, we have evaluated a few models. We settled on the BART zero-shot model to provide these predictions. As the database is rather large, we perform this classification step on the cluster.

In [None]:
! nvidia-smi

In [None]:
from pathlib import Path

In [None]:
PATH_INFERENCE_DATASET = Path("ids_to_abstracts_for_inference.parquet")
PATH_INFERENCE_RESULTS_CACHE = Path("prediction_results.jsonl")
PATH_INFERENCE_RESULTS = Path("ids_to_significance_predictions_zeroshot.parquet")
MODEL = "facebook/bart-large-mnli"
CANDIDATE_LABELS = ["significant effect", "no significant effect"]
DEVICE = "cuda:3"
BATCH_SIZE = 8

### Load the data

In [8]:
import pandas as pd

In [None]:
df_combined = pd.read_parquet(PATH_INFERENCE_DATASET)

In [None]:
df_combined

### Model Setup

In [7]:
from transformers import AutoTokenizer

TOKENIZER = AutoTokenizer.from_pretrained(
    MODEL, truncation="only_first", truncation_side="left", model_max_length=1024
)

In [None]:
from transformers import pipeline
import torch

pipe_zs = pipeline(
    "zero-shot-classification",
    model=MODEL,
    tokenizer=TOKENIZER,
    candidate_labels=CANDIDATE_LABELS,
    device=DEVICE,
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True,
)

In [None]:
import jsonlines
from transformers.pipelines import Pipeline
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset
from datasets import Dataset


def predict_significance_labels(
    ds: Dataset,
    id_col: str,
    feature_col: str,
    pipe: Pipeline,
    device: str,
    batch_size: int,
    output_file: str,
    output_batch_size: int,
) -> None:
    output_file = Path(output_file)
    if output_file.exists():
        with jsonlines.open(output_file) as reader:
            processed_ids = [item.get(id_col) for item in reader]

    else:
        processed_ids = []

    filtered_ds = ds.filter(lambda row: row[id_col] not in processed_ids)

    batch = []
    for doc_id, pred in tqdm(
        zip(
            filtered_ds[id_col],
            pipe(
                KeyDataset(filtered_ds, feature_col),
                batch_size=batch_size,
                device=device,
                return_all_scores=True,
            ),
        ),
        desc="Running inference",
        total=len(filtered_ds),
    ):
        batch.append(
            {id_col: doc_id, "labels": pred["labels"], "scores": pred["scores"]}
        )
        processed_ids.append(doc_id)

        if len(batch) == output_batch_size:
            with jsonlines.open(output_file, mode="a") as writer:
                writer.write_all(batch)
            batch = []

    with jsonlines.open(output_file, mode="a") as writer:
        writer.write_all(batch)

### Inference

In [None]:
ds = Dataset.from_pandas(df_combined)

In [None]:
predict_significance_labels(
    ds=ds,
    id_col="pm_id",
    feature_col="abstract",
    pipe=pipe_zs,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    output_file=PATH_INFERENCE_RESULTS_CACHE,
    output_batch_size=10000,
)

In [None]:
df_results = pd.read_json(PATH_INFERENCE_RESULTS_CACHE, lines=True)
df_results["predicted_label"] = df_results["labels"].str[0]
df_results[f"prob_{CANDIDATE_LABELS[0]}"] = df_results.apply(
    lambda row: row["scores"][row["labels"].index(CANDIDATE_LABELS[0])], axis=1
).astype(float)
df_results[f"prob_{CANDIDATE_LABELS[1]}"] = df_results.apply(
    lambda row: row["scores"][row["labels"].index(CANDIDATE_LABELS[1])], axis=1
).astype(float)
df_results["has_significant_effect"] = (
    df_results["predicted_label"] == CANDIDATE_LABELS[0]
)
df_results.drop(columns=["labels", "scores"], inplace=True)
df_results.to_parquet(PATH_INFERENCE_RESULTS, compression="gzip")