In [30]:
import json

anno_path = "/root//autodl-tmp/datasets/mimic_cxr/annotation.json"
images_path = "/root/autodl-tmp/datasets/mimic_cxr/images.json"
with open(anno_path, "r") as f:
    anno = json.load(f)

In [35]:
import re
from pprint import pprint
from copy import deepcopy

def anno_merge(anno, tag_name="study_id"):
    cur_tag = ""
    anno_merged = []
    for item in anno:
        if item[tag_name] != cur_tag:
            cur_tag = item[tag_name]
            anno_merged.append(deepcopy(item))
        else:
            anno_merged[-1]['image_path'].extend(item['image_path'])
    return anno_merged


def anno_filter(anno):
    anno_filtered = []
    for item in anno:
        if len(item['image_path']) > 2:
            continue
        if re.search(r'compare|previous', item['report'], re.IGNORECASE):
            continue
        anno_filtered.append(deepcopy(item))
        anno_filtered[-1]['report'] = anno_filtered[-1]['report'].replace('\n', '')
    return anno_filtered

anno_merged_train = anno_merge(anno["train"], tag_name="study_id")
anno_filtered_train = anno_filter(anno_merged_train)
print("Trainset filter before", len(anno["train"]))
print("Trainset filter after", len(anno_filtered_train))
anno_merged_val = anno_merge(anno["val"], tag_name="study_id")
anno_filtered_val = anno_filter(anno_merged_val)
print("Valset filter before", len(anno["val"]))
print("Valset filter after", len(anno_filtered_val))
anno_merged_test = anno_merge(anno["test"], tag_name="study_id")
anno_filtered_test = anno_filter(anno_merged_test)
print("Testset filter before", len(anno["test"]))
print("Testset filter after", len(anno_filtered_test))
anno_filtered = {
    "train": anno_filtered_train,
    "val": anno_filtered_val,
    "test": anno_filtered_test
}
pprint(anno_filtered["train"][0])
anno_filtered_path = anno_path.replace("annotation.json", "annotation_filtered.json")
with open(anno_filtered_path, "w") as f:
    json.dump(anno_filtered, f, indent=4)

Trainset filter before 270790
Trainset filter after 106812
Valset filter before 2130
Valset filter after 844
Testset filter before 3858
Testset filter after 1429
{'id': '02aa804e-bde0afdd-112c0b34-7bc16630-4e384014',
 'image_path': ['p10/p10000032/s50414267/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.jpg',
                'p10/p10000032/s50414267/174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962.jpg'],
 'label_vec': [0.0,
               0.0,
               0.0,
               0.0,
               0.0,
               0.0,
               0.0,
               0.0,
               1.0,
               0.0,
               0.0,
               0.0,
               0.0,
               0.0],
 'report': 'There is no focal consolidation, pleural effusion or '
           'pneumothorax.  Bilateral nodular opacities that most likely '
           'represent nipple shadows. The cardiomediastinal silhouette is '
           'normal.  Clips project over the left lung, potentially within the '
           'breast. 

In [2]:
import os

def get_samples_by_disease(anno, disease, split="train", mode="pure"):
    label_names = [
        "Atelectasis",
        "Cardiomegaly",
        "Consolidation",
        "Edema",
        "Enlarged Cardiomediastinum",
        "Fracture",
        "Lung Lesion",
        "Lung Opacity",
        "Pleural Effusion",
        "Pneumonia",
        "Pneumothorax",
        "Pleural Other",
        "Support Devices",
        "No Finding"
    ]
    disease_samples = []
    disease_idx = label_names.index(disease)
    for item in anno[split]:
        item["image_path"] = [os.path.join("/root/autodl-tmp/datasets/mimic_cxr/images", image_path) for image_path in item["image_path"]]
        if item["label_vec"][disease_idx] == 1:
            if mode == "pure":
                if 1 not in [i for idx, i in enumerate(item["label_vec"]) if idx != disease_idx]:
                    disease_samples.append(item)
            else:
                disease_samples.append(item)
    return disease_samples

In [5]:
import os
import json

diseases = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Enlarged Cardiomediastinum",
    "Fracture",
    "Lung Lesion",
    "Lung Opacity",
    "Pleural Effusion",
    "Pneumonia",
    "Pneumothorax",
    "Pleural Other",
    "Support Devices",
    "No Finding"
]

save_path = "/root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test"
for disease in diseases:
    disease_samples = get_samples_by_disease(anno, disease, split="test", mode="pure")
    print(f"Number of samples for {disease}: {len(disease_samples)}")
    save_file = os.path.join(save_path, f"{disease}.json")
    with open(save_file, "w") as f:
        json.dump(disease_samples, f, indent=4)
    print(f"Saved {disease} samples to {save_file}")

Number of samples for Atelectasis: 69
Saved Atelectasis samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Atelectasis.json
Number of samples for Cardiomegaly: 131
Saved Cardiomegaly samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Cardiomegaly.json
Number of samples for Consolidation: 36
Saved Consolidation samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Consolidation.json
Number of samples for Edema: 130
Saved Edema samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Edema.json
Number of samples for Enlarged Cardiomediastinum: 11
Saved Enlarged Cardiomediastinum samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Enlarged Cardiomediastinum.json
Number of samples for Fracture: 41
Saved Fracture samples to /root/autodl-tmp/wh/med_report_R1/assets/disease_samples_test/Fracture.json
Number of samples for Lung Lesion: 28
Saved Lung Lesion samples to /root/autodl-tmp/wh/med_repo