<a href="https://colab.research.google.com/github/crux82/ganbert-pytorch/blob/main/GANBERT_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN-BERT (in Pytorch and compatible with HuggingFace)

This is an implementation in Pytorch (and **HuggingFace**) of the GAN-BERT method from https://github.com/crux82/ganbert which is available in Tensorflow.

While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert!

**NOTE**: given that this implementation is different from the original one in Tensorflow, some results can be slighty different (but it alway improves the original BERT implementation).


Let's GO!

Required Imports.

In [None]:
!pip install transformers==4.1.1
import tensorflow as tf
import torch
import io
import torch.nn.functional as F
import random
import numpy as np
import time
import math
import datetime
import torch.nn as nn
from transformers import *
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
#!pip install sentencepiece

##Set random values
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(seed_val)

Collecting transformers==4.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/50/0c/7d5950fcd80b029be0a8891727ba21e0cd27692c407c51261c3c921f6da3/transformers-4.1.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 20.6MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 44.3MB/s 
Collecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/fb/36/59e4a62254c5fcb43894c6b0e9403ec6f4238cc2422a003ed2e6279a1784/tokenizers-0.9.4-cp37-cp37m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 24.3MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=5c9803dbe88

In [None]:
# If there's a GPU available...
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4


### Input Parameters


In [None]:
max_seq_length = 64
batch_size = 64
learning_rate_discriminator = 2e-5
learning_rate_generator = 2e-5
noise_size = 100
epsilon = 1e-8
out_dropout_rate = 0.1
apply_balance = True

num_train_epochs = 10

apply_scheduler = True
warmup_proportion = 0.1

print_each_n_step = 10
multi_gpu = True

model_name = "bert-base-cased"
#model_name = "roberta-base"
#model_name = "albert-base-v2"
#model_name = "xlm-roberta-base"
#model_name = "amazon/bort"

! git clone https://github.com/crux82/ganbert

labeled_file = "./ganbert/data/labeled.tsv"
unlabeled_file = "./ganbert/data/unlabeled.tsv"
test_filename = "./ganbert/data/test.tsv"

label_list = ["UNK_UNK","ABBR_abb", "ABBR_exp", "DESC_def", "DESC_desc", "DESC_manner", "DESC_reason", "ENTY_animal", "ENTY_body", "ENTY_color", "ENTY_cremat", "ENTY_currency", "ENTY_dismed", "ENTY_event", "ENTY_food", "ENTY_instru", "ENTY_lang", "ENTY_letter", "ENTY_other", "ENTY_plant", "ENTY_product", "ENTY_religion", "ENTY_sport", "ENTY_substance", "ENTY_symbol", "ENTY_techmeth", "ENTY_termeq", "ENTY_veh", "ENTY_word", "HUM_desc", "HUM_gr", "HUM_ind", "HUM_title", "LOC_city", "LOC_country", "LOC_mount", "LOC_other", "LOC_state", "NUM_code", "NUM_count", "NUM_date", "NUM_dist", "NUM_money", "NUM_ord", "NUM_other", "NUM_perc", "NUM_period", "NUM_speed", "NUM_temp", "NUM_volsize", "NUM_weight"]

