In [5]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {RUBERT_DIR} {MODEL_DIR}
s3 = S3()

In [7]:
if not exists(TEST):
    s3.download(S3_TEST, TEST)
    s3.download(S3_TRAIN, TRAIN)

In [None]:
if not exists(RUBERT_VOCAB):
    s3.download(S3_RUBERT_VOCAB, RUBERT_VOCAB)
    s3.download(S3_RUBERT_EMB, RUBERT_EMB)
    s3.download(S3_RUBERT_ENCODER, RUBERT_ENCODER)
    s3.download(S3_RUBERT_MLM, RUBERT_MLM)

In [None]:
vocab = BERTVocab.load(RUBERT_VOCAB)

In [None]:
config = RuBERTConfig()
emb = BERTEmbedding.from_config()
encoder = BERTEncoder.from_config()
head = BERTMLMHead(config.emb_dim, config.vocab_size)
model = BERTMLM(emb, encoder, head)

 # fix pos emb, train on short seqs
emb.position.weight.requires_grad = False

model.emb.load(RUBERT_EMB)
model.encoder.load(RUBERT_ENCODER)
model.head.load(RUBERT_MLM)
model = model.to(DEVICE)

criterion = masked_flatten_cross_entropy

In [None]:
torch.manual_seed(1)
seed(1)

In [None]:
encode = BERTMLMTrainEncoder(
    vocab,
    seq_len=128,
    batch_size=32,
    shuffle_size=10000
)

lines = load_lines(TEST)
batches = encode(lines)
test_batches = [_.to(DEVICE) for _ in batches]

lines = load_lines(TRAIN)
batches = encode(lines)
train_batches = (_.to(DEVICE) for _ in batches)

In [None]:
board = TensorBoard(BOARD_NAME, RUNS_DIR)
train_board = board.section(TRAIN_BOARD)
test_board = board.section(TEST_BOARD)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
model, optimizer = amp.initialize(model, optimizer, opt_level=O2)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)

In [None]:
train_meter = MLMScoreMeter()
test_meter = MLMScoreMeter()

accum_steps = 64  # 2K batch
log_steps = 256
eval_steps = 512
save_steps = eval_steps * 10

model.train()
optimizer.zero_grad()

for step, batch in log_progress(enumerate(train_batches)):
    batch = process_batch(model, criterion, batch)
    batch.loss /= accum_steps
    
    with amp.scale_loss(batch.loss, optimizer) as scaled:
        scaled.backward()

    score = score_mlm_batch(batch, ks=())
    train_meter.add(score)

    if every(step, log_steps):
        train_meter.write(train_board)
        train_meter.reset()

    if every(step, accum_steps):
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if every(step, eval_steps):
            batches = infer_batches(model, criterion, test_batches)
            scores = score_mlm_batches(batches)
            test_meter.extend(scores)
            test_meter.write(test_board)
            test_meter.reset()
    
    if every(step, save_steps):
        model.emb.dump(MODEL_EMB)
        model.encoder.dump(MODEL_ENCODER)
        model.mlm.dump(MODEL_MLM)
        
        s3.upload(MODEL_EMB, S3_MODEL_EMB)
        s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
        s3.upload(MODEL_MLM, S3_MODEL_MLM)
            
    board.step()