In [1]:
import pandas as pd
from collections import defaultdict

In [15]:
with open('review.sorted.uniq.refined.tsv.text.tok') as f:
    lines = [l.strip() for l in f.read().splitlines() if l.strip()]

In [16]:
len(lines)

100000

In [17]:
def get_term_frequency(document):
    term_freq = {}
    
    words = document.split()

    for w in words:
        term_freq[w] = 1 + (0 if term_freq.get(w) is None else term_freq[w])

    return term_freq

In [18]:
def get_context_counts(lines, vocab, w_size=2):
    context_cnt = defaultdict(int)
    
    for line in lines:
        words = line.split()
        
        for i, w in enumerate(words):
            if w in vocab:
                for c in words[i - w_size:i + w_size]:
                    if w != c:
                        context_cnt[(w, c)] += 1
            
    return context_cnt

In [19]:
def get_co_occurrence_df(context_cnt, vocab):
    data = []
    
    for word1 in vocab:
        row = []
        
        for word2 in vocab:
            try:
                count = context_cnt[(word1, word2)]
            except KeyError:
                count = 0
            row.append(count)
            
        data.append(row)
    
    return pd.DataFrame(data, index=vocab, columns=vocab)

In [20]:
term_freq = pd.Series(
    get_term_frequency(' '.join(lines))
).sort_values(ascending=False)

term_freq

.       165069
이        91963
고        77992
는        57894
네요       57552
         ...  
▁두두          1
플레임          1
함몰           1
▁브라네         1
요긴           1
Length: 53172, dtype: int64

In [21]:
vector_size = 1000

In [22]:
term_freq.index[:vector_size]

Index(['.', '이', '고', '는', '네요', '에', '하', '가', '도', '은',
       ...
       '끈', '▁퀄리티', '▁올리', '화', '▁의심', '▁생기', '▁달라요', '▁가지', '▁얼룩', '동안'],
      dtype='object', length=1000)

In [23]:
context_cnt = pd.Series(
    get_context_counts(
        lines,
        term_freq.index[:vector_size],
        w_size=4
    )
)

context_cnt

어요  ▁다    681
    녹      23
    아서    383
    ▁왓     58
    ▁.    174
         ... 
인지  딴       1
곳   걸       1
에서  딴       1
▁살  딴       1
되   부탁      1
Length: 2140653, dtype: int64

In [24]:
df = get_co_occurrence_df(context_cnt, term_freq.index[:vector_size])

df

Unnamed: 0,.,이,고,는,네요,에,하,가,도,은,...,끈,▁퀄리티,▁올리,화,▁의심,▁생기,▁달라요,▁가지,▁얼룩,동안
.,0,25370,16357,11938,29278,10167,12612,12501,12971,10852,...,104,162,76,88,177,133,174,57,96,74
이,22278,0,9372,9826,13150,6360,4457,4571,3985,5869,...,121,40,21,43,140,179,103,27,203,31
고,15417,9491,0,7297,3498,6534,15893,4763,19640,5261,...,68,45,116,32,31,129,48,330,75,61
는,11377,9541,7095,0,3639,7439,7768,4505,4009,3237,...,37,102,66,69,39,88,22,76,44,50
네요,29495,15165,4105,4148,0,3470,6690,6818,4169,3645,...,30,40,33,22,109,105,5,29,50,29
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
▁생기,133,179,128,86,100,62,14,71,56,17,...,0,0,0,0,0,0,0,0,3,0
▁달라요,167,116,49,25,8,11,41,64,33,18,...,0,0,0,1,0,0,0,0,0,0
▁가지,46,38,335,73,27,36,30,47,26,22,...,1,0,0,0,0,1,0,0,0,1
▁얼룩,64,193,67,48,45,129,9,10,62,72,...,1,0,0,1,0,3,1,0,0,0


In [25]:
import torch

In [28]:
print(torch.__version__)

1.10.0


In [29]:
def get_cosine_similarity(x1, x2):
    return (x1 * x2).sum() / ((x1**2).sum()**.5 * (x2**2).sum()**.5 + 1e-10)

In [41]:
def get_nearest(query, dataframe, metric, top_k, ascending=True):
    vector = torch.from_numpy(dataframe.loc[query].values).float()
    distances = dataframe.apply(
        lambda x: metric(vector, torch.from_numpy(x.values).float()),
        axis=1,
    )
    top_distances = distances.sort_values(ascending=ascending)[:top_k]

    print(', '.join([f'{k} ({v:.1f})' for k, v in top_distances.items()]))

In [42]:
print('\nCosine similarity:')
get_nearest('반품', df, get_cosine_similarity, 30, ascending=False)


Cosine similarity:
반품 (1.0), 교환 (0.9), ▁반품 (0.9), ▁그냥 (0.9), ▁황당 (0.9), ▁교환 (0.9), ▁참 (0.9), ▁걍 (0.9), 사용 (0.9), ▁찝찝 (0.9), ▁그래서 (0.9), 그냥 (0.9), ▁. (0.8), ▁허접 (0.8), ㅠㅠ (0.8), 정말 (0.8), ▁뭐 (0.8), 너무 (0.8), 다 (0.8), ㅠ (0.8), 그리고 (0.8), ▁일단 (0.8), ▁불편 (0.8), ㅜ (0.8), 배송 (0.8), ▁후회 (0.8), ㅋ (0.8), 진짜 (0.8), ▁진짜 (0.8), ㅡㅡ (0.8)
