In [1]:
import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
from google.cloud import storage
import tokenizers
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.sampler import RandomSampler
import numpy as np
import random
import torch.nn as nn
import math
import os

In [3]:
#set hyperparameters
seq_length = 128
accum_multipler = 1
batch_size = 128
epochs = 1
warmup_ratio = 0.06
lr = 5e-4

data_size = os.stat("/mnt/d/data_masked_%s"%seq_length).st_size // (batch_size*4)

num_batches = int(math.ceil(data_size / batch_size))
tot_num_steps   = int(math.ceil((data_size / batch_size / accum_multipler)  * epochs))
warmup_steps = int(tot_num_steps * warmup_ratio)
data_size

1379612

In [4]:
print('num_batches:    ', num_batches)
print('data_size:      ', data_size)
print('seq_length:     ', seq_length)
print('lr:             ', lr)
print('epochs:         ', epochs)
print('tot_num_steps:  ', tot_num_steps)
print('warmup_steps:   ', warmup_steps)

num_batches:     10779
data_size:       1379612
seq_length:      128
lr:              0.0005
epochs:          1
tot_num_steps:   10779
warmup_steps:    646


In [5]:
#initialize device
import os
os.environ['TPU_IP_ADDRESS']= "10.240.178.50"
os.environ['XRT_TPU_CONFIG'] = "tpu_worker;0;10.240.178.50:8470"
device = xm.xla_device()
device

device(type='xla', index=1)

In [6]:
#initialize tokenizer
tokenizer = BertWordPieceTokenizer(vocab_file = 'tokenizer/vocab.txt')
tokenizer.add_special_tokens(["<nl>"])
tokenizer.enable_truncation(max_length=seq_length)
tokenizer.enable_padding(length=seq_length)

In [7]:
#initialize data path
data_original_fn = "/mnt/d/data_original_%s"%seq_length
data_masked_fn   = "/mnt/d/data_masked_%s"%seq_length

In [13]:
import random
i = random.randint(0, 100000)
with open(data_original_fn, "rb") as f:
    data = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))
    
with open(data_masked_fn, "rb") as f:
    data_masked = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))
    


In [55]:
#example
from termcolor import colored
tensor = torch.zeros(())
labels = tensor.new_full(data.shape, -100).int()
labels[data!=data_masked] = data[data!=data_masked]

attention_mask = torch.where(data!=0, torch.ones_like(data), torch.zeros_like(data))


for id, label in zip(data, labels):
    if not id:
        continue
    token = tokenizer.id_to_token(id)
    if label >= 0:
        token = colored(tokenizer.id_to_token(label), 'blue')
    print(token, end=" ")
print()
print()
for id, label in zip(data_masked, labels):
    if not id:
        continue
    token = tokenizer.id_to_token(id)
    if label >= 0:
        token = colored(token,'red')
    print(token, end=" ")


