In [1]:
from elasticsearch import Elasticsearch

In [2]:
es = Elasticsearch([{ 'host': 'elasticsearch' }])

In [3]:
mappings = {
    'properties': {
        'id': { 'type': 'integer' },
        'label': { 'type': 'text' },
        'group': { 'type': 'integer' },
        'feature': { 'type': 'dense_vector', 'dims': 128 }
    }
}

In [4]:
if es.indices.exists(index='matching'):
    es.indices.delete(index='matching')

In [5]:
es.indices.create(index='matching', body={ 'mappings': mappings })

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'matching'}

### データの作成

データ定義にしたがってデータを登録します。  
`feature` を真面目に作っても良いのですが少しめんどいので numpy のランダムなベクトルで代用

In [6]:
import numpy as np

In [7]:
n_data = 100000

In [8]:
features = np.random.uniform(size=(n_data, 128))

一件だけ登録

この時特徴量を list にして python object に変換する必要あり。(たぶん serializer をカスタムすればいちいち list にしなくても済むはずなのであとで調べる)

In [9]:
es.index('matching', body={ 'id': 0, 'label': 'foo', 'feature': list(np.random.uniform(size=(128,)))})

{'_index': 'matching',
 '_type': '_doc',
 '_id': 'XSqdkG0Beps9V_M2vCEp',
 '_version': 1,
 'result': 'created',
 '_shards': {'total': 2, 'successful': 1, 'failed': 0},
 '_seq_no': 0,
 '_primary_term': 1}

まとめてデータを作るときは `elasticsearch.helper.buld` をつかうとよさ気

https://elasticsearch-py.readthedocs.io/en/master/helpers.html#bulk-helpers

iterator を渡すと各要素をドキュメントとして登録してくれるみたい

```python
def gendata():
    mywords = ['foo', 'bar', 'baz']
    for word in mywords:
        yield {
            "_index": "mywords",
            "_type": "document",
            "doc": {"word": word},
        }

bulk(es, gendata())
```

In [10]:
def generate_data(features):
    for i, f in enumerate(features):
        yield {
            '_index': 'matching',
            'id': i + 1,
            'feature': list(f),
            'group': i % 10,
            'label': f'name={i:04d}'
        }

In [11]:
from elasticsearch.helpers import bulk

In [12]:
bulk(es, generate_data(features))

(100000, [])

In [13]:
query_feature = np.random.uniform(size=(128,))

In [14]:
query_feature

array([2.25656104e-02, 4.12628106e-01, 2.88556042e-01, 7.56468340e-01,
       3.12808881e-01, 7.29584599e-01, 9.56357205e-01, 7.29698507e-02,
       4.09252755e-01, 2.69922480e-01, 9.25295814e-01, 7.60822976e-01,
       1.29845987e-01, 1.88688244e-01, 5.01047971e-02, 1.26361104e-01,
       5.39542362e-01, 9.60115683e-05, 8.53420984e-01, 7.94699255e-01,
       2.51482210e-02, 8.36036766e-01, 9.49654285e-01, 8.76074364e-01,
       7.78683647e-01, 9.86483510e-02, 4.49622957e-01, 7.05727890e-01,
       3.93871006e-02, 4.30947003e-01, 3.42438699e-01, 1.49028948e-01,
       2.94677986e-01, 2.84892139e-01, 1.94117328e-01, 4.80122977e-01,
       2.99955798e-01, 8.27019179e-01, 6.05190330e-01, 3.96328418e-01,
       2.62835416e-01, 4.20814684e-01, 4.02046669e-01, 4.50535804e-01,
       7.08852941e-01, 9.84423173e-01, 9.44775558e-01, 4.06187076e-01,
       9.29880828e-01, 8.77629954e-01, 3.34651475e-02, 1.99035568e-01,
       1.57964092e-02, 6.23365657e-01, 6.31790042e-01, 4.41137014e-01,
      

In [15]:
%%time
res = es.search(index='matching', body={
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vec, doc['feature'])",
        "params": {
          "query_vec": query_feature.tolist()
        }
      }
    }
  }
})

CPU times: user 1.74 ms, sys: 36 µs, total: 1.77 ms
Wall time: 49.7 ms


