<a href="https://colab.research.google.com/github/h5ng/GNN/blob/master/cmod_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fairseq==0.9 transformers==2.9



In [None]:
import csv
import os
import logging
import argparse
import random
from tqdm import tqdm, trange
import json

import numpy as np
import tor  ch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForMaskedLM, BertOnlyMLMHead

from transformers import AdamW

from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/transformers-data-augmentation/bert_aug')
from data_processors import get_task_processor

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BERT_MODEL = 'bert-base-uncased'

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

logger = logging.getLogger(__name__)

In [None]:
from argparse import Namespace
parser = argparse.ArgumentParser()
args = {
    'data_dir': '/content/drive/MyDrive/transformers-data-augmentation/datasets/TREC',
    'output_dir': 'aug_data',
    'task_name': 'trec',
    'max_seq_length': 64,
    'cache': 'transformers_cache',
    'train_batch_size': 8,
    'learning_rate': 4e-5,
    'num_train_epochs': 10.0,
    'warmup_proportion': 0.1,
    'seed': 42,
    'sample_num': 1,
    'sample_ratio': 7,
    'gpu': 0,
    'temp': 1.0
}
args = Namespace(**args)

In [None]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, init_ids, input_ids, input_mask, masked_lm_labels):
        self.init_ids = init_ids
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.masked_lm_labels = masked_lm_labels
        
def compute_dev_loss(model, dev_dataloader):
    model.eval()
    sum_loss = 0.
    for step, batch in enumerate(dev_dataloader):
        batch = tuple(t.to(device) for t in batch)
        _, input_ids, input_mask, masked_ids = batch
        inputs = {'input_ids': batch[1],
                  'attention_mask': batch[2],
                  'masked_lm_labels': batch[3]}

        outputs = model(**inputs)
        loss = outputs[0]
        sum_loss += loss.item()
    return sum_loss

