In [17]:
import jsonlines
from bunkatopics.datamodel import Document, Term

from bunkatopics.datamodel import TopicRanking, BourdieuDimension, Term
from pydantic import BaseModel, Field
import typing as t


DOC_ID = str
TERM_ID = str
TOPIC_ID = str


class Document(BaseModel):
    doc_id: DOC_ID
    content: str
    size: t.Optional[float] = None
    x: t.Optional[float] = None
    y: t.Optional[float] = None
    topic_id: t.Optional[TOPIC_ID] = None
    topic_ranking: t.Optional[TopicRanking] = None  # Make topic_ranking optional
    term_id: t.Optional[t.List[TERM_ID]] = None
    embedding: t.Optional[t.List[float]] = Field(None, repr=False)
    bourdieu_dimensions: t.List[BourdieuDimension] = []



# Define a function to read documents from a JSONL file
def read_documents_from_jsonl(file_path):
    documents = []
    with jsonlines.open(file_path, mode="r") as reader:
        for item in reader:
            document = Document(**item)
            documents.append(document)
    return documents

In [18]:
documents = read_documents_from_jsonl("exports/bunka_docs_lemonde.jsonl")


In [19]:
import pandas as pd
df_embedding = pd.DataFrame([x.model_dump() for x in documents])
df_embedding = df_embedding[['doc_id', 'embedding']]

In [21]:
df_topics = pd.read_csv('exports/df_topics_top_docs.csv', index_col=[0])
df_topics['short_name'] = df_topics['topic_name'].apply(lambda x: '-'.join(x.split(' | ')[:7]))

In [22]:
df_final = pd.merge(df_embedding, df_topics, on = 'doc_id')

In [23]:
df_final.head(3)

Unnamed: 0,doc_id,embedding,content,ranking_per_topic,topic_id,topic_name,short_name
0,f9d2e3e9-81b5-4072-a,"[-0.03271692246198654, 0.06378008425235748, 0....","Discrètement, le gouvernement a prévu de relev...",810,bt-15,entreprises | logement | emploi | crise | chôm...,entreprises-logement-emploi-crise-chômage-entr...
1,7c5d0504-0854-4f2d-8,"[0.026198111474514008, 0.0498882457613945, -0....",GUIDE,161,bt-20,SEMAINE | GUIDE | week | GALERIES | end | TRAV...,SEMAINE-GUIDE-week-GALERIES-end-TRAVERS-entrées
2,8a73a0d0-296e-4de4-a,"[-0.05174738168716431, 0.0240201186388731, 0.0...",Le Sénat fait la chasse aux fraudeurs à la red...,997,bt-7,RÉFORME | LOI | ASSEMBLÉE | SYNDICATS | PROJET...,RÉFORME-LOI-ASSEMBLÉE-SYNDICATS-PROJET-députés...


In [24]:
import numpy as np
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder


In [25]:
# Convert embedding column to numpy array
X = np.array(df_final['embedding'].tolist())

# Convert topic_id to categorical labels
y = df_final['topic_id']

# Convert topic_id to integer labels using label encoding
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

In [26]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.3, random_state=42)

# Create XGBoost classifier
model = XGBClassifier(objective='multi:softmax', num_class=len(set(y)))

# Train the model
model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = model.predict(X_test)

In [10]:

# Decode the predicted labels back to original topic IDs
y_pred_decoded = label_encoder.inverse_transform(y_pred)
y_pred_decoded

array(['bt-23', 'bt-1', 'bt-0', ..., 'bt-10', 'bt-8', 'bt-0'],
      dtype=object)

In [17]:
# Decode the predicted labels back to original topic IDs
y_pred_decoded = label_encoder.inverse_transform(y_pred)
y_test_decoded = label_encoder.inverse_transform(y_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')

# Classification report
target_names = label_encoder.classes_

topic_short_dict = df_final[['topic_id', 'short_name']].set_index('topic_id')['short_name'].to_dict()

target_names_short_names = [topic_short_dict[x] for x in target_names]

Accuracy: 0.6428645563051375


array(['bt-0', 'bt-1', 'bt-10', 'bt-14', 'bt-15', 'bt-16', 'bt-18',
       'bt-2', 'bt-20', 'bt-21', 'bt-22', 'bt-23', 'bt-3', 'bt-4', 'bt-6',
       'bt-7', 'bt-8', 'bt-9'], dtype=object)

In [23]:

report = classification_report(y_test_decoded, y_pred_decoded, target_names=target_names_short_names)
print(report)

                                                                     precision    recall  f1-score   support

     Révolution-COMBAT-démocratie-génération-histoire-femmes-presse       0.47      0.50      0.49      1890
FAIM-MANIFESTATIONS-migrants-COUR-milliers-MANIFESTATION-avortement       0.47      0.40      0.43      1368
         ANS-Questions-ordinateur-siècle-III-TÉLÉVISION-principales       0.54      0.45      0.49      1671
          MORTS-ACCIDENTS-INCENDIE-blessés-AVION-Airbus-inondations       0.74      0.56      0.64      1042
          entreprises-logement-emploi-crise-chômage-entreprise-euro       0.53      0.56      0.55      1754
     CONFÉRENCE-ALGÉRIE-visite-président-Calédonie-RELATIONS-traité       0.56      0.63      0.60      1871
                 COUPE-TOUR-CHAMPIONNATS-MONDE-Mondial-ÉQUIPE-Bleus       0.86      0.73      0.79       975
            gauche-droite-ÉLECTIONS-Pen-SOCIALISTES-MAJORITÉ-Macron       0.65      0.69      0.67      2066
                  

In [11]:

report = classification_report(y_test_decoded, y_pred_decoded, target_names=target_names_short_names)
print(report)

Accuracy: 0.6428645563051375
              precision    recall  f1-score   support

        bt-0       0.47      0.50      0.49      1890
        bt-1       0.47      0.40      0.43      1368
       bt-10       0.54      0.45      0.49      1671
       bt-14       0.74      0.56      0.64      1042
       bt-15       0.53      0.56      0.55      1754
       bt-16       0.56      0.63      0.60      1871
       bt-18       0.86      0.73      0.79       975
        bt-2       0.65      0.69      0.67      2066
       bt-20       0.91      0.65      0.76       724
       bt-21       0.77      0.80      0.78      1894
       bt-22       0.73      0.72      0.73      1687
       bt-23       0.58      0.60      0.59      1706
        bt-3       0.78      0.74      0.76      1719
        bt-4       0.94      0.62      0.75       796
        bt-6       0.69      0.79      0.74      2168
        bt-7       0.57      0.58      0.57      2050
        bt-8       0.62      0.76      0.68      229