In [4]:
import ast  # for converting embeddings saved as strings back to arrays
from openai import OpenAI # for calling the OpenAI API
import pandas as pd  # for storing text and embeddings data
import tiktoken  # for counting tokens
import os # for getting API token from env variable OPENAI_API_KEY
from scipy import spatial  # for calculating vector similarities for search

# create a list of models 
GPT_MODELS = ["gpt-4o", "gpt-4o-mini"]
# models
EMBEDDING_MODEL = "text-embedding-3-small"



In [9]:
# imports
import mwclient  # for downloading example Wikipedia articles
import mwparserfromhell  # for splitting Wikipedia articles into sections
from openai import OpenAI  # for generating embeddings
import os  # for environment variables
import pandas as pd  # for DataFrames to store article sections and embeddings
import re  # for cutting <ref> links out of Wikipedia articles
import tiktoken  # for counting tokens



In [11]:
# get Wikipedia pages about the 2022 Winter Olympics

CATEGORY_TITLE = "Category:Cosmetics"
WIKI_SITE = "en.wikipedia.org"


def titles_from_category(
    category: mwclient.listing.Category, max_depth: int
) -> set[str]:
    """Return a set of page titles in a given Wiki category and its subcategories."""
    titles = set()
    for cm in category.members():
        if type(cm) == mwclient.page.Page:
            # ^type() used instead of isinstance() to catch match w/ no inheritance
            titles.add(cm.name)
        elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:
            deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)
            titles.update(deeper_titles)
    return titles


site = mwclient.Site(WIKI_SITE)
category_page = site.pages[CATEGORY_TITLE]
titles = titles_from_category(category_page, max_depth=1)
# ^note: max_depth=1 means we go one level deep in the category tree
print(f"Found {len(titles)} article titles in {CATEGORY_TITLE}.")


Found 704 article titles in Category:Cosmetics.


In [12]:
# define functions to split Wikipedia pages into sections

SECTIONS_TO_IGNORE = [
    "See also",
    "References",
    "External links",
    "Further reading",
    "Footnotes",
    "Bibliography",
    "Sources",
    "Citations",
    "Literature",
    "Footnotes",
    "Notes and references",
    "Photo gallery",
    "Works cited",
    "Photos",
    "Gallery",
    "Notes",
    "References and sources",
    "References and notes",
]


def all_subsections_from_section(
    section: mwparserfromhell.wikicode.Wikicode,
    parent_titles: list[str],
    sections_to_ignore: set[str],
) -> list[tuple[list[str], str]]:
    """
    From a Wikipedia section, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    """
    headings = [str(h) for h in section.filter_headings()]
    title = headings[0]
    if title.strip("=" + " ") in sections_to_ignore:
        # ^wiki headings are wrapped like "== Heading =="
        return []
    titles = parent_titles + [title]
    full_text = str(section)
    section_text = full_text.split(title)[1]
    if len(headings) == 1:
        return [(titles, section_text)]
    else:
        first_subtitle = headings[1]
        section_text = section_text.split(first_subtitle)[0]
        results = [(titles, section_text)]
        for subsection in section.get_sections(levels=[len(titles) + 1]):
            results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))
        return results


def all_subsections_from_title(
    title: str,
    sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,
    site_name: str = WIKI_SITE,
) -> list[tuple[list[str], str]]:
    """From a Wikipedia page title, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    """
    site = mwclient.Site(site_name)
    page = site.pages[title]
    text = page.text()
    parsed_text = mwparserfromhell.parse(text)
    headings = [str(h) for h in parsed_text.filter_headings()]
    if headings:
        summary_text = str(parsed_text).split(headings[0])[0]
    else:
        summary_text = str(parsed_text)
    results = [([title], summary_text)]
    for subsection in parsed_text.get_sections(levels=[2]):
        results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))
    return results