def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, seed=12345):
    """Loads a data file into a list of `InputBatch`s."""

    features = []
    # ----
    # dupe_factor = 5
    masked_lm_prob = 0.15
    max_predictions_per_seq = 20
    rng = random.Random(seed)


    for (ex_index, example) in enumerate(examples):
        modified_example = example.label + " " + example.text_a
        tokens_a = tokenizer.tokenize(modified_example)
        # Account for [CLS] and [SEP] and label with "- 3"
        if len(tokens_a) > max_seq_length - 3:
            tokens_a = tokens_a[0:(max_seq_length - 3)]

        # take care of prepending the class label in this code
        tokens = []
        tokens.append("[CLS]")
        for token in tokens_a:
            tokens.append(token)
        tokens.append("[SEP]")
        masked_lm_labels = [-100] * max_seq_length
        
        cand_indexes = []
        for (i, token) in enumerate(tokens):
            # making sure that masking of # prepended label is avoided
            if token == "[CLS]" or token == "[SEP]" or (token in label_list and i == 1):
                continue
            cand_indexes.append(i)
        
        rng.shuffle(cand_indexes)
        len_cand = len(cand_indexes)
        output_tokens = list(tokens)
        
        num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
        
        masked_lms_pos = []
        covered_indexes = set()

        for index in cand_indexes:
          if len(masked_lms_pos) >= num_to_predict:
              break
          if index in covered_indexes:
              continue
          covered_indexes.add(index)

          masked_token = None
          
          # 80% of the time, replace with [MASK]
          if rng.random() < 0.8:
              masked_token = "[MASK]"
          else:
              # 10% of the time, keep original
              if rng.random() < 0.5:
                masked_token = tokens[index]
              # 10% of the time, replace with random word
              else:
                masked_token = tokens[cand_indexes[rng.randint(0, len_cand - 1)]]

          masked_lm_labels[index] = tokenizer.convert_tokens_to_ids([tokens[index]])[0]
          output_tokens[index] = masked_token
          masked_lms_pos.append(index)

        init_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_ids = tokenizer.convert_tokens_to_ids(output_tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            init_ids.append(0)
            input_ids.append(0)
            input_mask.append(0)

        assert len(init_ids) == max_seq_length
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length

        if ex_index < 2:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                [str(x) for x in tokens]))
            logger.info("init_ids: %s" % " ".join([str(x) for x in init_ids]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("masked_lm_labels: %s" % " ".join([str(x) for x in masked_lm_labels]))

        features.append(
            InputFeatures(init_ids=init_ids,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          masked_lm_labels=masked_lm_labels))
    return features

def prepare_data(features):
    all_init_ids = torch.tensor([f.init_ids for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_masked_lm_labels = torch.tensor([f.masked_lm_labels for f in features],
                                        dtype=torch.long)
    tensor_data = TensorDataset(all_init_ids, all_input_ids, all_input_mask, all_masked_lm_labels)
    return tensor_data

def rev_wordpiece(str):
    #print(str)
    if len(str) > 1:
        for i in range(len(str)-1, 0, -1):
            if str[i] == '[PAD]':
                str.remove(str[i])
            elif len(str[i]) > 1 and str[i][0]=='#' and str[i][1]=='#':
                str[i-1] += str[i][2:]
                str.remove(str[i])
    return " ".join(str[2:-1])


In [None]:
def train_cmodbert_and_augment(args, example_index):
  task_name = args.task_name
  os.makedirs(args.output_dir, exist_ok=True)

  random.seed(args.seed)
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)

  processor = get_task_processor(task_name, args.data_dir)
  label_list = processor.get_labels(task_name)

  # load train and dev data
  train_examples = processor.get_train_examples()
  dev_examples = processor.get_dev_examples()

  print(train_examples[example_index].guid)
  print(train_examples[example_index].text_a)
  print(train_examples[example_index].text_b)
  print(train_examples[example_index].label)

  tokenizer = BertTokenizer.from_pretrained(BERT_MODEL,
                                            do_lower_case=True,
                                            cache_dir=args.cache)
  model = BertForMaskedLM.from_pretrained(BERT_MODEL,
                                          cache_dir=args.cache)

  tokenizer.add_tokens(label_list) # 이 부분 좀 의심스러운데
  model.resize_token_embeddings(len(tokenizer))
  model.cls = BertOnlyMLMHead(model.config)

  model.to(device)

  # train data
  train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, args.seed)
  train_data = prepare_data(train_features)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data,
                                sampler=train_sampler,
                                batch_size=args.train_batch_size)
  
  # dev data
  dev_features = convert_examples_to_features(dev_examples,
                                              label_list,
                                              args.max_seq_length,
                                              tokenizer,
                                              args.seed)
  dev_data = prepare_data(dev_features)
  dev_sampler = SequentialSampler(dev_data)
  dev_dataloader = DataLoader(dev_data,
                              sampler=dev_sampler,
                              batch_size=args.train_batch_size)
  
  num_train_steps = int(len(train_features) / args.train_batch_size * args.num_train_epochs)
  logger.info("***** Running training *****")
  logger.info("  Num examples = %d", len(train_features))
  logger.info("  Batch size = %d", args.train_batch_size)
  logger.info("  Num steps = %d", num_train_steps)

  # optimizer
  t_total = num_train_steps
  no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.01},
      {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0}
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)

  best_dev_loss = float('inf')
  print(best_dev_loss)
  for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
    avg_loss = 0.
    model.train()
    for step, batch in enumerate(train_dataloader):
      batch = tuple(t.to(device) for t in batch)
      _, input_ids, input_mask, masked_ids = batch
      inputs = {'input_ids': batch[1],
                'attention_mask': batch[2],
                'masked_lm_labels': batch[3]}

      outputs = model(**inputs)
      loss = outputs[0]
      # loss = model(input_ids, segment_ids, input_mask, masked_ids)
      loss.backward()
      avg_loss += loss.item()
      optimizer.step()
      model.zero_grad()
      if (step + 1) % 50 == 0:
          print("avg_loss: {}".format(avg_loss / 50))
      avg_loss = 0.

    # eval on dev after every epoch
    dev_loss = compute_dev_loss(model, dev_dataloader)
    print("Epoch {}, Dev loss {}".format(epoch, dev_loss))
    if dev_loss < best_dev_loss:
      best_dev_loss = dev_loss
      print("Saving model. Best dev so far {}".format(best_dev_loss))
      save_model_path = os.path.join(args.output_dir, 'best_cmodbert.pt')
      torch.save(model.state_dict(), save_model_path)

  return model, tokenizer, train_data, label_list

In [None]:
model, tokenizer, train_data, label_list = train_cmodbert_and_augment(args, 1)

03/11/2021 03:59:13 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at transformers_cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
03/11/2021 03:59:13 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at transformers_cache/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
03/11/2021 03:59:13 - INFO - transformers.configuration_utils -   Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
 

/content/drive/MyDrive/transformers-data-augmentation/datasets/TREC
train-1
How long is human gestation ?
None
5


03/11/2021 03:59:13 - INFO - transformers.modeling_utils -   loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at transformers_cache/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
03/11/2021 03:59:17 - INFO - transformers.modeling_utils -   Weights of BertForMaskedLM not initialized from pretrained model: ['cls.predictions.decoder.bias']
03/11/2021 03:59:17 - INFO - transformers.modeling_utils -   Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
03/11/2021 03:59:17 - INFO - transformers.tokenization_utils -   Adding 0
 to the vocabulary
03/11/2021 03:59:17 - INFO - transformers.tokenization_utils -   Adding 1
 to the vocabulary
