# Create EMBED splits
This notebook takes care of creating the train/val/test splits csv used throughout this project.

In [None]:
import pandas as pd
from pathlib import Path
import sys

path_to_root = "/vol/biomedic3/mb121/shift_identification/"
sys.path.append(path_to_root)

from data_handling.mammo import domain_maps, modelname_map, tissue_maps
from default_paths import EMBED_ROOT
from sklearn.model_selection import train_test_split
import numpy as np

## Create main EMBED csv 
These cells take care of merging the oroginal metadata and clinical csv, remove invalid views, convert density to numerical scale etc. 

In [None]:
full_dicom = pd.read_csv(
    EMBED_ROOT / "tables/EMBED_OpenData_metadata.csv", low_memory=False
)[
    [
        "InstanceNumber",
        "anon_dicom_path",
        "PixelSpacing",
        "ImagerPixelSpacing",
        "Rows",
        "Columns",
    ]
]


dicom = pd.read_csv(
    EMBED_ROOT / "tables/EMBED_OpenData_metadata_reduced.csv", low_memory=False
)
print(len(dicom))
dicom = dicom.merge(full_dicom, on="anon_dicom_path")
print(len(dicom))
dicom["image_path"] = (
    dicom["empi_anon"].astype("str")
    + "/"
    + dicom["anon_dicom_path"].str.split("/").str[-1].str.split(".dcm").str[0]
    + ".png"
)

In [None]:
# XCCL shouldn't be converted to CC so manually editing it
dicom.loc[
    (dicom["SeriesDescription"] == "RXCCL") | (dicom["SeriesDescription"] == "LXCCL"),
    "ViewPosition",
] = "XCCL"

# Getting all rows with "ViewPosition" == Nan (but for which SeriesDescription is also not nan, as these are the ones subject to the data entry error)
view_nan = dicom.loc[(dicom.ViewPosition.isna()) & (~dicom.SeriesDescription.isna())]

# Drop these rows from
dicom_no_nans = dicom[~dicom.index.isin(view_nan.index)]

view_nan["ViewPosition"] = view_nan["SeriesDescription"].apply(
    lambda x: "CC" if "CC" in x else ("MLO" if "MLO" in x else None)
)

dicom = pd.concat([dicom_no_nans, view_nan], axis=0, ignore_index=True)

print(len(dicom))
# Remove any duplicated images
dicom = dicom.drop_duplicates(subset="anon_dicom_path")
# Remove spot compressed and magnified images
dicom = dicom[dicom.spot_mag.isna()]
# Remove invalid views
dicom = dicom[dicom.ViewPosition.isin(["CC", "MLO"])]
# Remove images from male clients
dicom = dicom[dicom.PatientSex == "F"]
print(len(dicom))

In [None]:
# Remove any unnecessary fields from the DICOM imagewise dataframe (this may need to be updated in the future if other fields are deemed relevant)
dicom = dicom[
    [
        "empi_anon",
        "acc_anon",
        "image_path",
        "FinalImageType",
        "ImageLateralityFinal",
        "ViewPosition",
        "Manufacturer",
        "ManufacturerModelName",
    ]
]

In [None]:
# Conversion dictionary to standardised naming of various fields in clincial metadata

# Human reader BIRADS density assessment
dens_conversion = {1.0: "A", 2.0: "B", 3.0: "C", 4.0: "D"}

# Load in the clinical metadata
mag = pd.read_csv(EMBED_ROOT / "tables/EMBED_OpenData_clinical.csv", low_memory=False)
print(len(mag))
# Remove cases from cases a valid BIRADS density assessment
mag = mag[mag.tissueden.isin([1.0, 2.0, 3.0, 4.0])]
mag.replace({"tissueden": dens_conversion}, inplace=True)


# Keep important study metadata tags to join up with final aggregated dataframe at end of script
mag = mag[["empi_anon", "tissueden", "study_date_anon", "acc_anon"]].drop_duplicates(
    subset="acc_anon"
)
print(len(mag))

# Convert to pandas datetime object
mag["study_date_anon"] = pd.to_datetime(mag["study_date_anon"], errors="coerce")

In [None]:
dicom.Manufacturer.value_counts()

In [None]:
# Only consider studies which have a valid link between the DICOM and clinical metadata
print(len(dicom))
df = mag.merge(dicom, on=["acc_anon", 'empi_anon'])
print(len(df))

