# Graph DB Load

## Setup

In [42]:
%%capture
%pip install kaggle sentence_transformers langchain openai tiktoken python-dotenv
%pip install --upgrade graphdatascience

In [7]:
import graphdatascience
graphdatascience.__version__

'1.8'

In [8]:
from graphdatascience import GraphDataScience
from dotenv import load_dotenv
import os
import pandas as pd
import numpy as np

In [9]:
# configure authentication per instruction @ https://github.com/Kaggle/kaggle-api/README.md
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f articles.csv -p data
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f customers.csv -p data
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f transactions_train.csv -p data
!cd data && unzip -n '*.zip'

articles.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
customers.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
transactions_train.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  customers.csv.zip

Archive:  articles.csv.zip

Archive:  transactions_train.csv.zip

3 archives were successfully processed.


In [10]:
def camel_case(s):
    ss = s.lower().split('_')
    if len(ss) == 1:
        return ss[0]
    return ss[0] + ''.join(st.title() for st in ss[1:])

def camel_case_dict(name_keys):
    name_values = [camel_case(s) for s in name_keys]
    return dict(zip(name_keys, name_values))

def camel_case_rename_cols(df):
    col_map = camel_case_dict(df.columns)
    return df.rename(columns=col_map)

## Filtering & Sampling
There is two stage optional sampling to pair down on articles and customers respectively.  Below are the two stages (in the order they are applied)
1. __Article Filtering__: Filter out transactions for intimate products along with customers that purchased those products. Since removing partial customer history could result in modeling difficulties due to unobserved/partial ground truth, customers who purchased these products are also completely removed from the dataset. __This will produce a list - `filtered_customer_ids`.
2. __Customer Sampling__: Sample a subset of customers. The full dataset includes 1.3 Million customers and tens of millions of transactions as a result.  This could be unwieldy in a demo setting so we provide a way to sample a smaller number of customers here. Once customers are subsampled, article Ids are sampled down again to only reflect the articles purchased by the sampled customers. This will produce two lists to sample down with later:
    - `customer_ids`
    - `article_ids`

In [11]:
SAMPLE_NUM_CUSTOMERS = 200 #set to 0 or less for no sampling
FILTER_ARTICLES = False #whether to filter out certain intimate products for demo purposes (real data problems)
RANDOM_SEED = 7474

### Article Filtering

In [12]:
init_article_df = camel_case_rename_cols(pd.read_csv('data/articles.csv'))
init_article_df.shape

(105542, 25)

In [13]:
# Filtering out some intimate products for demo purposes
filtered_article_ids = init_article_df.articleId
if FILTER_ARTICLES:
    filtered_article_ids = init_article_df[init_article_df.garmentGroupName != 'Under-, Nightwear'].articleId

In [14]:
init_transaction_df = camel_case_rename_cols(pd.read_csv('data/transactions_train.csv'))
# generally a good idea to have an id for these (source data doesn't include ids)
init_transaction_df['txId'] = range(init_transaction_df.shape[0])
init_transaction_df.shape

(31788324, 6)

In [15]:
filtered_customer_ids = init_transaction_df[init_transaction_df.articleId.isin(filtered_article_ids)].customerId.drop_duplicates()
filtered_customer_ids

0           000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...
2           00007d2de826758b65a93dd24ce629ed66842531df6699...
7           00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4...
12          0008968c0d451dbc5a9968da03196fe20051965edde741...
14          000aa7f0dc06cd7174389e76c9e132a67860c5f65f9706...
                                  ...                        
31788165    fe99a0069d6b3c64c2707d0ce53b9311540917471d82df...
31788202    fecc5f77b5f7ee4570efde9ab05ec94d0de2bf80efb4f6...
31788208    fece2f68864c311a0b5208e2eb735b3dcde7e41461d327...
31788217    fee56cc5315dafb35a4490ccc6f711092cae913550c832...
31788275    ff5b8a8b26bf93a66290e9bd1b73393ac6a58968a78519...
Name: customerId, Length: 1362281, dtype: object

## Sample Customers

In [16]:
customer_ids = filtered_customer_ids
if SAMPLE_NUM_CUSTOMERS > 0:
    customer_ids = filtered_customer_ids.sample(n=SAMPLE_NUM_CUSTOMERS, random_state=RANDOM_SEED).reset_index(drop=True)
customer_ids

