In [None]:
import torch
import numpy as np
import re
import json
from transformers import BertModel, BertTokenizer
from torch import nn

def preprocess(text):
    set1 = ['۱', '۲', '۳', '۴', '۵', '۶', '۷', '۸', '۹', '۰']
    set2 = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0']
    t = str(text)
    for s1, s2 in zip(set1, set2):
        t = re.sub(s1, s2, t)
    return t

def extract_numbers(text):
    t = re.sub('[^0-9]', ' ', str(text))
    t = re.sub(r' +', ' ', t)
    t = t.strip()
    return t.split()

class ChatBot():
    def __init__(self, min_intent_accuracy, min_slots_accuracy) -> None:
        self.min_intent_accuracy = min_intent_accuracy
        self.min_slots_accuracy = min_slots_accuracy
        self.state = 'WAIT_FOR_INSTRUCTION'
        self.num_interactions = 0
        self.required_slots = []
        self.base_text = ''
        self.detected_intent = ''
        self.detected_slots = {}
        self.required_slot_not_accurate_text = None
        self.required_slot_yes_no = False

    def show_results(self):
        print(
            {
                "intent": self.detected_intent,
                "parameters": self.detected_slots
            }
        )
        print(f'chatbot -> پایان عملیات {intent2persian[self.detected_intent]} | تعداد تعاملات {self.num_interactions}.')
        print()

    def get_intents_slots(self, text, base_text=''):
        tokens = torch.tensor(tokenizer.encode(base_text + text)[:-1]).view(1,-1)
        text_tokens_length = len(tokenizer.encode(text)[1:-1])
        inputs = {}
        inputs['input_ids'] = tokens.to(device)
        inputs['token_type_ids'] = torch.zeros_like(tokens).to(device)
        inputs['attention_mask'] = torch.ones_like(tokens).to(device)
        pred_intents_logits, pred_slots_logits = model(inputs)
        pred_slots = pred_slots_logits.argmax(dim=2)
        pred_intent = pred_intents_logits.argmax(dim=1)
        intent = number2intent[np.array(pred_intent.cpu())[0]]
        slots = [number2slot[x] for x in np.array(pred_slots.cpu())[0]].copy()
        slots_accuracy = [pred_slots_logits.softmax(dim=2)[0, i, pred_slots[0][i].item()].item() for i in range(pred_slots_logits.shape[1])]
        intent_accuracy = pred_intents_logits.softmax(dim=1)[0, pred_intent.item()].item()
        return intent, intent_accuracy, slots[-text_tokens_length:], slots_accuracy[-text_tokens_length:], list(np.array(tokens.reshape(-1)[1:].cpu()))[-text_tokens_length:]
    
    def fill_required_slots(self, text:str, slots:list, slots_accuracy:list, tokens:list):
        required_slots = self.required_slots
        required_slot_found = True
        required_slot_not_accurate = False
        for i,required_slot in enumerate(required_slots):
            if slot2persian[required_slot].startswith('آیا'):
                continue
            start_index = -1
            end_index = -1
            if f'b-{required_slot}' in slots: start_index = slots.index(f'b-{required_slot}')
            if f'i-{required_slot}' in slots: end_index = len(slots) - slots[-1::-1].index(f'i-{required_slot}')
            if start_index == -1 and end_index == -1:
                if i == 0 : 
                    required_slot_found = False
            else:
                if start_index == -1 and end_index != -1:
                    start_index = slots.index(f'i-{required_slot}')
                elif start_index != -1 and end_index == -1:
                    end_index = len(slots) - slots[-1::-1].index(f'b-{required_slot}')
                mean_accuracy = np.mean(slots_accuracy[start_index:end_index])
                numbers = extract_numbers(text)
                numbers_c = 0
                decoded_texts = [tokenizer.decode([x]) for x in tokens][start_index:end_index]
                for i in range(len(decoded_texts)):
                    if decoded_texts[i] == '[UNK]':
                        decoded_texts[i] = numbers[numbers_c]
                        numbers_c += 1
                if mean_accuracy > 0.9:
                    self.required_slots.remove(required_slot)
                    self.detected_slots[required_slot] = " ".join(decoded_texts)
                else:
                    required_slot_not_accurate = True
                    self.required_slot_not_accurate_text = " ".join(decoded_texts)
        if not required_slot_found:
            return 'NOT_FOUND'
        elif required_slot_not_accurate:
            return 'NOT_ACCURATE'
        else:
            return 'FOUND'
    
    def handle_state(self, **kwargs):
        if self.state == 'WAIT_FOR_INSTRUCTION':
            text, intent, intent_accuracy, slots, slots_accuracy, tokens = kwargs['text'], kwargs['intent'], kwargs['intent_accuracy'], kwargs['slots'], kwargs['slots_accuracy'], kwargs['tokens']
            # self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
            self.detected_intent = intent
            self.base_text = f'{text} '
            self.required_slots = intent2slots[intent].copy()
            self.fill_required_slots(text, slots, slots_accuracy, tokens)
            if intent_accuracy < 0.9:
                self.state = 'CONFIRM_OPERATION'
            else:
                if len(self.required_slots) == 0:
                    self.state = 'WAIT_FOR_INSTRUCTION'
                    self.show_results()
                    self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
                else:
                    self.state = 'GATHER_SLOTS'
        elif self.state == 'CONFIRM_OPERATION':
            text = kwargs['text']
            if text == 'بله':
                if len(self.required_slots) == 0:
                    self.state = 'WAIT_FOR_INSTRUCTION'
                    self.show_results()
                    self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
                else:
                    self.state = 'GATHER_SLOTS'
            elif text == 'خیر':
                self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
            else:
                intent, intent_accuracy, slots, slots_accuracy, tokens = self.get_intents_slots(text)    
                self.state = 'WAIT_FOR_INSTRUCTION'
                self.handle_state(text=text, intent=intent, intent_accuracy=intent_accuracy, slots=slots, slots_accuracy=slots_accuracy, tokens=tokens)
        elif self.state == 'GATHER_SLOTS':
            text, intent, intent_accuracy, slots, slots_accuracy, tokens = kwargs['text'], kwargs['intent'], kwargs['intent_accuracy'], kwargs['slots'], kwargs['slots_accuracy'], kwargs['tokens']
            if self.required_slot_yes_no:
                if text == 'بله' or text == 'خیر':
                    self.detected_slots[self.required_slots[0]] = text
                    self.required_slots.remove(self.required_slots[0])
                    self.required_slot_yes_no = False
            else:
                status = self.fill_required_slots(text, slots, slots_accuracy, tokens)
                if status == 'FOUND':
                    pass
                elif status == 'NOT_FOUND':
                    self.state = 'SLOT_NOT_FOUND'
                elif status == 'NOT_ACCURATE':
                    self.state = 'CONFIRM_SLOT'
            if len(self.required_slots) == 0:
                        self.state = 'WAIT_FOR_INSTRUCTION'
                        self.show_results()
                        self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
        elif self.state == 'CONFIRM_SLOT':
            text = kwargs['text']
            if text == 'بله':
                self.detected_slots[self.required_slots[0]] = self.required_slot_not_accurate_text
                self.required_slot_not_accurate_text = None
                self.required_slots.remove(self.required_slots[0])
                if len(self.required_slots) == 0:
                    self.state = 'WAIT_FOR_INSTRUCTION'
                    self.show_results()
                    self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
                else:
                    self.state = 'GATHER_SLOTS'
            elif text == 'خیر':
                self.required_slot_not_accurate_text = None
                self.state = 'GATHER_SLOTS'
        elif self.state == 'SLOT_NOT_FOUND':
            text, intent, intent_accuracy, slots, slots_accuracy, tokens = kwargs['text'], kwargs['intent'], kwargs['intent_accuracy'], kwargs['slots'], kwargs['slots_accuracy'], kwargs['tokens']
            self.state = 'GATHER_SLOTS'
            self.handle_state(text=text, intent=intent, intent_accuracy=intent_accuracy, slots=slots, slots_accuracy=slots_accuracy, tokens=tokens)

    def get_prompt(self):
        if self.state == 'WAIT_FOR_INSTRUCTION':
            prompt = 'با سلام، چگونه میتوانم کمکتان کنم؟'
        elif self.state == 'CONFIRM_OPERATION':
            translated_intent = intent2persian[self.detected_intent]
            prompt = f'آیا قصد انجام عملیات {translated_intent} را دارید؟ پاسخ شما میتواند بله یا خیر یا ذکر عملیات مورد نظر با جزئیات بیشتر باشد.'
        elif self.state == 'GATHER_SLOTS':
            if slot2persian[self.required_slots[0]].startswith('آیا'):
                prompt = f'{slot2persian[self.required_slots[0]]} (با بله یا خیر پاسخ دهید)'
                self.required_slot_yes_no = True
            else:
                prompt = f'لطفا {slot2persian[self.required_slots[0]]} را وارد کنید. (میتوانید در صورت تمایل اطلاعات ضروری دیگری نیز وارد کنید)'
        elif self.state == 'CONFIRM_SLOT':
            prompt = f'آیا {slot2persian[self.required_slots[0]]} ، {self.required_slot_not_accurate_text} است؟ (با بله یا خیر پاسخ دهید)'
        elif self.state == 'SLOT_NOT_FOUND':
            prompt = f'{slot2persian[self.required_slots[0]]} به درستی وارد نشد. لطفا مقدار درست آنرا وارد کنید.'
        return prompt

    def chat(self):
        last_command_ran = True
        while True:
            prompt = self.get_prompt()
            text = input()
            text = preprocess(text)
            if last_command_ran:
                print(f'chatbot -> {prompt}' )
            if text == '':
                last_command_ran = False
                continue
            print(f'user    -> {text}', )
            self.num_interactions += 1
            if text == 'خروج':
                break
            elif text == 'لغو':
                self.__init__(self.min_intent_accuracy, self.min_slots_accuracy)
                last_command_ran = True
                continue
            if self.state == 'WAIT_FOR_INSTRUCTION':
                intent, intent_accuracy, slots, slots_accuracy, tokens = self.get_intents_slots(text)
                self.handle_state(text=text, intent=intent, intent_accuracy=intent_accuracy, slots=slots, slots_accuracy=slots_accuracy, tokens=tokens)
            elif self.state == 'CONFIRM_OPERATION':
                self.handle_state(text=text)
            elif self.state == 'GATHER_SLOTS':
                intent, intent_accuracy, slots, slots_accuracy, tokens = self.get_intents_slots(f'{slot2persian[self.required_slots[0]]} {text}', self.base_text)
                self.handle_state(text=text, intent=intent, intent_accuracy=intent_accuracy, slots=slots, slots_accuracy=slots_accuracy, tokens=tokens)
            elif self.state == 'CONFIRM_SLOT':
                self.handle_state(text=text)
            elif self.state == 'SLOT_NOT_FOUND':
                intent, intent_accuracy, slots, slots_accuracy, tokens = self.get_intents_slots(f'{slot2persian[self.required_slots[0]]} {text}', self.base_text)
                self.handle_state(text=text, intent=intent, intent_accuracy=intent_accuracy, slots=slots, slots_accuracy=slots_accuracy, tokens=tokens)
            last_command_ran = True


