<a href="https://colab.research.google.com/github/jwlw2022/nlp-chatbot-project/blob/main/6864_NLP_Chatbot_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations

In [None]:
# %%bash
# Logistics #2: install the transformers package, create a folder, download the dataset and a patch
!pip install pytorch-pretrained-bert
!pip -q install transformers
!pip -q install datasets
!pip -q install tqdm
!pip -q install sentencepiece 

[K     |████████████████████████████████| 2.1MB 13.7MB/s 
[K     |████████████████████████████████| 3.3MB 51.9MB/s 
[K     |████████████████████████████████| 901kB 65.2MB/s 
[K     |████████████████████████████████| 225kB 13.0MB/s 
[K     |████████████████████████████████| 245kB 15.7MB/s 
[K     |████████████████████████████████| 112kB 15.8MB/s 
[K     |████████████████████████████████| 1.2MB 13.7MB/s 
[?25h

# Pretrained tokenizer

In [None]:
import transformers

# Use a pretrained tokenizer with CLASS.from_pretrained() function
tokenizer = transformers.AutoTokenizer.from_pretrained('distilbert-base-cased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




# Download PersonaChat dataset


In [None]:
import json
from pytorch_pretrained_bert import cached_path

url = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"

# Download and load JSON dataset
personachat_file = cached_path(url)
with open(personachat_file, "r", encoding="utf-8") as f:
    dataset = json.loads(f.read())

# Tokenize and encode the dataset using our loaded GPT tokenizer
def tokenize(obj):
    if isinstance(obj, str):
        return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
    if isinstance(obj, dict):
        return dict((n, tokenize(o)) for n, o in obj.items())
    return list(tokenize(o) for o in obj)
 
dataset = tokenize(dataset)

100%|██████████| 209850483/209850483 [00:09<00:00, 22819213.98B/s]


In [None]:
#  transformer_chatbot
#  Copyright (C) 2018 Golovanov, Tselousov
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.

import random
import torch
from torch.utils.data import Dataset
from .text import BPEVocab


class PersonaChatDataset(Dataset):
    @staticmethod
    def parse_data(path):
        with open(path, 'r', encoding='utf-8') as file:
            data = []
            for line in file.readlines():
                line = line.strip()

                if len(line) == 0:
                    continue

                space_idx = line.find(' ')
                if space_idx == -1:
                    dialog_idx = int(line)
                else:
                    dialog_idx = int(line[:space_idx])

                if int(dialog_idx) == 1:
                    data.append({'persona_info': [], 'dialog': []})

                dialog_line = line[space_idx + 1:].split('\t')
                dialog_line = [l.strip() for l in dialog_line]

                if dialog_line[0].startswith('your persona:'):
                    persona_info = dialog_line[0].replace('your persona: ', '')
                    data[-1]['persona_info'].append(persona_info)

                elif len(dialog_line) > 1:
                    data[-1]['dialog'].append(dialog_line[0])
                    data[-1]['dialog'].append(dialog_line[1])

            return data

    @staticmethod
    def make_dataset(data, vocab, max_lengths):
        dataset = []
        for chat in data:
            persona_info = [vocab.string2ids(s) for s in chat['persona_info']]
            dialog = [vocab.string2ids(s) for s in chat['dialog']]

            if len(dialog) % 2 == 1:
                dialog = dialog[:-1]
           
            dataset.append((persona_info, dialog))

        return dataset

    def __init__(self, paths, vocab, max_lengths=2048, min_infos=2):
        assert min_infos > 0             

        if isinstance(paths, str):
            paths = [paths]
        
        self.vocab = vocab
        self.max_lengths = max_lengths
        self.min_infos = min_infos

        parsed_data = sum([FacebookDataset.parse_data(path) for path in paths], [])
        self.data = FacebookDataset.make_dataset(parsed_data, vocab, max_lengths)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        persona_info, dialog = self.data[idx]

        if len(persona_info):
            n_info_samples = max(self.min_infos, random.randint(1, len(persona_info)))
            n_info_samples = min(n_info_samples, len(persona_info))
            persona_info = random.sample(persona_info, n_info_samples)
            random.shuffle(persona_info)
            persona_info = sum(persona_info, []) 
            persona_info = [self.vocab.info_bos_id] + persona_info[:self.max_lengths-2] + [self.vocab.info_eos_id]

        dialog_begin = 0
        dialog_end = random.randrange(2, len(dialog)+1, 2)

        h = []
        for i, ids in enumerate(dialog[dialog_begin:dialog_end-1], 1):
            if i % 2 == 1:
                ids = [self.vocab.talker1_bos_id] + ids + [self.vocab.talker1_eos_id]
            else:
                ids = [self.vocab.talker2_bos_id] + ids + [self.vocab.talker2_eos_id]
            h.extend(ids)
        h = h[-self.max_lengths:]

        y = [self.vocab.bos_id] + dialog[dialog_end-1] + [self.vocab.eos_id]
        y = y[:self.max_lengths]

        return persona_info, h, y