Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor head layers #130

Merged
merged 23 commits into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fit(
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
learning_rate: float = 1e-3,
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
Expand All @@ -37,7 +37,7 @@ def fit(
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
learning_rate: float = 1e-3,
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
Expand All @@ -61,7 +61,7 @@ def fit(
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
learning_rate: float = 1e-3,
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
Expand All @@ -79,7 +79,7 @@ def fit(
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
learning_rate: float = 1e-3,
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
Expand Down
4 changes: 2 additions & 2 deletions finetuner/labeler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def fit(
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
**kwargs,
) -> None:
dam_path = tempfile.mkdtemp()
Expand All @@ -44,7 +44,7 @@ def get_embed_model(self):
uses=MyExecutor,
uses_with={
'dam_path': dam_path,
'head_layer': head_layer,
'loss': loss,
},
)
)
Expand Down
6 changes: 3 additions & 3 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def __init__(
self,
dam_path: str,
metric: str = 'cosine',
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
**kwargs,
):
super().__init__(**kwargs)
self._all_data = DocumentArrayMemmap(dam_path)
self._metric = metric
self._head_layer = head_layer
self._loss = loss

@abc.abstractmethod
def get_embed_model(self):
Expand Down Expand Up @@ -77,7 +77,7 @@ def fit(self, docs, parameters: Dict, **kwargs):
self._embed_model,
docs,
epochs=int(parameters.get('epochs', 10)),
head_layer=self._head_layer,
loss=self._loss,
)

@requests(on='/save')
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def fit(
eval_data: Optional[DocumentArrayLike] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
loss: str = 'CosineSiameseLoss',
learning_rate: float = 1e-3,
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
Expand All @@ -38,7 +38,7 @@ def fit(
) -> TunerReturnType:
ft = get_tuner_class(embed_model)

return ft(embed_model, head_layer=head_layer).fit(
return ft(embed_model, loss=loss).fit(
train_data,
eval_data,
epochs=epochs,
Expand Down
65 changes: 9 additions & 56 deletions finetuner/tuner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,19 @@
from ..helper import AnyDNN, AnyDataLoader, AnyOptimizer, DocumentArrayLike


class BaseHead:
class BaseLoss:
arity: int

def __init__(self, arity_model: Optional[AnyDNN] = None):
super().__init__()
self._arity_model = arity_model

def forward(self, *inputs):
if self._arity_model:
inputs = self._arity_model(*inputs)
return self.get_output(*inputs)

@abc.abstractmethod
def get_output(self, *inputs):
...

@abc.abstractmethod
def loss_fn(self, pred_val, target_val):
...
hanxiao marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def metric_fn(self, pred_val, target_val):
...


class BaseTuner(abc.ABC):
def __init__(
self,
embed_model: Optional[AnyDNN] = None,
head_layer: Union[AnyDNN, str, None] = None,
loss: Union[AnyDNN, str, None] = None,
**kwargs,
):
self._embed_model = embed_model
self._head_layer = head_layer
self._loss = self._get_loss(loss)
self.logger = JinaLogger(self.__class__.__name__)

def _get_optimizer_kwargs(self, optimizer: str, custom_kwargs: Optional[Dict]):
Expand Down Expand Up @@ -89,16 +68,6 @@ def embed_model(self) -> AnyDNN:
"""Get the base model of this object."""
return self._embed_model

@property
@abc.abstractmethod
def wrapped_model(self) -> AnyDNN:
"""Get the wrapped model of this object.

A wrapped model is an :py:attr:`.embed_model` replicated by :py:attr:`.arity` times
with a ``head_layer`` that fuses all.
"""
...

@property
def arity(self) -> int:
"""Get the arity of this object.
Expand All @@ -107,13 +76,7 @@ def arity(self) -> int:
- ``arity = 2`` corresponds to the siamese network;
- ``arity = 3`` corresponds to the triplet network.
"""
return self.head_layer.arity

@property
@abc.abstractmethod
def head_layer(self) -> AnyDNN:
"""Get the head layer of this object."""
...
return self._loss.arity

@abc.abstractmethod
def _get_optimizer(
Expand All @@ -140,11 +103,12 @@ def fit(

@abc.abstractmethod
def save(self, *args, **kwargs):
"""Save the weights of the :py:attr:`.embed_model`.
"""Save the weights of the :py:attr:`.embed_model`."""
...

Note that, the :py:attr:`.head_layer` and :py:attr:`.wrapped_model` do not need to be stored,
as they are auxiliary layers for tuning :py:attr:`.embed_model`.
"""
@abc.abstractmethod
def _get_loss(self, loss: Union[str, AnyDNN, None]) -> BaseLoss:
"""Get the loss layer."""
...

@abc.abstractmethod
Expand Down Expand Up @@ -176,14 +140,3 @@ def __init__(
):
super().__init__()
self._inputs = inputs() if callable(inputs) else inputs


class BaseArityModel:
"""The helper class to copy the network for multi-inputs."""

def __init__(self, embed_model: AnyDNN):
super().__init__()
self._embed_model = embed_model

def forward(self, *args):
return tuple(self._embed_model(a) for a in args)
8 changes: 2 additions & 6 deletions finetuner/tuner/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@


class LogGenerator:
def __init__(self, name, losses, metrics, prefix: str = ''):
def __init__(self, name, losses, prefix: str = ''):
self._losses = losses
self._metrics = metrics
self._prefix = prefix
self._name = name

Expand All @@ -16,14 +15,11 @@ def __call__(self):
return f'{prefix}{self._name}: {self.get_statistic()}'

def get_statistic(self):
return f'L={self.mean_loss():>8} A={self.mean_metric():>4}'
return f'Loss={self.mean_loss():>8}'

def mean_loss(self):
return LogGenerator.get_log_value(self._losses)

def mean_metric(self):
return LogGenerator.get_log_value(self._metrics)

@staticmethod
def get_log_value(data):
mean = np.mean(data)
Expand Down
73 changes: 22 additions & 51 deletions finetuner/tuner/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,24 @@
from typing import Dict, Optional
from typing import Dict, Optional, Union

import paddle
from jina.logging.profile import ProgressBar
from paddle import nn
from paddle.io import DataLoader
from paddle.optimizer import Optimizer

from . import head_layers, datasets
from ..base import BaseTuner, BaseHead, BaseArityModel
from . import losses, datasets
from ..base import BaseTuner
from ...helper import DocumentArrayLike
from ..dataset.helper import get_dataset
from ..logger import LogGenerator


class _ArityModel(BaseArityModel, nn.Layer):
...


class PaddleTuner(BaseTuner):
@property
def head_layer(self) -> BaseHead:
if isinstance(self._head_layer, str):
return getattr(head_layers, self._head_layer)
elif isinstance(self._head_layer, nn.Layer):
return self._head_layer

@property
def wrapped_model(self) -> nn.Layer:
if self.embed_model is None:
raise ValueError('embed_model is not set')

if getattr(self, '_wrapped_model', None) is not None:
return self._wrapped_model

self._wrapped_model = self.head_layer(_ArityModel(self.embed_model))
return self._wrapped_model
def _get_loss(self, loss: Union[nn.Layer, str, None] = None):
if isinstance(loss, str):
return getattr(losses, loss)()
elif isinstance(loss, nn.Layer):
return loss

def _get_data_loader(self, inputs, batch_size: int, shuffle: bool):
ds = get_dataset(datasets, self.arity)
Expand All @@ -47,7 +31,7 @@ def _get_data_loader(self, inputs, batch_size: int, shuffle: bool):
def _get_optimizer(
self, optimizer: str, optimizer_kwargs: Optional[dict], learning_rate: float
) -> Optimizer:
params = self.wrapped_model.parameters()
params = self._embed_model.parameters()
optimizer_kwargs = self._get_optimizer_kwargs(optimizer, optimizer_kwargs)

if optimizer == 'adam':
Expand All @@ -71,54 +55,48 @@ def _get_optimizer(
)

def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):
self.wrapped_model.eval()
self._embed_model.eval()

losses = []
metrics = []

log_generator = LogGenerator('E', losses, metrics, train_log)
log_generator = LogGenerator('E', losses, train_log)

with ProgressBar(description, message_on_done=log_generator) as p:
for inputs, label in data:
outputs = self.wrapped_model(*inputs)
loss = self.wrapped_model.loss_fn(outputs, label)
metric = self.wrapped_model.metric_fn(outputs, label)
embeddings = [self._embed_model(inpt) for inpt in inputs]
loss = self._loss(embeddings, label)

losses.append(loss.item())
metrics.append(metric.numpy())

p.update(message=log_generator())

return losses, metrics
return losses

def _train(self, data, optimizer: Optimizer, description: str):

self.wrapped_model.train()
self._embed_model.train()

losses = []
metrics = []

log_generator = LogGenerator('T', losses, metrics)
log_generator = LogGenerator('T', losses)

with ProgressBar(
description, message_on_done=log_generator, final_line_feed=False
) as p:
for inputs, label in data:
# forward step
outputs = self.wrapped_model(*inputs)
loss = self.wrapped_model.loss_fn(outputs, label)
metric = self.wrapped_model.metric_fn(outputs, label)
embeddings = [self._embed_model(inpt) for inpt in inputs]
loss = self._loss(embeddings, label)

optimizer.clear_grad()

loss.backward()
optimizer.step()

losses.append(loss.item())
metrics.append(metric.numpy())

p.update(message=log_generator())
return losses, metrics
return losses

def fit(
self,
Expand All @@ -143,35 +121,28 @@ def fit(
_optimizer = self._get_optimizer(optimizer, optimizer_kwargs, learning_rate)

losses_train = []
metrics_train = []
losses_eval = []
metrics_eval = []

for epoch in range(epochs):
_data = self._get_data_loader(
inputs=train_data, batch_size=batch_size, shuffle=False
)
lt, mt = self._train(
lt = self._train(
_data,
_optimizer,
description=f'Epoch {epoch + 1}/{epochs}',
)
losses_train.extend(lt)
metrics_train.extend(mt)

if eval_data:
_data = self._get_data_loader(
inputs=eval_data, batch_size=batch_size, shuffle=False
)

le, me = self._eval(_data, train_log=LogGenerator('T', lt, mt)())
le = self._eval(_data, train_log=LogGenerator('T', lt)())
losses_eval.extend(le)
metrics_eval.extend(me)

return {
'loss': {'train': losses_train, 'eval': losses_eval},
'metric': {'train': metrics_train, 'eval': metrics_eval},
}
return {'loss': {'train': losses_train, 'eval': losses_eval}}

def save(self, *args, **kwargs):
paddle.save(self.embed_model.state_dict(), *args, **kwargs)
Loading