In [1]:
import os
import random
from pathlib import Path
from datetime import datetime, timezone, timedelta
import numpy as np
import torch

from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining

In [2]:
torch.cuda.is_available()

True

In [3]:
%load_ext autoreload

# 1. Configuraton

In [4]:
class MyConfig(dict):
    def __getattr__(self, name):
        return self[name]
    def __setattr__(self, name, value):
        self[name] = value

In [5]:
c = MyConfig({
    'device': 'cuda:0',
    'base_run_name': 'vanilla',  # run_name = {base_run_name}_{seed}
    'seed': 11081,  # 11081 36 1188 76 1 4 4649 7 # None/False to randomly choose seed from [0,999999]

    'adam_bias_correction': False,
    'schedule': 'original_linear',
    'sampling': 'fp32_gumbel',
    'electra_mask_style': True,
    'gen_smooth_label': False,
    'disc_smooth_label': False,

    'size': 'small',
#     'datas': ['openwebtext'],
    'datas': ['wikipedia'],
    'logger': "wandb",
    'num_workers': 3,
})


""" Vanilla ELECTRA settings
'adam_bias_correction': False,
'schedule': 'original_linear',
'sampling': 'fp32_gumbel',
'electra_mask_style': True,
'gen_smooth_label': False,
'disc_smooth_label': False,
'size': 'small',
'datas': ['openwebtext'],
"""

" Vanilla ELECTRA settings\n'adam_bias_correction': False,\n'schedule': 'original_linear',\n'sampling': 'fp32_gumbel',\n'electra_mask_style': True,\n'gen_smooth_label': False,\n'disc_smooth_label': False,\n'size': 'small',\n'datas': ['openwebtext'],\n"

In [6]:
# Check and Default
assert c.sampling in ['fp32_gumbel', 'fp16_gumbel', 'multinomial']
assert c.schedule in ['original_linear', 'separate_linear', 'one_cycle', 'adjusted_one_cycle']
for data in c.datas:
    assert data in ['wikipedia', 'bookcorpus', 'openwebtext']
assert c.logger in ['wandb', 'neptune', None, False]

if not c.base_run_name:
    c.base_run_name = str(datetime.now(timezone(timedelta(hours=+8))))[6:-13].replace(' ','').replace(':','').replace('-','')
if not c.seed:
    c.seed = random.randint(0, 999999)

c.run_name = f'{c.base_run_name}_{c.seed}'

if c.gen_smooth_label is True:
    c.gen_smooth_label = 0.1
if c.disc_smooth_label is True:
    c.disc_smooth_label = 0.1

# Setting of different sizes
i = ['small', 'base', 'large'].index(c.size)
c.mask_prob = [0.15, 0.15, 0.25][i]
c.lr = [5e-4, 2e-4, 2e-4][i]
c.bs = [128, 256, 2048][i]
c.steps = [10**6, 766*1000, 400*1000][i]
c.max_length = [128, 512, 512][i]
generator_size_divisor = [4, 3, 4][i]

disc_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-discriminator')
gen_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-generator')
# note that public electra-small model is actually small++ and don't scale down generator size 
gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor)
gen_config.num_attention_heads = disc_config.num_attention_heads//generator_size_divisor
gen_config.intermediate_size = disc_config.intermediate_size//generator_size_divisor
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{c.size}-generator")

# Path to data
Path('./datasets').mkdir(exist_ok=True)
Path('./checkpoints/pretrain').mkdir(exist_ok=True, parents=True)
edl_cache_dir = Path("./datasets/electra_dataloader")
edl_cache_dir.mkdir(exist_ok=True)

# Print info
print(f"process id: {os.getpid()}")
print(c)

