Skip to content

Commit

Permalink
feat: gpu support for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianwerk committed Oct 15, 2021
1 parent 02f91d8 commit 25edf29
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 9 deletions.
2 changes: 1 addition & 1 deletion finetuner/tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Dict

from ..helper import AnyDNN, DocumentArrayLike, TunerReturnType, get_tuner_class

Expand Down
14 changes: 6 additions & 8 deletions finetuner/tuner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _eval(
"""Evaluate the model on given labeled data"""
...

def log_evaluation(self, docs, label):
def log_evaluation(self, docs: DocumentArrayLike, label: str):
if self.logger.logger.isEnabledFor(logging.DEBUG):
if label not in self._catalogs:
self._catalogs[label] = evaluation.extract_catalog(docs)
Expand All @@ -183,20 +183,18 @@ def log_evaluation(self, docs, label):
self.logger.debug(f'{label} {name}: {value}')

def _get_evaluation(self, docs, catalog):
self._calc_embeddings(docs)
self._calc_embeddings(catalog)
self.get_embeddings(docs)
self.get_embeddings(catalog)
catalog.prune()
to_be_scored_docs = evaluation.prepare_eval_docs(docs, catalog, limit=10)
return {
'hits': evaluation.get_hits_at_n(to_be_scored_docs),
'ndcg': evaluation.get_ndcg_at_n(to_be_scored_docs),
}

def _calc_embeddings(self, docs):
blobs = docs.blobs
embeddings = self.embed_model(blobs)
for doc, embed in zip(docs, embeddings):
doc.embedding = np.array(embed)
@abc.abstractmethod
def get_embeddings(self, docs: DocumentArrayLike):
"""Calculates and adds the embeddings for the given Documents."""


class BaseDataset:
Expand Down
7 changes: 7 additions & 0 deletions finetuner/tuner/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional

import numpy as np
import tensorflow as tf
from jina.logging.profile import ProgressBar
from tensorflow import keras
Expand Down Expand Up @@ -188,5 +189,11 @@ def fit(
'metric': {'train': metrics_train, 'eval': metrics_eval},
}

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
embeddings = self.embed_model(blobs)
for doc, embed in zip(data, embeddings):
doc.embedding = np.array(embed)

def save(self, *args, **kwargs):
self.embed_model.save(*args, **kwargs)
7 changes: 7 additions & 0 deletions finetuner/tuner/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional

import numpy as np
import paddle
from jina.logging.profile import ProgressBar
from paddle import nn
Expand Down Expand Up @@ -176,5 +177,11 @@ def fit(
'metric': {'train': metrics_train, 'eval': metrics_eval},
}

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
embeddings = self.embed_model(paddle.Tensor(blobs))
for doc, embed in zip(data, embeddings):
doc.embedding = np.array(embed)

def save(self, *args, **kwargs):
paddle.save(self.embed_model.state_dict(), *args, **kwargs)
7 changes: 7 additions & 0 deletions finetuner/tuner/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn as nn
from jina.logging.profile import ProgressBar
Expand Down Expand Up @@ -189,5 +190,11 @@ def fit(
'metric': {'train': metrics_train, 'eval': metrics_eval},
}

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
embeddings = self.embed_model(torch.Tensor(blobs))
for doc, embed in zip(data, embeddings):
doc.embedding = np.array(embed)

def save(self, *args, **kwargs):
torch.save(self.embed_model.state_dict(), *args, **kwargs)

0 comments on commit 25edf29

Please sign in to comment.