-
-
Notifications
You must be signed in to change notification settings - Fork 54
/
extract_tile_pw_gtex.py
276 lines (237 loc) · 7.91 KB
/
extract_tile_pw_gtex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import argparse
import os
import time
from typing import List, Tuple
import numpy as np
import pandas as pd
import requests
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from histolab.slide import SlideSet
from histolab.tiler import RandomTiler
URL_ROOT = "https://brd.nci.nih.gov/brd/imagedownload"
def download_wsi_gtex(dataset_dir: str, sample_ids: List[str]) -> None:
"""Download into ``dataset_dir`` all the GTEx WSIs corresponding to ``sample_ids``
Parameters
----------
dataset_dir : str
Path where to save the WSIs
sample_ids : List[str]
List of GTEx WSI ids
"""
for sample_id in tqdm(sample_ids):
if f"{sample_id}.svs" not in os.listdir(dataset_dir):
with requests.get(f"{URL_ROOT}/{sample_id}", stream=True) as request:
request.raise_for_status()
with open(
os.path.join(dataset_dir, f"{sample_id}.svs"), "wb"
) as output_file:
for chunk in request.iter_content(chunk_size=8192):
output_file.write(chunk)
time.sleep(np.random.randint(60, 100))
def extract_random_tiles(
dataset_dir: str,
processed_path: str,
tile_size: Tuple[int, int],
n_tiles: int,
level: int,
seed: int,
check_tissue: bool,
) -> None:
"""Save random tiles extracted from WSIs in `dataset_dir` into `processed_path`/tiles
Parameters
----------
dataset_dir : str
Path were the WSIs are saved
processed_path : str
Path where to store the tiles (will be concatenated with /tiles)
tile_size : Tuple[int, int]
width and height of the cropped tiles
n_tiles : int
Maximum number of tiles to extract
level : int
Magnification level from which extract the tiles
seed : int
Seed for RandomState
check_tissue : bool
Whether to check if the tile has enough tissue to be saved
"""
slideset = SlideSet(dataset_dir, processed_path, valid_extensions=[".svs"])
for slide in tqdm(slideset.slides):
prefix = f"{slide.name}_"
random_tiles_extractor = RandomTiler(
tile_size=tile_size,
n_tiles=n_tiles,
level=level,
seed=seed,
check_tissue=check_tissue,
prefix=prefix,
)
random_tiles_extractor.extract(slide)
def train_test_df_patient_wise(
dataset_df: pd.DataFrame,
patient_col: str,
label_col: str,
test_size: float = 0.2,
seed: int = 1234,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Split ``dataset_df`` into train/test partitions following a patient-wise protocol.
Parameters
----------
dataset_df : pd.DataFrame
DataFrame containing the data to split
patient_col : str
Name of the patient column in ``dataset_df``
label_col : str
Name of the target column in ``dataset_df``
test_size : float, optional
Ratio of test set samples over the entire dataset, by default 0.2
seed : int, optional
Seed for RandomState, by default 1234
Returns
-------
pd.DataFrame
Training dataset
pd.DataFrame
Test dataset
"""
patient_with_labels = (
dataset_df.groupby(patient_col)[label_col].unique().apply(list)
)
unique_patients = patient_with_labels.index.values
train_patients, test_patients = train_test_split(
unique_patients, test_size=test_size, random_state=seed
)
dataset_train_df = dataset_df.loc[dataset_df[patient_col].isin(train_patients)]
dataset_test_df = dataset_df.loc[dataset_df[patient_col].isin(test_patients)]
return dataset_train_df, dataset_test_df
def split_tiles_patient_wise(
tiles_dir: str,
metadata_df: pd.DataFrame,
train_csv_path: str,
test_csv_path: str,
label_col: str,
patient_col: str,
test_size: float = 0.2,
seed: int = 1234,
) -> None:
"""Split a tile dataset into train-test following a patient-wise partitioning protocol.
Save two CSV files containing the train-test partition for the tile dataset.
Parameters
----------
tiles_dir : str
Tile dataset directory.
metadata_df : pd.DataFrame
CSV of patient metadata.
train_csv_path : str
Path where to save the CSV file for the training set.
test_csv_path : str
Path where to save the CSV file for the test set.
label_col : str
Name of the target column in ``dataset_df``
patient_col : str
Name of the patient column in ``dataset_df``
test_size : float, optional
Ratio of test set samples over the entire dataset, by default 0.2
seed : int, optional
Seed for RandomState, by default 1234
"""
tiles_filenames = [
f for f in os.listdir(tiles_dir) if os.path.splitext(f)[1] == ".png"
]
tiles_filenames_df = pd.DataFrame(
{
"tile_filename": tiles_filenames,
"Tissue Sample ID": [f.split("_")[0] for f in tiles_filenames],
}
)
tiles_metadata = metadata_df.join(
tiles_filenames_df.set_index("Tissue Sample ID"), on="Tissue Sample ID"
)
train_df, test_df = train_test_df_patient_wise(
tiles_metadata, patient_col, label_col, test_size, seed,
)
train_df.to_csv(train_csv_path, index=None)
test_df.to_csv(test_csv_path, index=None)
def main():
parser = argparse.ArgumentParser(
description="Retrieve a leakage-free dataset of tiles using a collection of WSI."
)
parser.add_argument(
"--metadata_csv",
type=str,
default="examples/GTEx/GTEx_AIDP2021.csv",
help="CSV with WSI metadata. Default examples/GTEx/GTEx_AIDP2021.csv.",
)
parser.add_argument(
"--wsi_dataset_dir",
type=str,
default="WSI_GTEx",
help="Path where to save the WSIs. Default WSI_GTEx.",
)
parser.add_argument(
"--tile_dataset_dir",
type=str,
default="tiles_GTEx",
help="Path where to save the WSIs. Default tiles_GTEx.",
)
parser.add_argument(
"--tile_size",
type=int,
nargs=2,
default=(512, 512),
help="width and height of the cropped tiles. Default (512, 512).",
)
parser.add_argument(
"--n_tiles",
type=int,
default=100,
help="Maximum number of tiles to extract. Default 100.",
)
parser.add_argument(
"--level",
type=int,
default=2,
help="Magnification level from which extract the tiles. Default 2.",
)
parser.add_argument(
"--seed", type=int, default=7, help="Seed for RandomState. Default 7.",
)
parser.add_argument(
"--check_tissue",
type=bool,
default=True,
help="Whether to check if the tile has enough tissue to be saved. Default True.",
)
args = parser.parse_args()
metadata_csv = args.metadata_csv
wsi_dataset_dir = args.wsi_dataset_dir
tile_dataset_dir = args.tile_dataset_dir
tile_size = args.tile_size
n_tiles = args.n_tiles
level = args.level
seed = args.seed
check_tissue = args.check_tissue
gtex_df = pd.read_csv(metadata_csv)
os.makedirs(wsi_dataset_dir)
sample_ids = gtex_df["Tissue Sample ID"].tolist()
download_wsi_gtex(wsi_dataset_dir, sample_ids)
extract_random_tiles(
wsi_dataset_dir, tile_dataset_dir, tile_size, n_tiles, level, seed, check_tissue
)
split_tiles_patient_wise(
tiles_dir=os.path.join(tile_dataset_dir, "tiles"),
metadata_df=gtex_df,
train_csv_path=os.path.join(
tile_dataset_dir, f"train_tiles_PW_{os.path.basename(metadata_csv)}"
),
test_csv_path=os.path.join(
tile_dataset_dir, f"test_tiles_PW_{os.path.basename(metadata_csv)}"
),
label_col="Tissue",
patient_col="Subject ID",
test_size=0.2,
seed=1234,
)
if __name__ == "__main__":
main()