In [None]:
path_to_repo_root = "/vol/biomedic3/mb121/shift_identification/"
df.to_csv(Path(path_to_repo_root) / "data_handling" / "embed_full.csv", index=False)

## Create the splits

In [None]:
image_dir = EMBED_ROOT / Path("images/png/1024x768")

try:
    df = pd.read_csv(Path(path_to_root) / "data_handling" / "embed_full.csv")
except FileNotFoundError:
    print(
        """
        For running EMBED code you need to first generate the csv
        file used for this study by running the cells above
        """
    )

df["shortimgpath"] = df["image_path"]
df["image_path"] = df["image_path"].apply(lambda x: image_dir / str(x))

df["manufacturer_domain"] = df.Manufacturer.apply(lambda x: domain_maps[x])

# convert tissueden to trainable label
df["tissueden"] = df.tissueden.apply(lambda x: tissue_maps[x])

df["SimpleModelLabel"] = df.ManufacturerModelName.apply(lambda x: modelname_map[x])
print(df.SimpleModelLabel.value_counts())
df["ViewLabel"] = df.ViewPosition.apply(lambda x: 0 if x == "MLO" else 1)

df = df.dropna(
    subset=[
        "tissueden",
        "SimpleModelLabel",
        "ViewLabel",
        "image_path",
    ]
)

df["tissueden"].value_counts(normalize=True)

In [None]:
df = df.loc[df.FinalImageType == "2D"]

y = df.groupby("empi_anon")["tissueden"].unique().apply(lambda x: x[0]).values
print(np.bincount(y) / np.bincount(y).sum())
train_id, val_id = train_test_split(
    df.empi_anon.unique(), test_size=0.4, random_state=33, stratify=y
)


val_test_df = df.loc[df["empi_anon"].isin(val_id)]
# Keep only one study by patient
studies = (
    val_test_df.groupby("empi_anon")["acc_anon"].unique().apply(lambda x: x[0]).values
)
# For testing filter out all studies for which there is more than the expected 4 images (L/R, MLO/CC).
# These are the studies with failed images, images with unexpected stuff. To make sure that the
# distribution of val and un-shifted test are the same. Otherwise it might falsily the results.
weird = (
    df.groupby("acc_anon")["acc_anon"]
    .unique()
    .index[
        np.where(
            df.groupby("acc_anon")["shortimgpath"]
            .unique()
            .apply(lambda x: len(x) != 4)
            .values
        )[0]
    ]
)
val_test_df = val_test_df.loc[val_test_df["acc_anon"].isin(studies)]
val_test_df = val_test_df.loc[~val_test_df["acc_anon"].isin(weird)]

pd.crosstab(val_test_df["SimpleModelLabel"], val_test_df["tissueden"])

In [None]:
val_test_df["combined_var"] = val_test_df["SimpleModelLabel"] + 10 * val_test_df["tissueden"]
val_test_df["combined_var"].value_counts()

In [None]:
tmp = val_test_df.groupby("acc_anon")["combined_var"].unique()
ids, y = tmp.index, tmp.apply(lambda x: x[0]).values
test_id, val_id = train_test_split(ids, test_size=1200, random_state=33, stratify=y)
print(
    f"N patients train: {train_id.shape[0]}, val: {val_id.shape[0]}, test {test_id.shape[0]}"
)  # noqa

In [None]:
train_df = df.loc[df.empi_anon.isin(train_id)]
val_df = val_test_df.loc[val_test_df.acc_anon.isin(val_id)]
test_df = val_test_df.loc[val_test_df.acc_anon.isin(test_id)]
test_df["idx_in_original_test"] = np.arange(len(test_df))

In [None]:
pd.crosstab(test_df["SimpleModelLabel"], test_df["tissueden"], normalize="index")

In [None]:
pd.crosstab(val_df["SimpleModelLabel"], val_df["tissueden"], normalize="index")

In [47]:
train_df.to_csv("/vol/biomedic3/mb121/shift_identification/experiments/train_embed.csv")

In [51]:
val_df.to_csv("/vol/biomedic3/mb121/shift_identification/experiments/val_embed.csv")

In [50]:
test_df.to_csv("/vol/biomedic3/mb121/shift_identification/experiments/test_embed.csv")