## vocab.py

### TorchVocab

In [3]:
class TorchVocab(object):

    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):

        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1

In [4]:
class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
                         max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
        pass

    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)

### WordVocab

In [5]:
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter()
        for line in tqdm.tqdm(texts):
            if isinstance(line, list):
                words = line
            else:
                words = line.replace("\n", "").replace("\t", "").split()

            for word in words:
                counter[word] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
        if isinstance(sentence, str):
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence]

        if with_eos:
            seq += [self.eos_index]  # this would be index 1
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" % idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]

        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

### 构建词典

In [6]:
corpus_path=args.train_dataset
with open(corpus_path, "r") as f:
    vocab = WordVocab(f, max_size=None, min_freq=1)

print("VOCAB SIZE:", len(vocab))
vocab.save_vocab(args.output_path)

2it [00:00, 544.54it/s]

Building Vocab
VOCAB SIZE: 15





### 加载词典

In [7]:
vocab=WordVocab.load_vocab(args.vocab_path)

### Loading DataSet

In [10]:
print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)

print("Loading Test Dataset", args.test_dataset)
test_dataset = BERTDataset(args.test_dataset, vocab,
                           seq_len=args.seq_len) if args.test_dataset is not None else None

print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
    if test_dataset is not None else None

Loading Dataset: 2it [00:00, 871.00it/s]
Loading Dataset: 2it [00:00, 5551.69it/s]

Loading Train Dataset data/corpus.small
Loading Test Dataset data/corpus.small
Creating Dataloader





### Building Bert

## pretain.py

### BERTTrainer

In [23]:
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm

class BERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    please check the details on README.md with simple example.

    """

    def __init__(self, bert: BERT, vocab_size: int,
                 train_dataloader: DataLoader, test_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        # This BERT model will be saved every epoch
        self.bert = bert
        # Initialize the BERT Language Model, with BERT model
        self.model = BERTLM(bert, vocab_size).to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)

        self.log_freq = log_freq

        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
       
        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=",
              total_correct * 100.0 / total_element)

    def save(self, epoch, file_path="output/bert_trained.model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.cpu(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

### BERTLM

In [24]:
class BERTLM(nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)


class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))


class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))


## Start Trainning

In [25]:
print("Creating BERT Trainer")
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
                      lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
                      with_cuda=args.with_cuda, log_freq=args.log_freq)

print("Training Start")
for epoch in range(args.epochs):
    trainer.train(epoch)
    trainer.save(epoch, args.output_path)

    if test_data_loader is not None:
        trainer.test(epoch)

Creating BERT Trainer
Total Parameters: 6327057
Training Start


EP_train:0: 100%|| 1/1 [00:00<00:00,  2.86it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 2.986823081970215, 'avg_acc': 0.0, 'loss': 2.986823081970215}
EP0_train, avg_loss= 2.986823081970215 total_acc= 0.0
EP:0 Model Saved on: data/corpus.small.vocab.ep0



  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
EP_test:0: 100%|| 1/1 [00:00<00:00,  7.96it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 4.260068416595459, 'avg_acc': 50.0, 'loss': 4.260068416595459}
EP0_test, avg_loss= 4.260068416595459 total_acc= 50.0



EP_train:1: 100%|| 1/1 [00:00<00:00,  2.61it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 6.681731224060059, 'avg_acc': 50.0, 'loss': 6.681731224060059}
EP1_train, avg_loss= 6.681731224060059 total_acc= 50.0
EP:1 Model Saved on: data/corpus.small.vocab.ep1



EP_test:1: 100%|| 1/1 [00:00<00:00,  8.81it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 2.9084110260009766, 'avg_acc': 50.0, 'loss': 2.9084110260009766}
EP1_test, avg_loss= 2.9084110260009766 total_acc= 50.0



EP_train:2: 100%|| 1/1 [00:00<00:00,  1.68it/s]


{'epoch': 2, 'iter': 0, 'avg_loss': 9.12502670288086, 'avg_acc': 0.0, 'loss': 9.12502670288086}
EP2_train, avg_loss= 9.12502670288086 total_acc= 0.0
EP:2 Model Saved on: data/corpus.small.vocab.ep2


EP_test:2: 100%|| 1/1 [00:00<00:00,  4.99it/s]

{'epoch': 2, 'iter': 0, 'avg_loss': 0.13810335099697113, 'avg_acc': 50.0, 'loss': 0.13810335099697113}
EP2_test, avg_loss= 0.13810335099697113 total_acc= 50.0



EP_train:3: 100%|| 1/1 [00:00<00:00,  2.09it/s]


{'epoch': 3, 'iter': 0, 'avg_loss': 2.3999338150024414, 'avg_acc': 50.0, 'loss': 2.3999338150024414}
EP3_train, avg_loss= 2.3999338150024414 total_acc= 50.0
EP:3 Model Saved on: data/corpus.small.vocab.ep3


EP_test:3: 100%|| 1/1 [00:00<00:00,  4.54it/s]

{'epoch': 3, 'iter': 0, 'avg_loss': 1.0868234634399414, 'avg_acc': 0.0, 'loss': 1.0868234634399414}
EP3_test, avg_loss= 1.0868234634399414 total_acc= 0.0



EP_train:4: 100%|| 1/1 [00:00<00:00,  2.06it/s]


{'epoch': 4, 'iter': 0, 'avg_loss': 3.5171289443969727, 'avg_acc': 50.0, 'loss': 3.5171289443969727}
EP4_train, avg_loss= 3.5171289443969727 total_acc= 50.0
EP:4 Model Saved on: data/corpus.small.vocab.ep4


EP_test:4: 100%|| 1/1 [00:00<00:00,  8.06it/s]

{'epoch': 4, 'iter': 0, 'avg_loss': 0.08919079601764679, 'avg_acc': 100.0, 'loss': 0.08919079601764679}
EP4_test, avg_loss= 0.08919079601764679 total_acc= 100.0



EP_train:5: 100%|| 1/1 [00:00<00:00,  2.62it/s]


{'epoch': 5, 'iter': 0, 'avg_loss': 4.4835968017578125, 'avg_acc': 100.0, 'loss': 4.4835968017578125}
EP5_train, avg_loss= 4.4835968017578125 total_acc= 100.0
EP:5 Model Saved on: data/corpus.small.vocab.ep5


EP_test:5: 100%|| 1/1 [00:00<00:00,  9.82it/s]

{'epoch': 5, 'iter': 0, 'avg_loss': 3.861670970916748, 'avg_acc': 50.0, 'loss': 3.861670970916748}
EP5_test, avg_loss= 3.861670970916748 total_acc= 50.0



EP_train:6: 100%|| 1/1 [00:00<00:00,  2.91it/s]

{'epoch': 6, 'iter': 0, 'avg_loss': 0.00042926802416332066, 'avg_acc': 100.0, 'loss': 0.00042926802416332066}
EP6_train, avg_loss= 0.00042926802416332066 total_acc= 100.0
EP:6 Model Saved on: data/corpus.small.vocab.ep6



EP_test:6: 100%|| 1/1 [00:00<00:00,  9.97it/s]

{'epoch': 6, 'iter': 0, 'avg_loss': 0.01838160678744316, 'avg_acc': 0.0, 'loss': 0.01838160678744316}
EP6_test, avg_loss= 0.01838160678744316 total_acc= 0.0



EP_train:7: 100%|| 1/1 [00:00<00:00,  3.27it/s]

{'epoch': 7, 'iter': 0, 'avg_loss': 1.6109089851379395, 'avg_acc': 50.0, 'loss': 1.6109089851379395}
EP7_train, avg_loss= 1.6109089851379395 total_acc= 50.0
EP:7 Model Saved on: data/corpus.small.vocab.ep7



EP_test:7: 100%|| 1/1 [00:00<00:00, 10.21it/s]

{'epoch': 7, 'iter': 0, 'avg_loss': 3.692484140396118, 'avg_acc': 50.0, 'loss': 3.692484140396118}
EP7_test, avg_loss= 3.692484140396118 total_acc= 50.0



EP_train:8: 100%|| 1/1 [00:00<00:00,  3.28it/s]

{'epoch': 8, 'iter': 0, 'avg_loss': 5.374189376831055, 'avg_acc': 50.0, 'loss': 5.374189376831055}
EP8_train, avg_loss= 5.374189376831055 total_acc= 50.0
EP:8 Model Saved on: data/corpus.small.vocab.ep8



EP_test:8: 100%|| 1/1 [00:00<00:00, 13.11it/s]

{'epoch': 8, 'iter': 0, 'avg_loss': 6.1481781005859375, 'avg_acc': 100.0, 'loss': 6.1481781005859375}
EP8_test, avg_loss= 6.1481781005859375 total_acc= 100.0



EP_train:9: 100%|| 1/1 [00:00<00:00,  3.05it/s]


{'epoch': 9, 'iter': 0, 'avg_loss': 0.0006999903125688434, 'avg_acc': 50.0, 'loss': 0.0006999903125688434}
EP9_train, avg_loss= 0.0006999903125688434 total_acc= 50.0
EP:9 Model Saved on: data/corpus.small.vocab.ep9


EP_test:9: 100%|| 1/1 [00:00<00:00, 11.70it/s]

{'epoch': 9, 'iter': 0, 'avg_loss': 2.9055590629577637, 'avg_acc': 100.0, 'loss': 2.9055590629577637}
EP9_test, avg_loss= 2.9055590629577637 total_acc= 100.0



