In [71]:
!pip install torch



In [72]:
! pip install --quiet "torchmetrics>=0.7, <0.12" "seaborn" "ipython[notebook]>=8.0.0, <8.9.0" "pytorch-lightning>=1.4, <1.9" "torchmetrics >=0.11.0" "setuptools==65.6.3" "pandas" "torchvision" "torch>=1.8.1, <1.14.0"

In [18]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertTokenizerFast, AutoModel, GPT2Tokenizer, BertTokenizer, BertModel, BertLMHeadModel
import json
from torch import nn
import torch
import copy
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
import seaborn as sn
import random
from typing import Iterable
import numpy as np
from tqdm import tqdm
import csv

In [5]:
pwd = os.getcwd()
pwd

'/Users/jaysun/Desktop/Python_Projects/GPTRLFinetune'

In [7]:
gpt2 = GPT2LMHeadModel.from_pretrained(os.path.join(pwd, 'GPT-2/GPT2_finetune_1'))
bert = BertModel.from_pretrained('bert-base-chinese')
bert_lm = BertLMHeadModel.from_pretrained('bert-base-chinese')
# tokenizer = BertTokenizer(vocab_file = './vocab_small.txt')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-chinese were not used when i

In [17]:
bert_lm(input_ids = torch.tensor([[1,2,3,54], [3, 5, 5, 3]]), attention_mask = torch.tensor([[1, 1, 0, 1], [1, 1, 0, 1]]))['logits'].shape

torch.Size([2, 4, 21128])

In [45]:
class Reward(nn.Module):
    def __init__(self, gpt : GPT2LMHeadModel, question_mark_token, toxic_words : list, non_sense_response : list, eos_token = 0, device = "cpu", gpt_tokenizer = None, bos_token = 101) -> None:
        super(Reward, self).__init__()
        self.reward_coefficient = torch.ones(6, device = device) / 6
        self.gpt = copy.deepcopy(gpt)
        self.gpt = self.gpt.to(device)
        self.gpt_tokenizer = BertTokenizer(vocab_file = os.path.join(pwd, 'GPT-2/vocab_small.txt')) if gpt_tokenizer is None else gpt_tokenizer
        for p in self.gpt.parameters():
            p.requires_grad = False
        self.device = device
        self.eos_token = eos_token
        self.bos_token = bos_token
        self.question_mark_token = question_mark_token
        self.bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
        self.bert = AutoModel.from_pretrained('ckiplab/bert-base-chinese').to(device)
        for p in self.bert.parameters():
            p.requires_grad = False
        self.toxic_words = toxic_words
        self.non_sense_response = non_sense_response
    def update_model(self, gpt):
        self.gpt = copy.deepcopy(gpt)
        for p in self.gpt.parameters():
                p.requires_grad = False
        self.gpt.eval()
    def to_device(self, device):
        self.device = device
    def update_reward_coefficient(self, gamma):
        self.reward_coefficient =  copy.deepcopy(gamma)
    def forward(self, state : list):
        # state = [batch_size, {prev_utterance, response}] # elements of second dimension are dictionaries composed of previous utterance and subsequent response
        # prev_utterance = [prev_utterance_seq_length, ]
        # response = [response_seq_length, ]
        self.gpt.eval()
        for idx in range(len(state)):
            state[idx]['prev_utterance'] = state[idx]['prev_utterance'].to(self.device)
            state[idx]['response'] = state[idx]['response'].to(self.device)
        with torch.no_grad():
            reward = self.reward_coefficient[0] * (self.get_length_reward(state).exp()) # exp for avoiding +- offset each other, if you don't want it, go head delete .exp()
            reward += self.reward_coefficient[1] * (self.get_question_reward(state).exp())
            reward += self.reward_coefficient[2] * (self.get_coherence(state).exp())
            reward += self.reward_coefficient[3] * (self.get_toxicity(state))
            reward += self.reward_coefficient[4] * (self.get_ease_of_answering(state).exp())
            reward += self.reward_coefficient[5] * (self.get_reward_semantic_coherence(state).exp())

        return F.normalize(reward, dim = 0)
    def get_response_prob(self, state, require_grad = False, reverse = False): # reverse means prev_utterance = response, response = prev_utterance
        state = copy.deepcopy(state)
        if reverse:
            state = [{"prev_utterance" : state[idx]['response'], 'response' : state[idx]['prev_utterance']} for idx in range(len(state))]
        def get_prob():
            probability = torch.ones((len(state)), device = self.device)
            for index, state_dict in enumerate(state):
                utterance, response = state_dict['prev_utterance'].clone().detach(), state_dict['response'].clone().detach()
                probability[index] *= self.p_seq2seq(utterance, response)
            return F.normalize(probability, dim = 0)
        self.gpt.train(require_grad)
        if not require_grad:
            with torch.no_grad():
                return get_prob()
        else:
            return get_prob()
    def p_seq2seq(self, up, down):
        input_ids = up.clone().detach()
        probability = 1e20 # refrain from continuously multiplying number 0 < number < 1 such that probability would be too small, which in turn would cause precision problem
        for i in range(len(down)):
            logits = self.gpt(input_ids = input_ids)['logits'] 
            logits = F.softmax(logits, dim = -1)
            probability *= logits[-1, down[i]]
            input_ids = torch.cat((input_ids, torch.tensor([down[i]], device = self.device)), dim = -1)
            if down[i] == self.eos_token:
                    break
        return probability
    def get_length_reward(self, state):
        return F.normalize(torch.tensor([len(state[idx]['response']) for idx in range(len(state))], device = self.device, dtype = torch.float), dim = 0)
    def get_question_reward(self, state):
        return F.normalize(torch.tensor([1 if self.question_mark_token in state[idx]['response'] else 0 for idx in range(len(state))], device = self.device, dtype = torch.float), dim = 0)
    def transform_from_gpt_to_bert_tokens(self, input_id):
        original_string = self.gpt_tokenizer.decode(input_id, skip_special_tokens = True).replace(" ", "")
        return self.bert_tokenizer.encode(original_string, return_tensors = 'pt')[0].to(self.device)
    def get_coherence(self, state):
        state = [{"prev_utterance" : self.transform_from_gpt_to_bert_tokens(state[idx]['prev_utterance']), 'response' : self.transform_from_gpt_to_bert_tokens(state[idx]['response'])} for idx in range(len(state))]
        coherence = torch.zeros((len(state)), device = self.device, dtype = torch.float)
        cos = nn.CosineSimilarity(dim = -1)
        for idx in range(len(state)):
            utterance = self.bert(input_ids = state[idx]['prev_utterance'].unsqueeze(0))['last_hidden_state'][0][0]
            response = self.bert(input_ids = torch.cat((torch.tensor([self.bos_token], device = self.device), state[idx]['response'])).unsqueeze(0))['last_hidden_state'][0][0]
            sim = cos(utterance, response)
            coherence[idx] = sim
        return F.normalize(coherence, dim = 0)
    def get_toxicity(self, state):
        toxicity = []
        for idx, value in enumerate(state):
            counter = 0
            for word in self.toxic_words:
                if self.x_in_y(word, value['reponse']):
                    counter -= 1
            toxicity.append(counter)
        return torch.tensor(toxicity, device = self.device)
    def get_ease_of_answering(self, state):
        ease_of_answering = torch.zeros((len(state)), device = self.device, dtype = torch.float)
        for idx in range(len(state)):
            temp = 0
            for sentence in self.non_sense_response:
                temp += self.p_seq2seq(state[idx]['response'], sentence) / len(sentence)
            temp *= (- 1 / len(self.non_sense_response))
            ease_of_answering[idx] = temp
        return F.normalize(ease_of_answering, dim = 0)
    def get_reward_semantic_coherence(self, state):
        forward = self.get_response_prob(state)
        backward = self.get_response_prob(state, reverse = True)
        return F.normalize(torch.tensor([forward[idx] / len(state[idx]['response']) + backward[idx] / len(state[idx]['prev_utterance']) for idx in range(len(state))], device = self.device, dtype = torch.float), dim = 0)
    def x_in_y(self, query, base):
        try:
            l = len(query)
        except TypeError:
            l = 1
            query = type(base)((query,))

        for i in range(len(base)):
            if base[i : i+l] == query:
                return True
        return False
    @staticmethod
    def sentence2id(sentence, tokenizer, emotion_dict = {'其它': 0, '喜歡': 1, '悲傷': 2, '噁心': 3, '憤怒': 4, '開心': 5}, max_len = None, pad_to_max_len = False):

        utterance = sentence[: sentence.index(']') + 1]
        response = sentence[sentence.index(']') + 1:]
        prev_utterance = tokenizer.encode(utterance, return_tensors = 'pt')[0]
        response = tokenizer.encode(response, return_tensors = 'pt')[0][1:]
        if max_len is not None and pad_to_max_len:
            padding = max_len - prev_utterance.shape[-1]
            if padding > 0:
                prev_utterance = torch.cat((prev_utterance, torch.zeros(padding, dtype = torch.int64) - 1), dim = -1)
            padding = max_len - response.shape[-1]
            if padding > 0:
                response = torch.cat((response, torch.zeros(padding, dtype = torch.int64) - 1), dim = -1)
            
        return {'prev_utterance' : prev_utterance, 'response' : response}

