From 8ac2e9bb68b96d1421f7e2ae6b01cec95aad3183 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Sat, 30 Apr 2022 17:16:17 +0200 Subject: [PATCH] fix(torch): fix oom in rerank endpoint (#699) --- server/clip_server/executors/clip_onnx.py | 2 +- server/clip_server/executors/clip_torch.py | 40 ++++++++++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 172cc8265..ab9984cc2 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -29,7 +29,7 @@ def __init__( name: str = 'ViT-B/32', device: Optional[str] = None, num_worker_preprocess: int = 4, - minibatch_size: int = 64, + minibatch_size: int = 16, **kwargs, ): super().__init__(**kwargs) diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index f373d6f03..e09eac953 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -4,9 +4,9 @@ from typing import Optional, List, Tuple, Dict import numpy as np -from jina import Executor, requests, DocumentArray, Document - +import torch from clip_server.model import clip +from jina import Executor, requests, DocumentArray class CLIPEncoder(Executor): @@ -16,13 +16,11 @@ def __init__( device: Optional[str] = None, jit: bool = False, num_worker_preprocess: int = 4, - minibatch_size: int = 64, + minibatch_size: int = 16, **kwargs, ): super().__init__(**kwargs) - import torch - if not device: self._device = 'cuda' if torch.cuda.is_available() else 'cpu' else: @@ -83,6 +81,8 @@ def _split_img_txt_da(d, _img_da, _txt_da): @requests(on='/rerank') async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): + import torch + _source = parameters.get('source', 'matches') _get = lambda d: getattr(d, _source) @@ -107,12 +107,24 @@ async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): f'`d.{_source}` must have more than one Documents to do ranking' ) else: - _img_da = self._preproc_image(_img_da) - _txt_da, texts = self._preproc_text(_txt_da) - - logits_per_image, logits_per_text = self._model( - _img_da.tensors, _txt_da.tensors + _img_da = await self.encode(_img_da) + _txt_da = await self.encode(_txt_da) + _img_da.embeddings = torch.from_numpy(_img_da.embeddings) + _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings) + + # normalized features + image_features = _img_da.embeddings / _img_da.embeddings.norm( + dim=-1, keepdim=True ) + text_features = _txt_da.embeddings / _txt_da.embeddings.norm( + dim=-1, keepdim=True + ) + + # cosine similarity as logits + logit_scale = self._model.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + probs_image = ( logits_per_image.softmax(dim=-1).cpu().detach().numpy().squeeze() ) @@ -124,9 +136,8 @@ async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): elif len(_txt_da) == 1: probs = probs_text - _txt_da.texts = texts - _img_da.tensors = None - _txt_da.tensors = None + _img_da.embeddings = None + _txt_da.embeddings = None for c, v in zip(_get(d), probs): c.scores['clip_score'].value = v @@ -147,8 +158,6 @@ async def encode(self, docs: 'DocumentArray', **kwargs): for d in docs: self._split_img_txt_da(d, _img_da, _txt_da) - import torch - with torch.inference_mode(): # for image if _img_da: @@ -181,3 +190,4 @@ async def encode(self, docs: 'DocumentArray', **kwargs): # drop tensors docs.tensors = None + return docs