Skip to content

Commit

Permalink
feat(tuner): add default projection head for ssl (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwanglzu committed Jan 11, 2022
1 parent bc25c37 commit 554878e
Show file tree
Hide file tree
Showing 26 changed files with 1,331 additions and 1,043 deletions.
4 changes: 2 additions & 2 deletions finetuner/__init__.py
Expand Up @@ -92,7 +92,7 @@ def fit(
input_dtype: str = 'float32',
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
) -> 'AnyDNN':
...

Expand Down Expand Up @@ -158,7 +158,7 @@ def fit(
input_dtype: str = 'float32',
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
) -> 'AnyDNN':
...

Expand Down
5 changes: 5 additions & 0 deletions finetuner/excepts.py
@@ -0,0 +1,5 @@
"""This modules defines all kinds of exceptions raised in Finetuner."""


class DimensionMismatchException(Exception):
"""Dimensionality mismatch given input and output layers."""
6 changes: 3 additions & 3 deletions finetuner/tailor/__init__.py
Expand Up @@ -30,7 +30,7 @@ def to_embedding_model(
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
**kwargs
) -> 'AnyDNN':
"""Convert a general model from :py:attr:`.model` to an embedding model.
Expand All @@ -42,13 +42,13 @@ def to_embedding_model(
:param input_size: The input size of the DNN model.
:param input_dtype: The input data type of the DNN model.
:param freeze: if set as True, will freeze all layers before :py:`attr`:`layer_name`. If set as list of str, will freeze layers by names.
:param bottleneck_net: Attach a bottleneck net at the end of model, this module should always trainable.
:param projection_head: Attach a module at the end of model, this module should be always trainable.
"""
ft = _get_tailor_class(model)

return ft(model, input_size, input_dtype).to_embedding_model(
layer_name=layer_name,
bottleneck_net=bottleneck_net,
projection_head=projection_head,
freeze=freeze,
)

Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/base.py
Expand Up @@ -33,15 +33,15 @@ def to_embedding_model(
self,
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
) -> 'AnyDNN':
"""Convert a general model from :py:attr:`.model` to an embedding model.
:param layer_name: the name of the layer that is used for output embeddings. All layers *after* that layer
will be removed. When set to ``None``, then the last layer listed in :py:attr:`.embedding_layers` will be used.
To see all available names you can check ``name`` field of :py:attr:`.embedding_layers`.
:param freeze: if set as True, will freeze all layers before :py:`attr`:`layer_name`. If set as list of str, will freeze layers by names.
:param bottleneck_net: Attach a bottleneck net at the end of model, this module should always trainable.
:param projection_head: Attach a module at the end of model, this module should be always trainable.
:return: Converted embedding model.
"""
...
Expand Down
10 changes: 5 additions & 5 deletions finetuner/tailor/keras/__init__.py
Expand Up @@ -74,7 +74,7 @@ def to_embedding_model(
self,
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
) -> 'AnyDNN':

"""Convert a general model from :py:attr:`.model` to an embedding model.
Expand All @@ -83,7 +83,7 @@ def to_embedding_model(
will be removed. When set to ``None``, then the last layer listed in :py:attr:`.embedding_layers` will be used.
To see all available names you can check ``name`` field of :py:attr:`.embedding_layers`.
:param freeze: if set as True, will freeze all layers before :py:`attr`:`layer_name`. If set as list of str, will freeze layers by names.
:param bottleneck_net: Attach a bottleneck net at the end of model, this module should always trainable.
:param projection_head: Attach a module at the end of model, this module should be always trainable.
:return: Converted embedding model.
"""
_all_embed_layers = {layer['name']: layer for layer in self.embedding_layers}
Expand Down Expand Up @@ -115,10 +115,10 @@ def to_embedding_model(
for layer in model.layers:
layer.trainable = False

if bottleneck_net:
# append bottleneck net at the end of embedding model.
if projection_head:
# append a mlp module at the end of embedding model.
x = model.output
for layer in bottleneck_net.layers:
for layer in projection_head.layers:
x = layer(x)
model = tf.keras.Model(model.input, x)

Expand Down
34 changes: 34 additions & 0 deletions finetuner/tailor/keras/projection_head.py
@@ -0,0 +1,34 @@
import tensorflow as tf


class ProjectionHead(tf.keras.layers.Layer):
"""Projection head used internally for self-supervised training.
It is (by default) a simple 3-layer MLP to be attached on top of embedding model only for training purpose.
After training, it should be cut-out from the embedding model.
"""

EPSILON = 1e-5

def __init__(self, in_features: int, output_dim: int = 128, num_layers: int = 3):
super().__init__()
self.layers = []
for idx in range(num_layers - 1):
self.layers.append(
tf.keras.layers.Dense(
units=in_features,
bias_initializer='zeros',
)
)
self.layers.append(tf.keras.layers.BatchNormalization(epsilon=self.EPSILON))
self.layers.append(tf.keras.layers.ReLU())
self.layers.append(
tf.keras.layers.Dense(
units=output_dim,
bias_initializer='zeros',
)
)

def call(self, x):
for layer in self.layers:
x = layer(x)
return x
14 changes: 8 additions & 6 deletions finetuner/tailor/paddle/__init__.py
Expand Up @@ -138,15 +138,15 @@ def to_embedding_model(
self,
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional['AnyDNN'] = None,
projection_head: Optional['AnyDNN'] = None,
) -> 'AnyDNN':
"""Convert a general model from :py:attr:`.model` to an embedding model.
:param layer_name: the name of the layer that is used for output embeddings. All layers *after* that layer
will be removed. When set to ``None``, then the last layer listed in :py:attr:`.embedding_layers` will be used.
To see all available names you can check ``name`` field of :py:attr:`.embedding_layers`.
:param freeze: if set as True, will freeze all layers before :py:`attr`:`layer_name`. If set as list of str, will freeze layers by names.
:param bottleneck_net: Attach a bottleneck net at the end of model, this module should always trainable.
:param projection_head: Attach a module at the end of model, this module should be always trainable.
:return: Converted embedding model..
"""
model = copy.deepcopy(self._model)
Expand Down Expand Up @@ -192,11 +192,13 @@ def to_embedding_model(
if _relative_idx_to_embedding_layer is not None:
_relative_idx_to_embedding_layer += 1

if bottleneck_net:
model = nn.Sequential(
model,
bottleneck_net,
if projection_head:
embed_model_with_projection_head = nn.Sequential()
embed_model_with_projection_head.add_sublayer('embed_model', model)
embed_model_with_projection_head.add_sublayer(
'projection_head', projection_head
)
return embed_model_with_projection_head

return model

Expand Down
38 changes: 38 additions & 0 deletions finetuner/tailor/paddle/projection_head.py
@@ -0,0 +1,38 @@
import paddle.nn as nn


class ProjectionHead(nn.Layer):
"""Projection head used internally for self-supervised training.
It is (by default) a simple 3-layer MLP to be attached on top of embedding model only for training purpose.
After training, it should be cut-out from the embedding model.
"""

EPSILON = 1e-5

def __init__(self, in_features: int, output_dim: int = 128, num_layers: int = 3):
super().__init__()
self.head_layers = nn.LayerList()
for idx in range(num_layers - 1):
self.head_layers.append(
nn.Linear(
in_features=in_features,
out_features=in_features,
bias_attr=False,
)
)
self.head_layers.append(
nn.BatchNorm1D(num_features=in_features, epsilon=self.EPSILON)
)
self.head_layers.append(nn.ReLU())
self.head_layers.append(
nn.Linear(
in_features=in_features,
out_features=output_dim,
bias_attr=False,
)
)

def forward(self, x):
for layer in self.head_layers:
x = layer(x)
return x
15 changes: 9 additions & 6 deletions finetuner/tailor/pytorch/__init__.py
Expand Up @@ -131,15 +131,15 @@ def to_embedding_model(
self,
layer_name: Optional[str] = None,
freeze: Union[bool, List[str]] = False,
bottleneck_net: Optional[nn.Module] = None,
projection_head: Optional[nn.Module] = None,
) -> 'AnyDNN':
"""Convert a general model from :py:attr:`.model` to an embedding model.
:param layer_name: the name of the layer that is used for output embeddings. All layers *after* that layer
will be removed. When set to ``None``, then the last layer listed in :py:attr:`.embedding_layers` will be used.
To see all available names you can check ``name`` field of :py:attr:`.embedding_layers`.
:param freeze: if set as True, will freeze all layers before :py:`attr`:`layer_name`. If set as list of str, will freeze layers by names.
:param bottleneck_net: Attach a bottleneck net at the end of model, this module should always trainable.
:param projection_head: Attach a module at the end of model, this module should be always trainable.
:return: Converted embedding model.
"""

Expand Down Expand Up @@ -188,9 +188,12 @@ def to_embedding_model(
if _relative_idx_to_embedding_layer is not None:
_relative_idx_to_embedding_layer += 1

if bottleneck_net:
return nn.Sequential(
model,
bottleneck_net,
if projection_head:
embed_model_with_projection_head = nn.Sequential()
embed_model_with_projection_head.add_module('embed_model', model)
embed_model_with_projection_head.add_module(
'projection_head', projection_head
)
return embed_model_with_projection_head

return model
30 changes: 30 additions & 0 deletions finetuner/tailor/pytorch/projection_head.py
@@ -0,0 +1,30 @@
import torch.nn as nn


class ProjectionHead(nn.Module):
"""Projection head used internally for self-supervised training.
It is (by default) a simple 3-layer MLP to be attached on top of embedding model only for training purpose.
After training, it should be cut-out from the embedding model.
"""

EPSILON = 1e-5

def __init__(self, in_features: int, output_dim: int = 128, num_layers: int = 3):
super().__init__()
self.head_layers = nn.ModuleList()
for idx in range(num_layers - 1):
self.head_layers.append(
nn.Linear(in_features=in_features, out_features=in_features, bias=False)
)
self.head_layers.append(
nn.BatchNorm1d(num_features=in_features, eps=self.EPSILON)
)
self.head_layers.append(nn.ReLU())
self.head_layers.append(
nn.Linear(in_features=in_features, out_features=output_dim, bias=False)
)

def forward(self, x):
for layer in self.head_layers:
x = layer(x)
return x
6 changes: 5 additions & 1 deletion finetuner/tuner/keras/__init__.py
Expand Up @@ -9,6 +9,7 @@
from ... import __default_tag_key__
from ..base import BaseLoss, BaseTuner
from ..dataset import ClassDataset, SessionDataset
from ..dataset.datasets import InstanceDataset
from ..state import TunerState
from . import losses
from .data import KerasDataSequence
Expand Down Expand Up @@ -46,7 +47,10 @@ def _get_data_loader(
if __default_tag_key__ in data[0].tags:
dataset = ClassDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = SessionDataset(data, preprocess_fn=preprocess_fn)
if len(data[0].matches) > 0:
dataset = SessionDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = InstanceDataset(data, preprocess_fn=preprocess_fn)

batch_sampler = self._get_batch_sampler(
dataset,
Expand Down
6 changes: 5 additions & 1 deletion finetuner/tuner/paddle/__init__.py
Expand Up @@ -9,6 +9,7 @@

from ... import __default_tag_key__
from ..base import BaseTuner
from ..dataset.datasets import InstanceDataset
from ..state import TunerState
from . import losses
from .datasets import PaddleClassDataset, PaddleSessionDataset
Expand Down Expand Up @@ -62,7 +63,10 @@ def collate_fn_all(inputs):
if __default_tag_key__ in data[0].tags:
dataset = PaddleClassDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = PaddleSessionDataset(data, preprocess_fn=preprocess_fn)
if len(data[0].matches) > 0:
dataset = PaddleSessionDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = InstanceDataset(data, preprocess_fn=preprocess_fn)

batch_sampler = self._get_batch_sampler(
dataset,
Expand Down
7 changes: 5 additions & 2 deletions finetuner/tuner/pytorch/__init__.py
Expand Up @@ -9,6 +9,7 @@

from ... import __default_tag_key__
from ..base import BaseTuner
from ..dataset.datasets import InstanceDataset
from ..state import TunerState
from . import losses
from .datasets import PytorchClassDataset, PytorchSessionDataset
Expand Down Expand Up @@ -62,7 +63,10 @@ def collate_fn_all(inputs):
if __default_tag_key__ in data[0].tags:
dataset = PytorchClassDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = PytorchSessionDataset(data, preprocess_fn=preprocess_fn)
if len(data[0].matches) > 0:
dataset = PytorchSessionDataset(data, preprocess_fn=preprocess_fn)
else:
dataset = InstanceDataset(data, preprocess_fn=preprocess_fn)

batch_sampler = self._get_batch_sampler(
dataset,
Expand Down Expand Up @@ -170,7 +174,6 @@ def _fit(
collate_fn=collate_fn,
num_workers=num_workers,
)

# Set state
self.state = TunerState(num_epochs=epochs)
self._trigger_callbacks('on_fit_begin')
Expand Down

0 comments on commit 554878e

Please sign in to comment.