Skip to content

Commit

Permalink
fix(torch): fix oom in rerank endpoint (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Apr 30, 2022
1 parent dd50816 commit 8ac2e9b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
2 changes: 1 addition & 1 deletion server/clip_server/executors/clip_onnx.py
Expand Up @@ -29,7 +29,7 @@ def __init__(
name: str = 'ViT-B/32',
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
minibatch_size: int = 16,
**kwargs,
):
super().__init__(**kwargs)
Expand Down
40 changes: 25 additions & 15 deletions server/clip_server/executors/clip_torch.py
Expand Up @@ -4,9 +4,9 @@
from typing import Optional, List, Tuple, Dict

import numpy as np
from jina import Executor, requests, DocumentArray, Document

import torch
from clip_server.model import clip
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand All @@ -16,13 +16,11 @@ def __init__(
device: Optional[str] = None,
jit: bool = False,
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
minibatch_size: int = 16,
**kwargs,
):
super().__init__(**kwargs)

import torch

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
Expand Down Expand Up @@ -83,6 +81,8 @@ def _split_img_txt_da(d, _img_da, _txt_da):

@requests(on='/rerank')
async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

Expand All @@ -107,12 +107,24 @@ async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
f'`d.{_source}` must have more than one Documents to do ranking'
)
else:
_img_da = self._preproc_image(_img_da)
_txt_da, texts = self._preproc_text(_txt_da)

logits_per_image, logits_per_text = self._model(
_img_da.tensors, _txt_da.tensors
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings)

# normalized features
image_features = _img_da.embeddings / _img_da.embeddings.norm(
dim=-1, keepdim=True
)
text_features = _txt_da.embeddings / _txt_da.embeddings.norm(
dim=-1, keepdim=True
)

# cosine similarity as logits
logit_scale = self._model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

probs_image = (
logits_per_image.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
Expand All @@ -124,9 +136,8 @@ async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
elif len(_txt_da) == 1:
probs = probs_text

_txt_da.texts = texts
_img_da.tensors = None
_txt_da.tensors = None
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
Expand All @@ -147,8 +158,6 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
for d in docs:
self._split_img_txt_da(d, _img_da, _txt_da)

import torch

with torch.inference_mode():
# for image
if _img_da:
Expand Down Expand Up @@ -181,3 +190,4 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# drop tensors
docs.tensors = None
return docs

0 comments on commit 8ac2e9b

Please sign in to comment.