# Importing Libraries and defining the model

In [None]:
import spacy
from spacy.matcher import Matcher
import re
from fuzzywuzzy import process
nlp = spacy.load('en_core_web_md')
matcher = Matcher(nlp.vocab)
# Assumed form of the database for now 

'''
TABLE products (
    id INTEGER PRIMARY KEY,       -- Unique identifier for each product
    name TEXT,                    -- Descriptive name of the product
    category TEXT,                -- Broad category of the product (e.g., ring, necklace, bracelet)
    price NUMERIC,                -- Cost of the product
    description TEXT,             -- Detailed description of the product
    shape TEXT,                   -- Shape of the product (e.g., Round, Oval)
    style TEXT,                   -- Style of the product (e.g., Solitaire, Halo)
    metal TEXT,                   -- Metal type (e.g., White Gold, Yellow Gold)
    collection TEXT,              -- Collection (e.g., Classic, Contemporary)
    type TEXT,                    -- Type of jewelry (e.g., Engagement Ring, Wedding Ring, Jewelry)
    birthstone TEXT               -- Birthstone (e.g., January, February)
);
'''

# Function for defining the patterns 

In [2]:
def patterns_definition(KEY="LEMMA"):
    '''
    Input: Key (default is "LEMMA") options: "LEMMA"(for     form) or "LOWER" (for tokenized form)
    Output: The modified matcher object (based on key form[i.e. token or lemma]) with the patterns for ProductType, Material, Color, PriceVal, and PriceConstraint added to it.
    Note: You can change to get the patterns based on the tokens by changing the key from "LEMMA" to "LOWER" in the patterns.
    '''
    
    product_patterns = [
        [{KEY: "ring"}],
        [{KEY: "necklace"}],
        [{KEY: "bracelet"}],
        [{KEY: "earring"}],
        [{KEY: "pendant"}],
        [{KEY: "bangle"}],
        [{KEY: "cufflink"}]
    ]
    matcher.add("ProductType", product_patterns)

    shape_patterns = [
        [{KEY: "round"}],
        [{KEY: "oval"}],
        [{KEY: "princess"}],
        [{KEY: "pear"}],
        [{KEY: "cushion"}],
        [{KEY: "marquise"}],
        [{KEY: "radiant"}]
    ]
    matcher.add("Shape", shape_patterns)

    style_patterns = [
        [{KEY: "solitaire"}],
        [{KEY: "hidden"}, {KEY: "halo"}],
        [{KEY: "straight"}],
        [{KEY: "three"}, {KEY: "stone"}],
        [{KEY: "halo"}],
        [{KEY: "double"}, {KEY: "halo"}],
        [{KEY: "split"}, {KEY: "shank"}],
        [{KEY: "three"}, {KEY: "stone"}, {KEY: "halo"}],
        [{KEY: "wide"}, {KEY: "band"}]
    ]
    matcher.add("Style", style_patterns)

    metal_patterns = [
        [{KEY: "white"}, {KEY: "gold"}],
        [{KEY: "yellow"}, {KEY: "gold"}],
        [{KEY: "rose"}, {KEY: "gold"}],
        [{KEY: "platinum"}],
        [{KEY: "two"}, {KEY: "tone"}],
        [{KEY: "emerald"}]
    ]
    matcher.add("Metal", metal_patterns)

    collection_patterns = [
        [{KEY: "classic"}],
        [{KEY: "contemporary"}],
        [{KEY: "vintage"}, {KEY: "inspired"}],
        [{KEY: "lotus"}],
        [{KEY: "starlight"}],
        [{KEY: "floral"}],
        [{KEY: "bujukan"}],
        
    ]
    matcher.add("Collection", collection_patterns)

    type_patterns = [
        [{KEY: "engagement"}, {KEY: "ring"}],
        [{KEY: "wedding"}, {KEY: "ring"}],
        [{KEY: "jewelry"}]
    ]
    matcher.add("Type", type_patterns)

    birthstone_patterns = [
        [{KEY: "january"}],
        [{KEY: "february"}],
        [{KEY: "march"}],
        [{KEY: "april"}],
        [{KEY: "may"}],
        [{KEY: "june"}],
        [{KEY: "july"}],
        [{KEY: "august"}],
        [{KEY: "september"}],
        [{KEY: "october"}],
        [{KEY: "november"}],
        [{KEY: "december"}]
    ]
    matcher.add("Birthstone", birthstone_patterns)

    priceval_patterns = [{"LIKE_NUM": True}]
    matcher.add("PriceVal", [priceval_patterns])

    priceconstraint_patterns = [
        [{KEY: "under"}],
        [{KEY: "less"}, {KEY: "than"}],
        [{KEY: "above"}],
        [{KEY: "more"}, {KEY: "than"}],
        [{KEY: "not"}, {KEY: "more"}, {KEY: "than"}]
    ]
    matcher.add("PriceConstraint", priceconstraint_patterns)
    
    id_patterns = [{"LIKE_NUM": True}]
    matcher.add("ID", [id_patterns])