In [29]:
class GPT2DataSet(Dataset):
    def __init__(self, tokenizer = None, max_len = 256, root_path = os.path.join(pwd, 'dataset'), train_path = 'single_emo_T_train.json', val_path = 'single_emo_T_valid.json', test_path = 'single_emo_T_test.json', status = 'train',
                 shuffle = True) -> None:
        self.file_path = os.path.join(root_path, (train_path if 'train' in status else (val_path if 'val' in status else test_path)))
        with open(self.file_path) as f:
            self.data = json.load(f)
        self.dataset = []
        self.tokenizer = BertTokenizer(vocab_file = os.path.join(pwd, 'GPT-2/vocab_small.txt')) if tokenizer is None else tokenizer
        if  status == 'test':
            self.data = [i[0] + i[1] for i in self.data]
        self.max_len = max_len
        for line in self.data:
            if '[' not in line or ']' not in line or line.count(']') != 1 or line.count('[') != 1:
                continue
            else:
                self.max_len = max(self.max_len, len(line))
                self.dataset.append(line)
        if shuffle:
            random.shuffle(self.dataset)
    def __len__(self) -> int:
        return len(self.dataset)
    def __getitem__(self, index) -> dict:
        return Reward.sentence2id(self.dataset[index], self.tokenizer, max_len = self.max_len)
    @staticmethod
    def get_toxic_ids_and_non_sense_response(tokenizer, 
                                         dirty_words = ['幹!','賤貨','米蟲','王八','王八蛋','不要臉','吃屎','敗類','智障','白癡','賤人','下流',
                                                        '死肥豬','人渣','神經病','賤','尼瑪','無恥','婊','娘炮','魯蛇','廢物', '腦殘'],   
                                         non_sense_sentences = ['嗯','嗯嗯','隨便啦','隨便啊','都可以','呵呵','哈哈','喔','笑死','是喔','好吧','我不知道',
                                                                '還好','是啊','對啊','我也是','嘿嘿']):
        # create toxic word list
        toxic_ids = []
        for i in range(len(dirty_words)):
            ids = tokenizer.encode(dirty_words[i])
            toxic_ids.append(ids[1:-1])
        # create non sense sentence list
        non_sense_ids = []
        for i in range(len(non_sense_sentences)):
            ids = tokenizer.encode(non_sense_sentences[i])
            non_sense_ids.append(ids[1:])
        return toxic_ids, non_sense_ids

In [80]:
g = GPT2DataSet()

In [81]:
a = DataLoader(g, batch_size = 16)

In [82]:
for i in a:
    print(a)
    break

RuntimeError: stack expects each tensor to be equal size, but got [25] at entry 0 and [42] at entry 1

In [None]:
tokenizer.encode("我愛你")

[101, 2769, 2695, 872, 102]

In [None]:
value, indices = F.softmax(gpt2(input_ids = torch.tensor([101, 2769, 2695, 872, 102]))['logits'], dim = -1).topk(3, dim = -1)

In [None]:
p

(tensor([[9.9955e-01, 2.8439e-05, 1.7044e-05],
         [1.3000e-01, 5.7022e-02, 4.0555e-02],
         [3.6167e-01, 6.0442e-02, 5.2951e-02],
         [1.8121e-01, 1.5723e-01, 9.7066e-02],
         [3.8228e-01, 2.7874e-01, 8.7628e-02]], grad_fn=<TopkBackward0>),
 tensor([[ 103,  138, 8024],
         [ 738,  947, 3221],
         [ 872, 4638, 2769],
         [ 138, 8024,  947],
         [ 134, 5965, 5083]]))

In [None]:
indices.shape

torch.Size([5, 3])

In [None]:
results = [[value[-1, i], indices[-1, i]] for i in range(3)]

In [None]:
results

[[tensor(0.3823, grad_fn=<SelectBackward0>), tensor(134)],
 [tensor(0.2787, grad_fn=<SelectBackward0>), tensor(5965)],
 [tensor(0.0876, grad_fn=<SelectBackward0>), tensor(5083)]]

In [None]:
gpt2(input_ids = torch.tensor([101,  791, 1921, 1921, 3706, 4696, 1962,  138, 1599, 3631,  140,  102,
         7433, 1921,  738, 3221, 2523, 5401, 4638,  101, 103, 2769, 738, 2682, 1391, 4125]))['logits'].argmax( dim = -1)

tensor([ 103, 1921, 4638, 3706,  679, 1962,  138, 1599, 3631,  140, 3221, 7433,
        1921,  738, 3221, 2523, 5401, 4638,  101,  103, 2769,  738, 2682, 1391,
        4125, 7102])

In [220]:
gpt2

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(13317, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [None]:
srf = gpt2.generate(torch.tensor([[101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,
 102]]), max_length = 100, num_beams = 3, early_stopping=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:102 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


In [None]:
srf.shape

torch.Size([1, 100])

In [None]:
tokenizer.decode(srf[0], skip_special_tokens = False)

'[CLS] 今 天 我 去 外 面 公 園 玩 [ 憤 怒 ] [SEP] 已 經 變 色 了 [CLS] [MASK] 你 是 最 棒 的 ， 摔 這 麼 美 的 ， 你 是 第 一 個 ！ [ 喜 歡 ] 你 是 最 棒 的 ， 摔 這 麼 美 的 ， 你 是 第 一 個 ！ [CLS] [MASK] 你 是 最 棒 的 ， 摔 這 麼 美 的 ， 你 是 第 一 個 ！ [ 喜 歡 ] 哈 哈 哈 哈 哈 哈 哈 哈 哈 哈 哈 哈 哈 哈'

In [None]:
tokenizer.all_special_ids

[100, 102, 0, 101, 103]

In [125]:
[1, 2,4,5,6][-1:]

[6]

In [None]:
tokenizer.all_special_tokens

['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']

In [None]:
gpt2(input_ids = torch.tensor([[2,3,4]]))['logits'].shape

torch.Size([1, 3, 21128])

In [30]:
class BeamHypotheses(object):
    def __init__(self, num_beams, max_length = 200, length_penalty = 0.7):
        self.max_length = max_length
        self.length_penalty = length_penalty
        self.num_beams = num_beams # beam size
        self.beams = [] # best sequences and corresponding scores
        self.worst_score = None

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, mask, sum_logprobs, cur_len, utter_length, prev_utterance, response, eos_token):
        score = sum_logprobs / (cur_len  ** self.length_penalty) # calculate penalized score
        if len(self) < self.num_beams or score > self.worst_score:
            if response[-1] != eos_token:
                response = torch.cat((response, torch.tensor([eos_token])))
            self.beams.append((score, hyp, mask, utter_length, sum_logprobs, prev_utterance, response))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _0, _1, _2, _3, _4, _5) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score) if self.worst_score is not None else score

    def is_done(self, best_sum_logprobs, cur_len):
        if len(self) < self.num_beams:
            return False
        else:
            cur_score = best_sum_logprobs / ((cur_len) ** self.length_penalty)
            ret = self.worst_score >= cur_score
            return ret

