<a href="https://colab.research.google.com/github/bichpham102/analyticsCaseStudies/blob/main/004_Unlabeled_text_classification_with_embeddings/Unlabeled_text_classification_with_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unlabeled text data classification with Embeddings & Cosine Similarity: Fashion example

---


## 1. Scenario:

- X is an emerging e-commerce platform specifically for Fashion items with many third-party sellers and thousands of SKUs.

- The company wants to build a classifier to classify each item into the correct category within a predefined 3-level category tree (***categories*** file). However, the data they have available (***items*** file) is unlabeled.
  - ***items*** file: [from Kaggle](https://www.kaggle.com/datasets/shivamb/fashion-clothing-products-catalog?resource=download) of size 12,491 rows x 8 dimensions (i.e. columns). Each row represents a unique fashion item (sku), each in turn described in 8 dimensions.(This dataset is very clean, in reality, you will have to clean the text yourself which typically includes lowercasing, remove special characters, emojis, dealing with unicode characters, typos, etc.)
  - ***categories*** file [(download here)](https://drive.google.com/file/d/13h9NlAB0zaK6yHgyerXAx-M6w4Xz5biF/view?usp=sharing): of size 100 rows x 4 dimensions. Each row represents a unique 3-level category.


- This matters a lot to X to build a reliable classifier because the company can:

  *   Save their sellers time filling in the category for each item themselves, which can be excruciating with thousands of SKUs.
  *   Improve their search engine & the clarity of the site's structure.
  *   Thus, improve the customer experience in finding the most relevant items to their needs.

\

## 2. Objectives:
\
Develop a machine learning model designed to accurately classify fashion items into appropriate categories within a predefined 3-level category tree.

\
**Other requirements**

- Accurately classify unlabeled Fashion items in ***items*** file into the category that best described them.
- Cost-effective to build and maintain.
- As an emerging Fashion platform, it is important for X that the model can deal with novel items that it has not seen before (i.e., outside the ***items*** file).


\

## 3. Approach:
\

### **3.1 Potential approaches:**
\
Labeling the Data:

1. **Manual Labeling or Crowdsourcing**: Engage human labelers directly or through platforms to categorize items, then apply supervised learning models.
  - Pros: Provides accurate and directly applicable training data.
  - Cons: High labor cost and time-consuming for thousands of items.
2. **Label Only a Subset** and Use One-shot Learning or Semi-supervised Techniques: Begin with a small, representative sample of labeled data and extrapolate to larger, unlabeled sets.
  - Pros: Reduces the need for extensive manual labeling while retaining model effectiveness.
  - Cons: May compromise the model’s accuracy if the subset is not sufficiently representative.

\
Utilizing Unlabeled Data:

3. **Use ChatGPT API**: Implement advanced NLP features without internal model development.
  - Pros: Access to state-of-the-art language processing capabilities.
  - Cons: Potentially high ongoing costs, especially when scaling.
4. **Employ Embedding Models**: Utilize pre-trained language models to convert text into meaningful, contextual embeddings.
  - Pros: Cost-effective; leverages the generalization power of large language models.
  - Cons: May require additional steps to tailor to specific classification needs.

\



### **3.2 Chosen Approach: Using Embeddings and Cosine Similarity**

\

**Overview**
\
The model should take the name of a fashion item as ***input*** and return the correct category as ***output*** using Embeddings generated from OpenAI's `text-embedding-3-small` model and Cosine Similarity.

\
**Pros & Cons of the approach**

\

| Aspect | Strengths | Weaknesses |
| --- | --- | --- |
| Data Requirements | Efficient with unlabeled data; leverages semantic meanings in item names. | Highly dependent on the quality and descriptiveness of item names. |
| Cost | Cost-effective to build and maintain due to use of pre-trained models and minimal computational needs. | Dependency on external models may limit control over operational costs and updates. |
| Scalability | Good at handling novel items due to generalization capabilities of pre-trained embeddings. | May struggle with items that have names not well represented in the model's training data. |
| Implementation | Simple and fast to deploy, suitable for startups and rapid development cycles. | Error diagnosis and correction can be complex due to the opaque nature of embedding models. |
| Model Performance | Quick and straightforward method for classifying items using a similarity-based approach. | Shallow contextual understanding may not capture deep nuances needed for accurate classification. |
| Flexibility | Adaptable to various types of text data and robust against small changes in input style. | Risk of overfitting to specific linguistic patterns not universally applicable. |
| Maintenance | Low maintenance needs if embedding model remains effective for the application context. | Adjusting and updating the model relies on third-party developments (e.g., OpenAI updates). |


\

**The intuition:**

Think of embeddings as a way to turn the names of fashion items and their categories into points on a graph. The closer two points are on this graph, the more similar they are.

\
**Steps Explained:**

- **S1. Make Points for Categories**: Convert each category name into a point on our graph.
- **S2. Make Points for Items**: Convert each item name into a point on our graph too.
- **S3. Find the Nearest Category for Each Item**: By measuring the distance between an item's point and all category points (using something called cosine similarity), we find out which category is closest to the item. The closest category is considered the best match for that item.

\


## 4. Evaluation

\
With a threshold of 0.5 for similarity_score, the preliminary accuracy achieved on a sample of 1,249 items is approximately 67%, employing Natural Language Understanding (NLU) to verify the correctness of the predicted labels.

Important observations:
- Many items are outside of the predefined category tree, this negatively impacts accuracy.
- Since the differences between some categories are nuanced (e.g., "Women's Camisoles", "Women's Tank Tops"), the similarity scores are low for these items.

Total costs: $1.62




---



## 0. Setups

- In your Drive, create a folder named 'TDC__UnlabeledTextClassification_Fashion'
- Download the two files above and put it in the folder

In [None]:
# allow Google Colab to access Drive's content
from google.colab import drive
drive.mount('/content/drive')

# adds the directory '/content/drive/MyDrive' to the list that Python checks for modules and packages,
# allowing Python scripts to import modules from that specified directory.
import sys
sys.path.append('/content/drive/MyDrive/TDC__UnlabeledTextClassification_Fashion')


Mounted at /content/drive


In [None]:
import tensorflow as tf

# Check available GPUs
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # If GPUs are available, set TensorFlow to use the GPU
        tf.config.set_visible_devices(gpus[0], 'GPU')  # Use the first GPU

        # Optionally, set memory growth to True to allocate only as much
        # GPU memory as needed at runtime
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        # Memory growth must be set before initializing GPUs
        print(e)
else:
    print("No GPU found. Using CPU instead.")


Import input files: items and cates

In [None]:
import pandas as pd
import numpy as np

items = 'myntra_products_catalog.csv' # file name
cates = 'category_tree.csv' # this is a google sheet file so we use it url

# import the csv file
items_df = pd.read_csv(f'/content/drive/MyDrive/TDC__UnlabeledTextClassification_Fashion/{items}')

# import the google sheet file
cates_df = pd.read_csv(f'/content/drive/MyDrive/TDC__UnlabeledTextClassification_Fashion/{cates}')


## 1. Inspect and Transform inputs

### 1.1 Category dataframe

In [None]:
cates_df.head()

Unnamed: 0,category_id,category_level1,category_level2,category_level3
0,1,Men,Tops,T-Shirts
1,2,Men,Tops,Polo Shirts
2,3,Men,Tops,Dress Shirts
3,4,Men,Tops,Tank Tops
4,5,Men,Tops,Sweatshirts


From inspection, it seems the meaning of each unique category can be described by just *category_level1* and *category_level3*. So we create a new field - `category_name` from *category_level1* and *category_level3* only.

We will use `category_name` in Step 3 when we get embeddings for each category.

In [None]:
cates_df['category_name'] = cates_df['category_level1'] + "'s " + cates_df['category_level3']
cates_df.head()

Unnamed: 0,category_id,category_level1,category_level2,category_level3,category_name
0,1,Men,Tops,T-Shirts,Men's T-Shirts
1,2,Men,Tops,Polo Shirts,Men's Polo Shirts
2,3,Men,Tops,Dress Shirts,Men's Dress Shirts
3,4,Men,Tops,Tank Tops,Men's Tank Tops
4,5,Men,Tops,Sweatshirts,Men's Sweatshirts


### 1.2 Items dataframe

In [None]:
items_df.head()

Unnamed: 0,ProductID,ProductName,ProductBrand,Gender,Price (INR),NumImages,Description,PrimaryColor
0,10017413,DKNY Unisex Black & Grey Printed Medium Trolle...,DKNY,Unisex,11745,7,"Black and grey printed medium trolley bag, sec...",Black
1,10016283,EthnoVogue Women Beige & Grey Made to Measure ...,EthnoVogue,Women,5810,7,Beige & Grey made to measure kurta with churid...,Beige
2,10009781,SPYKAR Women Pink Alexa Super Skinny Fit High-...,SPYKAR,Women,899,7,Pink coloured wash 5-pocket high-rise cropped ...,Pink
3,10015921,Raymond Men Blue Self-Design Single-Breasted B...,Raymond,Men,5599,5,Blue self-design bandhgala suitBlue self-desig...,Blue
4,10017833,Parx Men Brown & Off-White Slim Fit Printed Ca...,Parx,Men,759,5,"Brown and off-white printed casual shirt, has ...",White


From inspection, it seems the meaning of each unique category can be **sufficiently** described by just *ProductName* and *Gender*. So we create a new field - `item_name` from *ProductName* and *Gender* only.

We will use `item_name` in Step 3 when we get embeddings for each item.

In [None]:
# Define a function to apply to each row:
# if ProductName has not already mention Gender, then add Gender to the name, else keep it as it is
def modify_item_name(row):
    if row['Gender'] in row['ProductName']:
        return row['ProductName']
    else:
        return f"{row['Gender']} - {row['ProductName']}"

# Apply the function to each row
items_df['item_name'] = items_df.apply(modify_item_name, axis=1)

In [None]:
items_df.head()

Unnamed: 0,ProductID,ProductName,ProductBrand,Gender,Price (INR),NumImages,Description,PrimaryColor,item_name
0,10017413,DKNY Unisex Black & Grey Printed Medium Trolle...,DKNY,Unisex,11745,7,"Black and grey printed medium trolley bag, sec...",Black,DKNY Unisex Black & Grey Printed Medium Trolle...
1,10016283,EthnoVogue Women Beige & Grey Made to Measure ...,EthnoVogue,Women,5810,7,Beige & Grey made to measure kurta with churid...,Beige,EthnoVogue Women Beige & Grey Made to Measure ...
2,10009781,SPYKAR Women Pink Alexa Super Skinny Fit High-...,SPYKAR,Women,899,7,Pink coloured wash 5-pocket high-rise cropped ...,Pink,SPYKAR Women Pink Alexa Super Skinny Fit High-...
3,10015921,Raymond Men Blue Self-Design Single-Breasted B...,Raymond,Men,5599,5,Blue self-design bandhgala suitBlue self-desig...,Blue,Raymond Men Blue Self-Design Single-Breasted B...
4,10017833,Parx Men Brown & Off-White Slim Fit Printed Ca...,Parx,Men,759,5,"Brown and off-white printed casual shirt, has ...",White,Parx Men Brown & Off-White Slim Fit Printed Ca...


## 2. Get embeddings

In [None]:
!pip install openai

Collecting openai
  Downloading openai-1.23.2-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting httpx<1,>=0.23.0 (from openai)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)
  Downloading httpcore-1.0.5-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)
  Downloading h11-0.14.0-py3-none-any.whl (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: h11, httpcore, httpx, openai
Successfully installed h11-0.14.0 httpcore-1.0.5 h

In [None]:
import openai

# Create an OpenAI client object
client = openai.OpenAI(api_key = 'YOUR_OPENAI_API_KEY' )

def get_embedding(text, model="text-embedding-3-small"):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

In [None]:
from tqdm.auto import tqdm
tqdm.pandas()

cates_df['cate_embedding'] = cates_df['category_name'].progress_apply(lambda x: get_embedding(x, model='text-embedding-3-small'))
cates_df.head(2)

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,category_id,category_level1,category_level2,category_level3,category_name,cate_embedding
0,1,Men,Tops,T-Shirts,Men's T-Shirts,"[0.04717280715703964, 0.03839524835348129, -0...."
1,2,Men,Tops,Polo Shirts,Men's Polo Shirts,"[0.014527802355587482, 0.0327535904943943, -0...."


In [None]:
items_df['item_embedding'] = items_df['item_name'].progress_apply(lambda x: get_embedding(x, model='text-embedding-3-small'))
items_df.head(2)

  0%|          | 0/12491 [00:00<?, ?it/s]

Unnamed: 0,ProductID,ProductName,ProductBrand,Gender,Price (INR),NumImages,Description,PrimaryColor,item_name,item_embedding
0,10017413,DKNY Unisex Black & Grey Printed Medium Trolle...,DKNY,Unisex,11745,7,"Black and grey printed medium trolley bag, sec...",Black,DKNY Unisex Black & Grey Printed Medium Trolle...,"[0.0036523446906358004, -0.017720405012369156,..."
1,10016283,EthnoVogue Women Beige & Grey Made to Measure ...,EthnoVogue,Women,5810,7,Beige & Grey made to measure kurta with churid...,Beige,EthnoVogue Women Beige & Grey Made to Measure ...,"[0.017012350261211395, -0.012071176432073116, ..."


## 3. Identify the best-matched category for each item with Cosine Similarity

In [None]:
from scipy.spatial.distance import cosine

def find_nearest_neighbor_label_and_similarity(text_embedding, label_embeddings, label_names):
    # Calculate cosine distance with all labels' embeddings
    distances = label_embeddings.apply(lambda x: cosine(text_embedding, x))
    # Find the index and value of the label with minimum distance
    nearest_index = distances.idxmin()
    nearest_similarity = 1 - distances[nearest_index]  # Convert distance to similarity
    # Return the corresponding label and similarity
    return label_names.iloc[nearest_index], nearest_similarity


In [None]:
# Calculate nearest neighbor labels and similarities
items_df[['pred_label', 'pred_label_similarity']] = items_df['item_embedding'].apply(
    lambda x: find_nearest_neighbor_label_and_similarity(x, cates_df['cate_embedding'], cates_df['category_name'])
).apply(pd.Series)


In [None]:
items_df = items_df[['item_name','pred_label','pred_label_similarity','item_embedding']].copy()
items_df.head()

Unnamed: 0,item_name,pred_label,pred_label_similarity,item_embedding
0,DKNY Unisex Black & Grey Printed Medium Trolle...,Men's Bags,0.499529,"[0.0036523446906358004, -0.017720405012369156,..."
1,EthnoVogue Women Beige & Grey Made to Measure ...,Women's Jackets,0.496151,"[0.017012350261211395, -0.012071176432073116, ..."
2,SPYKAR Women Pink Alexa Super Skinny Fit High-...,Women's Jeans,0.551306,"[7.444973107340047e-06, -0.009631668217480183,..."
3,Raymond Men Blue Self-Design Single-Breasted B...,Men's Blazers,0.48166,"[0.018767036497592926, 0.007596811279654503, -..."
4,Parx Men Brown & Off-White Slim Fit Printed Ca...,Men's Polo Shirts,0.537286,"[0.028040817007422447, 0.013223097659647465, -..."


In [None]:
items_df.to_csv('/content/drive/MyDrive/TDC__UnlabeledTextClassification_Fashion/items_df_res.csv', index=False, encoding='utf-8-sig')

## 5. Evaluation

### 5.1 Some observations

- The model will try to match all items to category within the predefined category tree. -> This will significantly impact the overall model accuracy if there exists items that fall outside the categories.

In [None]:
items_df[items_df['pred_label_similarity'] <= 0.25].head()

Unnamed: 0,item_name,pred_label,pred_label_similarity,item_embedding
455,Qraa Men Set of 2 Intense Acne Clearing Face W...,Men's Camisoles,0.2493,"[-0.0033693143632262945, -0.026148486882448196..."
1057,Innisfree Unisex Green Tea Facial Mist 150 ml,Men's Raincoats,0.248253,"[-0.0022635001223534346, 0.039540983736515045,..."
3060,Schwarzkopf PROFESSIONAL Unisex Bonacure pH 4....,Men's Hats,0.225332,"[-0.012151699513196945, 0.030259646475315094, ..."
3092,Organic Harvest Unisex Sulphate Free Acne Cont...,Men's Crop Tops,0.235524,"[-0.01811647228896618, -8.517306559951976e-05,..."
3341,Organic Harvest Unisex Sulphate Free Fresh & G...,Women's Sweatshirts,0.244532,"[-0.002115577459335327, 0.0013206246076151729,..."


- Since the different between some categories are nuanced (e.g., "Women's Camisoles", "Women's Tank Tops"), the similarity score are low for these items.

In [None]:
items_df[items_df['pred_label'].isin(["Women's Camisoles", "Women's Tank Tops"])].head(10)

Unnamed: 0,item_name,pred_label,pred_label_similarity,item_embedding
45,Women - Soie Nude-Coloured Solid Non-Wired Non...,Women's Camisoles,0.483369,"[0.008178001269698143, -0.02102738432586193, -..."
49,Women - PARFAIT Plus Size Blue Solid Underwire...,Women's Camisoles,0.462877,"[0.005946676712483168, -0.022673727944493294, ..."
52,Women - PARFAIT Plus Size Blue Solid Underwire...,Women's Tank Tops,0.488558,"[0.002587677212432027, -0.03328511491417885, -..."
64,Women - PARFAIT Plus Size Red Solid Underwired...,Women's Tank Tops,0.490826,"[0.014129752293229103, -0.03453654795885086, 0..."
78,Women - PARFAIT Plus Size Black Lace Non-Wired...,Women's Camisoles,0.449468,"[0.014254650101065636, -0.03465014323592186, -..."
80,Lady Lyka Women Pack of 2 Beginners Bras TEENA...,Women's Camisoles,0.445808,"[0.04246421903371811, -0.018130652606487274, -..."
84,Women - PARFAIT Plus Size Black Solid Non-Wire...,Women's Tank Tops,0.441572,"[0.016022294759750366, -0.008620144799351692, ..."
100,Women - PARFAIT Plus Size Beige Solid Underwir...,Women's Camisoles,0.462715,"[0.0023894228506833315, -0.025594085454940796,..."
126,Women - PARFAIT Plus Size Blue Lace Underwired...,Women's Camisoles,0.465153,"[0.015861237421631813, -0.02198561653494835, 0..."
127,Lady Lyka Women Peach and Burgundy Pack of 2 E...,Women's Camisoles,0.445377,"[0.05112874135375023, -0.015719115734100342, -..."


### 5.2 Check accuracy

Quickly check accuracy of 10% of the data by utilizing GPT-4's ability to understand natural language. Genrally, we want it to judge if `Statement: '{item_name} is a {pred_label}'. is correct?` and return True/False.


In [None]:
def check_statement(item_name, pred_label):
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "user", "content": f"Statement: '{item_name} is a {pred_label}'. Is this statement correct? (True/False)"}
          ],
        max_tokens=5
    )
    res = response.choices[0].message.content
    if res == 'False':
        return False
    else:
        return True # to avoid value like 'True.' being returned



