# Predict `GenItem` variant for Beauty or Steam dataset

In [None]:
import openai
import re
from typing import Literal
import pickle
import pandas as pd
import numpy as np
import json
from tqdm.notebook import tqdm
from openai import AsyncOpenAI
from collections import Counter

In [None]:
# Fill in name of finetuned_model.
MODEL_NAME = ""

DATASET: Literal["beauty", "steam"] = "beauty"

# Name of the pickle with the test data for Beauty.
TEST_DATA_PICKLE_NAME = f"test_data_{DATASET}.pickle"

# Name of the embeddings DF for 
EMBEDDINGS_NAME = f"embeddings_{DATASET}.csv.gz"

# Fill in OpenAI key
OPENAI_KEY = ""

# Hyperparameters
TOP_K = 20
TEMPERATURE = 0
TOP_P = 1.0

# Correspond to respectively 4.1 to 4.4
VARIANT: Literal["genitem", "genlist", "class", "rank"] = "genitem"
TOTAL_MODEL_NAME = f"{MODEL_NAME}_{VARIANT}_temp_{TEMPERATURE}_top_p_{TOP_P}"

In [None]:
def parse_genitem(completion: str) -> str:
    return completion.strip()

In [None]:
system_message = {
        "role": "system",
        "content": 
"""Provide a unique item recommendation that is complementary to the user's item list. 
Ensure the recommendation is from items included in the data you are fine-tuned with. List only the item name.
""",
    }
user_message = {
    "role": "user",
    "content": "The user's item list:\n{ITEMS}",
}
parse_method = parse_genitem

## Load test prompts

We expect a pickle in the form:
```
{
    SESSION_ID : TEST_PROMPTS, TEST_GROUND_TRUTHS,
    ...
}
```
For example:
```
{
    13: ([1, 2, 3], [4])
}
```

In [None]:
test_prompts, _ = pickle.load(open(f"{TEST_DATA_PICKLE_NAME}", "rb"))
test_prompts[list(test_prompts.keys())[0]]

## Get embeddings and build lookup tables

In [None]:
product_embeddings = pd.read_csv(
    f"{EMBEDDINGS_NAME}", compression="gzip"
)
product_embeddings

In [None]:
product_id_to_name = (
    product_embeddings[["ItemId", "name"]]
    .set_index("ItemId")
    .to_dict()["name"]
)
product_name_to_id = (
    product_embeddings[["ItemId", "name"]]
    .set_index("name")
    .to_dict()["ItemId"]
)
product_index_to_embedding = (
    product_embeddings[["ItemId", "embedding"]]
    .set_index("ItemId")
    .to_dict()["embedding"]
)
product_index_to_embedding = {
    k: np.array(json.loads(v)) for k, v in product_index_to_embedding.items()
}
product_index_to_embedding = np.array(list(product_index_to_embedding.values()))
product_index_to_id = list(product_id_to_name.keys())
product_id_to_index = {idx: i for i, idx in enumerate(product_index_to_id)}

## Compute test prompts

In [None]:
test_messages: list[tuple[int, list[str]]] = []

for session_id, prompt in test_prompts.items():
    custom_user_message = user_message.copy()
    custom_user_message["content"] = custom_user_message["content"].replace("{ITEMS}", "\n".join([product_id_to_name[i] for i in prompt]))
    test_messages.append((session_id, [system_message, custom_user_message]))
test_messages[0]

# Compute completions

In [None]:
import asyncio
import time
completions: list[tuple[int, str]] = []

# Use async API to get parallel requests.
# Make sure batch_size is not too high otherwise we might hit rate limits.
async def run_completions():
    client = AsyncOpenAI(
        api_key=OPENAI_KEY,
    )

    batch_size = 150
    for i in tqdm(range(0, len(test_messages), batch_size)):
        start_batch = i
        end_batch = i + batch_size

        start_time = time.perf_counter()
        print(f"Completion batch {start_batch} - {end_batch}")

        requests = []
        for _, messages in test_messages[start_batch:end_batch]:
            requests.append(
                client.chat.completions.create(
                    model=MODEL_NAME,
                    temperature=TEMPERATURE,
                    top_p=TOP_P,
                    messages=messages,
                )
            )
        responses = await asyncio.gather(*requests)
        for (session_id, _), response in zip(test_messages[start_batch:end_batch], responses):
            completions.append((session_id, response.choices[0].message.content))
            
        print(f"Finished batch {start_batch} - {end_batch}. Took {time.perf_counter() - start_time} seconds.")


await run_completions()

In [None]:
pickle.dump(completions, open(f"completions_openai_{TOTAL_MODEL_NAME}.pkl", "wb"))

### Parse completions


In [None]:
parsed_completions: list[tuple[int, list[str]]] = []
for session_id, response in tqdm(completions):
    parsed_response: list[str] = parse_method(response)
    if parsed_response is None:
        break
    parsed_completions.append((session_id, parsed_response))
parsed_completions[0]

In [None]:
parsed_completions = []
for session_id, c in completions.items():
    recs = []
    for rec, count in c.items():
        for _ in range(count):
            recs.append(rec)
    
    parsed_completions.append((session_id, Counter(recs)))
parsed_completions[0]

# Completed product names to global product ids
First we try to map to the exact product name and otherwise we use embeddings to find the closest item. 

In [None]:
unmappable_items: set = set()
for session_id, items_counter in tqdm(parsed_completions):
    items = list(items_counter.keys())
    for item in items:
        if item not in product_name_to_id:
            unmappable_items.add(item)
print(f"No exact match for {len(unmappable_items)} items. Will use embedding based search to find closest item.")
len(unmappable_items) 

In [None]:
import time
unmappable_items_embeddings: dict[str, list[float]] = {}
unmappable_items: list[str] = list(unmappable_items)
async def get_embeddings():
    client = AsyncOpenAI(
        api_key=OPENAI_KEY,
    )

    batch_size = 2000
    for i in tqdm(range(0, len(unmappable_items), batch_size)):
        start_batch = i
        end_batch = i + batch_size

        start_time = time.perf_counter()
        print(f"Embeddings batch {start_batch} - {end_batch}")
        response = await client.embeddings.create(input = unmappable_items[start_batch:end_batch], model="text-embedding-ada-002")
        for item, embedding in zip(unmappable_items[start_batch:end_batch], response.data):
            unmappable_items_embeddings[item] = embedding.embedding
            
        print(f"Finished batch {start_batch} - {end_batch}. Took {time.perf_counter() - start_time} seconds.")
await get_embeddings()

Find closest actual item (with global product id) . Try to prevent duplicates.

In [None]:
import random

recommendations = {}
bug_item_list = []
num_sessions_done = 0
for session_id, value_counts in tqdm(parsed_completions):
    session_item_names = [product_id_to_name[item] for item in test_prompts[session_id]]
    session_recommendations = []

    duplicate_replacements = []
    for item_name, count in value_counts.items():
        # If an item occurs more than once, we need its embedding to find
        # neighbouring items.
        # If an item is not in the catalog, we get a similar item that is in the catalog.
        if count > 1 or item_name not in product_name_to_id:
            # Assert that the item is in the cache, otherwise we would
            # retrieve these embeddings from openAI again, which is slow and expensive.
            if item_name == "":
                # This always happens when item_name is an empty string, so we just
                # create a zero embedding.
                item_embedding = np.zeros((1, 1024 + 512))
            else:
                # Get item similarity using embedding
                if item_name in product_name_to_id:
                    item_embedding = product_index_to_embedding[product_id_to_index[product_name_to_id[item_name]]]
                else:
                    item_embedding = unmappable_items_embeddings[item_name]
                if isinstance(item_embedding, str):
                    item_embedding = json.loads(item_embedding)

                item_embedding = np.array([item_embedding], dtype=np.float64)

            predictions = (product_index_to_embedding @ item_embedding.T).T[0]

            # Get neighbouring item(s), and extend the recommendations for this
            # session with the neighbouring item(s).
            top_k_item_ids_indices = predictions.argsort()[::-1][:count + TOP_K]
            top_k_item_ids = [
                product_index_to_id[item_index] for item_index in top_k_item_ids_indices
            ]

            # Get names of the items that are not allowed to be added.
            already_recommended_names = [
                product_id_to_name[item]
                for item in session_recommendations + duplicate_replacements
            ]
            upcoming_recommendations = value_counts
            disallowed_items = (
                already_recommended_names
                + list(upcoming_recommendations.keys())
                + session_item_names
            )

            # Filter out disallowed items.
            top_k_item_ids = [
                item
                for item in top_k_item_ids
                if product_id_to_name[item] not in disallowed_items
            ]

            # We add the item itself if it exists.
            item_exists: bool = item_name in product_name_to_id
            if item_exists:
                item_id = product_name_to_id[item_name]
                session_recommendations.append(item_id)

            # Truncate.
            # If an item appeared `count` times, it needs `count - int(item_exists)` replacements.
            # If the item exists, we have added it already, so we only need count - 1 replacements.
            # If the item does not exist, we need count replacements.
            top_k_item_ids = top_k_item_ids[: count - int(item_exists)]

            duplicate_replacements.extend(top_k_item_ids)

        else:
            # Simply add the id to the list of recommendations
            item_id = product_name_to_id[item_name]
            session_recommendations.append(item_id)

    session_recommendations.extend(duplicate_replacements)

    num_sessions_done += 1
    if random.randint(0, 100) == 50:
        print(f"Num sessions done: {num_sessions_done}")

    recommendations.update({session_id: session_recommendations})

# Save file

In [None]:
pickle.dump(recommendations, open(f"recs_openai_{TOTAL_MODEL_NAME}.pickle", "wb"))