# Graph DB Load

## Setup

In [11]:
%%capture
%pip install kaggle sentence_transformers langchain openai tiktoken python-dotenv
%pip install --upgrade graphdatascience

In [12]:
import graphdatascience
graphdatascience.__version__

'1.8'

In [13]:
from graphdatascience import GraphDataScience
from dotenv import load_dotenv
import os
import pandas as pd
import numpy as np

In [14]:
# configure authentication per instruction @ https://github.com/Kaggle/kaggle-api/README.md
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f articles.csv -p data
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f customers.csv -p data
!kaggle competitions download -c h-and-m-personalized-fashion-recommendations -f transactions_train.csv -p data
!cd data && unzip -n '*.zip'

articles.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
customers.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
transactions_train.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  transactions_train.csv.zip

Archive:  customers.csv.zip

Archive:  articles.csv.zip

3 archives were successfully processed.


In [15]:
def camel_case(s):
    ss = s.lower().split('_')
    if len(ss) == 1:
        return ss[0]
    return ss[0] + ''.join(st.title() for st in ss[1:])

def camel_case_dict(name_keys):
    name_values = [camel_case(s) for s in name_keys]
    return dict(zip(name_keys, name_values))

def camel_case_rename_cols(df):
    col_map = camel_case_dict(df.columns)
    return df.rename(columns=col_map)

## Filtering & Sampling
There is two stage optional sampling to pair down on articles and customers respectively.  Below are the two stages (in the order they are applied)
1. __Article Filtering__: Filter out transactions for intimate products along with customers that purchased those products. Since removing partial customer history could result in modeling difficulties due to unobserved/partial ground truth, customers who purchased these products are also completely removed from the dataset. __This will produce a list - `filtered_customer_ids`.
2. __Customer Sampling__: Sample a subset of customers. The full dataset includes 1.3 Million customers and tens of millions of transactions as a result.  This could be unwieldy in a demo setting so we provide a way to sample a smaller number of customers here. Once customers are subsampled, article Ids are sampled down again to only reflect the articles purchased by the sampled customers. This will produce two lists to sample down with later:
    - `customer_ids`
    - `article_ids`

In [16]:
SAMPLE_NUM_CUSTOMERS = 2000 #set to 0 or less for no sampling
FILTER_ARTICLES = False #whether to filter out certain intimate products for demo purposes (real data problems)
RANDOM_SEED = 7474

### Article Filtering

In [17]:
init_article_df = camel_case_rename_cols(pd.read_csv('data/articles.csv'))
init_article_df.shape

(105542, 25)

In [18]:
# Filtering out some intimate products for demo purposes
filtered_article_ids = init_article_df.articleId
if FILTER_ARTICLES:
    filtered_article_ids = init_article_df[init_article_df.garmentGroupName != 'Under-, Nightwear'].articleId

In [19]:
init_transaction_df = camel_case_rename_cols(pd.read_csv('data/transactions_train.csv'))
# generally a good idea to have an id for these (source data doesn't include ids)
init_transaction_df['txId'] = range(init_transaction_df.shape[0])
init_transaction_df.shape

(31788324, 6)

In [20]:
filtered_customer_ids = init_transaction_df[init_transaction_df.articleId.isin(filtered_article_ids)].customerId.drop_duplicates()
filtered_customer_ids

0           000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...
2           00007d2de826758b65a93dd24ce629ed66842531df6699...
7           00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4...
12          0008968c0d451dbc5a9968da03196fe20051965edde741...
14          000aa7f0dc06cd7174389e76c9e132a67860c5f65f9706...
                                  ...                        
31788165    fe99a0069d6b3c64c2707d0ce53b9311540917471d82df...
31788202    fecc5f77b5f7ee4570efde9ab05ec94d0de2bf80efb4f6...
31788208    fece2f68864c311a0b5208e2eb735b3dcde7e41461d327...
31788217    fee56cc5315dafb35a4490ccc6f711092cae913550c832...
31788275    ff5b8a8b26bf93a66290e9bd1b73393ac6a58968a78519...
Name: customerId, Length: 1362281, dtype: object

## Sample Customers