In [None]:
df_above50 = items_df[items_df['pred_label_similarity'] > 0.5].copy()
sample_df = df_above50.sample(n=int(0.1*len(df_above50)), random_state=42)
sample_df['is_correct'] = sample_df.progress_apply(lambda row: check_statement(row['item_name'], row['pred_label']), axis=1)

  0%|          | 0/837 [00:00<?, ?it/s]

In [None]:
sample_df.head(10)

Unnamed: 0,item_name,pred_label,pred_label_similarity,item_embedding,is_correct
7197,Women - Lavie Olive Green Textured Hobo Bag,Women's Bags,0.647549,"[0.02513021230697632, -0.041580911725759506, -...",True
12227,Pepe Jeans Men Off-White Printed Round Neck T-...,Men's T-Shirts,0.515711,"[0.034798763692379, 0.012704678811132908, -0.0...",True
3626,Free Authority Men White Batman Printed T-shirt,Men's T-Shirts,0.503615,"[0.008314426988363266, 0.028553415089845657, 0...",True
8279,Franco Leone Men Tan Brown Leather Sandals,Men's Sandals,0.624545,"[-0.020072249695658684, -0.0002564014575909823...",True
8252,Mufti Men Black & White Regular Fit Printed Ca...,Men's Dress Shirts,0.524133,"[0.03393406793475151, 0.019687695428729057, -0...",True
3164,DressBerry Women Blue Boyfriend Fit Mid-Rise C...,Women's Jeans,0.576775,"[0.02218696102499962, 0.007113466504961252, -0...",True
1815,Tulsattva Women Sea Green A-Line Dress,Women's Dress Shirts,0.516991,"[0.05970532447099686, 0.02357577160000801, 0.0...",False
8863,Allen Solly Men Black Slim Fit Solid Formal Tr...,Men's Trousers,0.575604,"[0.0047430433332920074, 0.022379977628588676, ...",True
9509,GNIST Women Black Solid Suede Open Toe Flats,Women's Flats,0.65101,"[0.007082749158143997, 0.01575683429837227, -0...",True
5833,GAP Girls Super Skinny Jeans with Fantastiflex,Women's Jeans,0.542189,"[0.015329558402299881, 0.022858276963233948, -...",False


