# Dataset Generation

## Setup

In [7]:
import os
import random
from glob import glob
from shutil import copytree, rmtree

import cv2 as cv
import numpy as np
import pandas as pd
from PIL import Image
from roifile import roiread

## Processing `cts-01` folder

In [2]:
NORMAL_PATIENTS = [
    [1],
    [2],
    [3],
    [4],
    [5],
    [6],
    [7],
    [8],
    [9, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
    [10, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90],
    [11],
    [12],
    [13],
    [14],
    [15],
    [16],
    [17, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100],
    [18, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110],
    [19],
    [20],
    [21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
    [31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
    [41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
    [51, 52, 53, 54, 55, 56, 57, 58, 59, 60],
    [61, 62, 63, 64, 65, 66, 67, 68, 69, 70],
    [111, 112, 113, 114, 115, 116, 117, 118, 119, 120],
    [121, 122, 123, 124, 125, 126, 127, 128, 129, 130],
]
MILD_PATIENTS = [
    [1],
    [2],
    [3],
    [4],
    [5],
    [6],
    [7],
    [8],
    [9],
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
    [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
    [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
]
MODERATE_PATIENTS = [
    [1],
    [2],
    [3],
    [4],
    [5],
    [6],
    [7],
    [8],
    [9],
    [10],
    [11],
    [12],
    [13],
    [14],
    [15],
    [16],
    [17],
    [18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
    [28, 29, 30, 31, 32, 33, 34, 35, 36, 37],
    [38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
    [48, 49, 50, 51, 52, 53, 54, 55, 56, 57],
    [58, 59, 60, 61, 62, 63, 64, 65, 66, 67],
    [68, 69, 70, 71, 72, 73, 74, 75, 76, 77],
]
SEVERE_PATIENTS = [
    [1],
    [2],
    [3],
    [4],
    [5],
    [6],
    [7],
    [8],
    [9],
    [10],
    [11],
    [12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
]

In [3]:
SRC_DATASET = "../assets/cts-01"
TGT_DATASET = "../assets/cts-01-processed"

assert os.path.exists(SRC_DATASET)
if os.path.exists(TGT_DATASET):
    rmtree(TGT_DATASET)
os.mkdir(TGT_DATASET)

data2patient = dict()
for i, datanos in enumerate(NORMAL_PATIENTS):
    for j, datano in enumerate(datanos):
        data2patient[("normal", datano)] = ("normal", i + 1, j + 1)
for i, datanos in enumerate(MILD_PATIENTS):
    for j, datano in enumerate(datanos):
        data2patient[("mild", datano)] = ("mild", i + 1, j + 1)
for i, datanos in enumerate(MODERATE_PATIENTS):
    for j, datano in enumerate(datanos):
        data2patient[("moderate", datano)] = ("moderate", i + 1, j + 1)
for i, datanos in enumerate(SEVERE_PATIENTS):
    for j, datano in enumerate(datanos):
        data2patient[("severe", datano)] = ("severe", i + 1, j + 1)

abbv2sev = {"NL": "normal", "Mild": "mild", "Mod": "moderate", "Sev": "severe"}

src_files = glob(os.path.join(SRC_DATASET, "*.tif"))
for src_file in src_files:
    src_filename = os.path.splitext(os.path.basename(src_file))[0]
    splits = src_filename.split("_")
    severity = abbv2sev[splits[0]]
    datano = int(splits[1])
    muscle = "thenar" if splits[2] == "T" else "hypothenar"

    severity_path = os.path.join(TGT_DATASET, severity)
    if not os.path.exists(severity_path):
        os.mkdir(severity_path)

    patient = data2patient[(severity, datano)]
    patient_path = os.path.join(severity_path, f"{patient[0]}-{patient[1]}")
    if not os.path.exists(patient_path):
        os.mkdir(patient_path)

    src_img = Image.open(src_file)
    tgt_file = os.path.join(patient_path, f"{muscle}-{patient[2]}.tiff")
    src_img.save(tgt_file, format="TIFF")

    for roi in roiread(src_file):
        break
    coords = np.array(roi.integer_coordinates, "int32")
    coords[:, 0] += roi.left
    coords[:, 1] += roi.top

    org_img = cv.imread(src_file, flags=cv.IMREAD_GRAYSCALE)
    mask_img = np.zeros_like(org_img)
    cv.fillPoly(mask_img, [coords], (255,))
    mask_img = Image.fromarray(mask_img)
    mask_file = os.path.join(patient_path, f"{muscle}-mask-{patient[2]}.tiff")
    mask_img.save(mask_file, format="TIFF")

## Processing `cts-02` folder

In [4]:
SRC_DATASET = "../assets/cts-02"
TGT_DATASET = "../assets/cts-02-processed"

assert os.path.exists(SRC_DATASET)
if os.path.exists(TGT_DATASET):
    rmtree(TGT_DATASET)
os.mkdir(TGT_DATASET)

df = pd.read_excel(
    os.path.join(SRC_DATASET, "Data list.xlsx"),
    names=["Case No", "Lesion Side", "Severity"],
)
abbv2side = {"Lt": "left", "Rt": "right"}
severity_count = {"normal": 0, "mild": 0, "moderate": 0, "severe": 0}
data2patient = dict()
for _, caseno, side, severity in df.itertuples():
    side = abbv2side[side[:2]]
    severity = severity.strip()
    severity_count[severity] += 1
    data2patient[(caseno, side)] = (severity, severity_count[severity])

src_files = glob(os.path.join(SRC_DATASET, "*", "*.tif"))
for src_file in src_files:
    src_dirname = os.path.dirname(src_file)
    caseno = int(src_dirname.split()[1].split("_")[0])

    src_filename = os.path.splitext(os.path.basename(src_file))[0]
    splits = src_filename.split()
    side = abbv2side[splits[0][:2]]
    muscle = "hypothenar" if "HT" in splits[1] else "thenar"
    datano = int(splits[-1].split("(")[1].split(")")[0])

    patient = data2patient[(caseno, side)]
    severity_path = os.path.join(TGT_DATASET, patient[0])
    if not os.path.exists(severity_path):
        os.mkdir(severity_path)

    patient_path = os.path.join(severity_path, f"{patient[0]}-{patient[1]}")
    if not os.path.exists(patient_path):
        os.mkdir(patient_path)

    src_img = Image.open(src_file)
    tgt_file = os.path.join(patient_path, f"{muscle}-{datano}.tiff")
    src_img.save(tgt_file, format="TIFF")

    for roi in roiread(src_file):
        break
    coords = np.array(roi.integer_coordinates, "int32")
    coords[:, 0] += roi.left
    coords[:, 1] += roi.top

    org_img = cv.imread(src_file, flags=cv.IMREAD_GRAYSCALE)
    mask_img = np.zeros_like(org_img)
    cv.fillPoly(mask_img, [coords], (255,))
    mask_img = Image.fromarray(mask_img)
    mask_file = os.path.join(patient_path, f"{muscle}-mask-{datano}.tiff")
    mask_img.save(mask_file, format="TIFF")

## Generating CTSDiag

In [9]:
SRC_DATASETS = ["../assets/cts-01-processed", "../assets/cts-02-processed"]
TGT_DATASET = "../data/CTSDiag"
TRAIN_TEST_RATIO = (0.8, 0.2)

random.seed(42)

assert all(os.path.exists(s) for s in SRC_DATASETS)
if os.path.exists(TGT_DATASET):
    rmtree(TGT_DATASET)
os.mkdir(TGT_DATASET)

single_data = {"normal": [], "mild": [], "moderate": [], "severe": []}
multiple_data = {"normal": [], "mild": [], "moderate": [], "severe": []}
severity_count = {"normal": 0, "mild": 0, "moderate": 0, "severe": 0}
src2tgt = dict()
for src_dataset in SRC_DATASETS:
    for severity in ["normal", "mild", "moderate", "severe"]:
        severity_path = os.path.join(src_dataset, severity)
        patient_dirs = sorted(os.listdir(severity_path), key=lambda p: int(p.split("-")[1]))
        for patient_dir in patient_dirs:
            patient_path = os.path.join(severity_path, patient_dir)
            src_patient = (src_dataset, severity, int(patient_dir.split("-")[1]))
            num_data = len(glob(os.path.join(patient_path, "*.tiff"))) // 4
            if num_data == 1:
                single_data[severity].append(src_patient)
            else:
                multiple_data[severity].append(src_patient)
            severity_count[severity] += 1
            src2tgt[src_patient] = (severity, severity_count[severity])

stats = {
    "normal": dict(),
    "mild": dict(),
    "moderate": dict(),
    "severe": dict(),
}
for severity in ["normal", "mild", "moderate", "severe"]:
    single_shuffle = random.sample(single_data[severity], k=len(single_data[severity]))
    num_tests = int(len(single_shuffle) * TRAIN_TEST_RATIO[1])
    num_trains = len(single_shuffle) - num_tests
    stats[severity]["single"] = (num_trains, num_tests)

    for i, src_patient in enumerate(single_shuffle):
        split = "train" if i < num_trains else "test"
        src_dataset, severity, src_pno = src_patient
        src_patient_path = os.path.join(src_dataset, severity, f"{severity}-{src_pno}")
        tgt_patient = src2tgt[src_patient]
        _, tgt_pno = tgt_patient
        tgt_patient_path = os.path.join(TGT_DATASET, split, severity, f"{severity}-{tgt_pno}")
        copytree(src_patient_path, tgt_patient_path)
        print(f"{src_patient_path} -> {tgt_patient_path}")

    multiple_shuffle = random.sample(multiple_data[severity], k=len(multiple_data[severity]))
    num_tests = int(len(multiple_shuffle) * TRAIN_TEST_RATIO[1])
    num_trains = len(multiple_shuffle) - num_tests
    stats[severity]["multiple"] = (num_trains, num_tests)

    for i, src_patient in enumerate(multiple_shuffle):
        split = "train" if i < num_trains else "test"
        src_dataset, severity, src_pno = src_patient
        src_patient_path = os.path.join(src_dataset, severity, f"{severity}-{src_pno}")
        tgt_patient = src2tgt[src_patient]
        _, tgt_pno = tgt_patient
        tgt_patient_path = os.path.join(TGT_DATASET, split, severity, f"{severity}-{tgt_pno}")
        copytree(src_patient_path, tgt_patient_path)
        print(f"{src_patient_path} -> {tgt_patient_path}")

print()
print(f"{'Severity':9s}{'Type':9s}{'Train':6s}{'Test':6s}{'Total':6s}")
for severity in ["normal", "mild", "moderate", "severe"]:
    for type in ["single", "multiple"]:
        num_trains, num_tests = stats[severity][type]
        num_total = num_trains + num_tests
        print(
            f"{severity.capitalize():9s}"
            + f"{type.capitalize():9s}"
            + f"{str(num_trains):6s}"
            + f"{str(num_tests):6s}"
            + f"{str(num_total):6s}"
        )

../assets/cts-01-processed/normal/normal-4 -> ../data/CTSDiag/train/normal/normal-4
../assets/cts-01-processed/normal/normal-1 -> ../data/CTSDiag/train/normal/normal-1
../assets/cts-01-processed/normal/normal-14 -> ../data/CTSDiag/train/normal/normal-14
../assets/cts-01-processed/normal/normal-5 -> ../data/CTSDiag/train/normal/normal-5
../assets/cts-01-processed/normal/normal-20 -> ../data/CTSDiag/train/normal/normal-20
../assets/cts-01-processed/normal/normal-16 -> ../data/CTSDiag/train/normal/normal-16
../assets/cts-01-processed/normal/normal-3 -> ../data/CTSDiag/train/normal/normal-3
../assets/cts-01-processed/normal/normal-2 -> ../data/CTSDiag/train/normal/normal-2
../assets/cts-01-processed/normal/normal-11 -> ../data/CTSDiag/train/normal/normal-11
../assets/cts-01-processed/normal/normal-15 -> ../data/CTSDiag/train/normal/normal-15
../assets/cts-01-processed/normal/normal-13 -> ../data/CTSDiag/train/normal/normal-13
../assets/cts-01-processed/normal/normal-19 -> ../data/CTSDiag/t