In [21]:
customer_ids = filtered_customer_ids
if SAMPLE_NUM_CUSTOMERS > 0:
    customer_ids = filtered_customer_ids.sample(n=SAMPLE_NUM_CUSTOMERS, random_state=RANDOM_SEED).reset_index(drop=True)
customer_ids

0       fdbe75e71e134938025dbbb9bc495bd302d578b449ac96...
1       fb9310441b653525f1adad3fbe7ece522ba50e752cca62...
2       f0a8599239eea199f1440af86ab9df78cb5d4e85f532fd...
3       4dd8a1b3175c88f07b123b388a5c9b5dfe16b3ba6fdf62...
4       696093ad8815f16ab92c07eb32d69c2d1e90daef479de7...
                              ...                        
1995    77f2b46bb15fb4c1251d82a37b2ed7d83f5a92bb5ff159...
1996    074148aa72ce41f82ca909e1912a18e2c055cae5b71390...
1997    b96b1cf69098b801738b20c4e922a46f5c713113442689...
1998    dfba337fefbf14b24281a0f931bbaba6e388ef4d1bc0e4...
1999    1abde4c8b89375315feca757811924a377dfe2ec8cc8a1...
Name: customerId, Length: 2000, dtype: object

In [22]:
article_ids = init_transaction_df[init_transaction_df.customerId.isin(customer_ids)].articleId.drop_duplicates()
article_ids

1559        662888002
1560        662888001
1561        651244002
1562        651244001
1588        633152003
              ...    
31775199    706271031
31779118    865926002
31779121    906639004
31779122    684238003
31779125    812530004
Name: articleId, Length: 21596, dtype: int64

## Sample Down Data
Now that we have the list of customers and articles to include , `customer_ids`, and `article_ids` respectively, we can use them to filter the source data and stage for loading

In [23]:
transaction_df = init_transaction_df[init_transaction_df.customerId.isin(customer_ids)]
transaction_df

Unnamed: 0,tDat,customerId,articleId,price,salesChannelId,txId
1559,2018-09-20,080756754aef493b2b36f592eae744f2b9787dc55b635b...,662888002,0.033881,2,1559
1560,2018-09-20,080756754aef493b2b36f592eae744f2b9787dc55b635b...,662888001,0.033881,2,1560
1561,2018-09-20,080756754aef493b2b36f592eae744f2b9787dc55b635b...,651244002,0.013542,2,1561
1562,2018-09-20,080756754aef493b2b36f592eae744f2b9787dc55b635b...,651244001,0.006763,2,1562
1588,2018-09-20,0843d9fb6e4f3befa53ff3a8447b902b9f75bfa955a0f9...,633152003,0.030492,1,1588
...,...,...,...,...,...,...
31779124,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,921266007,0.016932,2,31779124
31779125,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,812530004,0.010153,2,31779125
31779126,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,942187001,0.016932,2,31779126
31779127,2020-09-22,b6be55f233772b5fc4a1ebedf36542fb3e1b6c15c23c7e...,866731001,0.025407,2,31779127


In [24]:
full_article_df = init_article_df[init_article_df.articleId.isin(article_ids)]
full_article_df

Unnamed: 0,articleId,productCode,prodName,productTypeNo,productTypeName,productGroupName,graphicalAppearanceNo,graphicalAppearanceName,colourGroupCode,colourGroupName,...,departmentName,indexCode,indexName,indexGroupNo,indexGroupName,sectionNo,sectionName,garmentGroupNo,garmentGroupName,detailDesc
0,108775015,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,9,Black,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
1,108775044,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,10,White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
3,110065001,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,9,Black,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."
4,110065002,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,10,White,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."
6,111565001,111565,20 den 1p Stockings,304,Underwear Tights,Socks & Tights,1010016,Solid,9,Black,...,Tights basic,B,Lingeries/Tights,1,Ladieswear,62,"Womens Nightwear, Socks & Tigh",1021,Socks and Tights,"Semi shiny nylon stockings with a wide, reinfo..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
105395,939927001,939927,Dolphin,265,Dress,Garment Full body,1010013,Other pattern,9,Black,...,Dress,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1013,Dresses Ladies,Short dress in an airy weave with a small stan...
105444,942187001,942187,ED Sasha tee,255,T-shirt,Garment Upper body,1010016,Solid,9,Black,...,Jersey,A,Ladieswear,1,Ladieswear,2,H&M+,1005,Jersey Fancy,"Oversized, straight-cut T-shirt in a soft moda..."
105493,946282001,946282,Linnea dress,265,Dress,Garment Full body,1010021,Lace,9,Black,...,Dress,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1013,Dresses Ladies,Short dress in lace with flounces down the fro...
105520,947599001,947599,ED Duno 2p.,254,Top,Garment Upper body,1010016,Solid,9,Black,...,Jersey,A,Ladieswear,1,Ladieswear,2,H&M+,1005,Jersey Fancy,"Long-sleeved tops in soft, organic cotton jers..."


