In [8]:
import typing
import pathlib

import rich.progress
import pandas
import cltrier_lib

In [9]:
EXPORT_PATH: str = "../data/interim/twitter.german.dataset.enriched.csv"

In [None]:
dataset: pandas.DataFrame = (
    pandas.merge(
        pandas.read_csv("../data/interim/twitter.german.replies.csv", index_col=0),
        pandas.read_csv("../data/interim/twitter.german.posts.csv", index_col=0),
        how="left",
        left_on="conversation_id",
        right_on="id",
        suffixes=("_reply", "_post")
    )
    .rename(columns=dict(username="author_post", first_name="author_first_name_post", last_name="author_last_name_post", party="author_party_post"))
    [["id_post", "id_reply", "author_id_post", "author_id_reply", "author_first_name_post", "author_last_name_post", "author_party_post", "text_post", "text_reply"]]
)
dataset

In [30]:
dataset.to_csv(EXPORT_PATH.replace(".enriched", ""))

In [27]:
topic_extraction_instruction: str = """Your task is to extract the main topics of the given tweet. Summarize topics exceeding 10 characters. Keep the total number of topics to 3 or fewer. 

Respond only with the topic names separated by commas. Omit any justification. This is the tweet: 
"""

In [None]:
for new_col, source_col, instruction in [
    ("topics_post", "text_post", topic_extraction_instruction),
    ("topics_reply", "text_reply", topic_extraction_instruction),
]:  
    
    if pathlib.Path(EXPORT_PATH).is_file():
        dataset = pandas.read_csv(EXPORT_PATH, index_col=0)

    if new_col not in dataset.columns:
        
        predictions: typing.List[str] = [
            cltrier_lib.inference.Pipeline()(
                chat= cltrier_lib.inference.schemas.Chat(messages=[
                    cltrier_lib.inference.schemas.Message(role="system", content=instruction),
                    cltrier_lib.inference.schemas.Message(role="user", content=content)
                ])
            )[-1].content
            for content in rich.progress.track(dataset[source_col])
        ]

        dataset = dataset.assign(**{new_col: predictions})
        dataset.to_csv(EXPORT_PATH)
    
    display(dataset[new_col].value_counts())
    

In [None]:
# manual (human) filtering and reduction of topics, comparison and unification with topics_reply to improve the dataset quality
list(dataset["topics_post"].str.split(",").explode().str.strip().drop_duplicates())