# Topic Transition Model Class


  MUKALMA - A Knowledge-Powered Conversational Agent
  Project Id: F21-20-R-KBCAgent

  Class that is responsible for keeping track of the topic a moving conversation is centered around.
  It detects changes in topics and updates its information about the current topic after every conversation turn.
  If it decides that the topic has changed, it will look for a new central topic.
  If the topic is the same, it will slightly tweak the focus of the conversation.

  @Author: Muhammad Farjad Ilyas
  @Date: 29th March 2022

### Imports

In [44]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from torch.cuda import is_available as is_cuda_available
from scipy.cluster.vq import kmeans

# -------------------------------------------------
import nltk
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.tokenize import word_tokenize, sent_tokenize
from scipy.cluster.vq import kmeans

# For Utilities
import math
import time
import numpy as np
# -------------------------------------------------

## TopicTransitionClass

### Functions

In [45]:
def list_sorted_args(l, reverse=False):
    return sorted(range(len(l)), key=l.__getitem__, reverse=reverse)

def find_highest_similarity_scores(scores, n=3):
    s_idxs = list_sorted_args(scores)
    s = [scores[i] for i in s_idxs]
    s_len = len(s)
    s_range = range(s_len)

    kclust = kmeans(np.matrix([s_range, s]).transpose(), n)
    assigned_clusters = [abs(kclust[0][:, 0] - e).argmin() for e in s_range]
    print(assigned_clusters)

    highest_cluster = assigned_clusters[-1]
    highest_idxs = []
    for i in range(s_len - 1, -1, -1):
        if assigned_clusters[i] != highest_cluster:
            return highest_idxs
        highest_idxs.append(s_idxs[i])
    return highest_idxs

In [46]:
class TopicTransitionModel:
    """
      Compares the current message with the previous message and a control message using sentence similarity.
      Uses the intuition that if the current message is closer to the generic control message that doesn't involve
      a particular topic, compared to a possibly topic-related previous message, then the topic may have changed.

      Updates the keywords that are relevant to the conversation by detecting topic changes. Gives a higher priority
      to keywords that have occurred recently by placing them earlier in the list of keywords in ascending order of
      indexing.
      
      Uses a filtering method to pass previous keywords that are still relevant to the conversation turn. This 
      helps the model simulate conversation context and aids in topic change
    """

    __control_msg = "Hey! How are you doing?"

    def __init__(self, model=None, model_path='../../models/all-MiniLM-L6-v2', use_cuda=False):
        # Keeping track of previous message and keywords that will pass forward
        self.sent_changed_topic = self.prev_msg = TopicTransitionModel.__control_msg
        self.prev_keywords = []
        self.pass_through = []

        # Keeping track of False Postives
        self.FALSE_TOPIC_CHANGE_LIMIT = 3
        self.false_topic_change = 0

        # Sentence transformer model
        self.model = SentenceTransformer(
            model_path, device=('cuda' if use_cuda and is_cuda_available() else 'cpu')
        ) if model is None else model

    def calc_sentence_similarity(self, msg, candidates):
        msg_embedding = self.model.encode([msg])
        candidate_embeddings = self.model.encode(candidates)
        distances = cosine_similarity(msg_embedding, candidate_embeddings).flatten()
        return distances

    def order_keywords_by_similarity(self, msg, keywords):
        s_scores = self.calc_sentence_similarity(msg, keywords)
        s_idxs = list_sorted_args(s_scores, reverse=True)
        t_keywords = [keywords[i] for i in s_idxs]
        return t_keywords

    def has_topic_changed(self, msg, prev_msg, control_msg, error_threshold = -0.02):
        distances = self.calc_sentence_similarity(msg, [prev_msg, control_msg])
        return (distances[1] - distances[0]) > error_threshold

    def update_topic(self, message, c_keywords):
        
        # Comparing the current message to the previous keywords
        if len(self.prev_keywords) > 0:
            self.prev_keywords = list(set(self.prev_keywords).difference(c_keywords)) 
            s_scores = self.calc_sentence_similarity(msg, self.prev_keywords)

            if len(self.prev_keywords) >= 3:
                self.pass_through = [self.prev_keywords[i] for i in find_highest_similarity_scores(s_scores, 2 if len(self.prev_keywords) <= 3 else 3)]
            else:
                self.pass_through = self.prev_keywords
        # End if
        
        # Calculating topic changes
        topic_change_from_prev_msg = self.has_topic_changed(msg, self.prev_msg, TopicTransitionModel.__control_msg)
        topic_change_from_prev_topic = self.has_topic_changed(msg, self.sent_changed_topic, TopicTransitionModel.__control_msg, error_threshold=0.05)

        # If topic changes
        if topic_change_from_prev_msg or topic_change_from_prev_topic:    
            self.sent_changed_topic = msg
            
        # Adding new keywords
        t_keywords = c_keywords + self.pass_through

        if len(t_keywords) > 0:
            t_keywords = self.order_keywords_by_similarity(msg, t_keywords)
        
        # Setting Previous keywords
        self.prev_msg = message
        self.prev_keywords = t_keywords
        
        return t_keywords
    # End of function
    
# End of class

## POS Tagging

In [47]:
def tag_sentence(message):
    tokenized = sent_tokenize(message)
    nouns = []
    for sentence in tokenized:
        wordsList = word_tokenize(sentence)
        tagged = nltk.pos_tag(wordsList)
        nouns.extend([tag[0] for tag in tagged if tag[1][:2] in ['NN', 'CD'] and tag[0].lower() not in ['hi', 'hey']])
    return nouns

### Testing

In [49]:
model = TopicTransitionModel()

# Main loop to simulate conversation
while True:
    # Taking user input
    msg = input('\nUser: ')
    t1 = time.time()
    
    # Breaking if user enters 'exit'
    if msg == 'exit':
        break
        
    topics = model.update_topic(msg, tag_sentence(msg))
    print (f"topics: {topics}")
        
    # Updating previous messages and state for the next turn
    print(f"time elapsed: {time.time() - t1}")
    print ("#" * 100)


User: I recently travelled to Paris
topics: ['Paris']
time elapsed: 0.11799860000610352
####################################################################################################

User: I went to Karachi and saw the beach and the sea
topics: ['Karachi', 'beach', 'sea', 'Paris']
time elapsed: 0.21702265739440918
####################################################################################################

User: I see. Have you gotten your COVID vaccine yet?
[0, 1, 1, 2]
topics: ['COVID', 'vaccine', 'sea']
time elapsed: 0.1700115203857422
####################################################################################################

User: No I haven't
[1, 1, 0]
topics: ['COVID']
time elapsed: 0.16199755668640137
####################################################################################################


KeyboardInterrupt: Interrupted by user