### Experimentation to take list of messages and group similarity


In [1]:
import os

from haystack import Pipeline, component
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack_integrations.components.generators.ollama import OllamaGenerator
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Dict
from pydantic import BaseModel, ConfigDict
from pprint import pprint
import json
from sentence_transformers import SentenceTransformer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

2025-02-08 10:50:36.636378: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-08 10:50:36.643775: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-08 10:50:36.652266: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-08 10:50:36.654792: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-08 10:50:36.661290: I tensorflow/core/platform/cpu_feature_guar

In [2]:
hf_token = False 
with open("secrets") as file:
    for line in file.readlines():
        key,value = line.strip().split("=")
        if key == 'HF_TOKEN':
            hf_token = True
            os.environ[key]=value
assert hf_token, 'HF_TOKEN not found'

In [3]:
messages = [
    "How do I reset my password?",
    "I forgot my password, how can I recover it?",
    "What is the process to change my password?",
    "Can you help me with my account recovery?",
    "How do I update my profile information?",
    "I need to change my email address on my account.",
    "What are the store hours for the weekend?",
    "Is the store open on holidays?",
    "Can I return an item without a receipt?",
    "What is the return policy for online purchases?",
    "How do I track my order?",
    "My order hasn't arrived yet, what should I do?",
    "Can I change the shipping address for my order?",
    "How do I apply for a job at your company?",
    "Are there any job openings in the marketing department?",
    "What benefits do you offer to employees?",
    "How do I schedule an appointment?",
    "Can I reschedule my appointment online?",
    "What documents do I need to bring to my appointment?",
    "How do I cancel my subscription?",
    "What are the subscription plans available?",
    "Can I upgrade my subscription plan?",
    "How do I contact customer support?",
    "Is there a live chat option for customer support?",
    "What is the phone number for customer support?",
    "How do I download the mobile app?",
    "Is the mobile app available for both iOS and Android?",
    "How do I report a bug in the mobile app?",
    "Can I use the mobile app to make payments?",
    "What payment methods are accepted?",
    "How do I add a new payment method?",
    "Can I set up automatic payments?",
    "How do I delete my account?",
    "What happens to my data if I delete my account?",
    "Can I reactivate my account after deleting it?",
    "How do I change my notification settings?",
    "Can I turn off email notifications?",
    "How do I enable push notifications?",
    "What is the privacy policy of your company?",
    "How do you handle customer data?",
    "What security measures are in place to protect my information?"
]

In [4]:
# Load environment variable
my_env_var = os.getenv('GROQ_API')
print(my_env_var)

None


In [None]:
class Answer(BaseModel):
    groups: List[List[str]]
    model_config = ConfigDict(extra='forbid')

json_schema = Answer.model_json_schema()
pprint(json_schema)

@component
class OutputValidator:

    def __init__(self, pydantic_model: BaseModel):
        self.pydantic_model = pydantic_model
    
    @component.output_types(valid_replies = List[str], invalid_replies = Optional[List[str]], error_msg = Optional[str])
    def run(self, replies: List[str]):
        try:
            output_dict = json.loads(replies[0])
            print('replies = ', output_dict)
            self.pydantic_model.model_validate(output_dict)
            print('[OK] valid')
            return {'valid_replies': replies}
        except ValueError as e:
            print('[not OK] wrong format')
            print(replies)
            return {'invalid_replies': replies, 'error_msg': str(e)}

def extract_valid_replies(res):
    return json.loads(res['json_validator']['valid_replies'][0])

In [20]:
@component
class MessageCluster:

    def __init__(self, sent_transformer: SentenceTransformer, threshold: float = 0.5):
        self.sent_transformer = sent_transformer
        self.threshold = threshold
    
    @component.output_types(clusters = List[List[str]])
    def run(self, messages: List[str]):
            clusters = self._cluster_messages(messages, self.threshold)
            return clusters
    
    def _group_embeddings(self, sim_matrix, threshold):

        num_embeddings = sim_matrix.shape[0]
        visited = [False] * num_embeddings
        groups = []

        for i in range(num_embeddings):
            if not visited[i]:
                # sorted_row = sorted(sim_matrix[i], reverse=True)
                group = [i]
                visited[i] = True
                for j in range(num_embeddings):
                    if i != j and sim_matrix[i][j] >= threshold and not visited[j]:
                        group.append(j)
                        visited[j] = True
                groups.append(group)
        
        return groups

    def _cluster_messages(self, messages, threshold = 0.5):
        test_messages = messages.copy()

        model = SentenceTransformer("all-MiniLM-L6-v2")

        embeds = model.encode(test_messages)

        sim_scores = model.similarity(embeds, embeds)

        groups = self._group_embeddings(sim_scores, threshold=threshold)

        clustered_messages = []
        
        for group in groups:
            clustered_messages.append([messages[i] for i in group])
        return clustered_messages