In [13]:
# split pages into sections
# may take ~1 minute per 100 articles
wikipedia_sections = []
for title in titles:
    wikipedia_sections.extend(all_subsections_from_title(title))
print(f"Found {len(wikipedia_sections)} sections in {len(titles)} pages.")


Found 4460 sections in 704 pages.


In [14]:
# clean text
def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:
    """
    Return a cleaned up section with:
        - <ref>xyz</ref> patterns removed
        - leading/trailing whitespace removed
    """
    titles, text = section
    text = re.sub(r"<ref.*?</ref>", "", text)
    text = text.strip()
    return (titles, text)


wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]

# filter out short/blank sections
def keep_section(section: tuple[list[str], str]) -> bool:
    """Return True if the section should be kept, False otherwise."""
    titles, text = section
    if len(text) < 16:
        return False
    else:
        return True


original_num_sections = len(wikipedia_sections)
wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]
print(f"Filtered out {original_num_sections-len(wikipedia_sections)} sections, leaving {len(wikipedia_sections)} sections.")


Filtered out 302 sections, leaving 4158 sections.


In [15]:
# print example data
for ws in wikipedia_sections[:5]:
    print(ws[0])
    display(ws[1][:77] + "...")
    print()


['Sa Sa International Holdings']


'{{Short description|Hong Kong retail chain}}\n{{Use dmy dates|date=June 2023}}...'


['Beautycounter']


'{{Short description|American skincare and cosmetics company}}\n{{Use mdy dates...'


['Beautycounter', '==History==']


'Beautycounter was founded by [[Gregg Renfrew]] in 2013.<ref name=fc/> Renfrew...'


['Beautycounter', '==Legislation==']


'In 2014, Renfrew hired public health and environmental advocate Lindsay Dahl ...'


['Paintbrush']


'{{Short description|Brush for painting}}\n{{Other uses}}\n{{Multiple issues|\n{{...'




In [16]:
GPT_MODEL = "gpt-4o-mini"  # only matters insofar as it selects which tokenizer to use


def num_tokens(text: str, model: str = GPT_MODEL) -> int:
    """Return the number of tokens in a string."""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))


def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str, str]:
    """Split a string in two, on a delimiter, trying to balance tokens on each side."""
    chunks = string.split(delimiter)
    if len(chunks) == 1:
        return [string, ""]  # no delimiter found
    elif len(chunks) == 2:
        return chunks  # no need to search for halfway point
    else:
        total_tokens = num_tokens(string)
        halfway = total_tokens // 2
        best_diff = halfway
        for i, chunk in enumerate(chunks):
            left = delimiter.join(chunks[: i + 1])
            left_tokens = num_tokens(left)
            diff = abs(halfway - left_tokens)
            if diff >= best_diff:
                break
            else:
                best_diff = diff
        left = delimiter.join(chunks[:i])
        right = delimiter.join(chunks[i:])
        return [left, right]


def truncated_string(
    string: str,
    model: str,
    max_tokens: int,
    print_warning: bool = True,
) -> str:
    """Truncate a string to a maximum number of tokens."""
    encoding = tiktoken.encoding_for_model(model)
    encoded_string = encoding.encode(string)
    truncated_string = encoding.decode(encoded_string[:max_tokens])
    if print_warning and len(encoded_string) > max_tokens:
        print(f"Warning: Truncated string from {len(encoded_string)} tokens to {max_tokens} tokens.")
    return truncated_string


