In [None]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {BERT_DIR} {MODEL_DIR}
s3 = S3()

In [None]:
if not exists(NEWS):
    s3.download(S3_NEWS, NEWS)
    s3.download(S3_FICTION, FICTION)

In [None]:
if not exists(BERT_VOCAB):
    s3.download(S3_BERT_VOCAB, BERT_VOCAB)
    s3.download(S3_BERT_EMB, BERT_EMB)
    s3.download(S3_BERT_ENCODER, BERT_ENCODER)

In [None]:
words_vocab = BERTVocab.load(BERT_VOCAB)

In [None]:
markups = {}
for path, name in [(NEWS, TEST), (FICTION, TRAIN)]:
    lines = load_gz_lines(path)
    items = parse_jl(lines)
    items = log_progress(items, desc=path)
    records = []
    for item in items:
        record = SyntaxMarkup.from_json(item)
        records.append(record)
    markups[name] = records

rels = set()
for name in [TEST, TRAIN]:
    for markup in markups[name]:
        for token in markup.tokens:
            rels.add(token.rel)
            
rels = [PAD] + sorted(rels)
rels_vocab = Vocab(rels)

In [None]:
torch.manual_seed(SEED)
seed(SEED)

In [None]:
encode = BERTSyntaxTrainEncoder(
    words_vocab, rels_vocab,
    seq_len=128, batch_size=16,
    sort_size=10000
)

batches = {}
for name in [TEST, TRAIN]:
    records = encode(markups[name])
    records = log_progress(records, desc=name)
    batches[name] = [_.to(DEVICE) for _ in records]

In [None]:
config = RuBERTConfig()
emb = BERTEmbedding.from_config(config)
encoder = BERTEncoder.from_config(config)
head = BERTSyntaxHead(
    input_dim=config.emb_dim,
    hidden_dim=config.emb_dim // 2,
)
rel = BERTSyntaxRel(
    input_dim=config.emb_dim,
    hidden_dim=config.emb_dim // 2,
    rel_dim=len(rels_vocab)
)
model = BERTSyntax(emb, encoder, head, rel)

for param in emb.parameters():
    param.requires_grad = False

model.emb.load(BERT_EMB)
model.encoder.load(BERT_ENCODER)
model = model.to(DEVICE)

criterion = masked_flatten_cross_entropy

In [None]:
board = MultiBoard([
    TensorBoard(BOARD_NAME, RUNS_DIR),
    LogBoard()
])
boards = {
    TRAIN: board.section(TRAIN_BOARD),
    TEST: board.section(TEST_BOARD),
}

In [None]:
optimizer = optim.Adam([
    dict(params=encoder.parameters(), lr=BERT_LR),
    dict(params=chain(head.parameters(), rel.parameters()), lr=LR),
])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)

In [None]:
meters = {
    TRAIN: SyntaxScoreMeter(),
    TEST: SyntaxScoreMeter()
}

for epoch in log_progress(range(EPOCHS)):
    model.train()
    for batch in log_progress(batches[TRAIN], leave=False):
        optimizer.zero_grad()
        batch = process_batch(model, criterion, batch)
        batch.loss.backward()
        optimizer.step()
    
        score = score_syntax_batch(batch)
        meters[TRAIN].add(score)

    meters[TRAIN].write(boards[TRAIN])
    meters[TRAIN].reset()

    model.eval()
    with torch.no_grad():
        for batch in log_progress(batches[TEST], leave=False, desc=TEST):
            batch = process_batch(model, criterion, batch)
            score = score_syntax_batch(batch)
            meters[TEST].add(score)
        meters[TEST].write(boards[TEST])
        meters[TEST].reset()
    
    scheduler.step()
    board.step()

In [None]:
# [2020-04-23 08:34:18]    0 0.4612 01_train/01_loss
# [2020-04-23 08:34:18]    0 0.9047 01_train/02_uas
# [2020-04-23 08:34:18]    0 0.8783 01_train/03_las
# [2020-04-23 08:34:19]    0 0.3214 02_test/01_loss
# [2020-04-23 08:34:19]    0 0.9618 02_test/02_uas
# [2020-04-23 08:34:19]    0 0.9300 02_test/03_las
# [2020-04-23 08:39:07]    1 0.2017 01_train/01_loss
# [2020-04-23 08:39:07]    1 0.9511 01_train/02_uas
# [2020-04-23 08:39:07]    1 0.9348 01_train/03_las

# [2020-04-23 08:39:07]    1 0.3035 02_test/01_loss
# [2020-04-23 08:39:07]    1 0.9635 02_test/02_uas
# [2020-04-23 08:39:07]    1 0.9311 02_test/03_las

In [None]:
# model.encoder.dump(MODEL_ENCODER)
# model.head.dump(MODEL_HEAD)
# model.rel.dump(MODEL_REL)
# rels_vocab.dump(RELS_VOCAB)
        
# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
# s3.upload(MODEL_HEAD, S3_MODEL_HEAD)
# s3.upload(MODEL_REL, S3_MODEL_REL)
# s3.upload(RELS_VOCAB, S3_RELS_VOCAB)