0      fdbe75e71e134938025dbbb9bc495bd302d578b449ac96...
1      fb9310441b653525f1adad3fbe7ece522ba50e752cca62...
2      f0a8599239eea199f1440af86ab9df78cb5d4e85f532fd...
3      4dd8a1b3175c88f07b123b388a5c9b5dfe16b3ba6fdf62...
4      696093ad8815f16ab92c07eb32d69c2d1e90daef479de7...
                             ...                        
195    dc33f78aea18f5a6048e83f5f3908c1dc0ee410f9ed321...
196    95fc2a180dac51f99de9c85df4e649025ccec12b0fd195...
197    1961b430b57868ceb5863e521192181a26e3a4e3e30f11...
198    ca9aa63d9814d804b80fefdfdf20e733db6c60463c8564...
199    2afaeb2ae4826bd7bfdcb29e748da2b55d53ba390d8040...
Name: customerId, Length: 200, dtype: object

In [17]:
article_ids = init_transaction_df[init_transaction_df.customerId.isin(customer_ids)].articleId.drop_duplicates()
article_ids

7734        615021023
7735        697489002
7736        661166005
29406       612891003
29407       504154015
              ...    
31779121    906639004
31779122    684238003
31779124    921266007
31779125    812530004
31779126    942187001
Name: articleId, Length: 3570, dtype: int64

## Sample Down Data
Now that we have the list of customers and articles to include , `customer_ids`, and `article_ids` respectively, we can use them to filter the source data and stage for loading

In [18]:
transaction_df = init_transaction_df[init_transaction_df.customerId.isin(customer_ids)]
transaction_df

Unnamed: 0,tDat,customerId,articleId,price,salesChannelId,txId
7734,2018-09-20,2a55ebfc7d91dafb5d75d32d4027ee99005d2c98b40711...,615021023,0.016932,1,7734
7735,2018-09-20,2a55ebfc7d91dafb5d75d32d4027ee99005d2c98b40711...,697489002,0.013542,1,7735
7736,2018-09-20,2a55ebfc7d91dafb5d75d32d4027ee99005d2c98b40711...,661166005,0.050831,1,7736
29406,2018-09-20,99e17922e68b9c627587e57e420c887243240cb9d61e2f...,612891003,0.028797,2,29406
29407,2018-09-20,99e17922e68b9c627587e57e420c887243240cb9d61e2f...,504154015,0.015237,2,29407
...,...,...,...,...,...,...
31779123,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,800691008,0.011847,2,31779123
31779124,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,921266007,0.016932,2,31779124
31779125,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,812530004,0.010153,2,31779125
31779126,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,942187001,0.016932,2,31779126


In [19]:
full_article_df = init_article_df[init_article_df.articleId.isin(article_ids)]
full_article_df

Unnamed: 0,articleId,productCode,prodName,productTypeNo,productTypeName,productGroupName,graphicalAppearanceNo,graphicalAppearanceName,colourGroupCode,colourGroupName,...,departmentName,indexCode,indexName,indexGroupNo,indexGroupName,sectionNo,sectionName,garmentGroupNo,garmentGroupName,detailDesc
0,108775015,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,9,Black,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
1,108775044,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,10,White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
6,111565001,111565,20 den 1p Stockings,304,Underwear Tights,Socks & Tights,1010016,Solid,9,Black,...,Tights basic,B,Lingeries/Tights,1,Ladieswear,62,"Womens Nightwear, Socks & Tigh",1021,Socks and Tights,"Semi shiny nylon stockings with a wide, reinfo..."
8,111586001,111586,Shape Up 30 den 1p Tights,273,Leggings/Tights,Garment Lower body,1010016,Solid,9,Black,...,Tights basic,B,Lingeries/Tights,1,Ladieswear,62,"Womens Nightwear, Socks & Tigh",1021,Socks and Tights,Tights with built-in support to lift the botto...
9,111593001,111593,Support 40 den 1p Tights,304,Underwear Tights,Socks & Tights,1010016,Solid,9,Black,...,Tights basic,B,Lingeries/Tights,1,Ladieswear,62,"Womens Nightwear, Socks & Tigh",1021,Socks and Tights,"Semi shiny tights that shape the tummy, thighs..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
105185,934054002,934054,Sicilly top,258,Blouse,Garment Upper body,1010016,Solid,6,Light Grey,...,Blouse,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1010,Blouses,V-neck blouse in woven fabric with a slight sh...
105221,934727002,934727,Kiara top,254,Top,Garment Upper body,1010021,Lace,9,Black,...,Jersey fancy,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1005,Jersey Fancy,Lace blouse with a stand-up collar and an open...
105314,936862001,936862,EDC Marla dress,265,Dress,Garment Full body,1010001,All over pattern,52,Pink,...,Campaigns,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1023,Special Offers,Calf-length dress in a patterned Tencel™ lyoce...
105319,936979001,936979,Class Filippa Necklace,77,Necklace,Accessories,1010016,Solid,5,Gold,...,Jewellery,C,Ladies Accessories,1,Ladieswear,66,Womens Small accessories,1019,Accessories,Metal chain necklace with a pendant. Adjustabl...


