In [None]:
import os
import sys
sys.path.append('../')

import pandas as pd
import torch
from tqdm.auto import tqdm

%load_ext autoreload
%autoreload 2

import src.utils

In [None]:
# Where the df data is stored 
DF_DATA_LOCATION = "../data/discq_questions_final.csv"

# Where to load the text files from
RELATION_TXT_FOLDER_LOC = "../data/relations_txt"

# Where to save the preprocessed data to
SAVE_LOCATION = "../data/clinical_qg_3_prior_2_future"

In [None]:
from src.t0_code.prepare_data_for_t0 import build_dataset

In [None]:
df = pd.read_csv(DF_DATA_LOCATION)
df

In [None]:
df.iloc[0]

In [None]:
question = df.iloc[0].question
trigger = df.iloc[0].reasoning
doc_id = df.iloc[0].id

shift = len(doc_id) - 1
span = df.iloc[0].start_index - shift, df.iloc[0].end_index - shift

In [None]:
with open(f"../data/relations_txt/{doc_id}", "r") as f:
    doc_text = f.read()


def extract_sentence_with_trigger(doc_text, trigger_span, expected_trigger_text=None, n_prior_sentences=0, n_future_sentences=0):
    """
    Extracts the sentence with the trigger.
    
    If expected_trigger_text is not None,
        checks if the trigger is the expected_trigger_text and tries to find it in the neighboring sentences (plus minus 20 sentences).
        If the trigger is not found, raises RuntimeError.

    Args:
        n_future_sentences: Number of sentences to extract after the trigger.
        n_prior_sentences: Number of sentences to extract before the trigger.
    """
    doc_sentences = doc_text.splitlines(keepends=True)
    # doc_sentences = [t for t in doc_text.split("\n") if len(t) > 0]
    sentence_lengths = [len(s) for s in doc_sentences]

    current_index = 0
    sentence_index = None
    for i, l in enumerate(sentence_lengths):
        if trigger_span[0] <= current_index + l:
            sentence_index = i
            break
        current_index += l

    if sentence_index is None:
        raise Exception("Could not find sentence with trigger by span")

    if expected_trigger_text is not None:
        if expected_trigger_text not in doc_sentences[sentence_index]:
            # Try to find the trigger in the neighboring sentences
            for i in range(1, 20):
                if expected_trigger_text in doc_sentences[sentence_index - i]:
                    sentence_index = sentence_index - i
                    break
                if expected_trigger_text in doc_sentences[sentence_index + i]:
                    sentence_index = sentence_index + i
                    break
            else:
                raise RuntimeError(f"Could not find expected trigger text {expected_trigger_text} in the document with name {doc_id}")
    
    sentence_index_start = max(0, sentence_index - n_prior_sentences)
    sentence_index_end = min(len(doc_sentences), sentence_index + n_future_sentences + 1)
    sentence = " ".join(doc_sentences[sentence_index_start:sentence_index_end])
    return sentence


In [None]:
sentence = extract_sentence_with_trigger(doc_text, span, n_prior_sentences=3, n_future_sentences=1)
sentence

In [None]:
prompt = '''{text}\nAfter reading the above EMR, what question do you have about "{trigger}"?\nQuestion:'''

In [None]:
print(prompt.format(text=sentence, trigger=trigger))

# Building pre-processor based on the above logic

In [None]:
import datasets


def build_dataset(
    sql_file_path,
    documents_folder,
    n_prior_sentences=0,
    n_future_sentences=0,
    split_questions=False,
    verbosity=1,
    val_size=0.15,
    test_size=0.15,
    seed=0,
):

    df = utils.load_df(sql_file_path)

    tr_ids, val_ids, test_ids = utils.split_ids(list(set(df.id.values)), seed=0, val_split=0.15, test_split=0.15)

    # Create each df
    tr_df = df[df.id.isin(tr_ids)]
    vl_df = df[df.id.isin(val_ids)]
    te_df = df[df.id.isin(test_ids)]

    dataset_dict = {}

    n_errors = 0
    for split_name, split_df in zip(["train", "validation", "test"], [tr_df, vl_df, te_df]):
        dataset = []

        for _, row in tqdm(split_df.iterrows(), total=len(df)):
            with open(os.path.join(documents_folder, row.id), "r") as f:
                doc_text = f.read()

            shift = len(row.id) - 1
            span = row.start_index - shift, row.end_index - shift

            # span_text = doc_text[span[0]:span[1]]
            # if row.reasoning != span_text:
            #     span_text = doc_text[span[0] + 1:span[1] + 1]
            #     if row.reasoning != span_text:
            #         raise RuntimeError(f"Trigger `{row.reasoning}` does not match corresponding span text `{span_text}`")
            #     else:
            #         span = span[0], span[1] + 1

            try:
                sentence = extract_sentence_with_trigger(
                    doc_text=doc_text,
                    trigger_span=span,
                    n_prior_sentences=n_prior_sentences,
                    n_future_sentences=n_future_sentences,
                    expected_trigger_text=row.reasoning,
                )
            except RuntimeError as e:
                n_errors += 1
                if n_errors < 3 or verbosity > 2:
                    if verbosity: print(e)
                elif n_errors == 3:
                    if verbosity > 0: print("Too many errors, not printing any more")
                continue

            if split_questions:
                questions = [q.strip() + "?" for q in row.question.split("?") if len(q) > 1]
                if len(questions) > 1:
                    if verbosity > 1: print("Multiple questions found:", questions)
                for question in questions:
                    dataset.append({"sentence": sentence, "trigger": row.reasoning, "question": question})
            else:
                dataset.append({"sentence": sentence, "trigger": row.reasoning, "question": row.question})

        if n_errors > 0:
            if verbosity > 0: print(f"Found {n_errors} errors. These examples were not added to the dataset")

        dataset = pd.DataFrame(dataset)
        dataset = datasets.Dataset.from_pandas(dataset)

        dataset_dict[split_name] = dataset


    return datasets.DatasetDict(dataset_dict)

In [None]:
dataset = build_dataset(DF_DATA_LOCATION, RELATION_TXT_FOLDER_LOC, n_prior_sentences=3, n_future_sentences=2)

In [None]:
dataset["train"][0]

In [None]:
dataset.save_to_disk()