def split_strings_from_subsection(
    subsection: tuple[list[str], str],
    max_tokens: int = 1000,
    model: str = GPT_MODEL,
    max_recursion: int = 5,
) -> list[str]:
    """
    Split a subsection into a list of subsections, each with no more than max_tokens.
    Each subsection is a tuple of parent titles [H1, H2, ...] and text (str).
    """
    titles, text = subsection
    string = "\n\n".join(titles + [text])
    num_tokens_in_string = num_tokens(string)
    # if length is fine, return string
    if num_tokens_in_string <= max_tokens:
        return [string]
    # if recursion hasn't found a split after X iterations, just truncate
    elif max_recursion == 0:
        return [truncated_string(string, model=model, max_tokens=max_tokens)]
    # otherwise, split in half and recurse
    else:
        titles, text = subsection
        for delimiter in ["\n\n", "\n", ". "]:
            left, right = halved_by_delimiter(text, delimiter=delimiter)
            if left == "" or right == "":
                # if either half is empty, retry with a more fine-grained delimiter
                continue
            else:
                # recurse on each half
                results = []
                for half in [left, right]:
                    half_subsection = (titles, half)
                    half_strings = split_strings_from_subsection(
                        half_subsection,
                        max_tokens=max_tokens,
                        model=model,
                        max_recursion=max_recursion - 1,
                    )
                    results.extend(half_strings)
                return results
    # otherwise no split was found, so just truncate (should be very rare)
    return [truncated_string(string, model=model, max_tokens=max_tokens)]


In [17]:
# split sections into chunks
MAX_TOKENS = 1600
wikipedia_strings = []
for section in wikipedia_sections:
    wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))

print(f"{len(wikipedia_sections)} Wikipedia sections split into {len(wikipedia_strings)} strings.")


4158 Wikipedia sections split into 4255 strings.


In [18]:
# print example data
print(wikipedia_strings[1])


Beautycounter

{{Short description|American skincare and cosmetics company}}
{{Use mdy dates|date=July 2024}}
{{infobox company
| name = Beautycounter
| logo = Beautycounter_company_logo.png
| founded = 2013
| defunct = {{End date|2024|04|27}}
| founder = [[Gregg Renfrew]]
| key_people = 
| hq_location = [[Santa Monica, California]]
| industry = [[Skin care]], [[cosmetics]]<ref name=CNBC1/>
| owner = 
| website = {{URL|beautycounter.com}}
}}

'''Beautycounter''' was an American [[Direct-to-consumer|direct to consumer]] and [[multi-level marketing]] company that sold skin care and cosmetic products.<ref name=CNBC1/> As of 2018, the company had 150 products with over 65,000 independent consultants, and with national retailers.<ref name=crackdown/> In April 2021, Beautycounter was acquired by [[The Carlyle Group]] in a deal that valued the company at $1 billion. In March 2024, Carlyle wrote off its investment in the company and the company went into administration in April 2024.


In [23]:
EMBEDDING_MODEL = "text-embedding-3-small"
BATCH_SIZE = 2000  # you can submit up to 2048 embedding inputs per request

embeddings = []
for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):
    batch_end = batch_start + BATCH_SIZE
    batch = wikipedia_strings[batch_start:batch_end]
    print(f"Batch {batch_start} to {batch_end-1}")
    response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
    for i, be in enumerate(response.data):
        assert i == be.index  # double check embeddings are in same order as input
    batch_embeddings = [e.embedding for e in response.data]
    embeddings.extend(batch_embeddings)

df = pd.DataFrame({"text": wikipedia_strings, "embedding": embeddings})


Batch 0 to 1999
Batch 2000 to 3999
Batch 4000 to 5999


In [24]:
# save document chunks and embeddings

SAVE_PATH = "data/cosmetics.csv"

df.to_csv(SAVE_PATH, index=False)


In [25]:
# search function
def strings_ranked_by_relatedness(
    query: str,
    df: pd.DataFrame,
    relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
    top_n: int = 100
) -> tuple[list[str], list[float]]:
    """Returns a list of strings and relatednesses, sorted from most related to least."""
    query_embedding_response = client.embeddings.create(
        model=EMBEDDING_MODEL,
        input=query,
    )
    query_embedding = query_embedding_response.data[0].embedding
    strings_and_relatednesses = [
        (row["text"], relatedness_fn(query_embedding, row["embedding"]))
        for i, row in df.iterrows()
    ]
    strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
    strings, relatednesses = zip(*strings_and_relatednesses)
    return strings[:top_n], relatednesses[:top_n]


