# Data Preprocessing for VinDr-Mammo

## Step 0: Download the Dataset
1. Download the dataset from [VinDr-Mammo on PhysioNet](https://www.physionet.org/content/vindr-mammo/1.0.0/).
2. Place the images in the following directory:

    ```
    /workspace/data/VinDr-Mammo/images
    ```

## Step 1: Convert DICOM to NIfTI
The following step converts the DICOM files located in ``/workspace/data/VinDr-Mammo/images`` to NIfTI format and saves them in ``/workspace/data/VinDr-Mammo/nifti``.

In [None]:
import os
import SimpleITK as sitk
from tqdm import tqdm
import pydicom

In [None]:
root = '/workspace/data/VinDr-Mammo'
dcm_root = os.path.join(root, 'images')
result_root = os.path.join(root, 'nifti')

pats = os.listdir(dcm_root)
for pat in tqdm(pats):
    imgs = os.listdir(os.path.join(dcm_root, pat))

    output_folder = os.path.join(result_root, pat)
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder, exist_ok=True)
        
    for img in imgs:
        if "dicom" not in img:
            continue
        img_path = os.path.join(dcm_root, pat, img)
        dcm_data = pydicom.read_file(img_path)
        view = dcm_data[0x0018, 0x5101].value
        laterality = dcm_data[0x0020, 0x0062].value
        
        output_path = os.path.join(output_folder, laterality+"_"+view+".nii.gz")
        img_sitk = sitk.ReadImage(img_path)
        sitk.WriteImage(img_sitk, output_path, True)

## Step 2: Get Bounding Box Image

In [None]:
import os
from tqdm import tqdm
import SimpleITK as sitk
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ast
import math

In [None]:
root = '/workspace/data/VinDr-Mammo'
img_root = os.path.join(root, 'nifti')
result_root = os.path.join(root, 'bbox_nifti')

os.makedirs(result_root, exist_ok=True)

In [None]:
findings_all = pd.read_csv(os.path.join(root, 'finding_annotations.csv'))
findings = findings_all[findings_all['finding_categories'] != "['No Finding']"]
findings['finding_num'] = findings.groupby('image_id').cumcount() + 1

In [None]:
findings.to_csv(os.path.join(root, 'finding_annotations+finding_num.csv'), index=False)

In [None]:
for row in tqdm(findings.iterrows()):
    pat_id = row[1]['study_id']
    view = row[1]['view_position']
    LR = row[1]['laterality']
    finding_num = row[1]['finding_num']
    ymin = math.floor(row[1]['ymin'])
    ymax = math.ceil(row[1]['ymax'])
    xmin = math.floor(row[1]['xmin'])
    xmax = math.ceil(row[1]['xmax'])
    if pd.isna(row[1]['finding_birads']):
        # print(row[0], ",,,finding briads is NaN,,,", row[1]['finding_categories'])
        continue
    birads = row[1]['finding_birads'].replace(" ", "-")
    finding_list = ast.literal_eval(row[1]['finding_categories'])
    finding_list = [s.replace(' ', '-') for s in finding_list]
    # if "Mass" not in finding_list:
        # continue
    
    img_path = os.path.join(img_root, pat_id, LR+"_"+view+".nii.gz")
    if not os.path.isfile(img_path):
        # print(row[0], pat_id, ",,,image not found,,,")
        continue
    img_arr = sitk.GetArrayFromImage(sitk.ReadImage(img_path))

    bbox = np.zeros_like(img_arr)
    bbox[:,ymin:ymax+1,xmin:xmax+1] = 1
    if not np.any(bbox):
        print(row[0], pat_id, ",,,empty bbox,,,")
    
    output_folder = os.path.join(result_root, pat_id)
    if not os.path.isdir(output_folder):
        os.mkdir(output_folder)
    
    filename = LR + "_" + view + "_bbox-" + str(finding_num) + "_" + birads + "_" + '+'.join(finding_list) + ".nii.gz"
    # filename = LR + "_" + view + "_bbox-" + str(finding_num) + "_" + birads + "_Mass.nii.gz"
    bbox_sitk = sitk.GetImageFromArray(bbox)
    sitk.WriteImage(bbox_sitk, os.path.join(output_folder, filename))
#     break