### Creating Product, Department, and Article Dataframes

In [25]:
product_df = full_article_df[['productCode', 'prodName',
                              'productTypeNo', 'productTypeName',
                              'productGroupName', 'garmentGroupNo', 'garmentGroupName',
                              'detailDesc']].drop_duplicates(subset='productCode')
product_df

Unnamed: 0,productCode,prodName,productTypeNo,productTypeName,productGroupName,garmentGroupNo,garmentGroupName,detailDesc
0,108775,Strap top,253,Vest top,Garment Upper body,1002,Jersey Basic,Jersey top with narrow shoulder straps.
3,110065,OP T-shirt (Idro),306,Bra,Underwear,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."
6,111565,20 den 1p Stockings,304,Underwear Tights,Socks & Tights,1021,Socks and Tights,"Semi shiny nylon stockings with a wide, reinfo..."
8,111586,Shape Up 30 den 1p Tights,273,Leggings/Tights,Garment Lower body,1021,Socks and Tights,Tights with built-in support to lift the botto...
9,111593,Support 40 den 1p Tights,304,Underwear Tights,Socks & Tights,1021,Socks and Tights,"Semi shiny tights that shape the tummy, thighs..."
...,...,...,...,...,...,...,...,...
105395,939927,Dolphin,265,Dress,Garment Full body,1013,Dresses Ladies,Short dress in an airy weave with a small stan...
105444,942187,ED Sasha tee,255,T-shirt,Garment Upper body,1005,Jersey Fancy,"Oversized, straight-cut T-shirt in a soft moda..."
105493,946282,Linnea dress,265,Dress,Garment Full body,1013,Dresses Ladies,Short dress in lace with flounces down the fro...
105520,947599,ED Duno 2p.,254,Top,Garment Upper body,1005,Jersey Fancy,"Long-sleeved tops in soft, organic cotton jers..."


In [26]:
department_df = full_article_df[['departmentNo', 'departmentName', 'sectionNo', 'sectionName']]\
    .drop_duplicates(subset='departmentNo')
department_df

Unnamed: 0,departmentNo,departmentName,sectionNo,sectionName
0,1676,Jersey Basic,16,Womens Everyday Basics
3,1339,Clean Lingerie,61,Womens Lingerie
6,3608,Tights basic,62,"Womens Nightwear, Socks & Tigh"
17,5883,Jersey Basic,26,Men Underwear
23,2032,Jersey,8,Mama
...,...,...,...,...
89929,7857,Kids Boy Exclusive,46,Kids Boy
92964,7510,Woven,28,Men Edition
97443,3420,Small Accessories Extended,66,Womens Small accessories
101971,8090,Promotion/Other/Offer,29,Men Other


In [27]:
article_df = full_article_df[['articleId', 'productCode', 'departmentNo', 'prodName', 'productTypeName',
                              'graphicalAppearanceNo', 'graphicalAppearanceName', 'colourGroupCode', 'colourGroupName']]
article_df

