In [None]:
import os
import random
from pathlib import Path
from datetime import datetime, timezone, timedelta
import numpy as np
import torch

from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining

In [None]:
%load_ext autoreload
%autoreload 2

# 1. Configuraton

In [None]:
class MyConfig(dict):
    def __getattr__(self, name):
        return self[name]
    def __setattr__(self, name, value):
        self[name] = value

In [None]:
c = MyConfig({
    'device': 'cuda:0',
    'base_run_name': 'vanilla',  # run_name = {base_run_name}_{seed}
    'seed': 11081,  # 11081 36 1188 76 1 4 4649 7 # None/False to randomly choose seed from [0,999999]

    'adam_bias_correction': False,
    'schedule': 'original_linear',
    'sampling': 'fp32_gumbel',
    'electra_mask_style': True,
    'gen_smooth_label': False,
    'disc_smooth_label': False,

    'size': 'small',
#     'datas': ['openwebtext'],
    'datas': ['wikipedia'],
    'logger': "wandb",
    'num_workers': 3,
})


""" Vanilla ELECTRA settings
'adam_bias_correction': False,
'schedule': 'original_linear',
'sampling': 'fp32_gumbel',
'electra_mask_style': True,
'gen_smooth_label': False,
'disc_smooth_label': False,
'size': 'small',
'datas': ['openwebtext'],
"""

In [None]:
# Check and Default
assert c.sampling in ['fp32_gumbel', 'fp16_gumbel', 'multinomial']
assert c.schedule in ['original_linear', 'separate_linear', 'one_cycle', 'adjusted_one_cycle']
for data in c.datas:
    assert data in ['wikipedia', 'bookcorpus', 'openwebtext']
assert c.logger in ['wandb', 'neptune', None, False]

if not c.base_run_name:
    c.base_run_name = str(datetime.now(timezone(timedelta(hours=+8))))[6:-13].replace(' ','').replace(':','').replace('-','')
if not c.seed:
    c.seed = random.randint(0, 999999)

c.run_name = f'{c.base_run_name}_{c.seed}'

if c.gen_smooth_label is True:
    c.gen_smooth_label = 0.1
if c.disc_smooth_label is True:
    c.disc_smooth_label = 0.1

# Setting of different sizes
i = ['small', 'base', 'large'].index(c.size)
c.mask_prob = [0.15, 0.15, 0.25][i]
c.lr = [5e-4, 2e-4, 2e-4][i]
c.bs = [128, 256, 2048][i]
c.steps = [10**6, 766*1000, 400*1000][i]
c.max_length = [128, 512, 512][i]
generator_size_divisor = [4, 3, 4][i]

disc_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-discriminator')
gen_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-generator')
# note that public electra-small model is actually small++ and don't scale down generator size 
gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor)
gen_config.num_attention_heads = disc_config.num_attention_heads//generator_size_divisor
gen_config.intermediate_size = disc_config.intermediate_size//generator_size_divisor
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{c.size}-generator")

# Path to data
Path('./datasets').mkdir(exist_ok=True)
Path('./checkpoints/pretrain').mkdir(exist_ok=True, parents=True)
edl_cache_dir = Path("./datasets/electra_dataloader")
edl_cache_dir.mkdir(exist_ok=True)

# Print info
print(f"process id: {os.getpid()}")
print(c)

# 1. Load Data

In [None]:
hf_tokenizer

In [None]:
%autoreload 2
from download_datasets import download_dset

train_dset = download_dset(c, hf_tokenizer, num_proc=16)

In [None]:
%autoreload 2
from get_dataloaders import get_dataloaders

dls = get_dataloaders(c, hf_tokenizer, train_dset)

In [None]:
len(dls.train)

In [None]:
train_dset.cache_files