Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 1, 2022
1 parent 1d7dec9 commit 5b4e887
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,15 @@ def gather_top_bottom_tiles_for_top_bottom_slides(self) -> None:
world_size = torch.distributed.get_world_size()
if world_size > 1:

shallow_top_slides_heaps = self._shallow_copy_slides_heaps(slides_heaps=self.top_slides_heaps)
shallow_bottom_slides_heaps = self._shallow_copy_slides_heaps(slides_heaps=self.bottom_slides_heaps)
shallow_top_slides_heaps = self._shallow_copy_slides_heaps(
slides_heaps=self.top_slides_heaps)
shallow_bottom_slides_heaps = self._shallow_copy_slides_heaps(
slides_heaps=self.bottom_slides_heaps)

agg_top_slides_heaps = self._aggregate_shallow_slides_heaps(world_size, shallow_top_slides_heaps)
agg_bottom_slides_heaps = self._aggregate_shallow_slides_heaps(world_size, shallow_bottom_slides_heaps)
agg_top_slides_heaps = self._aggregate_shallow_slides_heaps(
world_size, shallow_top_slides_heaps)
agg_bottom_slides_heaps = self._aggregate_shallow_slides_heaps(
world_size, shallow_bottom_slides_heaps)

top_slides_top_tiles, top_slides_bottom_tiles = self._collect_tiles_for_selected_slides_on_device(
new_slides_heaps=agg_top_slides_heaps, slides_heaps=self.top_slides_heaps
Expand All @@ -330,12 +334,17 @@ def gather_top_bottom_tiles_for_top_bottom_slides(self) -> None:
new_top_slides_heaps=agg_top_slides_heaps, new_bottom_slides_heaps=agg_bottom_slides_heaps
)

top_tiles: TileDict = self._gather_dictionaries(world_size, top_slides_top_tiles) # type: ignore
bottom_tiles: TileDict = self._gather_dictionaries(world_size, top_slides_bottom_tiles) # type: ignore
self._update_shallow_slides_heaps_with_top_bottom_tiles(self.top_slides_heaps, top_tiles, bottom_tiles)
top_tiles: TileDict = self._gather_dictionaries(
world_size, top_slides_top_tiles) # type: ignore
bottom_tiles: TileDict = self._gather_dictionaries(
world_size, top_slides_bottom_tiles) # type: ignore
self._update_shallow_slides_heaps_with_top_bottom_tiles(
self.top_slides_heaps, top_tiles, bottom_tiles)