In [40]:
# examples
strings, relatednesses = strings_ranked_by_relatedness("sunscreen", df, top_n=5)
for string, relatedness in zip(strings, relatednesses):
    print(f"{relatedness=:.3f}")
    display(string)

relatedness=0.652


'Sunscreen\n\n== Health effects ==\n\n{{See also|Health effects of sunlight exposure}}'

relatedness=0.636


'Sunscreen\n\n{{Short description|Skin product helping to prevent sunburn}}\n{{Redirect|Sunblock|the electronic music group|Sunblock (band)|the Ball Park Music song|Sunscreen (song)}}\n{{Distinguish|text=[[indoor tanning lotion]], the suntan lotion which intensifies sun exposure}}\n{{Use mdy dates|date=June 2015}}\n{{Infobox medical intervention\n| name         = Sunscreen\n| image        = Sunscreen on back under normal and UV light.jpg\n| caption      = Sunscreen on back under normal and UV light\n| alt          = \n| pronounce    =  \n| synonyms     = Sun screen, sunblock, sunburn cream, sun cream, block out\n| ICD10        = \n| ICD9         = \n| ICD9unlinked = \n| MeshID       = \n| LOINC        = \n| other_codes  = \n| MedlinePlus  = \n| eMedicine    = \n}}\n\n\'\'\'Sunscreen\'\'\', also known as \'\'\'sunblock\'\'\',{{efn|Sunblock and sunscreen are often used as synonyms. However, the term "sunblock" is controversial and banned in the EU and USA), sticks, powders and other topi

relatedness=0.607


'Sunscreen\n\n== Measurements of protection ==\n\n[[File:Sunburn blisters.jpg|thumb | Sunscreen helps prevent [[sunburn]], such as this, which has blistered.]]'

relatedness=0.604


'Sunscreen\n\n== History ==\n\n[[File:Malagasy Woman (26905387615).jpg|thumb|[[Malagasy people|Malagasy]] woman from [[Madagascar]] wearing [[masonjoany]], a traditional sunscreen whose use dates back to the 18th century]]\n[[File:Thanaka girls.JPG|thumb|Burmese girls wearing \'\'[[thanaka]]\'\' for sun protection and cosmetic purposes]] \nEarly civilizations used a variety of plant products to help protect the skin from sun damage. For example, [[ancient Greeks]] used olive oil for this purpose, and [[ancient Egypt]]ians used extracts of rice, jasmine, and lupine plants whose products are still used in skin care today. Zinc oxide paste has also been popular for skin protection for thousands of years. Among the nomadic sea-going [[Sama-Bajau people]] of the [[Philippines]], [[Malaysia]], and [[Indonesia]], a common type of sun protection is a paste called \'\'[[borak (cosmetic)|borak]]\'\' or \'\'[[borak (cosmetic)|burak]]\'\', which was made from water weeds, rice and spices. It is us

relatedness=0.603


'Sunscreen\n\n== Active ingredients ==\n\n{{anchor|table}}\n{| class="wikitable"\n! UV-filter\n! Other names\n! Maximum concentration\n! Known permitting jurisdictions\n! Results of safety testing\n! UVA\n! UVB\n|-\n| [[p-Aminobenzoic acid]]\n| PABA\n| 15% (USA), (EU: banned from sale to consumers from 8 October 2009)\n| USA\n| Protects against skin tumors in mice.  Shown to increase DNA defects, and not [[Generally recognized as safe and effective|generally recognised as safe and effective]] according to the [[Food and Drug Administration|FDA]]<ref name="Center for Drug Evaluation and Research_2022" />\n|\n|  X\n|-\n| [[Padimate O]]\n| OD-PABA, octyldimethyl-PABA, σ-PABA\n| 8% (USA, AUS) 10% (JP)\n(Not currently supported in EU and may be delisted)\n| EU, USA, AUS, JP\n| \n|\n|  X\n|-\n| [[Phenylbenzimidazole sulfonic acid]]\n| Ensulizole, PBSA\n| 4% (USA, AUS) 8% (EU) 3% (JP)\n| EU, USA, AUS, JP\n| Genotoxic in bacteria\n|\n|  X\n|-\n| [[Cinoxate]]\n| 2-Ethoxyethyl p-methoxycinnamate

