# Measuring Document Similarity

In [4]:
from ekorpkit import eKonf

if eKonf.is_colab():
    eKonf.mount_google_drive()
ws = eKonf.set_workspace(
    workspace="/workspace", 
    project="ekorpkit-book/exmaples", 
    task="esg", 
    log_level="INFO",
    verbose=True
)
print("version:", ws.version)
print("project_dir:", ws.project_dir)

INFO:matplotlib.font_manager:generated new fontManager
INFO:ekorpkit.base:Set environment variable EKORPKIT_DATA_ROOT=/workspace/data
INFO:ekorpkit.base:Set environment variable CACHED_PATH_CACHE_ROOT=/workspace/.cache/cached_path
INFO:ekorpkit.base:Set environment variable WANDB_DIR=/workspace/projects/ekorpkit-book/exmaples/logs
INFO:ekorpkit.base:Set environment variable WANDB_PROJECT=ekorpkit-book-exmaples
INFO:ekorpkit.base:Set environment variable WANDB_NOTEBOOK_NAME=/workspace/projects/ekorpkit-book/exmaples/logs/esg-nb
INFO:ekorpkit.base:Set environment variable WANDB_SILENT=False


version: 0.1.40.post0.dev83
project_dir: /workspace/projects/ekorpkit-book/exmaples
time: 1.64 s (started: 2023-02-01 09:36:13 +00:00)


## Load data to predict

In [5]:
news_data_dir = ws.project_dir / "esg/data/econ_news_kr/news_slice"
filename = "esg_news_valid_20221229.parquet"

valid_data = eKonf.load_data(filename, news_data_dir)
cols = ["text", "filename", "chunk_id", "codes"]
valid_data[cols].head()

Unnamed: 0,text,filename,chunk_id,codes
0,◆ 2020 경제기상도 / 업종별 전망 (반도체) ◆ 지난해 미·중 무역분쟁과 공...,02100101.20200101040200001.txt,0,660
2,"◆ 2020 경제기상도 / 업종별 전망 (가전) ◆ TV, 냉장고, 세탁기 등 전...",02100101.20200101040200002.txt,0,66570
3,"◆ 2020 경제기상도 / 업종별 전망 (가전) ◆ TV, 냉장고, 세탁기 등 전...",02100101.20200101040200002.txt,0,5930
4,◆ 2020 경제기상도 / 업종별 전망 (디스플레이) ◆ 액정표시장치(LCD) 시...,02100101.20200101040201001.txt,0,34220
5,디스플레이 업계 등에서는 삼성과 LG가 글로벌 디스플레이 시장에서 중국 업체의 L...,02100101.20200101040201001.txt,1,3550


time: 6.57 s (started: 2023-02-01 09:36:39 +00:00)


In [20]:
documents = valid_data["text"].sample(10).tolist()

time: 10.9 ms (started: 2023-01-19 08:56:22 +00:00)


## Predict similarity

In [21]:
from transformers import ElectraModel, ElectraTokenizer

# Initialize the Electra model and tokenizer
model = ElectraModel.from_pretrained('entelecheia/ekonelectra-base-discriminator')
tokenizer = ElectraTokenizer.from_pretrained('entelecheia/ekonelectra-base-discriminator')

Some weights of the model checkpoint at entelecheia/ekonelectra-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


time: 2.78 s (started: 2023-01-19 08:56:23 +00:00)


In [22]:
# Tokenize and encode the documents
encoded_docs = [tokenizer.encode(doc, return_tensors='pt', truncation=True) for doc in documents]


time: 31.7 ms (started: 2023-01-19 08:56:26 +00:00)


In [23]:
# Generate the embeddings for the documents
embeddings = [model(doc)[0][:, 0, :].detach().numpy() for doc in encoded_docs]

time: 2.47 s (started: 2023-01-19 08:56:27 +00:00)


In [30]:
import numpy as np

# Get the mean of the embeddings
mean_embeddings = [np.mean(embedding, axis=0) for embedding in embeddings]
# Gey CLS token embedding
cls_embeddings = [embedding[0] for embedding in embeddings]

time: 1.38 ms (started: 2023-01-19 08:59:45 +00:00)


In [31]:
# Compute the similarity matrix using cosine similarity
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(cls_embeddings)

print(similarity_matrix)

[[1.0000001  0.968363   0.92001873 0.9597888  0.9621922  0.9525209
  0.96460295 0.9683971  0.9587421  0.9609047 ]
 [0.968363   1.         0.9243116  0.99047786 0.9818795  0.974076
  0.9874664  0.99173224 0.9730664  0.9540314 ]
 [0.92001873 0.9243116  1.0000002  0.90941644 0.9215709  0.91559356
  0.92173624 0.9292955  0.94094986 0.9056404 ]
 [0.9597888  0.99047786 0.90941644 1.0000002  0.9848912  0.97961116
  0.9907066  0.9894204  0.9725035  0.9612313 ]
 [0.9621922  0.9818795  0.9215709  0.9848912  1.         0.9741106
  0.9906047  0.9855628  0.9789446  0.9729687 ]
 [0.9525209  0.974076   0.91559356 0.97961116 0.9741106  1.
  0.9774029  0.9789716  0.97945994 0.96554124]
 [0.96460295 0.9874664  0.92173624 0.9907066  0.9906047  0.9774029
  0.9999999  0.99083006 0.97806597 0.9706383 ]
 [0.9683971  0.99173224 0.9292955  0.9894204  0.9855628  0.9789716
  0.99083006 0.9999997  0.9804057  0.9656478 ]
 [0.9587421  0.9730664  0.94094986 0.9725035  0.9789446  0.97945994
  0.97806597 0.9804057  1.