In [None]:
import os
import re
import numpy as np
import pandas as pd
import pydicom
from tqdm import trange, tqdm
from PIL import Image

import matplotlib.pyplot as plt

data_dir = '/opt/gpudata/midrc-sift'
dcm_dir = os.path.join(data_dir, 'dcm')
series_uids = sorted(os.listdir(dcm_dir))
obj_ids = pd.read_csv(os.path.join(data_dir, 'obj_ids.csv'))
dcm_csv = os.path.join(data_dir, "annotated_dcms.csv")

In [None]:
assert obj_ids["annotation"].str.endswith(".dcm").all()
anns = set(obj_ids["annotation"])

In [None]:
dcms = []
for series_uid in tqdm(series_uids):
    series_dcms = []
    spath = os.path.join(dcm_dir, series_uid)
    for root, dirs, files in os.walk(spath):
        for fname in files:
            if fname.endswith(".dcm"):
                fpath = os.path.join(root, fname)
                series_dcms.append((fname, fpath))
    dcms.append(series_dcms)

In [None]:
im_dcms = []
for series_uid, series_dcms in tqdm(list(zip(series_uids, dcms))):
    # find annotation and check there's only one per series
    ann_idx = -1
    for i, (fname, fpath) in enumerate(series_dcms):
        if fname in anns:
            assert ann_idx == -1 # multiple annotations per series??
            ann_idx = i
    assert ann_idx != -1 # must find annotation
    ann_fname = series_dcms[ann_idx][0]

    # pair annotation with image dcm using the common uid
    series_im_dcms = []
    matches = re.findall(r"__([\d\.]*?)__seg.dcm", ann_fname)
    assert len(matches) == 1
    common_uid = matches[0]
    for i, (fname, fpath) in enumerate(series_dcms):
        if i == ann_idx:
            continue
        if common_uid in fpath:
            series_im_dcms.append((fname, fpath))

    # if no matches but there's only one other dcm, use that
    if len(series_im_dcms) == 0 and len(series_dcms) == 2:
        series_im_dcms.append(series_dcms[1-ann_idx])

    # if there are still no matches, read the dcm and try to match based on UIDs in the file
    if len(series_im_dcms) == 0:
        for i, (fname, fpath) in enumerate(series_dcms):
            if i == ann_idx:
                continue
            dcm = pydicom.dcmread(fpath)
            if common_uid == dcm.SOPInstanceUID:
                series_im_dcms.append((fname, fpath))

    # if there are still no matches, give up?
    assert len(series_im_dcms) == 1
    im_dcms.append((
        series_uid,
        common_uid,
        series_im_dcms[0][1], # fpath
    ))

In [None]:
annotated_dcms = pd.DataFrame(im_dcms, columns=["series_uid", "image_uid", "fpath"])

In [None]:
annotated_dcms.to_csv(dcm_csv, index=False)