In [1]:
import os
import chromadb
from chromadb import Documents, Embeddings, EmbeddingFunction
import sys

from mlx_embedding_models.embedding import EmbeddingModel

sys.path.insert(0, "../")
from db import HindsightDB
from config import DATA_DIR
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
db = HindsightDB()

In [22]:
db.release_lock("chromadb")

In [21]:
db.acquire_lock("chromadb")

True

In [11]:
frames_df = db.get_non_chromadb_processed_frames_with_ocr()

In [12]:
frames_df

Unnamed: 0,id,timestamp,path,application,chromadb_processed
0,100397,1718636489126,/Users/connorparish/code/hindsight/hindsight_s...,com-connor-hindsight,0
1,100398,1718636547237,/Users/connorparish/code/hindsight/hindsight_s...,com-android-systemui,0
2,100399,1718637248112,/Users/connorparish/code/hindsight/hindsight_s...,com-android-systemui,0
3,100400,1718637250186,/Users/connorparish/code/hindsight/hindsight_s...,com-connor-hindsight,0
4,100401,1718637252253,/Users/connorparish/code/hindsight/hindsight_s...,com-android-systemui,0
...,...,...,...,...,...
355,100751,1718640451932,/Users/connorparish/code/hindsight/hindsight_s...,com-google-android-apps-nexuslauncher,0
356,100753,1718640550797,/Users/connorparish/code/hindsight/hindsight_s...,com-connor-hindsight,0
357,100754,1718640558887,/Users/connorparish/code/hindsight/hindsight_s...,com-connor-hindsight,0
358,100755,1718640729905,/Users/connorparish/code/hindsight/hindsight_s...,com-android-systemui,0


In [13]:
frames = db.get_frames()
frames = utils.add_datetimes(frames)
frames = frames.sort_values(by='datetime_local', ascending=True)
ocr_results_df = db.get_frames_with_ocr()

In [14]:
class MLXEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_id="bge-base"):
        self.embedding_model = EmbeddingModel.from_registry(model_id)

    def __call__(self, input: Documents) -> Embeddings:
        return self.embedding_model.encode(input).tolist()

In [15]:
embedding_function = MLXEmbeddingFunction()



In [19]:
chroma_db_path = os.path.join(DATA_DIR, "chromadb")
chroma_client = chromadb.PersistentClient(path=chroma_db_path)
chroma_collection = chroma_client.get_or_create_collection("pixel_screenshots", embedding_function=embedding_function)

In [20]:
ingested_ids = [int(i) for i in chroma_collection.get()['ids']]

In [21]:
frames = frames.loc[~(frames['id'].isin(ingested_ids))]

In [22]:
def get_screenshot_preprompt(application, timestamp):
    return f"""Description: Text from a screenshot of {application} with UTC timestamp {timestamp}: \n""" + "-"*20 + "/n"

def get_chromadb_text(ocr_result, application, timestamp):
    frame_cleaned_text = utils.ocr_results_to_str(ocr_result)
    frame_text = get_screenshot_preprompt(application, timestamp) + frame_cleaned_text
    return frame_text

def get_chromadb_metadata(row):
    return {"frame_id" : row['id'], "application" : row['application'], "timestamp" : row['timestamp']}

In [26]:
batch_size = 1000
num_batches = len(frames) // batch_size + (1 if len(frames) % batch_size > 0 else 0)
for i in range(num_batches):
    print(i)
    # Extract the batch
    start_index = i * batch_size
    end_index = start_index + batch_size
    frames_batch = frames.iloc[start_index:end_index]

    documents = list()
    metadatas = list()
    ids = list()
    for i, row in frames_batch.iterrows():
        # ocr_result = db.get_ocr_results(frame_id=row['id'])
        ocr_result = ocr_results_df.loc[ocr_results_df['frame_id'] == row['id']]
        if len(ocr_result) == 0 or set(ocr_result['text']) == {None}:
            continue
        documents.append(get_chromadb_text(ocr_result=ocr_result, application=row['application'], timestamp=row['timestamp']))
        metadatas.append(get_chromadb_metadata(row))
        ids.append(str(row['id']))
    
    if len(documents) == 0:
        continue
    
    chroma_collection.add(
        documents=documents,
        metadatas=metadatas,
        ids=ids
    )

0
1
2
3
4


# Try querying

In [27]:
results = chroma_collection.query(
        query_texts=["How much battery does Hindsight use?"],
        n_results=10,
        # where={"metadata_field": "is_equal_to_this"},
        # where_document={"$contains":"search_string"}
    )

100%|██████████| 1/1 [00:00<00:00,  4.14it/s, seq_len=16]


In [28]:
results

{'ids': [['40563',
   '14911',
   '40481',
   '14829',
   '14830',
   '40482',
   '40415',
   '14763',
   '14760',
   '40412']],
 'distances': [[0.6643862724304199,
   0.6643862724304199,
   0.6859386563301086,
   0.6859386563301086,
   0.6958475112915039,
   0.6958475112915039,
   0.6958898305892944,
   0.6958898305892944,
   0.7418527007102966,
   0.7418527007102966]],
 'metadatas': [[{'application': 'com-google-android-as',
    'frame_id': 40563,
    'timestamp': 1716347348791},
   {'application': 'com-google-android-as',
    'frame_id': 14911,
    'timestamp': 1716347348791},
   {'application': 'com-google-android-apps-messaging',
    'frame_id': 40481,
    'timestamp': 1716347657684},
   {'application': 'com-google-android-apps-messaging',
    'frame_id': 14829,
    'timestamp': 1716347657684},
   {'application': 'com-google-android-apps-messaging',
    'frame_id': 14830,
    'timestamp': 1716347655656},
   {'application': 'com-google-android-apps-messaging',
    'frame_id': 40482