## Step 3: Crop Breast Mask
Crops the image to retain only the breast area using Otsu's thresholding after Gaussian filtering.

In [None]:
import os
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu, gaussian
from skimage.measure import label, regionprops
from skimage.exposure import equalize_adapthist


In [None]:
def crop(data, mask=None):
    # Otsu's thresholding after Gaussian filtering
    img_blurred = gaussian(data, sigma=10)
    thresh = threshold_otsu(img_blurred)
    breast_mask = (img_blurred > thresh).astype(np.uint8)
    labeled_img = label(breast_mask)
    regions = regionprops(labeled_img)
    largest_region = max(regions, key=lambda x: x.area)
    minr, minc, maxr, maxc = largest_region.bbox
    
    if mask is None: 
        return data[minr:maxr, minc:maxc], breast_mask[minr:maxr, minc:maxc]
    else:
        return data[minr:maxr, minc:maxc], breast_mask[minr:maxr, minc:maxc], mask[minr:maxr, minc:maxc]

def minmax_normalization(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def truncation_normalization(data, mask):
    """
    Pixel clipped and normalized in breast ROI
    """
    Pmin = np.percentile(data[mask!=0], 5)
    Pmax = np.percentile(data[mask!=0], 99)
    truncated = np.clip(data,Pmin, Pmax)  
    normalized = (truncated - Pmin)/(Pmax - Pmin)
    normalized[mask==0]=0
    
    return normalized

def clahe(data, clip_limit=0.01):
    #contrast enhancement
    return equalize_adapthist(data, clip_limit=clip_limit)

def save_comparison_images(original_arr, mammogram_arr, breast_mask, file_path, mass_mask=None):

    if mass_mask is not None:
        fig, axs = plt.subplots(1, 4, figsize=(20, 80))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(20, 60))

    axs[0].imshow(original_arr, cmap='gray')
    axs[0].axis('off')
    
    axs[1].imshow(mammogram_arr, cmap='gray', vmax=1, vmin=0)
    axs[1].axis('off')

    axs[2].imshow(breast_mask, cmap='gray', vmax=1, vmin=0)
    axs[2].axis('off')

    if mass_mask is not None:
        axs[3].imshow(mass_mask, cmap='gray', vmax=1, vmin=0)
        axs[3].axis('off')

    plt.savefig(file_path, bbox_inches='tight', pad_inches=0)
    plt.close("all")

def Read_nifti(img_path):
    img_sitk = sitk.ReadImage(img_path)
    return sitk.GetArrayFromImage(img_sitk)[0], img_sitk.GetSpacing()

def Save_nifti(img_arr, save_path, spacing=None):
    img_sitk = sitk.GetImageFromArray(img_arr)
    if spacing is not None:
        img_sitk.SetSpacing(spacing)
    sitk.WriteImage(img_sitk, save_path, True)

In [None]:
root = '/workspace/data/VinDr-Mammo'
img_root = os.path.join(root, 'nifti')
pats = os.listdir(img_root)

In [None]:
crop_error = []

for pat in tqdm(pats):
    imgs = os.listdir(os.path.join(img_root, pat))
    for img in imgs:
        img_path = os.path.join(img_root, pat, img)

        img_arr, spacing = Read_nifti(img_path)
        
        mammogram, breast_mask = crop(img_arr)
        mammogram = minmax_normalization(mammogram)
        
        if np.all(breast_mask==1):
            print("ERROR in cropping,,,", pat, img)
            crop_error.append(pat+"_"+img)
            continue

        pat_folder = os.path.join(root, 'cropped_nifti', pat)
        os.makedirs(pat_folder, exist_ok=True)

        mammogram_path = os.path.join(pat_folder, img)
        breast_mask_path = os.path.join(pat_folder, img.split(".")[0]+"-breask_mask.nii.gz")
        
        Save_nifti(mammogram, mammogram_path, spacing=spacing)
        Save_nifti(breast_mask, breast_mask_path, spacing=spacing)

        plot_path = os.path.join(root, 'cropped_check_all', pat+"_"+img.split(".")[0]+".png")
        os.makedirs(os.path.join(root, 'cropped_check_all'), exist_ok=True)
        save_comparison_images(img_arr, mammogram, breast_mask, plot_path)


