# Setup

## Mount drive, download PG model from repo, import pretrained weights

In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)

# Enter the foldername
FOLDERNAME = 'cs224n-project'

assert FOLDERNAME is not None, "[!] Enter the foldername."

%cd drive/MyDrive
%cd $FOLDERNAME
%ls .

%cd /content/
#install disfluency generator and copy over model parameters & weights
!git clone https://github.com/SALT-NLP/Disfluency-Generation-and-Detection.git

# Need to be in the disf_gen_coarse2fine folder
%cd /content/Disfluency-Generation-and-Detection/disf_gen_coarse2fine/

%cp /content/drive/MyDrive/cs224n-project/opt.json ./opt.json
%cp /content/drive/MyDrive/cs224n-project/m_30.pt ./m_30.pt

## Install necessary packages

In [None]:
# Following pip packages need to be installed:
!pip install git+https://github.com/huggingface/transformers sentencepiece datasets
!pip install torchtext==0.4.0
!pip install nltk

#Install necessary NLTK packages
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

## Modified PG model functions and helper methods

In [None]:
# helpers
from nltk.tokenize import word_tokenize
from nltk import pos_tag
import os, sys, re

def normalize(raw: str):
    return re.sub(r"[^A-Za-z0-9 ']+", "", raw.lower())

def format_disfl_input(dialog):
    dialog = normalize(dialog)
    tokens = word_tokenize(dialog)
    text = " ".join(tokens)
    pos = [pos[1] for pos in pos_tag(tokens)]
    o = " ".join(["O"] * len(tokens))

    return [text, pos, o]

def format_audio_input(disfluency):
    text_arr = disfluency.tgt
    io_arr = disfluency.tgt_tags
    
    ret = []
    for i, word in enumerate(text_arr):
        if io_arr[i] == "O":
            ret.append(word)
            if i<len(text_arr)-1 and io_arr[i+1] == "I":
                ret.append(" -- ")
        if io_arr[i] == "I":
            ret.append(word)
            if i<len(text_arr)-1 and io_arr[i+1] == "O":
                ret.append(" -- ")
    return " ".join(ret)

class HiddenPrints:
    #https://colab.research.google.com/github/jimit105/pytricks/blob/master/Hide%20print.ipynb
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [None]:
from table.IO import make_src, make_tgt, merge_vocabs, join_dicts, _dynamic_dict
import torchtext
from collections import Counter, defaultdict

UNK_WORD = '<unk>'
UNK = 0
PAD_WORD = '<pad>'
PAD = 1
BOS_WORD = '<bos>'
BOS = 2
EOS_WORD = '<eos>'
EOS = 3
EOD_WORD = '<eod>'
EOD = 4
IOD_WORD = '<iod>'
IOD = 5

BOD_LABEL='E'
NBOD_LABEL='N'

DISF_LABEL='I'
FLT_LABEL='O'

def modified_read_anno(reader, opt):
    opt.include_flt = True
    js_list=[]
    num_all=0
    i = -1
    for line in reader:
        i += 1
        if i % 4 == 0:
            num_all+=1
            js_list.append({'sent':[token for token in line.strip().split()]})
        elif i % 4 == 1:
            continue
        if i % 4 == 2:
            l = line.strip().split()
            assert (len(l) == len(js_list[-1]['sent']))
            if not opt.include_flt:
                if l==['O']*len(l):
                    js_list.pop()
                    continue
            if l == ['I'] * len(l):
                js_list.pop()
                continue
            assert (len(l)>0)
            js_list[-1]['sent_tag']=l
        else:
            continue
    # print(reader, ' all_size:', num_all, ' disf_size:',len(js_list))
    if 'gold_diversity' in opt.__dict__ and opt.gold_diversity:
        for dic in js_list:
            disfs=[]
            indisf=0
            for i in range(len(dic['sent_tag'])):
                if indisf==0 and dic['sent_tag'][i]=='I':
                    disfs.append([dic['sent'][i]])
                    indisf=1
                elif dic['sent_tag'][i]=='I':
                    disfs[-1].append(dic['sent'][i])
                elif indisf==1 and dic['sent_tag'][i]=='O':
                    indisf=0
                else:
                    pass
            dic['disf_frags']=disfs

    for dic in js_list:
        dic['fsent']=[]
        dic['fsent_tag']=[]
        for i in range(len(dic['sent_tag'])):
            dic['fsent'].append(dic['sent'][i])
            dic['fsent_tag'].append(dic['sent_tag'][i])
            if dic['sent_tag'][i]=='I' and (i==len(dic['sent_tag'])-1 or dic['sent_tag'][i+1]=='O'):
                dic['fsent'].append(EOD_WORD)
                dic['fsent_tag'].append('I')
        line = []
        assert (len(dic['fsent_tag']) > 0)
        line.append('E' if dic['fsent_tag'][0] == 'I' else 'N')
        for i in range(len(dic['fsent_tag'])):
            if dic['fsent_tag'][i] == 'O':
                if i < len(dic['fsent_tag']) - 1 and dic['fsent_tag'][i + 1] == 'I':
                    line.append('E')
                else:
                    line.append('N')
        dic['src_label']=line
        line = []
        for w, t in zip(dic['fsent'], dic['fsent_tag']):
            if t == 'O':
                line.append(w)
        dic['src']=line
    return js_list