In [38]:
def num_tokens(text: str, model: str = GPT_MODELS[0]) -> int:
    """Return the number of tokens in a string."""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))


def query_message(
    query: str,
    df: pd.DataFrame,
    model: str,
    token_budget: int
) -> str:
    """Return a message for GPT, with relevant source texts pulled from a dataframe."""
    strings, relatednesses = strings_ranked_by_relatedness(query, df)
    introduction = (
    "Use the below articles to answer the subsequent question as best as you can. "
    "If you're not sure, provide your best guess based on the information."
)

    question = f"\n\nQuestion: {query}"
    message = introduction
    for string in strings:
        next_article = f'\n\nWikipedia article section:\n"""\n{string}\n"""'
        if (
            num_tokens(message + next_article + question, model=model)
            > token_budget
        ):
            break
        else:
            message += next_article
    return message + question


def ask(
    query: str,
    df: pd.DataFrame = df,
    model: str = GPT_MODELS[0],
    token_budget: int = 4096 - 500,
    print_message: bool = False,
) -> str:
    """Answers a query using GPT and a dataframe of relevant texts and embeddings."""
    message = query_message(query, df, model=model, token_budget=token_budget)
    if print_message:
        print(message)
    messages = [
        {"role": "system", "content": "Your answer"},
        {"role": "user", "content": message},
    ]
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0
    )
    response_message = response.choices[0].message.content
    return response_message



In [39]:
ask('What UV filters are banned?')

'The UV filters that are banned include:\n\n1. **p-Aminobenzoic acid (PABA)** - Banned in the EU from sale to consumers since October 2009 and banned in the USA in 2021 due to safety concerns.\n2. **Trolamine salicylate** - Banned in the USA in 2021 due to safety concerns.\n3. **Oxybenzone** - Banned in Hawaii since 2018 due to environmental concerns, particularly its harmful effects on coral reefs.\n4. **Octinoxate** - Banned in Hawaii since 2021 due to environmental concerns, particularly its harmful effects on coral reefs.\n\nThese bans are primarily due to safety concerns for human health or environmental impacts, particularly on marine ecosystems.'

In [35]:
ask('What UV filters are banned?', print_message=True)

Use the below articles to answer the subsequent question. If the answer cannot be found in the articles, write "I could not find an answer."

Wikipedia article section:
"""
Sunscreen

== Active ingredients ==

{{anchor|table}}
{| class="wikitable"
! UV-filter
! Other names
! Maximum concentration
! Known permitting jurisdictions
! Results of safety testing
! UVA
! UVB
|-
| [[p-Aminobenzoic acid]]
| PABA
| 15% (USA), (EU: banned from sale to consumers from 8 October 2009)
| USA
| Protects against skin tumors in mice.  Shown to increase DNA defects, and not [[Generally recognized as safe and effective|generally recognised as safe and effective]] according to the [[Food and Drug Administration|FDA]]<ref name="Center for Drug Evaluation and Research_2022" />
|
|  X
|-
| [[Padimate O]]
| OD-PABA, octyldimethyl-PABA, σ-PABA
| 8% (USA, AUS) 10% (JP)
(Not currently supported in EU and may be delisted)
| EU, USA, AUS, JP
| 
|
|  X
|-
| [[Phenylbenzimidazole sulfonic acid]]
| Ensulizole, PBSA
| 

'I could not find an answer.'