### Creating Product, Department, and Article Dataframes

In [20]:
product_df = full_article_df[['productCode', 'prodName',
                              'productTypeNo', 'productTypeName',
                              'productGroupName', 'garmentGroupNo', 'garmentGroupName',
                              'detailDesc']].drop_duplicates(subset='productCode')
product_df

Unnamed: 0,productCode,prodName,productTypeNo,productTypeName,productGroupName,garmentGroupNo,garmentGroupName,detailDesc
0,108775,Strap top,253,Vest top,Garment Upper body,1002,Jersey Basic,Jersey top with narrow shoulder straps.
6,111565,20 den 1p Stockings,304,Underwear Tights,Socks & Tights,1021,Socks and Tights,"Semi shiny nylon stockings with a wide, reinfo..."
8,111586,Shape Up 30 den 1p Tights,273,Leggings/Tights,Garment Lower body,1021,Socks and Tights,Tights with built-in support to lift the botto...
9,111593,Support 40 den 1p Tights,304,Underwear Tights,Socks & Tights,1021,Socks and Tights,"Semi shiny tights that shape the tummy, thighs..."
17,118458,Jerry jogger bottoms,272,Trousers,Garment Lower body,1002,Jersey Basic,Trousers in sweatshirt fabric with an elastica...
...,...,...,...,...,...,...,...,...
105185,934054,Sicilly top,258,Blouse,Garment Upper body,1010,Blouses,V-neck blouse in woven fabric with a slight sh...
105221,934727,Kiara top,254,Top,Garment Upper body,1005,Jersey Fancy,Lace blouse with a stand-up collar and an open...
105314,936862,EDC Marla dress,265,Dress,Garment Full body,1023,Special Offers,Calf-length dress in a patterned Tencel™ lyoce...
105319,936979,Class Filippa Necklace,77,Necklace,Accessories,1019,Accessories,Metal chain necklace with a pendant. Adjustabl...


In [21]:
department_df = full_article_df[['departmentNo', 'departmentName', 'sectionNo', 'sectionName']]\
    .drop_duplicates(subset='departmentNo')
department_df

Unnamed: 0,departmentNo,departmentName,sectionNo,sectionName
0,1676,Jersey Basic,16,Womens Everyday Basics
6,3608,Tights basic,62,"Womens Nightwear, Socks & Tigh"
17,5883,Jersey Basic,26,Men Underwear
23,2032,Jersey,8,Mama
46,6515,Baby basics,44,Baby Essentials & Complements
...,...,...,...,...
80900,7848,Boys Local Relevance,48,Kids Local Relevance
84934,6525,Baby Girl Jersey Fancy,40,Baby Girl
85783,7912,Asia Assortment,15,Womens Everyday Collection
93600,5626,AK Dresses & Outdoor,70,Divided Asia keys


In [22]:
article_df = full_article_df[['articleId', 'productCode', 'departmentNo', 'prodName', 'productTypeName',
                              'graphicalAppearanceNo', 'graphicalAppearanceName', 'colourGroupCode', 'colourGroupName']]
article_df

Unnamed: 0,articleId,productCode,departmentNo,prodName,productTypeName,graphicalAppearanceNo,graphicalAppearanceName,colourGroupCode,colourGroupName
0,108775015,108775,1676,Strap top,Vest top,1010016,Solid,9,Black
1,108775044,108775,1676,Strap top,Vest top,1010016,Solid,10,White
6,111565001,111565,3608,20 den 1p Stockings,Underwear Tights,1010016,Solid,9,Black
8,111586001,111586,3608,Shape Up 30 den 1p Tights,Leggings/Tights,1010016,Solid,9,Black
9,111593001,111593,3608,Support 40 den 1p Tights,Underwear Tights,1010016,Solid,9,Black
...,...,...,...,...,...,...,...,...,...
105185,934054002,934054,1522,Sicilly top,Blouse,1010016,Solid,6,Light Grey
105221,934727002,934727,1636,Kiara top,Top,1010021,Lace,9,Black
105314,936862001,936862,3090,EDC Marla dress,Dress,1010001,All over pattern,52,Pink
105319,936979001,936979,4344,Class Filippa Necklace,Necklace,1010016,Solid,5,Gold


