In [1]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {BERT_DIR} {MODEL_DIR}
s3 = S3()

In [2]:
if not exists(NE5):
    s3.download(S3_NE5, NE5)
    s3.download(S3_BSNLP, BSNLP)
    s3.download(S3_FACTRU, FACTRU)

In [3]:
if not exists(BERT_VOCAB):
    s3.download(S3_BERT_VOCAB, BERT_VOCAB)
    s3.download(S3_BERT_EMB, BERT_EMB)
    s3.download(S3_BERT_ENCODER, BERT_ENCODER)

In [4]:
items = list(load_lines(BERT_VOCAB))
words_vocab = BERTVocab(items)
tags_vocab = BIOTagsVocab([PER, LOC, ORG])

In [5]:
torch.manual_seed(SEED)
seed(SEED)

In [6]:
config = RuBERTConfig()
emb = BERTEmbedding(
    config.vocab_size, config.seq_len, config.emb_dim,
    config.dropout, config.norm_eps
)
encoder = BERTEncoder(
    config.layers_num, config.emb_dim, config.heads_num, config.hidden_dim,
    config.dropout, config.norm_eps
)
ner = BERTNERHead(config.emb_dim, len(tags_vocab))
model = BERTNER(emb, encoder, ner)

for param in emb.parameters():
    param.requires_grad = False

load_model(model.emb, BERT_EMB)
load_model(model.encoder, BERT_ENCODER)
model = model.to(DEVICE)

In [7]:
markups = []
# adding bsnlp makes ORG 1% worse
for path in [NE5, FACTRU]:
    lines = load_gz_lines(path)
    items = parse_jl(lines)
    items = log_progress(items, desc=path)
    for item in items:
        markup = SpanMarkup.from_json(item)
        tokens = list(tokenize(markup.text))
        markup = markup.to_bio(tokens)
        markups.append(markup)
        
size = 100
markups = {
    TEST: markups[:size],  # ne5 is better, use it for test
    DEV: markups[size:2*size],
    TRAIN: markups[2*size:]
}

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='data/ne5.jl.gz', max=1.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='data/factru.jl.gz', max=1.0, style=Prog…




In [8]:
encode = BERTNEREncoder(
    words_vocab, tags_vocab,
    seq_len=128,
    batch_size=32,
    shuffle_size=10000
)

batches = {}
for name in [TEST, DEV, TRAIN]:
    batches[name] = [_.to(DEVICE) for _ in encode(markups[name])]

In [9]:
board = Board(BOARD_NAME, RUNS_DIR)
boards = {
    TRAIN: board.section(TRAIN_BOARD),
    DEV: board.section(DEV_BOARD),
    TEST: board.section(TEST_BOARD),
}

In [10]:
optimizer = optim.Adam([
    dict(params=encoder.parameters(), lr=BERT_LR),
    dict(params=ner.parameters(), lr=LR),
])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)

In [11]:
meters = {
    TRAIN: NERScoreMeter(),
    TEST: NERScoreMeter(),
    DEV: NERScoreMeter()
}

for epoch in log_progress(range(EPOCHS)):
    model.train()
    for batch in log_progress(batches[TRAIN], leave=False):
        optimizer.zero_grad()
        batch = process_batch(model, ner.crf, batch)
        batch.loss.backward()
        optimizer.step()
    
        score = NERBatchScore(batch.loss)
        meters[TRAIN].add(score)

    meters[TRAIN].write(boards[TRAIN])
    meters[TRAIN].reset()

    model.eval()
    with torch.no_grad():
        for name in [TEST, DEV]:
            for batch in log_progress(batches[name], leave=False, desc=name):
                batch = process_batch(model, ner.crf, batch)
                batch.target = split_masked(batch.target.value, batch.target.mask)
                batch.pred = ner.crf.decode(batch.pred.value, batch.pred.mask)
                score = score_batch(batch, tags_vocab)
                meters[name].add(score)

            meters[name].write(boards[name])
            meters[name].reset()
    
    scheduler.step()
    board.step()

HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='test', max=7.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='dev', max=8.0, style=ProgressStyle(description_width='ini…




In [13]:
dump_model(model.encoder, MODEL_ENCODER)
dump_model(model.ner, MODEL_NER)
        
s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
s3.upload(MODEL_NER, S3_MODEL_NER)