From 1e4a1aeebce9c11ec3372a716a1f17c31396b6b8 Mon Sep 17 00:00:00 2001 From: Wang Bo Date: Tue, 2 Nov 2021 10:41:26 +0100 Subject: [PATCH] feat(tuner): add miner v1 (#180) * feat(tuner): add base miner * feat(tuner): initialize miner classes * feat(tuner): add siamese miner * feat(tuner): add siamese miner * feat(tuner): refine mine method * feat(tuner): put parameters as function argument * feat(tuner): adjust docstrings * feat(tuner): add triplet miner * feat(tuner): add miner to framework * feat(tuner): add miner to framework * feat(tuner): fix type hints * feat(tuner): fix type hints * feat(tuner): revisit triplet miner * feat(tuner): adjust type hints and variable naming * feat(tuner): return idx instead of embeddings * feat(tuner): return idx instead of embeddings * feat(tuner): update type hints Co-authored-by: Tadej Svetina * feat(tuner): return list rather generator * feat(tuner): return list rather generator Co-authored-by: Tadej Svetina --- finetuner/helper.py | 3 ++ finetuner/tuner/base.py | 17 +++++- finetuner/tuner/pytorch/miner.py | 44 ++++++++++++++++ tests/unit/tuner/torch/test_miner.py | 77 ++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 finetuner/tuner/pytorch/miner.py create mode 100644 tests/unit/tuner/torch/test_miner.py diff --git a/finetuner/helper.py b/finetuner/helper.py index e53ebb371..5d8aff3ad 100644 --- a/finetuner/helper.py +++ b/finetuner/helper.py @@ -14,6 +14,9 @@ AnyDNN = TypeVar( 'AnyDNN' ) #: The type of any implementation of a Deep Neural Network object +AnyTensor = TypeVar( + 'AnyTensor' +) #: The type of any implementation of an tensor for model tuning AnyDataLoader = TypeVar( 'AnyDataLoader' ) #: The type of any implementation of a data loader diff --git a/finetuner/tuner/base.py b/finetuner/tuner/base.py index 3bd5d24fb..ae6ea9ac4 100644 --- a/finetuner/tuner/base.py +++ b/finetuner/tuner/base.py @@ -1,6 +1,7 @@ import abc import warnings from typing import ( + Generator, Optional, Union, Tuple, @@ -8,7 +9,7 @@ Dict, ) -from ..helper import AnyDNN, AnyDataLoader, AnyOptimizer, DocumentArrayLike +from ..helper import AnyDNN, AnyTensor, AnyDataLoader, AnyOptimizer, DocumentArrayLike from .summary import Summary @@ -148,3 +149,17 @@ def __init__( ): super().__init__() self._inputs = inputs() if callable(inputs) else inputs + + +class BaseMiner(abc.ABC): + @abc.abstractmethod + def mine( + self, embeddings: List[AnyTensor], labels: List[int] + ) -> List[Tuple[int, ...]]: + """Generate tuples/triplets from input embeddings and labels. + + :param embeddings: embeddings from model, should be a list of Tensor objects. + :param labels: labels of each embeddings, embeddings with same label indicates same class. + :return: tuple/triplet of label indices. + """ + ... diff --git a/finetuner/tuner/pytorch/miner.py b/finetuner/tuner/pytorch/miner.py new file mode 100644 index 000000000..fb10afacc --- /dev/null +++ b/finetuner/tuner/pytorch/miner.py @@ -0,0 +1,44 @@ +import numpy as np +from typing import List, Tuple +from itertools import combinations + +from ..base import BaseMiner +from ...helper import AnyTensor + + +class SiameseMiner(BaseMiner): + def mine( + self, embeddings: List[AnyTensor], labels: List[int] + ) -> List[Tuple[int, ...]]: + """Generate tuples from input embeddings and labels. + + :param embeddings: embeddings from model, should be a list of Tensor objects. + :param labels: labels of each embeddings, embeddings with same label indicates same class. + :return: a pair of label indices and their label as tuple. + """ + assert len(embeddings) == len(labels) + return [ + (left[0], right[0], 1) if left[1] == right[1] else (left[0], right[0], -1) + for left, right in combinations(enumerate(labels), 2) + ] + + +class TripletMiner(BaseMiner): + def mine( + self, embeddings: List[AnyTensor], labels: List[int] + ) -> List[Tuple[int, ...]]: + """Generate triplets from input embeddings and labels. + + :param embeddings: embeddings from model, should be a list of Tensor objects. + :param labels: labels of each embeddings, embeddings with same label indicates same class. + :return: triplet of label indices follows the order of anchor, positive and negative. + """ + assert len(embeddings) == len(labels) + labels1 = np.expand_dims(labels, 1) + labels2 = np.expand_dims(labels, 0) + matches = (labels1 == labels2).astype(int) + diffs = matches ^ 1 + np.fill_diagonal(matches, 0) + triplets = np.expand_dims(matches, 2) * np.expand_dims(diffs, 1) + idxes_anchor, idxes_pos, idxes_neg = np.where(triplets) + return list(zip(idxes_anchor, idxes_pos, idxes_neg)) diff --git a/tests/unit/tuner/torch/test_miner.py b/tests/unit/tuner/torch/test_miner.py new file mode 100644 index 000000000..f90b7769c --- /dev/null +++ b/tests/unit/tuner/torch/test_miner.py @@ -0,0 +1,77 @@ +import pytest +import torch + +from finetuner.tuner.pytorch.miner import SiameseMiner, TripletMiner + +BATCH_SIZE = 8 +NUM_DIM = 10 + + +@pytest.fixture +def siamese_miner(): + return SiameseMiner() + + +@pytest.fixture +def triplet_miner(): + return TripletMiner() + + +@pytest.fixture +def embeddings(): + return [torch.rand(NUM_DIM) for _ in range(BATCH_SIZE)] + + +@pytest.fixture +def labels(): + return [1, 3, 1, 3, 2, 4, 2, 4] + + +def test_siamese_miner(embeddings, labels, siamese_miner): + rv = siamese_miner.mine(embeddings, labels) + assert len(rv) == 28 + for item in rv: + idx_left, idx_right, label = item + # find corresponded label idx + label_left = labels[idx_left] + label_right = labels[idx_right] + if label_left == label_right: + expected_label = 1 + else: + expected_label = -1 + assert label == expected_label + + +@pytest.mark.parametrize('cut_index', [0, 1]) +def test_siamese_miner_given_insufficient_inputs( + embeddings, labels, siamese_miner, cut_index +): + embeddings = embeddings[:cut_index] + labels = labels[:cut_index] + rv = list(siamese_miner.mine(embeddings, labels)) + assert len(rv) == 0 + + +def test_triplet_miner(embeddings, labels, triplet_miner): + rv = triplet_miner.mine(embeddings, labels) + assert len(rv) == 48 + for item in rv: + idx_anchor, idx_pos, idx_neg = item + # find corresponded label idx + label_anchor = labels[idx_anchor] + label_pos = labels[idx_pos] + label_neg = labels[idx_neg] + # given ordered anchor, pos, neg, + # assure first two labels are identical, first third label is different + assert label_anchor == label_pos + assert label_anchor != label_neg + + +@pytest.mark.parametrize('cut_index', [0, 1]) +def test_triplet_miner_given_insufficient_inputs( + embeddings, labels, siamese_miner, cut_index +): + embeddings = embeddings[:cut_index] + labels = labels[:cut_index] + rv = list(siamese_miner.mine(embeddings, labels)) + assert len(rv) == 0