Cloning into 'ganbert'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 74 (delta 31), reused 49 (delta 15), pack-reused 0[K
Unpacking objects: 100% (74/74), done.


Load the input QC dataset (fine-grained)

In [None]:
def get_qc_examples(input_file):
  """Creates examples for the training and dev sets."""
  examples = []

  with open(input_file, 'r') as f:
      contents = f.read()
      file_as_list = contents.splitlines()
      for line in file_as_list[1:]:
          split = line.split(" ")
          question = ' '.join(split[1:])

          text_a = question
          inn_split = split[0].split(":")
          label = inn_split[0] + "_" + inn_split[1]
          examples.append((text_a, label))
      f.close()

  return examples

labeled_examples = get_qc_examples(labeled_file)
unlabeled_examples = get_qc_examples(unlabeled_file)
test_examples = get_qc_examples(test_filename)

num_train_examples = len(labeled_examples) + len(unlabeled_examples)

label_mask_rate = len(labeled_examples)/num_train_examples

Load the Tranformer Model

In [None]:
transformer = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if torch.cuda.is_available():    
  transformer.cuda()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…




Functions required to convert examples into Dataloader

In [None]:
def generate_data_loader(input_examples, label_masks, label_map, do_shuffle = False, balance_label_examples = False):
  '''
  Generate a Dataloader given the input examples, eventually masked if they are 
  to be considered NOT labeled.
  '''
  examples = []

  for index, ex in enumerate(input_examples): 
    if label_mask_rate == 1 or not balance_label_examples:
      examples.append((ex, label_masks[index]))
    else:
      # IT SIMULATE A LABELED EXAMPLE
      if label_masks[index]:
        balance = int(1/label_mask_rate)
        balance = int(math.log(balance,2))
        if balance < 1:
          balance = 1
        for b in range(0, int(balance)):
          examples.append((ex, label_masks[index]))
      else:
        examples.append((ex, label_masks[index]))

  input_ids_array = []
  input_mask_array = []
  segment_ids_array = []
  label_mask_array = []
  label_id_array = []

  for (text, label_mask) in examples: 
    input_ids = []
    input_mask = []
    tokens_a = tokenizer.tokenize(text[0])
    if len(tokens_a) > max_seq_length - 2:
        tokens_a = tokens_a[0:(max_seq_length - 2)]
    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
      tokens.append(token)
      segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)

    while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    label_id = label_map[text[1]]

    input_ids_array.append(torch.tensor(input_ids, dtype=torch.long)) 
    input_mask_array.append(torch.tensor(input_mask, dtype=torch.long))
    segment_ids_array.append(torch.tensor(segment_ids, dtype=torch.long))
    label_mask_array.append(label_mask)
    label_id_array.append(label_id)

  label_mask = np.array(label_mask_array)
  label_id = np.array(label_id_array)

  input_ids_array = torch.stack((input_ids_array),dim = 0)
  input_mask_array = torch.stack((input_mask_array),dim = 0)
  segment_ids_array = torch.stack((segment_ids_array),dim = 0)
  label_masks = torch.tensor(label_mask)
  label_ids = torch.tensor(label_id)

  dataset = TensorDataset(input_ids_array, input_mask_array, label_ids, label_masks)

  if do_shuffle:
    sampler = RandomSampler
  else:
    sampler = SequentialSampler

  return DataLoader(
              dataset,  # The training samples.
              sampler = sampler(dataset), 
              batch_size = batch_size) # Trains with this batch size.


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

Convert the input examples into DataLoader

In [None]:
label_map = {}
for (i, label) in enumerate(label_list):
  label_map[label] = i
#------------------------------
#   Load the train dataset
#------------------------------
train_examples = labeled_examples
#The labeled (train) dataset is assigned with a mask set to True
train_label_masks = np.ones(len(labeled_examples), dtype=bool)
#If unlabel examples are available
if unlabeled_examples:
  train_examples = train_examples + unlabeled_examples
  #The unlabeled (train) dataset is assigned with a mask set to False
  tmp_masks = np.zeros(len(unlabeled_examples), dtype=bool)
  train_label_masks = np.concatenate([train_label_masks,tmp_masks])

train_dataloader = generate_data_loader(train_examples, train_label_masks, label_map, do_shuffle = True, balance_label_examples = apply_balance)

#------------------------------
#   Load the test dataset
#------------------------------
#The labeled (test) dataset is assigned with a mask set to True
test_label_masks = np.ones(len(test_examples), dtype=bool)

test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)

We define the Generator and Discriminator as discussed in https://www.aclweb.org/anthology/2020.acl-main.191/

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):
        super(Generator, self).__init__()
        layers = []
        hidden_sizes = [noise_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        layers.append(nn.Linear(hidden_sizes[-1],output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, noise):
        output_rep = self.layers(noise)
        return output_rep

class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep, dropout_rate=0.1):
        #input_rep = nn.Dropout(input_rep,dropout_rate=0.1) #dropout all'input?
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

We instantiate the Discriminator and Generator

In [None]:
# The config file is required to get the dimension of the vector produced by 
# the underlying transformer
config = AutoConfig.from_pretrained(model_name)
hidden_size = int(config.hidden_size)
print(config)

generator = Generator(noise_size=noise_size, output_size=hidden_size, hidden_sizes=[hidden_size], dropout_rate=out_dropout_rate)
discriminator = Discriminator(input_size=hidden_size, hidden_sizes=[hidden_size], num_labels=len(label_list), dropout_rate=out_dropout_rate)

if torch.cuda.is_available():    
  generator.cuda()
  discriminator.cuda()
  if multi_gpu:
    transformer = torch.nn.DataParallel(transformer)

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "type_vocab_size": 2,
  "vocab_size": 28996
}



