In [1]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {BERT_DIR} {MODEL_DIR}
s3 = S3()

In [2]:
#if not exists(NE5):
#    s3.download(S3_NE5, NE5)
#    s3.download(S3_FACTRU, FACTRU)

In [3]:
if not exists(BERT_VOCAB):
    s3.download(S3_BERT_VOCAB, BERT_VOCAB)
    s3.download(S3_BERT_EMB, BERT_EMB)
    s3.download(S3_BERT_ENCODER, BERT_ENCODER)

In [4]:
words_vocab = BERTVocab.load(BERT_VOCAB)
tags_vocab = BIOTagsVocab(TAGS)

In [5]:
torch.manual_seed(SEED)
seed(SEED)

In [6]:
config = RuBERTConfig()
emb = BERTEmbedding.from_config(config)
encoder = BERTEncoder.from_config(config)
ner = BERTNERHead(config.emb_dim, len(tags_vocab))
model = BERTNER(emb, encoder, ner)

for param in emb.parameters():
    param.requires_grad = False

model.emb.load(BERT_EMB)
model.encoder.load(BERT_ENCODER)
# model = model.to(DEVICE)

BERTEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=3072, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=3072, out_features=768, bias=True)
      (norm1): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=3072, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=3072, out_features=768

In [7]:
records = []
lines = load_gz_lines(CUSTOM_TEXTS) if CUSTOM_TUNING else load_gz_lines(NE5)
# lines = load_gz_lines(NE5)
items = parse_jl(lines)
items = log_progress(items)

for item in items:
    record = SpanMarkup.from_json(item)
    tokens = list(tokenize(record.text))
    record = record.to_bio(tokens)
    records.append(record)

size = round(len(records) * 0.2)

markups = {
    TEST: records[:size],
    TRAIN: records[size:]
}

0it [00:00, ?it/s]

In [8]:
encode = BERTNERTrainEncoder(
    words_vocab, tags_vocab,
    seq_len=128,
    batch_size=32,
    shuffle_size=10000
)

batches = {}
for name in [TEST, TRAIN]:
    # batches[name] = [_.to(DEVICE) for _ in encode(markups[name])]
    batches[name] = [_ for _ in encode(markups[name])]

In [9]:
board = MultiBoard([
    TensorBoard(BOARD_NAME, RUNS_DIR),
    LogBoard()
])
boards = {
    TRAIN: board.section(TRAIN_BOARD),
    TEST: board.section(TEST_BOARD),
}

In [10]:
optimizer = optim.Adam([
    dict(params=encoder.parameters(), lr=BERT_LR),
    dict(params=ner.parameters(), lr=LR),
])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)

In [None]:
meters = {
    TRAIN: NERScoreMeter(),
    TEST: NERScoreMeter(),
}

for epoch in log_progress(range(EPOCHS)):
    model.train()
    for batch in log_progress(batches[TRAIN], leave=False):
        optimizer.zero_grad()
        batch = process_batch(model, ner.crf, batch)
        batch.loss.backward()
        optimizer.step()
    
        score = NERBatchScore(batch.loss)
        meters[TRAIN].add(score)

    meters[TRAIN].write(boards[TRAIN])
    meters[TRAIN].reset()

    model.eval()
    with torch.no_grad():
        for batch in log_progress(batches[TEST], leave=False, desc=TEST):
            batch = process_batch(model, ner.crf, batch)
            batch.target = split_masked(batch.target.value, batch.target.mask)
            batch.pred = ner.crf.decode(batch.pred.value, batch.pred.mask)
            score = score_ner_batch(batch, tags_vocab)
            meters[TEST].add(score)

        meters[TEST].write(boards[TEST])
        meters[TEST].reset()
    
    scheduler.step()
    board.step()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

[2024-05-23 17:55:45]    0 84.6607 01_train/01_loss


test:   0%|          | 0/2 [00:00<?, ?it/s]

[2024-05-23 17:55:52]    0 22.8041 02_test/01_loss
[2024-05-23 17:55:53]    0 0.7960 02_test/02_ORG
[2024-05-23 17:55:53]    0 0.0000 02_test/03_NUM
[2024-05-23 17:55:53]    0 0.0000 02_test/04_NAME_EMPLOYEE
[2024-05-23 17:55:53]    0 0.8066 02_test/05_LINK
[2024-05-23 17:55:53]    0 0.5000 02_test/06_DATE
[2024-05-23 17:55:53]    0 0.1649 02_test/07_ACRONYM
[2024-05-23 17:55:53]    0 0.0000 02_test/08_MAIL
[2024-05-23 17:55:53]    0 0.9635 02_test/09_TELEPHONE
[2024-05-23 17:55:53]    0 0.7048 02_test/10_TECH
[2024-05-23 17:55:53]    0 0.0000 02_test/11_NAME
[2024-05-23 17:55:53]    0 0.0000 02_test/12_PERCENT


  0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# [2020-03-31 14:05:40]    0 14.3334 01_train/01_loss
# [2020-03-31 14:05:43]    0 2.3965 02_test/01_loss
# [2020-03-31 14:05:43]    0 0.9962 02_test/02_PER
# [2020-03-31 14:05:43]    0 0.9807 02_test/03_LOC
# [2020-03-31 14:05:43]    0 0.9691 02_test/04_ORG
# [2020-03-31 14:06:10]    1 1.8448 01_train/01_loss
# [2020-03-31 14:06:13]    1 2.1326 02_test/01_loss
# [2020-03-31 14:06:13]    1 0.9975 02_test/02_PER
# [2020-03-31 14:06:13]    1 0.9862 02_test/03_LOC
# [2020-03-31 14:06:13]    1 0.9710 02_test/04_ORG
# [2020-03-31 14:06:40]    2 1.2753 01_train/01_loss
# [2020-03-31 14:06:43]    2 2.1436 02_test/01_loss
# [2020-03-31 14:06:43]    2 0.9972 02_test/02_PER
# [2020-03-31 14:06:43]    2 0.9867 02_test/03_LOC
# [2020-03-31 14:06:43]    2 0.9705 02_test/04_ORG
# [2020-03-31 14:07:10]    3 1.1283 01_train/01_loss
# [2020-03-31 14:07:13]    3 2.1885 02_test/01_loss
# [2020-03-31 14:07:13]    3 0.9975 02_test/02_PER
# [2020-03-31 14:07:13]    3 0.9867 02_test/03_LOC
# [2020-03-31 14:07:13]    3 0.9719 02_test/04_ORG
# [2020-03-31 14:07:40]    4 1.0464 01_train/01_loss

# [2020-03-31 14:07:43]    4 2.1705 02_test/01_loss
# [2020-03-31 14:07:43]    4 0.9977 02_test/02_PER
# [2020-03-31 14:07:43]    4 0.9862 02_test/03_LOC
# [2020-03-31 14:07:43]    4 0.9722 02_test/04_ORG

In [78]:
model.encoder.dump(MODEL_ENCODER)
ner.dump(MODEL_NER)

# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
# s3.upload(MODEL_NER, S3_MODEL_NER)