Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor core embeddings based datasets #495

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/vision/dino_vit/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ data:
manifest_file: manifest.csv
split: train
target_transforms:
class_path: eva.core.data.transforms.ArrayToFloatTensor
class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: torch.float32
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
init_args:
Expand Down
4 changes: 3 additions & 1 deletion configs/vision/dino_vit/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ data:
manifest_file: manifest.csv
split: train
target_transforms:
class_path: eva.core.data.transforms.ArrayToFloatTensor
class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: torch.float32
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
init_args:
Expand Down
4 changes: 3 additions & 1 deletion configs/vision/owkin/phikon/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ data:
manifest_file: manifest.csv
split: train
target_transforms:
class_path: eva.core.data.transforms.ArrayToFloatTensor
class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: torch.float32
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
init_args:
Expand Down
4 changes: 3 additions & 1 deletion configs/vision/owkin/phikon/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ data:
manifest_file: manifest.csv
split: train
target_transforms:
class_path: eva.core.data.transforms.ArrayToFloatTensor
class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: torch.float32
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
init_args:
Expand Down
4 changes: 3 additions & 1 deletion configs/vision/tests/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ data:
manifest_file: manifest.csv
split: train
target_transforms:
class_path: eva.core.data.transforms.ArrayToFloatTensor
class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: torch.float32
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
init_args:
Expand Down
4 changes: 2 additions & 2 deletions src/eva/core/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Datasets API."""

from eva.core.data.datasets.base import Dataset
from eva.core.data.datasets.dataset import TorchDataset
from eva.core.data.datasets.embeddings import (
from eva.core.data.datasets.classification import (
EmbeddingsClassificationDataset,
MultiEmbeddingsClassificationDataset,
)
from eva.core.data.datasets.dataset import TorchDataset

__all__ = [
"Dataset",
Expand Down
8 changes: 8 additions & 0 deletions src/eva/core/data/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Embedding cllassification datasets API."""

from eva.core.data.datasets.classification.embeddings import EmbeddingsClassificationDataset
from eva.core.data.datasets.classification.multi_embeddings import (
MultiEmbeddingsClassificationDataset,
)

__all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
import os
from typing import Callable, Dict, Literal

import numpy as np
import torch
from typing_extensions import override

from eva.core.data.datasets.embeddings import base
from eva.core.data.datasets import embeddings as embeddings_base


class EmbeddingsClassificationDataset(base.EmbeddingsDataset):
class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
"""Embeddings dataset class for classification tasks."""

def __init__(
self,
root: str,
manifest_file: str,
split: Literal["train", "val", "test"] | None = None,
column_mapping: Dict[str, str] = base.default_column_mapping,
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
embeddings_transforms: Callable | None = None,
target_transforms: Callable | None = None,
) -> None:
Expand Down Expand Up @@ -63,9 +62,9 @@ def _load_embeddings(self, index: int) -> torch.Tensor:
return tensor.squeeze(0)

@override
def _load_target(self, index: int) -> np.ndarray:
def _load_target(self, index: int) -> torch.Tensor:
target = self._data.at[index, self._column_mapping["target"]]
return np.asarray(target, dtype=np.int64)
return torch.tensor(target, dtype=torch.int64)

@override
def __len__(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
from typing_extensions import override

from eva.core.data.datasets.embeddings import base
from eva.core.data.datasets import embeddings as embeddings_base


class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
"""Dataset class for where a sample corresponds to multiple embeddings.

Example use case: Slide level dataset where each slide has multiple patch embeddings.
Expand All @@ -21,7 +21,7 @@ def __init__(
root: str,
manifest_file: str,
split: Literal["train", "val", "test"],
column_mapping: Dict[str, str] = base.default_column_mapping,
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
embeddings_transforms: Callable | None = None,
target_transforms: Callable | None = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import abc
import os
from typing import Callable, Dict, Literal, Tuple
from typing import Callable, Dict, Generic, Literal, Tuple, TypeVar

import numpy as np
import pandas as pd
import torch
from typing_extensions import override

from eva.core.data.datasets import base
from eva.core.utils import io

TargetType = TypeVar("TargetType")
"""The target data type."""


default_column_mapping: Dict[str, str] = {
"path": "embeddings",
"target": "target",
Expand All @@ -21,7 +24,7 @@
"""The default column mapping of the variables to the manifest columns."""


class EmbeddingsDataset(base.Dataset):
class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
"""Abstract base class for embedding datasets."""

def __init__(
Expand Down Expand Up @@ -62,32 +65,6 @@ def __init__(

self._data: pd.DataFrame

@abc.abstractmethod
def _load_embeddings(self, index: int) -> torch.Tensor:
"""Returns the `index`'th embedding sample.

Args:
index: The index of the data sample to load.

Returns:
The embedding sample as a tensor.
"""

@abc.abstractmethod
def _load_target(self, index: int) -> np.ndarray:
"""Returns the `index`'th target sample.

Args:
index: The index of the data sample to load.

Returns:
The sample target as an array.
"""

@abc.abstractmethod
def __len__(self) -> int:
"""Returns the total length of the data."""

def filename(self, index: int) -> str:
"""Returns the filename of the `index`'th data sample.

Expand All @@ -105,7 +82,11 @@ def filename(self, index: int) -> str:
def setup(self):
self._data = self._load_manifest()

def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
@abc.abstractmethod
def __len__(self) -> int:
"""Returns the total length of the data."""

def __getitem__(self, index) -> Tuple[torch.Tensor, TargetType]:
"""Returns the `index`'th data sample.

Args:
Expand All @@ -118,6 +99,28 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
target = self._load_target(index)
return self._apply_transforms(embeddings, target)

@abc.abstractmethod
def _load_embeddings(self, index: int) -> torch.Tensor:
"""Returns the `index`'th embedding sample.

Args:
index: The index of the data sample to load.

Returns:
The embedding sample as a tensor.
"""

@abc.abstractmethod
def _load_target(self, index: int) -> TargetType:
"""Returns the `index`'th target sample.

Args:
index: The index of the data sample to load.

Returns:
The sample target as an array.
"""

def _load_manifest(self) -> pd.DataFrame:
"""Loads manifest file and filters the data based on the split column.

Expand All @@ -132,8 +135,8 @@ def _load_manifest(self) -> pd.DataFrame:
return data

def _apply_transforms(
self, embeddings: torch.Tensor, target: np.ndarray
) -> Tuple[torch.Tensor, np.ndarray]:
self, embeddings: torch.Tensor, target: TargetType
) -> Tuple[torch.Tensor, TargetType]:
"""Applies the transforms to the provided data and returns them.

Args:
Expand Down
13 changes: 0 additions & 13 deletions src/eva/core/data/datasets/embeddings/__init__.py

This file was deleted.

10 changes: 0 additions & 10 deletions src/eva/core/data/datasets/embeddings/classification/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os
from typing import Literal

import numpy as np
import pytest
import torch

from eva.core.data.datasets.embeddings import classification
from eva.core.data.datasets import classification


@pytest.mark.parametrize("split", ["train", "val"])
Expand All @@ -21,7 +20,7 @@ def test_embedding_dataset(embeddings_dataset: classification.EmbeddingsClassifi
embeddings, target = sample
assert isinstance(embeddings, torch.Tensor)
assert embeddings.shape == (8,)
assert isinstance(target, np.ndarray)
assert isinstance(target, torch.Tensor)
assert target in [0, 1]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn

from eva.core.data import transforms
from eva.core.data.datasets.embeddings import classification
from eva.core.data.datasets import classification


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tests/eva/core/data/datasets/embeddings/__init__.py

This file was deleted.