### Create Customer Dataframe

In [23]:
customer_df = camel_case_rename_cols(pd.read_csv('data/customers.csv'))
customer_df = customer_df[customer_df.customerId.isin(customer_ids)]
customer_df

Unnamed: 0,customerId,fn,active,clubMemberStatus,fashionNewsFrequency,age,postalCode
6432,0132cd2eb3c6b1f66784f65f94ddd8352add2653e0caf5...,,,ACTIVE,NONE,49.0,49f7ec29bcacbbf2120af5162f9f99c212e9dd26b48d79...
7756,01713f103284ab2ba16f525802a6e33b69ffe216ad5b4d...,1.0,1.0,ACTIVE,Regularly,33.0,fdedb8bcffaf04924840f6657e10fe7b451f52c6bf1824...
28090,0544e776969215ad79e12fa0407b0e6a4cb53182513744...,1.0,1.0,ACTIVE,Regularly,24.0,d8ce1979cc44c786217c8130fe367e4b4d088c0bc88d65...
41270,07b80dee1cba3ff8ec0bc1ac13b1d37dd0d5cc0c3e934b...,,,ACTIVE,NONE,28.0,7dcb5c70c6e2515284cbc66bf92f2009bba0197320a276...
42401,07f0a353d7ff555a1c0d94369dfa40e8f402cb1a32507a...,,,ACTIVE,NONE,36.0,2de0cc9ede68389795f4dbe6949a1d23d2d17838e4eec6...
...,...,...,...,...,...,...,...
1359974,fdbe75e71e134938025dbbb9bc495bd302d578b449ac96...,1.0,1.0,ACTIVE,Regularly,22.0,4d15039e3b9f409da4f8de9bdb10ea8991e534c91b0020...
1361059,fdf1294f414faac2b00a725f5d80c34f98a744d9b8b3ce...,,,ACTIVE,NONE,32.0,0cd87888c3a13ebbb1e90cac6b9fbf34c51afa40865f55...
1366543,fef793ec3a7d62d782824517355d74ded50964dce33009...,,,ACTIVE,NONE,46.0,5799a39cffe701ebdb12181348bf10f9e23abcc3868c43...
1370498,ffb925b11e1bb2e375d22a02d67907994eb8cb92ec2e7d...,,,ACTIVE,NONE,34.0,ebdd8c5c893683c3cf52c011d4e35024e46d183c95f0fa...


## Graph Loading

In [24]:
load_dotenv('.env', override=True)

# Use Neo4j URI and credentials according to our setup
gds = GraphDataScience(
    os.getenv('NEO4J_URI'),
    auth=(os.getenv('NEO4J_USERNAME'),
          os.getenv('NEO4J_PASSWORD')),
    aura_ds=eval(os.getenv('AURA_DS').title()))

# Necessary if you enabled Arrow on the db - this is true for AuraDS
gds.set_database("neo4j")

In [25]:
gds.version()

'2.4.7+36'

### Create Indexes

In [26]:
# one uniqueness constraint for each node label
gds.run_cypher('CREATE CONSTRAINT unique_department_no IF NOT EXISTS FOR (n:Department) REQUIRE n.departmentNo IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_product_code IF NOT EXISTS FOR (n:Product) REQUIRE n.productCode IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_article_id IF NOT EXISTS FOR (n:Article) REQUIRE n.articleId IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_customer_id IF NOT EXISTS FOR (n:Customer) REQUIRE n.customerId IS UNIQUE')

In [27]:
from typing import Tuple, Union
from numpy.typing import ArrayLike

def make_map(x):
    if type(x) == str:
        return x, x
    elif type(x) == tuple:
        return x
    else:
        raise Exception("Entry must of type string or tuple")


def make_set_clause(prop_names: ArrayLike, element_name='n', item_name='rec'):
    clause_list = []
    for prop_name in prop_names:
        clause_list.append(f'{element_name}.{prop_name} = {item_name}.{prop_name}')
    return 'SET ' + ', '.join(clause_list)


def make_node_merge_query(node_key_name: str, node_label: str, cols: ArrayLike):
    template = f'''UNWIND $recs AS rec\nMERGE(n:{node_label} {{{node_key_name}: rec.{node_key_name}}})'''
    prop_names = [x for x in cols if x != node_key_name]
    if len(prop_names) > 0:
        template = template + '\n' + make_set_clause(prop_names)
    return template + '\nRETURN count(n) AS nodeLoadedCount'