patterns_definition()

# Function to preprocess text and extract entities

In [21]:
def text_to_entities(text: str) -> dict:
    '''
    Input: Raw Text exactly as it came from the user (str)
    Output: Entities of this text in the form of a Dictionary (Dict)
    Note: We are assuming that the entities are only: ProductType, Material, Color, PriceVal, PriceConstraint (until we get access to the database)
    Note: This function first uses the matcher based on lemmas to extract the entities, but if the entities are not found, it uses fuzzy matching to extract the entities.
    '''
    def fuzzy_match(entity, choices, threshold=95):
        def preprocess_for_fuzzy_matching(text):
            '''It removes all the special characters from the text and returns the preprocessed text.'''
            return re.sub(r'[^a-zA-Z\s]', '', text).strip()
        preprocessed_entity = preprocess_for_fuzzy_matching(entity)
        if not preprocessed_entity:
            return None
        match, score = process.extractOne(preprocessed_entity, choices)
        if score >= threshold:
            return match
        return None
    
    doc = nlp(text.lower())
    matches = matcher(doc)
    entities = {
        "ProductType": None,
        "Shape": [],
        "Style": [],
        "Metal": [],
        "Collection": [],
        "Type": None,
        "Birthstone": [],
        "PriceVal": None,
        "PriceConstraint": None,
        "ID": None
    }

    for match_id, start, end in matches:
        span = doc[start:end]
        label = nlp.vocab.strings[match_id]
        if label == "ProductType":
            entities["ProductType"] = span.lemma_
        elif label == "Shape":
            entities["Shape"].append(span.lemma_)
        elif label == "Style":
            entities["Style"].append(span.lemma_)
        elif label == "Metal":
            entities["Metal"].append(span.lemma_)
        elif label == "Collection":
            entities["Collection"].append(span.lemma_)
        elif label == "Type":
            entities["Type"] = span.lemma_
        elif label == "Birthstone":
            entities["Birthstone"].append(span.lemma_)
        elif label == "PriceVal":
            # Check context to determine if it's a price value or an ID
            if span.start > 0 and doc[span.start - 1].lemma_ in ["under", "less", "than", "above", "more", "not"]:
                entities["PriceVal"] = span.text
            else:
                entities["ID"] = span.text
        elif label == "PriceConstraint":
            entities["PriceConstraint"] = span.lemma_

    product_types = ["ring", "necklace", "bracelet", "earring", "pendant", "bangle", "cufflink"]
    shapes = ["round", "oval", "princess", "pear", "cushion", "marquise", "radiant"]
    styles = ["solitaire", "hidden halo", "straight", "three stone", "halo", "double halo", "split shank", "three stone halo", "wide band"]
    metals = ["white gold", "yellow gold", "rose gold", "platinum", "two tone","emerald"]
    collections = ["classic", "contemporary", "vintage inspired", "lotus", "starlight", "floral", "bujukan"]
    types = ["engagement ring", "wedding ring", "jewelry"]
    birthstones = ["january", "february", "march", "april", "may", "june", "july", "august", "september", "october", "november", "december"]

    if not entities["ProductType"]:
        entities["ProductType"] = fuzzy_match(doc.text, product_types)
    if not entities["Shape"]:
        entities["Shape"] = [fuzzy_match(token.lemma_, shapes) for token in doc if fuzzy_match(token.lemma_, shapes)]
    if not entities["Style"]:
        entities["Style"] = [fuzzy_match(token.lemma_, styles) for token in doc if fuzzy_match(token.lemma_, styles)]
    if not entities["Metal"]:
        entities["Metal"] = [fuzzy_match(token.lemma_, metals) for token in doc if fuzzy_match(token.lemma_, metals)]
    if not entities["Collection"]:
        entities["Collection"] = [fuzzy_match(token.lemma_, collections) for token in doc if fuzzy_match(token.lemma_, collections)]
    if not entities["Type"]:
        entities["Type"] = fuzzy_match(doc.text, types)
    if not entities["Birthstone"]:
        entities["Birthstone"] = [fuzzy_match(token.lemma_, birthstones) for token in doc if fuzzy_match(token.lemma_, birthstones)]

    return entities