[CLS] [34m#[0m 20 l : 好 簡 單 因 爲 強 國 vs 日 本 [34m香[0m [34m港[0m [34m人[0m 咁 拎 崇 日 點 會 話 日 本 嘢 唔 好 ？ [ sosad ] [ sosad ] <nl> # 21 m : d 度 縮 窮 閪 結 婚 ， [34m想[0m 一 圍 執 多 千 幾 蚊 ， 就 搵 [34m環[0m [34m保[0m 做 藉 口 冇 魚 翅 食 ， hiauntie 啦 ， 我 真 [34m係[0m 唔 [34m會[0m [34m比[0m 佢 地 得 逞 🤡 <nl> # 22 f : # 21 你 都 傻 豬 既 [34m🤡[0m [34m🤡[0m 🤡 你 講 到 所 有 環 保 拎 都 係 窮 閪 咁 😒 我 老 母 [34m已[0m [34m經[0m [34m同[0m 我 講 結 婚 唔 好 食 [SEP] 

[CLS] [31m[MASK][0m 20 l : 好 簡 單 因 爲 強 國 vs 日 本 [31m[MASK][0m [31m[MASK][0m [31m[MASK][0m 咁 拎 崇 日 點 會 話 日 本 嘢 唔 好 ？ [ sosad ] [ sosad ] <nl> # 21 m : d 度 縮 窮 閪 結 婚 ， [31m[MASK][0m 一 圍 執 多 千 幾 蚊 ， 就 搵 [31m[MASK][0m [31m[MASK][0m 做 藉 口 冇 魚 翅 食 ， hiauntie 啦 ， 我 真 [31m[MASK][0m 唔 [31m[MASK][0m [31m[MASK][0m 佢 地 得 逞 🤡 <nl> # 22 f : # 21 你 都 傻 豬 既 [31m[MASK][0m [31m[MASK][0m 🤡 你 講 到 所 有 環 保 拎 都 係 窮 閪 咁 😒 我 老 母 [31m[MASK][0m [31m[MASK][0m [31m[MASK][0m 我 講 結 婚 唔 好 食 [SEP] 

In [56]:
#define dataset
class textDataset(Dataset):
    def __init__(self, size):
        self.size = size
    def __len__(self):
        return self.size
    def __getitem__(self,i):
        with open(data_original_fn, "rb") as f:
            data = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))

        with open(data_masked_fn, "rb") as f:
            data_masked = torch.tensor(np.fromfile(f,dtype=np.int32, count=seq_length, offset=seq_length*i*4))
        
        attention_mask = torch.where(data!=0, torch.ones_like(data), torch.zeros_like(data))
        
        tensor = torch.zeros(())
        labels = tensor.new_full(data.shape, -100).int()
        labels[data!=data_masked] = data[data!=data_masked]
        
        unmask_no = int(round(attention_mask.sum().item()*0.15*0.1))
        unmask_indices = torch.randint(0,labels.shape[0],(unmask_no,))
        labels[unmask_indices] = data[unmask_indices]
                      
        return data_masked.long(), labels.long(), attention_mask.long(), data.long()

In [57]:
#set hyperparameters for network
from transformers import ElectraForMaskedLM, ElectraForPreTraining
from transformers import ElectraConfig
import torch.nn as nn

generator_config = ElectraConfig(
    max_position_embeddings=seq_length,
    num_hidden_layers=12,
    vocab_size=50000,
    embedding_size=128,
    hidden_size = 64,
    intermediate_size = 256,
    num_attention_heads=1,
)
discriminator_config = ElectraConfig(
    max_position_embeddings=seq_length,
    num_hidden_layers=12,
    vocab_size=50000,
    embedding_size=128,
    hidden_size=256,
    intermediate_size=1024,
    num_attention_heads=4,
)

generator = ElectraForMaskedLM(config=generator_config)
generator.to(device)
discriminator = ElectraForPreTraining(config=discriminator_config)
discriminator.to(device)
discriminator.electra.embeddings = generator.electra.embeddings

In [18]:
#initialize dataloader and sampler
dataset = textDataset(data_size)
sampler = torch.utils.data.distributed.DistributedSampler(
      dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
dataloader = DataLoader(dataset, batch_size = batch_size, sampler=sampler)

In [16]:
#initialize optimizer and scheduler
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
generator_optimizer = AdamW(
    generator.parameters(), betas=(0.9, 0.999), 
    lr = lr, 
    weight_decay=0.01)
discriminator_optimizer = AdamW(
    discriminator.parameters(), betas=(0.9, 0.999), 
    lr = lr, 
    weight_decay=0.01)

total_steps = len(dataloader) * epochs
generator_scheduler = get_linear_schedule_with_warmup(generator_optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = tot_num_steps)
discriminator_scheduler = get_linear_schedule_with_warmup(discriminator_optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = tot_num_steps)

In [17]:
import time
import datetime

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
#train
total_t0 = time.time()
for epoch_i in range(0, epochs):
    
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    generator_train_loss = 0
    discriminator_train_loss = 0

    generator.train()
    discriminator.train()
    generator.zero_grad()
    discriminator.zero_grad()
    for step, batch in enumerate(dataloader):
        
        #generator
        generator_input = batch[0].to(device)
        generator_labels = batch[1].to(device)
        generator_mask = batch[2].to(device)
        generator_original = batch[3].to(device)
        
        generator_loss, generator_scores = generator(generator_input, attention_mask=generator_mask, labels=generator_labels)
        generator_loss = generator_loss.mean()
        generator_train_loss += generator_loss.item()
        generator_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
        
        #discriminator
        discriminator_input = torch.where(generator_labels>=0, torch.argmax(generator_scores,dim=2), generator_original)
        discriminator_labels = torch.where(discriminator_input==generator_original, 
                                           torch.zeros_like(generator_original), torch.ones_like(generator_original))
        discriminator_mask = generator_mask
        
        
        discriminator_loss, discriminator_scores = discriminator(discriminator_input, 
                                                    attention_mask=discriminator_mask, labels=discriminator_labels)
        discriminator_loss = discriminator_loss.mean()
        discriminator_train_loss += discriminator_loss.item()
        discriminator_loss.backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
        
        if step % accum_multipler == 0 and (accum_multipler == 1 or step != 0):
            xm.optimizer_step(generator_optimizer)
            generator_scheduler.step()
            xm.optimizer_step(discriminator_optimizer)
            discriminator_scheduler.step()
            generator.zero_grad()
            discriminator.zero_grad()
        
        if step % 1 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            
            xm.master_print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.    Generator Loss: {:.3f}.    Discriminator Loss: {:.3f}.'
                  .format(step, 
                          len(dataloader), 
                          elapsed, 
                          generator_train_loss/1, discriminator_train_loss/1))

            generator_train_loss = 0
            discriminator_train_loss = 0


Training...
  Batch     1  of  10,779.    Elapsed: 0:03:49.    Generator Loss: 21.684.    Discriminator Loss: 1.401.
  Batch     2  of  10,779.    Elapsed: 0:07:53.    Generator Loss: 10.847.    Discriminator Loss: 0.694.
  Batch     3  of  10,779.    Elapsed: 0:19:42.    Generator Loss: 10.851.    Discriminator Loss: 0.689.
  Batch     4  of  10,779.    Elapsed: 0:31:49.    Generator Loss: 10.836.    Discriminator Loss: 0.683.
