-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(tuner): add base miner * feat(tuner): initialize miner classes * feat(tuner): add siamese miner * feat(tuner): add siamese miner * feat(tuner): refine mine method * feat(tuner): put parameters as function argument * feat(tuner): adjust docstrings * feat(tuner): add triplet miner * feat(tuner): add miner to framework * feat(tuner): add miner to framework * feat(tuner): fix type hints * feat(tuner): fix type hints * feat(tuner): revisit triplet miner * feat(tuner): adjust type hints and variable naming * feat(tuner): return idx instead of embeddings * feat(tuner): return idx instead of embeddings * feat(tuner): update type hints Co-authored-by: Tadej Svetina <tadej.svetina@jina.ai> * feat(tuner): return list rather generator * feat(tuner): return list rather generator Co-authored-by: Tadej Svetina <tadej.svetina@jina.ai>
- Loading branch information
Showing
4 changed files
with
140 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
from typing import List, Tuple | ||
from itertools import combinations | ||
|
||
from ..base import BaseMiner | ||
from ...helper import AnyTensor | ||
|
||
|
||
class SiameseMiner(BaseMiner): | ||
def mine( | ||
self, embeddings: List[AnyTensor], labels: List[int] | ||
) -> List[Tuple[int, ...]]: | ||
"""Generate tuples from input embeddings and labels. | ||
:param embeddings: embeddings from model, should be a list of Tensor objects. | ||
:param labels: labels of each embeddings, embeddings with same label indicates same class. | ||
:return: a pair of label indices and their label as tuple. | ||
""" | ||
assert len(embeddings) == len(labels) | ||
return [ | ||
(left[0], right[0], 1) if left[1] == right[1] else (left[0], right[0], -1) | ||
for left, right in combinations(enumerate(labels), 2) | ||
] | ||
|
||
|
||
class TripletMiner(BaseMiner): | ||
def mine( | ||
self, embeddings: List[AnyTensor], labels: List[int] | ||
) -> List[Tuple[int, ...]]: | ||
"""Generate triplets from input embeddings and labels. | ||
:param embeddings: embeddings from model, should be a list of Tensor objects. | ||
:param labels: labels of each embeddings, embeddings with same label indicates same class. | ||
:return: triplet of label indices follows the order of anchor, positive and negative. | ||
""" | ||
assert len(embeddings) == len(labels) | ||
labels1 = np.expand_dims(labels, 1) | ||
labels2 = np.expand_dims(labels, 0) | ||
matches = (labels1 == labels2).astype(int) | ||
diffs = matches ^ 1 | ||
np.fill_diagonal(matches, 0) | ||
triplets = np.expand_dims(matches, 2) * np.expand_dims(diffs, 1) | ||
idxes_anchor, idxes_pos, idxes_neg = np.where(triplets) | ||
return list(zip(idxes_anchor, idxes_pos, idxes_neg)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import pytest | ||
import torch | ||
|
||
from finetuner.tuner.pytorch.miner import SiameseMiner, TripletMiner | ||
|
||
BATCH_SIZE = 8 | ||
NUM_DIM = 10 | ||
|
||
|
||
@pytest.fixture | ||
def siamese_miner(): | ||
return SiameseMiner() | ||
|
||
|
||
@pytest.fixture | ||
def triplet_miner(): | ||
return TripletMiner() | ||
|
||
|
||
@pytest.fixture | ||
def embeddings(): | ||
return [torch.rand(NUM_DIM) for _ in range(BATCH_SIZE)] | ||
|
||
|
||
@pytest.fixture | ||
def labels(): | ||
return [1, 3, 1, 3, 2, 4, 2, 4] | ||
|
||
|
||
def test_siamese_miner(embeddings, labels, siamese_miner): | ||
rv = siamese_miner.mine(embeddings, labels) | ||
assert len(rv) == 28 | ||
for item in rv: | ||
idx_left, idx_right, label = item | ||
# find corresponded label idx | ||
label_left = labels[idx_left] | ||
label_right = labels[idx_right] | ||
if label_left == label_right: | ||
expected_label = 1 | ||
else: | ||
expected_label = -1 | ||
assert label == expected_label | ||
|
||
|
||
@pytest.mark.parametrize('cut_index', [0, 1]) | ||
def test_siamese_miner_given_insufficient_inputs( | ||
embeddings, labels, siamese_miner, cut_index | ||
): | ||
embeddings = embeddings[:cut_index] | ||
labels = labels[:cut_index] | ||
rv = list(siamese_miner.mine(embeddings, labels)) | ||
assert len(rv) == 0 | ||
|
||
|
||
def test_triplet_miner(embeddings, labels, triplet_miner): | ||
rv = triplet_miner.mine(embeddings, labels) | ||
assert len(rv) == 48 | ||
for item in rv: | ||
idx_anchor, idx_pos, idx_neg = item | ||
# find corresponded label idx | ||
label_anchor = labels[idx_anchor] | ||
label_pos = labels[idx_pos] | ||
label_neg = labels[idx_neg] | ||
# given ordered anchor, pos, neg, | ||
# assure first two labels are identical, first third label is different | ||
assert label_anchor == label_pos | ||
assert label_anchor != label_neg | ||
|
||
|
||
@pytest.mark.parametrize('cut_index', [0, 1]) | ||
def test_triplet_miner_given_insufficient_inputs( | ||
embeddings, labels, siamese_miner, cut_index | ||
): | ||
embeddings = embeddings[:cut_index] | ||
labels = labels[:cut_index] | ||
rv = list(siamese_miner.mine(embeddings, labels)) | ||
assert len(rv) == 0 |