process id: 12434
{'device': 'cuda:0', 'base_run_name': 'vanilla', 'seed': 11081, 'adam_bias_correction': False, 'schedule': 'original_linear', 'sampling': 'fp32_gumbel', 'electra_mask_style': True, 'gen_smooth_label': False, 'disc_smooth_label': False, 'size': 'small', 'datas': ['wikipedia'], 'logger': 'wandb', 'num_workers': 3, 'run_name': 'vanilla_11081', 'mask_prob': 0.15, 'lr': 0.0005, 'bs': 128, 'steps': 1000000, 'max_length': 128}


# 1. Load Data

In [7]:
hf_tokenizer

PreTrainedTokenizerFast(name_or_path='google/electra-small-generator', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [8]:
%autoreload 2
from download_datasets import download_dset

train_dset = download_dset(c, hf_tokenizer, cache_dir='./datasets', num_proc=16)

load/download wiki dataset


Reusing dataset wikipedia (./datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63)


load/create data from wiki dataset for ELECTRA
cache_file_name 1000_electra_wiki_128.arrow
cache_file_name /home/miyamonz/ghq/github.com/miyamonz/electra_pytorch/pretrain/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/1000_electra_wiki_128.arrow


Loading cached processed dataset at /home/miyamonz/ghq/github.com/miyamonz/electra_pytorch/pretrain/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/1000_electra_wiki_128_00000_of_00016.arrow
Loading cached processed dataset at /home/miyamonz/ghq/github.com/miyamonz/electra_pytorch/pretrain/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/1000_electra_wiki_128_00001_of_00016.arrow
Loading cached processed dataset at /home/miyamonz/ghq/github.com/miyamonz/electra_pytorch/pretrain/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/1000_electra_wiki_128_00002_of_00016.arrow
Loading cached processed dataset at /home/miyamonz/ghq/github.com/miyamonz/electra_pytorch/pretrain/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/1000_electra_wiki_128_00003_of_00016.arrow
Loading cached proce

In [9]:
%autoreload 2
from get_dataloaders import get_dataloader
dl = get_dataloader(c, hf_tokenizer, train_dset, device='cpu')

from fastai.text.all import DataLoaders
dls = DataLoaders(dl, path='.')

train_dset Dataset({
    features: ['first_segment', 'input_ids', 'second_segment', 'sentA_length'],
    num_rows: 20344249
})
HF_Dataset
cols {'input_ids': <class 'fastai.text.data.TensorText'>, 'sentA_length': <function get_dataloader.<locals>.<lambda> at 0x7f31777b0e50>}
n_inp 2
MySortedDL
pad_idx 0
pad_idxs [0, 0]


  return torch.tensor(x, **format_kwargs)


In [10]:
len(dls.train)

158940

# 5. Train

In [11]:
# Seed & PyTorch benchmark
torch.backends.cudnn.benchmark = True
dls[0].rng = random.Random(c.seed) # for fastai dataloader
random.seed(c.seed)
np.random.seed(c.seed)
torch.manual_seed(c.seed)

<torch._C.Generator at 0x7f31d91459b0>

In [12]:
# Generator and Discriminator
generator = ElectraForMaskedLM(gen_config)
discriminator = ElectraForPreTraining(disc_config)
discriminator.electra.embeddings = generator.electra.embeddings
generator.generator_lm_head.weight = generator.electra.embeddings.word_embeddings.weight

In [13]:
%autoreload 2
from pl_model import LitElectra
model = LitElectra(generator, discriminator, hf_tokenizer, sampling=c.sampling, config=c)

In [14]:
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()

In [15]:
%autoreload 2
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1, gradient_clip_val=1., precision=16,
                     logger=wandb_logger,
                     log_every_n_steps=1,
                    )
trainer.fit(model, dl)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.
[34m[1mwandb[0m: Currently logged in as: [33mmiyamonz[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.19 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name          | Type                  | Params
--------------------------------------------------------
0 | generator     | ElectraForMaskedLM    | 4.6 M 
1 | discriminator | ElectraForPreTraining | 13.5 M
--------------------------------------------------------
14.2 M    Trainable params
0         Non-trainable params
14.2 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…






1