Unnamed: 0,articleId,productCode,departmentNo,prodName,productTypeName,graphicalAppearanceNo,graphicalAppearanceName,colourGroupCode,colourGroupName
0,108775015,108775,1676,Strap top,Vest top,1010016,Solid,9,Black
1,108775044,108775,1676,Strap top,Vest top,1010016,Solid,10,White
3,110065001,110065,1339,OP T-shirt (Idro),Bra,1010016,Solid,9,Black
4,110065002,110065,1339,OP T-shirt (Idro),Bra,1010016,Solid,10,White
6,111565001,111565,3608,20 den 1p Stockings,Underwear Tights,1010016,Solid,9,Black
...,...,...,...,...,...,...,...,...,...
105395,939927001,939927,1322,Dolphin,Dress,1010013,Other pattern,9,Black
105444,942187001,942187,1919,ED Sasha tee,T-shirt,1010016,Solid,9,Black
105493,946282001,946282,1322,Linnea dress,Dress,1010021,Lace,9,Black
105520,947599001,947599,1919,ED Duno 2p.,Top,1010016,Solid,9,Black


### Create Customer Dataframe

In [28]:
customer_df = camel_case_rename_cols(pd.read_csv('data/customers.csv'))
customer_df = customer_df[customer_df.customerId.isin(customer_ids)]
customer_df

Unnamed: 0,customerId,fn,active,clubMemberStatus,fashionNewsFrequency,age,postalCode
86,0003e867a930d0d6842f923d6ba7c9b77aba33fe2a0fbf...,1.0,1.0,ACTIVE,Regularly,33.0,d647e4ede3d0eb4ce0750440a110350b5f4c758165d89d...
425,00140d87c629b961e410e1d143084146c6fe71df40fe3d...,,,ACTIVE,NONE,24.0,d686e242886674f5bed783e6ceb2c52fe89f2c39996bbf...
810,00264b7d4cd6498292e8a355b699c2d07725d123f04867...,1.0,1.0,ACTIVE,Regularly,53.0,2c29ae653a9282cce4151bd87643c907644e09541abc28...
1947,005c6d3bb66c86aab606814cd9995a12f99b3a44b58c72...,,,PRE-CREATE,NONE,,177b4a2258a85a2247daaa7cdffba96a74c741ea8a6605...
2155,006684ff58368b611db31b1ca782a87cad496e69835e42...,,,ACTIVE,NONE,32.0,4296834187b1ffb908c0aa276b29a4b1af87cad557fb40...
...,...,...,...,...,...,...,...
1364911,feac9822f51efc778acc044776b4b34e8e0a86615bf983...,,,ACTIVE,NONE,48.0,8cecc780f67ff32def9c8e8dff5f454bce26a7cbd4c860...
1366543,fef793ec3a7d62d782824517355d74ded50964dce33009...,,,ACTIVE,NONE,46.0,5799a39cffe701ebdb12181348bf10f9e23abcc3868c43...
1367605,ff2b58ad3e83f2e3499b3eda6ea99993b3bca10d8ceee4...,,,ACTIVE,NONE,35.0,2c29ae653a9282cce4151bd87643c907644e09541abc28...
1370498,ffb925b11e1bb2e375d22a02d67907994eb8cb92ec2e7d...,,,ACTIVE,NONE,34.0,ebdd8c5c893683c3cf52c011d4e35024e46d183c95f0fa...


## Graph Loading

In [29]:
load_dotenv('.env', override=True)

# Use Neo4j URI and credentials according to our setup
gds = GraphDataScience(
    os.getenv('NEO4J_URI'),
    auth=(os.getenv('NEO4J_USERNAME'),
          os.getenv('NEO4J_PASSWORD')),
    aura_ds=eval(os.getenv('AURA_DS').title()))

# Necessary if you enabled Arrow on the db - this is true for AuraDS
gds.set_database("neo4j")

In [30]:
gds.version()

'2.5.1+39'

### Create Indexes

In [31]:
# one uniqueness constraint for each node label
gds.run_cypher('CREATE CONSTRAINT unique_department_no IF NOT EXISTS FOR (n:Department) REQUIRE n.departmentNo IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_product_code IF NOT EXISTS FOR (n:Product) REQUIRE n.productCode IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_article_id IF NOT EXISTS FOR (n:Article) REQUIRE n.articleId IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_customer_id IF NOT EXISTS FOR (n:Customer) REQUIRE n.customerId IS UNIQUE')

