Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mebristo committed Dec 6, 2021
1 parent 7ea73c8 commit c53764b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 40 deletions.
2 changes: 1 addition & 1 deletion InnerEye/ML/Histopathology/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ def get_slide_labels(self) -> pd.Series:
def get_class_weights(self) -> torch.Tensor:
slide_labels = self.get_slide_labels()
classes = np.unique(slide_labels)
class_weights = compute_class_weight('balanced', classes, slide_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
return torch.as_tensor(class_weights)
6 changes: 4 additions & 2 deletions InnerEye/ML/baselines_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,17 @@ def compare_files(expected: Path, actual: Path) -> str:
If the files are not identical, an error message with details is return. This handles known text file formats,
where it ignores differences in line breaks. All other files are treated as binary, and compared on a byte-by-byte
basis.
:param expected: A file that contains the expected contents. The type of comparison (text or binary) is chosen
based on the extension of this file.
:param actual: A file that contains the actual contents.
:return: An empty string if the files appear identical, or otherwise an error message with details.
"""

def print_lines(prefix: str, lines: List[str]) -> None:
count = 5
logging.debug(f"{prefix} {len(lines)} lines, first {count} of those:")
num_lines = len(lines)
count = min(5, num_lines)
logging.debug(f"{prefix} {num_lines} lines, first {count} of those:")
logging.debug(os.linesep.join(lines[:count]))

if expected.suffix in TEXT_FILE_SUFFIXES:
Expand Down
38 changes: 1 addition & 37 deletions Tests/ML/histopathology/models/test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
from pathlib import Path
from typing import Callable

import pytest
from torch import Tensor, float32, nn, rand
from torchvision.models import resnet18

from health_azure.utils import CheckpointDownloader, get_workspace
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
ImageNetSimCLREncoder, InnerEyeSSLEncoder)
ImageNetSimCLREncoder)


def get_supervised_imagenet_encoder() -> TileEncoder:
Expand All @@ -20,21 +16,6 @@ def get_simclr_imagenet_encoder() -> TileEncoder:
return ImageNetSimCLREncoder(tile_size=224)


def get_simclr_crck_encoder(tmp_path: Path) -> TileEncoder:
aml_workspace = get_workspace()
download_dir = tmp_path / "downlads"
download_dir.mkdir(exist_ok=True)
downloader = CheckpointDownloader(aml_workspace=aml_workspace,
run_id="vsalva_ssl_crck:vsalva_ssl_crck_1630691119_af10db8a",
checkpoint_filename="best_checkpoint.ckpt",
download_dir=download_dir)
os.chdir(fixed_paths.repository_root_directory())
_ = downloader.download_checkpoint_if_necessary()

return InnerEyeSSLEncoder(pl_checkpoint_path=downloader.local_checkpoint_path,
tile_size=224)


def get_histo_ssl_encoder() -> TileEncoder:
return HistoSSLEncoder(tile_size=224)

Expand All @@ -57,20 +38,3 @@ def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None:
features = encoder(images)
assert isinstance(features, Tensor)
assert features.shape == (batch_size, encoder.num_encoding)

@pytest.mark.skip(reason="This checkpoint has extra keys with respect to the latest SimCLRInnerEye class, cannot be loaded with strict=True")
def test_simclr_crck_encoder(tmp_path: Path) -> None:
batch_size = 10

encoder = get_simclr_crck_encoder(tmp_path)

if isinstance(encoder, nn.Module):
for param_name, param in encoder.named_parameters():
assert not param.requires_grad, \
f"Feature extractor has unfrozen parameters: {param_name}"

images = rand(batch_size, *encoder.input_dim, dtype=float32)

features = encoder(images)
assert isinstance(features, Tensor)
assert features.shape == (batch_size, encoder.num_encoding)

0 comments on commit c53764b

Please sign in to comment.