In [None]:
from datasets import load_dataset
from label_generation import generate_labels, save_generation_output
from label_clustering import cluster_labels_gpt, make_clustering_prompt
import os
import json

## 1. Data Pre-processing

In [None]:
# Load dataset from Hugging Face
# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("jhu-clsp/astro-llms-full-query-data")

In [None]:
# make a dictionary where keys are thread id and values are full user query
data_processed = {}
label_gold = {}
full_data = True
n_samples = len(ds['train']) if full_data else 25
for i in range(n_samples):
    segment_id = 0 # some datasets have multiple segments per document, here we only use 0 as there is only one segment
    thread_id = str(ds['train'][i]['thread_ts']) + '_' + str(segment_id)
    data_processed[thread_id] = ds['train'][i]['full_user_query']
    label_gold[thread_id] = [ds['train'][i]['Open Coding']]
# print first 3 items
print("Examples of processed data:")
list(data_processed.items())[:3]

## 2. Label Generation

In [None]:
# Configuration for pipeline
config = {
# Generation Model 
"model_name": "gpt-4o-mini",

# Generation
"generation_output_dir": "./results/generation", # Change Me

# CLustering
"cluster_model_name": "gpt-4o-mini",
"cluster_output_dir": "./results/clustering" # Change Me, set to current directory for demo
}

In [None]:
# Define the system prompt for label generation
background = """
Query type refers to the strategy, motivation or knowledge solicited by the user. It is NOT about the topic of the query content. 
An LLM-powered bot is deployed for scientists to query literature in astronomy and then to analyze scientists' initial interactions. 
"""

coding_goal = "understanding the query type to the literature search bot from the astronomy scientists."

system_prompt = f"""
{background}

We are using the queries to this bot to conduct INDUCTIVE Coding. The labeling aims to {coding_goal}

Instruction:
- Label the input only when it is HIGHLY RELEVANT and USEFUL for {coding_goal}.
- Then, define the phrase of the label. The label description should be observational, concise and clear.
- ONLY output the label and DO NOT output any explanation.

Format:
- Define the label using the format \"LABEL: [The phrase of the label]\". 
- If there are multiple labels, each label is a new line. 
- If the input is irrelevant, use \"LABEL: [Irrelevant]\". 
"""


In [None]:
gen_result = generate_labels(data_processed, system_prompt, config)

In [None]:
# save generation results
gen_result_id = save_generation_output(gen_result, config["generation_output_dir"])

## 3. Hierarchical Clustering

In [None]:
dataset = "astrobot"
cluster_prompt = make_clustering_prompt(dataset=dataset)
cluster_result = cluster_labels_gpt(gen_result, cluster_prompt, config, gen_result_id=gen_result_id)


## 4. Evaluation

In [None]:
from metrics import *

In [None]:
gold_themes = list(set([i.lower() for i in ds['train']['Open Coding']]))
pred_themes = list(set([i.lower() for i in cluster_result[-1].keys()]))
print("==== Gold Themes ====")
for i in gold_themes:
	print(i)
print("==== Predicted Themes ====")
for i in pred_themes:
	print(i)

### Theme precision and recall

In [None]:
similarity_threshold = 0.4 # You can adjust this threshold based on your needs
theme_prec_score = theme_precision(gold_themes, pred_themes, cos_sim_thresh=similarity_threshold)
theme_recall_score = theme_recall(gold_themes, pred_themes, cos_sim_thresh=similarity_threshold)
print(f"Theme Precision: {theme_prec_score:.2f}")
print(f"Theme Recall: {theme_recall_score:.2f}")

### Segment Precision and Recall

In [None]:
## If you want to load previous generation results, you can do so as follows:
# gen_result_id = "<your_generation_result_id>" # Provide your generation result ID here
# gen_result = f"./results/generation/generation_{gen_result_id}.json"
# gen_result = json.load(open(gen_result, 'r'))

#### Mapping themes to segments

In [None]:
clustering_dir = f"./results/clustering/clustering_{gen_result_id}"
cluster_mapping = create_mapping(clustering_dir)

In [None]:
# add final themes from mapping to gen_result
for doc_id in gen_result.keys():
    for segment in gen_result[doc_id]["LLM_Annotation"]:
        for label in segment["label"]:
            theme = cluster_mapping.get(label.lower(), None)
            if theme is not None:
                segment.setdefault("theme", set()).add(theme)
            else:
                segment.setdefault("theme", set()).add("irrelevant")
        segment["theme"] = list(segment["theme"])

In [None]:
# add gold labels to gen_result
for k in gen_result.keys():
    for i, segment in enumerate(gen_result[k]["LLM_Annotation"]):
        segment["gold_label"] = label_gold[k+"_"+str(i)]

#### Evaluate segment-level metrics

In [None]:
similarity_threshold = 0.4 # You can adjust this threshold based on your needs
prec_by_theme = get_precision_by_theme(gold_themes, pred_themes, gen_result, similarity_threshold=similarity_threshold)
seg_prec_score = segment_precision(prec_by_theme)
recall_by_theme = get_recall_by_theme(gold_themes, pred_themes, gen_result, similarity_threshold=similarity_threshold)
seg_recall_score = segment_recall(recall_by_theme)
print(f"Segment Precision: {seg_prec_score:.2f}")
print(f"Segment Recall: {seg_recall_score:.2f}")