In [7]:
MAX_CONCURRENCY=256

In [8]:
import os
import re
from datetime import datetime

PROVIDERS = {
    "google": {
        "raw": "Takeout/My Activity/Search",
        "parsed": "google/search_history",
        "summary": "google/search_history_summary",
        "context": "",
        "interests": ""
    }
}


def get_filenames(
    kind="parsed", start_date=None, end_date=None, provider="google"
):
    directory = os.path.join("..", "_data", kind, PROVIDERS[provider][kind])
    if start_date is not None:
        start_date = datetime.strptime(start_date, "%Y-%m-%d")
    if end_date is not None:
        end_date = datetime.strptime(end_date, "%Y-%m-%d")
    file_pattern = r"^(\d{4}-\d{2}-\d{2})\.(csv|json)$"

    def is_date_in_range(file_date):
        if start_date is None and end_date is None:
            return True
        else:
            return start_date <= datetime.strptime(file_date, "%Y-%m-%d") <= end_date

    filenames = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            match = re.match(file_pattern, file)
            if match:
                file_date = match.groups()[0]
                if is_date_in_range(file_date):
                    filenames.append(os.path.join(root, file))

    return filenames

In [9]:
from openai import OpenAI
import os
import pandas as pd

In [10]:
from openai import AsyncOpenAI
import httpx

custom_client = AsyncOpenAI(
  http_client=httpx.AsyncClient(
    limits=httpx.Limits(
      max_connections=MAX_CONCURRENCY,
      max_keepalive_connections=MAX_CONCURRENCY
    ),
    timeout=60*10
  ),
  base_url="https://wao06rxq3acms1-8000.proxy.runpod.net/v1"
)

no_match = 0
errors = 0

async def summarize_interests(prompt):
  global custom_client

  summarization_prompt = "What interests can you find in the following search records? \n"

  try:
    answer1 = await custom_client.chat.completions.create(
      model="mistralai/Mistral-7B-Instruct-v0.2",
      messages=[
        {"role": "user", "content": summarization_prompt+prompt},
      ]
    )

    answer2 = await custom_client.chat.completions.create(
      model="mistralai/Mistral-7B-Instruct-v0.2",
      messages=[
        {"role": "user", "content": summarization_prompt+prompt}, 
        {"role": "assistant", "content": answer1.choices[0].message.content},
        {"role": "user", "content": "Summarize the previous answer as a comma-separated array of strings."},
      ]
    )

    raw = answer2.choices[0].message.content


    match = re.search(r'\[(.*?)\]', raw)
    if match:
        # If a match is found, split the substring by comma
        return match.group(1).replace("\"", "").replace("'","").split(",")
    else:
        global no_match
        no_match += 1
        return []

  except Exception as e:
    global errors
    errors += 1
    return []
  

In [11]:
from collections import defaultdict
from tqdm.asyncio import tqdm_asyncio


chunk_size = 35

interests = defaultdict(list)
tasks_dict = defaultdict(list)

for filename in get_filenames():
    df = pd.read_csv(filename)
    date = filename.split("/")[-1].split(".")[0]

    # if os.path.exists(f"../_data/summary_embeddings/{date}.npy"):
    #     continue

    inputs = df["title"].tolist()

    for i in range(0, len(inputs), chunk_size):
        tasks_dict[date].append(summarize_interests("\n".join(inputs[i:i+chunk_size])))


In [None]:
from asyncio import Semaphore

wrapped_tasks = []

async def wrap_task_with_date(sem, date, t):
    async with sem:
        result = await t
        return (date, result)

sem = Semaphore(MAX_CONCURRENCY)
for date, tasks in tasks_dict.items():
    wrapped_tasks.extend([wrap_task_with_date(sem, date, task) for task in tasks])

# Await all wrapped tasks

import json

results_dict = defaultdict(list)
results = await tqdm_asyncio.gather(*wrapped_tasks, smoothing=0)
    
for date, result in results:
    results_dict[date].extend(result)
    json.dump(results_dict[date], open(f"../_data/interests/{date}.json", "w"))

In [None]:
no_match, errors

(109, 910)