In [7]:
CONFIG_TAG = "ai_therapy"
CONFIG_FILE = f"configs/{CONFIG_TAG}.toml"

In [8]:
import toml
import pandas as pd
from openai import OpenAI
from datasource import RedditDataSource, HackerNewsDataSource

from processor import Processor, FILTER_PROMPT
import visualizer

secrets = toml.load("secrets.toml")
config = toml.load(CONFIG_FILE)

ANALYSIS_USE_CASE = config["use_case"]
FIELDS = config["fields"]

DEEPINFRA_CONFIG = secrets.get("deepinfra", {})
DEEPINFRA_LLM = OpenAI(
  api_key=DEEPINFRA_CONFIG.get("api_key"),
  base_url=DEEPINFRA_CONFIG.get("base_url"), 
)

OPENAI_CONFIG = secrets.get("openai", {})
OPENAI_LLM = OpenAI(
  api_key=OPENAI_CONFIG.get("api_key"),
  base_url=OPENAI_CONFIG.get("base_url"), 
)

processor = Processor(
    use_case_description=ANALYSIS_USE_CASE,
    filter_prompt=FILTER_PROMPT,
    extraction_schema=FIELDS,
    filter_llm_client=DEEPINFRA_LLM,
    filter_model="meta-llama/Llama-3.3-70B-Instruct-Turbo",  # Cheaper model for filtering
    extract_llm_client=DEEPINFRA_LLM,
    extract_model="meta-llama/Meta-Llama-3.1-405B-Instruct",      # More accurate model for extraction
)

In [3]:
data_sources = []

for ds_conf in config["data_sources"]:
    ds_type = ds_conf.pop("type", None)
    if ds_type == "reddit":
        data_sources.append(RedditDataSource(**ds_conf))
    elif ds_type == "hackernews":
        data_sources.append(HackerNewsDataSource(**ds_conf))
    else:
        print(f"Unknown data source type: {ds_type}")

In [4]:
dfs = [source.get_data() for source in data_sources]
df = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='id')

In [9]:
# df_filtered = processor.filter_data(df.sample(n=100, random_state=42))  # For testing with a small random sample
df_filtered = processor.filter_data(df)  # 1st level filtering
df_extracted = processor.extract_fields(df_filtered)  # Extraction of structured data
samples = df_extracted[df_extracted['relevant_sample'] == True]  # 2nd level filtering
print(f"Samples: initial={len(df)}, after 1st filter={len(df_filtered[df_filtered['is_relevant']==True])}, final={len(samples)}")

Extracting fields: 100%|██████████| 420/420 [12:48<00:00,  1.83s/it]

Samples: initial=1000, after 1st filter=420, final=401





In [None]:
# # Backup dataframes to output files
# df.to_json(f"out/{CONFIG_TAG}.json", orient='records', date_format='iso')
# df_filtered.to_json(f"out/{CONFIG_TAG}-filtered.json", orient='records', date_format='iso') 
# samples.to_json(f"out/{CONFIG_TAG}-samples.json", orient='records', date_format='iso')

In [11]:
# Simple plots are created automatically by vizualizer
viz = visualizer.Visualization(samples, FIELDS)
viz.plot_all_fields(show_examples=True)
viz.show_samples(n=3, extra_fields=['sentiment','use_cases'])
viz.plot_by_time('created_utc', "Posts Over Time")
# viz.plot_group_comparison('subreddit', 'sentiment', agg='mean')
# viz.plot_correlation(['sentiment'])