In [1]:
import os

os.environ["HF_HOME"] = "/projects/bhuang/.cache/huggingface"
# os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [10]:
import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset

In [3]:
input_files = {
    # "train": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-train.jsonl",
    "train": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-train-10k.jsonl",
    "validation": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-validation.jsonl",
    "test": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-test.jsonl",
    # "validation": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240913-processed-validation.jsonl",
    # "test": "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240913-processed-test.jsonl",
}

dataset = load_dataset("json", data_files=input_files)
dataset

Generating train split: 10000 examples [00:00, 129206.58 examples/s]
Generating validation split: 1000 examples [00:00, 112261.23 examples/s]
Generating test split: 1000 examples [00:00, 99589.32 examples/s]


DatasetDict({
    train: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
})

In [4]:
train_dataset = dataset["train"]
# flatten
train_total_labels = sum(train_dataset["labels"], [])
# overall label frequecies
uniques, counts = np.unique(train_total_labels, return_counts=True)

In [5]:
# split labels into head/mid/tail
labels_dict = {"head": set(), "medium": set(), "tail": set()}

for u, c in zip(uniques, counts):
    if c >= 100:
        labels_dict["head"].add(u)
    if 20 <= c < 100:
        labels_dict["medium"].add(u)
    else:
        labels_dict["tail"].add(u)

[len(v) for v in labels_dict.values()]

[58, 390, 543]

In [None]:
# num of examples containing head/mid/tail labels
for name, labels in labels_dict.items():
    ds_ = train_dataset.filter(
        lambda x: any(l in labels for l in x["labels"]), num_proc=32
    )
    print(ds_.num_rows)

In [None]:
# split valid set
valid_dataset = dataset["validation"]

for name, labels in labels_dict.items():
    # todo: any or all
    ds_ = valid_dataset.filter(
        lambda x: any(l in labels for l in x["labels"]), num_proc=32
    )
    print(ds_.num_rows)

    output_file = f"/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-validation-{name}.jsonl"
    ds_.to_json(output_file, orient="records", lines=True, force_ascii=False)

In [None]:
# split test set
test_dataset = dataset["test"]

for name, labels in labels_dict.items():
    # todo: any or all
    ds_ = test_dataset.filter(
        lambda x: any(l in labels for l in x["labels"]), num_proc=32
    )
    print(ds_.num_rows)

    output_file = f"/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-test-{name}.jsonl"
    ds_.to_json(output_file, orient="records", lines=True, force_ascii=False)

In [11]:
# regenerate code_freqs for head/mid/tail

for name, labels in labels_dict.items():
    mask = np.isin(uniques, np.array(list(labels)))
    target_counts = counts[mask]
    target_uniques = uniques[mask]
    # renormalized freq
    target_freqs = target_counts / target_counts.sum()

    # save
    df = pd.DataFrame({"code": target_uniques, "freq": target_freqs})
    output_file = f"/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-train-10k-code_freqs-{name}.jsonl"
    df.to_json(output_file, orient="records", lines=True, force_ascii=False)