In [68]:
import os

from itertools import product
from pathlib import Path
from datetime import datetime, timedelta
from dotenv import load_dotenv

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from bson import ObjectId
from IPython.display import display
from pymongo import MongoClient, errors
from pymongo.database import Database
from tqdm import tqdm

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchinfo import summary

load_dotenv("../../.env")

True

# Data collection

In [2]:
DATABASE = 'insightfinder-dev'

TOPICS_COLLECTION = 'topics'
TOPIC_EMBEDDINGS_COLLECTION = 'topic_embeddings'

MONGO_HOST = os.getenv("MONGO_HOST")

In [4]:
def mongo_transaction(func):
    """
    A decorator to execute a MongoDB operation with error handling.
    """
    def wrapper(*args, **kwargs):
        try:
            with MongoClient(MONGO_HOST) as mongo_client:
                db = mongo_client[DATABASE]
                return func(db, *args, **kwargs)
        except errors.PyMongoError as e:
            print(f"MongoDB error: {e}")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
    return wrapper


@mongo_transaction
def retrieve_topics_assignments(db: Database, start_date: str, end_date: str, skip_noise_topic: bool = True):
    topics_collection = db[TOPICS_COLLECTION]
    query = {"topic_start_date": {"$eq": start_date}, "topic_end_date": {"$eq": end_date}}
    if skip_noise_topic:
        query["topic_index"] = {"$ne": -1}
    return list(topics_collection.find(filter=query))


@mongo_transaction
def get_topic_embeddings(db: Database, topic_ids: list[ObjectId]):
    topic_embeddings_collection = db[TOPIC_EMBEDDINGS_COLLECTION]
    query = {"_id": {"$in": topic_ids}}
    return list(topic_embeddings_collection.find(filter=query))


def get_topics_data(topic_start_date, topic_end_date):
    topic_assignments = retrieve_topics_assignments(
        topic_start_date, topic_end_date
    )
    topic_assignments = pd.DataFrame(topic_assignments)
    topic_embeddings = {}
    topic_ids = topic_assignments["topic_id"].unique().tolist()
    for x in get_topic_embeddings(topic_ids):
        topic_embeddings[x["_id"]] = x["embedding"]
    return topic_assignments, topic_embeddings

In [53]:
topic_start_date = "2024-07-22"
topic_end_date = "2024-07-28"

topic_assignments, topic_embeddings = get_topics_data(topic_start_date, topic_end_date)

In [54]:
topic_assignments

Unnamed: 0,_id,document_id,document_date,topic_index,topic_label,assignment_probability,topic_start_date,topic_end_date,topic_id
0,66abfd045d3be4373f7709a6,66a4018aba06b56c721e77ba,2024-07-25,1,1_harris_biden_joe_kamala,1.000000,2024-07-22,2024-07-28,66abfd045d3be4373f770940
1,66abfd045d3be4373f7709a7,66a4016350e9b75b4a0f05ee,2024-07-22,4,4_film_staffel_filme_kopf,0.450914,2024-07-22,2024-07-28,66abfd045d3be4373f770943
2,66abfd045d3be4373f7709a8,66a40193fda8353911671a11,2024-07-25,26,26_app_blut_sichergestellt_behrde,0.734515,2024-07-22,2024-07-28,66abfd045d3be4373f770959
3,66abfd045d3be4373f7709a9,66a40188fda8353911671a0f,2024-07-25,20,20_quartal_prozent_zweiten quartal_wall,1.000000,2024-07-22,2024-07-28,66abfd045d3be4373f770953
4,66abfd045d3be4373f7709aa,66a4018dfda8353911671a10,2024-07-25,11,11_bank_wirtschaft_quartal_zweiten quartal,0.409256,2024-07-22,2024-07-28,66abfd045d3be4373f77094a
...,...,...,...,...,...,...,...,...,...
2412,66abfd045d3be4373f77153e,66a660c7ff627695746154fb,2024-07-26,76,76_schloss_liste_kleinstadt_ausgezeichnet,1.000000,2024-07-22,2024-07-28,66abfd045d3be4373f77098b
2413,66abfd045d3be4373f77153f,66a660bbff627695746154f8,2024-07-26,34,34_tote_toten_sturm_ums leben,0.073268,2024-07-22,2024-07-28,66abfd045d3be4373f770961
2414,66abfd045d3be4373f771540,66a660d4ff627695746154ff,2024-07-26,65,65_grenzen_migration_geflchtete_einfhrung,1.000000,2024-07-22,2024-07-28,66abfd045d3be4373f770980
2415,66abfd045d3be4373f771541,66a660deff62769574615502,2024-07-26,66,66_rede_treffen mit_israels_hamas,0.876987,2024-07-22,2024-07-28,66abfd045d3be4373f770981


# Data preparation

In [56]:
def build_dataset(topic_assignments, topic_embeddings, topic_start_date, topic_end_date):
    # extract all dates within the topics date range
    dates = pd.date_range(topic_start_date, topic_end_date).strftime("%Y-%m-%d").tolist()

    # collect document counts per topic-date pairs
    dataset_df = topic_assignments[["topic_id", "document_date", "document_id"]]
    dataset_df = dataset_df.rename(columns={"document_date": "date", "document_id": "count"})
    dataset_df = dataset_df.groupby(by=["topic_id", "date"])["count"].nunique()
    dataset_df = dataset_df.reset_index()

    # fill missing values
    topic_ids = dataset_df["topic_id"].unique().tolist()
    dataset_df = pd.merge(
        dataset_df,
        pd.DataFrame(list(product(topic_ids, dates)), columns=["topic_id", "date"]),
        on=["topic_id", "date"],
        how="right"
    )
    dataset_df["count"] = dataset_df["count"].fillna(value=0).astype(int)

    # format the dataset
    dataset_df = dataset_df.sort_values(by=["topic_id", "date"])
    dataset_df = dataset_df.groupby(by="topic_id").agg({"date": list, "count": list})
    dataset_df = dataset_df.reset_index()

    # add the embeddings
    dataset_df["embedding"] = dataset_df["topic_id"].map(topic_embeddings)
    return dataset_df

