In [2]:
import re
import math
from gensim.models import KeyedVectors
from typing import Callable, Iterable, List, Set, Dict, Tuple, Optional

model = None

class Word2Vec:
    DEBUG = True
    EXPECTED_ENTITY_CONTENT_RATE = 0.7
    model_filepath = None
    model = None

    def __init__(self, model_filepath, binary = True):
        global model
        self.model_filepath = model_filepath
        if self.DEBUG and model is not None:
            print("reusing loaded model.")
            self.model = model
            return
        self.model = KeyedVectors.load_word2vec_format(model_filepath, binary=binary)
        model = self.model

    def reload(self, model_filepath=None, binary = True):
        global model
        if model_filepath is not None:
            self.model_filepath = model_filepath
        self.model = KeyedVectors.load_word2vec_format(model_filepath, binary=binary)
        model = self.model

    def most_similar(
        self,
        positive: str or List[str],
        topn: int=30,
    ) -> List[Tuple[str, float]]:
        try:
            similars = self.model.wv.most_similar(
                positive=positive,
                topn=math.ceil(topn/(1-self.EXPECTED_ENTITY_CONTENT_RATE))
            )
        except KeyError:
            return [(positive, 1.0)] if type(positive) == str else [(s, 1.0) for s in positive]
        # remove entities.
        return [*filter(lambda x: not x[0].startswith("["), similars)][0:topn]

    def similarity(
        self,
        w1,
        w2,
    ):
        score = 0.0
        try:
            score = self.model.wv.similarity(w1, w2)
        except  KeyError:
            score = 0.0
        return score


w2v_model_filepath = './data/entity_vector.model.bin'
word2vec = Word2Vec(w2v_model_filepath, binary=True)

In [41]:
import re
import math
import json
import statistics
from typing import Callable, Iterable, List, Set, Dict, Tuple, Optional

class MeCabTagger:
    def __init__(self):
        pass

INV_SQRT2 = 1/math.sqrt(2)

# erf function with negative boost    
def erf(x, negative_boost:float=1.0):
    if x < 0:
        return negative_boost*math.erf(INV_SQRT2*x)
    return math.erf(INV_SQRT2*x)

class RelatedWord:
    w2v: Word2Vec = None
    tagger: MeCabTagger = None

    def __init__(
        self,
        w2v: Word2Vec,
        tagger: MeCabTagger
    ):
        self.w2v = w2v
        self.tagger = tagger
    
    def get_related_words(
        self, searches: Iterable[str],
        related_word_topn: int = 100,
        limit: int = 50,
        negative_boost:float = 1.0,
    ) -> List[Tuple[str, float, float, Dict[str, float]]]:
        combi = [
        (word, score, [score - self.w2v.similarity(word, search) for search in searches])
        for (word, score) in self.w2v.most_similar(
            positive=searches,
            topn=related_word_topn,
            )
        ]
        dists = [
            [*map(lambda x:x[2][i], combi)]
            for i, search in enumerate(searches)
        ]
        avg_dists = [statistics.mean(dist) for dist in dists]
        sigma_dists = [statistics.stdev(dist) for dist in dists]
        normalized = [
            (
                word,
                score, 
                [(search_scores[i]-avg_dists[i])/sigma_dists[i] for i, s in enumerate(search_scores)],
                search_scores
            )
            for (word, score, search_scores) in combi
        ]
        coocs = [*filter(None, [
            # None if len([*filter(lambda x: x > score, search_scores)]) > 0 else
            (
                word,
                float(score),
                float(sum(map(lambda x: erf(x, negative_boost), normalized))),
                dict([(search, (score+search_scores[i], normalized[i])) for i, search in enumerate(searches)])
            )
            for (word, score, normalized, search_scores) in normalized]
        )]
        sorted_result =  [*sorted(coocs, key=lambda x:-x[2])]
        if limit > 0:
            return sorted_result[0:limit]
        return sorted_result

related = RelatedWord(word2vec, None)

ret = related.get_related_words(["猫", "好き"], related_word_topn=1000)

display(ret)



[('可愛い',
  0.6815619468688965,
  0.41070960734246537,
  {'猫': (0.7793261408805847, 0.6868090527692646),
   '好き': (0.8266812562942505, 0.0035575477664252053)}),
 ('かわいい',
  0.6648719310760498,
  0.31282481980796745,
  {'猫': (0.7692475914955139, 0.7608455432944335),
   '好き': (0.7974321246147156, -0.1354168511711263)}),
 ('女の子',
  0.67524653673172,
  0.2895898866829594,
  {'猫': (0.6927868127822876, -0.21155402974369697),
   '好き': (0.898339182138443, 0.8663808228229298)}),
 ('あたし',
  0.5921953320503235,
  0.28298631057813667,
  {'猫': (0.6469616293907166, 0.20531021829549934),
   '好き': (0.74846550822258, 0.1269487974184852)}),
 ('父さん',
  0.591326117515564,
  0.28063729696094336,
  {'猫': (0.6397260427474976, 0.1340183415443292),
   '好き': (0.7536529004573822, 0.19396890400840117)}),
 ('パパ',
  0.6101052761077881,
  0.27924371590849645,
  {'猫': (0.6453093886375427, -0.013750943186171741),
   '好き': (0.792320191860199, 0.4140434162813464)}),
 ('お父さん',
  0.5902565717697144,
  0.2785745812281401,
 