In [32]:
from typing import Tuple, Union
from numpy.typing import ArrayLike

def make_map(x):
    if type(x) == str:
        return x, x
    elif type(x) == tuple:
        return x
    else:
        raise Exception("Entry must of type string or tuple")


def make_set_clause(prop_names: ArrayLike, element_name='n', item_name='rec'):
    clause_list = []
    for prop_name in prop_names:
        clause_list.append(f'{element_name}.{prop_name} = {item_name}.{prop_name}')
    return 'SET ' + ', '.join(clause_list)


def make_node_merge_query(node_key_name: str, node_label: str, cols: ArrayLike):
    template = f'''UNWIND $recs AS rec\nMERGE(n:{node_label} {{{node_key_name}: rec.{node_key_name}}})'''
    prop_names = [x for x in cols if x != node_key_name]
    if len(prop_names) > 0:
        template = template + '\n' + make_set_clause(prop_names)
    return template + '\nRETURN count(n) AS nodeLoadedCount'


def make_rel_merge_query(source_target_labels: Union[Tuple[str, str], str],
                         source_node_key: Union[Tuple[str, str], str],
                         target_node_key: Union[Tuple[str, str], str],
                         rel_type: str,
                         cols: ArrayLike,
                         rel_key: str = None):

    source_target_label_map = make_map(source_target_labels)
    source_node_key_map = make_map(source_node_key)
    target_node_key_map = make_map(target_node_key)

    merge_statement = f'MERGE(s)-[r:{rel_type}]->(t)'
    if rel_key is not None:
        merge_statement = f'MERGE(s)-[r:{rel_type} {{{rel_key}: rec.{rel_key}}}]->(t)'

    template = f'''\tUNWIND $recs AS rec
    MATCH(s:{source_target_label_map[0]} {{{source_node_key_map[0]}: rec.{source_node_key_map[1]}}})
    MATCH(t:{source_target_label_map[1]} {{{target_node_key_map[0]}: rec.{target_node_key_map[1]}}})\n\t''' + merge_statement
    prop_names = [x for x in cols if x not in [rel_key, source_node_key_map[1], target_node_key_map[1]]]
    if len(prop_names) > 0:
        template = template + '\n\t' + make_set_clause(prop_names, 'r')
    return template + '\n\tRETURN count(r) AS relLoadedCount'


def chunks(xs, n=50_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]


def load_nodes(gds: GraphDataScience, node_df: pd.DataFrame, node_key_col: str, node_label: str, chunk_size=50_000):
    records = node_df.to_dict('records')
    print(f'======  loading {node_label} nodes  ======')
    total = len(records)
    print(f'staging {total:,} records')
    query = make_node_merge_query(node_key_col, node_label, node_df.columns.copy())
    cumulative_count = 0
    for recs in chunks(records, chunk_size):
        res = gds.run_cypher(query, params={'recs': recs})
        cumulative_count += res.iloc[0, 0]
        print(f'Loaded {cumulative_count:,} of {total:,} nodes')


def load_rels(gds: GraphDataScience,
              rel_df: pd.DataFrame,
              source_target_labels: Union[Tuple[str, str], str],
              source_node_key: Union[Tuple[str, str], str],
              target_node_key: Union[Tuple[str, str], str],
              rel_type: str,
              rel_key: str = None,
              chunk_size=50_000):
    records = rel_df.to_dict('records')
    print(f'======  loading {rel_type} relationships  ======')
    total = len(records)
    print(f'staging {total:,} records')
    query = make_rel_merge_query(source_target_labels, source_node_key,
                                 target_node_key, rel_type, rel_df.columns.copy(), rel_key)
    cumulative_count = 0
    for recs in chunks(records, chunk_size):
        res = gds.run_cypher(query, params={'recs': recs})
        cumulative_count += res.iloc[0, 0]
        print(f'Loaded {cumulative_count:,} of {total:,} relationships')

### Load Nodes

In [33]:
%%time
load_nodes(gds, department_df, 'departmentNo', 'Department')