def modified_translate_opts(parser, model_path):
    parser.add_argument('-root_dir', default='',
                        help="Path to the root directory.")
    parser.add_argument('-dataset', default='swbd',
                        help="Name of dataset.")
    parser.add_argument('-tag_type', default='IO',
                        help="Type of tag system")
    parser.add_argument('-model_path', default=model_path, #required=True,
                        help='Path to model .pt file')
    parser.add_argument('-split', default="test",
                        help="Path to the evaluation annotated data")
    #parser.add_argument('-output', default='pred.txt',
                        #help="""Path to output the predictions (each line will be the decoded sequence""")
    parser.add_argument('-run_from', type=int, default=0,
                        help='Only evaluate run.* >= run_from.')
    parser.add_argument('-batch_size', type=int, default=1,
                        help='Batch size')
    parser.add_argument('-beam_size', type=int, default=0,
                        help='Beam size')
    parser.add_argument('-n_best', type=int, default=1,
                        help='N-best size')
    '''parser.add_argument('-max_lay_len', type=int, default=50,
                        help='Maximum layout decoding length.')
    parser.add_argument('-max_tgt_len', type=int, default=100,
                        help='Maximum tgt decoding length.')'''
    parser.add_argument('-max_disf_len', type=int, default=8,
                        help='Maximum layout decoding length.')
    parser.add_argument('-gpu', type=str, default='0',
                        help="Device to run on")
    parser.add_argument('-gold_layout', action='store_true',
                        help="Given the golden layout sequences for evaluation.")

    parser.add_argument('-random_layout', action='store_true',
                        help="Use random layout")

    parser.add_argument('-flt_gen', action='store_true',
                        help="Regenerate fluent part")

    parser.add_argument('-gold_diversity', action='store_true',
                        help="Report gold diversity")

    parser.add_argument('-no_in_sent_word', action='store_true',
                        help="no_in_sent_word")

    parser.add_argument('-random_choose_topk', action='store_true',
                        help="random_choose_topk")

    parser.add_argument('-random_sample', action='store_true',
                        help="random_sample")

    parser.add_argument('-eval_diversity', action='store_true',
                        help="evaluate diversity")

    parser.add_argument('-gen_eod', action='store_true',
                        help="generate eod")

    parser.add_argument('-attn_ignore_small', type=float, default=0,
                        help="Ignore small attention scores.")
    parser.add_argument('-include_flt', action='store_true', help='include fluent sentences during generation')
    parser.add_argument('-sample_num', type=int, default=1,
                        help='Number of samples in each step')

    parser.add_argument('-translate_num', type=int, default=0,
                        help='Number of translation sentences')

    parser.add_argument('-queue_size', type=int, default=50,
                        help='Number of translation sentences')

    parser.add_argument('-temperature', type=float, default=1.0,
                        help='temperature of flatting logits')

    parser.add_argument('-random_mask_eod', type=float, default=0.0,
                        help="""During generation, mask EOD with thos prob""")

    parser.add_argument('-output_file', default='pred.txt',
                        help="""Path to output the predictions and score (each line will be the decoded sequence""")

    parser.add_argument('-master_port', default='12355',
                        help="""Master Port""")

