From 4aca8cbcbcd3b9e7cf6ec1f106959fcd03cd57ef Mon Sep 17 00:00:00 2001 From: liujiachi1997 <69622171+liujiachi1997@users.noreply.github.com> Date: Sun, 1 May 2022 15:56:07 +0800 Subject: [PATCH 1/2] Update dataloader.py --- uer/utils/dataloader.py | 45 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/uer/utils/dataloader.py b/uer/utils/dataloader.py index 1eed3eb7..35cc8ba5 100644 --- a/uer/utils/dataloader.py +++ b/uer/utils/dataloader.py @@ -72,23 +72,27 @@ def __iter__(self): masked_words_num = 0 for ins in instances: + src_single, pad_num = ins[0] + for _ in range(pad_num): + src_single.append(self.vocab.get(PAD_TOKEN)) + if len(ins) == 4: - src.append(ins[0]) + src.append(src_single) masked_words_num += len(ins[1]) - tgt_mlm.append([0] * len(ins[0])) + tgt_mlm.append([0] * len(src_single)) for mask in ins[1]: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[2]) - seg.append([1] * ins[3][0] + [2] * (ins[3][1] - ins[3][0]) + [0] * (len(ins[0]) - ins[3][1])) + seg.append([1] * ins[3][0] + [2] * (ins[3][1] - ins[3][0]) + [0] * (pad_num)) else: - src_single, tgt_mlm_single = mask_seq(ins[0], self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + tgt_mlm.append([0] * len(src_single)) + src_single, tgt_mlm_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_mlm_single) src.append(src_single) - tgt_mlm.append([0] * len(ins[0])) for mask in tgt_mlm_single: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[1]) - seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * (len(ins[0]) - ins[2][1])) + seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * (pad_num)) if masked_words_num == 0: continue @@ -118,21 +122,25 @@ def __iter__(self): masked_words_num = 0 for ins in instances: + src_single, pad_num = ins[0] + for _ in range(pad_num): + src_single.append(self.vocab.get(PAD_TOKEN)) + if len(ins) == 3: - src.append(ins[0]) + src.append(src_single) masked_words_num += len(ins[1]) - tgt.append([0] * len(ins[0])) + tgt.append([0] * len(src_single)) for mask in ins[1]: tgt[-1][mask[0]] = mask[1] - seg.append([1] * ins[2][0] + [0] * (len(ins[0]) - ins[2][0])) + seg.append([1] * ins[2][0] + [0] * (pad_num)) else: - src_single, tgt_single = mask_seq(ins[0], self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + tgt.append([0] * len(src_single)) + src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_single) src.append(src_single) - tgt.append([0] * len(ins[0])) for mask in tgt_single: tgt[-1][mask[0]] = mask[1] - seg.append([1] * ins[1][0] + [0] * (len(ins[0]) - ins[1][0])) + seg.append([1] * ins[1][0] + [0] * (pad_num)) if masked_words_num == 0: continue @@ -142,6 +150,7 @@ def __iter__(self): torch.LongTensor(seg) + class AlbertDataloader(BertDataloader): ''' AlbertDataloader can reuse the code of BertDataloader. @@ -166,12 +175,16 @@ def __iter__(self): seg = [] for ins in instances: - src.append(ins[0][:-1]) - tgt.append(ins[0][1:]) - if ins[1] == len(ins[0]): + src_single, pad_num = ins[0] + if ins[1] == len(src_single): seg.append([1] * (ins[1] - 1)) else: - seg.append([1] * ins[1] + [0] * (len(ins[0]) - 1 - ins[1])) + for _ in range(pad_num): + src_single.append(self.vocab.get(PAD_TOKEN)) + seg.append([1] * ins[1] + [0] * (pad_num - 1)) + + src.append(src_single[:-1]) + tgt.append(src_single[1:]) yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ From 93df5d423ca45acba2ac67c84d7523ea262414b8 Mon Sep 17 00:00:00 2001 From: liujiachi1997 <69622171+liujiachi1997@users.noreply.github.com> Date: Sun, 1 May 2022 15:57:54 +0800 Subject: [PATCH 2/2] Update dataset.py --- uer/utils/dataset.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/uer/utils/dataset.py b/uer/utils/dataset.py index dfa2f434..a9d8fc59 100644 --- a/uer/utils/dataset.py +++ b/uer/utils/dataset.py @@ -205,13 +205,16 @@ def create_ins_from_doc(self, all_documents, document_index): src.append(self.vocab.get(SEP_TOKEN)) seg_pos.append(len(src)) - while len(src) != self.seq_length: - src.append(self.vocab.get(PAD_TOKEN)) - + pad_num = 0 + if len(src) != self.seq_length: + pad_num = self.seq_length - len(src) + if not self.dynamic_masking: src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src = (src, pad_num) instance = (src, tgt_mlm, is_random_next, seg_pos) else: + src = (src, pad_num) instance = (src, is_random_next, seg_pos) instances.append(instance) @@ -290,8 +293,10 @@ def build_instances(self, all_documents): if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src = (src, 0) instance = (src, tgt, seg_pos) else: + src = (src, 0) instance = (src, seg_pos) instances.append(instance) @@ -299,13 +304,15 @@ def build_instances(self, all_documents): src = all_documents[instances_num * self.seq_length:] seg_pos = [len(src)] - while len(src) != self.seq_length: - src.append(self.vocab.get(PAD_TOKEN)) + if len(src) != self.seq_length: + pad_num = self.seq_length - len(src) if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src = (src, pad_num) instance = (src, tgt, seg_pos) else: + src = (src, pad_num) instance = (src, seg_pos) instances.append(instance) @@ -442,13 +449,14 @@ def worker(self, proc_id, start, end): for i in range(instances_num): src = document[i * (self.seq_length + 1): (i + 1) * (self.seq_length + 1)] seg_pos = self.seq_length + src = (src, 0) pickle.dump((src, seg_pos), dataset_writer) src = document[instances_num * (self.seq_length + 1):] if len(src) > 0: seg_pos = len(src) - while len(src) != self.seq_length + 1: - src.append(self.vocab.get(PAD_TOKEN)) + pad_num = self.seq_length - len(src) + 1 + src = (src, pad_num) pickle.dump((src, seg_pos), dataset_writer) if pos >= end: