In [1]:
import string
from pathlib import Path
from typing import Optional

import pandas as pd
import rich
from pydantic import BaseModel, TypeAdapter
from tqdm import tqdm

# Import a model to verify the setup
from twfy_vector_explorer import notebook_setup as notebook_setup
from vector_explorer.models import NgramVector

pd.set_option("display.max_colwidth", None)

# Generic search for exploring - ngram

In [5]:
alts = ["mental health"]

false_positives = []


def fix_puntuation_items(t: str) -> str:
    words = t.split(" ")
    # remove any words that are just one bit of punctuation
    other_puntation = "".join(["”", "“", "’", "‘"])
    words = [w for w in words if w not in string.punctuation + other_puntation]
    # remove start or end words that are and, or, etc
    banned_words = [
        "and",
        "or",
        "of",
        "the",
        "in",
        "to",
        "a",
        "on",
        "for",
        "with",
        "by",
        "from",
        "as",
        "at",
        "an",
        "is",
        "are",
        "were",
    ]
    # check for banned words at start and end
    while words and words[0] in banned_words:
        words = words[1:]
    while words and words[-1] in banned_words:
        words = words[:-1]
    return " ".join(words)


def remove_future_contains(series: pd.Series) -> pd.Series:
    """
    Here we're trying to deal with the problem that we
    end up with
    'mental ill health' and 'of mental ill health' 'with mental ill health'
    We want to remove longer versions of the same thing.
    Return a mask that can be used to filter the series.
    We want to prioritise earlier items in the series.
    """
    banned_items = []
    # sort items by length
    sorted_series = series  # .sort_values(key=lambda x: x.str.len())
    for i, item in enumerate(sorted_series):
        remaining_items = sorted_series[i + 1 :]
        for ni, next_item in enumerate(remaining_items):
            if item in next_item:
                banned_items.append(next_item)

    return series.isin(banned_items)


class QueryMatch(BaseModel):
    text: str
    count: int
    distance: float


class SearchQuery(BaseModel):
    query: str
    nearest: list[QueryMatch]


SearchList = TypeAdapter(list[SearchQuery])


def find_related_ngrams(search_query: str) -> Optional[SearchQuery]:
    df = NgramVector.objects.all().search_distance(search_query, threshold=0.15).df()

    if df.empty:
        return None

    df = (
        df.assign(text=lambda df: df["text"].str.strip())
        .drop_duplicates(subset="text", keep="first")
        .drop(columns=["embedding", "id"])
    )

    df = df.sort_values("distance", ascending=True).reset_index(drop=True)

    df["text"] = df["text"].apply(fix_puntuation_items)
    # remove duplicates
    df = df.drop_duplicates(subset="text", keep="first")
    # drop future contains
    future_mask = remove_future_contains(df["text"])
    df = df[~future_mask]

    # remove if text is the original query
    df = df[df["text"] != search_query]

    # no matches that are an exact match - want to find similar items
    response = SearchQuery(
        query=search_query, nearest=df.head(10).to_dict(orient="records")  # type: ignore
    ) 
    return response

In [6]:
items: list[SearchQuery] = []

queries_df = pd.read_csv(Path("..", "data", "sample_queries.csv"))

queries = queries_df["queries"].tolist()

for query in tqdm(queries):
    if not query:
        continue
    result = find_related_ngrams(query)
    if not result:
        continue
    if len(result.nearest) > 0:
        items.append(result)

rich.print(items)

100%|██████████| 997/997 [08:58<00:00,  1.85it/s]


In [8]:

with Path("..", "data", "search_matches.json").open("wb") as f:
    f.write(SearchList.dump_json(items, indent=2))

In [9]:
rows = []
for item in items:
    for match in item.nearest:
        rows.append(
            {
                "query": item.query,
                "text": match.text, 
                "count": match.count,
                "distance": match.distance
            }
        )

results_df = pd.DataFrame(rows)

results_df.to_csv(Path("..", "data", "search_matches.csv"), index=False)