In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/WARNING_PRIVATE_FOLDER/gpt2-dialogue-generation-pytorch/
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch==1.7.1
  Downloading torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl (776.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.8/776.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers==4.12.5
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m59.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==1.16.1
  Downloading datasets-1.16.1-py3-none-any.whl (298 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.3/298.3 KB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece==0.1.96
  Downloading sentencepiece-0.1.96-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m


In [None]:
import json
import torch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

from transformers import pipeline
summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")

conversation = '''Jeff: Can I train a Transformers model on Amazon SageMaker? 
Philipp: Sure you can use the new Hugging Face Deep Learning Container. 
Jeff: ok.
Jeff: and how can I get started? 
Jeff: where can I find documentation? 
Philipp: ok, ok you can find everything here.                                   
'''
summarizer(conversation)[0]["summary_text"]

In [None]:
from datasets import *
from tqdm import tqdm


# For all
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

# For empathetic dialogues
exclude_symbol = "_conv"
comma_symbol = "_comma_"

# For persona chat
persona_chat_url = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
silence_symbol = "__ SILENCE __"


def load_daily():
    dataset = load_dataset('daily_dialog')
    test_dialogues = dataset['test']['dialog']
    
    return test_dialogues
    
    

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list


In [None]:
test_arr = load_daily()
load_daily()

In [None]:
import difflib
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from itertools import chain

import torch
import os, sys
import numpy as np
import argparse
import copy
import math
import random

class Arguments:
    def __init__(self):
        self.seed = 0 
        self.mode="test" 
        self.data_dir="data" 
        self.model_type="gpt2" 
        self.bos_token="<bos>" 
        self.sp1_token="<sp1>" 
        self.sp2_token="<sp2>" 
        self.gpu="0" 
        self.max_len=1024 
        self.max_turns=4
        self.top_p=0.8 
        self.ckpt_dir="saved_models" 
        self.ckpt_name="best_ckpt_epoch=3_valid_loss=2.5211" 
        self.end_command="Abort!"



#원래 shell로 들어가는 파라미터를 정의합니다.

class Manager():
    def __init__(self, args, test_arr):
        self.args = args
        self.test_arr = test_arr

        if torch.cuda.is_available():
            self.args.device = torch.device(f"cuda:{self.args.gpu}")
        else:
            self.args.device = torch.device("cpu")
        
        # Tokenizer & Vocab
        print("Loading the tokenizer...")
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.args.model_type)
        special_tokens = {
            'bos_token': self.args.bos_token,
            'additional_special_tokens': [self.args.sp1_token, self.args.sp2_token]
        }
        self.args.eos_token = self.tokenizer.eos_token
        num_new_tokens = self.tokenizer.add_special_tokens(special_tokens)
        vocab = self.tokenizer.get_vocab()
        self.args.vocab_size = len(vocab)
        self.args.bos_id = vocab[self.args.bos_token]
        self.args.eos_id = vocab[self.args.eos_token]
        self.args.sp1_id = vocab[self.args.sp1_token]
        self.args.sp2_id = vocab[self.args.sp2_token]
        
        # Load model    
        print("Loading the model...")
        self.fix_seed(self.args.seed)
        self.model = GPT2LMHeadModel.from_pretrained(self.args.model_type).to(self.args.device)
        self.model.resize_token_embeddings(self.args.vocab_size)
        
        self.args.max_len = min(self.args.max_len, self.model.config.n_ctx)
            
        
        
        if self.args.ckpt_name is not None:
            ckpt_path = f"{self.args.ckpt_dir}/{self.args.ckpt_name}.ckpt"
            if os.path.exists(ckpt_path):
                print("Loading the trained checkpoint...")
                ckpt = torch.load(ckpt_path, map_location=self.args.device)
                self.model.load_state_dict(ckpt['model_state_dict'])
                
                if self.args.mode == 'train':
                    print(f"The training restarts with the specified checkpoint: {self.args.ckpt_name}.ckpt.")
                    self.optim.load_state_dict(ckpt['optim_state_dict'])
                    self.sched.load_state_dict(ckpt['sched_state_dict'])
                    self.best_loss = ckpt['loss']
                    self.last_epoch = ckpt['epoch']
                else:
                    print("The inference will start with the specified checkpoint.")
            else:
                print(f"Cannot fine the specified checkpoint {ckpt_path}.")
                if self.args.mode == 'train':
                    print("Training will start with the initialized model.")
                else:
                    print("Cannot inference.")
                    exit()
              
        print("Setting finished.")
        
    def nucleus_sampling(self, input_ids, token_type_ids, input_len):
        output_ids = []
        for pos in range(input_len, self.args.max_len):
            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
            output = F.softmax(output, dim=-1)  # (1, V)
            
            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
            idx_remove = cumsum_probs > self.args.top_p
            idx_remove[:, 1:] = idx_remove[:, :-1].clone()
            idx_remove[:, 0] = False
            sorted_probs[idx_remove] = 0.0
            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
            
            probs = torch.zeros(output.shape, device=self.args.device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
            idx = torch.multinomial(probs, 1)  # (1, 1)
            
            idx_item = idx.squeeze(-1).squeeze(-1).item()
            output_ids.append(idx_item)
            
            if idx_item == self.args.eos_id:
                break
                
            input_ids = torch.cat((input_ids, idx), dim=-1)
            next_type_id = torch.LongTensor([[self.args.sp2_id]]).to(self.args.device)
            token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
            assert input_ids.shape == token_type_ids.shape
            
        return output_ids

    def fix_seed(self, seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        random.seed(seed)  

    def test(self):
        test_arr = self.test_arr
        args = self.args
        print("Let's start!")
        self.model.eval()
        self.fix_seed(self.args.seed)

        input_hists = []
        user = []
        ground_truth = []
        ex_cnt = -1
        similarity = 0

        context_for_json = []
        ans_for_json = []

        for utters in test_arr:
            ex_cnt += 1
            assert (len(utters) > 1 , "there's a short dialogue") 

            hists = []

            for u, utter in enumerate(utters):
                if u % 2 == 0:
                    hists.append([args.sp1_id] + self.tokenizer.encode(utter))
                else:
                    hists.append([args.sp2_id] + self.tokenizer.encode(utter))

            if len(hists) >= self.args.max_turns:
                    num_exceeded = len(hists) - self.args.max_turns + 1
                    hists = hists[num_exceeded:]

            if hists[-1][0] == args.sp2_id:
                hists = hists[:-1]
                ground_truth.append(utters[-1])
            elif hists[-1][0] == args.sp1_id:
                hists = hists[:-2]
                ground_truth.append(utters[-2])
            else:
                assert (False, "there's an error, sp1,2 token failed")
            


            with torch.no_grad():
                input_ids = [self.args.bos_id] + list(chain.from_iterable(hists)) + [self.args.sp2_id]
                start_sp_id = hists[0][0]
                next_sp_id = self.args.sp1_id if start_sp_id == self.args.sp2_id else self.args.sp2_id
                assert start_sp_id != next_sp_id
                token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(hists)]
                assert len(token_type_ids) == len(hists)
                token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.args.sp2_id]
                assert len(input_ids) == len(token_type_ids)
                input_len = len(input_ids)


                input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.args.device)
                token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.args.device)

                output_ids = self.nucleus_sampling(input_ids, token_type_ids, input_len)   
                res = self.tokenizer.decode(output_ids, skip_special_tokens=True)

                similarity += compute_similarity(res, ground_truth[ex_cnt])
                temp = []
               
                for u in hists[:-1]:
                    temp+=u

                print(f"multiturn : {self.tokenizer.decode(temp, skip_special_tokens=False)}")
                print(f"user : {self.tokenizer.decode(hists[-1], skip_special_tokens=False)}")
                print(f"res : {res}\n gt : {ground_truth[ex_cnt]}")

                context_for_json.append(self.tokenizer.decode(temp, skip_special_tokens=True) + self.tokenizer.decode(hists[-1], skip_special_tokens=True))
                ans_for_json.append(res)

                assert len(context_for_json) == len(ans_for_json)

                if len(context_for_json) == 300:
                    # 데이터 딕셔너리 생성
                    data = {"context": context_for_json, "ans": ans_for_json}

                    # JSON 파일로 저장
                    with open("test/multiturn_data.json", "w") as f:
                        json.dump(data, f)
                    

                # 예측한 문장과 ground truth를 비교할 수 있습니다.
                # 아직 눈으로 밖에 비교할 방법이 없음.

                #print(f"Bot: {res}")
                #input_hists.append([self.args.sp2_id] + self.tokenizer.encode(res))
        print(f"문자열 간에 유사도 : {similarity / len(test_arr)}")