Let's go with the training procedure

In [None]:
training_stats = []

# Measure the total training time for the whole run.
total_t0 = time.time()

#models parameters
transformer_vars = [i for i in transformer.parameters()]
d_vars = transformer_vars + [v for v in discriminator.parameters()]

g_vars = [v for v in generator.parameters()]

#optimizer
dis_optimizer = torch.optim.AdamW(d_vars, lr=learning_rate_discriminator)
gen_optimizer = torch.optim.AdamW(g_vars, lr=learning_rate_generator) 

#scheduler
if apply_scheduler:
  num_train_examples = len(train_examples)
  num_train_steps = int(num_train_examples / batch_size * num_train_epochs)
  num_warmup_steps = int(num_train_steps * warmup_proportion)

  scheduler_d = get_linear_schedule_with_warmup(dis_optimizer, 
                                           num_warmup_steps = num_warmup_steps, 
                                           num_training_steps = num_train_steps)
  scheduler_g = get_linear_schedule_with_warmup(gen_optimizer, 
                                           num_warmup_steps = num_warmup_steps, 
                                           num_training_steps = num_train_steps)


# For each epoch...
for epoch_i in range(0, num_train_epochs):
    
    # ========================================
    #               Training
    # ========================================
    
    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, num_train_epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()

    # Reset the total loss for this epoch.
    tr_g_loss = 0
    tr_d_loss = 0

    # Put the model into training mode.
    transformer.train() 
    generator.train()
    discriminator.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):

        # Progress update every print_each_n_step batches.
        if step % print_each_n_step == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_label_mask = batch[3].to(device)

        # Always clear any previously calculated gradients before performing a backward pass.
        dis_optimizer.zero_grad()
        gen_optimizer.zero_grad()

        noise = torch.zeros(b_input_ids.shape[0],noise_size, device=device).uniform_(0, 1).requires_grad_(True)

        #transformer
        model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
        hidden_states = model_outputs[-1]
        #discriminator
        D_real_features, D_real_logits, D_real_probs = discriminator(hidden_states)

        #generator
        gen_rep = generator(noise)
        #discriminator for generator
        D_fake_features, D_fake_logits, D_fake_probs = discriminator(gen_rep.detach())
        
        #train discriminator
        logits = D_real_logits[:,1:]
        probabilities = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        label2one_hot = torch.nn.functional.one_hot(b_labels, len(label_list))
        per_example_loss = -torch.sum(label2one_hot * log_probs, dim=-1)
        per_example_loss = torch.masked_select(per_example_loss, b_label_mask.to(device))
        labeled_example_count = per_example_loss.type(torch.float32).numel()

        if labeled_example_count==0:
          D_L_Supervised=0
        else:
          D_L_Supervised = torch.div(torch.sum(per_example_loss.to(device)), labeled_example_count)
                 
        D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, -1] + epsilon))
        D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, -1] + epsilon))
        d_loss = D_L_Supervised + D_L_unsupervised1U + D_L_unsupervised2U

        #train generator
        D_fake_features, D_fake_logits, D_fake_probs = discriminator(gen_rep) # for generator       
        g_loss = -1 * torch.mean(torch.log(1 - D_fake_probs[:,-1] + epsilon))
        g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features.detach(), dim=0) - torch.mean(D_fake_features, dim=0), 2))
        g_loss = g_loss + g_feat_reg

        #generator backward
        g_loss.backward()
        gen_optimizer.step()

        #discriminator backward
        d_loss.backward()
        dis_optimizer.step()

        #accumulate loss
        tr_g_loss += g_loss.item()
        tr_d_loss += d_loss.item()

        # Update the learning rate.
        if apply_scheduler:
          scheduler_d.step()
          scheduler_g.step()

    # Calculate the average loss over all of the batches.
    avg_train_loss_g = tr_g_loss / len(train_dataloader)
    avg_train_loss_d = tr_d_loss / len(train_dataloader)             
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss generetor: {0:.2f}".format(avg_train_loss_g))
    print("  Average training loss discriminator: {0:.2f}".format(avg_train_loss_d))
    print("  Training epcoh took: {:}".format(training_time))
        
    # ========================================
    #     TEST ON THE EVALUATION DATASET
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our test set.

    print("")
    print("Running Test...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    transformer.eval() #maybe redundant
    discriminator.eval()
    generator.eval()

    # Tracking variables 
    total_test_accuracy = 0
   
    total_test_loss = 0
    nb_test_steps = 0

    all_preds = []
    all_labels_ids = []

    #loss
    nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)

    # Evaluate data for one epoch
    for batch in test_dataloader:
        
        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        
            model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
            hidden_states = model_outputs[-1]
            _, logits, probs = discriminator(hidden_states)
            log_probs = F.log_softmax(probs[:,1:], dim=-1)
            
            # Accumulate the test loss.
            total_test_loss += nll_loss(log_probs, b_labels)

        # Accumulate the predictions and the input labels
        logits = logits[:,1:]
        _, preds = torch.max(logits, 1)
        all_preds += preds.detach().cpu()
        all_labels_ids += b_labels.detach().cpu()

    # Report the final accuracy for this validation run.
    all_preds = torch.stack(all_preds).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    test_accuracy = np.sum(all_preds == all_labels_ids) / len(all_preds)
    print("  Accuracy: {0:.2f}".format(test_accuracy))

    # Calculate the average loss over all of the batches.
    avg_test_loss = total_test_loss / len(test_dataloader)
    
    # Measure how long the validation run took.
    test_time = format_time(time.time() - t0)
    
    print("  Test Loss: {0:.2f}".format(avg_test_loss))
    print("  Test took: {:}".format(test_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss generator': avg_train_loss_g,
            'Training Loss discriminator': avg_train_loss_d,
            'Valid. Loss': avg_test_loss,
            'Valid. Accur.': test_accuracy,
            'Training Time': training_time,
            'Test Time': test_time
        }
    )