In [58]:
dataset_df = build_dataset(topic_assignments, topic_embeddings, topic_start_date, topic_end_date)
dataset_df.head()

Unnamed: 0,topic_id,date,count,embedding
0,66abfd045d3be4373f77093f,"[2024-07-22, 2024-07-23, 2024-07-24, 2024-07-2...","[11, 21, 20, 22, 21, 11, 0]","[0.03168333164796693, 0.15174446942843162, -0...."
1,66abfd045d3be4373f770940,"[2024-07-22, 2024-07-23, 2024-07-24, 2024-07-2...","[29, 18, 9, 20, 6, 2, 0]","[0.00894042448896282, 0.19605032915144033, 0.0..."
2,66abfd045d3be4373f770941,"[2024-07-22, 2024-07-23, 2024-07-24, 2024-07-2...","[11, 11, 19, 12, 7, 10, 0]","[-0.0026579808226891884, 0.17735331124138265, ..."
3,66abfd045d3be4373f770942,"[2024-07-22, 2024-07-23, 2024-07-24, 2024-07-2...","[9, 12, 17, 10, 13, 0, 0]","[0.014365441002883016, 0.1931595825191055, -0...."
4,66abfd045d3be4373f770943,"[2024-07-22, 2024-07-23, 2024-07-24, 2024-07-2...","[7, 11, 8, 15, 9, 5, 0]","[-0.04624435288801056, 0.17007282070937704, -0..."


In [59]:
dataset_df.isna().sum()

topic_id     0
date         0
count        0
embedding    0
dtype: int64

In [62]:
class Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # we need to convert ObjectId into strings
        topic_id = str(self.data.iloc[idx]['topic_id'])
        counts = torch.tensor(self.data.iloc[idx]['count'], dtype=torch.float32)
        embedding = torch.tensor(self.data.iloc[idx]['embedding'], dtype=torch.float32)
        return {"topic_id": topic_id, "counts": counts, "embedding": embedding}

In [63]:
dataset = Dataset(dataset_df)

In [64]:
len(dataset)

102

# Prepare model

In [66]:
class Model(nn.Module):
    def __init__(self, embedding_dim, counts_dim, counts_hidden_size):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=counts_hidden_size, batch_first=True)
        self.fc_counts = nn.Linear(counts_hidden_size, counts_hidden_size)
        self.fc_topic = nn.Linear(embedding_dim, embedding_dim)
        self.fc_combined = nn.Linear(embedding_dim + counts_hidden_size, 1)

    def forward(self, counts, topic_emb):
        # Reshape counts for LSTM
        counts = counts.unsqueeze(-1)  # Shape: [batch_size, counts_dim, 1]
        
        # Process counts with LSTM
        lstm_out, _ = self.lstm(counts)
        
        # Use the output from the last step
        counts_emb = lstm_out[:, -1, :]
        
        # Transform counts_emb to match embedding dimensions
        counts_emb = F.relu(self.fc_counts(counts_emb))

        # Transform topic embeddings
        topic_emb_processed = F.relu(self.fc_topic(topic_emb))
        
        # Combine the embeddings
        combined_emb = torch.cat((counts_emb, topic_emb_processed), dim=1)
        
        # Classification
        output = self.fc_combined(combined_emb)
        return output

In [69]:
model_path = "../data/models/v2/best_model.pth"

model = Model(embedding_dim=384, counts_dim=7, counts_hidden_size=64)
model.load_state_dict(torch.load(model_path))

summary(model)

Layer (type:depth-idx)                   Param #
Model                                    --
├─LSTM: 1-1                              17,152
├─Linear: 1-2                            4,160
├─Linear: 1-3                            147,840
├─Linear: 1-4                            449
Total params: 169,601
Trainable params: 169,601
Non-trainable params: 0

# Inference

In [70]:
def get_predictions(model, dataloader):
    model.eval()
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for batch in dataloader:
            logits = model(batch["counts"], batch["embedding"])
            probs = torch.sigmoid(logits).squeeze()
            preds = probs > 0.5
            all_probs.extend(probs.cpu().numpy().tolist())
            all_preds.extend(preds.cpu().numpy().astype(int).tolist())
    
    probabilitites = np.array(all_probs)
    predictions = np.array(all_preds)
    return probabilitites, predictions

In [73]:
dataloader = DataLoader(dataset, batch_size=16)
probabilities, predictions = get_predictions(model, dataloader)

In [76]:
predictions

array([1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
       0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0])

In [77]:
preds_df = pd.DataFrame({
    "topics": [x["topic_id"] for x in dataset],
    "probability": probabilities,
    "pred_matches_trending": predictions.astype(int),
})

preds_df.head()

Unnamed: 0,topics,probability,pred_matches_trending
0,66abfd045d3be4373f77093f,0.614504,1
1,66abfd045d3be4373f770940,0.942212,1
2,66abfd045d3be4373f770941,0.784287,1
3,66abfd045d3be4373f770942,0.537381,1
4,66abfd045d3be4373f770943,0.064832,0