## Check Ambiguity

In [22]:
def check_ambiguity(entity_dict):
    '''
    Input: Dictionary of entities extracted from the text
    Output: A message indicating the ambiguity.
    Note: For now the ambiguity is handled for only two specific cases (when ProductType is not provided nor the ID of the product). We will be extending it soon.
    '''
    if not entity_dict["ProductType"] and not entity_dict["ID"]:
        return "The query is ambiguous. Please provide a product type or an ID."
    return ""

# Function to generate SQL queries from the entities 

In [17]:
def entities_to_SQL(entities):
    '''
    Input: Entities extracted from the text (Dict)
    Output: SQL Query to fetch the data from the database (str)
    Note: Here we are using a rule-based method to convert the entities to SQL Query.
    '''
    message = check_ambiguity(entities)
    if message:
        return message

    # if entities["ID"]:
    #     return f"SELECT * FROM products WHERE id={entities['ID']}"

    conditions = [f"category='{entities['ProductType']}'"]

    if entities["Shape"]:
        shape_conditions = " OR ".join([f"shape='{shape}'" for shape in entities["Shape"]])
        conditions.append(f"({shape_conditions})")

    if entities["Style"]:
        style_conditions = " OR ".join([f"style='{style}'" for style in entities["Style"]])
        conditions.append(f"({style_conditions})")

    if entities["Metal"]:
        metal_conditions = " OR ".join([f"metal='{metal}'" for metal in entities["Metal"]])
        conditions.append(f"({metal_conditions})")

    if entities["Collection"]:
        collection_conditions = " OR ".join([f"collection='{collection}'" for collection in entities["Collection"]])
        conditions.append(f"({collection_conditions})")

    if entities["Birthstone"]:
        birthstone_conditions = " OR ".join([f"birthstone='{birthstone}'" for birthstone in entities["Birthstone"]])
        conditions.append(f"({birthstone_conditions})")

    if entities["PriceConstraint"] and entities["PriceVal"]:
        if entities["PriceConstraint"] in ["under", "less than"]:
            conditions.append(f"price < {entities['PriceVal']}")
        elif entities["PriceConstraint"] in ["above", "more than"]:
            conditions.append(f"price > {entities['PriceVal']}")
        elif entities["PriceConstraint"] == "not more than":
            conditions.append(f"price <= {entities['PriceVal']}")

    where_clause = " AND ".join(conditions)
    sql_query = f"SELECT * FROM products WHERE {where_clause}"

    return sql_query

# Testing the Model 

In [20]:
examples = [
    "Show me round emerald engagement rings from the classic collection",
    "I want a two tone wide band ring for my anniversary",
    "Find me earrings with gold",
    "Gold necklace from the vintage inspired collection under 1500$",
    "Show me birthstone rings for July",
    "I want a white gold bracelet not more than 500$",
    "Engagement rings in platinum and gold",
    "Show me a pendant with a ruby stone above 2000$",
    "I need a floral diamond necklace",
    "Find me a two tone bangle",
    "Gold band more than 200$",
    "Show me necklaces with emerald and sapphire stones",
    "I want the product with id 123",
    "Show me product 456",
    "rang"
]

for ex in examples:
    entities = text_to_entities(ex)
    generated_sql = entities_to_SQL(entities)
    print(f"User Query: {ex}")
    print(f"Entities: {entities}")
    print(f"Generated SQL: {generated_sql}\n")


User Query: Show me round emerald engagement rings from the classic collection
Entities: {'ProductType': 'ring', 'Shape': ['round'], 'Style': [], 'Metal': ['emerald'], 'Collection': ['classic'], 'Type': 'engagement ring', 'Birthstone': [], 'PriceVal': None, 'PriceConstraint': None, 'ID': None}
Generated SQL: SELECT * FROM products WHERE category='ring' AND (shape='round') AND (metal='emerald') AND (collection='classic')

User Query: I want a two tone wide band ring for my anniversary
Entities: {'ProductType': 'ring', 'Shape': [], 'Style': ['wide band'], 'Metal': ['two tone'], 'Collection': [], 'Type': None, 'Birthstone': [], 'PriceVal': None, 'PriceConstraint': None, 'ID': 'two'}
Generated SQL: SELECT * FROM products WHERE category='ring' AND (style='wide band') AND (metal='two tone')

User Query: Find me earrings with gold
Entities: {'ProductType': 'earring', 'Shape': [], 'Style': [], 'Metal': [], 'Collection': [], 'Type': None, 'Birthstone': [], 'PriceVal': None, 'PriceConstraint': N