def make_rel_merge_query(source_target_labels: Union[Tuple[str, str], str],
                         source_node_key: Union[Tuple[str, str], str],
                         target_node_key: Union[Tuple[str, str], str],
                         rel_type: str,
                         cols: ArrayLike,
                         rel_key: str = None):

    source_target_label_map = make_map(source_target_labels)
    source_node_key_map = make_map(source_node_key)
    target_node_key_map = make_map(target_node_key)

    merge_statement = f'MERGE(s)-[r:{rel_type}]->(t)'
    if rel_key is not None:
        merge_statement = f'MERGE(s)-[r:{rel_type} {{{rel_key}: rec.{rel_key}}}]->(t)'

    template = f'''\tUNWIND $recs AS rec
    MATCH(s:{source_target_label_map[0]} {{{source_node_key_map[0]}: rec.{source_node_key_map[1]}}})
    MATCH(t:{source_target_label_map[1]} {{{target_node_key_map[0]}: rec.{target_node_key_map[1]}}})\n\t''' + merge_statement
    prop_names = [x for x in cols if x not in [rel_key, source_node_key_map[1], target_node_key_map[1]]]
    if len(prop_names) > 0:
        template = template + '\n\t' + make_set_clause(prop_names, 'r')
    return template + '\n\tRETURN count(r) AS relLoadedCount'


def chunks(xs, n=50_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]


def load_nodes(gds: GraphDataScience, node_df: pd.DataFrame, node_key_col: str, node_label: str, chunk_size=50_000):
    records = node_df.to_dict('records')
    print(f'======  loading {node_label} nodes  ======')
    total = len(records)
    print(f'staging {total:,} records')
    query = make_node_merge_query(node_key_col, node_label, node_df.columns.copy())
    cumulative_count = 0
    for recs in chunks(records, chunk_size):
        res = gds.run_cypher(query, params={'recs': recs})
        cumulative_count += res.iloc[0, 0]
        print(f'Loaded {cumulative_count:,} of {total:,} nodes')


def load_rels(gds: GraphDataScience,
              rel_df: pd.DataFrame,
              source_target_labels: Union[Tuple[str, str], str],
              source_node_key: Union[Tuple[str, str], str],
              target_node_key: Union[Tuple[str, str], str],
              rel_type: str,
              rel_key: str = None,
              chunk_size=50_000):
    records = rel_df.to_dict('records')
    print(f'======  loading {rel_type} relationships  ======')
    total = len(records)
    print(f'staging {total:,} records')
    query = make_rel_merge_query(source_target_labels, source_node_key,
                                 target_node_key, rel_type, rel_df.columns.copy(), rel_key)
    cumulative_count = 0
    for recs in chunks(records, chunk_size):
        res = gds.run_cypher(query, params={'recs': recs})
        cumulative_count += res.iloc[0, 0]
        print(f'Loaded {cumulative_count:,} of {total:,} relationships')

### Load Nodes

In [28]:
%%time
load_nodes(gds, department_df, 'departmentNo', 'Department')

staging 212 records
Loaded 212 of 212 nodes
CPU times: user 7.98 ms, sys: 2.57 ms, total: 10.5 ms
Wall time: 324 ms


In [29]:
%%time
load_nodes(gds, product_df, 'productCode', 'Product')

staging 2,583 records
Loaded 2,583 of 2,583 nodes
CPU times: user 77 ms, sys: 5.3 ms, total: 82.3 ms
Wall time: 830 ms


In [30]:
%%time
load_nodes(gds, article_df.drop(columns=['productCode', 'departmentNo']), 'articleId', 'Article')

staging 3,570 records
Loaded 3,570 of 3,570 nodes
CPU times: user 79.4 ms, sys: 4.05 ms, total: 83.4 ms
Wall time: 456 ms


In [31]:
%%time
load_nodes(gds, customer_df, 'customerId', 'Customer')

staging 200 records
Loaded 200 of 200 nodes
CPU times: user 10.1 ms, sys: 2.28 ms, total: 12.3 ms
Wall time: 318 ms


### Load Relationship

In [32]:
%%time
load_rels(gds, article_df[['articleId', 'departmentNo']], source_target_labels=('Article', 'Department'),
          source_node_key='articleId', target_node_key='departmentNo',
          rel_type='FROM_DEPARTMENT')