staging 276 records
Loaded 276 of 276 nodes
CPU times: user 10.6 ms, sys: 0 ns, total: 10.6 ms
Wall time: 384 ms


In [34]:
%%time
load_nodes(gds, product_df, 'productCode', 'Product')

staging 12,058 records
Loaded 12,058 of 12,058 nodes
CPU times: user 486 ms, sys: 4.01 ms, total: 490 ms
Wall time: 1.75 s


In [35]:
%%time
load_nodes(gds, article_df.drop(columns=['productCode', 'departmentNo']), 'articleId', 'Article')

staging 21,596 records
Loaded 21,596 of 21,596 nodes
CPU times: user 743 ms, sys: 60 µs, total: 743 ms
Wall time: 1.69 s


In [36]:
%%time
load_nodes(gds, customer_df, 'customerId', 'Customer')

staging 2,000 records
Loaded 2,000 of 2,000 nodes
CPU times: user 79.8 ms, sys: 0 ns, total: 79.8 ms
Wall time: 405 ms


### Load Relationship

In [37]:
%%time
load_rels(gds, article_df[['articleId', 'departmentNo']], source_target_labels=('Article', 'Department'),
          source_node_key='articleId', target_node_key='departmentNo',
          rel_type='FROM_DEPARTMENT')

staging 21,596 records
Loaded 21,596 of 21,596 relationships
CPU times: user 256 ms, sys: 336 µs, total: 256 ms
Wall time: 1.2 s


In [38]:
%%time
load_rels(gds, article_df[['articleId', 'productCode']], source_target_labels=('Article', 'Product'),
          source_node_key='articleId',target_node_key='productCode',
          rel_type='VARIANT_OF')

staging 21,596 records
Loaded 21,596 of 21,596 relationships
CPU times: user 252 ms, sys: 530 µs, total: 253 ms
Wall time: 1.14 s


In [39]:
%%time
load_rels(gds, transaction_df, source_target_labels=('Customer', 'Article'),
          source_node_key='customerId', target_node_key='articleId',
          rel_type='PURCHASED')

staging 48,064 records
Loaded 48,064 of 48,064 relationships
CPU times: user 1.43 s, sys: 0 ns, total: 1.43 s
Wall time: 3.46 s


## Text Embedding Loading
For now we will embed a concatenation of product fields (name, type, group, and descrption).

In [40]:
from langchain.embeddings import OpenAIEmbeddings, BedrockEmbeddings, SentenceTransformerEmbeddings

In [41]:
def load_embedding_model(embedding_model_name: str):
    if embedding_model_name == "openai":
        embeddings = OpenAIEmbeddings()
        dimension = 1536
    elif embedding_model_name == "aws":
        embeddings = BedrockEmbeddings()
        dimension = 1536
    else:
        embeddings = SentenceTransformerEmbeddings(
            model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
        )
        dimension = 384
    return embeddings, dimension

In [42]:
embedding_model, dimension = load_embedding_model(os.getenv('EMBEDDING_MODEL'))

In [43]:
product_emb_df = product_df[['productCode', 'prodName', 'productTypeName', 'garmentGroupName', 'detailDesc']]
product_emb_df = product_emb_df[product_emb_df.detailDesc.notnull()]

In [44]:
def create_doc(row):
    return f'''
Name: {row.prodName}
Type: {row.productTypeName}
Group: {row.garmentGroupName}
Description: {row.detailDesc}
'''

product_emb_df['doc'] = product_emb_df.apply(create_doc, axis=1)
product_emb_df = product_emb_df.drop(columns=['prodName', 'productTypeName', 'garmentGroupName', 'detailDesc'])
product_emb_df

Unnamed: 0,productCode,doc
0,108775,\nName: Strap top\nType: Vest top\nGroup: Jers...
3,110065,\nName: OP T-shirt (Idro)\nType: Bra\nGroup: U...
6,111565,\nName: 20 den 1p Stockings\nType: Underwear T...
8,111586,\nName: Shape Up 30 den 1p Tights\nType: Leggi...
9,111593,\nName: Support 40 den 1p Tights\nType: Underw...
...,...,...
105395,939927,\nName: Dolphin\nType: Dress\nGroup: Dresses L...
105444,942187,\nName: ED Sasha tee\nType: T-shirt\nGroup: Je...
105493,946282,\nName: Linnea dress\nType: Dress\nGroup: Dres...
105520,947599,\nName: ED Duno 2p.\nType: Top\nGroup: Jersey ...


