In [None]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {RUBER_DIR} {MODEL_DIR}
s3 = S3()

In [None]:
if not exists(TEST):
    s3.download(S3_TEST, TEST)
    s3.download(S3_TRAIN, TRAIN)

In [None]:
if not exists(RUBERT_VOCAB):
    s3.download(S3_RUBERT_VOCAB, RUBERT_VOCAB)
    s3.download(S3_RUBERT_EMB, RUBERT_EMB)
    s3.download(S3_RUBERT_ENCODER, RUBERT_ENCODER)
    s3.download(S3_RUBERT_MLM, RUBERT_MLM)

In [None]:
items = list(load_lines(RUBERT_VOCAB))
vocab = BERTVocab(items)

In [None]:
device = CUDA0

In [None]:
config = BERTConfig(
    vocab_size=50106,
    seq_len=512,
    emb_dim=768,
    layers_num=12,
    heads_num=12,
    hidden_dim=3072,
    dropout=0.1,
    norm_eps=1e-12
)
emb = BERTEmbedding(
    config.vocab_size, config.seq_len, config.emb_dim,
    config.dropout, config.norm_eps
)
emb.position.requires_grad = False  # fix pos emb to train on short seqs
encoder = BERTEncoder(
    config.layers_num, config.emb_dim, config.heads_num, config.hidden_dim,
    config.dropout, config.norm_eps
)
mlm = BERTMLMHead(config.emb_dim, config.vocab_size)
model = BERTMLM(emb, encoder, mlm)

load_model(model.emb, RUBERT_EMB)
load_model(model.encoder, RUBERT_ENCODER)
load_model(model.mlm, RUBERT_MLM)
model = model.to(device)

In [None]:
torch.manual_seed(1)
seed(1)

In [None]:
encode = BERTMLMEncoder(
    vocab,
    seq_len=128,
    batch_size=32,
    shuffle_size=10000
)

lines = load_lines(TEST)
batches = encode(lines)
test_batches = [_.to(device) for _ in batches]

lines = load_lines(TRAIN)
batches = encode(lines)
train_batches = (_.to(device) for _ in batches)

In [None]:
board = Board(BOARD_NAME, RUNS_DIR)
train_board = board.section(TRAIN_BOARD)
test_board = board.section(TEST_BOARD)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
model, optimizer = amp.initialize(model, optimizer, opt_level=O2)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)

In [None]:
train_meter = MLMScoreMeter()
test_meter = MLMScoreMeter()

accum_steps = 64  # 2K batch
log_steps = 256
eval_steps = 512
save_steps = eval_steps * 10

model.train()
optimizer.zero_grad()

for step, batch in log_progress(enumerate(train_batches)):
    batch = process_batch(model, criterion, batch)
    batch.loss /= accum_steps
    
    with amp.scale_loss(batch.loss, optimizer) as scaled:
        scaled.backward()

    score = score_batch(batch, ks=())
    train_meter.add(score)

    if every(step, log_steps):
        train_meter.write(train_board)
        train_meter.reset()

    if every(step, accum_steps):
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if every(step, eval_steps):
            batches = infer_batches(model, criterion, test_batches)
            scores = score_batches(batches)
            test_meter.extend(scores)
            test_meter.write(test_board)
            test_meter.reset()
    
    if every(step, save_steps):
        dump_model(model.emb, MODEL_EMB)
        dump_model(model.encoder, MODEL_ENCODER)
        dump_model(model.mlm, MODEL_MLM)
        
        s3.upload(MODEL_EMB, S3_MODEL_EMB)
        s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
        s3.upload(MODEL_MLM, S3_MODEL_MLM)
            
    board.step()