top_tiles: TileDict = self._gather_dictionaries(world_size, bot_slides_top_tiles) # type: ignore
bottom_tiles: TileDict = self._gather_dictionaries(world_size, bot_slides_bottom_tiles) # type: ignore
top_tiles: TileDict = self._gather_dictionaries(
world_size, bot_slides_top_tiles) # type: ignore
bottom_tiles: TileDict = self._gather_dictionaries(
world_size, bot_slides_bottom_tiles) # type: ignore
self._update_shallow_slides_heaps_with_top_bottom_tiles(
self.bottom_slides_heaps, top_tiles, bottom_tiles
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _batch_data(data: Dict, batch_idx: int, batch_size: int) -> Generator:
"""Helper function to generate smaller batches from a dictionary."""
batch = {}
for k in data:
batch[k] = data[k][batch_idx * batch_size: (batch_idx + 1) * batch_size]
batch[k] = data[k][batch_idx *
batch_size: (batch_idx + 1) * batch_size]
yield batch


Expand Down Expand Up @@ -84,11 +85,14 @@ def _create_and_update_top_bottom_tiles_handler(
:return: A top bottom tiles handler with selected top and bottom slides and corresponding top and bottom slides.
"""

handler = TopBottomTilesHandler(n_classes, n_top_slides=n_top_slides, n_top_tiles=n_top_tiles)
handler = TopBottomTilesHandler(
n_classes, n_top_slides=n_top_slides, n_top_tiles=n_top_tiles)

for i in range(rank * n_batches, (rank + 1) * n_batches):
batch_data = next(_batch_data(data, batch_idx=i, batch_size=batch_size))
batch_results = next(_batch_data(results, batch_idx=i, batch_size=batch_size))
batch_data = next(_batch_data(
data, batch_idx=i, batch_size=batch_size))
batch_results = next(_batch_data(
results, batch_idx=i, batch_size=batch_size))
handler.update_slides_selection(batch_data, batch_results)

return handler
Expand All @@ -106,12 +110,14 @@ def _get_expected_slides_by_probability(
:return: A list of selected slide ids.
"""

class_indices = (results[ResultsKey.TRUE_LABEL].squeeze() == label).nonzero().squeeze(1)
class_indices = (
results[ResultsKey.TRUE_LABEL].squeeze() == label).nonzero().squeeze(1)
class_prob = results[ResultsKey.CLASS_PROBS][class_indices, label]
assert class_prob.shape == (len(class_indices),)
n_top_slides = min(n_top_slides, len(class_prob))

_, sorting_indices = class_prob.topk(n_top_slides, largest=top, sorted=True)
_, sorting_indices = class_prob.topk(
n_top_slides, largest=top, sorted=True)
sorted_class_indices = class_indices[sorting_indices]

# the order is inversed in the heaps
Expand Down Expand Up @@ -145,7 +151,8 @@ def test_gather_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size: i
n_top_slides = 2

torch.manual_seed(42)
data = _create_mock_data(n_samples=batch_size * total_batches, n_tiles=n_tiles, device=device)
data = _create_mock_data(n_samples=batch_size *
total_batches, n_tiles=n_tiles, device=device)
results = _create_mock_results(
n_samples=batch_size * total_batches, n_tiles=n_tiles, n_classes=n_classes, device=device
)
Expand All @@ -159,17 +166,21 @@ def test_gather_shallow_slide_nodes(n_classes: int, rank: int = 0, world_size: i

if torch.distributed.is_initialized():
if world_size > 1:
shallow_top_slides_heaps = handler._aggregate_shallow_slides_heaps(world_size, shallow_top_slides_heaps)
shallow_top_slides_heaps = handler._aggregate_shallow_slides_heaps(
world_size, shallow_top_slides_heaps)
shallow_bottom_slides_heaps = handler._aggregate_shallow_slides_heaps(
world_size, shallow_bottom_slides_heaps
)

if rank == 0:
for label in range(n_classes):
expected_top_slides_ids = get_expected_top_slides_by_probability(results, n_top_slides, label)
assert expected_top_slides_ids == [slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]]
expected_top_slides_ids = get_expected_top_slides_by_probability(
results, n_top_slides, label)
assert expected_top_slides_ids == [
slide_node.slide_id for slide_node in shallow_top_slides_heaps[label]]

expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(results, n_top_slides, label)
expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(
results, n_top_slides, label)
assert expected_bottom_slides_ids == [
slide_node.slide_id for slide_node in shallow_bottom_slides_heaps[label]
]
Expand Down Expand Up @@ -211,8 +222,10 @@ def assert_equal_top_bottom_attention_tiles(
results[ResultsKey.BAG_ATTN][slide_batch_idx].squeeze(), k=n_top_tiles, largest=False, sorted=True
)

expected_top_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in top_tiles_ids]
expected_bottom_tiles: List[torch.Tensor] = [tiles[tile_id] for tile_id in bottom_tiles_ids]
expected_top_tiles: List[torch.Tensor] = [
tiles[tile_id] for tile_id in top_tiles_ids]
expected_bottom_tiles: List[torch.Tensor] = [
tiles[tile_id] for tile_id in bottom_tiles_ids]

top_tiles = slide_nodes[i].top_tiles
bottom_tiles = slide_nodes[i].bottom_tiles
Expand All @@ -222,7 +235,8 @@ def assert_equal_top_bottom_attention_tiles(
assert expected_top_attns[j].item() == top_tiles[j].attn

for j, expected_bottom_tile in enumerate(expected_bottom_tiles):
assert torch.equal(expected_bottom_tile.cpu(), bottom_tiles[j].data)
assert torch.equal(expected_bottom_tile.cpu(),
bottom_tiles[j].data)
assert expected_bottom_attns[j].item() == bottom_tiles[j].attn


Expand Down Expand Up @@ -251,7 +265,8 @@ def test_select_k_top_bottom_tiles_on_the_fly(
n_top_slides = 2

torch.manual_seed(42)
data = _create_mock_data(n_samples=batch_size * total_batches, n_tiles=n_tiles, device=device)
data = _create_mock_data(n_samples=batch_size *
total_batches, n_tiles=n_tiles, device=device)
results = _create_mock_results(
n_samples=batch_size * total_batches, n_tiles=n_tiles, n_classes=n_classes, device=device
)
Expand All @@ -263,18 +278,23 @@ def test_select_k_top_bottom_tiles_on_the_fly(

if rank == 0:
for label in range(n_classes):
expected_top_slides_ids = get_expected_top_slides_by_probability(results, n_top_slides, label)
assert expected_top_slides_ids == [slide_node.slide_id for slide_node in handler.top_slides_heaps[label]]
expected_top_slides_ids = get_expected_top_slides_by_probability(
results, n_top_slides, label)
assert expected_top_slides_ids == [
slide_node.slide_id for slide_node in handler.top_slides_heaps[label]]
assert_equal_top_bottom_attention_tiles(
expected_top_slides_ids, data, results, n_top_tiles, handler.top_slides_heaps[label]
expected_top_slides_ids, data, results, n_top_tiles, handler.top_slides_heaps[
label]
)

expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(results, n_top_slides, label)
expected_bottom_slides_ids = get_expected_bottom_slides_by_probability(
results, n_top_slides, label)
assert expected_bottom_slides_ids == [
slide_node.slide_id for slide_node in handler.bottom_slides_heaps[label]
]
assert_equal_top_bottom_attention_tiles(
expected_bottom_slides_ids, data, results, n_top_tiles, handler.bottom_slides_heaps[label]
expected_bottom_slides_ids, data, results, n_top_tiles, handler.bottom_slides_heaps[
label]
)


Expand All @@ -284,11 +304,15 @@ def test_select_k_top_bottom_tiles_on_the_fly(
def test_select_k_top_bottom_tiles_on_the_fly_distributed() -> None:
"""These tests need to be called sequentially to prevent them to be run in parallel"""
# test with n_classes = 2
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [2], world_size=2)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly,
[2], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly,
[2], world_size=2)
# test with n_classes = 3
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly, [3], world_size=2)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly,
[3], world_size=1)
run_distributed(test_select_k_top_bottom_tiles_on_the_fly,
[3], world_size=2)


@pytest.fixture
Expand All @@ -298,11 +322,13 @@ def slide_node() -> SlideNode:
tile_size = (3, 224, 224)
n_top_tiles = 10
slide_node = SlideNode(slide_id="slide_0", prob_score=0.5)
top_attn_scores = [0.99, 0.98, 0.97, 0.96, 0.95, 0.94, 0.93, 0.92, 0.91, 0.90]
top_attn_scores = [0.99, 0.98, 0.97, 0.96,
0.95, 0.94, 0.93, 0.92, 0.91, 0.90]
slide_node.top_tiles = [
TileNode(attn=top_attn_scores[i], data=torch.randint(0, 255, tile_size)) for i in range(n_top_tiles)
]
bottom_attn_scores = [0.09, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01, 0.009]
bottom_attn_scores = [0.09, 0.08, 0.07, 0.06,
0.05, 0.04, 0.03, 0.02, 0.01, 0.009]
slide_node.bottom_tiles = [
TileNode(attn=bottom_attn_scores[i], data=torch.randint(0, 255, tile_size)) for i in range(n_top_tiles)
]
Expand All @@ -322,8 +348,12 @@ def assert_plot_tiles_figure(tiles_fig: plt.Figure, fig_name: str, test_output_d
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
def test_plot_top_bottom_tiles(slide_node: SlideNode, test_output_dirs: OutputFolderForTests) -> None:

top_tiles_fig = slide_node.plot_attention_tiles(tile_nodes=slide_node.top_tiles, case="TP")
bottom_tiles_fig = slide_node.plot_attention_tiles(tile_nodes=slide_node.bottom_tiles, case="FN")
top_tiles_fig = slide_node.plot_attention_tiles(
tile_nodes=slide_node.top_tiles, case="TP")
bottom_tiles_fig = slide_node.plot_attention_tiles(
tile_nodes=slide_node.bottom_tiles, case="FN")

assert_plot_tiles_figure(top_tiles_fig, "slide_0_top.png", test_output_dirs)
assert_plot_tiles_figure(bottom_tiles_fig, "slide_0_bottom.png", test_output_dirs)
assert_plot_tiles_figure(
top_tiles_fig, "slide_0_top.png", test_output_dirs)
assert_plot_tiles_figure(
bottom_tiles_fig, "slide_0_bottom.png", test_output_dirs)

0 comments on commit 5b4e887

Please sign in to comment.