In [None]:
import os
import pickle
from glob import glob
from tqdm import tqdm
from datetime import datetime
from src.embedding_models.all_MiniLM_L6_v2 import All_MiniLM_L6_v2
from src.embedding_models.all_mpnet_base_v2 import all_mpnet_base_v2
from src.embedding_models.roberta import Roberta
from src.downloader import download_dataset

In [9]:
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
download = False
dataset_dir = "./dataset"
articles_dir = "./output/clean_plaintext_articles/"
embeddings_output_maindir = "./output/embeddings/"
model_name = "roberta" # all_MiniLM_L6_v2 # all_mpnet_base_v2

In [11]:
if download:
    download_dataset(dataset_dir)

In [None]:
model = None
if model_name == "all_MiniLM_L6_v2":
    model = All_MiniLM_L6_v2()
elif model_name == "all_mpnet_base_v2":
    model = all_mpnet_base_v2()
elif model_name == "roberta"
    model = Roberta()

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [None]:
article_embeddings = {}
# Note from Nat : if this is too slow we can try to do batch processing
for article_path in tqdm(glob(f'{articles_dir}/*.txt')):
    article_name = os.path.splitext(os.path.basename(article_path))[0]
    with open(article_path, 'r', encoding='utf-8', errors='ignore') as file:
        text = file.read()
        embedding = model.embed(text)
        article_embeddings[article_name] = embedding

100%|██████████| 4604/4604 [1:37:34<00:00,  1.27s/it]  


In [14]:
# Create embedding output directory
embeddings_output_dir = os.path.join(embeddings_output_maindir, model_name, current_datetime)
if not os.path.exists(embeddings_output_dir):
    os.makedirs(embeddings_output_dir)

# Save pkl embeddings in output dir
embeddings_output_path = os.path.join(embeddings_output_dir, "embeddings.pkl")
with open(embeddings_output_path, 'wb') as file:
    pickle.dump(article_embeddings, file)

# Save pkl embeddings as the latest one
latest_embedding_output_path = os.path.join(embeddings_output_maindir, "latest_embeddings.pkl")
with open(latest_embedding_output_path, 'wb') as file:
    pickle.dump(article_embeddings, file)