In [1]:
import jsonlines
from bunkatopics.datamodel import Document, Term

from bunkatopics.datamodel import TopicRanking, BourdieuDimension, Term
from pydantic import BaseModel, Field
import typing as t


DOC_ID = str
TERM_ID = str
TOPIC_ID = str


class Document(BaseModel):
    doc_id: DOC_ID
    content: str
    size: t.Optional[float] = None
    x: t.Optional[float] = None
    y: t.Optional[float] = None
    topic_id: t.Optional[TOPIC_ID] = None
    topic_ranking: t.Optional[TopicRanking] = None  # Make topic_ranking optional
    term_id: t.Optional[t.List[TERM_ID]] = None
    embedding: t.Optional[t.List[float]] = Field(None, repr=False)
    bourdieu_dimensions: t.List[BourdieuDimension] = []



# Define a function to read documents from a JSONL file
def read_documents_from_jsonl(file_path):
    documents = []
    with jsonlines.open(file_path, mode="r") as reader:
        for item in reader:
            document = Document(**item)
            documents.append(document)
    return documents

In [2]:
documents = read_documents_from_jsonl("exports/bunka_docs_lemonde.jsonl")


In [3]:
import pandas as pd
df_embedding = pd.DataFrame([x.model_dump() for x in documents])
df_embedding = df_embedding[['doc_id', 'embedding']]

In [4]:
df_topics = pd.read_csv('exports/df_topics_top_docs.csv', index_col=[0])
df_topics['short_name'] = df_topics['topic_name'].apply(lambda x: '-'.join(x.split(' | ')[:7]))

In [5]:
df_final = pd.merge(df_embedding, df_topics, on = 'doc_id')

In [6]:
df_final.head(3)

Unnamed: 0,doc_id,embedding,content,ranking_per_topic,topic_id,topic_name,short_name
0,f9d2e3e9-81b5-4072-a,"[-0.03271692246198654, 0.06378008425235748, 0....","Discrètement, le gouvernement a prévu de relev...",810,bt-15,entreprises | logement | emploi | crise | chôm...,entreprises-logement-emploi-crise-chômage-entr...
1,7c5d0504-0854-4f2d-8,"[0.026198111474514008, 0.0498882457613945, -0....",GUIDE,161,bt-20,SEMAINE | GUIDE | week | GALERIES | end | TRAV...,SEMAINE-GUIDE-week-GALERIES-end-TRAVERS-entrées
2,8a73a0d0-296e-4de4-a,"[-0.05174738168716431, 0.0240201186388731, 0.0...",Le Sénat fait la chasse aux fraudeurs à la red...,997,bt-7,RÉFORME | LOI | ASSEMBLÉE | SYNDICATS | PROJET...,RÉFORME-LOI-ASSEMBLÉE-SYNDICATS-PROJET-députés...


In [7]:
import numpy as np
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder


In [8]:
# Convert embedding column to numpy array
X = np.array(df_final['embedding'].tolist())

# Convert topic_id to categorical labels
y = df_final['topic_id']

# Convert topic_id to integer labels using label encoding
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

In [9]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.3, random_state=42)

# Create XGBoost classifier
model = XGBClassifier(objective='multi:softmax', num_class=len(set(y)))

# Train the model
model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = model.predict(X_test)

In [14]:

# Decode the predicted labels back to original topic IDs
y_pred_decoded = label_encoder.inverse_transform(y_pred)
y_pred_decoded

array(['bt-23', 'bt-16', 'bt-9', ..., 'bt-9', 'bt-22', 'bt-21'],
      dtype=object)

In [15]:
# Decode the predicted labels back to original topic IDs
y_pred_decoded = label_encoder.inverse_transform(y_pred)
y_test_decoded = label_encoder.inverse_transform(y_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')

# Classification report
target_names = label_encoder.classes_

topic_short_dict = df_final[['topic_id', 'short_name']].set_index('topic_id')['short_name'].to_dict()

target_names_short_names = [topic_short_dict[x] for x in target_names]

Accuracy: 0.8179599804782821


In [16]:

report = classification_report(y_test_decoded, y_pred_decoded, target_names=target_names_short_names)
print(report)

                                                                     precision    recall  f1-score   support

     Révolution-COMBAT-démocratie-génération-histoire-femmes-presse       0.79      0.55      0.65       137
FAIM-MANIFESTATIONS-migrants-COUR-milliers-MANIFESTATION-avortement       0.87      0.63      0.73       167
         ANS-Questions-ordinateur-siècle-III-TÉLÉVISION-principales       0.83      0.73      0.78       175
               ERRATUM-SELECTION-LIGNE-tableau-informations-fr-JOUR       0.99      0.88      0.93       104
                                              INFORMATIONS-dépêches       1.00      1.00      1.00        21
       RECTIFICATIF-tirages-SPORT-ERRATUM-Résultats-samedi-mercredi       0.98      0.96      0.97        47
          MORTS-ACCIDENTS-INCENDIE-blessés-AVION-Airbus-inondations       0.88      0.75      0.81       199
          entreprises-logement-emploi-crise-chômage-entreprise-euro       0.76      0.70      0.73       245
     CONFÉRENCE-AL

In [13]:

report = classification_report(y_test_decoded, y_pred_decoded, target_names=target_names_short_names)
print(report)

                                                                     precision    recall  f1-score   support

     Révolution-COMBAT-démocratie-génération-histoire-femmes-presse       0.79      0.55      0.65       137
FAIM-MANIFESTATIONS-migrants-COUR-milliers-MANIFESTATION-avortement       0.87      0.63      0.73       167
         ANS-Questions-ordinateur-siècle-III-TÉLÉVISION-principales       0.83      0.73      0.78       175
               ERRATUM-SELECTION-LIGNE-tableau-informations-fr-JOUR       0.99      0.88      0.93       104
                                              INFORMATIONS-dépêches       1.00      1.00      1.00        21
       RECTIFICATIF-tirages-SPORT-ERRATUM-Résultats-samedi-mercredi       0.98      0.96      0.97        47
          MORTS-ACCIDENTS-INCENDIE-blessés-AVION-Airbus-inondations       0.88      0.75      0.81       199
          entreprises-logement-emploi-crise-chômage-entreprise-euro       0.76      0.70      0.73       245
     CONFÉRENCE-AL