class MainModel(nn.Module):
    def __init__(self, bert, num_slots, num_intents) -> None:
        super().__init__()
        self.bert = bert
        self.intent_fc = nn.Linear(bert.config.hidden_size, num_intents)
        self.slot_fc = nn.Linear(bert.config.hidden_size, num_slots)

    def forward(self, x):
        x = self.bert(**x).last_hidden_state
        return self.intent_fc(x[:, 0]), self.slot_fc(x[:, 1:])


with open("nlu_project_saves/nlu_project_slot2number.json", "r", encoding='utf8') as file:
    slot2number = json.load(file)
with open("nlu_project_saves/nlu_project_intent2number.json", "r", encoding='utf8') as file:
    intent2number = json.load(file)    
with open('nlu_project_saves/nlu_project_intent2persian.json', 'r', encoding='utf8') as file:
    intent2persian = json.load(file)
with open('nlu_project_saves/nlu_project_slot2persian.json', 'r', encoding='utf8') as file:
    slot2persian = json.load(file)
intent2slots = {}
for obj in json.load(open('./nlu_project_saves/nlu_project_intents-slots.json', 'r', encoding='utf8')):
    intent2slots[obj['intent']] = obj['slots']
number2slot = {slot2number[slot]:slot for slot in list(slot2number.keys())}

number2intent = {intent2number[intent]:intent for intent in list(intent2number.keys())}
tokenizer = BertTokenizer.from_pretrained('HooshvareLab/bert-base-parsbert-uncased')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MainModel(BertModel.from_pretrained('./nlu_project_saves/nlu_project_bert/'), len(slot2number.keys()), len(intent2number.keys())).to(device)
model.intent_fc.load_state_dict(torch.load('nlu_project_saves/nlu_project_intent.pth', map_location='cuda'))
model.slot_fc.load_state_dict(torch.load('nlu_project_saves/nlu_project_slot.pth', map_location='cuda'))


chatbot = ChatBot(min_intent_accuracy=0.9, min_slots_accuracy=0.8)
chatbot.chat()