03/11/2021 03:59:17 - INFO - transformers.tokenization_utils -   Adding 2
 to the vocabulary
03/11/2021 03:59:17 - INFO - transformers.tokenization_utils -   Adding 

inf
avg_loss: 0.14749713897705077
avg_loss: 0.1124178409576416
avg_loss: 0.1293332099914551
avg_loss: 0.1456339931488037
avg_loss: 0.13964256286621093
avg_loss: 0.09378203392028808
avg_loss: 0.08656529426574706
avg_loss: 0.07420523166656494
avg_loss: 0.07092886924743652
avg_loss: 0.10040390014648437
avg_loss: 0.046237034797668455
avg_loss: 0.10412364959716797
Epoch 0, Dev loss 321.07247614860535
Saving model. Best dev so far 321.07247614860535


Epoch:  10%|█         | 1/10 [01:36<14:24, 96.05s/it]

avg_loss: 0.07889072895050049
avg_loss: 0.06286872386932373
avg_loss: 0.06840457439422608
avg_loss: 0.10257291793823242
avg_loss: 0.05823854923248291
avg_loss: 0.06943430423736573
avg_loss: 0.08183790206909179
avg_loss: 0.07886219501495362
avg_loss: 0.0828990650177002
avg_loss: 0.0563795804977417
avg_loss: 0.052818312644958496
avg_loss: 0.04927875518798828
Epoch 1, Dev loss 308.95835864543915
Saving model. Best dev so far 308.95835864543915


Epoch:  20%|██        | 2/10 [03:12<12:48, 96.02s/it]

avg_loss: 0.05796962738037109
avg_loss: 0.05191439151763916
avg_loss: 0.05106105327606201
avg_loss: 0.06225635528564453
avg_loss: 0.052340970039367676
avg_loss: 0.04236915588378906
avg_loss: 0.028450374603271485
avg_loss: 0.06262572288513184
avg_loss: 0.05459034919738769
avg_loss: 0.04307582855224609
avg_loss: 0.04313146114349365
avg_loss: 0.06143080234527588
Epoch 2, Dev loss 303.44585132598877
Saving model. Best dev so far 303.44585132598877


Epoch:  30%|███       | 3/10 [04:48<11:12, 96.02s/it]

avg_loss: 0.028343071937561037
avg_loss: 0.03181040048599243
avg_loss: 0.03698539733886719
avg_loss: 0.03525476455688477
avg_loss: 0.05465151786804199
avg_loss: 0.03251964092254639
avg_loss: 0.038784761428833005
avg_loss: 0.055782651901245116
avg_loss: 0.044208769798278806
avg_loss: 0.031033742427825927
avg_loss: 0.015695006847381593
avg_loss: 0.038946750164031985
Epoch 3, Dev loss 302.87026047706604
Saving model. Best dev so far 302.87026047706604


Epoch:  40%|████      | 4/10 [06:23<09:36, 96.01s/it]

avg_loss: 0.024553453922271727
avg_loss: 0.026608052253723143
avg_loss: 0.020166246891021727
avg_loss: 0.03499277114868164
avg_loss: 0.01984520435333252
avg_loss: 0.010063151121139527
avg_loss: 0.02372631072998047
avg_loss: 0.019359079599380494
avg_loss: 0.03165896654129028
avg_loss: 0.02938525676727295
avg_loss: 0.01866966962814331
avg_loss: 0.037000298500061035
Epoch 4, Dev loss 301.32491624355316
Saving model. Best dev so far 301.32491624355316


Epoch:  50%|█████     | 5/10 [08:00<08:00, 96.05s/it]

avg_loss: 0.00367341548204422
avg_loss: 0.013570754528045655
avg_loss: 0.0069470185041427615
avg_loss: 0.01803031325340271
avg_loss: 0.020314912796020507
avg_loss: 0.011440095901489257
avg_loss: 0.018651736974716185
avg_loss: 0.01015558123588562
avg_loss: 0.009552507400512696
avg_loss: 0.009389273524284363
avg_loss: 0.01786839485168457
avg_loss: 0.03133715629577637


Epoch:  60%|██████    | 6/10 [09:34<06:22, 95.56s/it]

Epoch 5, Dev loss 302.43590664863586
avg_loss: 0.00635937511920929
avg_loss: 0.011653796434402466
avg_loss: 0.011780924797058105
avg_loss: 0.006674959659576416
avg_loss: 0.009018615484237671
avg_loss: 0.004539644420146942
avg_loss: 0.01310280203819275
avg_loss: 0.004178353250026703
avg_loss: 0.012614812850952149
avg_loss: 0.01304574966430664
avg_loss: 0.004071841835975647
avg_loss: 0.010236523151397704


Epoch:  70%|███████   | 7/10 [11:08<04:45, 95.20s/it]