In [None]:
# get accuracy
sample_accuracy = sample_df['is_correct'].sum() / len(sample_df)
print('sample_accuracy: ', sample_accuracy)

sample_accuracy:  0.6690561529271206


In [None]:
sample_df.to_csv('/content/drive/MyDrive/TDC__UnlabeledTextClassification_Fashion/sample_df_res.csv', index=False, encoding='utf-8-sig')

## 6. Visualization

In [None]:
filtered_df1 = items_df[items_df['pred_label'].isin(["Men's Oxfords"])][:10].copy()
filtered_df2 = items_df[items_df['pred_label'].isin(["Men's Jewelry"])][:10].copy()

cates_filtered_df1 = cates_df[cates_df['category_name'].isin(filtered_df1['pred_label'])].copy()
cates_filtered_df2 = cates_df[cates_df['category_name'].isin(filtered_df2['pred_label'])].copy()

In [None]:
import plotly.graph_objects as go

# Function to create Plotly 3D scatter plot
def create_3d_plot(df, embedding_col, name_col, color, name):
    # Extracting x, y, z coordinates for Plotly
    coords = df[embedding_col].tolist()  # Ensure this is a list of lists with each inner list having at least 3 numbers.
    x = [coord[0] for coord in coords]
    y = [coord[1] for coord in coords]
    z = [coord[2] for coord in coords if len(coord) > 2]  # Make sure to handle the z-coordinate correctly
    labels = df[name_col].tolist()

    # Create a trace for the scatter plot
    trace = go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(size=5, color=color),  # You can adjust the size and color here
        text=labels,  # This is what will be shown when hovering
        name=name
    )
    return trace

# Data preparation
items1 = create_3d_plot(filtered_df1, 'item_embedding', 'item_name', 'pink', "Men's Oxfords items")
items2 = create_3d_plot(filtered_df2, 'item_embedding', 'item_name', 'lightblue', "Men's Jewelry items")
cate1 = create_3d_plot(cates_filtered_df1, 'cate_embedding', 'category_name', 'red', cates_filtered_df1['category_name'].values[0])
cate2 = create_3d_plot(cates_filtered_df2, 'cate_embedding', 'category_name', 'blue', cates_filtered_df2['category_name'].values[0])

# Plot configuration
fig = go.Figure(data=[items1, cate1, items2, cate2])
fig.update_layout(
    title="3D Visualization of Items and Categories",
    scene=dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis'
    ),
    legend_title="Legend"
)
fig.show()
