In [1]:
import torch
from google.cloud import storage
import tokenizers
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.sampler import RandomSampler
import numpy as np
import random
import os

In [10]:
seq_length = 128
accum_multipler = 1
batch_size = 128
epochs = 1
warmup_ratio = 0.06
lr = 5e-4

data_size = os.stat("/mnt/d/data_masked_%s"%seq_length).st_size // (batch_size*4)

num_batches = int(math.ceil(data_size / batch_size))
tot_num_steps   = int(math.ceil((data_size / batch_size / accum_multipler)  * epochs))
warmup_steps = int(tot_num_steps * warmup_ratio)
data_size

18746072

In [12]:
print('num_batches:    ', num_batches)
print('data_size:      ', data_size)
print('seq_length:     ', seq_length)
print('lr:             ', lr)
print('epochs:         ', epochs)
print('tot_num_steps:  ', tot_num_steps)
print('warmup_steps:   ', warmup_steps)

data_size:       3279552
seq_length:      128
lr:              0.0005
epochs:          2
tot_num_steps:   292906
warmup_steps:    17574


In [13]:
if torch.cuda.is_available():      
    device = torch.device("cuda")
    print('GPU:', torch.cuda.get_device_name(0))
else:
    print('CPU')
    device = torch.device("cpu")

GPU: Tesla V100-SXM2-16GB


In [14]:
tokenizer = BertWordPieceTokenizer(vocab_file = 'tokenizer/vocab.txt')
tokenizer.add_special_tokens(["<nl>"])
tokenizer.enable_truncation(max_length=seq_length)
tokenizer.enable_padding(length=seq_length)

In [15]:
data_original_fn = "/mnt/d/data_original_%s"%seq_length
data_masked_fn   = "/mnt/d/data_masked_%s"%seq_length

In [16]:
import random
i = random.randint(0, 100000)
with open(data_original_fn, "rb") as f:
    data = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))
    
with open(data_masked_fn, "rb") as f:
    data_masked = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))

In [17]:
from termcolor import colored
tensor = torch.zeros(())
labels = tensor.new_full(data.shape, -100).int()
labels[data!=data_masked] = data[data!=data_masked]

attention_mask = torch.where(data!=0, torch.ones_like(data), torch.zeros_like(data))

for id, label in zip(data, labels):
    token = tokenizer.id_to_token(id)
    if label >= 0:
        token = colored(token,'red')
    print(token, end=" ")
print()
print()

for id, label in zip(data, labels):
    if not id:
        continue
    token = tokenizer.id_to_token(id)
    if label >= 0:
        token = colored(tokenizer.id_to_token(label), 'blue')
    print(token, end=" ")
print()
print()
for id, label in zip(data_masked, labels):
    if not id:
        continue
    token = tokenizer.id_to_token(id)
    if label >= 0:
        token = colored(token,'red')
    print(token, end=" ")


[CLS] the full . she was watching him and he grinned at her . ‘ stop [31mfishing[0m for compliment ##s . ’ <nl> melan ##ie grinned back , her head on one side . ‘ so [31m,[0m what about you and fi ##zz beau ##mont then ? ’ <nl> it was his turn to hesitate . ‘ what [31mdo[0m you [31mmean[0m ? ’ <nl> she threw him a look full of misc ##hi ##ef [31m.[0m ‘ i asked her if [31myou[0m [31m’[0m d had [31ma[0m row or something . ’ <nl> ‘ [31mthat[0m [31mwas[0m very rude of you . ’ <nl> ‘ probably [31m.[0m but i wanted to know . ’ <nl> ‘ and what did she say [31m?[0m ’ <nl> melan ##ie ’ s eyes gle [SEP] 

[CLS] the full . she was watching him and he grinned at her . ‘ stop [34mfishing[0m for compliment ##s . ’ <nl> melan ##ie grinned back , her head on one side . ‘ so [34m,[0m what about you and fi ##zz beau ##mont then ? ’ <nl> it was his turn to hesitate . ‘ what [34mdo[0m you [34mmean[0m ? ’ <nl> she threw him a look full of misc ##hi ##ef [34m.[0m ‘ i asked 

In [18]:
class textDataset(Dataset):
    def __init__(self, size):
        self.size = size
    def __len__(self):
        return self.size
    def __getitem__(self,i):
        with open(data_original_fn, "rb") as f:
            data = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))

        with open(data_masked_fn, "rb") as f:
            data_masked = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))

        tensor = torch.zeros(())
        labels = tensor.new_full(data.shape, -100).int()
        labels[data!=data_masked] = data[data!=data_masked]
        
        attention_mask = torch.where(data!=0, torch.ones_like(data), torch.zeros_like(data))
        
        return data_masked.long(), labels.long(), attention_mask.long(), data.long()

In [19]:
dataset = textDataset(data_size)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True)

print('Actual batch size:', batch_size * accum_multipler)

print('Batch size per GPU per pass:', batch_size // torch.cuda.device_count())

Actual batch size: 128
Batch size per GPU per pass: 16


In [21]:
from transformers import ElectraForMaskedLM, ElectraForPreTraining
from transformers import ElectraConfig
import torch.nn as nn

generator_config = ElectraConfig(
    max_position_embeddings=seq_length,
    num_hidden_layers=12,
    vocab_size=50000,
    embedding_size=128,
    hidden_size = 64,
    intermediate_size = 256,
    num_attention_heads=1,
)
discriminator_config = ElectraConfig(
    max_position_embeddings=seq_length,
    num_hidden_layers=12,
    vocab_size=50000,
    embedding_size=128,
    hidden_size=256,
    intermediate_size=1024,
    num_attention_heads=4,
)

generator = nn.DataParallel(ElectraForMaskedLM(config=generator_config))
generator.to(device)
discriminator = nn.DataParallel(ElectraForPreTraining(config=discriminator_config))
discriminator.to(device)

DataParallel(
  (module): ElectraForPreTraining(
    (electra): ElectraModel(
      (embeddings): ElectraEmbeddings(
        (word_embeddings): Embedding(50000, 128, padding_idx=0)
        (position_embeddings): Embedding(128, 128)
        (token_type_embeddings): Embedding(2, 128)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (embeddings_project): Linear(in_features=128, out_features=256, bias=True)
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfO

In [22]:
discriminator.module.electra.embeddings = generator.module.electra.embeddings

In [24]:
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
generator_optimizer = AdamW(
    generator.parameters(), betas=(0.9, 0.999), 
    lr = lr, 
    weight_decay=0.01)
discriminator_optimizer = AdamW(
    discriminator.parameters(), betas=(0.9, 0.999), 
    lr = lr, 
    weight_decay=0.01)

total_steps = len(dataloader) * epochs
generator_scheduler = get_linear_schedule_with_warmup(generator_optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = tot_num_steps)
discriminator_scheduler = get_linear_schedule_with_warmup(discriminator_optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = tot_num_steps)

In [25]:
import time
import datetime

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:

total_t0 = time.time()
for epoch_i in range(0, epochs):
    
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    generator_train_loss = 0
    discriminator_train_loss = 0

    generator.train()
    discriminator.train()
    generator.zero_grad()
    discriminator.zero_grad()
    for step, batch in enumerate(dataloader):
        #generator
        generator_input = batch[0].to(device)
        generator_labels = batch[1].to(device)
        generator_mask = batch[2].to(device)
        generator_original = batch[3].to(device)
        
        generator_loss, generator_scores = generator(generator_input, attention_mask=generator_mask, labels=generator_labels)
        generator_loss = generator_loss.mean()
        generator_train_loss += generator_loss.item()
        generator_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
        
        #discriminator
        discriminator_input = torch.where(generator_labels>=0, torch.argmax(generator_scores,dim=2), generator_original)
        discriminator_labels = torch.where(discriminator_input==generator_original, 
                                           torch.zeros_like(generator_original), torch.ones_like(generator_original))
        discriminator_mask = generator_mask
        
        
        discriminator_loss, discriminator_scores = discriminator(discriminator_input, 
                                                    attention_mask=discriminator_mask, labels=discriminator_labels)
        discriminator_loss = discriminator_loss.mean()
        discriminator_train_loss += discriminator_loss.item()
        discriminator_loss.backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
        
        if step % accum_multipler == 0 and (accum_multipler == 1 or step != 0):
            generator_optimizer.step()
            generator_scheduler.step()
            discriminator_optimizer.step()
            discriminator_scheduler.step()
            generator.zero_grad()
            discriminator.zero_grad()
        
        if step % 200 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.    Generator Loss: {:.3f}.    Discriminator Loss: {:.3f}.'
                  .format(step, 
                          len(dataloader), 
                          elapsed, 
                          generator_train_loss/200, discriminator_train_loss/200))
            generator_train_loss = 0
            discriminator_train_loss = 0
            



Training...




  Batch   200  of  25,622.    Elapsed: 0:03:17.    Generator Loss: 54.284.    Discriminator Loss: 2.914.
  Batch   400  of  25,622.    Elapsed: 0:06:02.    Generator Loss: 52.188.    Discriminator Loss: 1.462.
  Batch   600  of  25,622.    Elapsed: 0:08:48.    Generator Loss: 48.621.    Discriminator Loss: 0.640.
  Batch   800  of  25,622.    Elapsed: 0:11:35.    Generator Loss: 44.140.    Discriminator Loss: 0.409.
  Batch 1,000  of  25,622.    Elapsed: 0:14:24.    Generator Loss: 40.509.    Discriminator Loss: 0.550.
  Batch 1,200  of  25,622.    Elapsed: 0:17:08.    Generator Loss: 38.512.    Discriminator Loss: 0.437.
  Batch 1,400  of  25,622.    Elapsed: 0:19:54.    Generator Loss: 37.989.    Discriminator Loss: 0.432.
  Batch 1,600  of  25,622.    Elapsed: 0:22:44.    Generator Loss: 37.774.    Discriminator Loss: 0.426.
  Batch 1,800  of  25,622.    Elapsed: 0:25:30.    Generator Loss: 37.747.    Discriminator Loss: 0.428.
  Batch 2,000  of  25,622.    Elapsed: 0:28:16.    Gene

In [33]:
torch.save(generator,'electra_small_generator_%s.pth'%seq_length)
torch.save(discriminator,'electra_small_discriminator_%s.pth'%seq_length)