Epoch 6, Dev loss 305.63548243045807
avg_loss: 0.004345260858535767
avg_loss: 0.003992502093315125
avg_loss: 0.0050913399457931515
avg_loss: 0.006967033743858337
avg_loss: 0.006056959629058838
avg_loss: 0.0037680011987686157
avg_loss: 0.009455618262290955
avg_loss: 0.007782201766967773
avg_loss: 0.002526736855506897
avg_loss: 0.007161176204681397
avg_loss: 0.005452333092689514
avg_loss: 0.0030105677247047424


Epoch:  80%|████████  | 8/10 [12:43<03:09, 94.95s/it]

Epoch 7, Dev loss 308.733332157135
avg_loss: 0.0045415091514587405
avg_loss: 0.0034392452239990234
avg_loss: 0.0035231485962867737
avg_loss: 0.0014805065095424652
avg_loss: 0.004940434098243713
avg_loss: 0.002244567275047302
avg_loss: 0.0023155677318572997
avg_loss: 0.0024924999475479125
avg_loss: 0.003725219368934631
avg_loss: 0.002823907732963562
avg_loss: 0.005153459310531616
avg_loss: 0.001724792569875717


Epoch:  90%|█████████ | 9/10 [14:17<01:34, 94.76s/it]

Epoch 8, Dev loss 311.9006907939911
avg_loss: 0.0017856520414352417
avg_loss: 0.0020895926654338837
avg_loss: 0.0010587909817695617
avg_loss: 0.001425417810678482
avg_loss: 0.0017901211977005005
avg_loss: 0.0009386606514453888
avg_loss: 0.0010705384612083434
avg_loss: 0.001415014863014221
avg_loss: 0.0010418861359357834
avg_loss: 0.002231907546520233
avg_loss: 0.0019370584189891814
avg_loss: 0.0009833110123872756


Epoch: 100%|██████████| 10/10 [15:51<00:00, 95.20s/it]

Epoch 9, Dev loss 312.17053961753845





In [None]:
train_sampler = SequentialSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

best_model_path = os.path.join(args.output_dir, "best_cmodbert.pt")
if os.path.exists(best_model_path):
  model.load_state_dict(torch.load(best_model_path))
else:
  raise ValueError("Unable to find the saved model at {}".format(best_model_path))

save_train_path = os.path.join(args.output_dir, "cmodbert_aug.tsv")
save_train_file = open(save_train_path, 'w')

MASK_id = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
tsv_writer = csv.writer(save_train_file, delimiter='\t')

for step, batch in enumerate(train_dataloader):
  batch = tuple(t.to(device) for t in batch)
  init_ids, _, input_mask, _ = batch
  input_lens = [sum(mask).item() for mask in input_mask]
  masked_idx = np.squeeze(
    [np.random.randint(2, l, max((l-2) // args.sample_ratio, 1)) for l in input_lens]
  )
  
  for ids, idx in zip(init_ids, masked_idx):
    ids[idx] = MASK_id

  inputs = {'input_ids': init_ids,
            'attention_mask': input_mask}

  outputs = model(**inputs)
  predictions = outputs[0]
  predictions = F.softmax(predictions / args.temp, dim=2)

  for ids, idx, preds in zip(init_ids, masked_idx, predictions):

    preds = torch.multinomial(preds, args.sample_num, replacement=True)[idx]
    if len(preds.size()) == 2:
        preds = torch.transpose(preds, 0, 1)
    for pred in preds:
        ids[idx] = pred
        new_str = tokenizer.convert_ids_to_tokens(ids.cpu().numpy())
        label = new_str[1]
        new_str = rev_wordpiece(new_str)
        tsv_writer.writerow([label, new_str])


  return array(a, dtype, copy=False, order=order)


In [None]:
processor2 = get_task_processor(args.task_name, args.data_dir)
train_examples2 = processor2.get_train_examples()
train_features2 = convert_examples_to_features(train_examples2[1:2], label_list[1:2], args.max_seq_length, tokenizer, args.seed)

# train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, args.seed)


03/11/2021 04:15:42 - INFO - __main__ -   *** Example ***
03/11/2021 04:15:42 - INFO - __main__ -   guid: train-1
03/11/2021 04:15:42 - INFO - __main__ -   tokens: [CLS] 5 how long is human ge ##station ? [SEP]
03/11/2021 04:15:42 - INFO - __main__ -   init_ids: 101 1019 2129 2146 2003 2529 16216 20100 1029 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
03/11/2021 04:15:42 - INFO - __main__ -   input_ids: 101 1019 2129 2146 103 103 16216 20100 1029 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
03/11/2021 04:15:42 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
03/11/2021 04:15:42 - INFO - __main__ -   masked_lm_labels: -100 -100 -100 -100 2003 2529 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -10