# MS Preprocessing
!To get dataset for this task, first line classifier model has to be available and used to generate line labels on the reports in seantis_kisim.csv (corresponds to inference on the "all" split in the line-label/line-label_clean_dataset)!

Necessary steps before this:
1. notebooks/00_preprocessing_old_project.ipynb
2. notebooks/01_classifying_text_lines.ipynb
3. scripts/line-label/inference.py --model_name "trained-line-classifier" --split all

The data/preprocessed/midatams/seantis_kisim.csv file was created by the original project (00_preprocessing_old_project.ipynb).
This file contains the longest report per rid splitted into lines. Their approach:

1. Extract the longest diagnosis per rid (most lines) from the csv and if the rid had a manually line labelled text, they used this instead.
2. Results in dataset consisting of text lines per row with a label for the line.

Further processing:

3. Label this dataset with clf 1. Merge diagnoses.csv and line labelled dataset by rid. Clean the labels. Correct SPMS and PPMS labels that are wrong.
4. Get a list of eligible rid (rids with text that have at least one dm line).
5. Df1: concatenate all dm lines per eligible rid for the text.
6. Df2: concatenate all text per eligible rid.

In [None]:
import pandas as pd
import torch
import sys
import os
sys.path.append(os.getcwd()+"/../..")

from src import paths
from src.utils import ms_label2id, line_label_id2label

from datasets import DatasetDict, Dataset

from sklearn.model_selection import train_test_split

import json

In [None]:
# Line Labelled dataset Token Finetuned
data = torch.load(paths.RESULTS_PATH/"line-label/line-label_medbert-512_token_all.pt")

data_list = []
for obs in data:
    _df = pd.DataFrame(obs["text"], columns=["text"])
    _df["class2"] = obs["preds"]
    _df["rid"] = obs["rid"]
    data_list.append(_df)

data_token_df = pd.concat(data_list)

data_token_df = data_token_df[["rid", "text", "class2"]]


# Make directory if it doesn't exist
os.makedirs(paths.DATA_PATH_PREPROCESSED/"ms-diag", exist_ok=True)

data_token_df.to_csv(paths.DATA_PATH_PREPROCESSED/"ms-diag/line-label_medbert-512_token_finetuned_all.csv", index=False)

In [None]:
# Line Labelled dataset Line Finetuned
data_line = torch.load(paths.RESULTS_PATH/"line-label/line-label_medbert-512_class_all.pt")
data_line.pop("last_hidden_state")

data_line_df = pd.DataFrame(data_line)
data_line_df.rename(columns={"preds": "class2"}, inplace=True)
data_line_df.drop(columns=["labels"], inplace=True)
data_line_df = data_line_df[["rid", "index_within_rid", "text", "class2"]]

# Remap values in class2 using ms_id2label
data_line_df["class2"] = data_line_df["class2"].map(line_label_id2label)

# Make directory if it doesn't exist
os.makedirs(paths.DATA_PATH_PREPROCESSED/"ms-diag", exist_ok=True)

data_line_df.to_csv(paths.DATA_PATH_PREPROCESSED/"ms-diag/line-label_medbert-512_class_finetuned_all.csv", index=False)

In [None]:
# The line labelled will have more lines because no line truncation
print("Length Line Data: ", len(data_line_df))
print("Length Token Data: ", len(data_token_df))

In [None]:
data_line_df[data_line_df["rid"] == "016B6D16-2BBA-4C05-A8E4-30F761C95813"]

In [None]:
data_token_df[data_token_df["rid"] == "016B6D16-2BBA-4C05-A8E4-30F761C95813"]

## Preparing CLF1 Dataset for MS Diag

- I decide to go with line classifier, even though it is a bit worse, but no truncation issues for longer reports
- The steps are:

1. Merging Datasets on rid (one row is one line in text)
2. Cleaning up labels (only using confirmed diganoses, rewriting German labels, correcting wrong labels for small count classes)
3. Construction of "no_ms" class: all lines labelled as something other than "dm" are "no_ms".
4. Construct 2 datasets:
    - df1: contains only single lines per row.
    - df2: contains all text per rid. If at least one line was "dm" then label was set to the MS Type for this rid. If not "no_ms". More "no_ms" by creating texts, that don't contain the "dm" line. (So there might be multiple examples from same rid, one time with dm line, one time without.)

In [None]:
# Line labelled dataset from classifier 1
df_text = pd.read_csv(os.path.join(paths.DATA_PATH_PREPROCESSED, "ms-diag", "line-label_medbert-512_class_finetuned_all.csv"))[["rid", "text", "class2", "index_within_rid"]]
df_labels = pd.read_csv(os.path.join(paths.DATA_PATH_SEANTIS, "diagnoses.csv"))

# In old approach they only used confirmed diagnosis
df_labels = df_labels[df_labels["diagnosis_reliability"] == "confirmed"]
df_labels = df_labels[["research_id", "disease"]].rename(columns={"disease": "labels", "research_id": "rid"})

In [None]:
# Merge with diagnoses.csv
df_merged = pd.merge(df_text, df_labels, on="rid", how="inner")

In [None]:
# English Labels
english_labels = set(["relapsing_remitting_multiple_sclerosis", "secondary_progressive_multiple_sclerosis", "primary_progressive_multiple_sclerosis"])
other_labels = set(df_merged["labels"].unique()) - english_labels

In [None]:
other_labels

In [None]:
# Check non english labels
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)

for rid, rid_data in df_merged.groupby("rid"):
    if rid_data.labels.isin(other_labels).any():
        print(rid_data.labels.unique())
        print(rid_data["text"].str.cat(sep = " "))

In [None]:
# Remap non english labels if possible
map_dict = {
    "Multiple Sklerose a.e. primär progredient": "primary_progressive_multiple_sclerosis",
    "Multiple Sklerose mit a.e. primär-progredientem Verlauf": "primary_progressive_multiple_sclerosis",
    "Schubförmig remittierende Multiple Sklerose (RRMS)": "relapsing_remitting_multiple_sclerosis",
}
df_merged = df_merged.replace(map_dict)

In [None]:
# Remove all non english labels
df_merged = df_merged[df_merged["labels"].isin(english_labels)]

In [None]:
# Because mapping was done manually, check if label matches text for classes with low counts like SPMS
for rid, rid_data in df_merged[df_merged["labels"] == "secondary_progressive_multiple_sclerosis"].groupby("rid"):
    print(rid)
    print(rid_data["text"].str.cat(sep = "\n"))
    print("\n")

In [None]:
spms_wrong = ["2A9F4832-B09D-470A-B05F-519854310DBB",
              "39D432B0-902B-49D9-B727-12EDC053B09E",
              "AF834D8D-F7DB-4B22-BB01-29F10EE6A828",
              "B886879A-5109-46FD-A2B0-9DCA2DA733F8",
              "C0784569-1E15-4FBE-A4B2-F9473975D199"
                ]
df_merged[df_merged["labels"] == "secondary_progressive_multiple_sclerosis"].rid.unique().shape
# Because of this exclusion we end up with less training examples than in their original approach
df_merged[(df_merged["labels"] == "relapsing_remitting_multiple_sclerosis") & df_merged["text"].str.lower().str.contains("spms|sekundär")]

In [None]:
# Drop entries with wrong label
df_merged = df_merged[~df_merged["rid"].isin(spms_wrong)]

# Remap entry 157 to secondary_progressive_multiple_sclerosis
df_merged.loc[157, "labels"] = "secondary_progressive_multiple_sclerosis"

In [None]:
# Check primary_progressive_multiple_sclerosis
for rid, rid_data in df_merged[df_merged["labels"] == "primary_progressive_multiple_sclerosis"].groupby("rid"):
    print(rid)
    print(rid_data["text"].str.cat(sep = "\n"))
    print("\n")

In [None]:
# Rids with diagnosis
rids_dm = set(df_merged[df_merged["class2"] == "dm"]["rid"].unique())

# Rids without diagnosis
rids_no_dm = set(df_merged["rid"].unique()) - rids_dm

In [None]:
# Set labels of rids without diagnosis to no_ms
df_merged.loc[df_merged["rid"].isin(rids_no_dm), "labels"] = "no_ms"

In [None]:
# Take all non dm lines and set label to no dm
df1_no_dm = df_merged[df_merged["class2"] != "dm"]
df1_no_dm.loc[:, "labels"] = "no_ms"

# For the rids in rids_dm, extract all lines with class2 == dm
df1_dm = df_merged[df_merged["class2"] == "dm"].groupby("rid").agg({"text": "\n".join, "labels": "first", "index_within_rid": "first"}).reset_index()

# Concat both dataframes
df1 = pd.concat([df1_no_dm, df1_dm]).drop(columns=["class2"])

In [None]:
# For the rids in rids_dm, extract all lines and keep original label
df2 = df_merged.groupby("rid").agg({"text": "\n".join, "labels": "first"}).reset_index()

In [None]:
# df3 will be df2 but put first text line last
df3 = df2.copy()
df3["text"] = df3["text"].apply(lambda x: x.split("\n"))
df3["text"] = df3["text"].apply(lambda x: x[1:] + [x[0]])
df3["text"] = df3["text"].apply(lambda x: "\n".join(x))

In [None]:
df2.labels.value_counts(), df3.labels.value_counts()

In [None]:
len(df1.rid.unique()), len(df2.rid.unique()), len(df3.rid.unique())

In [None]:
# Train Val Test split
df2train, df2test = train_test_split(df2, test_size=0.3, random_state=42, stratify=df2["labels"])
df2train, df2val = train_test_split(df2train, test_size=0.1, random_state=42, stratify=df2train["labels"])

train_rids = set(df2train["rid"].unique())
val_rids = set(df2val["rid"].unique())
test_rids = set(df2test["rid"].unique())

df3train = df3[df3["rid"].isin(train_rids)]
df3val = df3[df3["rid"].isin(val_rids)]
df3test = df3[df3["rid"].isin(test_rids)]

df1train = df1[df1["rid"].isin(train_rids)]
df1val = df1[df1["rid"].isin(val_rids)]
df1test = df1[df1["rid"].isin(test_rids)]

In [None]:
len(df1test.rid.unique()), len(df2test.rid.unique()), len(df3test.rid.unique()) 

In [None]:
# Create HuggingFace Dataset
def create_hf_dataset(train:pd.DataFrame, val:pd.DataFrame, test:pd.DataFrame):
    """Create HuggingFace Dataset from train, val and test dataframes. Remaps labels to ids and drops unnecessary columns.
    
    Args:
        train (pd.DataFrame): Training dataframe
        val (pd.DataFrame): Validation dataframe
        test (pd.DataFrame): Test dataframe
        
        Returns:
            DatasetDict: HuggingFace DatasetDict
            
    """
    dataset = DatasetDict({
        "train": Dataset.from_pandas(train),
        "val": Dataset.from_pandas(val),
        "test": Dataset.from_pandas(test),
    })

    # Map the labels to ids
    dataset = dataset.map(lambda e: {"labels": [ms_label2id[l] for l in e["labels"]]}, batched=True)

    # Drop __index_level_0__ column
    dataset = dataset.remove_columns(["__index_level_0__"])

    return dataset

dataset1 = create_hf_dataset(df1train, df1val, df1test)
dataset2 = create_hf_dataset(df2train, df2val, df2test)
dataset3 = create_hf_dataset(df3train, df3val, df3test)

# Save the dataset
dataset1.save_to_disk(os.path.join(paths.DATA_PATH_PREPROCESSED, "ms-diag/ms_diag_line"))
dataset2.save_to_disk(os.path.join(paths.DATA_PATH_PREPROCESSED, "ms-diag/ms_diag_all"))
dataset3.save_to_disk(os.path.join(paths.DATA_PATH_PREPROCESSED, "ms-diag/ms_diag_all_first_line_last"))

### Test Set no ms label

To later evaluate the validity of the "no_ms" label I will manually check if the "no_ms" labels in the test set are correct:

In [None]:
no_ms_text = dataset2["test"].filter(lambda e: e["labels"] == 3)["text"]
no_ms_text

# Summary

In [None]:
# Label distribution
print("Label distribution all:")
print(df2.labels.value_counts(), "\n\n")

print("Label distribution all_first_line_last:")
print(df3.labels.value_counts(), "\n\n")

print("Label distribution line:")
print(df1.labels.value_counts(), "\n\n")

## Pipeline Approach

To get a fair comparison, I need to retrain the line-classifier with the test rids excluded.

In [None]:
# Save test rids in a file
with open(os.path.join(paths.DATA_PATH_PREPROCESSED, "ms-diag/test_rids.txt"), "w") as f:
    f.write("\n".join(test_rids))

In [None]:
# Load line labelled dataset
from src.utils import load_line_label_data
from datasets import concatenate_datasets

line_labels = load_line_label_data()
line_labels_all = concatenate_datasets([line_labels["train"], line_labels["val"], line_labels["test"]])

# Create test set from line labelled dataset by using the test rids
line_labels_test = line_labels_all.filter(lambda e: e["rid"] in test_rids)

# Remove test rids from line labelled dataset
line_labels_all = line_labels_all.filter(lambda e: e["rid"] not in test_rids)

# Cast labels column to ClassLabel and split into train and test
line_labels_all = line_labels_all.class_encode_column("labels").train_test_split(test_size=0.1, shuffle=True, seed=42, stratify_by_column="labels")

# Assign correct splits
line_labels_all["val"] = line_labels_all["test"]
line_labels_all["test"] = line_labels_test
line_labels_all["all"] = line_labels["all"]

# Save the dataset
line_labels_all.save_to_disk(os.path.join(paths.DATA_PATH_PREPROCESSED, "line-label/line-label_clean_dataset_pipeline"))

# Prompting

Following the task instruction, system prompt and examples for the ms type extraction task.

In [None]:
from src.utils import zero_shot_base, zero_shot_instruction, few_shot_base, few_shot_instruction

In [None]:
task_instruction = """Your task is to extract the diagnosis corresponding to a type of multiple sclerosis (MS) stated in a German medical report. The input for this task is a German medical report, and the output should be the type of MS.
There are three types of multiple sclerosis in German:
- primär progrediente Multiple Sklerose (PPMS)
- sekundär progrediente Multiple Sklerose (SPMS)
- schubförmig remittierende Multiple Sklerose (RRMS)

The type is provided in the text, and your task is to extract it. If you cannot match a type exactly, please answer with 'other'.
Your answer should solely consist of one of the following:
- primär progrediente Multiple Sklerose
- sekundär progrediente Multiple Sklerose
- schubförmige remittierende Multiple Sklerose
- other
"""

system_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not makeany sense, or is not factually coherent, explain why instead of answering something not correct. 
If you don’t know the answer to a question, please don’t share false information.
"""


with open(paths.DATA_PATH_PREPROCESSED/"ms-diag/task_instruction.txt", "w") as file:
    file.write(task_instruction)

with open(paths.DATA_PATH_PREPROCESSED/"ms-diag/system_prompt.txt", "w") as file:
    file.write(system_prompt)

In [None]:
# Examples

ppms_text = dataset2["test"].filter(lambda e: e["labels"] == 0)["text"][0][:200] + "..."
rrms_text = dataset2["test"].filter(lambda e: e["labels"] == 1)["text"][0][:200] + "..."
spms_text = dataset2["test"].filter(lambda e: e["labels"] == 2)["text"][0][:200] + "..."
other_text = dataset2["test"].filter(lambda e: e["labels"] == 3)["text"][0][:200] + "..."

examples = [{"text": ppms_text, "labels": "primär progrediente Multiple Sklerose"},
            {"text": rrms_text, "labels": "schubförmige remittierende Multiple Sklerose"},
            {"text": spms_text, "labels": "sekundär progrediente Multiple Sklerose"},
            {"text": other_text, "labels": "other"}]

with open(paths.DATA_PATH_PREPROCESSED/"ms-diag/examples.json", "w") as file:
    json.dump(examples, file)