In [1]:
CONFIG_FILE = "configs/ivf.toml"

# TODO: Do a first pass to get examples and fields automatically

In [2]:
import toml
import pandas as pd
from openai import OpenAI
from datasource import RedditDataSource
from processor import Processor, FILTER_PROMPT
import visualizer
from importlib import reload

config = toml.load(CONFIG_FILE)
ANALYSIS_USE_CASE = config["use_case"]
FIELDS = config["fields"]

secrets = toml.load("secrets.toml")

DEEPINFRA = "deepinfra"
DEEPINFRA_LLM = OpenAI(
  api_key=secrets.get(DEEPINFRA, {}).get("api_key"),
  base_url=secrets.get(DEEPINFRA, {}).get("base_url"), 
)

OPENAI = "openai"
OPENAI_LLM = OpenAI(
  api_key=secrets.get(OPENAI, {}).get("api_key"),
  base_url=secrets.get(OPENAI, {}).get("base_url"), 
)

processor = Processor(
    use_case_description=ANALYSIS_USE_CASE,
    filter_prompt=FILTER_PROMPT,
    extraction_schema=FIELDS,
    filter_llm_client=DEEPINFRA_LLM,
    filter_model="meta-llama/Llama-3.3-70B-Instruct-Turbo",  # Cheaper model for filtering
    extract_llm_client=OPENAI_LLM,
    extract_model="gpt-4o",      # More accurate model for extraction
)

In [3]:
data_sources = []
for ds_conf in config["data_sources"]:
    ds_type = ds_conf.pop("type", None)
    if ds_type == "reddit":
        data_sources.append(RedditDataSource(**ds_conf))
    else:
        print(f"Unknown data source type: {ds_type}")

In [4]:
dfs = [source.get_data() for source in data_sources]
df = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='id')

In [5]:
# df_filtered = processor.filter_data(df.sample(n=100, random_state=42))  # For testing only
df_filtered = processor.filter_data(df)  # 1st level filtering
df_extracted = processor.extract_fields(df_filtered)  # Extraction of structured data
samples = df_extracted[df_extracted['relevant_sample'] == True]  # 2nd level filtering
print(f"Samples: initial={len(df)}, after 1st filter={len(df_filtered[df_filtered['is_relevant']==True])}, final={len(samples)}")

Filtering data: 100%|██████████| 100/100 [00:51<00:00,  1.96it/s]
Filtering data: 100%|██████████| 618/618 [04:37<00:00,  2.23it/s]
Extracting fields: 100%|██████████| 139/139 [03:12<00:00,  1.38s/it]

Samples: initial=618, after 1st filter=139, final=117





In [7]:
reload(visualizer)
viz = visualizer.Visualization(samples, FIELDS)
viz.plot_all_fields(show_examples=True)
viz.show_samples(n=3, extra_fields=['sentiment','use_cases'])
viz.plot_by_time('created_utc', "Posts Over Time")
viz.plot_group_comparison('subreddit', 'sentiment', agg='mean')
viz.plot_correlation(['sentiment'])  