## Step 3-1: Get Mass Mask using MedSAM
As the mass masks used in this paper are not provided, we offer code to extract mass masks using MedSAM based on bounding boxes. This is based on the code from the [MedSAM repository](https://github.com/bowang-lab/MedSAM.git).

In [None]:
# %% environment and functions
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F
import SimpleITK as sitk

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)

    return medsam_seg

def pad_image_to_square(image_array):

    original_shape = image_array.shape
    
    target_size = max(original_shape[:2])
    
    padding_needed = [(0, 0)] * len(original_shape)

    for dim in range(2):
        padding = (target_size - original_shape[dim]) // 2
        extra_padding = (target_size - original_shape[dim]) % 2
        padding_needed[dim] = (padding, padding + extra_padding)
    
    padded_img = np.pad(image_array, padding_needed, mode='constant', constant_values=0)

    return padded_img

from skimage.transform import resize

def resize_image(image_array, output_shape=(1024, 1024), interpolation='linear'):
    if image_array.ndim == 2:
        output_shape = output_shape[:2]
    elif image_array.ndim == 3:
        assert len(output_shape) == 3, "Output shape must be 3-dimensional for 3-dimensional images."
    
    # 보간법 선택
    anti_aliasing = interpolation == 'linear'
    order = 0 if interpolation == 'nearest' else 1

    resized_image = resize(image_array, output_shape, mode='reflect', anti_aliasing=anti_aliasing, order=order)

    return resized_image

In [None]:
#%% load model and image
MedSAM_CKPT_PATH = "medsam_vit_b.pth"
device = "cuda:0"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()

In [None]:
root = '/workspace/data/VinDr-Mammo'
img_root = os.path.join(root, 'nifti')
bbox_root = os.path.join(root, 'bbox_nifti')
pats = os.listdir(bbox_root)

In [None]:
empty_bbox = []
error_crop = []