In [16]:
%%time
res = es.search(index='matching', body={
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vec, doc['feature'])",
        "params": {
          "query_vec": query_feature.tolist()
        }
      }
    }
  }
})

CPU times: user 1.51 ms, sys: 31 µs, total: 1.54 ms
Wall time: 41.4 ms


In [17]:
top_object = res['hits']['hits'][0]['_source']

In [18]:
import pandas as pd

In [19]:
top_df = pd.DataFrame(res['hits']['hits'])

In [20]:
top_df

Unnamed: 0,_index,_type,_id,_score,_source
0,matching,_doc,qCqdkG0Beps9V_M2wDPd,0.8389,"{'id': 4683, 'feature': [0.022118664913051478,..."
1,matching,_doc,UiqdkG0Beps9V_M2z4Ll,0.836838,"{'id': 24821, 'feature': [0.37636770550248566,..."
2,matching,_doc,FyuekG0Beps9V_M2AYPs,0.833041,"{'id': 90554, 'feature': [0.22601973490458882,..."
3,matching,_doc,aiqdkG0Beps9V_M21qXP,0.83149,"{'id': 33805, 'feature': [0.016277852989449837..."
4,matching,_doc,ASqdkG0Beps9V_M2z4Pl,0.830883,"{'id': 24996, 'feature': [0.16140512873908297,..."
5,matching,_doc,ZiudkG0Beps9V_M28S-b,0.829608,"{'id': 69129, 'feature': [0.5458508761317696, ..."
6,matching,_doc,5yudkG0Beps9V_M26gvZ,0.82946,"{'id': 60042, 'feature': [0.2763118424659108, ..."
7,matching,_doc,DiqdkG0Beps9V_M2276z,0.828975,"{'id': 40113, 'feature': [0.20523668568922793,..."
8,matching,_doc,TSqdkG0Beps9V_M25OkT,0.828411,"{'id': 51184, 'feature': [0.04993116913060569,..."
9,matching,_doc,xyqdkG0Beps9V_M2vy2y,0.828209,"{'id': 3178, 'feature': [0.31926823698365736, ..."


In [21]:
top_object['id'], top_object['label']

(4683, 'name=4682')

答え合わせ

In [22]:
from scipy.spatial.distance import cosine

In [23]:
%%time
cos_sim = [1 - cosine(x, query_feature) for x in features]
idx = np.argmax(cos_sim)
cos_sim[idx]

CPU times: user 3.14 s, sys: 7.92 ms, total: 3.15 s
Wall time: 3.15 s


0.8389003829058972

In [24]:
from joblib import Parallel, delayed

In [25]:
%%time
Parallel(n_jobs=-1)([delayed(cosine)(x, query_feature) for x in features])

CPU times: user 5.06 s, sys: 258 ms, total: 5.32 s
Wall time: 5.39 s


[0.20128730837368602,
 0.2878335010455836,
 0.2725500550443356,
 0.2821604798313322,
 0.22811671091162833,
 0.2430795728098376,
 0.26233453051938194,
 0.2581440695666827,
 0.2540080971499157,
 0.25013176292859873,
 0.236896781621591,
 0.25918948527820473,
 0.2879359039011584,
 0.2775695569322222,
 0.2314597267151529,
 0.23852917644874072,
 0.2588606091140008,
 0.24816265341115717,
 0.22751419751696245,
 0.2339253244419397,
 0.29256902189558665,
 0.2940712313914856,
 0.2780513431842686,
 0.23595917498034524,
 0.2880192398691451,
 0.2575326285109809,
 0.27389078611386963,
 0.2734453424816763,
 0.264194084748371,
 0.24918653522597534,
 0.2820932924254289,
 0.25292086810488634,
 0.28415665486277986,
 0.24226655822135035,
 0.28470408120544377,
 0.256759545997025,
 0.24503060186546288,
 0.22285642564527286,
 0.2460959180793214,
 0.24382858470596358,
 0.27431992395201343,
 0.23143744683198542,
 0.24911580841255587,
 0.2644551092695028,
 0.2061432073438474,
 0.26903641795896205,
 0.22900877070