class ModifiedTableDataset(torchtext.data.Dataset):

    @staticmethod
    def sort_key(ex):
        "Sort in reverse size order"
        if 'src' in ex.__dict__:
            return -len(ex.src)
        else:
            return -len(ex.sent)

    def __init__(self, anno, fields, opt, **kwargs):
        """
        Create a TranslationDataset given paths and fields.
        anno: location of annotated data
        filter_ex: False - keep all the examples for evaluation (should not have filtered examples); True - filter examples with unmatched spans;
        """
        js_list = anno

        self.opt=opt

        if opt.disf_seg:
            sent_data = self._read_annotated_file(opt, js_list, 'sent')
            sent_examples = self._construct_examples(sent_data, 'sent')

            sent_tag_data = self._read_annotated_file(opt, js_list, 'sent_tag')
            sent_tag_examples = self._construct_examples(sent_tag_data, 'sent_tag')
        else:
            opt.no_disf_trans=False

        if opt.no_disf_trans:
            assert (opt.disf_seg)
            examples = [join_dicts(*it) for it in
                        zip(sent_examples, sent_tag_examples)]
                        
        else:
            src_data = self._read_annotated_file(opt, js_list, 'src')
            src_examples = self._construct_examples(src_data, 'src')

            assert(opt.disf_seg==False)
            examples = [join_dicts(*it) for it in
                        zip(src_examples)]

        # the examples should not contain None
        len_before_filter = len(examples)
        examples = list(filter(lambda x: all(
            (v is not None for k, v in x.items())), examples))
        len_after_filter = len(examples)
        num_filter = len_before_filter - len_after_filter
        if num_filter > 0:
            print('Filter #examples (with None): {} / {} = {:.2%}'.format(num_filter,
                                                                          len_before_filter,
                                                                  num_filter / len_before_filter))
        if not opt.no_disf_trans:
            self.src_vocabs = []
            for ex_dict in examples:
                src_ex_vocab, ex_dict = _dynamic_dict(
                    ex_dict)
                self.src_vocabs.append(src_ex_vocab)

        # Peek at the first to see which fields are used.
        ex = examples[0]
        keys = ex.keys()
        fields = [(k, fields[k])
                  for k in (list(keys) + ["indices"])]

        super(ModifiedTableDataset, self).__init__(
            self.construct_final(examples,fields,keys), fields, None)

    def __getattr__(self, attr):
        # avoid infinite recursion when fields isn't defined
        if 'fields' not in vars(self):
            raise AttributeError
        if attr in self.fields:
            return (getattr(x, attr) for x in self.examples)
        else:
            raise AttributeError

    def construct_final(self,examples,fields,keys):
        exs=[]
        for i, ex in enumerate(examples):
            exs.append(torchtext.data.Example.fromlist(
                [ex[k] for k in keys] + [i],
                fields))
        return exs

    def filter_pred(self,example):
        if self.test:
            return True
        if not self.opt.no_disf_trans and (len(example.src)>self.opt.src_seq_length or len(example.tgt)>self.opt.tgt_seq_length):
            return False
        if self.opt.disf_seg and len(example.sent)>self.opt.tgt_seq_length:
            return False
        return True

    def _read_annotated_file(self, opt, js_list, field):
        l=[]
        if field == 'src':
            for dic in js_list:
                l.append(dic['src'])
            return l
        elif field == 'sent':
            for dic in js_list:
                l.append(dic['sent'])
            return l
        elif field == 'sent_tag':
            for dic in js_list:
                l.append([FLT_LABEL]+dic['sent_tag'])
            return l
        elif field == 'src_label':
            for dic in js_list:
                l.append(dic['src_label'])
            return l
        elif field=="lay_index":
            for dic in js_list:
                line = [0]
                i=1
                for w, t in zip(dic['fsent'], dic['fsent_tag']):
                    if t == 'O':
                        line.append(i)
                        i += 1
                    else:
                        line.append(0)
                l.append(line)
            return l
        elif field=="tgt_mask":
            for dic in js_list:
                if 'no_connection_decoder' in opt.__dict__ and opt.no_connection_decoder:
                    line=[1]+[1]*len(dic['fsent_tag'])
                else:
                    if 'decoder_word_input' in opt.__dict__ and opt.decoder_word_input:
                        line = []
                        line.append(0 if dic['fsent_tag'][0] == 'I' else 1)
                        for i in range(len(dic['fsent_tag'])):
                            if dic['fsent_tag'][i] == 'O' and i < len(dic['fsent_tag']) - 1 and dic['fsent_tag'][
                                i + 1] == 'I':
                                line.append(0)
                            else:
                                line.append(1)
                    else:
                        line = [0] + [1 if t == 'I' else 0 for t in dic['fsent_tag']]
                l.append(line)
            return l
        elif field=="tgt_loss_mask":
            for dic in js_list:
                line = [0 if t=='I' else 1 for t in dic['fsent_tag']] + [1]
                l.append(line)
            return l
        elif field=="tgt":
            for dic in js_list:
                l.append(dic['fsent'])
                '''line=[w if t=='I' else PAD_WORD for w, t in zip(dic['fsent'], dic['fsent_tag']) ]
                l.append(line)'''
            return l
        elif field=="tgt_loss":
            for dic in js_list:
                l.append(dic['fsent'])
            return l
        else:
            raise NotImplementedError


    def _construct_examples(self, lines, side):
        l=[]
        for words in lines:
            example_dict = {side: words}
            l.append(example_dict)
        return l

    def save(self, path, remove_fields=True):
        if remove_fields:
            self.fields = []
        torch.save(self, path)

    @staticmethod
    def load_fields(vocab):
        vocab = dict(vocab)
        fields = ModifiedTableDataset.get_fields()
        for k, v in vocab.items():
            # Hack. Can't pickle defaultdict :(
            v.stoi = defaultdict(lambda: 0, v.stoi)
            fields[k].vocab = v
        return fields

    @staticmethod
    def save_vocab(fields):
        vocab = []
        for k, f in fields.items():
            if 'vocab' in f.__dict__:
                f.vocab.stoi = dict(f.vocab.stoi)
                vocab.append((k, f.vocab))
        return vocab

    @staticmethod
    def get_fields(opt=None):
        fields = {}
        fields["sent"] = torchtext.data.Field(
            init_token=IOD_WORD, pad_token=PAD_WORD, include_lengths=True, lower=opt.lower if opt else True)
        fields["sent_tag"] = torchtext.data.Field(
            pad_token=PAD_WORD, lower=False)
        fields["src"] = torchtext.data.Field(
            init_token=BOS_WORD,pad_token=PAD_WORD, include_lengths=True,lower=opt.lower if opt else True)
        fields["src_label"] = torchtext.data.Field(
            pad_token=PAD_WORD, lower=False)
        fields["lay_index"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["tgt_mask"] = torchtext.data.Field(
            use_vocab=False, dtype=torch.float, pad_token=1)
        fields["tgt_loss_mask"] = torchtext.data.Field(
            use_vocab=False, dtype=torch.long, pad_token=1)
        fields["tgt"] = torchtext.data.Field(
            init_token=BOS_WORD, pad_token=PAD_WORD,lower=opt.lower if opt else True)

        fields["tgt_loss"] = torchtext.data.Field(
            eos_token=EOS_WORD, pad_token=PAD_WORD,lower=opt.lower if opt else True)

        fields["src_map"] = torchtext.data.Field(use_vocab=False, dtype=torch.float,
            postprocessing=make_src, sequential=False)
        fields["src_ex_vocab"] = torchtext.data.RawField()
        fields["alignment"] = torchtext.data.Field(use_vocab=False, dtype=torch.long,
            postprocessing=make_tgt, sequential=False)
        fields["indices"] = torchtext.data.Field(
            use_vocab=False, sequential=False)

        return fields

    @staticmethod
    def build_vocab(train, dev, test, opt):
        fields = train.fields

        if opt.disf_seg:
            for field_name in ('sent', 'sent_tag'):
                fields[field_name].build_vocab(
                    train, min_freq=opt.src_words_min_frequency)

        if not opt.no_disf_trans:
            src_vocab_all = []
            # build vocabulary only based on the training set
            # the last one should be the variable 'train'
            for split in (dev, test, train,):
                fields['src'].build_vocab(split, min_freq=0)
                src_vocab_all.extend(list(fields['src'].vocab.stoi.keys()))

            # build vocabulary only based on the training set
            for field_name in ('src', 'src_label'):
                fields[field_name].build_vocab(
                    train, min_freq=opt.src_words_min_frequency)
            if opt.disf_seg:
                src_merge_name_list = ['src', 'sent']
                src_merge = merge_vocabs([fields[field_name].vocab for field_name in src_merge_name_list],
                                        min_freq=opt.src_words_min_frequency)
                for field_name in src_merge_name_list:
                    fields[field_name].vocab = src_merge

            # build vocabulary only based on the training set
            for field_name in ('tgt', 'tgt_loss'):
                fields[field_name].build_vocab(
                    train, min_freq=opt.tgt_words_min_frequency)

            tgt_merge_name_list = ['tgt', 'tgt_loss']
            tgt_merge = merge_vocabs([fields[field_name].vocab for field_name in tgt_merge_name_list],
                                     min_freq=opt.tgt_words_min_frequency)
            for field_name in tgt_merge_name_list:
                fields[field_name].vocab = tgt_merge

## Create Models

### Dialog Chatbot Model

In [None]:
# create DialoGPT-large model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

dialog_mname = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(dialog_mname)
dialoGPT_model = AutoModelForCausalLM.from_pretrained(dialog_mname)

### TTS Model

In [None]:
!pip install gtts

In [None]:
# create TTS model
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import torch
import soundfile as sf

processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

# load xvector containing speaker's voice characteristics from a dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)

### Disfluency Generation Model

In [None]:
# create disfluency model
import torch
import table
import table.IO
import opts
import argparse
import glob

with HiddenPrints():
    parser = argparse.ArgumentParser(description='generate.py')
    modified_translate_opts(parser, "m_30.pt")
    opt = parser.parse_args("")
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opt.dataset = opt.dataset + opt.tag_type
    opt.anno = os.path.join(opt.root_dir, opt.dataset, '{}.txt'.format(opt.split))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if opt.beam_size > 0:
        opt.batch_size = 1

    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        translator = table.Translator(opt, dummy_opt.__dict__)

## Helper methods to call models

In [None]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
from IPython.display import Audio, display
from gtts import gTTS

# helpers
def generate_disfluency(inp):
    js_list = modified_read_anno(inp, opt)

    data = ModifiedTableDataset(
            js_list, translator.fields, translator.model_opt)
    test_data = table.IO.OrderedIterator(
                dataset=data, device=device, batch_size=opt.batch_size, train=False, sort=True, sort_within_batch=False)
    
    # inference
    r_list = []
    with torch.no_grad():
        #print(test_data)
        for batch in test_data:
            r = translator.translate(batch, js_list)
            r_list += r

    r_list.sort(key=lambda x: x.idx)
    assert len(r_list) == len(js_list), 'len(r_list) != len(js_list): {} != {}'.format(
        len(r_list), len(js_list))

    pred = r_list[0]
    return pred

def generate_audio(inp, use_gtts = False):
    inp = TreebankWordDetokenizer().detokenize(inp.split())
    if use_gtts:
        speech = gTTS(inp)
        speech.save("speech.wav")
    else:
        inputs = processor(text=inp, return_tensors="pt")
        speech = tts_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
        sf.write("speech.wav", speech.numpy(), samplerate=16000)
    wn = Audio("speech.wav", autoplay=True) ##
    display(wn)

# Pipeline

In [None]:
from nltk.tokenize.treebank import TreebankWordDetokenizer
from nltk.tokenize import word_tokenize
import time
from transformers.utils import logging
import re

logging.set_verbosity_error()

use_gtts = True if input("Use gTTS? [Y/N]: ").lower() == 'y' else False
k = int(input("Number of lines to chat for: ")) # number of lines to chat for

for step in range(k):
    with HiddenPrints():
        if dialog_mname == "microsoft/DialoGPT-large":
            new_user_input_ids = tokenizer.encode(input(">> User: ") + tokenizer.eos_token, return_tensors='pt')

            # append the new user input tokens to the chat history
            bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

            # generated a response while limiting the total chat history to 1000 tokens, 
            chat_history_ids = dialoGPT_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

            response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
        else:
            pipe()
    response_utterances = re.split('?.!,', response)
    full = []
    for r in response_utterances:
        disfl_inp = format_disfl_input(response)
        disfluency = generate_disfluency(disfl_inp)
        if disfluency:
            full.append(disfluency)
    if full:
        disfluency = ". ".join(full)
    if disfluency:
        output = TreebankWordDetokenizer().detokenize(disfluency.tgt)
        audio_inp = format_audio_input(disfluency)
        print(f">> DialoGPT: {output}")
        generate_audio(audio_inp, use_gtts)
        time.sleep(2.5) # let audio load and autoplay