In [25]:
model = SentenceTransformer("all-MiniLM-L6-v2")
cluster = MessageCluster(sent_transformer=model, threshold=0.5)
groups = cluster.run(messages)

for group in groups:
    print(group)
    print()

['How do I reset my password?', 'I forgot my password, how can I recover it?', 'What is the process to change my password?', 'Can you help me with my account recovery?', 'How do I delete my account?']

['How do I update my profile information?']

['I need to change my email address on my account.']

['What are the store hours for the weekend?', 'Is the store open on holidays?']

['Can I return an item without a receipt?', 'What is the return policy for online purchases?']

['How do I track my order?', "My order hasn't arrived yet, what should I do?"]

['Can I change the shipping address for my order?']

['How do I apply for a job at your company?']

['Are there any job openings in the marketing department?']

['What benefits do you offer to employees?']

['How do I schedule an appointment?', 'Can I reschedule my appointment online?']

['What documents do I need to bring to my appointment?']

['How do I cancel my subscription?', 'What are the subscription plans available?', 'Can I upgra

In [None]:

def rephrase_messages(messages):
  template = '''
    Given the following messages:
    {{messages}}
    Rephrase these messages into 1 message that clarify the question and keep it consise and contain most important key questions. 
    Just give the final message with no explanation and extra information.
    If you don't have enough information or the question is unclear, return 1 original message that makes the most sense.
  '''

  MODEL = 'llama3.2:3b'
  MAX_RUN = 10

  #component
  prompt_template = PromptBuilder(template=template)
  llm = OllamaGenerator(model = MODEL, url="http://localhost:11434") 
  pipe = Pipeline(max_runs_per_component=MAX_RUN)
  pipe.add_component('prompt', prompt_template)
  pipe.add_component('llm', llm)

  #make connection
  pipe.connect('prompt', 'llm')

  #run pipeline
  model = SentenceTransformer("all-MiniLM-L6-v2")
  cluster = MessageCluster(sent_transformer=model, threshold=0.6)
  groups = cluster.run(messages)

  rephrase_groups = []

  for group in groups:
    res = pipe.run({'prompt': {'messages': group}})
    if len(group) > 1:
      rephrase = res['llm']['replies']
    else:
      rephrase = group
    rephrase_groups.append(rephrase)
    # print(rephrase)
    # print(group)
    # print("#####")
  
  #build pipeline
  return rephrase_groups


In [50]:
rephrase_messages(messages)

['"How can I recover my password?"']
['How do I reset my password?', 'I forgot my password, how can I recover it?', 'What is the process to change my password?']
#####
['Can you help me with my account recovery?']
['Can you help me with my account recovery?']
#####
['How do I update my profile information?']
['How do I update my profile information?']
#####
['I need to change my email address on my account.']
['I need to change my email address on my account.']
#####
['What are the store hours for the weekend?']
['What are the store hours for the weekend?']
#####
['Is the store open on holidays?']
['Is the store open on holidays?']
#####
['Can I return an item without a receipt?']
['Can I return an item without a receipt?']
#####
['What is the return policy for online purchases?']
['What is the return policy for online purchases?']
#####
['How do I track my order?']
['How do I track my order?']
#####
["My order hasn't arrived yet, what should I do?"]
["My order hasn't arrived yet, what

[['"How can I recover my password?"'],
 ['Can you help me with my account recovery?'],
 ['How do I update my profile information?'],
 ['I need to change my email address on my account.'],
 ['What are the store hours for the weekend?'],
 ['Is the store open on holidays?'],
 ['Can I return an item without a receipt?'],
 ['What is the return policy for online purchases?'],
 ['How do I track my order?'],
 ["My order hasn't arrived yet, what should I do?"],
 ['Can I change the shipping address for my order?'],
 ['How do I apply for a job at your company?'],
 ['Are there any job openings in the marketing department?'],
 ['What benefits do you offer to employees?'],
 ['Can I reschedule my appointment online?'],
 ['What documents do I need to bring to my appointment?'],
 ['How do I cancel my subscription?'],
 ['"What are my current subscription plans available for upgrade?"'],
 ['How can I get in touch with customer support?'],
 ['Is there a live chat option for customer support?'],
 ['Is the 