Skip to content

Commit

Permalink
fix: use cosine as the rank score (#708)
Browse files Browse the repository at this point in the history
* fix: use cosine as the rank score

* fix: return cosine or probls as rank score

* fix: add logit scale

* fix: unittest

* fix: use softmax as para name

* fix: return both softmax and cosine

* fix: numpy softmax
  • Loading branch information
numb3r3 committed May 9, 2022
1 parent 706fa62 commit 835eb13
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 78 deletions.
56 changes: 31 additions & 25 deletions server/clip_server/executors/clip_onnx.py
Expand Up @@ -10,7 +10,12 @@

from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
numpy_softmax,
)


class CLIPEncoder(Executor):
Expand All @@ -20,7 +25,6 @@ def __init__(
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 16,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -31,7 +35,8 @@ def __init__(
self._minibatch_size = minibatch_size

self._model = CLIPOnnxModel(name)
self._logit_scale = logit_scale
# Note: hard coded here since all the pretrained clip model use the same logit_scale parameter
self._logit_scale = np.exp(4.60517)

import torch

Expand Down Expand Up @@ -80,14 +85,15 @@ def __init__(
@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -98,7 +104,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
Expand All @@ -114,37 +120,37 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_txt_da.embeddings, axis=1, keepdims=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * np.matmul(
image_features, text_features.T
)
logits_per_text = logits_per_image.T

def numpy_softmax(z):
s = np.max(z, axis=1)
s = s[:, np.newaxis]
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
# paired cosine similarity
scores_per_text = np.matmul(image_features, text_features.T)
scores_per_image = scores_per_text.T

if len(_img_da) == 1:
probs = numpy_softmax(logits_per_image)[0]
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = numpy_softmax(logits_per_text)[0]
cosine_scores = scores_per_image

softmax_scores = numpy_softmax(self._logit_scale * cosine_scores)

# squeeze scores
cosine_scores = cosine_scores[0]
softmax_scores = softmax_scores[0]

# drop embeddings
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
47 changes: 26 additions & 21 deletions server/clip_server/executors/clip_torch.py
Expand Up @@ -3,7 +3,7 @@
from functools import partial

from multiprocessing.pool import ThreadPool
from typing import Optional, List, Tuple, Dict
from typing import Optional, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
self._model, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
)
self._logit_scale = self._model.logit_scale.exp()

self._pool = ThreadPool(processes=num_worker_preprocess)

Expand All @@ -61,14 +62,15 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -79,7 +81,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
Expand All @@ -97,34 +99,37 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
dim=-1, keepdim=True
)

# cosine similarity as logits
logit_scale = self._model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# paired cosine between image and text
scores_per_text = image_features @ text_features.t()
scores_per_image = scores_per_text.t()

if len(_img_da) == 1:
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
cosine_scores = scores_per_image

softmax_scores = self._logit_scale * cosine_scores
softmax_scores = softmax_scores.softmax(dim=-1)

# squeeze scores
cosine_scores = cosine_scores.cpu().detach().numpy().squeeze()
softmax_scores = softmax_scores.cpu().detach().numpy().squeeze()

_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
69 changes: 37 additions & 32 deletions server/clip_server/executors/clip_trt.py
Expand Up @@ -6,7 +6,12 @@

from clip_server.model import clip
from clip_server.model.clip_trt import CLIPTensorRTModel
from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
numpy_softmax,
)


class CLIPEncoder(Executor):
Expand All @@ -16,7 +21,6 @@ def __init__(
device: str = 'cuda',
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -27,8 +31,6 @@ def __init__(
self._minibatch_size = minibatch_size
self._device = device

self._logit_scale = logit_scale

import torch

assert self._device.startswith('cuda'), (
Expand All @@ -44,19 +46,21 @@ def __init__(

self._model.start_engines()

# Note: hard coded here since all the pretrained clip model use the same logit_scale parameter
self._logit_scale = np.exp(4.60517)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -67,52 +71,53 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
else:
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings)

# normalized features
image_features = _img_da.embeddings / _img_da.embeddings.norm(
dim=-1, keepdim=True
image_features = _img_da.embeddings / np.linalg.norm(
_img_da.embeddings, axis=1, keepdims=True
)
text_features = _txt_da.embeddings / _txt_da.embeddings.norm(
dim=-1, keepdim=True
text_features = _txt_da.embeddings / np.linalg.norm(
_txt_da.embeddings, axis=1, keepdims=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# cosine similarity as rank score
scores_per_text = np.matmul(image_features, text_features.T)
scores_per_image = scores_per_text.T

if len(_img_da) == 1:
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
cosine_scores = scores_per_image

softmax_scores = numpy_softmax(self._logit_scale * cosine_scores)

# squeeze scores
softmax_scores = softmax_scores[0]
cosine_scores = cosine_scores[0]

# drop embeddings
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
8 changes: 8 additions & 0 deletions server/clip_server/executors/helper.py
Expand Up @@ -6,6 +6,14 @@
from docarray import Document, DocumentArray


def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray':
max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - max)
div = np.sum(e_x, axis=axis, keepdims=True)
f_x = e_x / div
return f_x


def preproc_image(
da: 'DocumentArray',
preprocess_fn: Callable,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_helper.py
@@ -0,0 +1,19 @@
import pytest
import numpy as np
from clip_server.executors.helper import numpy_softmax


@pytest.mark.parametrize('shape', [(5, 10), (5, 10, 10)])
@pytest.mark.parametrize('axis', [-1, 1, 0])
def test_numpy_softmax(shape, axis):
import torch

logits = np.random.random(shape)

np_softmax = numpy_softmax(logits, axis=axis)
torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy()
np.testing.assert_array_almost_equal(np_softmax, torch_softmax)

np_softmax = numpy_softmax(logits, axis=axis)
torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy()
np.testing.assert_array_almost_equal(np_softmax, torch_softmax)

0 comments on commit 835eb13

Please sign in to comment.