In [45]:
%%time

#Yeah....we should do this in batching/chunks...I think there is a way with langchain
count = 0
embeddings = []
for txt in product_emb_df.doc:
    count += 1
    if count%200 == 0:
       print(f'Embedded {count} of {product_emb_df.shape[0]}')
    embeddings.append(embedding_model.embed_query(txt))

Embedded 200 of 12021
Embedded 400 of 12021
Embedded 600 of 12021
Embedded 800 of 12021
Embedded 1000 of 12021
Embedded 1200 of 12021
Embedded 1400 of 12021
Embedded 1600 of 12021
Embedded 1800 of 12021
Embedded 2000 of 12021
Embedded 2200 of 12021
Embedded 2400 of 12021
Embedded 2600 of 12021
Embedded 2800 of 12021
Embedded 3000 of 12021
Embedded 3200 of 12021
Embedded 3400 of 12021
Embedded 3600 of 12021
Embedded 3800 of 12021
Embedded 4000 of 12021
Embedded 4200 of 12021
Embedded 4400 of 12021
Embedded 4600 of 12021
Embedded 4800 of 12021
Embedded 5000 of 12021
Embedded 5200 of 12021
Embedded 5400 of 12021
Embedded 5600 of 12021
Embedded 5800 of 12021
Embedded 6000 of 12021
Embedded 6200 of 12021
Embedded 6400 of 12021
Embedded 6600 of 12021
Embedded 6800 of 12021
Embedded 7000 of 12021
Embedded 7200 of 12021
Embedded 7400 of 12021
Embedded 7600 of 12021


Retrying langchain.embeddings.openai.embed_with_retry.<locals>._embed_with_retry in 4.0 seconds as it raised APIError: The server had an error while processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 99d77f92b53f790c51dfd29ef0a1aa0a in your message.) {
  "error": {
    "message": "The server had an error while processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 99d77f92b53f790c51dfd29ef0a1aa0a in your message.)",
    "type": "server_error",
    "param": null,
    "code": null
  }
}
 500 {'error': {'message': 'The server had an error while processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 99d77f92b53f

Embedded 7800 of 12021
Embedded 8000 of 12021
Embedded 8200 of 12021
Embedded 8400 of 12021
Embedded 8600 of 12021
Embedded 8800 of 12021
Embedded 9000 of 12021
Embedded 9200 of 12021
Embedded 9400 of 12021
Embedded 9600 of 12021
Embedded 9800 of 12021
Embedded 10000 of 12021
Embedded 10200 of 12021
Embedded 10400 of 12021
Embedded 10600 of 12021
Embedded 10800 of 12021
Embedded 11000 of 12021
Embedded 11200 of 12021
Embedded 11400 of 12021
Embedded 11600 of 12021
Embedded 11800 of 12021
Embedded 12000 of 12021
CPU times: user 19.1 s, sys: 1.3 s, total: 20.4 s
Wall time: 50min 14s


In [46]:
product_emb_df['textEmbedding'] = embeddings

In [47]:
%%time
load_nodes(gds, product_emb_df[['productCode', 'textEmbedding']], 'productCode', 'Product')

staging 12,021 records
Loaded 12,021 of 12,021 nodes
CPU times: user 36.6 s, sys: 451 ms, total: 37.1 s
Wall time: 43.6 s


In [48]:
gds.run_cypher(f'CALL db.index.vector.createNodeIndex("product-text-embeddings", "Product", "textEmbedding", {dimension}, "cosine")')

In [50]:
# Save in case we want it again
product_emb_df.to_csv('product_emb.csv', index=False)

## Convert Transaction Dates
We should have probably done this earlier...these need to be converted from string to date

In [51]:
gds.run_cypher('''
MATCH (:Customer)-[r:PURCHASED]->()
SET r.tDat = date(r.tDat)
''')