for pat in tqdm(pats):
    imgs = os.listdir(os.path.join(bbox_root, pat))
    for img in imgs:
        if "Mass" not in img:
            continue
        img_filename = "_".join(img.split("_")[:2]) + ".nii.gz"
        img_path = os.path.join(img_root, pat, img_filename)
        bbox_path = os.path.join(bbox_root, pat, img)
        img_arr, spacing = Read_nifti(img_path)
        bbox_arr, _ = Read_nifti(bbox_path)

        if bbox_arr.sum() == 0:
            print(pat, img, ",,,ERROR,,,empty bbox")
            empty_bbox.append(pat+"_"+img)
            continue

        mammogram, breast_mask, mass_bbox = crop(img_arr, bbox_arr)
        minmax_normalized = minmax_normalization(mammogram)
        trunc_normalized = truncation_normalization(mammogram, breast_mask)
        cl2 = clahe(trunc_normalized, 0.01)

        mammogram = np.stack([minmax_normalized, trunc_normalized, cl2], axis=2)

        padded_mammogram = pad_image_to_square(mammogram)
        padded_mass_bbox = pad_image_to_square(mass_bbox)
        padded_breast_mask = pad_image_to_square(breast_mask)

        if padded_mass_bbox.sum() == 0:
            print(pat, img, ",,,ERROR,,,in cropping")
            error_crop.append(pat+"_"+img)
            continue

        mammogram_1024 = resize_image(padded_mammogram, output_shape=(1024, 1024,3), interpolation='linear')
        mammogram_1024 = torch.tensor(mammogram_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
        
        xmin = np.where(padded_mass_bbox==1)[0].min()
        xmax = np.where(padded_mass_bbox==1)[0].max()
        ymin = np.where(padded_mass_bbox==1)[1].min()
        ymax = np.where(padded_mass_bbox==1)[1].max()
        
        box_np = np.array([[ymin, xmin, ymax, xmax]])
        H, W = padded_mass_bbox.shape
        box_1024 = box_np / np.array([W, H, W, H]) * 1024

        with torch.no_grad():
            image_embedding = medsam_model.image_encoder(mammogram_1024)
        medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
        
        xmin = np.where(padded_breast_mask==1)[0].min()
        xmax = np.where(padded_breast_mask==1)[0].max()+1
        ymin = np.where(padded_breast_mask==1)[1].min()
        ymax = np.where(padded_breast_mask==1)[1].max()+1
        
        mammogram_arr = padded_mammogram[xmin:xmax,ymin:ymax, 0]
        mass_bbox_arr = padded_mass_bbox[xmin:xmax,ymin:ymax]
        breast_mask_arr = padded_breast_mask[xmin:xmax,ymin:ymax]
        mass_mask_arr = medsam_seg[xmin:xmax,ymin:ymax]

        pat_folder = os.path.join(root, 'cropped_nifti_mass', pat)
        os.makedirs(pat_folder, exist_ok=True)
        bbox_folder = os.path.join(root, 'cropped_bbox')
        os.makedirs(bbox_folder, exist_ok=True)
        mask_folder = os.path.join(root, 'cropped_mass')
        os.makedirs(mask_folder, exist_ok=True)

        mammogram_path = os.path.join(pat_folder, img_filename)
        breast_mask_path = os.path.join(pat_folder, img_filename.split(".")[0]+"-breask_mask.nii.gz")
        bbox_path = os.path.join(bbox_folder, pat+"_"+img)
        mask_path = os.path.join(mask_folder, pat+"_"+img)

        Save_nifti(mammogram_arr, mammogram_path, spacing=spacing)
        Save_nifti(breast_mask_arr, breast_mask_path, spacing=spacing)
        Save_nifti(mass_bbox_arr, bbox_path, spacing=spacing)
        Save_nifti(mass_mask_arr, mask_path, spacing=spacing)

        plot_path = os.path.join(root, 'cropped_check_mass', pat+"_"+img.split(".")[0]+".png")
        os.makedirs(os.path.join(root, 'cropped_check_mass'), exist_ok=True)
        save_comparison_images(img_arr, mammogram_arr, breast_mask_arr, plot_path, mass_mask_arr)

## Step 4: Extract Radiomics Features
To ensure the mass is included in the input for the `stable-diffusion-2-inpainting` model, which requires an input size of (512,512), the image is resampled with 3x spacing. Radiomics features are then extracted from the resampled image.

In [None]:
import radiomics
from radiomics import featureextractor, firstorder, glcm, imageoperations, shape, glszm
import SimpleITK as sitk
import os
import numpy as np
import pandas as pd
import json
import warnings
warnings.filterwarnings("error")
from tqdm import tqdm
import matplotlib.pyplot as plt
logger = radiomics.logging.getLogger("radiomics")
logger.setLevel(radiomics.logging.ERROR)

def Shape_Feature_Extract(ID, image, ROI):
    ShapeFeatureExtractor = radiomics.shape2D.RadiomicsShape2D(image, ROI, force2D=True)
    ShapeFeatureExtractor.enableAllFeatures()
    ShapeFeatureExtractor.execute()
    
    result = pd.DataFrame([ShapeFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['Shape_'+x for x in list(result.columns[1:])]
    
    return result

def Hist_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    HistFeatureExtractor = radiomics.firstorder.RadiomicsFirstOrder(image, ROI, **settings)
    HistFeatureExtractor.enableAllFeatures()
    HistFeatureExtractor.execute()
    
    result = pd.DataFrame([HistFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['Hist_'+x for x in list(result.columns[1:])]
    
    return result

def GLCM_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    GLCMFeatureExtractor = radiomics.glcm.RadiomicsGLCM(image, ROI, **settings)
    GLCMFeatureExtractor.enableAllFeatures()
    GLCMFeatureExtractor.execute()
    
    result = pd.DataFrame([GLCMFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['GLCM_'+x for x in list(result.columns[1:])]
    
    return result

def GLSZM_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    GLSZMFeatureExtractor = radiomics.glszm.RadiomicsGLSZM(image, ROI, **settings)
    GLSZMFeatureExtractor.enableAllFeatures()
    GLSZMFeatureExtractor.execute()
    
    result = pd.DataFrame([GLSZMFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['GLSZM_'+x for x in list(result.columns[1:])]
    
    return result

def resize_image_and_roi(img_path, seg_path, breast_mask_path, resize_factor=3, crop_size=512):
    img_sitk = sitk.ReadImage(img_path)
    seg_sitk = sitk.ReadImage(seg_path)
    breast_mask_sitk = sitk.ReadImage(breast_mask_path)
    
    original_size = img_sitk.GetSize()
    original_spacing = img_sitk.GetSpacing()

    new_size = [int(original_size[0] / resize_factor), int(original_size[1] / resize_factor)]
    new_spacing = [original_spacing[0] * resize_factor, original_spacing[1] * resize_factor]

    resample_img = sitk.Resample(img_sitk, new_size, sitk.Transform(), 
                                 sitk.sitkLinear, img_sitk.GetOrigin(),
                                 new_spacing, img_sitk.GetDirection(), 0,
                                 img_sitk.GetPixelID())

    resample_seg = sitk.Resample(seg_sitk, new_size, sitk.Transform(), 
                                 sitk.sitkNearestNeighbor, seg_sitk.GetOrigin(),
                                 new_spacing, seg_sitk.GetDirection(), 0,
                                 seg_sitk.GetPixelID())

    resample_breast_mask = sitk.Resample(breast_mask_sitk, new_size, sitk.Transform(), 
                                         sitk.sitkNearestNeighbor, breast_mask_sitk.GetOrigin(),
                                         new_spacing, breast_mask_sitk.GetDirection(), 0,
                                         breast_mask_sitk.GetPixelID())

    return resample_img, resample_seg, resample_breast_mask


In [None]:
root = '/workspace/data/VinDr-Mammo'
img_root = os.path.join(root, 'cropped_nifti_mass')
roi_root = os.path.join(root, "cropped_mass")
save_root = os.path.join(root, "Mass")
imgs = os.listdir(roi_root)

os.makedirs(save_root, exist_ok=True)

In [None]:
shape_storage = dict()
hist_storage = dict()
glcm_storage = dict()
glszm_storage = dict()

Except_pat = dict()

for i, img in enumerate(tqdm(imgs)):
    try:
        pat_id = img.split("_")[0]
        img_name = "_".join(img.split("_")[1:3])
        img_path = os.path.join(img_root, pat_id, img_name+".nii.gz")
        breast_mask_path = os.path.join(img_root, pat_id, img_name+"-breask_mask.nii.gz")
        seg_path = os.path.join(roi_root, img)
        
        img_sitk, seg_sitk, breast_mask_sitk = resize_image_and_roi(img_path, seg_path, breast_mask_path, resize_factor=3)
        
        shape = Shape_Feature_Extract(img, img_sitk, seg_sitk)
        hist = Hist_Feature_Extract(img, img_sitk, seg_sitk)
        glcm = GLCM_Feature_Extract(img, img_sitk, seg_sitk)
        glszm = GLSZM_Feature_Extract(img, img_sitk, seg_sitk)

    except:
        print("### Fatal ERROR! ###", img)
        Except_pat[img] = "ERROR"
        continue
        
    else:

        isnan = False

        shape = dict(shape.iloc[0,1:])
        for f in shape:
            shape[f] = float(shape[f])
            if np.isnan(shape[f]):
                isnan = True
        hist = dict(hist.iloc[0,1:])
        for f in hist:
            hist[f] = float(hist[f])  
            if np.isnan(hist[f]):
                isnan = True
        glcm = dict(glcm.iloc[0,1:])
        for f in glcm:
            glcm[f] = float(glcm[f])
            if np.isnan(glcm[f]):
                isnan = True
        glszm = dict(glszm.iloc[0,1:])
        for f in glszm:
            glszm[f] = float(glszm[f])
            if np.isnan(glszm[f]):
                isnan = True

        if isnan:
            print("## ERROR in Nan! ###", img)
            Except_pat[img] = "Nan"
            continue

        shape_storage[img] = shape
        hist_storage[img] = hist
        glcm_storage[img] = glcm
        glszm_storage[img] = glszm

        os.makedirs(os.path.join(save_root, "image"), exist_ok=True)
        sitk.WriteImage(img_sitk, os.path.join(save_root, "image", img))

        os.makedirs(os.path.join(save_root, "mask"), exist_ok=True)
        sitk.WriteImage(seg_sitk, os.path.join(save_root, "mask", img))

        os.makedirs(os.path.join(save_root, "breast_mask"), exist_ok=True)
        sitk.WriteImage(breast_mask_sitk, os.path.join(save_root, "breast_mask", img))

        img_arr = sitk.GetArrayFromImage(img_sitk)
        seg_arr = sitk.GetArrayFromImage(seg_sitk)
        breast_mask_arr = sitk.GetArrayFromImage(breast_mask_sitk)

        stacked_arr = np.stack([img_arr, breast_mask_arr, seg_arr], axis=0)
        # stacked_arr = np.transpose(stacked_arr, (0,1,2))
        stacked_sitk = sitk.GetImageFromArray(stacked_arr)
        stacked_sitk.SetSpacing([seg_sitk.GetSpacing()[0], seg_sitk.GetSpacing()[1], 1])

        os.makedirs(os.path.join(save_root, "image_withMask"), exist_ok=True)
        sitk.WriteImage(stacked_sitk, os.path.join(save_root, "image_withMask", img))
        
        fileName = "VinDr-Mammo_ExceptList_Extract2DpatchRadiomicsFeatures.json"
        file_path = os.path.join(save_root, fileName)
        with open(file_path, 'w') as file:
            json.dump(Except_pat, file, indent=4)

print(Except_pat)
fileName = "VinDr-Mammo_ExceptList_Extract2DpatchRadiomicsFeatures.json"
file_path = os.path.join(save_root, fileName)
with open(file_path, 'w') as file:
    json.dump(Except_pat, file, indent=4)

df_shape = pd.DataFrame(shape_storage).T
df_hist = pd.DataFrame(hist_storage).T
df_glcm = pd.DataFrame(glcm_storage).T
df_glszm = pd.DataFrame(glszm_storage).T

df_shape.to_csv(os.path.join(save_root, "shape.csv"))
df_hist.to_csv(os.path.join(save_root, "hist.csv"))
df_glcm.to_csv(os.path.join(save_root, "glcm.csv"))
df_glszm.to_csv(os.path.join(save_root, "glszm.csv"))


### Split training/validation/test set

In [None]:
root = '/workspace/data/VinDr-Mammo'
mass_root = os.path.join(root, "Mass")

findings = pd.read_csv(os.path.join(root, 'finding_annotations+finding_num.csv'))

density_mapping = {'DENSITY A': 1, 'DENSITY B': 2, 'DENSITY C':3, 'DENSITY D':4}

In [None]:
json_path = os.path.join(mass_root, 'Mass_dataset_split.json')
with open(json_path, 'r') as json_file:
    Mass_dataset = json.load(json_file)

print(len(Mass_dataset['trainset']), len(Mass_dataset['valset']), len(Mass_dataset['testset']))

In [None]:
shape = pd.read_csv(os.path.join(mass_root, 'shape.csv'))
hist = pd.read_csv(os.path.join(mass_root, 'hist.csv'))
glcm = pd.read_csv(os.path.join(mass_root, 'glcm.csv'))
glszm = pd.read_csv(os.path.join(mass_root, 'glszm.csv'))

radiomics_features = pd.concat([shape, hist.iloc[:, 1:], glcm.iloc[:, 1:], glszm.iloc[:, 1:]], axis=1)
radiomics_features.rename(columns = {'Unnamed: 0': 'ID'}, inplace = True)

In [None]:
densities = []
assessments = []
splits = []
for index, row in radiomics_features.iterrows():
    row_splits = row['ID'].split("_")
    pat_id = row_splits[0]
    side = row_splits[1]
    view = row_splits[2]
#     print(pat_id, side, view)
    density = density_mapping[findings[(findings['study_id']==pat_id) & (findings['laterality']==side) & (findings['view_position']==view)]['breast_density'].values[0]]
    assessment = row_splits[4].split("-")[-1]
    split = findings[(findings['study_id']==pat_id) & (findings['laterality']==side) & (findings['view_position']==view)]['split'].values[0]
    densities.append(density)
    assessments.append(int(assessment))
    # splits.append(split)
    if pat_id in Mass_dataset['trainset']:
        splits.append('training')
    elif pat_id in Mass_dataset['valset']:
        splits.append('validation')
    elif pat_id in Mass_dataset['testset']:
        splits.append('test')
    
radiomics_features["density"] = densities
radiomics_features["assessment"] = assessments
radiomics_features["split"] = splits

In [None]:
radiomics_features.loc[radiomics_features["density"]<3, "density"] = 0 # low density
radiomics_features.loc[radiomics_features["density"]!=0, "density"] = 1 # high density

radiomics_features.loc[radiomics_features["assessment"]==3, "assessment"] = 1 # benign
radiomics_features.loc[radiomics_features["assessment"]!=1, "assessment"] = 2 # malignant

In [None]:
trainset = radiomics_features[radiomics_features['split']=='training'].iloc[:, :-1]
valset = radiomics_features[radiomics_features['split']=='validation'].iloc[:, :-1]
testset = radiomics_features[radiomics_features['split']=='test'].iloc[:, :-1]

In [None]:
trainset.shape, valset.shape, testset.shape

In [None]:
trainset.to_csv(os.path.join(mass_root, 'trainset.csv'), index=False)
valset.to_csv(os.path.join(mass_root, 'valset.csv'), index=False)
testset.to_csv(os.path.join(mass_root, 'testset.csv'), index=False)

### Min-Max Normalization

In [None]:
from sklearn.preprocessing import MinMaxScaler
import json

In [None]:
root = '/workspace/data/VinDr-Mammo'
mass_root = os.path.join(root, "Mass")

In [None]:
trainset = pd.read_csv(os.path.join(mass_root, "trainset.csv"))
valset = pd.read_csv(os.path.join(mass_root, "valset.csv"))
testset = pd.read_csv(os.path.join(mass_root,"testset.csv"))

In [None]:
scaler = MinMaxScaler()
columns_to_scale = trainset.columns[1:-2]
print(len(columns_to_scale))

In [None]:
trainset[columns_to_scale] = scaler.fit_transform(trainset[columns_to_scale])
valset[columns_to_scale] = scaler.transform(valset[columns_to_scale])
testset[columns_to_scale] = scaler.transform(testset[columns_to_scale])

In [None]:
trainset['assessment'] = trainset['assessment'].astype(float)
valset['assessment'] = valset['assessment'].astype(float)
testset['assessment'] = testset['assessment'].astype(float)

trainset.loc[trainset['assessment']==1, 'assessment'] = 0.5
trainset.loc[trainset['assessment']==2, 'assessment'] = 1

valset.loc[valset['assessment']==1, 'assessment'] = 0.5
valset.loc[valset['assessment']==2, 'assessment'] = 1

testset.loc[testset['assessment']==1, 'assessment'] = 0.5
testset.loc[testset['assessment']==2, 'assessment'] = 1

In [None]:
trainset.to_csv(os.path.join(mass_root, 'trainset_normalized.csv'), index=False)

valset.to_csv(os.path.join(mass_root,'valset_normalized.csv'), index=False)

testset.to_csv(os.path.join(mass_root,'testset_normalized.csv'), index=False)


In [None]:
min_max_dict = dict()

for i, col in enumerate(columns_to_scale):
    min_max_dict[col] = [scaler.data_min_[i], scaler.data_max_[i]]

In [None]:
min_max_dict_path = os.path.join(mass_root, "min_max_dict.json")
with open(min_max_dict_path, 'w') as file:
    json.dump(min_max_dict, file, indent=4)

In [None]:
Normal_data_list = [
    ["Low_Density"]+[0]*67+[1]+[0],
    ["High_Density"]+[0]*67+[0]+[0],
]

In [None]:
normal_rows = pd.DataFrame(Normal_data_list, columns=trainset.columns)

new_trainset = pd.concat([trainset, normal_rows], ignore_index=True)
new_valset = pd.concat([valset, normal_rows], ignore_index=True)
new_testset = pd.concat([testset, normal_rows], ignore_index=True)

In [None]:
new_trainset.to_csv(os.path.join(root, "trainset_normalized_6cls.csv"), index=False)
new_valset.to_csv(os.path.join(root, "valset_normalized_6cls.csv"), index=False)
new_testset.to_csv(os.path.join(root, "testset_normalized_6cls.csv"), index=False)

## Step 5: Preprocessing for Normal cases

Only "No Finding" cases are used as Normal cases, excluding not only mass findings but also Suspicious Calcification findings, which are not covered in this paper. As with mass cases, images are resampled with 3x spacing.

In [None]:
import os
import pandas as pd
import json
from tqdm import tqdm
import random
import numpy as np
import SimpleITK as sitk

In [None]:
def resize_image_and_roi(img_path, breast_mask_path, resize_factor=3, crop_size=512):
    img_sitk = sitk.ReadImage(img_path)
    breast_mask_sitk = sitk.ReadImage(breast_mask_path)
    
    original_size = img_sitk.GetSize()
    original_spacing = img_sitk.GetSpacing()

    new_size = [int(original_size[0] / resize_factor), int(original_size[1] / resize_factor)]
    new_spacing = [original_spacing[0] * resize_factor, original_spacing[1] * resize_factor]

    resample_img = sitk.Resample(img_sitk, new_size, sitk.Transform(), 
                                 sitk.sitkLinear, img_sitk.GetOrigin(),
                                 new_spacing, img_sitk.GetDirection(), 0,
                                 img_sitk.GetPixelID())

    resample_breast_mask = sitk.Resample(breast_mask_sitk, new_size, sitk.Transform(), 
                                         sitk.sitkNearestNeighbor, breast_mask_sitk.GetOrigin(),
                                         new_spacing, breast_mask_sitk.GetDirection(), 0,
                                         breast_mask_sitk.GetPixelID())

    return resample_img, resample_breast_mask

In [None]:
root = '/workspace/data/VinDr-Mammo'
normal_root = os.path.join(root, "Normal")
os.makedirs(normal_root, exist_ok=True)

findings = pd.read_csv(os.path.join(root, 'finding_annotations.csv'))
normal = findings[findings['finding_categories']=="['No Finding']"]

json_path = os.path.join(root, 'Mass', 'Mass_dataset_split.json')
with open(json_path, 'r') as json_file:
    Mass_dataset = json.load(json_file)

In [None]:
training_list = []
test_list = []

for index, row in tqdm(normal.iterrows()):
    if row['study_id'] in Mass_dataset['trainset']:
        continue
    elif row['study_id'] in Mass_dataset['valset']:
        continue
    elif row['study_id'] in Mass_dataset['testset']:
        continue
    pat_id = row['study_id']
    if row['split'] == 'training':
        training_list.append(pat_id)
    elif row['split'] == 'test':
        test_list.append(pat_id)

training_list = list(set(training_list))
test_list = list(set(test_list))
len(training_list), len(test_list)

In [None]:
val_num = int(len(training_list) * 0.1)
val_list = random.sample(training_list, val_num)
training_list = list(set(training_list) - set(val_list))

In [None]:
len(training_list), len(test_list), len(val_list)

In [None]:
Normal_dataset = dict()
Normal_dataset['trainset'] = training_list
Normal_dataset['valset'] = val_list
Normal_dataset['testset'] = test_list

In [None]:
# json_path = os.path.join(root, 'Normal', 'Normal_dataset_split.json')
# with open(json_path, 'w') as json_file:
#     json.dump(Normal_dataset, json_file, indent=4)

In [None]:
img_root = os.path.join(root, 'cropped_nifti')

pats = os.listdir(img_root)

for pat in tqdm(pats):
    if pat in Normal_dataset['trainset'] or pat in Normal_dataset['valset'] or pat in Normal_dataset['testset']:
        imgs = os.listdir(os.path.join(img_root, pat))
        for img in imgs:
            if ("-breask_mask.nii.gz" in img) or (".nii.gz" not in img):
                continue
            img_path = os.path.join(img_root, pat, img)
            breast_mask_path = os.path.join(img_root, pat, img.split(".")[0]+"-breask_mask.nii.gz")
            img_sitk, breast_mask_sitk = resize_image_and_roi(img_path, breast_mask_path)
            
            os.makedirs(os.path.join(normal_root, "image"), exist_ok=True)
            sitk.WriteImage(img_sitk, os.path.join(normal_root, "image", pat+"_"+img))

            os.makedirs(os.path.join(normal_root, "breast_mask"), exist_ok=True)
            sitk.WriteImage(breast_mask_sitk, os.path.join(normal_root, "breast_mask", pat+"_"+img))

# Step 6: Register Opposite-Side Image

To provide additional normal tissue information during inpainting, we use the opposite-side image due to the symmetrical nature of normal tissue in mammograms. The opposite-side image is rigidly registered using ANTs (refer to separate code).