Training...
  Batch    10  of     92.    Elapsed: 0:00:07.
  Batch    20  of     92.    Elapsed: 0:00:13.
  Batch    30  of     92.    Elapsed: 0:00:20.
  Batch    40  of     92.    Elapsed: 0:00:27.
  Batch    50  of     92.    Elapsed: 0:00:34.
  Batch    60  of     92.    Elapsed: 0:00:41.
  Batch    70  of     92.    Elapsed: 0:00:48.
  Batch    80  of     92.    Elapsed: 0:00:55.
  Batch    90  of     92.    Elapsed: 0:01:02.

  Average training loss generetor: 0.11
  Average training loss discriminator: 7.56
  Training epcoh took: 0:01:04

Running Test...
  Accuracy: 0.11
  Test Loss: 3.91
  Test took: 0:00:02

Training...
  Batch    10  of     92.    Elapsed: 0:00:07.
  Batch    20  of     92.    Elapsed: 0:00:14.
  Batch    30  of     92.    Elapsed: 0:00:21.
  Batch    40  of     92.    Elapsed: 0:00:27.
  Batch    50  of     92.    Elapsed: 0:00:34.
  Batch    60  of     92.    Elapsed: 0:00:41.
  Batch    70  of     92.    Elapsed: 0:00:48.
  Batch    80  of     92.    Elap

In [None]:
for stat in training_stats:
  print(stat)

print("\nTraining complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

{'epoch': 1, 'Training Loss generator': 0.11162844915752826, 'Training Loss discriminator': 7.564802874689517, 'Valid. Loss': tensor(3.9065, device='cuda:0'), 'Valid. Accur.': 0.11, 'Training Time': '0:01:04', 'Test Time': '0:00:02'}
{'epoch': 2, 'Training Loss generator': 0.10451734479030837, 'Training Loss discriminator': 6.092482727506886, 'Valid. Loss': tensor(3.8542, device='cuda:0'), 'Valid. Accur.': 0.388, 'Training Time': '0:01:03', 'Test Time': '0:00:02'}
{'epoch': 3, 'Training Loss generator': 0.07300473927803662, 'Training Loss discriminator': 5.315364254557568, 'Valid. Loss': tensor(3.8153, device='cuda:0'), 'Valid. Accur.': 0.448, 'Training Time': '0:01:03', 'Test Time': '0:00:02'}
{'epoch': 4, 'Training Loss generator': 0.06840045377612114, 'Training Loss discriminator': 5.114511266998623, 'Valid. Loss': tensor(3.7553, device='cuda:0'), 'Valid. Accur.': 0.424, 'Training Time': '0:01:03', 'Test Time': '0:00:02'}
{'epoch': 5, 'Training Loss generator': 0.06866135958420194, 