staging 3,570 records
Loaded 3,570 of 3,570 relationships
CPU times: user 68.9 ms, sys: 30.7 ms, total: 99.6 ms
Wall time: 423 ms


In [33]:
%%time
load_rels(gds, article_df[['articleId', 'productCode']], source_target_labels=('Article', 'Product'),
          source_node_key='articleId',target_node_key='productCode',
          rel_type='VARIANT_OF')

staging 3,570 records
Loaded 3,570 of 3,570 relationships
CPU times: user 37 ms, sys: 2.72 ms, total: 39.8 ms
Wall time: 289 ms


In [34]:
%%time
load_rels(gds, transaction_df, source_target_labels=('Customer', 'Article'),
          source_node_key='customerId', target_node_key='articleId',
          rel_type='PURCHASED')

staging 4,693 records
Loaded 4,693 of 4,693 relationships
CPU times: user 79.8 ms, sys: 3.59 ms, total: 83.4 ms
Wall time: 481 ms


## Text Embedding Loading
For now we will embed a concatenation of product fields (name, type, group, and descrption).

In [35]:
from langchain.embeddings import OpenAIEmbeddings, BedrockEmbeddings, SentenceTransformerEmbeddings

In [36]:
def load_embedding_model(embedding_model_name: str):
    if embedding_model_name == "openai":
        embeddings = OpenAIEmbeddings()
        dimension = 1536
    elif embedding_model_name == "aws":
        embeddings = BedrockEmbeddings()
        dimension = 1536
    else:
        embeddings = SentenceTransformerEmbeddings(
            model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
        )
        dimension = 384
    return embeddings, dimension

In [40]:
embedding_model, dimension = load_embedding_model(os.getenv('EMBEDDING_MODEL'))

In [108]:
product_emb_df = product_df[['productCode', 'prodName', 'productTypeName', 'garmentGroupName', 'detailDesc']]
product_emb_df = product_emb_df[product_emb_df.detailDesc.notnull()]

In [109]:
def create_doc(row):
    return f'''
Name: {row.prodName}
Type: {row.productTypeName}
Group: {row.garmentGroupName}
Description: {row.detailDesc}
'''

product_emb_df['doc'] = product_emb_df.apply(create_doc, axis=1)
product_emb_df = product_emb_df.drop(columns=['prodName', 'productTypeName', 'garmentGroupName', 'detailDesc'])
product_emb_df

Unnamed: 0,productCode,doc
0,108775,\nName: Strap top\nType: Vest top\nGroup: Jers...
6,111565,\nName: 20 den 1p Stockings\nType: Underwear T...
8,111586,\nName: Shape Up 30 den 1p Tights\nType: Leggi...
9,111593,\nName: Support 40 den 1p Tights\nType: Underw...
17,118458,\nName: Jerry jogger bottoms\nType: Trousers\n...
...,...,...
105185,934054,\nName: Sicilly top\nType: Blouse\nGroup: Blou...
105221,934727,\nName: Kiara top\nType: Top\nGroup: Jersey Fa...
105314,936862,\nName: EDC Marla dress\nType: Dress\nGroup: S...
105319,936979,\nName: Class Filippa Necklace\nType: Necklace...


In [111]:
%%time

#Yeah....we should do this in batching/chunks...I think there is a way with langchain
count = 0
embeddings = []
for txt in product_emb_df.doc:
    count += 1
    if count%200 == 0:
       print(f'Embedded {count} of {product_emb_df.shape[0]}')
    embeddings.append(embedding_model.embed_query(txt))

Embedded 200 of 2578
Embedded 400 of 2578
Embedded 600 of 2578
Embedded 800 of 2578
Embedded 1000 of 2578
Embedded 1200 of 2578
Embedded 1400 of 2578
Embedded 1600 of 2578
Embedded 1800 of 2578
Embedded 2000 of 2578
Embedded 2200 of 2578
Embedded 2400 of 2578
CPU times: user 14.1 s, sys: 3.17 s, total: 17.3 s
Wall time: 8min 3s


In [112]:
product_emb_df['textEmbedding'] = embeddings

In [114]:
%%time
load_nodes(gds, product_emb_df[['productCode', 'textEmbedding']], 'productCode', 'Product')

staging 2,578 records
Loaded 2,578 of 2,578 nodes
CPU times: user 4.28 s, sys: 230 ms, total: 4.51 s
Wall time: 8.94 s


In [63]:
gds.run_cypher(f'CALL db.index.vector.createNodeIndex("product-text-embeddings", "Product", "textEmbedding", {dimension}, "cosine")')