def compute_similarity(string1, string2):

    matcher = difflib.SequenceMatcher(None, string1, string2)
    return matcher.ratio()

# 두 문자열을 비교하는 함수이나...
# 큰 의미는 없는것 같다.

      

In [None]:
args = Arguments()
args.ckpt_dir = f"{args.ckpt_dir}/{args.model_type}"
assert args.ckpt_name is not None, "Please specify the trained model checkpoint."
manager = Manager(args, test_arr)
manager.test()

Loading the tokenizer...
Loading the model...
Loading the trained checkpoint...
The inference will start with the specified checkpoint.
Setting finished.
Let's start!
multiturn : <sp2>  Do you really have all of these drugs? Where do you get them from?  <sp1>  I got my connections! Just tell me what you want and I ’ ll even give you one ounce for free.  <sp2>  Sounds good! Let ’ s see, I want. 
user : <sp1>  Yeah? 
res :  Would you please put your name on the package please?
 gt :  I want you to put your hands behind your head ! You are under arrest ! 
multiturn : <sp1> The taxi drivers are on strike again.  <sp2>  What for? 
user : <sp1>  They want the government to reduce the price of the gasoline. 
res :  What is the driving fee?
 gt :  It is really a hot potato . 
multiturn : <sp1>  Mainly because we've invested in a heat recovery system.  <sp2>  What does that mean exactly? 
user : <sp1>  Well, we use the exhaust gases from our printing presses to provide energy to heat our dryers

KeyboardInterrupt: ignored