Skip to content

Commit

Permalink
feat(tuner): add miner v1 (#180)
Browse files Browse the repository at this point in the history
* feat(tuner): add base miner

* feat(tuner): initialize miner classes

* feat(tuner): add siamese miner

* feat(tuner): add siamese miner

* feat(tuner): refine mine method

* feat(tuner): put parameters as function argument

* feat(tuner): adjust docstrings

* feat(tuner): add triplet miner

* feat(tuner): add miner to framework

* feat(tuner): add miner to framework

* feat(tuner): fix type hints

* feat(tuner): fix type hints

* feat(tuner): revisit triplet miner

* feat(tuner): adjust type hints and variable naming

* feat(tuner): return idx instead of embeddings

* feat(tuner): return idx instead of embeddings

* feat(tuner): update type hints

Co-authored-by: Tadej Svetina <tadej.svetina@jina.ai>

* feat(tuner): return list rather generator

* feat(tuner): return list rather generator

Co-authored-by: Tadej Svetina <tadej.svetina@jina.ai>
  • Loading branch information
bwanglzu and Tadej Svetina committed Nov 2, 2021
1 parent ae8e399 commit 1e4a1ae
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 1 deletion.
3 changes: 3 additions & 0 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
AnyDNN = TypeVar(
'AnyDNN'
) #: The type of any implementation of a Deep Neural Network object
AnyTensor = TypeVar(
'AnyTensor'
) #: The type of any implementation of an tensor for model tuning
AnyDataLoader = TypeVar(
'AnyDataLoader'
) #: The type of any implementation of a data loader
Expand Down
17 changes: 16 additions & 1 deletion finetuner/tuner/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import abc
import warnings
from typing import (
Generator,
Optional,
Union,
Tuple,
List,
Dict,
)

from ..helper import AnyDNN, AnyDataLoader, AnyOptimizer, DocumentArrayLike
from ..helper import AnyDNN, AnyTensor, AnyDataLoader, AnyOptimizer, DocumentArrayLike
from .summary import Summary


Expand Down Expand Up @@ -148,3 +149,17 @@ def __init__(
):
super().__init__()
self._inputs = inputs() if callable(inputs) else inputs


class BaseMiner(abc.ABC):
@abc.abstractmethod
def mine(
self, embeddings: List[AnyTensor], labels: List[int]
) -> List[Tuple[int, ...]]:
"""Generate tuples/triplets from input embeddings and labels.
:param embeddings: embeddings from model, should be a list of Tensor objects.
:param labels: labels of each embeddings, embeddings with same label indicates same class.
:return: tuple/triplet of label indices.
"""
...
44 changes: 44 additions & 0 deletions finetuner/tuner/pytorch/miner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from typing import List, Tuple
from itertools import combinations

from ..base import BaseMiner
from ...helper import AnyTensor


class SiameseMiner(BaseMiner):
def mine(
self, embeddings: List[AnyTensor], labels: List[int]
) -> List[Tuple[int, ...]]:
"""Generate tuples from input embeddings and labels.
:param embeddings: embeddings from model, should be a list of Tensor objects.
:param labels: labels of each embeddings, embeddings with same label indicates same class.
:return: a pair of label indices and their label as tuple.
"""
assert len(embeddings) == len(labels)
return [
(left[0], right[0], 1) if left[1] == right[1] else (left[0], right[0], -1)
for left, right in combinations(enumerate(labels), 2)
]


class TripletMiner(BaseMiner):
def mine(
self, embeddings: List[AnyTensor], labels: List[int]
) -> List[Tuple[int, ...]]:
"""Generate triplets from input embeddings and labels.
:param embeddings: embeddings from model, should be a list of Tensor objects.
:param labels: labels of each embeddings, embeddings with same label indicates same class.
:return: triplet of label indices follows the order of anchor, positive and negative.
"""
assert len(embeddings) == len(labels)
labels1 = np.expand_dims(labels, 1)
labels2 = np.expand_dims(labels, 0)
matches = (labels1 == labels2).astype(int)
diffs = matches ^ 1
np.fill_diagonal(matches, 0)
triplets = np.expand_dims(matches, 2) * np.expand_dims(diffs, 1)
idxes_anchor, idxes_pos, idxes_neg = np.where(triplets)
return list(zip(idxes_anchor, idxes_pos, idxes_neg))
77 changes: 77 additions & 0 deletions tests/unit/tuner/torch/test_miner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
import torch

from finetuner.tuner.pytorch.miner import SiameseMiner, TripletMiner

BATCH_SIZE = 8
NUM_DIM = 10


@pytest.fixture
def siamese_miner():
return SiameseMiner()


@pytest.fixture
def triplet_miner():
return TripletMiner()


@pytest.fixture
def embeddings():
return [torch.rand(NUM_DIM) for _ in range(BATCH_SIZE)]


@pytest.fixture
def labels():
return [1, 3, 1, 3, 2, 4, 2, 4]


def test_siamese_miner(embeddings, labels, siamese_miner):
rv = siamese_miner.mine(embeddings, labels)
assert len(rv) == 28
for item in rv:
idx_left, idx_right, label = item
# find corresponded label idx
label_left = labels[idx_left]
label_right = labels[idx_right]
if label_left == label_right:
expected_label = 1
else:
expected_label = -1
assert label == expected_label


@pytest.mark.parametrize('cut_index', [0, 1])
def test_siamese_miner_given_insufficient_inputs(
embeddings, labels, siamese_miner, cut_index
):
embeddings = embeddings[:cut_index]
labels = labels[:cut_index]
rv = list(siamese_miner.mine(embeddings, labels))
assert len(rv) == 0


def test_triplet_miner(embeddings, labels, triplet_miner):
rv = triplet_miner.mine(embeddings, labels)
assert len(rv) == 48
for item in rv:
idx_anchor, idx_pos, idx_neg = item
# find corresponded label idx
label_anchor = labels[idx_anchor]
label_pos = labels[idx_pos]
label_neg = labels[idx_neg]
# given ordered anchor, pos, neg,
# assure first two labels are identical, first third label is different
assert label_anchor == label_pos
assert label_anchor != label_neg


@pytest.mark.parametrize('cut_index', [0, 1])
def test_triplet_miner_given_insufficient_inputs(
embeddings, labels, siamese_miner, cut_index
):
embeddings = embeddings[:cut_index]
labels = labels[:cut_index]
rv = list(siamese_miner.mine(embeddings, labels))
assert len(rv) == 0

0 comments on commit 1e4a1ae

Please sign in to comment.