In [None]:
import os
import random
from pathlib import Path
from datetime import datetime, timezone, timedelta
import numpy as np
import torch

from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining

# 1. Configuraton

In [None]:
class MyConfig(dict):
    def __getattr__(self, name):
        return self[name]
    def __setattr__(self, name, value):
        self[name] = value

In [None]:
c = MyConfig({
    'device': 'cuda:0',
    'base_run_name': 'vanilla',  # run_name = {base_run_name}_{seed}
    'seed': 11081,  # 11081 36 1188 76 1 4 4649 7 # None/False to randomly choose seed from [0,999999]

    'adam_bias_correction': False,
    'schedule': 'original_linear',
    'sampling': 'fp32_gumbel',
    'electra_mask_style': True,
    'gen_smooth_label': False,
    'disc_smooth_label': False,

    'size': 'small',
#     'datas': ['openwebtext'],
    'datas': ['wikipedia'],
    'logger': "wandb",
    'num_workers': 3,
})


""" Vanilla ELECTRA settings
'adam_bias_correction': False,
'schedule': 'original_linear',
'sampling': 'fp32_gumbel',
'electra_mask_style': True,
'gen_smooth_label': False,
'disc_smooth_label': False,
'size': 'small',
'datas': ['openwebtext'],
"""

In [None]:
# Check and Default
assert c.sampling in ['fp32_gumbel', 'fp16_gumbel', 'multinomial']
assert c.schedule in ['original_linear', 'separate_linear', 'one_cycle', 'adjusted_one_cycle']
for data in c.datas:
    assert data in ['wikipedia', 'bookcorpus', 'openwebtext']
assert c.logger in ['wandb', 'neptune', None, False]

if not c.base_run_name:
    c.base_run_name = str(datetime.now(timezone(timedelta(hours=+8))))[6:-13].replace(' ','').replace(':','').replace('-','')
if not c.seed:
    c.seed = random.randint(0, 999999)

c.run_name = f'{c.base_run_name}_{c.seed}'

if c.gen_smooth_label is True:
    c.gen_smooth_label = 0.1
if c.disc_smooth_label is True:
    c.disc_smooth_label = 0.1

# Setting of different sizes
i = ['small', 'base', 'large'].index(c.size)
c.mask_prob = [0.15, 0.15, 0.25][i]
c.lr = [5e-4, 2e-4, 2e-4][i]
c.bs = [128, 256, 2048][i]
c.steps = [10**6, 766*1000, 400*1000][i]
c.max_length = [128, 512, 512][i]
generator_size_divisor = [4, 3, 4][i]

disc_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-discriminator')
gen_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-generator')
# note that public electra-small model is actually small++ and don't scale down generator size 
gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor)
gen_config.num_attention_heads = disc_config.num_attention_heads//generator_size_divisor
gen_config.intermediate_size = disc_config.intermediate_size//generator_size_divisor
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{c.size}-generator")


# Print info
print(f"process id: {os.getpid()}")
print(c)

# 1. Load Data

In [None]:
hf_tokenizer

In [None]:
import datasets
def download_dataset(cache_dir):
    wiki = datasets.load_dataset('wikipedia', '20200501.en', cache_dir=cache_dir)['train']
    return wiki

wiki = download_dataset('../data/huggingface_datasets')

In [None]:
from pathlib import Path
from functools import partial
from _utils.electra_dataprocessor import ELECTRADataProcessor
data_dir = Path('../data')
def preprocess(wiki, c, hf_tokenizer, num_proc):
    dsets = []
    ELECTRAProcessor = partial(
        ELECTRADataProcessor, hf_tokenizer=hf_tokenizer, max_length=c.max_length)
    # Wikipedia
    if 'wikipedia' in c.datas:
        cache_dir = data_dir / "preprocess" / f"wiki_{len(wiki)}_{c.max_length}"
        cache_dir.mkdir(parents=True, exist_ok=True)
        path = cache_dir / f"electra.arrow"
        
        cache_file_name = str(path.resolve())
        e_wiki = ELECTRAProcessor(wiki).map(cache_file_name=cache_file_name, num_proc=num_proc)
        dsets.append(e_wiki)

    assert len(dsets) == len(c.datas)

    train_dset = datasets.concatenate_datasets(dsets)
    return train_dset

In [None]:
train_dset = preprocess(wiki, c, hf_tokenizer, num_proc=16)

In [None]:
from get_dataloaders import get_dataloader
dl = get_dataloader(c, hf_tokenizer, train_dset, device='cpu')

from fastai.text.all import DataLoaders
dls = DataLoaders(dl, path='.')

In [None]:
len(dls.train)

# 5. Train

In [None]:
# Seed & PyTorch benchmark
torch.backends.cudnn.benchmark = True
dls[0].rng = random.Random(c.seed) # for fastai dataloader
random.seed(c.seed)
np.random.seed(c.seed)
torch.manual_seed(c.seed)

In [None]:
from models import ELECTRAModel, ELECTRALoss
# Generator and Discriminator
generator = ElectraForMaskedLM(gen_config)
discriminator = ElectraForPreTraining(disc_config)
discriminator.electra.embeddings = generator.electra.embeddings
generator.generator_lm_head.weight = generator.electra.embeddings.word_embeddings.weight

# ELECTRA training loop
electra_model = ELECTRAModel(generator, discriminator, hf_tokenizer, sampling=c.sampling)
electra_loss_func = ELECTRALoss(gen_label_smooth=c.gen_smooth_label, disc_label_smooth=c.disc_smooth_label)

In [None]:
from pl_model import LitElectra
model = LitElectra(electra_model, electra_loss_func, hf_tokenizer, config=c)

In [None]:
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project='electra_pretrain_debug', config={**c})

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
max_epochs = 9999
period = - ( - max_epochs // 4) # ceil
checkpoint_callback = ModelCheckpoint(
    save_last=True, 
    period=period,
    filename='pretrain-electra-{epoch}', # ％を渡す方法無さそう
)

In [None]:
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1, gradient_clip_val=1., precision=16,
                     max_epochs=max_epochs,
                     callbacks=[checkpoint_callback],
                     logger=wandb_logger,
                     log_every_n_steps=1,
                    )

In [None]:
dls.to(torch.device(c.device))
trainer.fit(model, dl)