In [None]:
import sentence_transformers
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
import datetime
import annoy
from annoy import AnnoyIndex
from tqdm import tqdm
from bisect import bisect_left
import datetime
import os
import re

In [None]:
os.environ['OMP_NUM_THREADS'] = '9' 
model = sentence_transformers.SentenceTransformer('LaBSE')

In [None]:
# Load Data and Run Sanity Checks
df = pd.read_csv("../Data/minimal_FactCheckData_local.csv.gz",compression = "gzip")

df["datePublished"] = pd.to_datetime(df.datePublished, errors = "coerce")
df = df.dropna(subset = ["datePublished"])
assert df["claimReviewed"].isna().sum() == 0, "NAs"
assert sum(df["claimReviewed"] == "") == 0, "Empties"
assert sum(df["claimReviewed"].apply(lambda x: len(x) < 5)) == 0, "Length < 5"

docs = df["claimReviewed"].tolist()
ids = df["claim_minimal"].tolist()

# Chunk docs into batches of 100.000
docs_chunks = [docs[x:x+10000] for x in range(0, len(docs), 10000)]
ids_chunks = [ids[x:x+10000] for x in range(0, len(ids), 10000)]

# Embedd each chunk, and export as dictionary with tweet_id as key
for i in range(len(docs_chunks)):
    # Test if chunk is already embedded:
    if os.path.isfile(f"../Data/embeddings/embeddings_{i+1}.npy"):
        continue
    print(f"Embedding chunk {i+1} of {len(docs_chunks)}")
    print(f"Starting at {datetime.datetime.now()}")
    embeddings = model.encode(docs_chunks[i], show_progress_bar=True, batch_size=256)  
    embeddings_dict = dict(zip(ids_chunks[i], embeddings))
    np.save(f"../Data/Embeddings/embeddings_{i+1}.npy", embeddings_dict)

In [None]:
df["datePublished"] = pd.to_datetime(df["datePublished"])
DATES = df.set_index("claim_minimal").datePublished.to_dict()

ORIGINAL_IDS = df.claim_minimal.to_list()
DEFAULT_DIMENSION = 768
DEFAULT_TREES = 50
DEFAULT_NEIGHBORS = 50
DEFAULT_THRESHOLD = np.sqrt(0.8)
DEFAULT_INCREASE = 2

def load_embeddings(file_path, ids_order = ORIGINAL_IDS):
    """Load embeddings from a numpy file."""
    try:
        embeddings_dict = np.load(file_path, allow_pickle=True).item()
        embeddings_ordered = {id: embeddings_dict[id] for id in ids_order if id in embeddings_dict}
        return embeddings_ordered
    except Exception as e:
        print(f"Failed to load embeddings from {file_path} with error {e}")
        return None
    
def create_annoy_index(embeddings_dict, dimension=DEFAULT_DIMENSION, trees=DEFAULT_TREES):
    """Create an Annoy index from embeddings."""
    index = AnnoyIndex(dimension, 'angular')

    assert list(embeddings_dict.keys()) == ORIGINAL_IDS, "The embeddings_dict does not contain the same ids as the original dataframe. Error in create_annoy_index."

    for i, v in tqdm(enumerate(embeddings_dict.values()), total = len(embeddings_dict)):
        index.add_item(i, v)

    index.build(trees)
    return index

def load_all_embeddings(folder_path):
    """Load all embeddings from a directory."""
    embeddings_dict = {}
    # Get the list of filenames and sort them by the number included in the filename
    filesnames = os.listdir(folder_path)
    filesnames = [x for x in filesnames if x.endswith('.npy')]
    filenames = sorted(filesnames, key=lambda x: int(re.search(r'\d+', x).group()))
    
    for file_name in filenames:
        if file_name.endswith('.npy') and file_name.startswith('embeddings'):
            embeddings_dict.update(load_embeddings(os.path.join(folder_path, file_name)))

    return embeddings_dict

def logging(string, colour, logging_lvl = 0):
    """Prints and logs a string with a timestamp."""
    now = datetime.datetime.now().strftime('%H:%M:%S')
    if colour == "red":
        print(f"\033[91m{string}Starting at {now}\033[0m")
    elif colour == "green":
        print(f"\033[92m{string}Starting at {now}\033[0m")
    
    if logging_lvl > 1:
        # Write to log file
        with open("log.txt", "a") as f:
            f.write(f"{string}. Starting at {now}\n")

def get_edge_list(embeddings_dict, index, start_neighbors=10, increase_rate=2, threshold=np.sqrt(0.8)):
    """Get a list of nearest neighbors for each item in the embeddings."""
    ids = list(embeddings_dict.keys())
    assert ids == ORIGINAL_IDS, "The embeddings_dict does not contain the same ids as the original dataframe. Error in get_edge_list."
    edge_list = {}

    for i, ind in tqdm(enumerate(ids), total = len(ids)):
        neighbors = start_neighbors
        nn, distances = index.get_nns_by_item(i, neighbors, include_distances=True)
        
        while distances and distances[-1] < threshold:
            neighbors = int(neighbors * increase_rate)
            nn, distances = index.get_nns_by_item(i, neighbors, include_distances=True)
        
        insertion_point = bisect_left(distances, threshold)
        edge_list[ind] = {ids[nn[x]]: (2-(distances[x]**2))/2 for x in range(insertion_point) if DATES[ids[nn[x]]] < DATES[ind]} # Exactly the wrong way around. 
    return edge_list

In [None]:
logging("","green")
# Get all embeddings from the folder
logging("Loading Embeddings. ", colour = "green")
embeddings = load_all_embeddings(folder_path = "../Data/embeddings/")
# Create the Annoy index
logging("Creating Annoy index. ", colour = "green")
annoy_index = create_annoy_index(embeddings_dict = embeddings, dimension=DEFAULT_DIMENSION, trees=DEFAULT_TREES)
# Create the edge_list
logging("Creating Edge List. ", colour = "green")
edge_list = get_edge_list(embeddings_dict = embeddings,
                          index = annoy_index,
                          start_neighbors=DEFAULT_NEIGHBORS,
                          increase_rate=DEFAULT_INCREASE,
                          threshold=DEFAULT_THRESHOLD)

In [None]:
# Pickle Edge List
import pickle
logging("Pickling Edge List. ", colour = "green")
with open("../Data/edge_list.pkl", "wb") as f:
    pickle.dump(edge_list, f)