In [36]:
class Q(nn.Module):
    def __init__(self, bert = None, bert_name = 'bert-base-chinese', bert_tokenizer = None, gpt_tokenizer = None, down_stream_features = 1024, only_down_stream = True, gamma = 0.9) -> None:
        super(Q, self).__init__()
        self.gamma = gamma
        self.bert_tokenizer = BertTokenizerFast.from_pretrained(bert_name) if bert_tokenizer is None else bert_tokenizer
        self.bert = AutoModel.from_pretrained(bert_name) if bert is None else copy.deepcopy(bert)
        self.gpt_tokenizer = BertTokenizer(vocab_file = os.path.join(pwd, 'GPT-2/vocab_small.txt')) if gpt_tokenizer is None else gpt_tokenizer
        self.only_down_stream = only_down_stream
        if only_down_stream:
           self.bert.eval()
           for p in self.bert.parameters():
               p.require_grad = False
        self.down_stream = nn.Sequential( nn.Linear(in_features = self.bert.pooler.dense.out_features, out_features = down_stream_features // 2),
                                          nn.BatchNorm1d(down_stream_features // 2),
                                          nn.ReLU(),
                                          nn.Linear(in_features = down_stream_features // 2, out_features = down_stream_features),
                                          nn.ReLU(),
                                          nn.Linear(in_features = down_stream_features, out_features = down_stream_features),
                                          nn.ReLU(),
                                          nn.Dropout(0.1),
                                          nn.Linear(in_features = down_stream_features, out_features = down_stream_features // 4),
                                          nn.BatchNorm1d(down_stream_features // 4),
                                          nn.ReLU(),
                                          nn.Linear(in_features = down_stream_features // 4, out_features = 1))
    def transform_from_gpt_to_bert_tokens(self, input_id):
        '''
            input_ids : [batch_size, sequences]
        '''
        original_string = [self.gpt_tokenizer.decode(input, skip_special_tokens = True).replace(" ", "") for input in input_id]
        return [self.bert_tokenizer.encode(string, return_tensors = 'pt')[0] for string in original_string]
    def get_processed(self, prev_utterance, bert_tokens = False, max_len = 256):
            if not bert_tokens:
                prev_utterance = self.transform_from_gpt_to_bert_tokens(prev_utterance)
            input_ids = torch.zeros((len(prev_utterance), max_len)) - 1
            for idx, utter in enumerate(prev_utterance):       
                input_ids[idx][:min(max_len, len(utter))] = utter[:min(max_len, len(utter))]
            mask = input_ids.ge(0)
            input_ids[~mask] = 0
            mask = mask.float()
            return input_ids, mask
    def forward(self, prev_utterance, response = None, mask = None, bert_tokens = False, max_len = 256, processed = False):
        '''
            prev_utterance : [batch_size, sequences]
            response : [batch_size, sequences]
            mask : [batch_size, sequences]
        '''
        prev_utterance = prev_utterance if response is None else [torch.cat((utt, res), dim = -1) for utt, res in zip(prev_utterance, response)]

        # creating mask
        if not processed or mask is None or not bert_tokens:
            prev_utterance, mask = self.get_processed(prev_utterance = prev_utterance, bert_tokens = bert_tokens, max_len = max_len)
        if self.only_down_stream:
            with torch.no_grad():
                up_stream = self.bert(input_ids = prev_utterance.to(torch.long), attention_mask = mask.to(torch.long))['last_hidden_state'][:, 0, :].view(len(prev_utterance), -1)
        else:
            up_stream = self.bert(input_ids = prev_utterance.to(torch.long), attention_mask = mask.to(torch.long))['last_hidden_state'][:, 0, :].view(len(prev_utterance), -1)
        return self.down_stream(up_stream).view(len(prev_utterance))

In [32]:
class GPT2Wrapper(nn.Module):
    def __init__(self, gpt = None, tokenizer = None, device = 'cpu'):
        super(GPT2Wrapper, self).__init__()
        self.gpt = gpt if gpt is not None else GPT2LMHeadModel.from_pretrained(os.path.join(pwd, './GPT-2/GPT2_finetune_1'))
        self.tokenizer = BertTokenizer(vocab_file = os.path.join(pwd, 'GPT-2/vocab_small.txt')) if tokenizer is None else tokenizer
        self.vocab_size = self.tokenizer.vocab_size
        self.special_tokens = {}
        for idx, key in enumerate(self.tokenizer.all_special_tokens):
            self.special_tokens[key] = self.tokenizer.all_special_ids[idx]
        self.device = device
    def forward(self, prev_utterance, response = None, beam = 3, max_len = 100, require_grad = True, device = 'mps'):
        self.device = device
        self.gpt.to(device)
        if not require_grad:
            with torch.no_grad():
                    rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse = self.beam_search(prev_utterance, response = response, beam = beam, max_len = max_len)
        else:
                rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse = self.beam_search(prev_utterance, response = response, beam = beam, max_len = max_len)
        return rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse
    def get_prob(self, result, mask, prev_utterance, response):
        utter_len = [len(utt) for utt in prev_utterance]
        logits = torch.log(F.softmax(self.gpt(input_ids = result, attention_mask = mask)['logits'], dim = -1))
        probs = torch.zeros(len(result))
        for idx, res in enumerate(response):
            res_len = res
            for idx_, digit in enumerate(res):
                probs[idx] += logits[idx][idx_ + utter_len[idx]][digit]
            probs[idx] /= res_len
        return probs
    def beam_search(self, prev_utterance, response = None, eos_word = '[SEP]', beam = 3, max_len = 200, length_penalty_for_hypothesis = 0.8, emotion = ['喜歡', '悲傷', '噁心', '憤怒', '開心', '其它']):
        '''
            prev_utterance = [seq_length_prev_utterance]
            response = [seq_length_response]
            if response is None, mask and utter_length must not be None, which means they have been processed, i.e. use the output of this function to generate next response
        '''
        eos_token = self.special_tokens[eos_word]
        # prev_utterance = prev_utterance if response is None else [torch.cat((utt, res)) for utt, res in zip(prev_utterance, response)]
        batch_size = len(prev_utterance)
        utter_length = torch.zeros((batch_size), dtype = torch.long)
        ''' create mask '''
        input_ids = torch.zeros((batch_size, max_len * 2), dtype = torch.long) - 1 # * 2 is for latter we will concatenate the generated sequences to them
        with torch.no_grad():
            for idx, utt in enumerate(prev_utterance):
                    if response is not None: # means prev_utterance has not been processed
                        if 138 not in response[idx][-5: ] and 140 not in response[idx][-5: ]:
                            response[idx] = torch.cat((response[idx][:-1] if response[idx][-1] == eos_token else response[idx], self.tokenizer.encode('[' + emotion[random.randint(0, len(emotion) - 1)] + ']', return_tensors = 'pt')[0][1:]))
                        elif response[idx][-1] != eos_token:
                            response[idx] = torch.cat((response[idx], torch.tensor([eos_token])))    
                        if 138 not in utt[-5: ] and 140 not in utt[-5: ]: # if emotion is not included in the sentence, randomly select a emotion type
                            prev_utterance[idx] = torch.cat((utt[:-1] if utt[-1] == eos_token else utt, self.tokenizer.encode('[' + emotion[random.randint(0, len(emotion) - 1)] + ']', return_tensors = 'pt')[0][1:]))
                        elif utt[-1] != eos_token:
                            prev_utterance[idx] = torch.cat((prev_utterance[idx], torch.tensor([eos_token])))   
                        prev_utterance[idx] = torch.cat((prev_utterance[idx], response[idx]))
                        utter_length[idx] = min(len(prev_utterance[idx]), max_len)            
                    else:
                        length = len(utt[utt > 0])
                        if 138 not in utt[length - 5:] and 140 not in utt[length - 5:]:
                            start = length - 1 if utt[length - 1] == eos_token else length
                            prev_utterance[idx][start: start + 5] = self.tokenizer.encode('[' + emotion[random.randint(0, len(emotion) - 1)] + ']', return_tensors = 'pt')[0][1:]
                        elif utt[length - 1] != eos_token:
                            prev_utterance[idx][length] = eos_token
                        utter_length[idx] = len(prev_utterance[idx][prev_utterance[idx] > 0])
                    input_ids[idx][: utter_length[idx]] = prev_utterance[idx][: utter_length[idx]]

        mask = input_ids.ge(0)
        input_ids[~mask] = 0
        mask = mask.float()
        prev_utterance = input_ids
        
        beams_score = torch.zeros((batch_size, beam))
        beams_score[:, 1:] = -1e9
        beams_score = beams_score.view(-1) # [batch_size * beam]
        beams_score = beams_score.to(self.device)
    
        utter_length = utter_length.unsqueeze(-1).expand((-1, beam)).contiguous().view(-1).to(self.device)
        original_utter_length = utter_length.detach().clone()
        
        input_ids = prev_utterance.unsqueeze(1).expand((-1, beam, -1))
        input_ids = input_ids.contiguous().view((batch_size * beam, max_len * 2)) # [batch_size * beam, max_len]
        input_ids = input_ids.to(self.device)
        mask = mask.unsqueeze(1).expand((-1, beam, -1))
        mask = mask.contiguous().view(-1, max_len * 2) # [batch_size * beam, max_len]
        mask = mask.to(self.device)
        '''
            prev_utterance = [batch_size, max_len * 2]
            input_ids = [batch_size * beam, max_len * 2]
            mask = [batch_size * beam, max_len * 2] 
        '''

        done = [False for _ in range(batch_size)]
        hyps = [BeamHypotheses(num_beams = beam, max_length = max_len, length_penalty = length_penalty_for_hypothesis) for _ in range(self.vocab_size)]
        for cur_len in range(max_len):
            # print(input_ids[3], mask[3])
            # global a, b
            # if cur_len == max_len + 1:
            #     a = input_ids
            #     b = mask
            #     return
            out = torch.log(F.softmax(self.gpt(input_ids = input_ids, attention_mask = mask)['logits'].gather(index = (utter_length - 1).unsqueeze(-1).unsqueeze(-1).expand((-1, -1, self.vocab_size)), dim = 1), dim = -1))
            # print(out.argmax(-1))
            out = out.contiguous().view((batch_size, -1)) # [batch_size, beam * vocab_suze]
            beams_score_next = beams_score.unsqueeze(-1).expand((-1, self.vocab_size)).contiguous().view(batch_size, -1) + out
            next_scores, next_tokens = beams_score_next.topk(beam * 2, dim = -1)
            # print(next_tokens)
            next_beams = [] 
            for batch in range(batch_size):
                next_beams_batch = []
                batch_is_done = True
                for score, token in zip(next_scores[batch], next_tokens[batch]):
                    if len(next_beams_batch) >= beam:
                        break
                    beam_index = token // self.vocab_size
                    real_token = token % self.vocab_size
                    beam_index_for_input_ids = batch * beam + beam_index
                    # print(score / ((cur_len) ** length_penalty_for_hypothesis), hyps[batch].worst_score, beam_index_for_input_ids)
                    if (real_token == eos_token or cur_len == max_len - 1 or hyps[batch].worst_score is None or \
                    score / ((cur_len) ** length_penalty_for_hypothesis) > hyps[batch].worst_score or len(hyps[batch].beams) < beam) and cur_len >= 5:
                        
                        mask_beam = mask[beam_index_for_input_ids].detach().clone()
                        input_id_beam = input_ids[beam_index_for_input_ids].detach().clone()
                        mask_beam[utter_length[beam_index_for_input_ids]] = 1
                        input_id_beam[utter_length[beam_index_for_input_ids]] = real_token

                        hyps[batch].add(hyp = input_id_beam, mask = mask_beam, sum_logprobs = score, 
                                        cur_len = cur_len, utter_length = utter_length[beam_index_for_input_ids] + 1, 
                                        prev_utterance = input_id_beam[:original_utter_length[beam_index_for_input_ids]].detach().clone(),
                                        response = input_id_beam[original_utter_length[beam_index_for_input_ids]: utter_length[beam_index_for_input_ids] + 1].detach().clone(),
                                        eos_token = eos_token)
                        if real_token != eos_token:
                            next_beams_batch.append((score, real_token, beam_index_for_input_ids))
                    else:
                        next_beams_batch.append((score, real_token, beam_index_for_input_ids))
                if batch_is_done:
                    batch_is_done = hyps[batch].is_done(score, cur_len)
                next_beams.extend(next_beams_batch)
                if cur_len >= 20:
                    done[batch] = batch_is_done if not done[batch] else True

            beams_score = beams_score.new([x[0] for x in next_beams])
            beam_token = input_ids.new([x[1] for x in next_beams])
            beam_idx = input_ids.new([x[2] for x in next_beams])
            utter_length = utter_length[beam_idx] + 1
            original_utter_length = original_utter_length[beam_idx]
            input_ids = input_ids[beam_idx, :]
            mask = mask[beam_idx, :]
            with torch.no_grad():
                input_ids = input_ids.scatter(dim = 1, index = utter_length.unsqueeze(-1) - 1, src = beam_token.unsqueeze(-1).expand((-1, 2 * max_len)))
                mask = mask.scatter(dim = 1, index = utter_length.unsqueeze(-1) - 1, src = torch.zeros((batch_size * beam, 2 * max_len)) + 1)
            if all(done):
                break
        results = []
        scores = []
        masks = []
        utter_len = []
        sum_logprobs = []
        result_prev_utterance = []
        result_response = []
        for batch in range(batch_size):
            # (score, hyp, mask, utter_length, sum_logprobs)
            results.append([])
            scores.append([])
            masks.append([])
            utter_len.append([])
            sum_logprobs.append([])
            result_prev_utterance.append([])
            result_response.append([])
            for x in hyps[batch].beams: 
                result_prev_utterance[-1].append(x[5])
                result_response[-1].append(x[6])     
                sum_logprobs[-1].append(x[4])
                utter_len[-1].append(x[3])
                masks[-1].append(x[2])
                results[-1].append(x[1])
                scores[-1].append(x[0])
            results[-1] = torch.stack(results[-1])
            scores[-1] = torch.stack(scores[-1])
            masks[-1] = torch.stack(masks[-1])
            utter_len[-1] = torch.stack(utter_len[-1])
            sum_logprobs[-1] = torch.stack(sum_logprobs[-1])
        results = torch.stack(results)
        masks = torch.stack(masks)
        scores = torch.stack(scores)
        utter_len = torch.stack(utter_len)
        sum_logprobs = torch.stack(sum_logprobs)
        return results, masks, scores, sum_logprobs, utter_len, result_prev_utterance, result_response

In [42]:
[10, 34,5,6,7,8,9][-3:]

[7, 8, 9]

In [38]:
def train_one_epoch(epoch: int, gpt: GPT2Wrapper, Q_A: Q, Q_B: Q,  optimizer: torch.optim.Optimizer = None, R: Reward = None,
                    dataset: Iterable = GPT2DataSet(), device: torch.device = 'cpu', batch_size = 4, max_len = 256, beam = 3, update_time_per_episode = 10, criterion = nn.MSELoss(), kl_control = 0.01,
                    gpt_loss_coefficient = 0.1):
    gpt.train()
    # two Qs for Double DQN
    Q_A.train()
    Q_B.train()
    kl_losses = []
    gpt_losses = []
    q_losses = []
    total_losses = []
    

    with tqdm(total = len(dataset) // batch_size) as t:
        for step in range(len(dataset) // batch_size):
            t.set_description(f"Epoch {epoch}")

            prev_utterance = []
            response = []
            for mini_step in range(step * batch_size, (step + 1) * batch_size):
                pair = dataset[mini_step]
                prev_utterance.append(pair['prev_utterance'])
                response.append(pair['prev_utterance'])
            utter_length = None
            generate_time = 0
            with torch.no_grad(): # generate episode
                while (utter_length is None or all(utter_length < max_len)) and generate_time < 5:
                    generate_time += 1
                    if utter_length is None:
                        rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse = gpt(prev_utterance = prev_utterance, response = response, beam = beam, max_len = max_len,
                                                                device = device)
                        # soft sampling
                        previous_Q_distribution= F.softmax((Q_A(prev_utterance = rslt.view(batch_size * beam, -1), response = None, mask = None, bert_tokens = False, max_len = max_len * 2, processed = False) + 
                        Q_B(prev_utterance = rslt.view(batch_size * beam, -1), response = None, mask = None, bert_tokens = False, max_len = max_len * 2, processed = False)).view(batch_size, -1), dim = -1)
                        select = torch.multinomial(previous_Q_distribution, 1)
                        previous_result = rslt.gather(index = select.unsqueeze(-1).expand(-1, -1, max_len * 2), dim = 1).squeeze(1)
                        results = previous_result.detach().clone()
                        masks = msk.gather(index = select.unsqueeze(-1).expand(-1, -1, max_len * 2), dim = 1).squeeze(1)
                        scores = sco.gather(index = select, dim = 1).squeeze(1)
                        utter_length = utrlen.gather(index = select, dim = 1).squeeze(1)
                        sum_logprobs = smlgprbs.gather(index = select, dim = 1).squeeze(1)
                        results_prev_utterance = [rsltprvrnce[i][int(select[i])] for i in range(2)]
                        results_response = [rsltprvrnce[i][int(select[i])] for i in range(2)]
                    else:
                        rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse = gpt(prev_utterance = previous_result, beam = beam, max_len = max_len,
                                                            device = device)

                        previous_Q_distribution = F.softmax((Q_A(prev_utterance = rslt.view(batch_size * beam, -1), response = None, mask = None, bert_tokens = False, max_len = max_len * 2, processed = False) + 
                        Q_B(prev_utterance = rslt.view(batch_size * beam, -1), response = None, mask = None, bert_tokens = False, max_len = max_len * 2, processed = False)).view(batch_size, -1), dim = -1)
                        select = torch.multinomial(previous_Q_distribution, 1)
                        previous_result = rslt.gather(index = select.unsqueeze(-1).expand(-1, -1, max_len * 2), dim = 1).squeeze(1)
                        
                        results = torch.cat((results, previous_result.detach().clone()), dim = 0)
                        masks = torch.cat((masks, msk.gather(index = select.unsqueeze(-1).expand(-1, -1, max_len * 2), dim = 1).squeeze(1)), dim = 0)
                        scores = torch.cat((scores, sco.gather(index = select, dim = 1).squeeze(1)), dim = 0)
                        utter_length = torch.cat((utter_length, utrlen.gather(index = select, dim = 1).squeeze(1)), dim = 0)
                        sum_logprobs = torch.cat((sum_logprobs, smlgprbs.gather(index = select, dim = 1).squeeze(1)), dim = 0)
                        results_prev_utterance.extend([rsltprvrnce[i][int(select[i])] for i in range(2)])
                        results_response.extend([rsltprvrnce[i][int(select[i])] for i in range(2)])
                reward = R([{'prev_utterance': utt, 'response': res} for utt, res in zip(results_prev_utterance, results_response)])
                results_Q, mask_Q = Q_A.get_processed(prev_utterance = results, bert_tokens = False, max_len = max_len * 2)
            for update_time in range(update_time_per_episode):
                if random.random() >= 0.5: # update Q_A
                    q_estimate = Q_A.forward(prev_utterance = results_Q[: len(results_Q) - batch_size], response = None, mask = mask_Q[: len(results_Q) - batch_size], bert_tokens = True, max_len = max_len * 2, processed = True)
                    q_target = reward + Q_A.gamma * Q_B.forward(prev_utterance = results_Q[batch_size: ], response = None, mask = mask_Q[batch_size: ], bert_tokens = True, max_len = max_len * 2, processed = True)
                else: # update Q_B
                    q_estimate = Q_B.forward(prev_utterance = results_Q[: len(results_Q) - batch_size], response = None, mask = mask_Q[: len(results_Q) - batch_size], bert_tokens = True, max_len = max_len * 2, processed = True)
                    q_target = reward + Q_B.gamma * Q_A.forward(prev_utterance = results_Q[batch_size: ], response = None, mask = mask_Q[batch_size: ], bert_tokens = True, max_len = max_len * 2, processed = True)
                q_loss = criterion(q_estimate, q_target) 
                kl_loss = - kl_control * torch.log(torch.mean(q_estimate.exp()))
                probs = gpt.get_prob(results[: len(results_Q) - batch_size], masks[: len(results_Q) - batch_size], results_prev_utterance[: len(results_Q) - batch_size], results_response[: len(results_Q) - batch_size])
                gpt_loss = torch.sum(probs.detach().clone().exp() * probs * q_target)
                total_loss = q_loss + kl_control * kl_loss + gpt_loss_coefficient * gpt_loss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                
                kl_losses.append(kl_loss)
                gpt_losses.append(gpt_loss)
                q_losses.append(q_loss)
                total_losses.append(total_loss)
                t.set_postfix(klloss = np.mean(kl_losses[-1]), gpt_loss = np.mean(gpt_losses[-1]),
                          q_loss = np.mean(q_losses[-1]), total_loss = np.mean(total_losses[-1]))
            t.update(1)
    return np.mean(q_losses), np.mean(kl_losses), np.mean(gpt_losses), np.mean(total_losses)
            

In [34]:
def get_output_dir(pwd):
    runs_dir = os.path.join(pwd, 'runs')
    train_dir = os.path.join(runs_dir, 'train')
    if not os.path.exists(runs_dir):
        os.mkdir(runs_dir)
    if not os.path.exists(train_dir):
        os.mkdir(train_dir)
    counter = 1
    exp = f"exp{counter}"
    while exp in os.listdir(train_dir):
        counter += 1
        exp = f"exp{counter}"
    output_dir = os.path.join(train_dir, exp)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    if not os.path.exists(os.path.join(output_dir, 'models')):
        os.mkdir(os.path.join(output_dir, 'models'))
    with open(os.path.join(output_dir, 'log.csv'), 'w', newline = '') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Epoch', 'KLLoss', 'GPTLoss', 'QLoss', 'TOTALLoss'])
    return output_dir

In [41]:
sum([p.numel() for p in gpt2.parameters()])

102068736

In [44]:
toxic_words, non_sense_response = GPT2DataSet.get_toxic_ids_and_non_sense_response(gpt2_tokenizer)
toxic_words

[[2402, 106],
 [6547, 6515],
 [5101, 6100],
 [4374, 1061],
 [4374, 1061, 6028],
 [679, 6206, 5622],
 [1391, 2241],
 [3134, 7546],
 [3255, 7397],
 [4635, 4622],
 [6547, 782],
 [678, 3837],
 [3647, 5503, 6500],
 [782, 3942],
 [4868, 5195, 4567],
 [6547],
 [2225, 4454],
 [4192, 2615],
 [2040],
 [2023, 4152],
 [7798, 6026],
 [2450, 4289],
 [5582, 3659]]

In [46]:
lr = 1e-4
batch_size = 4
weight_decay = 1e-4
epochs = 16
lr_drop = 8
Q_discounter_factor = 0.9
kl_control = 0.01
gpt_loss_coefficient = 0.1
pwd = os.getcwd()
data_root = os.path.join(pwd, 'dataset')
assert os.path.exists(data_root)
output_dir = get_output_dir(pwd)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
gpt2 = GPT2LMHeadModel.from_pretrained(os.path.join(pwd, 'GPT-2/GPT2_finetune_2'))
bert = BertModel.from_pretrained('bert-base-chinese')
gpt2_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
gpt_wrapper = GPT2Wrapper(gpt2, tokenizer = gpt2_tokenizer, device = device)
Q_A = Q(gpt_tokenizer = gpt2_tokenizer, gamma = Q_discounter_factor, bert_name = 'bert-base-chinese')
Q_B = Q(gpt_tokenizer = gpt2_tokenizer, gamma = Q_discounter_factor, bert_name = 'bert-base-chinese')
toxic_words, non_sense_response = GPT2DataSet.get_toxic_ids_and_non_sense_response(gpt2_tokenizer)
R = Reward(gpt = gpt2, question_mark_token = 136, toxic_words = toxic_words, gpt_tokenizer = gpt2_tokenizer,
               non_sense_response = non_sense_response, eos_token = 102, device = device, bos_token = 101)
criterion = nn.MSELoss()
gpt_wrapper.to(device)
Q_A.to(device)
Q_B.to(device)
optimizer = torch.optim.Adam([{ 'params': [p for p in gpt_wrapper.parameters() if p.requires_grad]},
                              { 'params': [p for p in Q_A.parameters() if p.requires_grad]},
                              { 'params': [p for p in Q_B.parameters() if p.requires_grad]},
                              ], lr = lr, weight_decay = weight_decay)
print('total parameters: ', sum([p.numel() for p in gpt_wrapper.parameters() if p.requires_grad]) + 
      sum([p.numel() for p in Q_A.parameters() if p.requires_grad]) + 
      sum([p.numel() for p in Q_B.parameters() if p.requires_grad]))
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_drop)
train_dataset = GPT2DataSet(tokenizer = gpt2_tokenizer, max_len = 256, root_path = data_root, status = 'train')
test_dataset = GPT2DataSet(tokenizer = gpt2_tokenizer, max_len = 256, root_path = data_root, status = 'test')
q_losses = []
gpt_losses = []
kl_losses = []
total_losses = []
print('Start training')
for epoch in range(epochs):
    q_loss, kl_loss, gpt_loss, total_loss = train_one_epoch(epoch = epoch, gpt = gpt_wrapper, Q_A = Q_A, Q_B = Q_B,
                                                            optimizer = optimizer, R = R, dataset = train_dataset, device = device, batch_size = batch_size,
                                                            max_len = 256, beam = 3, update_time_per_episode = 10, criterion = criterion, 
                                                            kl_control = kl_control, gpt_loss_coefficient = gpt_loss_coefficient)
    q_losses.append(q_loss)
    kl_losses.append(kl_loss)
    gpt_losses.append(gpt_loss)
    total_losses.append(total_losses)
    torch.save({
            'Q_A': Q_A.state_dict(),
            'Q_B': Q_B.state_dict(),
            'GPT': gpt_wrapper.state_dict()
        }, os.path.join(output_dir, 'models/latest.pth'))
    if min(q_losses) == q_loss:
        torch.save({
            'Q_A': Q_A.state_dict(),
            'Q_B': Q_B.state_dict(),
            'GPT': gpt_wrapper.state_dict()
        }, os.path.join(output_dir, 'models/best_q_loss.pth'))
    if min(kl_losses) == kl_loss:
        torch.save({
            'Q_A': Q_A.state_dict(),
            'Q_B': Q_B.state_dict(),
            'GPT': gpt_wrapper.state_dict()
        }, os.path.join(output_dir, 'models/best_kl_loss.pth'))
    if min(gpt_losses) == gpt_loss:
        torch.save({
            'Q_A': Q_A.state_dict(),
            'Q_B': Q_B.state_dict(),
            'GPT': gpt_wrapper.state_dict()
        }, os.path.join(output_dir, 'models/best_gpt_loss.pth'))
    if min(total_losses) == total_loss:
        torch.save({
            'Q_A': Q_A.state_dict(),
            'Q_B': Q_B.state_dict(),
            'GPT': gpt_wrapper.state_dict()
        }, os.path.join(output_dir, 'models/best_total_loss.pth'))
    with open(os.path.join(output_dir, 'log.csv'), 'a', newline = '') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([epoch, kl_loss, gpt_loss, q_loss, total_loss])
        


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predicti

total parameters:  311069698
Start training


Epoch 0:   0%|          | 0/1705075 [00:00<?, ?it/s]

: 

: 

In [13]:
os.mkdir(os.path.join(pwd, 'runs/train'))

In [14]:
os.listdir('./runs/train/')

[]

In [12]:
aq = Q()
bq = Q()

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.trans

In [37]:
torch.sum(rslt)

tensor(353404)

In [21]:
train_one_epoch(gpt = g, Q_A = aq, Q_B = bq)

tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26, 32, 32, 32, 34, 34, 34])
tensor([22, 22, 22, 26, 26, 26,

KeyboardInterrupt: 

In [24]:
results[0]

tensor([ 101, 2371, 2218,  671,  943, 2099,  138, 1599, 3631,  140,  102,  101,
        2371, 2218,  671,  943, 2099,  138, 1599, 3631,  140,  102, 4638, 6929,
        1372, 3291, 3472, 8013,  101,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,   

In [30]:
g = GPT2Wrapper(gpt = gpt2, tokenizer = tokenizer)


In [69]:
Qa = 

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [71]:
Qa(prev_utterance = rslt.view(2 * 3, -1), response = None, mask = None, bert_tokens = False, max_len = 100 * 2, processed = False)

tensor([[-0.3341],
        [-0.6686],
        [-0.2113],
        [-0.6384],
        [-0.0766],
        [-0.3934]], grad_fn=<AddmmBackward0>)

In [76]:
dis = F.softmax(Qa(prev_utterance = rslt.view(2 * 3, -1), response = None, mask = None, bert_tokens = False, max_len = 100 * 2, processed = False).view(2, -1), dim = -1)

In [111]:
select = torch.multinomial(dis, 1)
select

tensor([[2],
        [1]])

In [109]:
sco, smlgprbs.shape

(tensor([[-0.7753, -0.7489, -0.7135],
         [-0.6005, -0.6241, -0.6006]], grad_fn=<StackBackward0>),
 torch.Size([2, 3]))

In [108]:
sco.gather(index = select, dim = 1).shape

torch.Size([2, 1])

In [31]:
rslt, msk, sco, smlgprbs, utrlen, rsltprvrnce, rsltrspse = g([torch.tensor([101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,
 102]), torch.tensor([101, 2769, 2695, 872,  138,
 2734,
 2584,
 140, 102])], device = 'cpu')

tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])
tensor([15, 15, 15,  9,  9,  9])


In [35]:
rsltprvrnce, rsltrspse

([[tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102]),
   tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102]),
   tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102])],
  [tensor([ 101, 2769, 2695,  872,  138, 2734, 2584,  140,  102]),
   tensor([ 101, 2769, 2695,  872,  138, 2734, 2584,  140,  102]),
   tensor([ 101, 2769, 2695,  872,  138, 2734, 2584,  140,  102])]],
 [[tensor([2111, 2094,  679, 4761, 6887,  872,  947, 3221, 6306, 1485, 8043, 8043,
           8043,  101,  102]),
   tensor([3221, 4158,  749, 1343, 1912, 7481, 7087, 7509, 1621, 8043, 8043, 8043,
            101,  103,  102]),
   tensor([2111, 2094,  679, 4761, 6887,  872,  947, 3221, 6306, 1485, 8043, 8043,
           8043,  101,  103,  102])],
  [tensor([ 872, 4638, 6929,  943, 3582, 3298,  102]),
   tensor([2769, 1652,  511, 1464,  872, 2695, 429

In [34]:
rslt[0]

tensor([[ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
         2584,  140,  102, 2111, 2094,  679, 4761, 6887,  872,  947, 3221, 6306,
         1485, 8043, 8043, 8043,  101,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,  

In [127]:
rsltrspse

[[tensor([8043,  792, 3221, 4158, 1567, 8043,  102]),
  tensor([8043,  792, 3221, 4158,  784, 7938, 8043, 6538,  982, 3043, 3043, 1765,
          1553, 7619,  746,  100,  101,  102]),
  tensor([8043,  792, 3221, 4158,  784, 7938, 8043, 6538,  982, 3043, 3043, 1765,
          1553, 7619,  746,  100,  101,  103,  102])],
 [tensor([3221, 2769, 2061, 2874, 7097,  749,  102]),
  tensor([3221, 2769, 2061, 2874,  749, 2769, 4638, 2894, 1462, 8024, 2769, 2894,
          1462, 4638, 2894, 1462, 4638, 2894, 1462, 4638,  102]),
  tensor([3221, 2769, 2061, 2874,  749, 2769, 4638, 2894, 1462, 8024, 2769, 2894,
          1462, 4638, 2894, 1462, 4638, 2894, 1462, 4638, 2894,  102])]]

In [126]:
[rsltrspse[index_selected][int(select[index_selected])] for index_selected in range(2)]

[tensor([8043,  792, 3221, 4158,  784, 7938, 8043, 6538,  982, 3043, 3043, 1765,
         1553, 7619,  746,  100,  101,  103,  102]),
 tensor([3221, 2769, 2061, 2874,  749, 2769, 4638, 2894, 1462, 8024, 2769, 2894,
         1462, 4638, 2894, 1462, 4638, 2894, 1462, 4638,  102])]

In [24]:
rslt, msk, sco, smlgprbs, utrlen = g(rslt[0], device = 'cpu')

'[CLS] 今 天 我 去 外 面 公 園 玩 [ 憤 怒 ] [SEP] ？ 介 是 為 啥 ？ [ 憤 怒 ] [SEP] 騙 人 的 嗎 ？ [CLS] [MASK] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 

In [121]:
a[1][0].shape
r([[-2.8096, -6.8823, -6.8823],
        [-2.1762, -6.5803, -6.5984]], gra

torch.Size([200])

In [43]:
g = h.gather(index = torch.tensor([[2], [1], [0]]), dim = 1)
g, g.shape

(tensor([[ 0.9054],
         [ 2.1621],
         [-0.3211]]),
 torch.Size([3, 1]))

In [51]:
h.scatter(dim = 1, index = torch.tensor([[2], [1], [0]]), src = torch.zeros())

TypeError: zeros() received an invalid combination of arguments - got (), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of SymInts size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [212]:
gpt2(input_ids = a[3].unsqueeze(0), attention_mask = b[3].unsqueeze(0).to(torch.long))['logits'][:, 9, :].topk(6, dim = -1)

torch.return_types.topk(
values=tensor([[13.6946, 13.5733, 11.8984, 11.7583, 11.6346, 11.5276]],
       grad_fn=<TopkBackward0>),
indices=tensor([[ 947, 5582, 8043, 1557,  134, 4638]]))

In [219]:
gpt2(input_ids = torch.tensor([[ 101, 2769, 2695,  872,  138, 2734, 2584,  140,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,  102, 3221,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]]), attention_mask = torch.tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]]))['logits'][:, 95:105, :].topk(6, dim = -1)

torch.return_types.topk(
values=tensor([[[19.3966, 16.6512, 15.6357, 14.6237, 14.0722, 13.7992],
         [17.7092, 17.1337, 17.0325, 16.8664, 16.7009, 16.6512],
         [18.5763, 15.8345, 15.3630, 15.3226, 14.5392, 13.3719],
         [18.1575, 17.4653, 15.9241, 14.5630, 14.4235, 14.4225],
         [15.6550, 15.1796, 15.0827, 13.2894, 12.9438, 11.5672],
         [14.6270, 12.9387, 12.3596, 12.2733, 11.1103, 10.4297],
         [19.9478, 15.4028, 13.7760, 13.4685, 13.0539, 12.9692],
         [16.9682, 13.9107, 13.6344, 13.5101, 12.8403, 11.6358],
         [17.5265, 12.5370, 12.4661, 11.6980, 11.4350, 11.1611],
         [14.6954, 11.4877, 11.4264, 11.2573, 11.2200, 10.9985]]],
       grad_fn=<TopkBackward0>),
indices=tensor([[[2241, 5175, 4649, 3221,  872, 1889],
         [1995, 6552, 2094, 2336, 5529, 6303],
         [ 872, 2094, 1162, 8043, 8013, 3221],
         [5175, 2094, 4863, 3506, 2241,  712],
         [3221, 8013,  511, 8024,  100, 4638],
         [1567,  943, 2769, 8043,  784, 

In [70]:
gpt2(input_ids = torch.tensor([[101, 2769, 2695, 872,  138,
 2734,
 2584,
 140, 102, 3221, 2769]]))['logits'][:, -1, :].topk(6, dim = -1)

torch.return_types.topk(
values=tensor([[19.6527, 13.3922, 12.5040, 12.1969, 10.9196,  9.9409]],
       grad_fn=<TopkBackward0>),
indices=tensor([[2061, 7464, 6873, 4638, 3553,  982]]))

In [97]:
tokenizer.decode(gpt2.generate(input_ids = torch.tensor([[101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,
 102]]), num_beams = 3)[0]), gpt2.generate(input_ids = torch.tensor([[101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,
 102]]), num_beams = 3)[0]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


('[CLS] 今 天 我 去 外 面 公 園 玩 [ 憤 怒 ] [SEP] ？ 介 是 為 啥',
 tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
         2584,  140,  102, 8043,  792, 3221, 4158, 1567]))

In [104]:
input_ids = torch.tensor([101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,  102,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0])
gpt2(input_ids = input_ids, attention_mask = (input_ids > 0).float())['logits'].argmax(-1)[99]

tensor(511)

In [None]:
t = []
t.extend([[1,2,3], [1,2,5]])
t.extend([[1,2,3], [1,2,5]])
t

[[1, 2, 3], [1, 2, 5], [1, 2, 3], [1, 2, 5]]

In [None]:
t

[[1, 2, 3], [1, 2, 5]]

In [None]:
value, indices = g.topk(2)

In [None]:
indices[1]

tensor([2, 1])

In [None]:
g

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [None]:
torch.cat((g, torch.zeros((5, 1)) + 1), dim = -1)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [None]:
result, sr, score = gg(torch.tensor([[101,
 791,
 1921,
 2769,
 1343,
 1912,
 7481,
 1062,
 1754,
 4381,
 138,
 2734,
 2584,
 140,
 102]]), max_len = 100, beam = 3)

  value, indices = F.softmax(gpt2(input_ids = prev_utterance)['logits']).topk(beam, dim = -1)


In [None]:
result

[[[tensor(-5.5424, grad_fn=<AddBackward0>),
   tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102, 2347, 5195, 6365, 5682,  749,  101,  103])],
  [tensor(-6.1180, grad_fn=<AddBackward0>),
   tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102, 2347, 5195, 4255, 4634,  749,  101,  103])],
  [tensor(-6.3388, grad_fn=<AddBackward0>),
   tensor([ 101,  791, 1921, 2769, 1343, 1912, 7481, 1062, 1754, 4381,  138, 2734,
           2584,  140,  102, 2347, 5195, 6365, 5622,  749,  101,  103])]]]

In [None]:
score

tensor([[-5.5424, -6.1180, -6.3388]])

In [None]:
sr

[[tensor([2347, 5195, 6365, 5682,  749,  101,  103]),
  tensor([2347, 5195, 4255, 4634,  749,  101,  103]),
  tensor([2347, 5195, 6365, 5622,  749,  101,  103])]]

In [None]:
for i in sr[0]:
    print(tokenizer.decode(i[1], skip_special_tokens = True))

已 經 變 色 了 我
已 經 變 色 了 你
已 經 爆 發 了 我
已 經 爆 發 了 你
已 經 變 臉 了 我


In [None]:
value, indices = F.softmax(gpt2(input_ids = torch.tensor([101, 2769, 2695, 872, 102]))['logits'], dim = -1).topk(3, dim = -1)

In [None]:
torch.log(value), indices, value.shape

(tensor([[-1.7166e-05, -1.1302e+01, -1.3569e+01],
         [-2.4632e+00, -2.5122e+00, -2.8395e+00],
         [-4.5884e-01, -2.5241e+00, -3.6483e+00],
         [-1.1748e+00, -2.0745e+00, -2.3433e+00],
         [-1.3878e+00, -1.6446e+00, -2.4694e+00]], grad_fn=<LogBackward0>),
 tensor([[ 103,  140,  101],
         [ 738, 4638,  947],
         [ 872, 4638, 1391],
         [ 138,  101, 8024],
         [ 112,  138,  100]]),
 torch.Size([5, 3]))

In [None]:
torch.cat((torch.tensor([19, 19]), indices[-1][0].unsqueeze(0)))

tensor([ 19,  19, 112])

In [None]:
q = Q()

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
prev = [torch.tensor([101, 2769, 2695, 872, 102]), torch.tensor([101, 2769, 2695, 872, 34, 234, 45, 102])]

In [None]:
q(prev_utterance = prev)

tensor([[3.7675e-01],
        [2.7341e-04]], grad_fn=<AddmmBackward0>)

In [None]:
g = torch.randn(10)
h = torch.randn(20)

In [None]:
mask = g.ge(0)

In [None]:
mask.float()

tensor([0., 0., 1., 1., 0., 1., 1., 0., 1., 1.])

In [35]:
data = GPT2DataSet(shuffle = True)

In [40]:
data[4]

{'prev_utterance': tensor([ 101, 4412, 1762, 6917, 6206, 4500, 3265, 3706, 8043,  138, 1599, 3631,
          140,  102]),
 'response': tensor([4412, 1762, 3760,  749,  102])}

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-1.2095e-01,  5.9866e-01, -1.1736e+00,  ...,  1.3144e+00,
          -6.6073e-01,  3.1353e-01],
         [-2.5177e-01,  5.4246e-01, -1.5848e-01,  ...,  1.7940e-01,
          -5.1476e-01, -4.9664e-02],
         [-3.1723e-01,  4.4264e-01, -1.4355e-01,  ...,  1.8843e-01,
          -6.1138e-01,  1.0900e-03]]], grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 0.9996,  0.9367,  0.9999,  0.9204,  0.8145, -0.3277, -0.8850, -0.9492,
          0.9857, -0.9993,  1.0000,  1.0000, -0.8858, -0.7393,  0.9996, -0.9996,
         -0.5293,  0.0339,  0.9951,  0.1820,  0.9798, -0.9997,  0.0852, -0.8749,
         -0.5547,  0.9978,  0.7787, -0.9091, -0.9961,  0.9996,  0.9721,  0.9997,
          0.8698, -0.9885, -0.9998,  0.5070,  0.5023,  0.9952,  0.8226, -0.8865,
         -0.9540,  0.4088, -0.2360, -0.9980, -0.6263,  0.6001, -1.0000, -0.9999,
         -0.9786,  0.9238, -0.7652, -1.0000,  0.9813, -0.9630, -0.8354,  0.9973,
  

In [None]:
q = Q()

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
q(prev_utterance = i['prev_utterance'], response = i['response'])

[tensor([ 101, 3193, 7953, 1962, 6629,  889, 1557,  511, 4495, 3189, 2571, 3556,
         100, 1962, 2533, 2345,  679, 1914, 1568,  102]), tensor([ 101, 6929, 7938, 3193, 8013, 8013, 2769, 4638, 5440, 6275, 1453, 6917,
        3760, 3300, 7274, 1993, 1450, 8013,  100,  679, 3632,  791, 2399, 8024,
        2769, 4634, 4412,  872, 3680, 3613, 6963, 1962, 3193,  102]), tensor([ 101,  791, 1921,  679, 4761, 6887, 5543,  678, 7938, 8024,  679, 6882,
        1921, 1348, 1107,  749, 5018,  671, 1842, 7434,  679, 3221, 7434, 5709,
        4638, 3564, 2094, 8024, 3800, 2692, 7344, 2207, 1102, 7441, 1521,  102]), tensor([ 101, 2207, 1520, 1520, 3136, 3136, 2769,  100, 6241, 1001, 6716, 3136,
         102])]


tensor([[-0.2735],
        [-0.2923],
        [-0.0200],
        [ 0.3614]], grad_fn=<AddmmBackward0>)

In [None]:
i

{'prev_utterance': tensor([[ 101, 3193, 7953,  ...,   -1,   -1,   -1],
         [ 101, 6929, 7938,  ...,   -1,   -1,   -1],
         [ 101,  791, 1921,  ...,   -1,   -1,   -1],
         [ 101, 2207, 1520,  ...,   -1,   -1,   -1]]),
 'response': tensor([[1962, 2533, 2345,  ...,   -1,   -1,   -1],
         [ 679, 3632,  791,  ...,   -1,   -1,   -1],
         [5018,  671, 1842,  ...,   -1,   -1,   -1],
         [6241, 1001, 6716,  ...,   -1,   -1,   -1]])}

In [None]:
gpt_tokenizer = BertTokenizer(vocab_file = './vocab_small.txt')

In [None]:
gpt_tokenizer.encode(["我我", "我我"], return_tensors = 'pt')

tensor([[101, 100, 100, 102]])

In [None]:
bert = BertModel.from_pretrained('bert-base-chinese')

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
bert.forward()

In [None]:
bert(torch.tensor([[2, 3, 4]]))['last_hidden_state'].shape

torch.Size([1, 3, 768])

In [None]:
bert.pooler.dense.out_features

768

In [None]:
class LitMNIST(LightningModule):
    def __init__(self, data_dir = '../LCCC-base', q_learning_rate = 1e-4, gpt_learning_rate = 1e-4, dqn_discounter = 0.95, dqn_alpha = 0.98):
        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir

        self.q_learning_rate = q_learning_rate
        self.gpt_learning_rate = gpt_learning_rate
        self.dqn_discounter = dqn_discounter
        self.dqn_alpha = dqn_alpha

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.Q_a = 

        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [None]:
g = GPT2DataSet()

In [None]:
h = DataLoader(g, batch_size = 4)

In [None]:
i['prev_utterance'][i['prev_utterance'] >= 0]

tensor([ 101, 3193, 7953, 1962, 6629,  889, 1557,  511, 4495, 3189, 2571, 3556,
         102,    1,  101, 6929, 7938, 3193, 8013, 8013, 2769, 4638, 5440, 6275,
        1453, 6917, 3760, 3300, 7274, 1993, 1450, 8013,  102,    1,  101,  791,
        1921,  679, 4761, 6887, 5543,  678, 7938, 8024,  679, 6882, 1921, 1348,
        1107,  749,  102,    0,  101, 2207, 1520, 1520, 3136, 3136, 2769,  102,
           1])

In [None]:
g.tokenizer.encode('我按你', return_tensors = 'pt')[0]

tensor([ 101, 2769, 2902,  872,  102])

In [None]:
counter = 0
sample = []
for i in h:
    if counter == 1:
        break
    counter += 1
    print(i)

{'prev_utterance': tensor([[ 101,  872, 1343,  ...,   -1,   -1,   -1],
        [ 101, 2769, 4500,  ...,   -1,   -1,   -1],
        [ 101, 1490, 1490,  ...,   -1,   -1,   -1],
        [ 101, 6250, 2533,  ...,   -1,   -1,   -1]]), 'response': tensor([[ 6887,  3624,  8013,  ...,    -1,    -1,    -1],
        [10235,  8168,  3193,  ...,    -1,    -1,    -1],
        [ 2402,  2130,  6857,  ...,    -1,    -1,    -1],
        [  807,  6134,   749,  ...,    -1,    -1,    -1]])}


In [None]:
for i in h:
    print(i)

TypeError: object of type 'type' has no len()

In [None]:
with open('../LCCC-base/single_emo_T_test.json') as f:
    train = json.load(f)

In [None]:
train[130]

['期待趙麗穎[喜歡]', '期待']

In [None]:
sample = train[52:60]

In [None]:
state = []
for i in sample:
    state.append(Reward.sentence2id(i, tokenizer))

In [None]:
state[0]['response'] = state[1]['response']

In [None]:
reward = Reward(gpt = gpt2, question_mark_token = 8043, toxic_words = [], eos_token = 102, device = 'cpu', gpt_tokenizer = tokenizer, non_sense_response = [[6857, 3564, 2094, 1621,  102], [1962, 1595,  102]])

Some weights of the model checkpoint at ckiplab/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dens

In [None]:
state

[{'prev_utterance': tensor([ 101, 2769,  791, 1921, 2682, 6313, 2145, 2786, 1391, 6917, 3760, 1391,
          2768,  738,  102,    1]),
  'response': tensor([6250, 2533, 2380, 4891, 4289, 8024, 1190, 1726,  889, 8024, 5168, 6640,
           102])},
 {'prev_utterance': tensor([ 101,  671, 6662, 6624, 1962,  172,  102,    1]),
  'response': tensor([6250, 2533, 2380, 4891, 4289, 8024, 1190, 1726,  889, 8024, 5168, 6640,
           102])},
 {'prev_utterance': tensor([ 101, 3173, 2094, 5507, 2137, 3298,  976, 4525, 4994, 7427, 7797, 3221,
          2130, 5401, 4638,  102,    1], device='mps:0'),
  'response': tensor([1506, 1506, 1506, 1506, 1506, 1506, 1506, 3221, 8013, 7427, 7797, 3297,
          2130, 5401, 8013,  102], device='mps:0')},
 {'prev_utterance': tensor([ 101, 2769,  679, 1962, 8024, 5439, 2094, 6206, 4717,  749, 8024,  872,
           738, 2571, 4717, 1416,  102,    1], device='mps:0'),
  'response': tensor([ 872, 4412, 1762, 6917, 1377,  809, 4717, 8013, 8013,  102],
        

In [None]:
reward(state)

tensor([0.3175, 0.4140, 0.3186, 0.3201, 0.3121, 0.2846, 0.4882, 0.3271])